Movie/Group tags (#4969)

* Combine common tag control code into hook
* Combine common scraped tag row code into hook
This commit is contained in:
WithoutPants
2024-06-18 11:24:15 +10:00
committed by GitHub
parent f9a624b803
commit fda4776d30
63 changed files with 1586 additions and 450 deletions

View File

@@ -23,6 +23,7 @@ type Movie struct {
BackImage string `json:"back_image,omitempty"`
URLs []string `json:"urls,omitempty"`
Studio string `json:"studio,omitempty"`
Tags []string `json:"tags,omitempty"`
CreatedAt json.JSONTime `json:"created_at,omitempty"`
UpdatedAt json.JSONTime `json:"updated_at,omitempty"`

View File

@@ -312,6 +312,29 @@ func (_m *MovieReaderWriter) GetFrontImage(ctx context.Context, movieID int) ([]
return r0, r1
}
// GetTagIDs provides a mock function with given fields: ctx, relatedID
func (_m *MovieReaderWriter) GetTagIDs(ctx context.Context, relatedID int) ([]int, error) {
ret := _m.Called(ctx, relatedID)
var r0 []int
if rf, ok := ret.Get(0).(func(context.Context, int) []int); ok {
r0 = rf(ctx, relatedID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]int)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
r1 = rf(ctx, relatedID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// GetURLs provides a mock function with given fields: ctx, relatedID
func (_m *MovieReaderWriter) GetURLs(ctx context.Context, relatedID int) ([]string, error) {
ret := _m.Called(ctx, relatedID)

View File

@@ -266,6 +266,29 @@ func (_m *TagReaderWriter) FindByImageID(ctx context.Context, imageID int) ([]*m
return r0, r1
}
// FindByMovieID provides a mock function with given fields: ctx, movieID
func (_m *TagReaderWriter) FindByMovieID(ctx context.Context, movieID int) ([]*models.Tag, error) {
ret := _m.Called(ctx, movieID)
var r0 []*models.Tag
if rf, ok := ret.Get(0).(func(context.Context, int) []*models.Tag); ok {
r0 = rf(ctx, movieID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Tag)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
r1 = rf(ctx, movieID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// FindByName provides a mock function with given fields: ctx, name, nocase
func (_m *TagReaderWriter) FindByName(ctx context.Context, name string, nocase bool) (*models.Tag, error) {
ret := _m.Called(ctx, name, nocase)

View File

@@ -19,7 +19,8 @@ type Movie struct {
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
URLs RelatedStrings `json:"urls"`
URLs RelatedStrings `json:"urls"`
TagIDs RelatedIDs `json:"tag_ids"`
}
func NewMovie() Movie {
@@ -30,9 +31,15 @@ func NewMovie() Movie {
}
}
func (g *Movie) LoadURLs(ctx context.Context, l URLLoader) error {
return g.URLs.load(func() ([]string, error) {
return l.GetURLs(ctx, g.ID)
func (m *Movie) LoadURLs(ctx context.Context, l URLLoader) error {
return m.URLs.load(func() ([]string, error) {
return l.GetURLs(ctx, m.ID)
})
}
func (m *Movie) LoadTagIDs(ctx context.Context, l TagIDLoader) error {
return m.TagIDs.load(func() ([]int, error) {
return l.GetTagIDs(ctx, m.ID)
})
}
@@ -47,6 +54,7 @@ type MoviePartial struct {
Director OptionalString
Synopsis OptionalString
URLs *UpdateStrings
TagIDs *UpdateIDs
CreatedAt OptionalTime
UpdatedAt OptionalTime
}

View File

@@ -371,6 +371,7 @@ type ScrapedMovie struct {
URLs []string `json:"urls"`
Synopsis *string `json:"synopsis"`
Studio *ScrapedStudio `json:"studio"`
Tags []*ScrapedTag `json:"tags"`
// This should be a base64 encoded data URL
FrontImage *string `json:"front_image"`
// This should be a base64 encoded data URL

View File

@@ -17,6 +17,10 @@ type MovieFilterType struct {
URL *StringCriterionInput `json:"url"`
// Filter to only include movies where performer appears in a scene
Performers *MultiCriterionInput `json:"performers"`
// Filter to only include performers with these tags
Tags *HierarchicalMultiCriterionInput `json:"tags"`
// Filter by tag count
TagCount *IntCriterionInput `json:"tag_count"`
// Filter by date
Date *DateCriterionInput `json:"date"`
// Filter by related scenes that meet this criteria

View File

@@ -65,6 +65,7 @@ type MovieReader interface {
MovieQueryer
MovieCounter
URLLoader
TagIDLoader
All(ctx context.Context) ([]*Movie, error)
GetFrontImage(ctx context.Context, movieID int) ([]byte, error)

View File

@@ -20,6 +20,7 @@ type TagFinder interface {
FindByImageID(ctx context.Context, imageID int) ([]*Tag, error)
FindByGalleryID(ctx context.Context, galleryID int) ([]*Tag, error)
FindByPerformerID(ctx context.Context, performerID int) ([]*Tag, error)
FindByMovieID(ctx context.Context, movieID int) ([]*Tag, error)
FindBySceneMarkerID(ctx context.Context, sceneMarkerID int) ([]*Tag, error)
FindByName(ctx context.Context, name string, nocase bool) (*Tag, error)
FindByNames(ctx context.Context, names []string, nocase bool) ([]*Tag, error)

View File

@@ -20,6 +20,8 @@ type TagFilterType struct {
GalleryCount *IntCriterionInput `json:"gallery_count"`
// Filter by number of performers with this tag
PerformerCount *IntCriterionInput `json:"performer_count"`
// Filter by number of movies with this tag
MovieCount *IntCriterionInput `json:"movie_count"`
// Filter by number of markers with this tag
MarkerCount *IntCriterionInput `json:"marker_count"`
// Filter by parent tags

View File

@@ -3,9 +3,11 @@ package movie
import (
"context"
"fmt"
"strings"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/models/jsonschema"
"github.com/stashapp/stash/pkg/sliceutil"
"github.com/stashapp/stash/pkg/utils"
)
@@ -17,6 +19,7 @@ type ImporterReaderWriter interface {
type Importer struct {
ReaderWriter ImporterReaderWriter
StudioWriter models.StudioFinderCreator
TagWriter models.TagFinderCreator
Input jsonschema.Movie
MissingRefBehaviour models.ImportMissingRefEnum
@@ -32,6 +35,10 @@ func (i *Importer) PreImport(ctx context.Context) error {
return err
}
if err := i.populateTags(ctx); err != nil {
return err
}
var err error
if len(i.Input.FrontImage) > 0 {
i.frontImageData, err = utils.ProcessBase64Image(i.Input.FrontImage)
@@ -49,6 +56,74 @@ func (i *Importer) PreImport(ctx context.Context) error {
return nil
}
func (i *Importer) populateTags(ctx context.Context) error {
if len(i.Input.Tags) > 0 {
tags, err := importTags(ctx, i.TagWriter, i.Input.Tags, i.MissingRefBehaviour)
if err != nil {
return err
}
for _, p := range tags {
i.movie.TagIDs.Add(p.ID)
}
}
return nil
}
func importTags(ctx context.Context, tagWriter models.TagFinderCreator, names []string, missingRefBehaviour models.ImportMissingRefEnum) ([]*models.Tag, error) {
tags, err := tagWriter.FindByNames(ctx, names, false)
if err != nil {
return nil, err
}
var pluckedNames []string
for _, tag := range tags {
pluckedNames = append(pluckedNames, tag.Name)
}
missingTags := sliceutil.Filter(names, func(name string) bool {
return !sliceutil.Contains(pluckedNames, name)
})
if len(missingTags) > 0 {
if missingRefBehaviour == models.ImportMissingRefEnumFail {
return nil, fmt.Errorf("tags [%s] not found", strings.Join(missingTags, ", "))
}
if missingRefBehaviour == models.ImportMissingRefEnumCreate {
createdTags, err := createTags(ctx, tagWriter, missingTags)
if err != nil {
return nil, fmt.Errorf("error creating tags: %v", err)
}
tags = append(tags, createdTags...)
}
// ignore if MissingRefBehaviour set to Ignore
}
return tags, nil
}
func createTags(ctx context.Context, tagWriter models.TagFinderCreator, names []string) ([]*models.Tag, error) {
var ret []*models.Tag
for _, name := range names {
newTag := models.NewTag()
newTag.Name = name
err := tagWriter.Create(ctx, &newTag)
if err != nil {
return nil, err
}
ret = append(ret, &newTag)
}
return ret, nil
}
func (i *Importer) movieJSONToMovie(movieJSON jsonschema.Movie) models.Movie {
newMovie := models.Movie{
Name: movieJSON.Name,
@@ -57,6 +132,8 @@ func (i *Importer) movieJSONToMovie(movieJSON jsonschema.Movie) models.Movie {
Synopsis: movieJSON.Synopsis,
CreatedAt: movieJSON.CreatedAt.GetTime(),
UpdatedAt: movieJSON.UpdatedAt.GetTime(),
TagIDs: models.NewRelatedIDs([]int{}),
}
if len(movieJSON.URLs) > 0 {

View File

@@ -26,6 +26,13 @@ const (
missingStudioName = "existingStudioName"
errImageID = 3
existingTagID = 105
errTagsID = 106
existingTagName = "existingTagName"
existingTagErr = "existingTagErr"
missingTagName = "missingTagName"
)
var testCtx = context.Background()
@@ -157,6 +164,97 @@ func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) {
db.AssertExpectations(t)
}
func TestImporterPreImportWithTag(t *testing.T) {
db := mocks.NewDatabase()
i := Importer{
ReaderWriter: db.Movie,
TagWriter: db.Tag,
MissingRefBehaviour: models.ImportMissingRefEnumFail,
Input: jsonschema.Movie{
Tags: []string{
existingTagName,
},
},
}
db.Tag.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{
{
ID: existingTagID,
Name: existingTagName,
},
}, nil).Once()
db.Tag.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once()
err := i.PreImport(testCtx)
assert.Nil(t, err)
assert.Equal(t, existingTagID, i.movie.TagIDs.List()[0])
i.Input.Tags = []string{existingTagErr}
err = i.PreImport(testCtx)
assert.NotNil(t, err)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingTag(t *testing.T) {
db := mocks.NewDatabase()
i := Importer{
ReaderWriter: db.Movie,
TagWriter: db.Tag,
Input: jsonschema.Movie{
Tags: []string{
missingTagName,
},
},
MissingRefBehaviour: models.ImportMissingRefEnumFail,
}
db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3)
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) {
t := args.Get(1).(*models.Tag)
t.ID = existingTagID
}).Return(nil)
err := i.PreImport(testCtx)
assert.NotNil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore
err = i.PreImport(testCtx)
assert.Nil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumCreate
err = i.PreImport(testCtx)
assert.Nil(t, err)
assert.Equal(t, existingTagID, i.movie.TagIDs.List()[0])
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) {
db := mocks.NewDatabase()
i := Importer{
ReaderWriter: db.Movie,
TagWriter: db.Tag,
Input: jsonschema.Movie{
Tags: []string{
missingTagName,
},
},
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
}
db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once()
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error"))
err := i.PreImport(testCtx)
assert.NotNil(t, err)
db.AssertExpectations(t)
}
func TestImporterPostImport(t *testing.T) {
db := mocks.NewDatabase()

View File

@@ -18,3 +18,15 @@ func CountByStudioID(ctx context.Context, r models.MovieQueryer, id int, depth *
return r.QueryCount(ctx, filter, nil)
}
func CountByTagID(ctx context.Context, r models.MovieQueryer, id int, depth *int) (int, error) {
filter := &models.MovieFilterType{
Tags: &models.HierarchicalMultiCriterionInput{
Value: []string{strconv.Itoa(id)},
Modifier: models.CriterionModifierIncludes,
Depth: depth,
},
}
return r.QueryCount(ctx, filter, nil)
}

View File

@@ -284,11 +284,13 @@ type mappedMovieScraperConfig struct {
mappedConfig
Studio mappedConfig `yaml:"Studio"`
Tags mappedConfig `yaml:"Tags"`
}
type _mappedMovieScraperConfig mappedMovieScraperConfig
const (
mappedScraperConfigMovieStudio = "Studio"
mappedScraperConfigMovieTags = "Tags"
)
func (s *mappedMovieScraperConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
@@ -303,9 +305,11 @@ func (s *mappedMovieScraperConfig) UnmarshalYAML(unmarshal func(interface{}) err
thisMap := make(map[string]interface{})
thisMap[mappedScraperConfigMovieStudio] = parentMap[mappedScraperConfigMovieStudio]
delete(parentMap, mappedScraperConfigMovieStudio)
thisMap[mappedScraperConfigMovieTags] = parentMap[mappedScraperConfigMovieTags]
delete(parentMap, mappedScraperConfigMovieTags)
// re-unmarshal the sub-fields
yml, err := yaml.Marshal(thisMap)
if err != nil {
@@ -1086,6 +1090,7 @@ func (s mappedScraper) scrapeMovie(ctx context.Context, q mappedQuery) (*models.
movieMap := movieScraperConfig.mappedConfig
movieStudioMap := movieScraperConfig.Studio
movieTagsMap := movieScraperConfig.Tags
results := movieMap.process(ctx, q, s.Common)
@@ -1100,7 +1105,19 @@ func (s mappedScraper) scrapeMovie(ctx context.Context, q mappedQuery) (*models.
}
}
if len(results) == 0 && ret.Studio == nil {
// now apply the tags
if movieTagsMap != nil {
logger.Debug(`Processing movie tags:`)
tagResults := movieTagsMap.process(ctx, q, s.Common)
for _, p := range tagResults {
tag := &models.ScrapedTag{}
p.apply(tag)
ret.Tags = append(ret.Tags, tag)
}
}
if len(results) == 0 && ret.Studio == nil && len(ret.Tags) == 0 {
return nil, nil
}

View File

@@ -71,13 +71,24 @@ func (c Cache) postScrapePerformer(ctx context.Context, p models.ScrapedPerforme
}
func (c Cache) postScrapeMovie(ctx context.Context, m models.ScrapedMovie) (ScrapedContent, error) {
if m.Studio != nil {
r := c.repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
return match.ScrapedStudio(ctx, r.StudioFinder, m.Studio, nil)
}); err != nil {
return nil, err
r := c.repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
tqb := r.TagFinder
tags, err := postProcessTags(ctx, tqb, m.Tags)
if err != nil {
return err
}
m.Tags = tags
if m.Studio != nil {
if err := match.ScrapedStudio(ctx, r.StudioFinder, m.Studio, nil); err != nil {
return err
}
}
return nil
}); err != nil {
return nil, err
}
// post-process - set the image if applicable

View File

@@ -30,7 +30,7 @@ const (
dbConnTimeout = 30
)
var appSchemaVersion uint = 60
var appSchemaVersion uint = 61
//go:embed migrations/*.sql
var migrationsBox embed.FS

View File

@@ -0,0 +1,10 @@
CREATE TABLE `movies_tags` (
`movie_id` integer NOT NULL,
`tag_id` integer NOT NULL,
foreign key(`movie_id`) references `movies`(`id`) on delete CASCADE,
foreign key(`tag_id`) references `tags`(`id`) on delete CASCADE,
PRIMARY KEY(`movie_id`, `tag_id`)
);
CREATE INDEX `index_movies_tags_on_tag_id` on `movies_tags` (`tag_id`);
CREATE INDEX `index_movies_tags_on_movie_id` on `movies_tags` (`movie_id`);

View File

@@ -23,6 +23,8 @@ const (
movieFrontImageBlobColumn = "front_image_blob"
movieBackImageBlobColumn = "back_image_blob"
moviesTagsTable = "movies_tags"
movieURLsTable = "movie_urls"
movieURLColumn = "url"
)
@@ -98,6 +100,7 @@ func (r *movieRowRecord) fromPartial(o models.MoviePartial) {
type movieRepositoryType struct {
repository
scenes repository
tags joinRepository
}
var (
@@ -110,11 +113,21 @@ var (
tableName: moviesScenesTable,
idColumn: movieIDColumn,
},
tags: joinRepository{
repository: repository{
tableName: moviesTagsTable,
idColumn: movieIDColumn,
},
fkColumn: tagIDColumn,
foreignTable: tagTable,
orderBy: "tags.name ASC",
},
}
)
type MovieStore struct {
blobJoinQueryBuilder
tagRelationshipStore
tableMgr *table
}
@@ -125,6 +138,11 @@ func NewMovieStore(blobStore *BlobStore) *MovieStore {
blobStore: blobStore,
joinTable: movieTable,
},
tagRelationshipStore: tagRelationshipStore{
idRelationshipStore: idRelationshipStore{
joinTable: moviesTagsTableMgr,
},
},
tableMgr: movieTableMgr,
}
@@ -154,6 +172,10 @@ func (qb *MovieStore) Create(ctx context.Context, newObject *models.Movie) error
}
}
if err := qb.tagRelationshipStore.createRelationships(ctx, id, newObject.TagIDs); err != nil {
return err
}
updated, err := qb.find(ctx, id)
if err != nil {
return fmt.Errorf("finding after create: %w", err)
@@ -185,6 +207,10 @@ func (qb *MovieStore) UpdatePartial(ctx context.Context, id int, partial models.
}
}
if err := qb.tagRelationshipStore.modifyRelationships(ctx, id, partial.TagIDs); err != nil {
return nil, err
}
return qb.find(ctx, id)
}
@@ -202,6 +228,10 @@ func (qb *MovieStore) Update(ctx context.Context, updatedObject *models.Movie) e
}
}
if err := qb.tagRelationshipStore.replaceRelationships(ctx, updatedObject.ID, updatedObject.TagIDs); err != nil {
return err
}
return nil
}
@@ -430,6 +460,7 @@ var movieSortOptions = sortOptions{
"random",
"rating",
"scenes_count",
"tag_count",
"updated_at",
}
@@ -451,6 +482,8 @@ func (qb *MovieStore) getMovieSort(findFilter *models.FindFilterType) (string, e
sortQuery := ""
switch sort {
case "tag_count":
sortQuery += getCountSort(movieTable, moviesTagsTable, movieIDColumn, direction)
case "scenes_count": // generic getSort won't work for this
sortQuery += getCountSort(movieTable, moviesScenesTable, movieIDColumn, direction)
default:

View File

@@ -63,6 +63,8 @@ func (qb *movieFilterHandler) criterionHandler() criterionHandler {
qb.urlsCriterionHandler(movieFilter.URL),
studioCriterionHandler(movieTable, movieFilter.Studios),
qb.performersCriterionHandler(movieFilter.Performers),
qb.tagsCriterionHandler(movieFilter.Tags),
qb.tagCountCriterionHandler(movieFilter.TagCount),
&dateCriterionHandler{movieFilter.Date, "movies.date", nil},
&timestampCriterionHandler{movieFilter.CreatedAt, "movies.created_at", nil},
&timestampCriterionHandler{movieFilter.UpdatedAt, "movies.updated_at", nil},
@@ -162,3 +164,28 @@ func (qb *movieFilterHandler) performersCriterionHandler(performers *models.Mult
}
}
}
func (qb *movieFilterHandler) tagsCriterionHandler(tags *models.HierarchicalMultiCriterionInput) criterionHandlerFunc {
h := joinedHierarchicalMultiCriterionHandlerBuilder{
primaryTable: movieTable,
foreignTable: tagTable,
foreignFK: "tag_id",
relationsTable: "tags_relations",
joinAs: "movie_tag",
joinTable: moviesTagsTable,
primaryFK: movieIDColumn,
}
return h.handler(tags)
}
func (qb *movieFilterHandler) tagCountCriterionHandler(count *models.IntCriterionInput) criterionHandlerFunc {
h := countCriterionHandlerBuilder{
primaryTable: movieTable,
joinTable: moviesTagsTable,
primaryFK: movieIDColumn,
}
return h.handler(count)
}

View File

@@ -9,6 +9,7 @@ import (
"strconv"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
@@ -17,7 +18,12 @@ import (
func loadMovieRelationships(ctx context.Context, expected models.Movie, actual *models.Movie) error {
if expected.URLs.Loaded() {
if err := actual.LoadURLs(ctx, db.Gallery); err != nil {
if err := actual.LoadURLs(ctx, db.Movie); err != nil {
return err
}
}
if expected.TagIDs.Loaded() {
if err := actual.LoadTagIDs(ctx, db.Movie); err != nil {
return err
}
}
@@ -25,6 +31,337 @@ func loadMovieRelationships(ctx context.Context, expected models.Movie, actual *
return nil
}
func Test_MovieStore_Create(t *testing.T) {
var (
name = "name"
url = "url"
aliases = "alias1, alias2"
director = "director"
rating = 60
duration = 34
synopsis = "synopsis"
date, _ = models.ParseDate("2003-02-01")
createdAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC)
updatedAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC)
)
tests := []struct {
name string
newObject models.Movie
wantErr bool
}{
{
"full",
models.Movie{
Name: name,
Duration: &duration,
Date: &date,
Rating: &rating,
StudioID: &studioIDs[studioIdxWithMovie],
Director: director,
Synopsis: synopsis,
URLs: models.NewRelatedStrings([]string{url}),
TagIDs: models.NewRelatedIDs([]int{tagIDs[tagIdx1WithDupName], tagIDs[tagIdx1WithMovie]}),
Aliases: aliases,
CreatedAt: createdAt,
UpdatedAt: updatedAt,
},
false,
},
{
"invalid tag id",
models.Movie{
Name: name,
TagIDs: models.NewRelatedIDs([]int{invalidID}),
},
true,
},
}
qb := db.Movie
for _, tt := range tests {
runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) {
assert := assert.New(t)
p := tt.newObject
if err := qb.Create(ctx, &p); (err != nil) != tt.wantErr {
t.Errorf("MovieStore.Create() error = %v, wantErr = %v", err, tt.wantErr)
}
if tt.wantErr {
assert.Zero(p.ID)
return
}
assert.NotZero(p.ID)
copy := tt.newObject
copy.ID = p.ID
// load relationships
if err := loadMovieRelationships(ctx, copy, &p); err != nil {
t.Errorf("loadMovieRelationships() error = %v", err)
return
}
assert.Equal(copy, p)
// ensure can find the movie
found, err := qb.Find(ctx, p.ID)
if err != nil {
t.Errorf("MovieStore.Find() error = %v", err)
}
if !assert.NotNil(found) {
return
}
// load relationships
if err := loadMovieRelationships(ctx, copy, found); err != nil {
t.Errorf("loadMovieRelationships() error = %v", err)
return
}
assert.Equal(copy, *found)
return
})
}
}
func Test_movieQueryBuilder_Update(t *testing.T) {
var (
name = "name"
url = "url"
aliases = "alias1, alias2"
director = "director"
rating = 60
duration = 34
synopsis = "synopsis"
date, _ = models.ParseDate("2003-02-01")
createdAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC)
updatedAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC)
)
tests := []struct {
name string
updatedObject *models.Movie
wantErr bool
}{
{
"full",
&models.Movie{
ID: movieIDs[movieIdxWithTag],
Name: name,
Duration: &duration,
Date: &date,
Rating: &rating,
StudioID: &studioIDs[studioIdxWithMovie],
Director: director,
Synopsis: synopsis,
URLs: models.NewRelatedStrings([]string{url}),
TagIDs: models.NewRelatedIDs([]int{tagIDs[tagIdx1WithDupName], tagIDs[tagIdx1WithMovie]}),
Aliases: aliases,
CreatedAt: createdAt,
UpdatedAt: updatedAt,
},
false,
},
{
"clear tag ids",
&models.Movie{
ID: movieIDs[movieIdxWithTag],
Name: name,
TagIDs: models.NewRelatedIDs([]int{}),
},
false,
},
{
"invalid studio id",
&models.Movie{
ID: movieIDs[movieIdxWithScene],
Name: name,
StudioID: &invalidID,
},
true,
},
{
"invalid tag id",
&models.Movie{
ID: movieIDs[movieIdxWithScene],
Name: name,
TagIDs: models.NewRelatedIDs([]int{invalidID}),
},
true,
},
}
qb := db.Movie
for _, tt := range tests {
runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) {
assert := assert.New(t)
copy := *tt.updatedObject
if err := qb.Update(ctx, tt.updatedObject); (err != nil) != tt.wantErr {
t.Errorf("movieQueryBuilder.Update() error = %v, wantErr %v", err, tt.wantErr)
}
if tt.wantErr {
return
}
s, err := qb.Find(ctx, tt.updatedObject.ID)
if err != nil {
t.Errorf("movieQueryBuilder.Find() error = %v", err)
}
// load relationships
if err := loadMovieRelationships(ctx, copy, s); err != nil {
t.Errorf("loadMovieRelationships() error = %v", err)
return
}
assert.Equal(copy, *s)
})
}
}
func clearMoviePartial() models.MoviePartial {
// leave mandatory fields
return models.MoviePartial{
Aliases: models.OptionalString{Set: true, Null: true},
Synopsis: models.OptionalString{Set: true, Null: true},
Director: models.OptionalString{Set: true, Null: true},
Duration: models.OptionalInt{Set: true, Null: true},
URLs: &models.UpdateStrings{Mode: models.RelationshipUpdateModeSet},
Date: models.OptionalDate{Set: true, Null: true},
Rating: models.OptionalInt{Set: true, Null: true},
StudioID: models.OptionalInt{Set: true, Null: true},
TagIDs: &models.UpdateIDs{Mode: models.RelationshipUpdateModeSet},
}
}
func Test_movieQueryBuilder_UpdatePartial(t *testing.T) {
var (
name = "name"
url = "url"
aliases = "alias1, alias2"
director = "director"
rating = 60
duration = 34
synopsis = "synopsis"
date, _ = models.ParseDate("2003-02-01")
createdAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC)
updatedAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC)
)
tests := []struct {
name string
id int
partial models.MoviePartial
want models.Movie
wantErr bool
}{
{
"full",
movieIDs[movieIdxWithScene],
models.MoviePartial{
Name: models.NewOptionalString(name),
Director: models.NewOptionalString(director),
Synopsis: models.NewOptionalString(synopsis),
Aliases: models.NewOptionalString(aliases),
URLs: &models.UpdateStrings{
Values: []string{url},
Mode: models.RelationshipUpdateModeSet,
},
Date: models.NewOptionalDate(date),
Duration: models.NewOptionalInt(duration),
Rating: models.NewOptionalInt(rating),
StudioID: models.NewOptionalInt(studioIDs[studioIdxWithMovie]),
CreatedAt: models.NewOptionalTime(createdAt),
UpdatedAt: models.NewOptionalTime(updatedAt),
TagIDs: &models.UpdateIDs{
IDs: []int{tagIDs[tagIdx1WithMovie], tagIDs[tagIdx1WithDupName]},
Mode: models.RelationshipUpdateModeSet,
},
},
models.Movie{
ID: movieIDs[movieIdxWithScene],
Name: name,
Director: director,
Synopsis: synopsis,
Aliases: aliases,
URLs: models.NewRelatedStrings([]string{url}),
Date: &date,
Duration: &duration,
Rating: &rating,
StudioID: &studioIDs[studioIdxWithMovie],
CreatedAt: createdAt,
UpdatedAt: updatedAt,
TagIDs: models.NewRelatedIDs([]int{tagIDs[tagIdx1WithDupName], tagIDs[tagIdx1WithMovie]}),
},
false,
},
{
"clear all",
movieIDs[movieIdxWithScene],
clearMoviePartial(),
models.Movie{
ID: movieIDs[movieIdxWithScene],
Name: movieNames[movieIdxWithScene],
TagIDs: models.NewRelatedIDs([]int{}),
},
false,
},
{
"invalid id",
invalidID,
models.MoviePartial{},
models.Movie{},
true,
},
}
for _, tt := range tests {
qb := db.Movie
runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) {
assert := assert.New(t)
got, err := qb.UpdatePartial(ctx, tt.id, tt.partial)
if (err != nil) != tt.wantErr {
t.Errorf("movieQueryBuilder.UpdatePartial() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr {
return
}
// load relationships
if err := loadMovieRelationships(ctx, tt.want, got); err != nil {
t.Errorf("loadMovieRelationships() error = %v", err)
return
}
assert.Equal(tt.want, *got)
s, err := qb.Find(ctx, tt.id)
if err != nil {
t.Errorf("movieQueryBuilder.Find() error = %v", err)
}
// load relationships
if err := loadMovieRelationships(ctx, tt.want, s); err != nil {
t.Errorf("loadMovieRelationships() error = %v", err)
return
}
assert.Equal(tt.want, *s)
})
}
}
func TestMovieFindByName(t *testing.T) {
withTxn(func(ctx context.Context) error {
mqb := db.Movie
@@ -280,12 +617,12 @@ func TestMovieQueryURLExcludes(t *testing.T) {
Name: &nameCriterion,
}
movies := queryMovie(ctx, t, mqb, &filter, nil)
movies := queryMovies(ctx, t, &filter, nil)
assert.Len(t, movies, 0, "Expected no movies to be found")
// query for movies that exclude the URL "ccc"
urlCriterion.Value = "ccc"
movies = queryMovie(ctx, t, mqb, &filter, nil)
movies = queryMovies(ctx, t, &filter, nil)
if assert.Len(t, movies, 1, "Expected one movie to be found") {
assert.Equal(t, movie.Name, movies[0].Name)
@@ -300,7 +637,7 @@ func verifyMovieQuery(t *testing.T, filter models.MovieFilterType, verifyFn func
t.Helper()
sqb := db.Movie
movies := queryMovie(ctx, t, sqb, &filter, nil)
movies := queryMovies(ctx, t, &filter, nil)
for _, movie := range movies {
if err := movie.LoadURLs(ctx, sqb); err != nil {
@@ -319,7 +656,8 @@ func verifyMovieQuery(t *testing.T, filter models.MovieFilterType, verifyFn func
})
}
func queryMovie(ctx context.Context, t *testing.T, sqb models.MovieReader, movieFilter *models.MovieFilterType, findFilter *models.FindFilterType) []*models.Movie {
func queryMovies(ctx context.Context, t *testing.T, movieFilter *models.MovieFilterType, findFilter *models.FindFilterType) []*models.Movie {
sqb := db.Movie
movies, _, err := sqb.Query(ctx, movieFilter, findFilter)
if err != nil {
t.Errorf("Error querying movie: %s", err.Error())
@@ -328,6 +666,102 @@ func queryMovie(ctx context.Context, t *testing.T, sqb models.MovieReader, movie
return movies
}
func TestMovieQueryTags(t *testing.T) {
withTxn(func(ctx context.Context) error {
tagCriterion := models.HierarchicalMultiCriterionInput{
Value: []string{
strconv.Itoa(tagIDs[tagIdxWithMovie]),
strconv.Itoa(tagIDs[tagIdx1WithMovie]),
},
Modifier: models.CriterionModifierIncludes,
}
movieFilter := models.MovieFilterType{
Tags: &tagCriterion,
}
// ensure ids are correct
movies := queryMovies(ctx, t, &movieFilter, nil)
assert.Len(t, movies, 3)
for _, movie := range movies {
assert.True(t, movie.ID == movieIDs[movieIdxWithTag] || movie.ID == movieIDs[movieIdxWithTwoTags] || movie.ID == movieIDs[movieIdxWithThreeTags])
}
tagCriterion = models.HierarchicalMultiCriterionInput{
Value: []string{
strconv.Itoa(tagIDs[tagIdx1WithMovie]),
strconv.Itoa(tagIDs[tagIdx2WithMovie]),
},
Modifier: models.CriterionModifierIncludesAll,
}
movies = queryMovies(ctx, t, &movieFilter, nil)
if assert.Len(t, movies, 2) {
assert.Equal(t, sceneIDs[movieIdxWithTwoTags], movies[0].ID)
assert.Equal(t, sceneIDs[movieIdxWithThreeTags], movies[1].ID)
}
tagCriterion = models.HierarchicalMultiCriterionInput{
Value: []string{
strconv.Itoa(tagIDs[tagIdx1WithMovie]),
},
Modifier: models.CriterionModifierExcludes,
}
q := getSceneStringValue(movieIdxWithTwoTags, titleField)
findFilter := models.FindFilterType{
Q: &q,
}
movies = queryMovies(ctx, t, &movieFilter, &findFilter)
assert.Len(t, movies, 0)
return nil
})
}
func TestMovieQueryTagCount(t *testing.T) {
const tagCount = 1
tagCountCriterion := models.IntCriterionInput{
Value: tagCount,
Modifier: models.CriterionModifierEquals,
}
verifyMoviesTagCount(t, tagCountCriterion)
tagCountCriterion.Modifier = models.CriterionModifierNotEquals
verifyMoviesTagCount(t, tagCountCriterion)
tagCountCriterion.Modifier = models.CriterionModifierGreaterThan
verifyMoviesTagCount(t, tagCountCriterion)
tagCountCriterion.Modifier = models.CriterionModifierLessThan
verifyMoviesTagCount(t, tagCountCriterion)
}
func verifyMoviesTagCount(t *testing.T, tagCountCriterion models.IntCriterionInput) {
withTxn(func(ctx context.Context) error {
sqb := db.Movie
movieFilter := models.MovieFilterType{
TagCount: &tagCountCriterion,
}
movies := queryMovies(ctx, t, &movieFilter, nil)
assert.Greater(t, len(movies), 0)
for _, movie := range movies {
ids, err := sqb.GetTagIDs(ctx, movie.ID)
if err != nil {
return err
}
verifyInt(t, len(ids), tagCountCriterion)
}
return nil
})
}
func TestMovieQuerySorting(t *testing.T) {
sort := "scenes_count"
direction := models.SortDirectionEnumDesc
@@ -337,8 +771,7 @@ func TestMovieQuerySorting(t *testing.T) {
}
withTxn(func(ctx context.Context) error {
sqb := db.Movie
movies := queryMovie(ctx, t, sqb, nil, &findFilter)
movies := queryMovies(ctx, t, nil, &findFilter)
// scenes should be in same order as indexes
firstMovie := movies[0]
@@ -348,7 +781,7 @@ func TestMovieQuerySorting(t *testing.T) {
// sort in descending order
direction = models.SortDirectionEnumAsc
movies = queryMovie(ctx, t, sqb, nil, &findFilter)
movies = queryMovies(ctx, t, nil, &findFilter)
lastMovie := movies[len(movies)-1]
assert.Equal(t, movieIDs[movieIdxWithScene], lastMovie.ID)

View File

@@ -0,0 +1,41 @@
package sqlite
import (
"context"
"github.com/stashapp/stash/pkg/models"
)
type idRelationshipStore struct {
joinTable *joinTable
}
func (s *idRelationshipStore) createRelationships(ctx context.Context, id int, fkIDs models.RelatedIDs) error {
if fkIDs.Loaded() {
if err := s.joinTable.insertJoins(ctx, id, fkIDs.List()); err != nil {
return err
}
}
return nil
}
func (s *idRelationshipStore) modifyRelationships(ctx context.Context, id int, fkIDs *models.UpdateIDs) error {
if fkIDs != nil {
if err := s.joinTable.modifyJoins(ctx, id, fkIDs.IDs, fkIDs.Mode); err != nil {
return err
}
}
return nil
}
func (s *idRelationshipStore) replaceRelationships(ctx context.Context, id int, fkIDs models.RelatedIDs) error {
if fkIDs.Loaded() {
if err := s.joinTable.replaceJoins(ctx, id, fkIDs.List()); err != nil {
return err
}
}
return nil
}

View File

@@ -150,9 +150,12 @@ const (
const (
movieIdxWithScene = iota
movieIdxWithStudio
movieIdxWithTag
movieIdxWithTwoTags
movieIdxWithThreeTags
// movies with dup names start from the end
// create 10 more basic movies (can remove this if we add more indexes)
movieIdxWithDupName = movieIdxWithStudio + 10
// create 7 more basic movies (can remove this if we add more indexes)
movieIdxWithDupName = movieIdxWithStudio + 7
moviesNameCase = movieIdxWithDupName
moviesNameNoCase = 1
@@ -214,6 +217,10 @@ const (
tagIdxWithParentAndChild
tagIdxWithGrandParent
tagIdx2WithMarkers
tagIdxWithMovie
tagIdx1WithMovie
tagIdx2WithMovie
tagIdx3WithMovie
// new indexes above
// tags with dup names start from the end
tagIdx1WithDupName
@@ -487,6 +494,12 @@ var (
movieStudioLinks = [][2]int{
{movieIdxWithStudio, studioIdxWithMovie},
}
movieTags = linkMap{
movieIdxWithTag: {tagIdxWithMovie},
movieIdxWithTwoTags: {tagIdx1WithMovie, tagIdx2WithMovie},
movieIdxWithThreeTags: {tagIdx1WithMovie, tagIdx2WithMovie, tagIdx3WithMovie},
}
)
var (
@@ -622,14 +635,14 @@ func populateDB() error {
// TODO - link folders to zip files
if err := createMovies(ctx, db.Movie, moviesNameCase, moviesNameNoCase); err != nil {
return fmt.Errorf("error creating movies: %s", err.Error())
}
if err := createTags(ctx, db.Tag, tagsNameCase, tagsNameNoCase); err != nil {
return fmt.Errorf("error creating tags: %s", err.Error())
}
if err := createMovies(ctx, db.Movie, moviesNameCase, moviesNameNoCase); err != nil {
return fmt.Errorf("error creating movies: %s", err.Error())
}
if err := createPerformers(ctx, performersNameCase, performersNameNoCase); err != nil {
return fmt.Errorf("error creating performers: %s", err.Error())
}
@@ -1321,6 +1334,8 @@ func createMovies(ctx context.Context, mqb models.MovieReaderWriter, n int, o in
index := i
name := namePlain
tids := indexesToIDs(tagIDs, movieTags[i])
if i >= n { // i<n tags get normal names
name = nameNoCase // i>=n movies get dup names if case is not checked
index = n + o - (i + 1) // for the name to be the same the number (index) must be the same also
@@ -1333,6 +1348,7 @@ func createMovies(ctx context.Context, mqb models.MovieReaderWriter, n int, o in
URLs: models.NewRelatedStrings([]string{
getMovieEmptyString(i, urlField),
}),
TagIDs: models.NewRelatedIDs(tids),
}
err := mqb.Create(ctx, &movie)

View File

@@ -155,6 +155,10 @@ func (t *table) join(j joiner, as string, parentIDCol string) {
type joinTable struct {
table
fkColumn exp.IdentifierExpression
// required for ordering
foreignTable *table
orderBy exp.OrderedExpression
}
func (t *joinTable) invert() *joinTable {
@@ -170,6 +174,13 @@ func (t *joinTable) invert() *joinTable {
func (t *joinTable) get(ctx context.Context, id int) ([]int, error) {
q := dialect.Select(t.fkColumn).From(t.table.table).Where(t.idColumn.Eq(id))
if t.orderBy != nil {
if t.foreignTable != nil {
q = q.InnerJoin(t.foreignTable.table, goqu.On(t.foreignTable.idColumn.Eq(t.fkColumn)))
}
q = q.Order(t.orderBy)
}
const single = false
var ret []int
if err := queryFunc(ctx, q, single, func(rows *sqlx.Rows) error {

View File

@@ -36,6 +36,7 @@ var (
studiosStashIDsJoinTable = goqu.T("studio_stash_ids")
moviesURLsJoinTable = goqu.T(movieURLsTable)
moviesTagsJoinTable = goqu.T(moviesTagsTable)
tagsAliasesJoinTable = goqu.T(tagAliasesTable)
tagRelationsJoinTable = goqu.T(tagRelationsTable)
@@ -330,6 +331,16 @@ var (
},
valueColumn: moviesURLsJoinTable.Col(movieURLColumn),
}
moviesTagsTableMgr = &joinTable{
table: table{
table: moviesTagsJoinTable,
idColumn: moviesTagsJoinTable.Col(movieIDColumn),
},
fkColumn: moviesTagsJoinTable.Col(tagIDColumn),
foreignTable: tagTableMgr,
orderBy: tagTableMgr.table.Col("name").Asc(),
}
)
var (

View File

@@ -424,6 +424,18 @@ func (qb *TagStore) FindByGalleryID(ctx context.Context, galleryID int) ([]*mode
return qb.queryTags(ctx, query, args)
}
func (qb *TagStore) FindByMovieID(ctx context.Context, movieID int) ([]*models.Tag, error) {
query := `
SELECT tags.* FROM tags
LEFT JOIN movies_tags as movies_join on movies_join.tag_id = tags.id
WHERE movies_join.movie_id = ?
GROUP BY tags.id
`
query += qb.getDefaultTagSort()
args := []interface{}{movieID}
return qb.queryTags(ctx, query, args)
}
func (qb *TagStore) FindBySceneMarkerID(ctx context.Context, sceneMarkerID int) ([]*models.Tag, error) {
query := `
SELECT tags.* FROM tags
@@ -615,6 +627,7 @@ var tagSortOptions = sortOptions{
"galleries_count",
"id",
"images_count",
"movies_count",
"name",
"performers_count",
"random",
@@ -655,6 +668,8 @@ func (qb *TagStore) getTagSort(query *queryBuilder, findFilter *models.FindFilte
sortQuery += getCountSort(tagTable, galleriesTagsTable, tagIDColumn, direction)
case "performers_count":
sortQuery += getCountSort(tagTable, performersTagsTable, tagIDColumn, direction)
case "movies_count":
sortQuery += getCountSort(tagTable, moviesTagsTable, tagIDColumn, direction)
default:
sortQuery += getSort(sort, direction, "tags")
}
@@ -888,3 +903,17 @@ SELECT t.*, c.path FROM tags t INNER JOIN children c ON t.id = c.child_id
return qb.queryTagPaths(ctx, query, args)
}
type tagRelationshipStore struct {
idRelationshipStore
}
func (s *tagRelationshipStore) CountByTagID(ctx context.Context, tagID int) (int, error) {
joinTable := s.joinTable.table.table
q := dialect.Select(goqu.COUNT("*")).From(joinTable).Where(joinTable.Col(tagIDColumn).Eq(tagID))
return count(ctx, q)
}
func (s *tagRelationshipStore) GetTagIDs(ctx context.Context, id int) ([]int, error) {
return s.joinTable.get(ctx, id)
}

View File

@@ -66,6 +66,7 @@ func (qb *tagFilterHandler) criterionHandler() criterionHandler {
qb.imageCountCriterionHandler(tagFilter.ImageCount),
qb.galleryCountCriterionHandler(tagFilter.GalleryCount),
qb.performerCountCriterionHandler(tagFilter.PerformerCount),
qb.movieCountCriterionHandler(tagFilter.MovieCount),
qb.markerCountCriterionHandler(tagFilter.MarkerCount),
qb.parentsCriterionHandler(tagFilter.Parents),
qb.childrenCriterionHandler(tagFilter.Children),
@@ -174,6 +175,17 @@ func (qb *tagFilterHandler) performerCountCriterionHandler(performerCount *model
}
}
func (qb *tagFilterHandler) movieCountCriterionHandler(movieCount *models.IntCriterionInput) criterionHandlerFunc {
return func(ctx context.Context, f *filterBuilder) {
if movieCount != nil {
f.addLeftJoin("movies_tags", "", "movies_tags.tag_id = tags.id")
clause, args := getIntCriterionWhereClause("count(distinct movies_tags.movie_id)", *movieCount)
f.addHaving(clause, args...)
}
}
}
func (qb *tagFilterHandler) markerCountCriterionHandler(markerCount *models.IntCriterionInput) criterionHandlerFunc {
return func(ctx context.Context, f *filterBuilder) {
if markerCount != nil {

View File

@@ -42,6 +42,33 @@ func TestMarkerFindBySceneMarkerID(t *testing.T) {
})
}
func TestTagFindByMovieID(t *testing.T) {
withTxn(func(ctx context.Context) error {
tqb := db.Tag
movieID := movieIDs[movieIdxWithTag]
tags, err := tqb.FindByMovieID(ctx, movieID)
if err != nil {
t.Errorf("Error finding tags: %s", err.Error())
}
assert.Len(t, tags, 1)
assert.Equal(t, tagIDs[tagIdxWithMovie], tags[0].ID)
tags, err = tqb.FindByMovieID(ctx, 0)
if err != nil {
t.Errorf("Error finding tags: %s", err.Error())
}
assert.Len(t, tags, 0)
return nil
})
}
func TestTagFindByName(t *testing.T) {
withTxn(func(ctx context.Context) error {
tqb := db.Tag
@@ -203,6 +230,10 @@ func TestTagQuerySort(t *testing.T) {
tags = queryTags(ctx, t, sqb, nil, findFilter)
assert.Equal(tagIDs[tagIdx2WithPerformer], tags[0].ID)
sortBy = "movies_count"
tags = queryTags(ctx, t, sqb, nil, findFilter)
assert.Equal(tagIDs[tagIdx1WithMovie], tags[0].ID)
return nil
})
}