diff --git a/pkg/manager/task_autotag_test.go b/pkg/autotag/integration_test.go similarity index 71% rename from pkg/manager/task_autotag_test.go rename to pkg/autotag/integration_test.go index 0eb755c4d..32a0e7501 100644 --- a/pkg/manager/task_autotag_test.go +++ b/pkg/autotag/integration_test.go @@ -1,6 +1,6 @@ // +build integration -package manager +package autotag import ( "context" @@ -8,8 +8,6 @@ import ( "fmt" "io/ioutil" "os" - "strings" - "sync" "testing" "github.com/stashapp/stash/pkg/database" @@ -22,49 +20,12 @@ import ( ) const testName = "Foo's Bar" -const testExtension = ".mp4" const existingStudioName = "ExistingStudio" -const existingStudioSceneName = testName + ".dontChangeStudio" + testExtension +const existingStudioSceneName = testName + ".dontChangeStudio.mp4" var existingStudioID int -var testSeparators = []string{ - ".", - "-", - "_", - " ", -} - -var testEndSeparators = []string{ - "{", - "}", - "(", - ")", - ",", -} - -func generateNamePatterns(name, separator string) []string { - var ret []string - ret = append(ret, fmt.Sprintf("%s%saaa"+testExtension, name, separator)) - ret = append(ret, fmt.Sprintf("aaa%s%s"+testExtension, separator, name)) - ret = append(ret, fmt.Sprintf("aaa%s%s%sbbb"+testExtension, separator, name, separator)) - ret = append(ret, fmt.Sprintf("dir/%s%saaa"+testExtension, name, separator)) - ret = append(ret, fmt.Sprintf("dir\\%s%saaa"+testExtension, name, separator)) - ret = append(ret, fmt.Sprintf("%s%saaa/dir/bbb"+testExtension, name, separator)) - ret = append(ret, fmt.Sprintf("%s%saaa\\dir\\bbb"+testExtension, name, separator)) - ret = append(ret, fmt.Sprintf("dir/%s%s/aaa"+testExtension, name, separator)) - ret = append(ret, fmt.Sprintf("dir\\%s%s\\aaa"+testExtension, name, separator)) - - return ret -} - -func generateFalseNamePattern(name string, separator string) string { - splitted := strings.Split(name, " ") - - return fmt.Sprintf("%s%saaa%s%s"+testExtension, splitted[0], separator, separator, splitted[1]) -} - func testTeardown(databaseFile string) { err := database.DB.Close() @@ -126,7 +87,7 @@ func createStudio(qb models.StudioWriter, name string) (*models.Studio, error) { // create the studio studio := models.Studio{ Checksum: name, - Name: sql.NullString{Valid: true, String: testName}, + Name: sql.NullString{Valid: true, String: name}, } return qb.Create(studio) @@ -148,23 +109,7 @@ func createTag(qb models.TagWriter) error { func createScenes(sqb models.SceneReaderWriter) error { // create the scenes - var scenePatterns []string - var falseScenePatterns []string - - separators := append(testSeparators, testEndSeparators...) - - for _, separator := range separators { - scenePatterns = append(scenePatterns, generateNamePatterns(testName, separator)...) - scenePatterns = append(scenePatterns, generateNamePatterns(strings.ToLower(testName), separator)...) - falseScenePatterns = append(falseScenePatterns, generateFalseNamePattern(testName, separator)) - } - - // add test cases for intra-name separators - for _, separator := range testSeparators { - if separator != " " { - scenePatterns = append(scenePatterns, generateNamePatterns(strings.Replace(testName, " ", separator, -1), separator)...) - } - } + scenePatterns, falseScenePatterns := generateScenePaths(testName) for _, fn := range scenePatterns { err := createScene(sqb, makeScene(fn, true)) @@ -278,17 +223,14 @@ func TestParsePerformers(t *testing.T) { return } - task := AutoTagPerformerTask{ - AutoTagTask: AutoTagTask{ - txnManager: sqlite.NewTransactionManager(), - }, - performer: performers[0], + for _, p := range performers { + if err := withTxn(func(r models.Repository) error { + return PerformerScenes(p, nil, r.Scene()) + }); err != nil { + t.Errorf("Error auto-tagging performers: %s", err) + } } - var wg sync.WaitGroup - wg.Add(1) - task.Start(&wg) - // verify that scenes were tagged correctly withTxn(func(r models.Repository) error { pqb := r.Performer() @@ -328,17 +270,14 @@ func TestParseStudios(t *testing.T) { return } - task := AutoTagStudioTask{ - AutoTagTask: AutoTagTask{ - txnManager: sqlite.NewTransactionManager(), - }, - studio: studios[0], + for _, s := range studios { + if err := withTxn(func(r models.Repository) error { + return StudioScenes(s, nil, r.Scene()) + }); err != nil { + t.Errorf("Error auto-tagging performers: %s", err) + } } - var wg sync.WaitGroup - wg.Add(1) - task.Start(&wg) - // verify that scenes were tagged correctly withTxn(func(r models.Repository) error { scenes, err := r.Scene().All() @@ -354,9 +293,14 @@ func TestParseStudios(t *testing.T) { } } else { // title is only set on scenes where we expect studio to be set - if scene.Title.String == scene.Path && scene.StudioID.Int64 != int64(studios[0].ID) { - t.Errorf("Did not set studio '%s' for path '%s'", testName, scene.Path) - } else if scene.Title.String != scene.Path && scene.StudioID.Int64 == int64(studios[0].ID) { + if scene.Title.String == scene.Path { + if !scene.StudioID.Valid { + t.Errorf("Did not set studio '%s' for path '%s'", testName, scene.Path) + } else if scene.StudioID.Int64 != int64(studios[1].ID) { + t.Errorf("Incorrect studio id %d set for path '%s'", scene.StudioID.Int64, scene.Path) + } + + } else if scene.Title.String != scene.Path && scene.StudioID.Int64 == int64(studios[1].ID) { t.Errorf("Incorrectly set studio '%s' for path '%s'", testName, scene.Path) } } @@ -377,17 +321,14 @@ func TestParseTags(t *testing.T) { return } - task := AutoTagTagTask{ - AutoTagTask: AutoTagTask{ - txnManager: sqlite.NewTransactionManager(), - }, - tag: tags[0], + for _, s := range tags { + if err := withTxn(func(r models.Repository) error { + return TagScenes(s, nil, r.Scene()) + }); err != nil { + t.Errorf("Error auto-tagging performers: %s", err) + } } - var wg sync.WaitGroup - wg.Add(1) - task.Start(&wg) - // verify that scenes were tagged correctly withTxn(func(r models.Repository) error { scenes, err := r.Scene().All() diff --git a/pkg/autotag/performer.go b/pkg/autotag/performer.go new file mode 100644 index 000000000..b6f40c1ad --- /dev/null +++ b/pkg/autotag/performer.go @@ -0,0 +1,42 @@ +package autotag + +import ( + "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/scene" +) + +func getMatchingPerformers(path string, performerReader models.PerformerReader) ([]*models.Performer, error) { + words := getPathWords(path) + performers, err := performerReader.QueryForAutoTag(words) + + if err != nil { + return nil, err + } + + var ret []*models.Performer + for _, p := range performers { + // TODO - commenting out alias handling until both sides work correctly + if nameMatchesPath(p.Name.String, path) { // || nameMatchesPath(p.Aliases.String, path) { + ret = append(ret, p) + } + } + + return ret, nil +} + +func getPerformerTagger(p *models.Performer) tagger { + return tagger{ + ID: p.ID, + Type: "performer", + Name: p.Name.String, + } +} + +// PerformerScenes searches for scenes whose path matches the provided performer name and tags the scene with the performer. +func PerformerScenes(p *models.Performer, paths []string, rw models.SceneReaderWriter) error { + t := getPerformerTagger(p) + + return t.tagScenes(paths, rw, func(subjectID, otherID int) (bool, error) { + return scene.AddPerformer(rw, otherID, subjectID) + }) +} diff --git a/pkg/autotag/performer_test.go b/pkg/autotag/performer_test.go new file mode 100644 index 000000000..1e935c9d5 --- /dev/null +++ b/pkg/autotag/performer_test.go @@ -0,0 +1,81 @@ +package autotag + +import ( + "testing" + + "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/models/mocks" + "github.com/stretchr/testify/assert" +) + +func TestPerformerScenes(t *testing.T) { + type test struct { + performerName string + expectedRegex string + } + + performerNames := []test{ + { + "performer name", + `(?i)(?:^|_|[^\w\d])performer[.\-_ ]*name(?:$|_|[^\w\d])`, + }, + { + "performer + name", + `(?i)(?:^|_|[^\w\d])performer[.\-_ ]*\+[.\-_ ]*name(?:$|_|[^\w\d])`, + }, + } + + for _, p := range performerNames { + testPerformerScenes(t, p.performerName, p.expectedRegex) + } +} + +func testPerformerScenes(t *testing.T, performerName, expectedRegex string) { + mockSceneReader := &mocks.SceneReaderWriter{} + + const performerID = 2 + + var scenes []*models.Scene + matchingPaths, falsePaths := generateScenePaths(performerName) + for i, p := range append(matchingPaths, falsePaths...) { + scenes = append(scenes, &models.Scene{ + ID: i + 1, + Path: p, + }) + } + + performer := models.Performer{ + ID: performerID, + Name: models.NullString(performerName), + } + + organized := false + perPage := models.PerPageAll + + expectedSceneFilter := &models.SceneFilterType{ + Organized: &organized, + Path: &models.StringCriterionInput{ + Value: expectedRegex, + Modifier: models.CriterionModifierMatchesRegex, + }, + } + + expectedFindFilter := &models.FindFilterType{ + PerPage: &perPage, + } + + mockSceneReader.On("Query", expectedSceneFilter, expectedFindFilter).Return(scenes, len(scenes), nil).Once() + + for i := range matchingPaths { + sceneID := i + 1 + mockSceneReader.On("GetPerformerIDs", sceneID).Return(nil, nil).Once() + mockSceneReader.On("UpdatePerformers", sceneID, []int{performerID}).Return(nil).Once() + } + + err := PerformerScenes(&performer, nil, mockSceneReader) + + assert := assert.New(t) + + assert.Nil(err) + mockSceneReader.AssertExpectations(t) +} diff --git a/pkg/autotag/scene.go b/pkg/autotag/scene.go new file mode 100644 index 000000000..d9bffd630 --- /dev/null +++ b/pkg/autotag/scene.go @@ -0,0 +1,117 @@ +package autotag + +import ( + "fmt" + "path/filepath" + "strings" + + "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/scene" +) + +func pathsFilter(paths []string) *models.SceneFilterType { + if paths == nil { + return nil + } + + sep := string(filepath.Separator) + + var ret *models.SceneFilterType + var or *models.SceneFilterType + for _, p := range paths { + newOr := &models.SceneFilterType{} + if or != nil { + or.Or = newOr + } else { + ret = newOr + } + + or = newOr + + if !strings.HasSuffix(p, sep) { + p = p + sep + } + + or.Path = &models.StringCriterionInput{ + Modifier: models.CriterionModifierEquals, + Value: p + "%", + } + } + + return ret +} + +func getMatchingScenes(name string, paths []string, sceneReader models.SceneReader) ([]*models.Scene, error) { + regex := getPathQueryRegex(name) + organized := false + filter := models.SceneFilterType{ + Path: &models.StringCriterionInput{ + Value: "(?i)" + regex, + Modifier: models.CriterionModifierMatchesRegex, + }, + Organized: &organized, + } + + filter.And = pathsFilter(paths) + + pp := models.PerPageAll + scenes, _, err := sceneReader.Query(&filter, &models.FindFilterType{ + PerPage: &pp, + }) + + if err != nil { + return nil, fmt.Errorf("error querying scenes with regex '%s': %s", regex, err.Error()) + } + + var ret []*models.Scene + for _, p := range scenes { + if nameMatchesPath(name, p.Path) { + ret = append(ret, p) + } + } + + return ret, nil +} + +func getSceneFileTagger(s *models.Scene) tagger { + return tagger{ + ID: s.ID, + Type: "scene", + Name: s.GetTitle(), + Path: s.Path, + } +} + +// ScenePerformers tags the provided scene with performers whose name matches the scene's path. +func ScenePerformers(s *models.Scene, rw models.SceneReaderWriter, performerReader models.PerformerReader) error { + t := getSceneFileTagger(s) + + return t.tagPerformers(performerReader, func(subjectID, otherID int) (bool, error) { + return scene.AddPerformer(rw, subjectID, otherID) + }) +} + +// SceneStudios tags the provided scene with the first studio whose name matches the scene's path. +// +// Scenes will not be tagged if studio is already set. +func SceneStudios(s *models.Scene, rw models.SceneReaderWriter, studioReader models.StudioReader) error { + if s.StudioID.Valid { + // don't modify + return nil + } + + t := getSceneFileTagger(s) + + return t.tagStudios(studioReader, func(subjectID, otherID int) (bool, error) { + return addSceneStudio(rw, subjectID, otherID) + }) +} + +// SceneTags tags the provided scene with tags whose name matches the scene's path. +func SceneTags(s *models.Scene, rw models.SceneReaderWriter, tagReader models.TagReader) error { + t := getSceneFileTagger(s) + + return t.tagTags(tagReader, func(subjectID, otherID int) (bool, error) { + return scene.AddTag(rw, subjectID, otherID) + }) +} diff --git a/pkg/autotag/scene_test.go b/pkg/autotag/scene_test.go new file mode 100644 index 000000000..a71056885 --- /dev/null +++ b/pkg/autotag/scene_test.go @@ -0,0 +1,276 @@ +package autotag + +import ( + "fmt" + "strings" + "testing" + + "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/models/mocks" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +var testSeparators = []string{ + ".", + "-", + "_", + " ", +} + +var testEndSeparators = []string{ + "{", + "}", + "(", + ")", + ",", +} + +func generateNamePatterns(name, separator string) []string { + var ret []string + ret = append(ret, fmt.Sprintf("%s%saaa.mp4", name, separator)) + ret = append(ret, fmt.Sprintf("aaa%s%s.mp4", separator, name)) + ret = append(ret, fmt.Sprintf("aaa%s%s%sbbb.mp4", separator, name, separator)) + ret = append(ret, fmt.Sprintf("dir/%s%saaa.mp4", name, separator)) + ret = append(ret, fmt.Sprintf("dir\\%s%saaa.mp4", name, separator)) + ret = append(ret, fmt.Sprintf("%s%saaa/dir/bbb.mp4", name, separator)) + ret = append(ret, fmt.Sprintf("%s%saaa\\dir\\bbb.mp4", name, separator)) + ret = append(ret, fmt.Sprintf("dir/%s%s/aaa.mp4", name, separator)) + ret = append(ret, fmt.Sprintf("dir\\%s%s\\aaa.mp4", name, separator)) + + return ret +} + +func generateSplitNamePatterns(name, separator string) []string { + var ret []string + splitted := strings.Split(name, " ") + // only do this for names that are split into two + if len(splitted) == 2 { + ret = append(ret, fmt.Sprintf("%s%s%s.mp4", splitted[0], separator, splitted[1])) + } + + return ret +} + +func generateFalseNamePatterns(name string, separator string) []string { + splitted := strings.Split(name, " ") + + var ret []string + // only do this for names that are split into two + if len(splitted) == 2 { + ret = append(ret, fmt.Sprintf("%s%saaa%s%s.mp4", splitted[0], separator, separator, splitted[1])) + } + + return ret +} + +func generateScenePaths(testName string) (scenePatterns []string, falseScenePatterns []string) { + separators := append(testSeparators, testEndSeparators...) + + for _, separator := range separators { + scenePatterns = append(scenePatterns, generateNamePatterns(testName, separator)...) + scenePatterns = append(scenePatterns, generateNamePatterns(strings.ToLower(testName), separator)...) + scenePatterns = append(scenePatterns, generateNamePatterns(strings.ReplaceAll(testName, " ", ""), separator)...) + falseScenePatterns = append(falseScenePatterns, generateFalseNamePatterns(testName, separator)...) + } + + // add test cases for intra-name separators + for _, separator := range testSeparators { + if separator != " " { + scenePatterns = append(scenePatterns, generateNamePatterns(strings.Replace(testName, " ", separator, -1), separator)...) + } + } + + // add basic false scenarios + falseScenePatterns = append(falseScenePatterns, fmt.Sprintf("aaa%s.mp4", testName)) + falseScenePatterns = append(falseScenePatterns, fmt.Sprintf("%saaa.mp4", testName)) + + // add path separator false scenarios + falseScenePatterns = append(falseScenePatterns, generateFalseNamePatterns(testName, "/")...) + falseScenePatterns = append(falseScenePatterns, generateFalseNamePatterns(testName, "\\")...) + + // split patterns only valid for ._- and whitespace + for _, separator := range testSeparators { + scenePatterns = append(scenePatterns, generateSplitNamePatterns(testName, separator)...) + } + + // false patterns for other separators + for _, separator := range testEndSeparators { + falseScenePatterns = append(falseScenePatterns, generateSplitNamePatterns(testName, separator)...) + } + + return +} + +type pathTestTable struct { + ScenePath string + Matches bool +} + +func generateTestTable(testName string) []pathTestTable { + var ret []pathTestTable + + var scenePatterns []string + var falseScenePatterns []string + + separators := append(testSeparators, testEndSeparators...) + + for _, separator := range separators { + scenePatterns = append(scenePatterns, generateNamePatterns(testName, separator)...) + scenePatterns = append(scenePatterns, generateNamePatterns(strings.ToLower(testName), separator)...) + falseScenePatterns = append(falseScenePatterns, generateFalseNamePatterns(testName, separator)...) + } + + for _, p := range scenePatterns { + t := pathTestTable{ + ScenePath: p, + Matches: true, + } + + ret = append(ret, t) + } + + for _, p := range falseScenePatterns { + t := pathTestTable{ + ScenePath: p, + Matches: false, + } + + ret = append(ret, t) + } + + return ret +} + +func TestScenePerformers(t *testing.T) { + const sceneID = 1 + const performerName = "performer name" + const performerID = 2 + performer := models.Performer{ + ID: performerID, + Name: models.NullString(performerName), + } + + const reversedPerformerName = "name performer" + const reversedPerformerID = 3 + reversedPerformer := models.Performer{ + ID: reversedPerformerID, + Name: models.NullString(reversedPerformerName), + } + + testTables := generateTestTable(performerName) + + assert := assert.New(t) + + for _, test := range testTables { + mockPerformerReader := &mocks.PerformerReaderWriter{} + mockSceneReader := &mocks.SceneReaderWriter{} + + mockPerformerReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once() + + if test.Matches { + mockSceneReader.On("GetPerformerIDs", sceneID).Return(nil, nil).Once() + mockSceneReader.On("UpdatePerformers", sceneID, []int{performerID}).Return(nil).Once() + } + + scene := models.Scene{ + ID: sceneID, + Path: test.ScenePath, + } + err := ScenePerformers(&scene, mockSceneReader, mockPerformerReader) + + assert.Nil(err) + mockPerformerReader.AssertExpectations(t) + mockSceneReader.AssertExpectations(t) + } +} + +func TestSceneStudios(t *testing.T) { + const sceneID = 1 + const studioName = "studio name" + const studioID = 2 + studio := models.Studio{ + ID: studioID, + Name: models.NullString(studioName), + } + + const reversedStudioName = "name studio" + const reversedStudioID = 3 + reversedStudio := models.Studio{ + ID: reversedStudioID, + Name: models.NullString(reversedStudioName), + } + + testTables := generateTestTable(studioName) + + assert := assert.New(t) + + for _, test := range testTables { + mockStudioReader := &mocks.StudioReaderWriter{} + mockSceneReader := &mocks.SceneReaderWriter{} + + mockStudioReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once() + + if test.Matches { + mockSceneReader.On("Find", sceneID).Return(&models.Scene{}, nil).Once() + expectedStudioID := models.NullInt64(studioID) + mockSceneReader.On("Update", models.ScenePartial{ + ID: sceneID, + StudioID: &expectedStudioID, + }).Return(nil, nil).Once() + } + + scene := models.Scene{ + ID: sceneID, + Path: test.ScenePath, + } + err := SceneStudios(&scene, mockSceneReader, mockStudioReader) + + assert.Nil(err) + mockStudioReader.AssertExpectations(t) + mockSceneReader.AssertExpectations(t) + } +} + +func TestSceneTags(t *testing.T) { + const sceneID = 1 + const tagName = "tag name" + const tagID = 2 + tag := models.Tag{ + ID: tagID, + Name: tagName, + } + + const reversedTagName = "name tag" + const reversedTagID = 3 + reversedTag := models.Tag{ + ID: reversedTagID, + Name: reversedTagName, + } + + testTables := generateTestTable(tagName) + + assert := assert.New(t) + + for _, test := range testTables { + mockTagReader := &mocks.TagReaderWriter{} + mockSceneReader := &mocks.SceneReaderWriter{} + + mockTagReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once() + + if test.Matches { + mockSceneReader.On("GetTagIDs", sceneID).Return(nil, nil).Once() + mockSceneReader.On("UpdateTags", sceneID, []int{tagID}).Return(nil).Once() + } + + scene := models.Scene{ + ID: sceneID, + Path: test.ScenePath, + } + err := SceneTags(&scene, mockSceneReader, mockTagReader) + + assert.Nil(err) + mockTagReader.AssertExpectations(t) + mockSceneReader.AssertExpectations(t) + } +} diff --git a/pkg/autotag/studio.go b/pkg/autotag/studio.go new file mode 100644 index 000000000..c01eedecc --- /dev/null +++ b/pkg/autotag/studio.go @@ -0,0 +1,66 @@ +package autotag + +import ( + "database/sql" + + "github.com/stashapp/stash/pkg/models" +) + +func getMatchingStudios(path string, reader models.StudioReader) ([]*models.Studio, error) { + words := getPathWords(path) + candidates, err := reader.QueryForAutoTag(words) + + if err != nil { + return nil, err + } + + var ret []*models.Studio + for _, c := range candidates { + if nameMatchesPath(c.Name.String, path) { + ret = append(ret, c) + } + } + + return ret, nil +} + +func addSceneStudio(sceneWriter models.SceneReaderWriter, sceneID, studioID int) (bool, error) { + // don't set if already set + scene, err := sceneWriter.Find(sceneID) + if err != nil { + return false, err + } + + if scene.StudioID.Valid { + return false, nil + } + + // set the studio id + s := sql.NullInt64{Int64: int64(studioID), Valid: true} + scenePartial := models.ScenePartial{ + ID: sceneID, + StudioID: &s, + } + + if _, err := sceneWriter.Update(scenePartial); err != nil { + return false, err + } + return true, nil +} + +func getStudioTagger(p *models.Studio) tagger { + return tagger{ + ID: p.ID, + Type: "studio", + Name: p.Name.String, + } +} + +// StudioScenes searches for scenes whose path matches the provided studio name and tags the scene with the studio, if studio is not already set on the scene. +func StudioScenes(p *models.Studio, paths []string, rw models.SceneReaderWriter) error { + t := getStudioTagger(p) + + return t.tagScenes(paths, rw, func(subjectID, otherID int) (bool, error) { + return addSceneStudio(rw, otherID, subjectID) + }) +} diff --git a/pkg/autotag/studio_test.go b/pkg/autotag/studio_test.go new file mode 100644 index 000000000..d61ba7efb --- /dev/null +++ b/pkg/autotag/studio_test.go @@ -0,0 +1,85 @@ +package autotag + +import ( + "testing" + + "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/models/mocks" + "github.com/stretchr/testify/assert" +) + +func TestStudioScenes(t *testing.T) { + type test struct { + studioName string + expectedRegex string + } + + studioNames := []test{ + { + "studio name", + `(?i)(?:^|_|[^\w\d])studio[.\-_ ]*name(?:$|_|[^\w\d])`, + }, + { + "studio + name", + `(?i)(?:^|_|[^\w\d])studio[.\-_ ]*\+[.\-_ ]*name(?:$|_|[^\w\d])`, + }, + } + + for _, p := range studioNames { + testStudioScenes(t, p.studioName, p.expectedRegex) + } +} + +func testStudioScenes(t *testing.T, studioName, expectedRegex string) { + mockSceneReader := &mocks.SceneReaderWriter{} + + const studioID = 2 + + var scenes []*models.Scene + matchingPaths, falsePaths := generateScenePaths(studioName) + for i, p := range append(matchingPaths, falsePaths...) { + scenes = append(scenes, &models.Scene{ + ID: i + 1, + Path: p, + }) + } + + studio := models.Studio{ + ID: studioID, + Name: models.NullString(studioName), + } + + organized := false + perPage := models.PerPageAll + + expectedSceneFilter := &models.SceneFilterType{ + Organized: &organized, + Path: &models.StringCriterionInput{ + Value: expectedRegex, + Modifier: models.CriterionModifierMatchesRegex, + }, + } + + expectedFindFilter := &models.FindFilterType{ + PerPage: &perPage, + } + + mockSceneReader.On("Query", expectedSceneFilter, expectedFindFilter).Return(scenes, len(scenes), nil).Once() + + for i := range matchingPaths { + sceneID := i + 1 + mockSceneReader.On("Find", sceneID).Return(&models.Scene{}, nil).Once() + expectedStudioID := models.NullInt64(studioID) + mockSceneReader.On("Update", models.ScenePartial{ + ID: sceneID, + StudioID: &expectedStudioID, + }).Return(nil, nil).Once() + } + + err := StudioScenes(&studio, nil, mockSceneReader) + + assert := assert.New(t) + + assert.Nil(err) + mockSceneReader.AssertExpectations(t) +} diff --git a/pkg/autotag/tag.go b/pkg/autotag/tag.go new file mode 100644 index 000000000..4f08394cc --- /dev/null +++ b/pkg/autotag/tag.go @@ -0,0 +1,41 @@ +package autotag + +import ( + "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/scene" +) + +func getMatchingTags(path string, tagReader models.TagReader) ([]*models.Tag, error) { + words := getPathWords(path) + tags, err := tagReader.QueryForAutoTag(words) + + if err != nil { + return nil, err + } + + var ret []*models.Tag + for _, p := range tags { + if nameMatchesPath(p.Name, path) { + ret = append(ret, p) + } + } + + return ret, nil +} + +func getTagTagger(p *models.Tag) tagger { + return tagger{ + ID: p.ID, + Type: "tag", + Name: p.Name, + } +} + +// TagScenes searches for scenes whose path matches the provided tag name and tags the scene with the tag. +func TagScenes(p *models.Tag, paths []string, rw models.SceneReaderWriter) error { + t := getTagTagger(p) + + return t.tagScenes(paths, rw, func(subjectID, otherID int) (bool, error) { + return scene.AddTag(rw, otherID, subjectID) + }) +} diff --git a/pkg/autotag/tag_test.go b/pkg/autotag/tag_test.go new file mode 100644 index 000000000..47dd3356e --- /dev/null +++ b/pkg/autotag/tag_test.go @@ -0,0 +1,81 @@ +package autotag + +import ( + "testing" + + "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/models/mocks" + "github.com/stretchr/testify/assert" +) + +func TestTagScenes(t *testing.T) { + type test struct { + tagName string + expectedRegex string + } + + tagNames := []test{ + { + "tag name", + `(?i)(?:^|_|[^\w\d])tag[.\-_ ]*name(?:$|_|[^\w\d])`, + }, + { + "tag + name", + `(?i)(?:^|_|[^\w\d])tag[.\-_ ]*\+[.\-_ ]*name(?:$|_|[^\w\d])`, + }, + } + + for _, p := range tagNames { + testTagScenes(t, p.tagName, p.expectedRegex) + } +} + +func testTagScenes(t *testing.T, tagName, expectedRegex string) { + mockSceneReader := &mocks.SceneReaderWriter{} + + const tagID = 2 + + var scenes []*models.Scene + matchingPaths, falsePaths := generateScenePaths(tagName) + for i, p := range append(matchingPaths, falsePaths...) { + scenes = append(scenes, &models.Scene{ + ID: i + 1, + Path: p, + }) + } + + tag := models.Tag{ + ID: tagID, + Name: tagName, + } + + organized := false + perPage := models.PerPageAll + + expectedSceneFilter := &models.SceneFilterType{ + Organized: &organized, + Path: &models.StringCriterionInput{ + Value: expectedRegex, + Modifier: models.CriterionModifierMatchesRegex, + }, + } + + expectedFindFilter := &models.FindFilterType{ + PerPage: &perPage, + } + + mockSceneReader.On("Query", expectedSceneFilter, expectedFindFilter).Return(scenes, len(scenes), nil).Once() + + for i := range matchingPaths { + sceneID := i + 1 + mockSceneReader.On("GetTagIDs", sceneID).Return(nil, nil).Once() + mockSceneReader.On("UpdateTags", sceneID, []int{tagID}).Return(nil).Once() + } + + err := TagScenes(&tag, nil, mockSceneReader) + + assert := assert.New(t) + + assert.Nil(err) + mockSceneReader.AssertExpectations(t) +} diff --git a/pkg/autotag/tagger.go b/pkg/autotag/tagger.go new file mode 100644 index 000000000..8690ffc8b --- /dev/null +++ b/pkg/autotag/tagger.go @@ -0,0 +1,198 @@ +// Package autotag provides methods to auto-tag scenes with performers, +// studios and tags. +// +// The autotag engine tags scenes with performers/studios/tags if the scene's +// path matches the performer/studio/tag name. A scene's path is considered +// a match if it contains the performer/studio/tag's full name, ignoring any +// '.', '-', '_' characters in the path. +// +// For example, for a performer "foo bar", the following paths would be +// considered a match: "foo bar.mp4", "foobar.mp4", "foo.bar.mp4", +// "foo-bar.mp4", "aaa.foo bar.bbb.mp4". +// The following would not be considered a match: +// "aafoo bar.mp4", "foo barbb.mp4", "foo/bar.mp4" +package autotag + +import ( + "fmt" + "path/filepath" + "regexp" + "strings" + + "github.com/stashapp/stash/pkg/logger" + "github.com/stashapp/stash/pkg/models" +) + +const separatorChars = `.\-_ ` + +// fixes #1292 +func escapePathRegex(name string) string { + ret := name + + chars := `+*?()|[]{}^$` + for _, c := range chars { + cStr := string(c) + ret = strings.ReplaceAll(ret, cStr, `\`+cStr) + } + + return ret +} + +func getPathQueryRegex(name string) string { + // escape specific regex characters + name = escapePathRegex(name) + + // handle path separators + const separator = `[` + separatorChars + `]` + + ret := strings.Replace(name, " ", separator+"*", -1) + ret = `(?:^|_|[^\w\d])` + ret + `(?:$|_|[^\w\d])` + return ret +} + +func nameMatchesPath(name, path string) bool { + // escape specific regex characters + name = escapePathRegex(name) + + name = strings.ToLower(name) + path = strings.ToLower(path) + + // handle path separators + const separator = `[` + separatorChars + `]` + + reStr := strings.Replace(name, " ", separator+"*", -1) + reStr = `(?:^|_|[^\w\d])` + reStr + `(?:$|_|[^\w\d])` + + re := regexp.MustCompile(reStr) + return re.MatchString(path) +} + +func getPathWords(path string) []string { + retStr := path + + // remove the extension + ext := filepath.Ext(retStr) + if ext != "" { + retStr = strings.TrimSuffix(retStr, ext) + } + + // handle path separators + const separator = `(?:_|[^\w\d])+` + re := regexp.MustCompile(separator) + retStr = re.ReplaceAllString(retStr, " ") + + words := strings.Split(retStr, " ") + + // remove any single letter words + var ret []string + for _, w := range words { + if len(w) > 1 { + ret = append(ret, w) + } + } + + return ret +} + +type tagger struct { + ID int + Type string + Name string + Path string +} + +type addLinkFunc func(subjectID, otherID int) (bool, error) + +func (t *tagger) addError(otherType, otherName string, err error) error { + return fmt.Errorf("error adding %s '%s' to %s '%s': %s", otherType, otherName, t.Type, t.Name, err.Error()) +} + +func (t *tagger) addLog(otherType, otherName string) { + logger.Infof("Added %s '%s' to %s '%s'", otherType, otherName, t.Type, t.Name) +} + +func (t *tagger) tagPerformers(performerReader models.PerformerReader, addFunc addLinkFunc) error { + others, err := getMatchingPerformers(t.Path, performerReader) + if err != nil { + return err + } + + for _, p := range others { + added, err := addFunc(t.ID, p.ID) + + if err != nil { + return t.addError("performer", p.Name.String, err) + } + + if added { + t.addLog("performer", p.Name.String) + } + } + + return nil +} + +func (t *tagger) tagStudios(studioReader models.StudioReader, addFunc addLinkFunc) error { + others, err := getMatchingStudios(t.Path, studioReader) + if err != nil { + return err + } + + // only add first studio + if len(others) > 0 { + studio := others[0] + added, err := addFunc(t.ID, studio.ID) + + if err != nil { + return t.addError("studio", studio.Name.String, err) + } + + if added { + t.addLog("studio", studio.Name.String) + } + } + + return nil +} + +func (t *tagger) tagTags(tagReader models.TagReader, addFunc addLinkFunc) error { + others, err := getMatchingTags(t.Path, tagReader) + if err != nil { + return err + } + + for _, p := range others { + added, err := addFunc(t.ID, p.ID) + + if err != nil { + return t.addError("tag", p.Name, err) + } + + if added { + t.addLog("tag", p.Name) + } + } + + return nil +} + +func (t *tagger) tagScenes(paths []string, sceneReader models.SceneReader, addFunc addLinkFunc) error { + others, err := getMatchingScenes(t.Name, paths, sceneReader) + if err != nil { + return err + } + + for _, p := range others { + added, err := addFunc(t.ID, p.ID) + + if err != nil { + return t.addError("scene", p.GetTitle(), err) + } + + if added { + t.addLog("scene", p.GetTitle()) + } + } + + return nil +} diff --git a/pkg/manager/manager_tasks.go b/pkg/manager/manager_tasks.go index 8f1abf551..ffa7f1722 100644 --- a/pkg/manager/manager_tasks.go +++ b/pkg/manager/manager_tasks.go @@ -5,12 +5,15 @@ import ( "errors" "fmt" "os" + "path/filepath" "strconv" + "strings" "sync" "time" "github.com/remeh/sizedwaitgroup" + "github.com/stashapp/stash/pkg/autotag" "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/manager/config" "github.com/stashapp/stash/pkg/models" @@ -626,6 +629,15 @@ func (s *singleton) generateScreenshot(sceneId string, at *float64) { }() } +func (s *singleton) isFileBasedAutoTag(input models.AutoTagMetadataInput) bool { + const wildcard = "*" + performerIds := input.Performers + studioIds := input.Studios + tagIds := input.Tags + + return (len(performerIds) == 0 || performerIds[0] == wildcard) && (len(studioIds) == 0 || studioIds[0] == wildcard) && (len(tagIds) == 0 || tagIds[0] == wildcard) +} + func (s *singleton) AutoTag(input models.AutoTagMetadataInput) { if s.Status.Status != Idle { return @@ -636,58 +648,160 @@ func (s *singleton) AutoTag(input models.AutoTagMetadataInput) { go func() { defer s.returnToIdleState() - performerIds := input.Performers - studioIds := input.Studios - tagIds := input.Tags - - // calculate work load - performerCount := len(performerIds) - studioCount := len(studioIds) - tagCount := len(tagIds) - - if err := s.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { - performerQuery := r.Performer() - studioQuery := r.Studio() - tagQuery := r.Tag() - - const wildcard = "*" - var err error - if performerCount == 1 && performerIds[0] == wildcard { - performerCount, err = performerQuery.Count() - if err != nil { - return fmt.Errorf("Error getting performer count: %s", err.Error()) - } - } - if studioCount == 1 && studioIds[0] == wildcard { - studioCount, err = studioQuery.Count() - if err != nil { - return fmt.Errorf("Error getting studio count: %s", err.Error()) - } - } - if tagCount == 1 && tagIds[0] == wildcard { - tagCount, err = tagQuery.Count() - if err != nil { - return fmt.Errorf("Error getting tag count: %s", err.Error()) - } - } - - return nil - }); err != nil { - logger.Error(err.Error()) - return + if s.isFileBasedAutoTag(input) { + // doing file-based auto-tag + s.autoTagScenes(input.Paths, len(input.Performers) > 0, len(input.Studios) > 0, len(input.Tags) > 0) + } else { + // doing specific performer/studio/tag auto-tag + s.autoTagSpecific(input) } - - total := performerCount + studioCount + tagCount - s.Status.setProgress(0, total) - - s.autoTagPerformers(input.Paths, performerIds) - s.autoTagStudios(input.Paths, studioIds) - s.autoTagTags(input.Paths, tagIds) }() } +func (s *singleton) autoTagScenes(paths []string, performers, studios, tags bool) { + if err := s.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { + ret := &models.SceneFilterType{} + or := ret + sep := string(filepath.Separator) + + for _, p := range paths { + if !strings.HasSuffix(p, sep) { + p = p + sep + } + + if ret.Path == nil { + or = ret + } else { + newOr := &models.SceneFilterType{} + or.Or = newOr + or = newOr + } + + or.Path = &models.StringCriterionInput{ + Modifier: models.CriterionModifierEquals, + Value: p + "%", + } + } + + organized := false + ret.Organized = &organized + + // batch process scenes + batchSize := 1000 + page := 1 + findFilter := &models.FindFilterType{ + PerPage: &batchSize, + Page: &page, + } + + more := true + processed := 0 + for more { + scenes, total, err := r.Scene().Query(ret, findFilter) + if err != nil { + return err + } + + if processed == 0 { + logger.Infof("Starting autotag of %d scenes", total) + } + + for _, ss := range scenes { + if s.Status.stopping { + logger.Info("Stopping due to user request") + return nil + } + + t := autoTagSceneTask{ + txnManager: s.TxnManager, + scene: ss, + performers: performers, + studios: studios, + tags: tags, + } + + var wg sync.WaitGroup + wg.Add(1) + go t.Start(&wg) + wg.Wait() + + processed++ + s.Status.setProgress(processed, total) + } + + if len(scenes) != batchSize { + more = false + } else { + page++ + } + } + + return nil + }); err != nil { + logger.Error(err.Error()) + } + + logger.Info("Finished autotag") +} + +func (s *singleton) autoTagSpecific(input models.AutoTagMetadataInput) { + performerIds := input.Performers + studioIds := input.Studios + tagIds := input.Tags + + performerCount := len(performerIds) + studioCount := len(studioIds) + tagCount := len(tagIds) + + if err := s.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { + performerQuery := r.Performer() + studioQuery := r.Studio() + tagQuery := r.Tag() + + const wildcard = "*" + var err error + if performerCount == 1 && performerIds[0] == wildcard { + performerCount, err = performerQuery.Count() + if err != nil { + return fmt.Errorf("error getting performer count: %s", err.Error()) + } + } + if studioCount == 1 && studioIds[0] == wildcard { + studioCount, err = studioQuery.Count() + if err != nil { + return fmt.Errorf("error getting studio count: %s", err.Error()) + } + } + if tagCount == 1 && tagIds[0] == wildcard { + tagCount, err = tagQuery.Count() + if err != nil { + return fmt.Errorf("error getting tag count: %s", err.Error()) + } + } + + return nil + }); err != nil { + logger.Error(err.Error()) + return + } + + total := performerCount + studioCount + tagCount + s.Status.setProgress(0, total) + + logger.Infof("Starting autotag of %d performers, %d studios, %d tags", performerCount, studioCount, tagCount) + + s.autoTagPerformers(input.Paths, performerIds) + s.autoTagStudios(input.Paths, studioIds) + s.autoTagTags(input.Paths, tagIds) + + logger.Info("Finished autotag") +} + func (s *singleton) autoTagPerformers(paths []string, performerIds []string) { - var wg sync.WaitGroup + if s.Status.stopping { + return + } + for _, performerId := range performerIds { var performers []*models.Performer @@ -698,46 +812,53 @@ func (s *singleton) autoTagPerformers(paths []string, performerIds []string) { var err error performers, err = performerQuery.All() if err != nil { - return fmt.Errorf("Error querying performers: %s", err.Error()) + return fmt.Errorf("error querying performers: %s", err.Error()) } } else { performerIdInt, err := strconv.Atoi(performerId) if err != nil { - return fmt.Errorf("Error parsing performer id %s: %s", performerId, err.Error()) + return fmt.Errorf("error parsing performer id %s: %s", performerId, err.Error()) } performer, err := performerQuery.Find(performerIdInt) if err != nil { - return fmt.Errorf("Error finding performer id %s: %s", performerId, err.Error()) + return fmt.Errorf("error finding performer id %s: %s", performerId, err.Error()) + } + + if performer == nil { + return fmt.Errorf("performer with id %s not found", performerId) } performers = append(performers, performer) } + for _, performer := range performers { + if s.Status.stopping { + logger.Info("Stopping due to user request") + return nil + } + + if err := s.TxnManager.WithTxn(context.TODO(), func(r models.Repository) error { + return autotag.PerformerScenes(performer, paths, r.Scene()) + }); err != nil { + return fmt.Errorf("error auto-tagging performer '%s': %s", performer.Name.String, err.Error()) + } + + s.Status.incrementProgress() + } + return nil }); err != nil { logger.Error(err.Error()) continue } - - for _, performer := range performers { - wg.Add(1) - task := AutoTagPerformerTask{ - AutoTagTask: AutoTagTask{ - txnManager: s.TxnManager, - paths: paths, - }, - performer: performer, - } - go task.Start(&wg) - wg.Wait() - - s.Status.incrementProgress() - } } } func (s *singleton) autoTagStudios(paths []string, studioIds []string) { - var wg sync.WaitGroup + if s.Status.stopping { + return + } + for _, studioId := range studioIds { var studios []*models.Studio @@ -747,46 +868,54 @@ func (s *singleton) autoTagStudios(paths []string, studioIds []string) { var err error studios, err = studioQuery.All() if err != nil { - return fmt.Errorf("Error querying studios: %s", err.Error()) + return fmt.Errorf("error querying studios: %s", err.Error()) } } else { studioIdInt, err := strconv.Atoi(studioId) if err != nil { - return fmt.Errorf("Error parsing studio id %s: %s", studioId, err.Error()) + return fmt.Errorf("error parsing studio id %s: %s", studioId, err.Error()) } studio, err := studioQuery.Find(studioIdInt) if err != nil { - return fmt.Errorf("Error finding studio id %s: %s", studioId, err.Error()) + return fmt.Errorf("error finding studio id %s: %s", studioId, err.Error()) } + + if studio == nil { + return fmt.Errorf("studio with id %s not found", studioId) + } + studios = append(studios, studio) } + for _, studio := range studios { + if s.Status.stopping { + logger.Info("Stopping due to user request") + return nil + } + + if err := s.TxnManager.WithTxn(context.TODO(), func(r models.Repository) error { + return autotag.StudioScenes(studio, paths, r.Scene()) + }); err != nil { + return fmt.Errorf("error auto-tagging studio '%s': %s", studio.Name.String, err.Error()) + } + + s.Status.incrementProgress() + } + return nil }); err != nil { logger.Error(err.Error()) continue } - - for _, studio := range studios { - wg.Add(1) - task := AutoTagStudioTask{ - AutoTagTask: AutoTagTask{ - txnManager: s.TxnManager, - paths: paths, - }, - studio: studio, - } - go task.Start(&wg) - wg.Wait() - - s.Status.incrementProgress() - } } } func (s *singleton) autoTagTags(paths []string, tagIds []string) { - var wg sync.WaitGroup + if s.Status.stopping { + return + } + for _, tagId := range tagIds { var tags []*models.Tag if err := s.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { @@ -795,41 +924,41 @@ func (s *singleton) autoTagTags(paths []string, tagIds []string) { var err error tags, err = tagQuery.All() if err != nil { - return fmt.Errorf("Error querying tags: %s", err.Error()) + return fmt.Errorf("error querying tags: %s", err.Error()) } } else { tagIdInt, err := strconv.Atoi(tagId) if err != nil { - return fmt.Errorf("Error parsing tag id %s: %s", tagId, err.Error()) + return fmt.Errorf("error parsing tag id %s: %s", tagId, err.Error()) } tag, err := tagQuery.Find(tagIdInt) if err != nil { - return fmt.Errorf("Error finding tag id %s: %s", tagId, err.Error()) + return fmt.Errorf("error finding tag id %s: %s", tagId, err.Error()) } tags = append(tags, tag) } + for _, tag := range tags { + if s.Status.stopping { + logger.Info("Stopping due to user request") + return nil + } + + if err := s.TxnManager.WithTxn(context.TODO(), func(r models.Repository) error { + return autotag.TagScenes(tag, paths, r.Scene()) + }); err != nil { + return fmt.Errorf("error auto-tagging tag '%s': %s", tag.Name, err.Error()) + } + + s.Status.incrementProgress() + } + return nil }); err != nil { logger.Error(err.Error()) continue } - - for _, tag := range tags { - wg.Add(1) - task := AutoTagTagTask{ - AutoTagTask: AutoTagTask{ - txnManager: s.TxnManager, - paths: paths, - }, - tag: tag, - } - go task.Start(&wg) - wg.Wait() - - s.Status.incrementProgress() - } } } diff --git a/pkg/manager/task_autotag.go b/pkg/manager/task_autotag.go index 6632dc2b9..6fb0a08e0 100644 --- a/pkg/manager/task_autotag.go +++ b/pkg/manager/task_autotag.go @@ -2,196 +2,38 @@ package manager import ( "context" - "database/sql" - "fmt" - "path/filepath" - "strings" "sync" + "github.com/stashapp/stash/pkg/autotag" "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/models" - "github.com/stashapp/stash/pkg/scene" ) -type AutoTagTask struct { - paths []string +type autoTagSceneTask struct { txnManager models.TransactionManager + scene *models.Scene + + performers bool + studios bool + tags bool } -type AutoTagPerformerTask struct { - AutoTagTask - performer *models.Performer -} - -func (t *AutoTagPerformerTask) Start(wg *sync.WaitGroup) { +func (t *autoTagSceneTask) Start(wg *sync.WaitGroup) { defer wg.Done() - - t.autoTagPerformer() -} - -func (t *AutoTagTask) getQueryRegex(name string) string { - const separatorChars = `.\-_ ` - // handle path separators - const separator = `[` + separatorChars + `]` - - ret := strings.Replace(name, " ", separator+"*", -1) - ret = `(?:^|_|[^\w\d])` + ret + `(?:$|_|[^\w\d])` - return ret -} - -func (t *AutoTagTask) getQueryFilter(regex string) *models.SceneFilterType { - organized := false - ret := &models.SceneFilterType{ - Path: &models.StringCriterionInput{ - Modifier: models.CriterionModifierMatchesRegex, - Value: "(?i)" + regex, - }, - Organized: &organized, - } - - sep := string(filepath.Separator) - - var or *models.SceneFilterType - for _, p := range t.paths { - newOr := &models.SceneFilterType{} - if or == nil { - ret.And = newOr - } else { - or.Or = newOr - } - - or = newOr - - if !strings.HasSuffix(p, sep) { - p = p + sep - } - - or.Path = &models.StringCriterionInput{ - Modifier: models.CriterionModifierEquals, - Value: p + "%", - } - } - - return ret -} - -func (t *AutoTagTask) getFindFilter() *models.FindFilterType { - perPage := -1 - return &models.FindFilterType{ - PerPage: &perPage, - } -} - -func (t *AutoTagPerformerTask) autoTagPerformer() { - regex := t.getQueryRegex(t.performer.Name.String) - if err := t.txnManager.WithTxn(context.TODO(), func(r models.Repository) error { - qb := r.Scene() - - scenes, _, err := qb.Query(t.getQueryFilter(regex), t.getFindFilter()) - - if err != nil { - return fmt.Errorf("Error querying scenes with regex '%s': %s", regex, err.Error()) - } - - for _, s := range scenes { - added, err := scene.AddPerformer(qb, s.ID, t.performer.ID) - - if err != nil { - return fmt.Errorf("Error adding performer '%s' to scene '%s': %s", t.performer.Name.String, s.GetTitle(), err.Error()) + if t.performers { + if err := autotag.ScenePerformers(t.scene, r.Scene(), r.Performer()); err != nil { + return err } - - if added { - logger.Infof("Added performer '%s' to scene '%s'", t.performer.Name.String, s.GetTitle()) - } - } - - return nil - }); err != nil { - logger.Error(err.Error()) - } -} - -type AutoTagStudioTask struct { - AutoTagTask - studio *models.Studio -} - -func (t *AutoTagStudioTask) Start(wg *sync.WaitGroup) { - defer wg.Done() - - t.autoTagStudio() -} - -func (t *AutoTagStudioTask) autoTagStudio() { - regex := t.getQueryRegex(t.studio.Name.String) - - if err := t.txnManager.WithTxn(context.TODO(), func(r models.Repository) error { - qb := r.Scene() - scenes, _, err := qb.Query(t.getQueryFilter(regex), t.getFindFilter()) - - if err != nil { - return fmt.Errorf("Error querying scenes with regex '%s': %s", regex, err.Error()) - } - - for _, s := range scenes { - // #306 - don't overwrite studio if already present - if s.StudioID.Valid { - // don't modify - continue - } - - logger.Infof("Adding studio '%s' to scene '%s'", t.studio.Name.String, s.GetTitle()) - - // set the studio id - studioID := sql.NullInt64{Int64: int64(t.studio.ID), Valid: true} - scenePartial := models.ScenePartial{ - ID: s.ID, - StudioID: &studioID, - } - - if _, err := qb.Update(scenePartial); err != nil { - return fmt.Errorf("Error adding studio to scene: %s", err.Error()) - } - } - - return nil - }); err != nil { - logger.Error(err.Error()) - } -} - -type AutoTagTagTask struct { - AutoTagTask - tag *models.Tag -} - -func (t *AutoTagTagTask) Start(wg *sync.WaitGroup) { - defer wg.Done() - - t.autoTagTag() -} - -func (t *AutoTagTagTask) autoTagTag() { - regex := t.getQueryRegex(t.tag.Name) - - if err := t.txnManager.WithTxn(context.TODO(), func(r models.Repository) error { - qb := r.Scene() - scenes, _, err := qb.Query(t.getQueryFilter(regex), t.getFindFilter()) - - if err != nil { - return fmt.Errorf("Error querying scenes with regex '%s': %s", regex, err.Error()) - } - - for _, s := range scenes { - added, err := scene.AddTag(qb, s.ID, t.tag.ID) - - if err != nil { - return fmt.Errorf("Error adding tag '%s' to scene '%s': %s", t.tag.Name, s.GetTitle(), err.Error()) - } - - if added { - logger.Infof("Added tag '%s' to scene '%s'", t.tag.Name, s.GetTitle()) + } + if t.studios { + if err := autotag.SceneStudios(t.scene, r.Scene(), r.Studio()); err != nil { + return err + } + } + if t.tags { + if err := autotag.SceneTags(t.scene, r.Scene(), r.Tag()); err != nil { + return err } } diff --git a/pkg/models/extension_find_filter.go b/pkg/models/extension_find_filter.go index f2e7e7ab1..8dc1ed515 100644 --- a/pkg/models/extension_find_filter.go +++ b/pkg/models/extension_find_filter.go @@ -1,5 +1,9 @@ package models +// PerPageAll is the value used for perPage to indicate all results should be +// returned. +const PerPageAll = -1 + func (ff FindFilterType) GetSort(defaultSort string) string { var sort string if ff.Sort == nil { diff --git a/pkg/models/mocks/PerformerReaderWriter.go b/pkg/models/mocks/PerformerReaderWriter.go index 60575ab3b..20629b3b5 100644 --- a/pkg/models/mocks/PerformerReaderWriter.go +++ b/pkg/models/mocks/PerformerReaderWriter.go @@ -388,6 +388,29 @@ func (_m *PerformerReaderWriter) Query(performerFilter *models.PerformerFilterTy return r0, r1, r2 } +// QueryForAutoTag provides a mock function with given fields: words +func (_m *PerformerReaderWriter) QueryForAutoTag(words []string) ([]*models.Performer, error) { + ret := _m.Called(words) + + var r0 []*models.Performer + if rf, ok := ret.Get(0).(func([]string) []*models.Performer); ok { + r0 = rf(words) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*models.Performer) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func([]string) error); ok { + r1 = rf(words) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // Update provides a mock function with given fields: updatedPerformer func (_m *PerformerReaderWriter) Update(updatedPerformer models.PerformerPartial) (*models.Performer, error) { ret := _m.Called(updatedPerformer) diff --git a/pkg/models/mocks/StudioReaderWriter.go b/pkg/models/mocks/StudioReaderWriter.go index fb9c02d7d..fbd8a1936 100644 --- a/pkg/models/mocks/StudioReaderWriter.go +++ b/pkg/models/mocks/StudioReaderWriter.go @@ -296,6 +296,29 @@ func (_m *StudioReaderWriter) Query(studioFilter *models.StudioFilterType, findF return r0, r1, r2 } +// QueryForAutoTag provides a mock function with given fields: words +func (_m *StudioReaderWriter) QueryForAutoTag(words []string) ([]*models.Studio, error) { + ret := _m.Called(words) + + var r0 []*models.Studio + if rf, ok := ret.Get(0).(func([]string) []*models.Studio); ok { + r0 = rf(words) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*models.Studio) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func([]string) error); ok { + r1 = rf(words) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // Update provides a mock function with given fields: updatedStudio func (_m *StudioReaderWriter) Update(updatedStudio models.StudioPartial) (*models.Studio, error) { ret := _m.Called(updatedStudio) diff --git a/pkg/models/mocks/TagReaderWriter.go b/pkg/models/mocks/TagReaderWriter.go index 65dcd8b89..e0d765577 100644 --- a/pkg/models/mocks/TagReaderWriter.go +++ b/pkg/models/mocks/TagReaderWriter.go @@ -367,6 +367,29 @@ func (_m *TagReaderWriter) Query(tagFilter *models.TagFilterType, findFilter *mo return r0, r1, r2 } +// QueryForAutoTag provides a mock function with given fields: words +func (_m *TagReaderWriter) QueryForAutoTag(words []string) ([]*models.Tag, error) { + ret := _m.Called(words) + + var r0 []*models.Tag + if rf, ok := ret.Get(0).(func([]string) []*models.Tag); ok { + r0 = rf(words) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*models.Tag) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func([]string) error); ok { + r1 = rf(words) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // Update provides a mock function with given fields: updatedTag func (_m *TagReaderWriter) Update(updatedTag models.Tag) (*models.Tag, error) { ret := _m.Called(updatedTag) diff --git a/pkg/models/performer.go b/pkg/models/performer.go index 2c550b720..fabcff44a 100644 --- a/pkg/models/performer.go +++ b/pkg/models/performer.go @@ -11,6 +11,9 @@ type PerformerReader interface { CountByTagID(tagID int) (int, error) Count() (int, error) All() ([]*Performer, error) + // TODO - this interface is temporary until the filter schema can fully + // support the query needed + QueryForAutoTag(words []string) ([]*Performer, error) Query(performerFilter *PerformerFilterType, findFilter *FindFilterType) ([]*Performer, int, error) GetImage(performerID int) ([]byte, error) GetStashIDs(performerID int) ([]*StashID, error) diff --git a/pkg/models/studio.go b/pkg/models/studio.go index 358abf596..7aa2e87b8 100644 --- a/pkg/models/studio.go +++ b/pkg/models/studio.go @@ -7,6 +7,9 @@ type StudioReader interface { FindByName(name string, nocase bool) (*Studio, error) Count() (int, error) All() ([]*Studio, error) + // TODO - this interface is temporary until the filter schema can fully + // support the query needed + QueryForAutoTag(words []string) ([]*Studio, error) Query(studioFilter *StudioFilterType, findFilter *FindFilterType) ([]*Studio, int, error) GetImage(studioID int) ([]byte, error) HasImage(studioID int) (bool, error) diff --git a/pkg/models/tag.go b/pkg/models/tag.go index 5f03e33b5..a675bfbdf 100644 --- a/pkg/models/tag.go +++ b/pkg/models/tag.go @@ -12,6 +12,9 @@ type TagReader interface { FindByNames(names []string, nocase bool) ([]*Tag, error) Count() (int, error) All() ([]*Tag, error) + // TODO - this interface is temporary until the filter schema can fully + // support the query needed + QueryForAutoTag(words []string) ([]*Tag, error) Query(tagFilter *TagFilterType, findFilter *FindFilterType) ([]*Tag, int, error) GetImage(tagID int) ([]byte, error) } diff --git a/pkg/sqlite/performer.go b/pkg/sqlite/performer.go index 17406eb7c..a631f1ad1 100644 --- a/pkg/sqlite/performer.go +++ b/pkg/sqlite/performer.go @@ -4,6 +4,7 @@ import ( "database/sql" "fmt" "strconv" + "strings" "github.com/stashapp/stash/pkg/models" ) @@ -172,6 +173,25 @@ func (qb *performerQueryBuilder) All() ([]*models.Performer, error) { return qb.queryPerformers(selectAll("performers")+qb.getPerformerSort(nil), nil) } +func (qb *performerQueryBuilder) QueryForAutoTag(words []string) ([]*models.Performer, error) { + // TODO - Query needs to be changed to support queries of this type, and + // this method should be removed + query := selectAll(performerTable) + + var whereClauses []string + var args []interface{} + + for _, w := range words { + whereClauses = append(whereClauses, "name like ?") + args = append(args, "%"+w+"%") + whereClauses = append(whereClauses, "aliases like ?") + args = append(args, "%"+w+"%") + } + + where := strings.Join(whereClauses, " OR ") + return qb.queryPerformers(query+" WHERE "+where, args) +} + func (qb *performerQueryBuilder) Query(performerFilter *models.PerformerFilterType, findFilter *models.FindFilterType) ([]*models.Performer, int, error) { if performerFilter == nil { performerFilter = &models.PerformerFilterType{} diff --git a/pkg/sqlite/performer_test.go b/pkg/sqlite/performer_test.go index d2f32182b..71237831c 100644 --- a/pkg/sqlite/performer_test.go +++ b/pkg/sqlite/performer_test.go @@ -100,6 +100,26 @@ func TestPerformerFindByNames(t *testing.T) { }) } +func TestPerformerQueryForAutoTag(t *testing.T) { + withTxn(func(r models.Repository) error { + tqb := r.Performer() + + name := performerNames[performerIdxWithScene] // find a performer by name + + performers, err := tqb.QueryForAutoTag([]string{name}) + + if err != nil { + t.Errorf("Error finding performers: %s", err.Error()) + } + + assert.Len(t, performers, 2) + assert.Equal(t, strings.ToLower(performerNames[performerIdxWithScene]), strings.ToLower(performers[0].Name.String)) + assert.Equal(t, strings.ToLower(performerNames[performerIdxWithScene]), strings.ToLower(performers[1].Name.String)) + + return nil + }) +} + func TestPerformerUpdatePerformerImage(t *testing.T) { if err := withTxn(func(r models.Repository) error { qb := r.Performer() diff --git a/pkg/sqlite/studio.go b/pkg/sqlite/studio.go index cbfb57f4d..2ec5f82c2 100644 --- a/pkg/sqlite/studio.go +++ b/pkg/sqlite/studio.go @@ -3,6 +3,7 @@ package sqlite import ( "database/sql" "fmt" + "strings" "github.com/stashapp/stash/pkg/models" ) @@ -121,6 +122,23 @@ func (qb *studioQueryBuilder) All() ([]*models.Studio, error) { return qb.queryStudios(selectAll("studios")+qb.getStudioSort(nil), nil) } +func (qb *studioQueryBuilder) QueryForAutoTag(words []string) ([]*models.Studio, error) { + // TODO - Query needs to be changed to support queries of this type, and + // this method should be removed + query := selectAll(studioTable) + + var whereClauses []string + var args []interface{} + + for _, w := range words { + whereClauses = append(whereClauses, "name like ?") + args = append(args, "%"+w+"%") + } + + where := strings.Join(whereClauses, " OR ") + return qb.queryStudios(query+" WHERE "+where, args) +} + func (qb *studioQueryBuilder) Query(studioFilter *models.StudioFilterType, findFilter *models.FindFilterType) ([]*models.Studio, int, error) { if studioFilter == nil { studioFilter = &models.StudioFilterType{} diff --git a/pkg/sqlite/studio_test.go b/pkg/sqlite/studio_test.go index ae17c1cf7..de947ee82 100644 --- a/pkg/sqlite/studio_test.go +++ b/pkg/sqlite/studio_test.go @@ -45,6 +45,26 @@ func TestStudioFindByName(t *testing.T) { }) } +func TestStudioQueryForAutoTag(t *testing.T) { + withTxn(func(r models.Repository) error { + tqb := r.Studio() + + name := studioNames[studioIdxWithScene] // find a studio by name + + studios, err := tqb.QueryForAutoTag([]string{name}) + + if err != nil { + t.Errorf("Error finding studios: %s", err.Error()) + } + + assert.Len(t, studios, 2) + assert.Equal(t, strings.ToLower(studioNames[studioIdxWithScene]), strings.ToLower(studios[0].Name.String)) + assert.Equal(t, strings.ToLower(studioNames[studioIdxWithScene]), strings.ToLower(studios[1].Name.String)) + + return nil + }) +} + func TestStudioQueryParent(t *testing.T) { withTxn(func(r models.Repository) error { sqb := r.Studio() diff --git a/pkg/sqlite/tag.go b/pkg/sqlite/tag.go index acb0105d1..a4549acd9 100644 --- a/pkg/sqlite/tag.go +++ b/pkg/sqlite/tag.go @@ -4,6 +4,7 @@ import ( "database/sql" "errors" "fmt" + "strings" "github.com/stashapp/stash/pkg/models" ) @@ -192,6 +193,23 @@ func (qb *tagQueryBuilder) All() ([]*models.Tag, error) { return qb.queryTags(selectAll("tags")+qb.getDefaultTagSort(), nil) } +func (qb *tagQueryBuilder) QueryForAutoTag(words []string) ([]*models.Tag, error) { + // TODO - Query needs to be changed to support queries of this type, and + // this method should be removed + query := selectAll(tagTable) + + var whereClauses []string + var args []interface{} + + for _, w := range words { + whereClauses = append(whereClauses, "name like ?") + args = append(args, "%"+w+"%") + } + + where := strings.Join(whereClauses, " OR ") + return qb.queryTags(query+" WHERE "+where, args) +} + func (qb *tagQueryBuilder) validateFilter(tagFilter *models.TagFilterType) error { const and = "AND" const or = "OR" diff --git a/pkg/sqlite/tag_test.go b/pkg/sqlite/tag_test.go index 272006776..f6a29def2 100644 --- a/pkg/sqlite/tag_test.go +++ b/pkg/sqlite/tag_test.go @@ -70,6 +70,26 @@ func TestTagFindByName(t *testing.T) { }) } +func TestTagQueryForAutoTag(t *testing.T) { + withTxn(func(r models.Repository) error { + tqb := r.Tag() + + name := tagNames[tagIdxWithScene] // find a tag by name + + tags, err := tqb.QueryForAutoTag([]string{name}) + + if err != nil { + t.Errorf("Error finding tags: %s", err.Error()) + } + + assert.Len(t, tags, 2) + assert.Equal(t, strings.ToLower(tagNames[tagIdxWithScene]), strings.ToLower(tags[0].Name)) + assert.Equal(t, strings.ToLower(tagNames[tagIdxWithScene]), strings.ToLower(tags[1].Name)) + + return nil + }) +} + func TestTagFindByNames(t *testing.T) { var names []string diff --git a/ui/v2.5/src/components/Changelog/versions/v070.md b/ui/v2.5/src/components/Changelog/versions/v070.md index 356ef3a1b..ceae9c844 100644 --- a/ui/v2.5/src/components/Changelog/versions/v070.md +++ b/ui/v2.5/src/components/Changelog/versions/v070.md @@ -11,6 +11,7 @@ * Added scene queue. ### 🎨 Improvements +* Improved performance of the auto-tagger. * Clean generation artifacts after generating each scene. * Log message at startup when cleaning the `tmp` and `downloads` generated folders takes more than one second. * Sort movie scenes by scene number by default. @@ -27,6 +28,7 @@ * Change performer text query to search by name and alias only. ### 🐛 Bug fixes +* Fixed error when auto-tagging for performers/studios/tags with regex characters in the name. * Fix scraped performer image not updating after clearing the current image when creating a new performer. * Fix error preventing adding a new library path when an existing library path is missing. * Fix whitespace in query string returning all objects.