From c91ffe1e58483c6579162becac84369e8f4d0c56 Mon Sep 17 00:00:00 2001 From: gitgiggety <79809426+gitgiggety@users.noreply.github.com> Date: Thu, 9 Sep 2021 06:58:43 +0200 Subject: [PATCH] Tag hierarchy (#1519) * Add migration script for tag relations table * Expand hierarchical filter features Expand the features of the hierarchical multi input filter with support for using a relations table, which only has parent_id and child_id columns, and support adding an additional intermediate table to join on, for example for scenes and tags which are linked by the scenes_tags table as well. * Add hierarchical filtering for tags * Add hierarchical tags support to scene markers Refactor filtering of scene markers to filterBuilder and in the process add support for hierarchical tags as well. * List parent and child tags on tag details page * Support setting parent and child tags Add support for setting parent and child tags during tag creation and tag updates. * Validate no loops are created in tags hierarchy * Update tag merging to support tag hierarcy * Add unit tests for tags.EnsureUniqueHierarchy * Fix applying recursive to with clause The SQL `RECURSIVE` of a `WITH` clause only needs to be applied once, imediately after the `WITH`. So this fixes the query building to do just that, automatically applying the `RECURSIVE` keyword when any added with clause is added as recursive. * Rename hierarchical root id column * Rewrite hierarchical filtering for performance Completely rewrite the hierarchical filtering to optimize for performance. Doing the recursive query in combination with a complex query seems to break SQLite optimizing some things which means that the recursive part might be 2,5 second slower than adding a static `VALUES()` list. This is mostly noticable in case of the tag hierarchy where setting an exclusion with any depth (or depth: all) being applied has this performance impact of 2,5 second. "Include" also suffered this issue, but some rewritten query by joining in the *_tags table in one pass and applying a `WHERE x IS NOT NULL` filter did seem to optimize that case. But that optimization isn't applied to the `IS NULL` filter of "exclude". Running a simple query beforehand to get all (recursive) items and then applying them to the query doesn't have this performance penalty. * Remove UI references to child studios and tags * Add parents to tag export * Support importing of parent relationship for tags * Assign stable ids to parent / child badges * Silence Apollo warning on parents/children fields on tags Silence warning triggered by Apollo GraphQL by explicitly instructing it to use the incoming parents/children values. By default it already does this, but it triggers a warning as it might be unintended that it uses the incoming values (instead of for example merging both arrays). Setting merge to false still applies the same behaviour (use only incoming values) but silences the warning as it's explicitly configured to work like this. * Rework detecting unique tag hierarchy Completely rework the unique tag hierarchy to detect invalid hierarchies for which a tag is "added in the middle". So when there are tags A <- B and A <- C, you could previously edit tag B and add tag C as a sub tag without it being noticed as parent A being applied twice (to tag C). While afterwards saving tag C would fail as tag A was applied as parent twice. The updated code correctly detects this scenario as well. Furthermore the error messaging has been reworked a bit and the message now mentions both the direct parent / sub tag as well as the tag which would results in the error. So in aboves example it would now show the message that tag C can't be applied because tag A already is a parent. * Update relations on cached tags when needed Update the relations on cached tags when a tag is created / updated / deleted so these always reflect the correct state. Otherwise (re)opening a tag might still show the old relations untill the page is fully reloaded or the list is navigated. But this obviously is strange when you for example have tag A, create or update tag B to have a relation to tag A, and from tags B page click through to tag A and it doesn't show that it is linked to tag B. --- graphql/documents/data/tag.graphql | 8 + graphql/schema/types/filters.graphql | 20 +- graphql/schema/types/tag.graphql | 9 + pkg/api/resolver_model_tag.go | 22 ++ pkg/api/resolver_mutation_tag.go | 71 ++++ pkg/database/database.go | 2 +- .../migrations/26_tag_hierarchy.up.sql | 7 + pkg/dlna/cds.go | 3 +- pkg/gallery/query.go | 3 +- pkg/image/query.go | 3 +- pkg/manager/jsonschema/tag.go | 1 + pkg/manager/task_import.go | 56 +++- pkg/models/mocks/TagReaderWriter.go | 120 +++++++ pkg/models/tag.go | 6 + pkg/sqlite/filter.go | 148 +++++++-- pkg/sqlite/gallery.go | 56 ++-- pkg/sqlite/gallery_test.go | 18 +- pkg/sqlite/image.go | 56 ++-- pkg/sqlite/image_test.go | 18 +- pkg/sqlite/movies.go | 2 + pkg/sqlite/performer.go | 31 +- pkg/sqlite/performer_test.go | 9 +- pkg/sqlite/query.go | 15 +- pkg/sqlite/repository.go | 8 +- pkg/sqlite/scene.go | 52 ++- pkg/sqlite/scene_marker.go | 197 +++++------- pkg/sqlite/scene_test.go | 18 +- pkg/sqlite/tag.go | 134 ++++++++ pkg/tag/export.go | 7 + pkg/tag/export_test.go | 41 ++- pkg/tag/import.go | 69 +++- pkg/tag/import_test.go | 113 +++++++ pkg/tag/update.go | 162 +++++++++- pkg/tag/update_test.go | 302 ++++++++++++++++++ .../components/Changelog/versions/v0100.md | 2 + .../Filters/HierarchicalLabelValueFilter.tsx | 14 +- .../components/Shared/DeleteEntityDialog.tsx | 5 + .../Studios/StudioDetails/Studio.tsx | 2 +- .../src/components/Tags/TagDetails/Tag.tsx | 27 +- .../Tags/TagDetails/TagDetailsPanel.tsx | 51 ++- .../Tags/TagDetails/TagEditPanel.tsx | 62 +++- .../Tags/TagDetails/TagMarkersPanel.tsx | 9 +- ui/v2.5/src/components/Tags/TagList.tsx | 18 ++ ui/v2.5/src/core/createClient.ts | 11 + ui/v2.5/src/core/tags.ts | 70 +++- ui/v2.5/src/locales/de-DE.json | 3 +- ui/v2.5/src/locales/en-GB.json | 7 +- ui/v2.5/src/locales/pt-BR.json | 3 +- ui/v2.5/src/locales/zh-CN.json | 3 +- ui/v2.5/src/locales/zh-TW.json | 3 +- .../src/models/list-filter/criteria/tags.ts | 7 +- ui/v2.5/src/utils/navigation.ts | 25 +- 52 files changed, 1778 insertions(+), 331 deletions(-) create mode 100644 pkg/database/migrations/26_tag_hierarchy.up.sql create mode 100644 pkg/tag/update_test.go diff --git a/graphql/documents/data/tag.graphql b/graphql/documents/data/tag.graphql index cf1e2c050..ab83adf73 100644 --- a/graphql/documents/data/tag.graphql +++ b/graphql/documents/data/tag.graphql @@ -8,4 +8,12 @@ fragment TagData on Tag { image_count gallery_count performer_count + + parents { + ...SlimTagData + } + + children { + ...SlimTagData + } } diff --git a/graphql/schema/types/filters.graphql b/graphql/schema/types/filters.graphql index 76640bb24..04a508f06 100644 --- a/graphql/schema/types/filters.graphql +++ b/graphql/schema/types/filters.graphql @@ -72,7 +72,7 @@ input PerformerFilterType { """Filter to only include performers missing this property""" is_missing: String """Filter to only include performers with these tags""" - tags: MultiCriterionInput + tags: HierarchicalMultiCriterionInput """Filter by tag count""" tag_count: IntCriterionInput """Filter by scene count""" @@ -99,11 +99,11 @@ input PerformerFilterType { input SceneMarkerFilterType { """Filter to only include scene markers with this tag""" - tag_id: ID + tag_id: ID @deprecated(reason: "use tags filter instead") """Filter to only include scene markers with these tags""" - tags: MultiCriterionInput + tags: HierarchicalMultiCriterionInput """Filter to only include scene markers attached to a scene with these tags""" - scene_tags: MultiCriterionInput + scene_tags: HierarchicalMultiCriterionInput """Filter to only include scene markers with these performers""" performers: MultiCriterionInput } @@ -143,11 +143,11 @@ input SceneFilterType { """Filter to only include scenes with this movie""" movies: MultiCriterionInput """Filter to only include scenes with these tags""" - tags: MultiCriterionInput + tags: HierarchicalMultiCriterionInput """Filter by tag count""" tag_count: IntCriterionInput """Filter to only include scenes with performers with these tags""" - performer_tags: MultiCriterionInput + performer_tags: HierarchicalMultiCriterionInput """Filter to only include scenes with these performers""" performers: MultiCriterionInput """Filter by performer count""" @@ -226,11 +226,11 @@ input GalleryFilterType { """Filter to only include galleries with this studio""" studios: HierarchicalMultiCriterionInput """Filter to only include galleries with these tags""" - tags: MultiCriterionInput + tags: HierarchicalMultiCriterionInput """Filter by tag count""" tag_count: IntCriterionInput """Filter to only include galleries with performers with these tags""" - performer_tags: MultiCriterionInput + performer_tags: HierarchicalMultiCriterionInput """Filter to only include galleries with these performers""" performers: MultiCriterionInput """Filter by performer count""" @@ -295,11 +295,11 @@ input ImageFilterType { """Filter to only include images with this studio""" studios: HierarchicalMultiCriterionInput """Filter to only include images with these tags""" - tags: MultiCriterionInput + tags: HierarchicalMultiCriterionInput """Filter by tag count""" tag_count: IntCriterionInput """Filter to only include images with performers with these tags""" - performer_tags: MultiCriterionInput + performer_tags: HierarchicalMultiCriterionInput """Filter to only include images with these performers""" performers: MultiCriterionInput """Filter by performer count""" diff --git a/graphql/schema/types/tag.graphql b/graphql/schema/types/tag.graphql index 46f8c25cd..fea9f40fe 100644 --- a/graphql/schema/types/tag.graphql +++ b/graphql/schema/types/tag.graphql @@ -11,6 +11,9 @@ type Tag { image_count: Int # Resolver gallery_count: Int # Resolver performer_count: Int + + parents: [Tag!]! + children: [Tag!]! } input TagCreateInput { @@ -19,6 +22,9 @@ input TagCreateInput { """This should be a URL or a base64 encoded data URL""" image: String + + parent_ids: [ID!] + child_ids: [ID!] } input TagUpdateInput { @@ -28,6 +34,9 @@ input TagUpdateInput { """This should be a URL or a base64 encoded data URL""" image: String + + parent_ids: [ID!] + child_ids: [ID!] } input TagDestroyInput { diff --git a/pkg/api/resolver_model_tag.go b/pkg/api/resolver_model_tag.go index a863cd233..3b90463ca 100644 --- a/pkg/api/resolver_model_tag.go +++ b/pkg/api/resolver_model_tag.go @@ -10,6 +10,28 @@ import ( "github.com/stashapp/stash/pkg/models" ) +func (r *tagResolver) Parents(ctx context.Context, obj *models.Tag) (ret []*models.Tag, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + ret, err = repo.Tag().FindByChildTagID(obj.ID) + return err + }); err != nil { + return nil, err + } + + return ret, nil +} + +func (r *tagResolver) Children(ctx context.Context, obj *models.Tag) (ret []*models.Tag, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + ret, err = repo.Tag().FindByParentTagID(obj.ID) + return err + }); err != nil { + return nil, err + } + + return ret, nil +} + func (r *tagResolver) Aliases(ctx context.Context, obj *models.Tag) (ret []string, err error) { if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { ret, err = repo.Tag().GetAliases(obj.ID) diff --git a/pkg/api/resolver_mutation_tag.go b/pkg/api/resolver_mutation_tag.go index 94292a7ba..d1dc230e4 100644 --- a/pkg/api/resolver_mutation_tag.go +++ b/pkg/api/resolver_mutation_tag.go @@ -75,6 +75,28 @@ func (r *mutationResolver) TagCreate(ctx context.Context, input models.TagCreate } } + if input.ParentIds != nil && len(input.ParentIds) > 0 { + ids, err := utils.StringSliceToIntSlice(input.ParentIds) + if err != nil { + return err + } + + if err := qb.UpdateParentTags(t.ID, ids); err != nil { + return err + } + } + + if input.ChildIds != nil && len(input.ChildIds) > 0 { + ids, err := utils.StringSliceToIntSlice(input.ChildIds) + if err != nil { + return err + } + + if err := qb.UpdateChildTags(t.ID, ids); err != nil { + return err + } + } + return nil }); err != nil { return nil, err @@ -161,6 +183,41 @@ func (r *mutationResolver) TagUpdate(ctx context.Context, input models.TagUpdate } } + var parentIDs []int + var childIDs []int + + if translator.hasField("parent_ids") { + parentIDs, err = utils.StringSliceToIntSlice(input.ParentIds) + if err != nil { + return err + } + } + + if translator.hasField("child_ids") { + childIDs, err = utils.StringSliceToIntSlice(input.ChildIds) + if err != nil { + return err + } + } + + if parentIDs != nil || childIDs != nil { + if err := tag.EnsureUniqueHierarchy(tagID, parentIDs, childIDs, qb); err != nil { + return err + } + } + + if parentIDs != nil { + if err := qb.UpdateParentTags(tagID, parentIDs); err != nil { + return err + } + } + + if childIDs != nil { + if err := qb.UpdateChildTags(tagID, childIDs); err != nil { + return err + } + } + return nil }); err != nil { return nil, err @@ -242,10 +299,24 @@ func (r *mutationResolver) TagsMerge(ctx context.Context, input models.TagsMerge return fmt.Errorf("Tag with ID %d not found", destination) } + parents, children, err := tag.MergeHierarchy(destination, source, qb) + if err != nil { + return err + } + if err = qb.Merge(source, destination); err != nil { return err } + err = qb.UpdateParentTags(destination, parents) + if err != nil { + return err + } + err = qb.UpdateChildTags(destination, children) + if err != nil { + return err + } + return nil }); err != nil { return nil, err diff --git a/pkg/database/database.go b/pkg/database/database.go index cedcc9428..b4a6a8c29 100644 --- a/pkg/database/database.go +++ b/pkg/database/database.go @@ -23,7 +23,7 @@ import ( var DB *sqlx.DB var WriteMu *sync.Mutex var dbPath string -var appSchemaVersion uint = 25 +var appSchemaVersion uint = 26 var databaseSchemaVersion uint var ( diff --git a/pkg/database/migrations/26_tag_hierarchy.up.sql b/pkg/database/migrations/26_tag_hierarchy.up.sql new file mode 100644 index 000000000..6e9855fc2 --- /dev/null +++ b/pkg/database/migrations/26_tag_hierarchy.up.sql @@ -0,0 +1,7 @@ +CREATE TABLE tags_relations ( + parent_id integer, + child_id integer, + primary key (parent_id, child_id), + foreign key (parent_id) references tags(id) on delete cascade, + foreign key (child_id) references tags(id) on delete cascade +); diff --git a/pkg/dlna/cds.go b/pkg/dlna/cds.go index e99f63be8..79a80ab51 100644 --- a/pkg/dlna/cds.go +++ b/pkg/dlna/cds.go @@ -570,9 +570,10 @@ func (me *contentDirectoryService) getTags() []interface{} { func (me *contentDirectoryService) getTagScenes(paths []string, host string) []interface{} { sceneFilter := &models.SceneFilterType{ - Tags: &models.MultiCriterionInput{ + Tags: &models.HierarchicalMultiCriterionInput{ Modifier: models.CriterionModifierIncludes, Value: []string{paths[0]}, + Depth: 0, }, } diff --git a/pkg/gallery/query.go b/pkg/gallery/query.go index a354f5aa5..a7fbed856 100644 --- a/pkg/gallery/query.go +++ b/pkg/gallery/query.go @@ -31,9 +31,10 @@ func CountByStudioID(r models.GalleryReader, id int) (int, error) { func CountByTagID(r models.GalleryReader, id int) (int, error) { filter := &models.GalleryFilterType{ - Tags: &models.MultiCriterionInput{ + Tags: &models.HierarchicalMultiCriterionInput{ Value: []string{strconv.Itoa(id)}, Modifier: models.CriterionModifierIncludes, + Depth: 0, }, } diff --git a/pkg/image/query.go b/pkg/image/query.go index b5c70738c..eca6d49a4 100644 --- a/pkg/image/query.go +++ b/pkg/image/query.go @@ -31,9 +31,10 @@ func CountByStudioID(r models.ImageReader, id int) (int, error) { func CountByTagID(r models.ImageReader, id int) (int, error) { filter := &models.ImageFilterType{ - Tags: &models.MultiCriterionInput{ + Tags: &models.HierarchicalMultiCriterionInput{ Value: []string{strconv.Itoa(id)}, Modifier: models.CriterionModifierIncludes, + Depth: 0, }, } diff --git a/pkg/manager/jsonschema/tag.go b/pkg/manager/jsonschema/tag.go index 66fa910ab..fbb85a835 100644 --- a/pkg/manager/jsonschema/tag.go +++ b/pkg/manager/jsonschema/tag.go @@ -12,6 +12,7 @@ type Tag struct { Name string `json:"name,omitempty"` Aliases []string `json:"aliases,omitempty"` Image string `json:"image,omitempty"` + Parents []string `json:"parents,omitempty"` CreatedAt models.JSONTime `json:"created_at,omitempty"` UpdatedAt models.JSONTime `json:"updated_at,omitempty"` } diff --git a/pkg/manager/task_import.go b/pkg/manager/task_import.go index db3d9f76d..1bf2f3476 100644 --- a/pkg/manager/task_import.go +++ b/pkg/manager/task_import.go @@ -376,6 +376,7 @@ func (t *ImportTask) ImportGalleries(ctx context.Context) { } func (t *ImportTask) ImportTags(ctx context.Context) { + pendingParent := make(map[string][]*jsonschema.Tag) logger.Info("[tags] importing") for i, mappingJSON := range t.mappings.Tags { @@ -389,23 +390,64 @@ func (t *ImportTask) ImportTags(ctx context.Context) { logger.Progressf("[tags] %d of %d", index, len(t.mappings.Tags)) if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error { - readerWriter := r.Tag() - - tagImporter := &tag.Importer{ - ReaderWriter: readerWriter, - Input: *tagJSON, + return t.ImportTag(tagJSON, pendingParent, false, r.Tag()) + }); err != nil { + if parentError, ok := err.(tag.ParentTagNotExistError); ok { + pendingParent[parentError.MissingParent()] = append(pendingParent[parentError.MissingParent()], tagJSON) + continue } - return performImport(tagImporter, t.DuplicateBehaviour) - }); err != nil { logger.Errorf("[tags] <%s> failed to import: %s", mappingJSON.Checksum, err.Error()) continue } } + for _, s := range pendingParent { + for _, orphanTagJSON := range s { + if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error { + return t.ImportTag(orphanTagJSON, nil, true, r.Tag()) + }); err != nil { + logger.Errorf("[tags] <%s> failed to create: %s", orphanTagJSON.Name, err.Error()) + continue + } + } + } + logger.Info("[tags] import complete") } +func (t *ImportTask) ImportTag(tagJSON *jsonschema.Tag, pendingParent map[string][]*jsonschema.Tag, fail bool, readerWriter models.TagReaderWriter) error { + importer := &tag.Importer{ + ReaderWriter: readerWriter, + Input: *tagJSON, + MissingRefBehaviour: t.MissingRefBehaviour, + } + + // first phase: return error if parent does not exist + if !fail { + importer.MissingRefBehaviour = models.ImportMissingRefEnumFail + } + + if err := performImport(importer, t.DuplicateBehaviour); err != nil { + return err + } + + for _, childTagJSON := range pendingParent[tagJSON.Name] { + if err := t.ImportTag(childTagJSON, pendingParent, fail, readerWriter); err != nil { + if parentError, ok := err.(tag.ParentTagNotExistError); ok { + pendingParent[parentError.MissingParent()] = append(pendingParent[parentError.MissingParent()], tagJSON) + continue + } + + return fmt.Errorf("failed to create child tag <%s>: %s", childTagJSON.Name, err.Error()) + } + } + + delete(pendingParent, tagJSON.Name) + + return nil +} + func (t *ImportTask) ImportScrapedItems(ctx context.Context) { if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error { logger.Info("[scraped sites] importing") diff --git a/pkg/models/mocks/TagReaderWriter.go b/pkg/models/mocks/TagReaderWriter.go index f7cabb8ce..f6b257ed2 100644 --- a/pkg/models/mocks/TagReaderWriter.go +++ b/pkg/models/mocks/TagReaderWriter.go @@ -130,6 +130,75 @@ func (_m *TagReaderWriter) Find(id int) (*models.Tag, error) { return r0, r1 } +// FindAllAncestors provides a mock function with given fields: tagID, excludeIDs +func (_m *TagReaderWriter) FindAllAncestors(tagID int, excludeIDs []int) ([]*models.Tag, error) { + ret := _m.Called(tagID, excludeIDs) + + var r0 []*models.Tag + if rf, ok := ret.Get(0).(func(int, []int) []*models.Tag); ok { + r0 = rf(tagID, excludeIDs) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*models.Tag) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(int, []int) error); ok { + r1 = rf(tagID, excludeIDs) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// FindAllDescendants provides a mock function with given fields: tagID, excludeIDs +func (_m *TagReaderWriter) FindAllDescendants(tagID int, excludeIDs []int) ([]*models.Tag, error) { + ret := _m.Called(tagID, excludeIDs) + + var r0 []*models.Tag + if rf, ok := ret.Get(0).(func(int, []int) []*models.Tag); ok { + r0 = rf(tagID, excludeIDs) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*models.Tag) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(int, []int) error); ok { + r1 = rf(tagID, excludeIDs) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// FindByChildTagID provides a mock function with given fields: childID +func (_m *TagReaderWriter) FindByChildTagID(childID int) ([]*models.Tag, error) { + ret := _m.Called(childID) + + var r0 []*models.Tag + if rf, ok := ret.Get(0).(func(int) []*models.Tag); ok { + r0 = rf(childID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*models.Tag) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(int) error); ok { + r1 = rf(childID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // FindByGalleryID provides a mock function with given fields: galleryID func (_m *TagReaderWriter) FindByGalleryID(galleryID int) ([]*models.Tag, error) { ret := _m.Called(galleryID) @@ -222,6 +291,29 @@ func (_m *TagReaderWriter) FindByNames(names []string, nocase bool) ([]*models.T return r0, r1 } +// FindByParentTagID provides a mock function with given fields: parentID +func (_m *TagReaderWriter) FindByParentTagID(parentID int) ([]*models.Tag, error) { + ret := _m.Called(parentID) + + var r0 []*models.Tag + if rf, ok := ret.Get(0).(func(int) []*models.Tag); ok { + r0 = rf(parentID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*models.Tag) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(int) error); ok { + r1 = rf(parentID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // FindByPerformerID provides a mock function with given fields: performerID func (_m *TagReaderWriter) FindByPerformerID(performerID int) ([]*models.Tag, error) { ret := _m.Called(performerID) @@ -464,6 +556,20 @@ func (_m *TagReaderWriter) UpdateAliases(tagID int, aliases []string) error { return r0 } +// UpdateChildTags provides a mock function with given fields: tagID, parentIDs +func (_m *TagReaderWriter) UpdateChildTags(tagID int, parentIDs []int) error { + ret := _m.Called(tagID, parentIDs) + + var r0 error + if rf, ok := ret.Get(0).(func(int, []int) error); ok { + r0 = rf(tagID, parentIDs) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // UpdateFull provides a mock function with given fields: updatedTag func (_m *TagReaderWriter) UpdateFull(updatedTag models.Tag) (*models.Tag, error) { ret := _m.Called(updatedTag) @@ -500,3 +606,17 @@ func (_m *TagReaderWriter) UpdateImage(tagID int, image []byte) error { return r0 } + +// UpdateParentTags provides a mock function with given fields: tagID, parentIDs +func (_m *TagReaderWriter) UpdateParentTags(tagID int, parentIDs []int) error { + ret := _m.Called(tagID, parentIDs) + + var r0 error + if rf, ok := ret.Get(0).(func(int, []int) error); ok { + r0 = rf(tagID, parentIDs) + } else { + r0 = ret.Error(0) + } + + return r0 +} diff --git a/pkg/models/tag.go b/pkg/models/tag.go index 4d3e0d84e..bcb7096f0 100644 --- a/pkg/models/tag.go +++ b/pkg/models/tag.go @@ -10,6 +10,8 @@ type TagReader interface { FindByGalleryID(galleryID int) ([]*Tag, error) FindByName(name string, nocase bool) (*Tag, error) FindByNames(names []string, nocase bool) ([]*Tag, error) + FindByParentTagID(parentID int) ([]*Tag, error) + FindByChildTagID(childID int) ([]*Tag, error) Count() (int, error) All() ([]*Tag, error) // TODO - this interface is temporary until the filter schema can fully @@ -18,6 +20,8 @@ type TagReader interface { Query(tagFilter *TagFilterType, findFilter *FindFilterType) ([]*Tag, int, error) GetImage(tagID int) ([]byte, error) GetAliases(tagID int) ([]string, error) + FindAllAncestors(tagID int, excludeIDs []int) ([]*Tag, error) + FindAllDescendants(tagID int, excludeIDs []int) ([]*Tag, error) } type TagWriter interface { @@ -29,6 +33,8 @@ type TagWriter interface { DestroyImage(tagID int) error UpdateAliases(tagID int, aliases []string) error Merge(source []int, destination int) error + UpdateParentTags(tagID int, parentIDs []int) error + UpdateChildTags(tagID int, parentIDs []int) error } type TagReaderWriter interface { diff --git a/pkg/sqlite/filter.go b/pkg/sqlite/filter.go index d83ce12b0..0d47c2fb2 100644 --- a/pkg/sqlite/filter.go +++ b/pkg/sqlite/filter.go @@ -3,7 +3,9 @@ package sqlite import ( "errors" "fmt" + "github.com/stashapp/stash/pkg/logger" "regexp" + "strconv" "strings" "github.com/stashapp/stash/pkg/models" @@ -97,6 +99,7 @@ type filterBuilder struct { whereClauses []sqlClause havingClauses []sqlClause withClauses []sqlClause + recursiveWith bool err error } @@ -188,6 +191,16 @@ func (f *filterBuilder) addWith(sql string, args ...interface{}) { f.withClauses = append(f.withClauses, makeClause(sql, args...)) } +// addRecursiveWith adds a with clause and arguments to the filter, and sets it to recursive +func (f *filterBuilder) addRecursiveWith(sql string, args ...interface{}) { + if sql == "" { + return + } + + f.addWith(sql, args...) + f.recursiveWith = true +} + func (f *filterBuilder) getSubFilterClause(clause, subFilterClause string) string { ret := clause @@ -510,18 +523,41 @@ func (m *stringListCriterionHandlerBuilder) handler(criterion *models.StringCrit } type hierarchicalMultiCriterionHandlerBuilder struct { + tx dbi + primaryTable string foreignTable string foreignFK string - derivedTable string - parentFK string + derivedTable string + parentFK string + relationsTable string } -func addHierarchicalWithClause(f *filterBuilder, value []string, derivedTable, table, parentFK string, depth int) { +func getHierarchicalValues(tx dbi, values []string, table, relationsTable, parentFK string, depth int) string { var args []interface{} - for _, value := range value { + if depth == 0 { + valid := true + var valuesClauses []string + for _, value := range values { + id, err := strconv.Atoi(value) + // In case of invalid value just run the query. + // Building VALUES() based on provided values just saves a query when depth is 0. + if err != nil { + valid = false + break + } + + valuesClauses = append(valuesClauses, fmt.Sprintf("(%d,%d)", id, id)) + } + + if valid { + return "VALUES" + strings.Join(valuesClauses, ",") + } + } + + for _, value := range values { args = append(args, value) } inCount := len(args) @@ -532,45 +568,107 @@ func addHierarchicalWithClause(f *filterBuilder, value []string, derivedTable, t } withClauseMap := utils.StrFormatMap{ - "derivedTable": derivedTable, - "table": table, - "inBinding": getInBinding(inCount), - "parentFK": parentFK, - "depthCondition": depthCondition, - "unionClause": "", + "table": table, + "relationsTable": relationsTable, + "inBinding": getInBinding(inCount), + "recursiveSelect": "", + "parentFK": parentFK, + "depthCondition": depthCondition, + "unionClause": "", + } + + if relationsTable != "" { + withClauseMap["recursiveSelect"] = utils.StrFormat(`SELECT p.root_id, c.child_id, depth + 1 FROM {relationsTable} AS c +INNER JOIN items as p ON c.parent_id = p.item_id +`, withClauseMap) + } else { + withClauseMap["recursiveSelect"] = utils.StrFormat(`SELECT p.root_id, c.id, depth + 1 FROM {table} as c +INNER JOIN items as p ON c.{parentFK} = p.item_id +`, withClauseMap) } if depth != 0 { withClauseMap["unionClause"] = utils.StrFormat(` -UNION SELECT p.id, c.id, depth + 1 FROM {table} as c -INNER JOIN {derivedTable} as p ON c.{parentFK} = p.child_id {depthCondition} +UNION {recursiveSelect} {depthCondition} `, withClauseMap) } - withClause := utils.StrFormat(`RECURSIVE {derivedTable} AS ( -SELECT id as id, id as child_id, 0 as depth FROM {table} + withClause := utils.StrFormat(`items AS ( +SELECT id as root_id, id as item_id, 0 as depth FROM {table} WHERE id in {inBinding} {unionClause}) `, withClauseMap) - f.addWith(withClause, args...) + query := fmt.Sprintf("WITH RECURSIVE %s SELECT 'VALUES' || GROUP_CONCAT('(' || root_id || ', ' || item_id || ')') AS val FROM items", withClause) + + var valuesClause string + err := tx.Get(&valuesClause, query, args...) + if err != nil { + logger.Error(err) + // return record which never matches so we don't have to handle error here + return "VALUES(NULL, NULL)" + } + + return valuesClause +} + +func addHierarchicalConditionClauses(f *filterBuilder, criterion *models.HierarchicalMultiCriterionInput, table, idColumn string) { + if criterion.Modifier == models.CriterionModifierIncludes { + f.addWhere(fmt.Sprintf("%s.%s IS NOT NULL", table, idColumn)) + } else if criterion.Modifier == models.CriterionModifierIncludesAll { + f.addWhere(fmt.Sprintf("%s.%s IS NOT NULL", table, idColumn)) + f.addHaving(fmt.Sprintf("count(distinct %s.%s) IS %d", table, idColumn, len(criterion.Value))) + } else if criterion.Modifier == models.CriterionModifierExcludes { + f.addWhere(fmt.Sprintf("%s.%s IS NULL", table, idColumn)) + } } func (m *hierarchicalMultiCriterionHandlerBuilder) handler(criterion *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { return func(f *filterBuilder) { if criterion != nil && len(criterion.Value) > 0 { - addHierarchicalWithClause(f, criterion.Value, m.derivedTable, m.foreignTable, m.parentFK, criterion.Depth) + valuesClause := getHierarchicalValues(m.tx, criterion.Value, m.foreignTable, m.relationsTable, m.parentFK, criterion.Depth) - f.addJoin(m.derivedTable, "", fmt.Sprintf("%s.child_id = %s.%s", m.derivedTable, m.primaryTable, m.foreignFK)) + f.addJoin("(SELECT column1 AS root_id, column2 AS item_id FROM ("+valuesClause+"))", m.derivedTable, fmt.Sprintf("%s.item_id = %s.%s", m.derivedTable, m.primaryTable, m.foreignFK)) - if criterion.Modifier == models.CriterionModifierIncludes { - f.addWhere(fmt.Sprintf("%s.id IS NOT NULL", m.derivedTable)) - } else if criterion.Modifier == models.CriterionModifierIncludesAll { - f.addWhere(fmt.Sprintf("%s.id IS NOT NULL", m.derivedTable)) - f.addHaving(fmt.Sprintf("count(distinct %s.id) IS %d", m.derivedTable, len(criterion.Value))) - } else if criterion.Modifier == models.CriterionModifierExcludes { - f.addWhere(fmt.Sprintf("%s.id IS NULL", m.derivedTable)) - } + addHierarchicalConditionClauses(f, criterion, m.derivedTable, "root_id") + } + } +} + +type joinedHierarchicalMultiCriterionHandlerBuilder struct { + tx dbi + + primaryTable string + foreignTable string + foreignFK string + + parentFK string + relationsTable string + + joinAs string + joinTable string + primaryFK string +} + +func (m *joinedHierarchicalMultiCriterionHandlerBuilder) handler(criterion *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { + return func(f *filterBuilder) { + if criterion != nil && len(criterion.Value) > 0 { + valuesClause := getHierarchicalValues(m.tx, criterion.Value, m.foreignTable, m.relationsTable, m.parentFK, criterion.Depth) + + joinAlias := m.joinAs + joinTable := utils.StrFormat(`( + SELECT j.*, d.column1 AS root_id, d.column2 AS item_id FROM {joinTable} AS j + INNER JOIN ({valuesClause}) AS d ON j.{foreignFK} = d.column2 +) +`, utils.StrFormatMap{ + "joinTable": m.joinTable, + "foreignFK": m.foreignFK, + "valuesClause": valuesClause, + }) + + f.addJoin(joinTable, joinAlias, fmt.Sprintf("%s.%s = %s.id", joinAlias, m.primaryFK, m.primaryTable)) + + addHierarchicalConditionClauses(f, criterion, joinAlias, "root_id") } } } diff --git a/pkg/sqlite/gallery.go b/pkg/sqlite/gallery.go index 30f13c75c..2564c068b 100644 --- a/pkg/sqlite/gallery.go +++ b/pkg/sqlite/gallery.go @@ -311,17 +311,18 @@ func galleryIsMissingCriterionHandler(qb *galleryQueryBuilder, isMissing *string } } -func galleryTagsCriterionHandler(qb *galleryQueryBuilder, tags *models.MultiCriterionInput) criterionHandlerFunc { - h := joinedMultiCriterionHandlerBuilder{ - primaryTable: galleryTable, - joinTable: galleriesTagsTable, - joinAs: "tags_join", - primaryFK: galleryIDColumn, - foreignFK: tagIDColumn, +func galleryTagsCriterionHandler(qb *galleryQueryBuilder, tags *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { + h := joinedHierarchicalMultiCriterionHandlerBuilder{ + tx: qb.tx, - addJoinTable: func(f *filterBuilder) { - qb.tagsRepository().join(f, "tags_join", "galleries.id") - }, + primaryTable: galleryTable, + foreignTable: tagTable, + foreignFK: "tag_id", + + relationsTable: "tags_relations", + joinAs: "image_tag", + joinTable: galleriesTagsTable, + primaryFK: galleryIDColumn, } return h.handler(tags) @@ -375,6 +376,8 @@ func galleryImageCountCriterionHandler(qb *galleryQueryBuilder, imageCount *mode func galleryStudioCriterionHandler(qb *galleryQueryBuilder, studios *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { h := hierarchicalMultiCriterionHandlerBuilder{ + tx: qb.tx, + primaryTable: galleryTable, foreignTable: studioTable, foreignFK: studioIDColumn, @@ -385,31 +388,20 @@ func galleryStudioCriterionHandler(qb *galleryQueryBuilder, studios *models.Hier return h.handler(studios) } -func galleryPerformerTagsCriterionHandler(qb *galleryQueryBuilder, performerTagsFilter *models.MultiCriterionInput) criterionHandlerFunc { +func galleryPerformerTagsCriterionHandler(qb *galleryQueryBuilder, tags *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { return func(f *filterBuilder) { - if performerTagsFilter != nil && len(performerTagsFilter.Value) > 0 { - qb.performersRepository().join(f, "performers_join", "galleries.id") - f.addJoin("performers_tags", "performer_tags_join", "performers_join.performer_id = performer_tags_join.performer_id") + if tags != nil && len(tags.Value) > 0 { + valuesClause := getHierarchicalValues(qb.tx, tags.Value, tagTable, "tags_relations", "", tags.Depth) - var args []interface{} - for _, tagID := range performerTagsFilter.Value { - args = append(args, tagID) - } + f.addWith(`performer_tags AS ( +SELECT pg.gallery_id, t.column1 AS root_tag_id FROM performers_galleries pg +INNER JOIN performers_tags pt ON pt.performer_id = pg.performer_id +INNER JOIN (` + valuesClause + `) t ON t.column2 = pt.tag_id +)`) - if performerTagsFilter.Modifier == models.CriterionModifierIncludes { - // includes any of the provided ids - f.addWhere("performer_tags_join.tag_id IN "+getInBinding(len(performerTagsFilter.Value)), args...) - } else if performerTagsFilter.Modifier == models.CriterionModifierIncludesAll { - // includes all of the provided ids - f.addWhere("performer_tags_join.tag_id IN "+getInBinding(len(performerTagsFilter.Value)), args...) - f.addHaving(fmt.Sprintf("count(distinct performer_tags_join.tag_id) IS %d", len(performerTagsFilter.Value))) - } else if performerTagsFilter.Modifier == models.CriterionModifierExcludes { - f.addWhere(fmt.Sprintf(`not exists - (select performers_galleries.performer_id from performers_galleries - left join performers_tags on performers_tags.performer_id = performers_galleries.performer_id where - performers_galleries.gallery_id = galleries.id AND - performers_tags.tag_id in %s)`, getInBinding(len(performerTagsFilter.Value))), args...) - } + f.addJoin("performer_tags", "", "performer_tags.gallery_id = galleries.id") + + addHierarchicalConditionClauses(f, tags, "performer_tags", "root_tag_id") } } } diff --git a/pkg/sqlite/gallery_test.go b/pkg/sqlite/gallery_test.go index 03f03a544..4ba755b00 100644 --- a/pkg/sqlite/gallery_test.go +++ b/pkg/sqlite/gallery_test.go @@ -621,12 +621,13 @@ func TestGalleryQueryPerformers(t *testing.T) { func TestGalleryQueryTags(t *testing.T) { withTxn(func(r models.Repository) error { sqb := r.Gallery() - tagCriterion := models.MultiCriterionInput{ + tagCriterion := models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdxWithGallery]), strconv.Itoa(tagIDs[tagIdx1WithGallery]), }, Modifier: models.CriterionModifierIncludes, + Depth: 0, } galleryFilter := models.GalleryFilterType{ @@ -641,12 +642,13 @@ func TestGalleryQueryTags(t *testing.T) { assert.True(t, gallery.ID == galleryIDs[galleryIdxWithTag] || gallery.ID == galleryIDs[galleryIdxWithTwoTags]) } - tagCriterion = models.MultiCriterionInput{ + tagCriterion = models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdx1WithGallery]), strconv.Itoa(tagIDs[tagIdx2WithGallery]), }, Modifier: models.CriterionModifierIncludesAll, + Depth: 0, } galleries = queryGallery(t, sqb, &galleryFilter, nil) @@ -654,11 +656,12 @@ func TestGalleryQueryTags(t *testing.T) { assert.Len(t, galleries, 1) assert.Equal(t, galleryIDs[galleryIdxWithTwoTags], galleries[0].ID) - tagCriterion = models.MultiCriterionInput{ + tagCriterion = models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdx1WithGallery]), }, Modifier: models.CriterionModifierExcludes, + Depth: 0, } q := getGalleryStringValue(galleryIdxWithTwoTags, titleField) @@ -776,12 +779,13 @@ func TestGalleryQueryStudioDepth(t *testing.T) { func TestGalleryQueryPerformerTags(t *testing.T) { withTxn(func(r models.Repository) error { sqb := r.Gallery() - tagCriterion := models.MultiCriterionInput{ + tagCriterion := models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdxWithPerformer]), strconv.Itoa(tagIDs[tagIdx1WithPerformer]), }, Modifier: models.CriterionModifierIncludes, + Depth: 0, } galleryFilter := models.GalleryFilterType{ @@ -796,12 +800,13 @@ func TestGalleryQueryPerformerTags(t *testing.T) { assert.True(t, gallery.ID == galleryIDs[galleryIdxWithPerformerTag] || gallery.ID == galleryIDs[galleryIdxWithPerformerTwoTags]) } - tagCriterion = models.MultiCriterionInput{ + tagCriterion = models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdx1WithPerformer]), strconv.Itoa(tagIDs[tagIdx2WithPerformer]), }, Modifier: models.CriterionModifierIncludesAll, + Depth: 0, } galleries = queryGallery(t, sqb, &galleryFilter, nil) @@ -809,11 +814,12 @@ func TestGalleryQueryPerformerTags(t *testing.T) { assert.Len(t, galleries, 1) assert.Equal(t, galleryIDs[galleryIdxWithPerformerTwoTags], galleries[0].ID) - tagCriterion = models.MultiCriterionInput{ + tagCriterion = models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdx1WithPerformer]), }, Modifier: models.CriterionModifierExcludes, + Depth: 0, } q := getGalleryStringValue(galleryIdxWithPerformerTwoTags, titleField) diff --git a/pkg/sqlite/image.go b/pkg/sqlite/image.go index d0b6f16f8..01598c51a 100644 --- a/pkg/sqlite/image.go +++ b/pkg/sqlite/image.go @@ -348,17 +348,18 @@ func (qb *imageQueryBuilder) getMultiCriterionHandlerBuilder(foreignTable, joinT } } -func imageTagsCriterionHandler(qb *imageQueryBuilder, tags *models.MultiCriterionInput) criterionHandlerFunc { - h := joinedMultiCriterionHandlerBuilder{ - primaryTable: imageTable, - joinTable: imagesTagsTable, - joinAs: "tags_join", - primaryFK: imageIDColumn, - foreignFK: tagIDColumn, +func imageTagsCriterionHandler(qb *imageQueryBuilder, tags *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { + h := joinedHierarchicalMultiCriterionHandlerBuilder{ + tx: qb.tx, - addJoinTable: func(f *filterBuilder) { - qb.tagsRepository().join(f, "tags_join", "images.id") - }, + primaryTable: imageTable, + foreignTable: tagTable, + foreignFK: "tag_id", + + relationsTable: "tags_relations", + joinAs: "image_tag", + joinTable: imagesTagsTable, + primaryFK: imageIDColumn, } return h.handler(tags) @@ -412,6 +413,8 @@ func imagePerformerCountCriterionHandler(qb *imageQueryBuilder, performerCount * func imageStudioCriterionHandler(qb *imageQueryBuilder, studios *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { h := hierarchicalMultiCriterionHandlerBuilder{ + tx: qb.tx, + primaryTable: imageTable, foreignTable: studioTable, foreignFK: studioIDColumn, @@ -422,31 +425,20 @@ func imageStudioCriterionHandler(qb *imageQueryBuilder, studios *models.Hierarch return h.handler(studios) } -func imagePerformerTagsCriterionHandler(qb *imageQueryBuilder, performerTagsFilter *models.MultiCriterionInput) criterionHandlerFunc { +func imagePerformerTagsCriterionHandler(qb *imageQueryBuilder, tags *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { return func(f *filterBuilder) { - if performerTagsFilter != nil && len(performerTagsFilter.Value) > 0 { - qb.performersRepository().join(f, "performers_join", "images.id") - f.addJoin("performers_tags", "performer_tags_join", "performers_join.performer_id = performer_tags_join.performer_id") + if tags != nil && len(tags.Value) > 0 { + valuesClause := getHierarchicalValues(qb.tx, tags.Value, tagTable, "tags_relations", "", tags.Depth) - var args []interface{} - for _, tagID := range performerTagsFilter.Value { - args = append(args, tagID) - } + f.addWith(`performer_tags AS ( +SELECT pi.image_id, t.column1 AS root_tag_id FROM performers_images pi +INNER JOIN performers_tags pt ON pt.performer_id = pi.performer_id +INNER JOIN (` + valuesClause + `) t ON t.column2 = pt.tag_id +)`) - if performerTagsFilter.Modifier == models.CriterionModifierIncludes { - // includes any of the provided ids - f.addWhere("performer_tags_join.tag_id IN "+getInBinding(len(performerTagsFilter.Value)), args...) - } else if performerTagsFilter.Modifier == models.CriterionModifierIncludesAll { - // includes all of the provided ids - f.addWhere("performer_tags_join.tag_id IN "+getInBinding(len(performerTagsFilter.Value)), args...) - f.addHaving(fmt.Sprintf("count(distinct performer_tags_join.tag_id) IS %d", len(performerTagsFilter.Value))) - } else if performerTagsFilter.Modifier == models.CriterionModifierExcludes { - f.addWhere(fmt.Sprintf(`not exists - (select performers_images.performer_id from performers_images - left join performers_tags on performers_tags.performer_id = performers_images.performer_id where - performers_images.image_id = images.id AND - performers_tags.tag_id in %s)`, getInBinding(len(performerTagsFilter.Value))), args...) - } + f.addJoin("performer_tags", "", "performer_tags.image_id = images.id") + + addHierarchicalConditionClauses(f, tags, "performer_tags", "root_tag_id") } } } diff --git a/pkg/sqlite/image_test.go b/pkg/sqlite/image_test.go index 6d866041f..02e077c8d 100644 --- a/pkg/sqlite/image_test.go +++ b/pkg/sqlite/image_test.go @@ -722,12 +722,13 @@ func TestImageQueryPerformers(t *testing.T) { func TestImageQueryTags(t *testing.T) { withTxn(func(r models.Repository) error { sqb := r.Image() - tagCriterion := models.MultiCriterionInput{ + tagCriterion := models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdxWithImage]), strconv.Itoa(tagIDs[tagIdx1WithImage]), }, Modifier: models.CriterionModifierIncludes, + Depth: 0, } imageFilter := models.ImageFilterType{ @@ -746,12 +747,13 @@ func TestImageQueryTags(t *testing.T) { assert.True(t, image.ID == imageIDs[imageIdxWithTag] || image.ID == imageIDs[imageIdxWithTwoTags]) } - tagCriterion = models.MultiCriterionInput{ + tagCriterion = models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdx1WithImage]), strconv.Itoa(tagIDs[tagIdx2WithImage]), }, Modifier: models.CriterionModifierIncludesAll, + Depth: 0, } images, _, err = sqb.Query(&imageFilter, nil) @@ -762,11 +764,12 @@ func TestImageQueryTags(t *testing.T) { assert.Len(t, images, 1) assert.Equal(t, imageIDs[imageIdxWithTwoTags], images[0].ID) - tagCriterion = models.MultiCriterionInput{ + tagCriterion = models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdx1WithImage]), }, Modifier: models.CriterionModifierExcludes, + Depth: 0, } q := getImageStringValue(imageIdxWithTwoTags, titleField) @@ -902,12 +905,13 @@ func queryImages(t *testing.T, sqb models.ImageReader, imageFilter *models.Image func TestImageQueryPerformerTags(t *testing.T) { withTxn(func(r models.Repository) error { sqb := r.Image() - tagCriterion := models.MultiCriterionInput{ + tagCriterion := models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdxWithPerformer]), strconv.Itoa(tagIDs[tagIdx1WithPerformer]), }, Modifier: models.CriterionModifierIncludes, + Depth: 0, } imageFilter := models.ImageFilterType{ @@ -922,12 +926,13 @@ func TestImageQueryPerformerTags(t *testing.T) { assert.True(t, image.ID == imageIDs[imageIdxWithPerformerTag] || image.ID == imageIDs[imageIdxWithPerformerTwoTags]) } - tagCriterion = models.MultiCriterionInput{ + tagCriterion = models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdx1WithPerformer]), strconv.Itoa(tagIDs[tagIdx2WithPerformer]), }, Modifier: models.CriterionModifierIncludesAll, + Depth: 0, } images = queryImages(t, sqb, &imageFilter, nil) @@ -935,11 +940,12 @@ func TestImageQueryPerformerTags(t *testing.T) { assert.Len(t, images, 1) assert.Equal(t, imageIDs[imageIdxWithPerformerTwoTags], images[0].ID) - tagCriterion = models.MultiCriterionInput{ + tagCriterion = models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdx1WithPerformer]), }, Modifier: models.CriterionModifierExcludes, + Depth: 0, } q := getImageStringValue(imageIdxWithPerformerTwoTags, titleField) diff --git a/pkg/sqlite/movies.go b/pkg/sqlite/movies.go index 9e69b8451..92c3636e3 100644 --- a/pkg/sqlite/movies.go +++ b/pkg/sqlite/movies.go @@ -195,6 +195,8 @@ func movieIsMissingCriterionHandler(qb *movieQueryBuilder, isMissing *string) cr func movieStudioCriterionHandler(qb *movieQueryBuilder, studios *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { h := hierarchicalMultiCriterionHandlerBuilder{ + tx: qb.tx, + primaryTable: movieTable, foreignTable: studioTable, foreignFK: studioIDColumn, diff --git a/pkg/sqlite/performer.go b/pkg/sqlite/performer.go index 526a6e360..c8b3f86de 100644 --- a/pkg/sqlite/performer.go +++ b/pkg/sqlite/performer.go @@ -280,7 +280,7 @@ func (qb *performerQueryBuilder) makeFilter(filter *models.PerformerFilterType) query.handleCriterion(performerTagsCriterionHandler(qb, filter.Tags)) - query.handleCriterion(performerStudiosCriterionHandler(filter.Studios)) + query.handleCriterion(performerStudiosCriterionHandler(qb, filter.Studios)) query.handleCriterion(performerTagCountCriterionHandler(qb, filter.TagCount)) query.handleCriterion(performerSceneCountCriterionHandler(qb, filter.SceneCount)) @@ -376,17 +376,18 @@ func performerAgeFilterCriterionHandler(age *models.IntCriterionInput) criterion } } -func performerTagsCriterionHandler(qb *performerQueryBuilder, tags *models.MultiCriterionInput) criterionHandlerFunc { - h := joinedMultiCriterionHandlerBuilder{ - primaryTable: performerTable, - joinTable: performersTagsTable, - joinAs: "tags_join", - primaryFK: performerIDColumn, - foreignFK: tagIDColumn, +func performerTagsCriterionHandler(qb *performerQueryBuilder, tags *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { + h := joinedHierarchicalMultiCriterionHandlerBuilder{ + tx: qb.tx, - addJoinTable: func(f *filterBuilder) { - qb.tagsRepository().join(f, "tags_join", "performers.id") - }, + primaryTable: performerTable, + foreignTable: tagTable, + foreignFK: "tag_id", + + relationsTable: "tags_relations", + joinAs: "image_tag", + joinTable: performersTagsTable, + primaryFK: performerIDColumn, } return h.handler(tags) @@ -432,7 +433,7 @@ func performerGalleryCountCriterionHandler(qb *performerQueryBuilder, count *mod return h.handler(count) } -func performerStudiosCriterionHandler(studios *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { +func performerStudiosCriterionHandler(qb *performerQueryBuilder, studios *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { return func(f *filterBuilder) { if studios != nil { var clauseCondition string @@ -465,13 +466,13 @@ func performerStudiosCriterionHandler(studios *models.HierarchicalMultiCriterion }, } - const derivedStudioTable = "studio" const derivedPerformerStudioTable = "performer_studio" - addHierarchicalWithClause(f, studios.Value, derivedStudioTable, studioTable, "parent_id", studios.Depth) + valuesClause := getHierarchicalValues(qb.tx, studios.Value, studioTable, "", "parent_id", studios.Depth) + f.addWith("studio(root_id, item_id) AS (" + valuesClause + ")") templStr := `SELECT performer_id FROM {primaryTable} INNER JOIN {joinTable} ON {primaryTable}.id = {joinTable}.{primaryFK} - INNER JOIN studio ON {primaryTable}.studio_id = studio.child_id` + INNER JOIN studio ON {primaryTable}.studio_id = studio.item_id` var unions []string for _, c := range formatMaps { diff --git a/pkg/sqlite/performer_test.go b/pkg/sqlite/performer_test.go index a68e09e03..4ce990c35 100644 --- a/pkg/sqlite/performer_test.go +++ b/pkg/sqlite/performer_test.go @@ -500,12 +500,13 @@ func queryPerformers(t *testing.T, qb models.PerformerReader, performerFilter *m func TestPerformerQueryTags(t *testing.T) { withTxn(func(r models.Repository) error { sqb := r.Performer() - tagCriterion := models.MultiCriterionInput{ + tagCriterion := models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdxWithPerformer]), strconv.Itoa(tagIDs[tagIdx1WithPerformer]), }, Modifier: models.CriterionModifierIncludes, + Depth: 0, } performerFilter := models.PerformerFilterType{ @@ -519,12 +520,13 @@ func TestPerformerQueryTags(t *testing.T) { assert.True(t, performer.ID == performerIDs[performerIdxWithTag] || performer.ID == performerIDs[performerIdxWithTwoTags]) } - tagCriterion = models.MultiCriterionInput{ + tagCriterion = models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdx1WithPerformer]), strconv.Itoa(tagIDs[tagIdx2WithPerformer]), }, Modifier: models.CriterionModifierIncludesAll, + Depth: 0, } performers = queryPerformers(t, sqb, &performerFilter, nil) @@ -532,11 +534,12 @@ func TestPerformerQueryTags(t *testing.T) { assert.Len(t, performers, 1) assert.Equal(t, sceneIDs[performerIdxWithTwoTags], performers[0].ID) - tagCriterion = models.MultiCriterionInput{ + tagCriterion = models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdx1WithPerformer]), }, Modifier: models.CriterionModifierExcludes, + Depth: 0, } q := getSceneStringValue(performerIdxWithTwoTags, titleField) diff --git a/pkg/sqlite/query.go b/pkg/sqlite/query.go index 2d31ee583..158e7f971 100644 --- a/pkg/sqlite/query.go +++ b/pkg/sqlite/query.go @@ -18,6 +18,7 @@ type queryBuilder struct { havingClauses []string args []interface{} withClauses []string + recursiveWith bool sortAndPagination string @@ -32,7 +33,7 @@ func (qb queryBuilder) executeFind() ([]int, int, error) { body := qb.body body += qb.joins.toSQL() - return qb.repository.executeFindQuery(body, qb.args, qb.sortAndPagination, qb.whereClauses, qb.havingClauses, qb.withClauses) + return qb.repository.executeFindQuery(body, qb.args, qb.sortAndPagination, qb.whereClauses, qb.havingClauses, qb.withClauses, qb.recursiveWith) } func (qb queryBuilder) executeCount() (int, error) { @@ -45,7 +46,11 @@ func (qb queryBuilder) executeCount() (int, error) { withClause := "" if len(qb.withClauses) > 0 { - withClause = "WITH " + strings.Join(qb.withClauses, ", ") + " " + var recursive string + if qb.recursiveWith { + recursive = " RECURSIVE " + } + withClause = "WITH " + recursive + strings.Join(qb.withClauses, ", ") + " " } body = qb.repository.buildQueryBody(body, qb.whereClauses, qb.havingClauses) @@ -69,12 +74,14 @@ func (qb *queryBuilder) addHaving(clauses ...string) { } } -func (qb *queryBuilder) addWith(clauses ...string) { +func (qb *queryBuilder) addWith(recursive bool, clauses ...string) { for _, clause := range clauses { if len(clause) > 0 { qb.withClauses = append(qb.withClauses, clause) } } + + qb.recursiveWith = qb.recursiveWith || recursive } func (qb *queryBuilder) addArg(args ...interface{}) { @@ -104,7 +111,7 @@ func (qb *queryBuilder) addFilter(f *filterBuilder) { clause, args := f.generateWithClauses() if len(clause) > 0 { - qb.addWith(clause) + qb.addWith(f.recursiveWith, clause) } if len(args) > 0 { diff --git a/pkg/sqlite/repository.go b/pkg/sqlite/repository.go index 6550d0d67..ffe5b78ad 100644 --- a/pkg/sqlite/repository.go +++ b/pkg/sqlite/repository.go @@ -246,12 +246,16 @@ func (r *repository) buildQueryBody(body string, whereClauses []string, havingCl return body } -func (r *repository) executeFindQuery(body string, args []interface{}, sortAndPagination string, whereClauses []string, havingClauses []string, withClauses []string) ([]int, int, error) { +func (r *repository) executeFindQuery(body string, args []interface{}, sortAndPagination string, whereClauses []string, havingClauses []string, withClauses []string, recursiveWith bool) ([]int, int, error) { body = r.buildQueryBody(body, whereClauses, havingClauses) withClause := "" if len(withClauses) > 0 { - withClause = "WITH " + strings.Join(withClauses, ", ") + " " + var recursive string + if recursiveWith { + recursive = " RECURSIVE " + } + withClause = "WITH " + recursive + strings.Join(withClauses, ", ") + " " } countQuery := withClause + r.buildCountQuery(body) diff --git a/pkg/sqlite/scene.go b/pkg/sqlite/scene.go index 92cdd5c2c..5aa0d4722 100644 --- a/pkg/sqlite/scene.go +++ b/pkg/sqlite/scene.go @@ -548,12 +548,19 @@ func (qb *sceneQueryBuilder) getMultiCriterionHandlerBuilder(foreignTable, joinT } } -func sceneTagsCriterionHandler(qb *sceneQueryBuilder, tags *models.MultiCriterionInput) criterionHandlerFunc { - addJoinsFunc := func(f *filterBuilder) { - qb.tagsRepository().join(f, "tags_join", "scenes.id") - f.addJoin("tags", "", "tags_join.tag_id = tags.id") +func sceneTagsCriterionHandler(qb *sceneQueryBuilder, tags *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { + h := joinedHierarchicalMultiCriterionHandlerBuilder{ + tx: qb.tx, + + primaryTable: sceneTable, + foreignTable: tagTable, + foreignFK: "tag_id", + + relationsTable: "tags_relations", + joinAs: "scene_tag", + joinTable: scenesTagsTable, + primaryFK: sceneIDColumn, } - h := qb.getMultiCriterionHandlerBuilder(tagTable, scenesTagsTable, tagIDColumn, addJoinsFunc) return h.handler(tags) } @@ -596,6 +603,8 @@ func scenePerformerCountCriterionHandler(qb *sceneQueryBuilder, performerCount * func sceneStudioCriterionHandler(qb *sceneQueryBuilder, studios *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { h := hierarchicalMultiCriterionHandlerBuilder{ + tx: qb.tx, + primaryTable: sceneTable, foreignTable: studioTable, foreignFK: studioIDColumn, @@ -615,31 +624,20 @@ func sceneMoviesCriterionHandler(qb *sceneQueryBuilder, movies *models.MultiCrit return h.handler(movies) } -func scenePerformerTagsCriterionHandler(qb *sceneQueryBuilder, performerTagsFilter *models.MultiCriterionInput) criterionHandlerFunc { +func scenePerformerTagsCriterionHandler(qb *sceneQueryBuilder, tags *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { return func(f *filterBuilder) { - if performerTagsFilter != nil && len(performerTagsFilter.Value) > 0 { - qb.performersRepository().join(f, "performers_join", "scenes.id") - f.addJoin("performers_tags", "performer_tags_join", "performers_join.performer_id = performer_tags_join.performer_id") + if tags != nil && len(tags.Value) > 0 { + valuesClause := getHierarchicalValues(qb.tx, tags.Value, tagTable, "tags_relations", "", tags.Depth) - var args []interface{} - for _, tagID := range performerTagsFilter.Value { - args = append(args, tagID) - } + f.addWith(`performer_tags AS ( +SELECT ps.scene_id, t.column1 AS root_tag_id FROM performers_scenes ps +INNER JOIN performers_tags pt ON pt.performer_id = ps.performer_id +INNER JOIN (` + valuesClause + `) t ON t.column2 = pt.tag_id +)`) - if performerTagsFilter.Modifier == models.CriterionModifierIncludes { - // includes any of the provided ids - f.addWhere("performer_tags_join.tag_id IN "+getInBinding(len(performerTagsFilter.Value)), args...) - } else if performerTagsFilter.Modifier == models.CriterionModifierIncludesAll { - // includes all of the provided ids - f.addWhere("performer_tags_join.tag_id IN "+getInBinding(len(performerTagsFilter.Value)), args...) - f.addHaving(fmt.Sprintf("count(distinct performer_tags_join.tag_id) IS %d", len(performerTagsFilter.Value))) - } else if performerTagsFilter.Modifier == models.CriterionModifierExcludes { - f.addWhere(fmt.Sprintf(`not exists - (select performers_scenes.performer_id from performers_scenes - left join performers_tags on performers_tags.performer_id = performers_scenes.performer_id where - performers_scenes.scene_id = scenes.id AND - performers_tags.tag_id in %s)`, getInBinding(len(performerTagsFilter.Value))), args...) - } + f.addJoin("performer_tags", "", "performer_tags.scene_id = scenes.id") + + addHierarchicalConditionClauses(f, tags, "performer_tags", "root_tag_id") } } } diff --git a/pkg/sqlite/scene_marker.go b/pkg/sqlite/scene_marker.go index 0ad3cda43..392bfc75d 100644 --- a/pkg/sqlite/scene_marker.go +++ b/pkg/sqlite/scene_marker.go @@ -3,8 +3,6 @@ package sqlite import ( "database/sql" "fmt" - "strconv" - "github.com/stashapp/stash/pkg/database" "github.com/stashapp/stash/pkg/models" ) @@ -127,6 +125,17 @@ func (qb *sceneMarkerQueryBuilder) Wall(q *string) ([]*models.SceneMarker, error return qb.querySceneMarkers(query, nil) } +func (qb *sceneMarkerQueryBuilder) makeFilter(sceneMarkerFilter *models.SceneMarkerFilterType) *filterBuilder { + query := &filterBuilder{} + + query.handleCriterion(sceneMarkerTagIDCriterionHandler(qb, sceneMarkerFilter.TagID)) + query.handleCriterion(sceneMarkerTagsCriterionHandler(qb, sceneMarkerFilter.Tags)) + query.handleCriterion(sceneMarkerSceneTagsCriterionHandler(qb, sceneMarkerFilter.SceneTags)) + query.handleCriterion(sceneMarkerPerformersCriterionHandler(qb, sceneMarkerFilter.Performers)) + + return query +} + func (qb *sceneMarkerQueryBuilder) Query(sceneMarkerFilter *models.SceneMarkerFilterType, findFilter *models.FindFilterType) ([]*models.SceneMarker, int, error) { if sceneMarkerFilter == nil { sceneMarkerFilter = &models.SceneMarkerFilterType{} @@ -135,121 +144,23 @@ func (qb *sceneMarkerQueryBuilder) Query(sceneMarkerFilter *models.SceneMarkerFi findFilter = &models.FindFilterType{} } - var whereClauses []string - var havingClauses []string - var args []interface{} - body := selectDistinctIDs("scene_markers") - body = body + ` - left join tags as primary_tag on primary_tag.id = scene_markers.primary_tag_id - left join scenes as scene on scene.id = scene_markers.scene_id - left join scene_markers_tags as tags_join on tags_join.scene_marker_id = scene_markers.id - left join tags on tags_join.tag_id = tags.id - ` + query := qb.newQuery() - if tagsFilter := sceneMarkerFilter.Tags; tagsFilter != nil && len(tagsFilter.Value) > 0 { - //select `scene_markers`.* from `scene_markers` - //left join `tags` as `primary_tags_join` - // on `primary_tags_join`.`id` = `scene_markers`.`primary_tag_id` - // and `primary_tags_join`.`id` in ('3', '37', '9', '89') - //left join `scene_markers_tags` as `tags_join` - // on `tags_join`.`scene_marker_id` = `scene_markers`.`id` - // and `tags_join`.`tag_id` in ('3', '37', '9', '89') - //group by `scene_markers`.`id` - //having ((count(distinct `primary_tags_join`.`id`) + count(distinct `tags_join`.`tag_id`)) = 4) - - length := len(tagsFilter.Value) - - if tagsFilter.Modifier == models.CriterionModifierIncludes || tagsFilter.Modifier == models.CriterionModifierIncludesAll { - body += " LEFT JOIN tags AS ptj ON ptj.id = scene_markers.primary_tag_id AND ptj.id IN " + getInBinding(length) - body += " LEFT JOIN scene_markers_tags AS tj ON tj.scene_marker_id = scene_markers.id AND tj.tag_id IN " + getInBinding(length) - - // only one required for include any - requiredCount := 1 - - // all required for include all - if tagsFilter.Modifier == models.CriterionModifierIncludesAll { - requiredCount = length - } - - havingClauses = append(havingClauses, "((COUNT(DISTINCT ptj.id) + COUNT(DISTINCT tj.tag_id)) >= "+strconv.Itoa(requiredCount)+")") - } else if tagsFilter.Modifier == models.CriterionModifierExcludes { - // excludes all of the provided ids - whereClauses = append(whereClauses, "scene_markers.primary_tag_id not in "+getInBinding(length)) - whereClauses = append(whereClauses, "not exists (select smt.scene_marker_id from scene_markers_tags as smt where smt.scene_marker_id = scene_markers.id and smt.tag_id in "+getInBinding(length)+")") - } - - for _, tagID := range tagsFilter.Value { - args = append(args, tagID) - } - for _, tagID := range tagsFilter.Value { - args = append(args, tagID) - } - } - - if sceneTagsFilter := sceneMarkerFilter.SceneTags; sceneTagsFilter != nil && len(sceneTagsFilter.Value) > 0 { - length := len(sceneTagsFilter.Value) - - if sceneTagsFilter.Modifier == models.CriterionModifierIncludes || sceneTagsFilter.Modifier == models.CriterionModifierIncludesAll { - body += " LEFT JOIN scenes_tags AS scene_tags_join ON scene_tags_join.scene_id = scene.id AND scene_tags_join.tag_id IN " + getInBinding(length) - - // only one required for include any - requiredCount := 1 - - // all required for include all - if sceneTagsFilter.Modifier == models.CriterionModifierIncludesAll { - requiredCount = length - } - - havingClauses = append(havingClauses, "COUNT(DISTINCT scene_tags_join.tag_id) >= "+strconv.Itoa(requiredCount)) - } else if sceneTagsFilter.Modifier == models.CriterionModifierExcludes { - // excludes all of the provided ids - whereClauses = append(whereClauses, "not exists (select st.scene_id from scenes_tags as st where st.scene_id = scene.id AND st.tag_id IN "+getInBinding(length)+")") - } - - for _, tagID := range sceneTagsFilter.Value { - args = append(args, tagID) - } - } - - if performersFilter := sceneMarkerFilter.Performers; performersFilter != nil && len(performersFilter.Value) > 0 { - length := len(performersFilter.Value) - - if performersFilter.Modifier == models.CriterionModifierIncludes || performersFilter.Modifier == models.CriterionModifierIncludesAll { - body += " LEFT JOIN performers_scenes as scene_performers ON scene.id = scene_performers.scene_id" - whereClauses = append(whereClauses, "scene_performers.performer_id IN "+getInBinding(length)) - - // only one required for include any - requiredCount := 1 - - // all required for include all - if performersFilter.Modifier == models.CriterionModifierIncludesAll { - requiredCount = length - } - - havingClauses = append(havingClauses, "COUNT(DISTINCT scene_performers.performer_id) >= "+strconv.Itoa(requiredCount)) - } else if performersFilter.Modifier == models.CriterionModifierExcludes { - // excludes all of the provided ids - whereClauses = append(whereClauses, "not exists (select sp.scene_id from performers_scenes as sp where sp.scene_id = scene.id AND sp.performer_id IN "+getInBinding(length)+")") - } - - for _, performerID := range performersFilter.Value { - args = append(args, performerID) - } - } + query.body = selectDistinctIDs("scene_markers") if q := findFilter.Q; q != nil && *q != "" { searchColumns := []string{"scene_markers.title", "scene.title"} clause, thisArgs := getSearchBinding(searchColumns, *q, false) - whereClauses = append(whereClauses, clause) - args = append(args, thisArgs...) + query.addWhere(clause) + query.addArg(thisArgs...) } - if tagID := sceneMarkerFilter.TagID; tagID != nil { - whereClauses = append(whereClauses, "(scene_markers.primary_tag_id = "+*tagID+" OR tags.id = "+*tagID+")") - } + filter := qb.makeFilter(sceneMarkerFilter) - sortAndPagination := qb.getSceneMarkerSort(findFilter) + getPagination(findFilter) - idsResult, countResult, err := qb.executeFindQuery(body, args, sortAndPagination, whereClauses, havingClauses, []string{}) + query.addFilter(filter) + + query.sortAndPagination = qb.getSceneMarkerSort(findFilter) + getPagination(findFilter) + idsResult, countResult, err := query.executeFind() if err != nil { return nil, 0, err } @@ -267,6 +178,74 @@ func (qb *sceneMarkerQueryBuilder) Query(sceneMarkerFilter *models.SceneMarkerFi return sceneMarkers, countResult, nil } +func sceneMarkerTagIDCriterionHandler(qb *sceneMarkerQueryBuilder, tagID *string) criterionHandlerFunc { + return func(f *filterBuilder) { + if tagID != nil { + f.addJoin("scene_markers_tags", "", "scene_markers_tags.scene_marker_id = scene_markers.id") + + f.addWhere("(scene_markers.primary_tag_id = ? OR scene_markers_tags.tag_id = ?)", *tagID, *tagID) + } + } +} + +func sceneMarkerTagsCriterionHandler(qb *sceneMarkerQueryBuilder, tags *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { + return func(f *filterBuilder) { + if tags != nil && len(tags.Value) > 0 { + valuesClause := getHierarchicalValues(qb.tx, tags.Value, tagTable, "tags_relations", "", tags.Depth) + + f.addWith(`marker_tags AS ( +SELECT mt.scene_marker_id, t.column1 AS root_tag_id FROM scene_markers_tags mt +INNER JOIN (` + valuesClause + `) t ON t.column2 = mt.tag_id +UNION +SELECT m.id, t.column1 FROM scene_markers m +INNER JOIN (` + valuesClause + `) t ON t.column2 = m.primary_tag_id +)`) + + f.addJoin("marker_tags", "", "marker_tags.scene_marker_id = scene_markers.id") + + addHierarchicalConditionClauses(f, tags, "marker_tags", "root_tag_id") + } + } +} + +func sceneMarkerSceneTagsCriterionHandler(qb *sceneMarkerQueryBuilder, tags *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { + return func(f *filterBuilder) { + if tags != nil && len(tags.Value) > 0 { + valuesClause := getHierarchicalValues(qb.tx, tags.Value, tagTable, "tags_relations", "", tags.Depth) + + f.addWith(`scene_tags AS ( +SELECT st.scene_id, t.column1 AS root_tag_id FROM scenes_tags st +INNER JOIN (` + valuesClause + `) t ON t.column2 = st.tag_id +)`) + + f.addJoin("scene_tags", "", "scene_tags.scene_id = scene_markers.scene_id") + + addHierarchicalConditionClauses(f, tags, "scene_tags", "root_tag_id") + } + } +} + +func sceneMarkerPerformersCriterionHandler(qb *sceneMarkerQueryBuilder, performers *models.MultiCriterionInput) criterionHandlerFunc { + h := joinedMultiCriterionHandlerBuilder{ + primaryTable: sceneTable, + joinTable: performersScenesTable, + joinAs: "performers_join", + primaryFK: sceneIDColumn, + foreignFK: performerIDColumn, + + addJoinTable: func(f *filterBuilder) { + f.addJoin(performersScenesTable, "performers_join", "performers_join.scene_id = scene_markers.scene_id") + }, + } + + handler := h.handler(performers) + return func(f *filterBuilder) { + // Make sure scenes is included, otherwise excludes filter fails + f.addJoin(sceneTable, "", "scenes.id = scene_markers.scene_id") + handler(f) + } +} + func (qb *sceneMarkerQueryBuilder) getSceneMarkerSort(findFilter *models.FindFilterType) string { sort := findFilter.GetSort("title") direction := findFilter.GetDirection() diff --git a/pkg/sqlite/scene_test.go b/pkg/sqlite/scene_test.go index 2be8e4d45..aa895b8b7 100644 --- a/pkg/sqlite/scene_test.go +++ b/pkg/sqlite/scene_test.go @@ -1034,12 +1034,13 @@ func TestSceneQueryPerformers(t *testing.T) { func TestSceneQueryTags(t *testing.T) { withTxn(func(r models.Repository) error { sqb := r.Scene() - tagCriterion := models.MultiCriterionInput{ + tagCriterion := models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdxWithScene]), strconv.Itoa(tagIDs[tagIdx1WithScene]), }, Modifier: models.CriterionModifierIncludes, + Depth: 0, } sceneFilter := models.SceneFilterType{ @@ -1054,12 +1055,13 @@ func TestSceneQueryTags(t *testing.T) { assert.True(t, scene.ID == sceneIDs[sceneIdxWithTag] || scene.ID == sceneIDs[sceneIdxWithTwoTags]) } - tagCriterion = models.MultiCriterionInput{ + tagCriterion = models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdx1WithScene]), strconv.Itoa(tagIDs[tagIdx2WithScene]), }, Modifier: models.CriterionModifierIncludesAll, + Depth: 0, } scenes = queryScene(t, sqb, &sceneFilter, nil) @@ -1067,11 +1069,12 @@ func TestSceneQueryTags(t *testing.T) { assert.Len(t, scenes, 1) assert.Equal(t, sceneIDs[sceneIdxWithTwoTags], scenes[0].ID) - tagCriterion = models.MultiCriterionInput{ + tagCriterion = models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdx1WithScene]), }, Modifier: models.CriterionModifierExcludes, + Depth: 0, } q := getSceneStringValue(sceneIdxWithTwoTags, titleField) @@ -1089,12 +1092,13 @@ func TestSceneQueryTags(t *testing.T) { func TestSceneQueryPerformerTags(t *testing.T) { withTxn(func(r models.Repository) error { sqb := r.Scene() - tagCriterion := models.MultiCriterionInput{ + tagCriterion := models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdxWithPerformer]), strconv.Itoa(tagIDs[tagIdx1WithPerformer]), }, Modifier: models.CriterionModifierIncludes, + Depth: 0, } sceneFilter := models.SceneFilterType{ @@ -1109,12 +1113,13 @@ func TestSceneQueryPerformerTags(t *testing.T) { assert.True(t, scene.ID == sceneIDs[sceneIdxWithPerformerTag] || scene.ID == sceneIDs[sceneIdxWithPerformerTwoTags]) } - tagCriterion = models.MultiCriterionInput{ + tagCriterion = models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdx1WithPerformer]), strconv.Itoa(tagIDs[tagIdx2WithPerformer]), }, Modifier: models.CriterionModifierIncludesAll, + Depth: 0, } scenes = queryScene(t, sqb, &sceneFilter, nil) @@ -1122,11 +1127,12 @@ func TestSceneQueryPerformerTags(t *testing.T) { assert.Len(t, scenes, 1) assert.Equal(t, sceneIDs[sceneIdxWithPerformerTwoTags], scenes[0].ID) - tagCriterion = models.MultiCriterionInput{ + tagCriterion = models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdx1WithPerformer]), }, Modifier: models.CriterionModifierExcludes, + Depth: 0, } q := getSceneStringValue(sceneIdxWithPerformerTwoTags, titleField) diff --git a/pkg/sqlite/tag.go b/pkg/sqlite/tag.go index c67a8dea5..28237867c 100644 --- a/pkg/sqlite/tag.go +++ b/pkg/sqlite/tag.go @@ -196,6 +196,28 @@ func (qb *tagQueryBuilder) FindByNames(names []string, nocase bool) ([]*models.T return qb.queryTags(query, args) } +func (qb *tagQueryBuilder) FindByParentTagID(parentID int) ([]*models.Tag, error) { + query := ` + SELECT tags.* FROM tags + INNER JOIN tags_relations ON tags_relations.child_id = tags.id + WHERE tags_relations.parent_id = ? + ` + query += qb.getDefaultTagSort() + args := []interface{}{parentID} + return qb.queryTags(query, args) +} + +func (qb *tagQueryBuilder) FindByChildTagID(parentID int) ([]*models.Tag, error) { + query := ` + SELECT tags.* FROM tags + INNER JOIN tags_relations ON tags_relations.parent_id = tags.id + WHERE tags_relations.child_id = ? + ` + query += qb.getDefaultTagSort() + args := []interface{}{parentID} + return qb.queryTags(query, args) +} + func (qb *tagQueryBuilder) Count() (int, error) { return qb.runCountQuery(qb.buildCountQuery("SELECT tags.id FROM tags"), nil) } @@ -572,3 +594,115 @@ AND NOT EXISTS(SELECT 1 FROM `+table+` o WHERE o.`+idColumn+` = `+table+`.`+idCo return nil } + +func (qb *tagQueryBuilder) UpdateParentTags(tagID int, parentIDs []int) error { + tx := qb.tx + if _, err := tx.Exec("DELETE FROM tags_relations WHERE child_id = ?", tagID); err != nil { + return err + } + + if len(parentIDs) > 0 { + var args []interface{} + var values []string + for _, parentID := range parentIDs { + values = append(values, "(? , ?)") + args = append(args, parentID, tagID) + } + + query := "INSERT INTO tags_relations (parent_id, child_id) VALUES " + strings.Join(values, ", ") + if _, err := tx.Exec(query, args...); err != nil { + return err + } + } + + return nil +} + +func (qb *tagQueryBuilder) UpdateChildTags(tagID int, childIDs []int) error { + tx := qb.tx + if _, err := tx.Exec("DELETE FROM tags_relations WHERE parent_id = ?", tagID); err != nil { + return err + } + + if len(childIDs) > 0 { + var args []interface{} + var values []string + for _, childID := range childIDs { + values = append(values, "(? , ?)") + args = append(args, tagID, childID) + } + + query := "INSERT INTO tags_relations (parent_id, child_id) VALUES " + strings.Join(values, ", ") + if _, err := tx.Exec(query, args...); err != nil { + return err + } + } + + return nil +} + +func (qb *tagQueryBuilder) FindAllAncestors(tagID int, excludeIDs []int) ([]*models.Tag, error) { + inBinding := getInBinding(len(excludeIDs) + 1) + + query := `WITH RECURSIVE +parents AS ( + SELECT t.id AS parent_id, t.id AS child_id FROM tags t WHERE t.id = ? + UNION + SELECT tr.parent_id, tr.child_id FROM tags_relations tr INNER JOIN parents p ON p.parent_id = tr.child_id WHERE tr.parent_id NOT IN` + inBinding + ` +), +children AS ( + SELECT tr.parent_id, tr.child_id FROM tags_relations tr INNER JOIN parents p ON p.parent_id = tr.parent_id WHERE tr.child_id NOT IN` + inBinding + ` + UNION + SELECT tr.parent_id, tr.child_id FROM tags_relations tr INNER JOIN children c ON c.child_id = tr.parent_id WHERE tr.child_id NOT IN` + inBinding + ` +) +SELECT t.* FROM tags t INNER JOIN parents p ON t.id = p.parent_id +UNION +SELECT t.* FROM tags t INNER JOIN children c ON t.id = c.child_id +` + + var ret models.Tags + excludeArgs := []interface{}{tagID} + for _, excludeID := range excludeIDs { + excludeArgs = append(excludeArgs, excludeID) + } + args := []interface{}{tagID} + args = append(args, append(append(excludeArgs, excludeArgs...), excludeArgs...)...) + if err := qb.query(query, args, &ret); err != nil { + return nil, err + } + + return ret, nil +} + +func (qb *tagQueryBuilder) FindAllDescendants(tagID int, excludeIDs []int) ([]*models.Tag, error) { + inBinding := getInBinding(len(excludeIDs) + 1) + + query := `WITH RECURSIVE +children AS ( + SELECT t.id AS parent_id, t.id AS child_id FROM tags t WHERE t.id = ? + UNION + SELECT tr.parent_id, tr.child_id FROM tags_relations tr INNER JOIN children c ON c.child_id = tr.parent_id WHERE tr.child_id NOT IN` + inBinding + ` +), +parents AS ( + SELECT tr.parent_id, tr.child_id FROM tags_relations tr INNER JOIN children c ON c.child_id = tr.child_id WHERE tr.parent_id NOT IN` + inBinding + ` + UNION + SELECT tr.parent_id, tr.child_id FROM tags_relations tr INNER JOIN parents p ON p.parent_id = tr.child_id WHERE tr.parent_id NOT IN` + inBinding + ` +) +SELECT t.* FROM tags t INNER JOIN children c ON t.id = c.child_id +UNION +SELECT t.* FROM tags t INNER JOIN parents p ON t.id = p.parent_id +` + + var ret models.Tags + excludeArgs := []interface{}{tagID} + for _, excludeID := range excludeIDs { + excludeArgs = append(excludeArgs, excludeID) + } + args := []interface{}{tagID} + args = append(args, append(append(excludeArgs, excludeArgs...), excludeArgs...)...) + if err := qb.query(query, args, &ret); err != nil { + return nil, err + } + + return ret, nil +} diff --git a/pkg/tag/export.go b/pkg/tag/export.go index ba9d6da82..54c64990e 100644 --- a/pkg/tag/export.go +++ b/pkg/tag/export.go @@ -32,6 +32,13 @@ func ToJSON(reader models.TagReader, tag *models.Tag) (*jsonschema.Tag, error) { newTagJSON.Image = utils.GetBase64StringFromData(image) } + parents, err := reader.FindByChildTagID(tag.ID) + if err != nil { + return nil, fmt.Errorf("error getting parents: %s", err.Error()) + } + + newTagJSON.Parents = GetNames(parents) + return &newTagJSON, nil } diff --git a/pkg/tag/export_test.go b/pkg/tag/export_test.go index 16e85292d..36b7a0855 100644 --- a/pkg/tag/export_test.go +++ b/pkg/tag/export_test.go @@ -13,10 +13,12 @@ import ( ) const ( - tagID = 1 - noImageID = 2 - errImageID = 3 - errAliasID = 4 + tagID = 1 + noImageID = 2 + errImageID = 3 + errAliasID = 4 + withParentsID = 5 + errParentsID = 6 ) const tagName = "testTag" @@ -37,7 +39,7 @@ func createTag(id int) models.Tag { } } -func createJSONTag(aliases []string, image string) *jsonschema.Tag { +func createJSONTag(aliases []string, image string, parents []string) *jsonschema.Tag { return &jsonschema.Tag{ Name: tagName, Aliases: aliases, @@ -47,7 +49,8 @@ func createJSONTag(aliases []string, image string) *jsonschema.Tag { UpdatedAt: models.JSONTime{ Time: updateTime, }, - Image: image, + Image: image, + Parents: parents, } } @@ -63,12 +66,12 @@ func initTestTable() { scenarios = []testScenario{ { createTag(tagID), - createJSONTag([]string{"alias"}, "PHN2ZwogICB4bWxuczpkYz0iaHR0cDovL3B1cmwub3JnL2RjL2VsZW1lbnRzLzEuMS8iCiAgIHhtbG5zOmNjPSJodHRwOi8vY3JlYXRpdmVjb21tb25zLm9yZy9ucyMiCiAgIHhtbG5zOnJkZj0iaHR0cDovL3d3dy53My5vcmcvMTk5OS8wMi8yMi1yZGYtc3ludGF4LW5zIyIKICAgeG1sbnM6c3ZnPSJodHRwOi8vd3d3LnczLm9yZy8yMDAwL3N2ZyIKICAgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIgogICB4bWxuczpzb2RpcG9kaT0iaHR0cDovL3NvZGlwb2RpLnNvdXJjZWZvcmdlLm5ldC9EVEQvc29kaXBvZGktMC5kdGQiCiAgIHhtbG5zOmlua3NjYXBlPSJodHRwOi8vd3d3Lmlua3NjYXBlLm9yZy9uYW1lc3BhY2VzL2lua3NjYXBlIgogICB3aWR0aD0iMjAwIgogICBoZWlnaHQ9IjIwMCIKICAgaWQ9InN2ZzIiCiAgIHZlcnNpb249IjEuMSIKICAgaW5rc2NhcGU6dmVyc2lvbj0iMC40OC40IHI5OTM5IgogICBzb2RpcG9kaTpkb2NuYW1lPSJ0YWcuc3ZnIj4KICA8ZGVmcwogICAgIGlkPSJkZWZzNCIgLz4KICA8c29kaXBvZGk6bmFtZWR2aWV3CiAgICAgaWQ9ImJhc2UiCiAgICAgcGFnZWNvbG9yPSIjMDAwMDAwIgogICAgIGJvcmRlcmNvbG9yPSIjNjY2NjY2IgogICAgIGJvcmRlcm9wYWNpdHk9IjEuMCIKICAgICBpbmtzY2FwZTpwYWdlb3BhY2l0eT0iMSIKICAgICBpbmtzY2FwZTpwYWdlc2hhZG93PSIyIgogICAgIGlua3NjYXBlOnpvb209IjEiCiAgICAgaW5rc2NhcGU6Y3g9IjE4MS43Nzc3MSIKICAgICBpbmtzY2FwZTpjeT0iMjc5LjcyMzc2IgogICAgIGlua3NjYXBlOmRvY3VtZW50LXVuaXRzPSJweCIKICAgICBpbmtzY2FwZTpjdXJyZW50LWxheWVyPSJsYXllcjEiCiAgICAgc2hvd2dyaWQ9ImZhbHNlIgogICAgIGZpdC1tYXJnaW4tdG9wPSIwIgogICAgIGZpdC1tYXJnaW4tbGVmdD0iMCIKICAgICBmaXQtbWFyZ2luLXJpZ2h0PSIwIgogICAgIGZpdC1tYXJnaW4tYm90dG9tPSIwIgogICAgIGlua3NjYXBlOndpbmRvdy13aWR0aD0iMTkyMCIKICAgICBpbmtzY2FwZTp3aW5kb3ctaGVpZ2h0PSIxMDE3IgogICAgIGlua3NjYXBlOndpbmRvdy14PSItOCIKICAgICBpbmtzY2FwZTp3aW5kb3cteT0iLTgiCiAgICAgaW5rc2NhcGU6d2luZG93LW1heGltaXplZD0iMSIgLz4KICA8bWV0YWRhdGEKICAgICBpZD0ibWV0YWRhdGE3Ij4KICAgIDxyZGY6UkRGPgogICAgICA8Y2M6V29yawogICAgICAgICByZGY6YWJvdXQ9IiI+CiAgICAgICAgPGRjOmZvcm1hdD5pbWFnZS9zdmcreG1sPC9kYzpmb3JtYXQ+CiAgICAgICAgPGRjOnR5cGUKICAgICAgICAgICByZGY6cmVzb3VyY2U9Imh0dHA6Ly9wdXJsLm9yZy9kYy9kY21pdHlwZS9TdGlsbEltYWdlIiAvPgogICAgICAgIDxkYzp0aXRsZT48L2RjOnRpdGxlPgogICAgICA8L2NjOldvcms+CiAgICA8L3JkZjpSREY+CiAgPC9tZXRhZGF0YT4KICA8ZwogICAgIGlua3NjYXBlOmxhYmVsPSJMYXllciAxIgogICAgIGlua3NjYXBlOmdyb3VwbW9kZT0ibGF5ZXIiCiAgICAgaWQ9ImxheWVyMSIKICAgICB0cmFuc2Zvcm09InRyYW5zbGF0ZSgtMTU3Ljg0MzU4LC01MjQuNjk1MjIpIj4KICAgIDxwYXRoCiAgICAgICBpZD0icGF0aDI5ODciCiAgICAgICBkPSJtIDIyOS45NDMxNCw2NjkuMjY1NDkgLTM2LjA4NDY2LC0zNi4wODQ2NiBjIC00LjY4NjUzLC00LjY4NjUzIC00LjY4NjUzLC0xMi4yODQ2OCAwLC0xNi45NzEyMSBsIDM2LjA4NDY2LC0zNi4wODQ2NyBhIDEyLjAwMDQ1MywxMi4wMDA0NTMgMCAwIDEgOC40ODU2LC0zLjUxNDggbCA3NC45MTQ0MywwIGMgNi42Mjc2MSwwIDEyLjAwMDQxLDUuMzcyOCAxMi4wMDA0MSwxMi4wMDA0MSBsIDAsNzIuMTY5MzMgYyAwLDYuNjI3NjEgLTUuMzcyOCwxMi4wMDA0MSAtMTIuMDAwNDEsMTIuMDAwNDEgbCAtNzQuOTE0NDMsMCBhIDEyLjAwMDQ1MywxMi4wMDA0NTMgMCAwIDEgLTguNDg1NiwtMy41MTQ4MSB6IG0gLTEzLjQ1NjM5LC01My4wNTU4NyBjIC00LjY4NjUzLDQuNjg2NTMgLTQuNjg2NTMsMTIuMjg0NjggMCwxNi45NzEyMSA0LjY4NjUyLDQuNjg2NTIgMTIuMjg0NjcsNC42ODY1MiAxNi45NzEyLDAgNC42ODY1MywtNC42ODY1MyA0LjY4NjUzLC0xMi4yODQ2OCAwLC0xNi45NzEyMSAtNC42ODY1MywtNC42ODY1MiAtMTIuMjg0NjgsLTQuNjg2NTIgLTE2Ljk3MTIsMCB6IgogICAgICAgaW5rc2NhcGU6Y29ubmVjdG9yLWN1cnZhdHVyZT0iMCIKICAgICAgIHN0eWxlPSJmaWxsOiNmZmZmZmY7ZmlsbC1vcGFjaXR5OjEiIC8+CiAgPC9nPgo8L3N2Zz4="), + createJSONTag([]string{"alias"}, image, nil), false, }, { createTag(noImageID), - createJSONTag(nil, ""), + createJSONTag(nil, "", nil), false, }, { @@ -81,6 +84,16 @@ func initTestTable() { nil, true, }, + { + createTag(withParentsID), + createJSONTag(nil, image, []string{"parent"}), + false, + }, + { + createTag(errParentsID), + nil, + true, + }, } } @@ -91,15 +104,25 @@ func TestToJSON(t *testing.T) { imageErr := errors.New("error getting image") aliasErr := errors.New("error getting aliases") + parentsErr := errors.New("error getting parents") mockTagReader.On("GetAliases", tagID).Return([]string{"alias"}, nil).Once() mockTagReader.On("GetAliases", noImageID).Return(nil, nil).Once() mockTagReader.On("GetAliases", errImageID).Return(nil, nil).Once() mockTagReader.On("GetAliases", errAliasID).Return(nil, aliasErr).Once() + mockTagReader.On("GetAliases", withParentsID).Return(nil, nil).Once() + mockTagReader.On("GetAliases", errParentsID).Return(nil, nil).Once() - mockTagReader.On("GetImage", tagID).Return(models.DefaultTagImage, nil).Once() + mockTagReader.On("GetImage", tagID).Return(imageBytes, nil).Once() mockTagReader.On("GetImage", noImageID).Return(nil, nil).Once() mockTagReader.On("GetImage", errImageID).Return(nil, imageErr).Once() + mockTagReader.On("GetImage", withParentsID).Return(imageBytes, nil).Once() + mockTagReader.On("GetImage", errParentsID).Return(nil, nil).Once() + + mockTagReader.On("FindByChildTagID", tagID).Return(nil, nil).Once() + mockTagReader.On("FindByChildTagID", noImageID).Return(nil, nil).Once() + mockTagReader.On("FindByChildTagID", withParentsID).Return([]*models.Tag{{Name: "parent"}}, nil).Once() + mockTagReader.On("FindByChildTagID", errParentsID).Return(nil, parentsErr).Once() for i, s := range scenarios { tag := s.tag diff --git a/pkg/tag/import.go b/pkg/tag/import.go index 54de4bd0e..8fe0410d2 100644 --- a/pkg/tag/import.go +++ b/pkg/tag/import.go @@ -8,9 +8,22 @@ import ( "github.com/stashapp/stash/pkg/utils" ) +type ParentTagNotExistError struct { + missingParent string +} + +func (e ParentTagNotExistError) Error() string { + return fmt.Sprintf("parent tag <%s> does not exist", e.missingParent) +} + +func (e ParentTagNotExistError) MissingParent() string { + return e.missingParent +} + type Importer struct { - ReaderWriter models.TagReaderWriter - Input jsonschema.Tag + ReaderWriter models.TagReaderWriter + Input jsonschema.Tag + MissingRefBehaviour models.ImportMissingRefEnum tag models.Tag imageData []byte @@ -45,6 +58,15 @@ func (i *Importer) PostImport(id int) error { return fmt.Errorf("error setting tag aliases: %s", err.Error()) } + parents, err := i.getParents() + if err != nil { + return err + } + + if err := i.ReaderWriter.UpdateParentTags(id, parents); err != nil { + return fmt.Errorf("error setting parents: %s", err.Error()) + } + return nil } @@ -87,3 +109,46 @@ func (i *Importer) Update(id int) error { return nil } + +func (i *Importer) getParents() ([]int, error) { + var parents []int + for _, parent := range i.Input.Parents { + tag, err := i.ReaderWriter.FindByName(parent, false) + if err != nil { + return nil, fmt.Errorf("error finding parent by name: %s", err.Error()) + } + + if tag == nil { + if i.MissingRefBehaviour == models.ImportMissingRefEnumFail { + return nil, ParentTagNotExistError{missingParent: parent} + } + + if i.MissingRefBehaviour == models.ImportMissingRefEnumIgnore { + continue + } + + if i.MissingRefBehaviour == models.ImportMissingRefEnumCreate { + parentID, err := i.createParent(parent) + if err != nil { + return nil, err + } + parents = append(parents, parentID) + } + } else { + parents = append(parents, tag.ID) + } + } + + return parents, nil +} + +func (i *Importer) createParent(name string) (int, error) { + newTag := *models.NewTag(name) + + created, err := i.ReaderWriter.Create(newTag) + if err != nil { + return 0, err + } + + return created.ID, nil +} diff --git a/pkg/tag/import_test.go b/pkg/tag/import_test.go index ea29e47c3..c2f29d8e5 100644 --- a/pkg/tag/import_test.go +++ b/pkg/tag/import_test.go @@ -8,6 +8,7 @@ import ( "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models/mocks" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" ) const image = "aW1hZ2VCeXRlcw==" @@ -64,13 +65,25 @@ func TestImporterPostImport(t *testing.T) { updateTagImageErr := errors.New("UpdateImage error") updateTagAliasErr := errors.New("UpdateAlias error") + updateTagParentsErr := errors.New("UpdateParentTags error") readerWriter.On("UpdateAliases", tagID, i.Input.Aliases).Return(nil).Once() readerWriter.On("UpdateAliases", errAliasID, i.Input.Aliases).Return(updateTagAliasErr).Once() + readerWriter.On("UpdateAliases", withParentsID, i.Input.Aliases).Return(nil).Once() + readerWriter.On("UpdateAliases", errParentsID, i.Input.Aliases).Return(nil).Once() readerWriter.On("UpdateImage", tagID, imageBytes).Return(nil).Once() readerWriter.On("UpdateImage", errAliasID, imageBytes).Return(nil).Once() readerWriter.On("UpdateImage", errImageID, imageBytes).Return(updateTagImageErr).Once() + readerWriter.On("UpdateImage", withParentsID, imageBytes).Return(nil).Once() + readerWriter.On("UpdateImage", errParentsID, imageBytes).Return(nil).Once() + + var parentTags []int + readerWriter.On("UpdateParentTags", tagID, parentTags).Return(nil).Once() + readerWriter.On("UpdateParentTags", withParentsID, []int{100}).Return(nil).Once() + readerWriter.On("UpdateParentTags", errParentsID, []int{100}).Return(updateTagParentsErr).Once() + + readerWriter.On("FindByName", "Parent", false).Return(&models.Tag{ID: 100}, nil) err := i.PostImport(tagID) assert.Nil(t, err) @@ -81,6 +94,106 @@ func TestImporterPostImport(t *testing.T) { err = i.PostImport(errAliasID) assert.NotNil(t, err) + i.Input.Parents = []string{"Parent"} + err = i.PostImport(withParentsID) + assert.Nil(t, err) + + err = i.PostImport(errParentsID) + assert.NotNil(t, err) + + readerWriter.AssertExpectations(t) +} + +func TestImporterPostImportParentMissing(t *testing.T) { + readerWriter := &mocks.TagReaderWriter{} + + i := Importer{ + ReaderWriter: readerWriter, + Input: jsonschema.Tag{}, + imageData: imageBytes, + } + + createID := 1 + createErrorID := 2 + createFindErrorID := 3 + createFoundID := 4 + failID := 5 + failFindErrorID := 6 + failFoundID := 7 + ignoreID := 8 + ignoreFindErrorID := 9 + ignoreFoundID := 10 + + findError := errors.New("failed finding parent") + + var emptyParents []int + + readerWriter.On("UpdateImage", mock.Anything, mock.Anything).Return(nil) + readerWriter.On("UpdateAliases", mock.Anything, mock.Anything).Return(nil) + + readerWriter.On("FindByName", "Create", false).Return(nil, nil).Once() + readerWriter.On("FindByName", "CreateError", false).Return(nil, nil).Once() + readerWriter.On("FindByName", "CreateFindError", false).Return(nil, findError).Once() + readerWriter.On("FindByName", "CreateFound", false).Return(&models.Tag{ID: 101}, nil).Once() + readerWriter.On("FindByName", "Fail", false).Return(nil, nil).Once() + readerWriter.On("FindByName", "FailFindError", false).Return(nil, findError) + readerWriter.On("FindByName", "FailFound", false).Return(&models.Tag{ID: 102}, nil).Once() + readerWriter.On("FindByName", "Ignore", false).Return(nil, nil).Once() + readerWriter.On("FindByName", "IgnoreFindError", false).Return(nil, findError) + readerWriter.On("FindByName", "IgnoreFound", false).Return(&models.Tag{ID: 103}, nil).Once() + + readerWriter.On("UpdateParentTags", createID, []int{100}).Return(nil).Once() + readerWriter.On("UpdateParentTags", createFoundID, []int{101}).Return(nil).Once() + readerWriter.On("UpdateParentTags", failFoundID, []int{102}).Return(nil).Once() + readerWriter.On("UpdateParentTags", ignoreID, emptyParents).Return(nil).Once() + readerWriter.On("UpdateParentTags", ignoreFoundID, []int{103}).Return(nil).Once() + + readerWriter.On("Create", mock.MatchedBy(func(t models.Tag) bool { return t.Name == "Create" })).Return(&models.Tag{ID: 100}, nil).Once() + readerWriter.On("Create", mock.MatchedBy(func(t models.Tag) bool { return t.Name == "CreateError" })).Return(nil, errors.New("failed creating parent")).Once() + + i.MissingRefBehaviour = models.ImportMissingRefEnumCreate + i.Input.Parents = []string{"Create"} + err := i.PostImport(createID) + assert.Nil(t, err) + + i.Input.Parents = []string{"CreateError"} + err = i.PostImport(createErrorID) + assert.NotNil(t, err) + + i.Input.Parents = []string{"CreateFindError"} + err = i.PostImport(createFindErrorID) + assert.NotNil(t, err) + + i.Input.Parents = []string{"CreateFound"} + err = i.PostImport(createFoundID) + assert.Nil(t, err) + + i.MissingRefBehaviour = models.ImportMissingRefEnumFail + i.Input.Parents = []string{"Fail"} + err = i.PostImport(failID) + assert.NotNil(t, err) + + i.Input.Parents = []string{"FailFindError"} + err = i.PostImport(failFindErrorID) + assert.NotNil(t, err) + + i.Input.Parents = []string{"FailFound"} + err = i.PostImport(failFoundID) + assert.Nil(t, err) + + i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore + i.Input.Parents = []string{"Ignore"} + err = i.PostImport(ignoreID) + assert.Nil(t, err) + + i.Input.Parents = []string{"IgnoreFindError"} + err = i.PostImport(ignoreFindErrorID) + assert.NotNil(t, err) + + i.Input.Parents = []string{"IgnoreFound"} + err = i.PostImport(ignoreFoundID) + assert.Nil(t, err) + readerWriter.AssertExpectations(t) } diff --git a/pkg/tag/update.go b/pkg/tag/update.go index 4f5b9b18b..d4a03f8f5 100644 --- a/pkg/tag/update.go +++ b/pkg/tag/update.go @@ -2,7 +2,6 @@ package tag import ( "fmt" - "github.com/stashapp/stash/pkg/models" ) @@ -23,6 +22,20 @@ func (e *NameUsedByAliasError) Error() string { return fmt.Sprintf("name '%s' is used as alias for '%s'", e.Name, e.OtherTag) } +type InvalidTagHierarchyError struct { + Direction string + InvalidTag string + ApplyingTag string +} + +func (e *InvalidTagHierarchyError) Error() string { + if e.InvalidTag == e.ApplyingTag { + return fmt.Sprintf("Cannot apply tag \"%s\" as it already is a %s", e.InvalidTag, e.Direction) + } else { + return fmt.Sprintf("Cannot apply tag \"%s\" as it is linked to \"%s\" which already is a %s", e.ApplyingTag, e.InvalidTag, e.Direction) + } +} + // EnsureTagNameUnique returns an error if the tag name provided // is used as a name or alias of another existing tag. func EnsureTagNameUnique(id int, name string, qb models.TagReader) error { @@ -63,3 +76,150 @@ func EnsureAliasesUnique(id int, aliases []string, qb models.TagReader) error { return nil } + +func EnsureUniqueHierarchy(id int, parentIDs, childIDs []int, qb models.TagReader) error { + allAncestors := make(map[int]*models.Tag) + allDescendants := make(map[int]*models.Tag) + excludeIDs := []int{id} + + validateParent := func(testID, applyingID int) error { + if parentTag, exists := allAncestors[testID]; exists { + applyingTag, err := qb.Find(applyingID) + + if err != nil { + return nil + } + + return &InvalidTagHierarchyError{ + Direction: "parent", + InvalidTag: parentTag.Name, + ApplyingTag: applyingTag.Name, + } + } + + return nil + } + + validateChild := func(testID, applyingID int) error { + if childTag, exists := allDescendants[testID]; exists { + applyingTag, err := qb.Find(applyingID) + + if err != nil { + return nil + } + + return &InvalidTagHierarchyError{ + Direction: "child", + InvalidTag: childTag.Name, + ApplyingTag: applyingTag.Name, + } + } + + return validateParent(testID, applyingID) + } + + if parentIDs == nil { + parentTags, err := qb.FindByChildTagID(id) + if err != nil { + return err + } + + for _, parentTag := range parentTags { + parentIDs = append(parentIDs, parentTag.ID) + } + } + + if childIDs == nil { + childTags, err := qb.FindByParentTagID(id) + if err != nil { + return err + } + + for _, childTag := range childTags { + childIDs = append(childIDs, childTag.ID) + } + } + + for _, parentID := range parentIDs { + parentsAncestors, err := qb.FindAllAncestors(parentID, excludeIDs) + if err != nil { + return err + } + + for _, ancestorTag := range parentsAncestors { + if err := validateParent(ancestorTag.ID, parentID); err != nil { + return err + } + + allAncestors[ancestorTag.ID] = ancestorTag + } + } + + for _, childID := range childIDs { + childsDescendants, err := qb.FindAllDescendants(childID, excludeIDs) + if err != nil { + return err + } + + for _, descendentTag := range childsDescendants { + if err := validateChild(descendentTag.ID, childID); err != nil { + return err + } + + allDescendants[descendentTag.ID] = descendentTag + } + } + + return nil +} + +func MergeHierarchy(destination int, sources []int, qb models.TagReader) ([]int, []int, error) { + var mergedParents, mergedChildren []int + allIds := append([]int{destination}, sources...) + + addTo := func(mergedItems []int, tags []*models.Tag) []int { + Tags: + for _, tag := range tags { + // Ignore tags which are already set + for _, existingItem := range mergedItems { + if tag.ID == existingItem { + continue Tags + } + } + + // Ignore tags which are being merged, as these are rolled up anyway (if A is merged into B any direct link between them can be ignored) + for _, id := range allIds { + if tag.ID == id { + continue Tags + } + } + + mergedItems = append(mergedItems, tag.ID) + } + + return mergedItems + } + + for _, id := range allIds { + parents, err := qb.FindByChildTagID(id) + if err != nil { + return nil, nil, err + } + + mergedParents = addTo(mergedParents, parents) + + children, err := qb.FindByParentTagID(id) + if err != nil { + return nil, nil, err + } + + mergedChildren = addTo(mergedChildren, children) + } + + err := EnsureUniqueHierarchy(destination, mergedParents, mergedChildren, qb) + if err != nil { + return nil, nil, err + } + + return mergedParents, mergedChildren, nil +} diff --git a/pkg/tag/update_test.go b/pkg/tag/update_test.go new file mode 100644 index 000000000..5a9cddf9d --- /dev/null +++ b/pkg/tag/update_test.go @@ -0,0 +1,302 @@ +package tag + +import ( + "fmt" + "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/models/mocks" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "testing" +) + +var testUniqueHierarchyTags = map[int]*models.Tag{ + 1: { + ID: 1, + Name: "one", + }, + 2: { + ID: 2, + Name: "two", + }, + 3: { + ID: 3, + Name: "three", + }, + 4: { + ID: 4, + Name: "four", + }, +} + +type testUniqueHierarchyCase struct { + id int + parents []*models.Tag + children []*models.Tag + + onFindAllAncestors map[int][]*models.Tag + onFindAllDescendants map[int][]*models.Tag + + expectedError string +} + +var testUniqueHierarchyCases = []testUniqueHierarchyCase{ + { + id: 1, + parents: []*models.Tag{}, + children: []*models.Tag{}, + onFindAllAncestors: map[int][]*models.Tag{ + 1: {}, + }, + onFindAllDescendants: map[int][]*models.Tag{ + 1: {}, + }, + expectedError: "", + }, + { + id: 1, + parents: []*models.Tag{testUniqueHierarchyTags[2]}, + children: []*models.Tag{testUniqueHierarchyTags[3]}, + onFindAllAncestors: map[int][]*models.Tag{ + 2: {testUniqueHierarchyTags[2]}, + }, + onFindAllDescendants: map[int][]*models.Tag{ + 3: {testUniqueHierarchyTags[3]}, + }, + expectedError: "", + }, + { + id: 2, + parents: []*models.Tag{testUniqueHierarchyTags[3]}, + children: make([]*models.Tag, 0), + onFindAllAncestors: map[int][]*models.Tag{ + 3: {testUniqueHierarchyTags[3]}, + }, + onFindAllDescendants: map[int][]*models.Tag{ + 2: {testUniqueHierarchyTags[2]}, + }, + expectedError: "", + }, + { + id: 2, + parents: []*models.Tag{ + testUniqueHierarchyTags[3], + testUniqueHierarchyTags[4], + }, + children: []*models.Tag{}, + onFindAllAncestors: map[int][]*models.Tag{ + 3: {testUniqueHierarchyTags[3], testUniqueHierarchyTags[4]}, + 4: {testUniqueHierarchyTags[4]}, + }, + onFindAllDescendants: map[int][]*models.Tag{ + 2: {testUniqueHierarchyTags[2]}, + }, + expectedError: "Cannot apply tag \"four\" as it already is a parent", + }, + { + id: 2, + parents: []*models.Tag{}, + children: []*models.Tag{testUniqueHierarchyTags[3]}, + onFindAllAncestors: map[int][]*models.Tag{ + 2: {testUniqueHierarchyTags[2]}, + }, + onFindAllDescendants: map[int][]*models.Tag{ + 3: {testUniqueHierarchyTags[3]}, + }, + expectedError: "", + }, + { + id: 2, + parents: []*models.Tag{}, + children: []*models.Tag{ + testUniqueHierarchyTags[3], + testUniqueHierarchyTags[4], + }, + onFindAllAncestors: map[int][]*models.Tag{ + 2: {testUniqueHierarchyTags[2]}, + }, + onFindAllDescendants: map[int][]*models.Tag{ + 3: {testUniqueHierarchyTags[3], testUniqueHierarchyTags[4]}, + 4: {testUniqueHierarchyTags[4]}, + }, + expectedError: "Cannot apply tag \"four\" as it already is a child", + }, + { + id: 1, + parents: []*models.Tag{testUniqueHierarchyTags[2]}, + children: []*models.Tag{testUniqueHierarchyTags[3]}, + onFindAllAncestors: map[int][]*models.Tag{ + 2: {testUniqueHierarchyTags[2], testUniqueHierarchyTags[3]}, + }, + onFindAllDescendants: map[int][]*models.Tag{ + 3: {testUniqueHierarchyTags[3]}, + }, + expectedError: "Cannot apply tag \"three\" as it already is a parent", + }, + { + id: 1, + parents: []*models.Tag{testUniqueHierarchyTags[2]}, + children: []*models.Tag{testUniqueHierarchyTags[3]}, + onFindAllAncestors: map[int][]*models.Tag{ + 2: {testUniqueHierarchyTags[2]}, + }, + onFindAllDescendants: map[int][]*models.Tag{ + 3: {testUniqueHierarchyTags[3], testUniqueHierarchyTags[2]}, + }, + expectedError: "Cannot apply tag \"three\" as it is linked to \"two\" which already is a parent", + }, + { + id: 1, + parents: []*models.Tag{testUniqueHierarchyTags[3]}, + children: []*models.Tag{testUniqueHierarchyTags[3]}, + onFindAllAncestors: map[int][]*models.Tag{ + 3: {testUniqueHierarchyTags[3]}, + }, + onFindAllDescendants: map[int][]*models.Tag{ + 3: {testUniqueHierarchyTags[3]}, + }, + expectedError: "Cannot apply tag \"three\" as it already is a parent", + }, + { + id: 1, + parents: []*models.Tag{ + testUniqueHierarchyTags[2], + }, + children: []*models.Tag{ + testUniqueHierarchyTags[3], + }, + onFindAllAncestors: map[int][]*models.Tag{ + 2: {testUniqueHierarchyTags[2]}, + }, + onFindAllDescendants: map[int][]*models.Tag{ + 3: {testUniqueHierarchyTags[3], testUniqueHierarchyTags[2]}, + }, + expectedError: "Cannot apply tag \"three\" as it is linked to \"two\" which already is a parent", + }, + { + id: 1, + parents: []*models.Tag{testUniqueHierarchyTags[2]}, + children: []*models.Tag{testUniqueHierarchyTags[2]}, + onFindAllAncestors: map[int][]*models.Tag{ + 2: {testUniqueHierarchyTags[2]}, + }, + onFindAllDescendants: map[int][]*models.Tag{ + 2: {testUniqueHierarchyTags[2]}, + }, + expectedError: "Cannot apply tag \"two\" as it already is a parent", + }, + { + id: 2, + parents: []*models.Tag{testUniqueHierarchyTags[1]}, + children: []*models.Tag{testUniqueHierarchyTags[3]}, + onFindAllAncestors: map[int][]*models.Tag{ + 1: {testUniqueHierarchyTags[1]}, + }, + onFindAllDescendants: map[int][]*models.Tag{ + 3: {testUniqueHierarchyTags[3], testUniqueHierarchyTags[1]}, + }, + expectedError: "Cannot apply tag \"three\" as it is linked to \"one\" which already is a parent", + }, +} + +func TestEnsureUniqueHierarchy(t *testing.T) { + for _, tc := range testUniqueHierarchyCases { + testEnsureUniqueHierarchy(t, tc, false, false) + testEnsureUniqueHierarchy(t, tc, true, false) + testEnsureUniqueHierarchy(t, tc, false, true) + testEnsureUniqueHierarchy(t, tc, true, true) + } +} + +func testEnsureUniqueHierarchy(t *testing.T, tc testUniqueHierarchyCase, queryParents, queryChildren bool) { + mockTagReader := &mocks.TagReaderWriter{} + + var parentIDs, childIDs []int + var find map[int]*models.Tag + find = make(map[int]*models.Tag) + if tc.parents != nil { + parentIDs = make([]int, 0) + for _, parent := range tc.parents { + if parent.ID != tc.id { + find[parent.ID] = parent + parentIDs = append(parentIDs, parent.ID) + } + } + } + + if tc.children != nil { + childIDs = make([]int, 0) + for _, child := range tc.children { + if child.ID != tc.id { + find[child.ID] = child + childIDs = append(childIDs, child.ID) + } + } + } + + if queryParents { + parentIDs = nil + mockTagReader.On("FindByChildTagID", tc.id).Return(tc.parents, nil).Once() + } + + if queryChildren { + childIDs = nil + mockTagReader.On("FindByParentTagID", tc.id).Return(tc.children, nil).Once() + } + + mockTagReader.On("Find", mock.AnythingOfType("int")).Return(func(tagID int) *models.Tag { + for id, tag := range find { + if id == tagID { + return tag + } + } + return nil + }, func(tagID int) error { + return nil + }).Maybe() + + mockTagReader.On("FindAllAncestors", mock.AnythingOfType("int"), []int{tc.id}).Return(func(tagID int, excludeIDs []int) []*models.Tag { + for id, tags := range tc.onFindAllAncestors { + if id == tagID { + return tags + } + } + return nil + }, func(tagID int, excludeIDs []int) error { + for id, _ := range tc.onFindAllAncestors { + if id == tagID { + return nil + } + } + return fmt.Errorf("undefined ancestors for: %d", tagID) + }).Maybe() + + mockTagReader.On("FindAllDescendants", mock.AnythingOfType("int"), []int{tc.id}).Return(func(tagID int, excludeIDs []int) []*models.Tag { + for id, tags := range tc.onFindAllDescendants { + if id == tagID { + return tags + } + } + return nil + }, func(tagID int, excludeIDs []int) error { + for id, _ := range tc.onFindAllDescendants { + if id == tagID { + return nil + } + } + return fmt.Errorf("undefined descendants for: %d", tagID) + }).Maybe() + + res := EnsureUniqueHierarchy(tc.id, parentIDs, childIDs, mockTagReader) + + assert := assert.New(t) + + if tc.expectedError != "" { + if assert.NotNil(res) { + assert.Equal(tc.expectedError, res.Error()) + } + } else { + assert.Nil(res) + } + + mockTagReader.AssertExpectations(t) +} diff --git a/ui/v2.5/src/components/Changelog/versions/v0100.md b/ui/v2.5/src/components/Changelog/versions/v0100.md index 5cb28b991..6fec97614 100644 --- a/ui/v2.5/src/components/Changelog/versions/v0100.md +++ b/ui/v2.5/src/components/Changelog/versions/v0100.md @@ -1,4 +1,6 @@ ### ✨ New Features +* Added support for Tag hierarchies. ([#1519](https://github.com/stashapp/stash/pull/1519)) +* Added native support for Apple Silicon / M1 Macs. ([#1646] https://github.com/stashapp/stash/pull/1646) * Added Movies to Scene bulk edit dialog. ([#1676](https://github.com/stashapp/stash/pull/1676)) * Added Movies tab to Studio and Performer pages. ([#1675](https://github.com/stashapp/stash/pull/1675)) * Support filtering Movies by Performers. ([#1675](https://github.com/stashapp/stash/pull/1675)) diff --git a/ui/v2.5/src/components/List/Filters/HierarchicalLabelValueFilter.tsx b/ui/v2.5/src/components/List/Filters/HierarchicalLabelValueFilter.tsx index 151a813c6..9105beefc 100644 --- a/ui/v2.5/src/components/List/Filters/HierarchicalLabelValueFilter.tsx +++ b/ui/v2.5/src/components/List/Filters/HierarchicalLabelValueFilter.tsx @@ -1,6 +1,6 @@ import React from "react"; import { Form } from "react-bootstrap"; -import { defineMessages, useIntl } from "react-intl"; +import { defineMessages, MessageDescriptor, useIntl } from "react-intl"; import { FilterSelect, ValidTypes } from "../../Shared"; import { Criterion } from "../../../models/list-filter/criteria/criterion"; import { IHierarchicalLabelValue } from "../../../models/list-filter/types"; @@ -49,6 +49,16 @@ export const HierarchicalLabelValueFilter: React.FC @@ -63,7 +73,7 @@ export const HierarchicalLabelValueFilter: React.FC onDepthChanged(criterion.value.depth !== 0 ? 0 : -1)} /> diff --git a/ui/v2.5/src/components/Shared/DeleteEntityDialog.tsx b/ui/v2.5/src/components/Shared/DeleteEntityDialog.tsx index 1ab0432f4..2ba92061c 100644 --- a/ui/v2.5/src/components/Shared/DeleteEntityDialog.tsx +++ b/ui/v2.5/src/components/Shared/DeleteEntityDialog.tsx @@ -20,6 +20,7 @@ interface IDeleteEntityDialogProps { singularEntity: string; pluralEntity: string; destroyMutation: DestroyMutation; + onDeleted?: () => void; } const messages = defineMessages({ @@ -43,6 +44,7 @@ const DeleteEntityDialog: React.FC = ({ singularEntity, pluralEntity, destroyMutation, + onDeleted, }) => { const intl = useIntl(); const Toast = useToast(); @@ -56,6 +58,9 @@ const DeleteEntityDialog: React.FC = ({ setIsDeleting(true); try { await deleteEntities(); + if (onDeleted) { + onDeleted(); + } Toast.success({ content: intl.formatMessage(messages.deleteToast, { count, diff --git a/ui/v2.5/src/components/Studios/StudioDetails/Studio.tsx b/ui/v2.5/src/components/Studios/StudioDetails/Studio.tsx index 52323bc87..c0481ec2c 100644 --- a/ui/v2.5/src/components/Studios/StudioDetails/Studio.tsx +++ b/ui/v2.5/src/components/Studios/StudioDetails/Studio.tsx @@ -283,7 +283,7 @@ export const Studio: React.FC = () => { diff --git a/ui/v2.5/src/components/Tags/TagDetails/Tag.tsx b/ui/v2.5/src/components/Tags/TagDetails/Tag.tsx index 094dc8f91..fe4f47673 100644 --- a/ui/v2.5/src/components/Tags/TagDetails/Tag.tsx +++ b/ui/v2.5/src/components/Tags/TagDetails/Tag.tsx @@ -21,6 +21,7 @@ import { Icon, } from "src/components/Shared"; import { useToast } from "src/hooks"; +import { tagRelationHook } from "src/core/tags"; import { TagScenesPanel } from "./TagScenesPanel"; import { TagMarkersPanel } from "./TagMarkersPanel"; import { TagImagesPanel } from "./TagImagesPanel"; @@ -123,6 +124,10 @@ export const Tag: React.FC = () => { input: Partial ) { try { + const oldRelations = { + parents: tag?.parents ?? [], + children: tag?.children ?? [], + }; if (!isNew) { const result = await updateTag({ variables: { @@ -131,7 +136,12 @@ export const Tag: React.FC = () => { }); if (result.data?.tagUpdate) { setIsEditing(false); - return result.data.tagUpdate.id; + const updated = result.data.tagUpdate; + tagRelationHook(updated, oldRelations, { + parents: updated.parents, + children: updated.children, + }); + return updated.id; } } else { const result = await createTag({ @@ -141,7 +151,12 @@ export const Tag: React.FC = () => { }); if (result.data?.tagCreate?.id) { setIsEditing(false); - return result.data.tagCreate.id; + const created = result.data.tagCreate; + tagRelationHook(created, oldRelations, { + parents: created.parents, + children: created.children, + }); + return created.id; } } } catch (e) { @@ -161,7 +176,15 @@ export const Tag: React.FC = () => { async function onDelete() { try { + const oldRelations = { + parents: tag?.parents ?? [], + children: tag?.children ?? [], + }; await deleteTag(); + tagRelationHook(tag as GQL.TagDataFragment, oldRelations, { + parents: [], + children: [], + }); } catch (e) { Toast.error(e); } diff --git a/ui/v2.5/src/components/Tags/TagDetails/TagDetailsPanel.tsx b/ui/v2.5/src/components/Tags/TagDetails/TagDetailsPanel.tsx index 92d007ebb..c02570b85 100644 --- a/ui/v2.5/src/components/Tags/TagDetails/TagDetailsPanel.tsx +++ b/ui/v2.5/src/components/Tags/TagDetails/TagDetailsPanel.tsx @@ -1,6 +1,7 @@ import React from "react"; import { Badge } from "react-bootstrap"; import { FormattedMessage } from "react-intl"; +import { Link } from "react-router-dom"; import * as GQL from "src/core/generated-graphql"; interface ITagDetails { @@ -29,5 +30,53 @@ export const TagDetailsPanel: React.FC = ({ tag }) => { ); } - return <>{renderAliasesField()}; + function renderParentsField() { + if (!tag.parents?.length) { + return; + } + + return ( +
+
+ +
+
+ {tag.parents.map((p) => ( + + {p.name} + + ))} +
+
+ ); + } + + function renderChildrenField() { + if (!tag.children?.length) { + return; + } + + return ( +
+
+ +
+
+ {tag.children.map((c) => ( + + {c.name} + + ))} +
+
+ ); + } + + return ( + <> + {renderAliasesField()} + {renderParentsField()} + {renderChildrenField()} + + ); }; diff --git a/ui/v2.5/src/components/Tags/TagDetails/TagEditPanel.tsx b/ui/v2.5/src/components/Tags/TagDetails/TagEditPanel.tsx index 9eb5321c0..86b53558b 100644 --- a/ui/v2.5/src/components/Tags/TagDetails/TagEditPanel.tsx +++ b/ui/v2.5/src/components/Tags/TagDetails/TagEditPanel.tsx @@ -2,9 +2,9 @@ import React, { useEffect } from "react"; import { FormattedMessage, useIntl } from "react-intl"; import * as GQL from "src/core/generated-graphql"; import * as yup from "yup"; -import { DetailsEditNavbar } from "src/components/Shared"; +import { DetailsEditNavbar, TagSelect } from "src/components/Shared"; import { Form, Col, Row } from "react-bootstrap"; -import { ImageUtils } from "src/utils"; +import { FormUtils, ImageUtils } from "src/utils"; import { useFormik } from "formik"; import { Prompt, useHistory } from "react-router-dom"; import Mousetrap from "mousetrap"; @@ -51,11 +51,15 @@ export const TagEditPanel: React.FC = ({ }, message: "aliases must be unique", }), + parent_ids: yup.array(yup.string().required()).optional().nullable(), + child_ids: yup.array(yup.string().required()).optional().nullable(), }); const initialValues = { name: tag?.name, aliases: tag?.aliases, + parent_ids: (tag?.parents ?? []).map((t) => t.id), + child_ids: (tag?.children ?? []).map((t) => t.id), }; type InputValues = typeof initialValues; @@ -153,6 +157,60 @@ export const TagEditPanel: React.FC = ({ /> + + + {FormUtils.renderLabel({ + title: intl.formatMessage({ id: "parent_tags" }), + labelProps: { + column: true, + sm: 3, + xl: 12, + }, + })} + + + formik.setFieldValue( + "parent_ids", + items.map((item) => item.id) + ) + } + ids={formik.values.parent_ids} + excludeIds={(tag?.id ? [tag.id] : []).concat( + ...formik.values.child_ids + )} + creatable={false} + /> + + + + + {FormUtils.renderLabel({ + title: intl.formatMessage({ id: "sub_tags" }), + labelProps: { + column: true, + sm: 3, + xl: 12, + }, + })} + + + formik.setFieldValue( + "child_ids", + items.map((item) => item.id) + ) + } + ids={formik.values.child_ids} + excludeIds={(tag?.id ? [tag.id] : []).concat( + ...formik.values.parent_ids + )} + creatable={false} + /> + + = ({ tag }) => { ) { // add the tag if not present if ( - !tagCriterion.value.find((p) => { + !tagCriterion.value.items.find((p) => { return p.id === tag.id; }) ) { - tagCriterion.value.push(tagValue); + tagCriterion.value.items.push(tagValue); } tagCriterion.modifier = GQL.CriterionModifier.IncludesAll; } else { // overwrite tagCriterion = new TagsCriterion(TagsCriterionOption); - tagCriterion.value = [tagValue]; + tagCriterion.value = { + items: [tagValue], + depth: 0, + }; filter.criteria.push(tagCriterion); } diff --git a/ui/v2.5/src/components/Tags/TagList.tsx b/ui/v2.5/src/components/Tags/TagList.tsx index d96320e85..88e804a82 100644 --- a/ui/v2.5/src/components/Tags/TagList.tsx +++ b/ui/v2.5/src/components/Tags/TagList.tsx @@ -24,6 +24,7 @@ import { NavUtils } from "src/utils"; import { Icon, Modal, DeleteEntityDialog } from "src/components/Shared"; import { TagCard } from "./TagCard"; import { ExportDialog } from "../Shared/ExportDialog"; +import { tagRelationHook } from "../../core/tags"; interface ITagList { filterHook?: (filter: ListFilterModel) => ListFilterModel; @@ -138,6 +139,15 @@ export const TagList: React.FC = ({ filterHook }) => { singularEntity={intl.formatMessage({ id: "tag" })} pluralEntity={intl.formatMessage({ id: "tags" })} destroyMutation={useTagsDestroy} + onDeleted={() => { + selectedTags.forEach((t) => + tagRelationHook( + t, + { parents: t.parents ?? [], children: t.children ?? [] }, + { parents: [], children: [] } + ) + ); + }} /> ); @@ -175,7 +185,15 @@ export const TagList: React.FC = ({ filterHook }) => { async function onDelete() { try { + const oldRelations = { + parents: deletingTag?.parents ?? [], + children: deletingTag?.children ?? [], + }; await deleteTag(); + tagRelationHook(deletingTag as GQL.TagDataFragment, oldRelations, { + parents: [], + children: [], + }); Toast.success({ content: intl.formatMessage( { id: "toast.delete_past_tense" }, diff --git a/ui/v2.5/src/core/createClient.ts b/ui/v2.5/src/core/createClient.ts index ce976be77..ec4766f6c 100644 --- a/ui/v2.5/src/core/createClient.ts +++ b/ui/v2.5/src/core/createClient.ts @@ -68,6 +68,17 @@ const typePolicies: TypePolicies = { }, }, }, + + Tag: { + fields: { + parents: { + merge: false, + }, + children: { + merge: false, + }, + }, + }, }; export const getPlatformURL = (ws?: boolean) => { diff --git a/ui/v2.5/src/core/tags.ts b/ui/v2.5/src/core/tags.ts index 0c0afed61..d7365b8ac 100644 --- a/ui/v2.5/src/core/tags.ts +++ b/ui/v2.5/src/core/tags.ts @@ -1,4 +1,6 @@ +import { gql } from "@apollo/client"; import * as GQL from "src/core/generated-graphql"; +import { getClient } from "src/core/StashService"; import { TagsCriterion, TagsCriterionOption, @@ -20,21 +22,83 @@ export const tagFilterHook = (tag: GQL.TagDataFragment) => { ) { // add the tag if not present if ( - !tagCriterion.value.find((p) => { + !tagCriterion.value.items.find((p) => { return p.id === tag.id; }) ) { - tagCriterion.value.push(tagValue); + tagCriterion.value.items.push(tagValue); } tagCriterion.modifier = GQL.CriterionModifier.IncludesAll; } else { // overwrite tagCriterion = new TagsCriterion(TagsCriterionOption); - tagCriterion.value = [tagValue]; + tagCriterion.value = { + items: [tagValue], + depth: 0, + }; filter.criteria.push(tagCriterion); } return filter; }; }; + +interface ITagRelationTuple { + parents: GQL.SlimTagDataFragment[]; + children: GQL.SlimTagDataFragment[]; +} + +export const tagRelationHook = ( + tag: GQL.SlimTagDataFragment | GQL.TagDataFragment, + old: ITagRelationTuple, + updated: ITagRelationTuple +) => { + const { cache } = getClient(); + + const tagRef = cache.writeFragment({ + data: tag, + fragment: gql` + fragment Tag on Tag { + id + } + `, + }); + + function updater( + property: "parents" | "children", + oldTags: GQL.SlimTagDataFragment[], + updatedTags: GQL.SlimTagDataFragment[] + ) { + oldTags.forEach((o) => { + if (!updatedTags.some((u) => u.id === o.id)) { + cache.modify({ + id: cache.identify(o), + fields: { + [property](value, { readField }) { + return value.filter( + (t: GQL.SlimTagDataFragment) => readField("id", t) !== tag.id + ); + }, + }, + }); + } + }); + + updatedTags.forEach((u) => { + if (!oldTags.some((o) => o.id === u.id)) { + cache.modify({ + id: cache.identify(u), + fields: { + [property](value) { + return [...value, tagRef]; + }, + }, + }); + } + }); + } + + updater("children", old.parents, updated.parents); + updater("parents", old.children, updated.children); +}; diff --git a/ui/v2.5/src/locales/de-DE.json b/ui/v2.5/src/locales/de-DE.json index 32bd74623..312206a08 100644 --- a/ui/v2.5/src/locales/de-DE.json +++ b/ui/v2.5/src/locales/de-DE.json @@ -91,7 +91,7 @@ "birthdate": "Geburtsdatum", "bitrate": "Bitrate", "career_length": "Länge der Karriere", - "child_studios": "Tochterstudios", + "subsidiary_studios": "Tochterstudios", "component_tagger": { "config": { "active_instance": "Aktive stash-box Instanz:", @@ -501,7 +501,6 @@ "image_count": "Bilderanzahl", "images": "Bilder", "images-size": "Bildgröße", - "include_child_studios": "Tochterstudios einbeziehen", "instagram": "Instagram", "interactive": "Interaktiv", "isMissing": "Wird vermisst", diff --git a/ui/v2.5/src/locales/en-GB.json b/ui/v2.5/src/locales/en-GB.json index 060032cdc..0251ff1be 100644 --- a/ui/v2.5/src/locales/en-GB.json +++ b/ui/v2.5/src/locales/en-GB.json @@ -95,7 +95,8 @@ "birthdate": "Birthdate", "bitrate": "Bit Rate", "career_length": "Career Length", - "child_studios": "Child Studios", + "subsidiary_studios": "Subsidiary Studios", + "sub_tags": "Sub-Tags", "component_tagger": { "config": { "active_instance": "Active stash-box instance:", @@ -529,7 +530,8 @@ "image_count": "Image Count", "images": "Images", "images-size": "Images size", - "include_child_studios": "Include child studios", + "include_sub_studios": "Include subsidiary studios", + "include_sub_tags": "Include sub-tags", "instagram": "Instagram", "interactive": "Interactive", "isMissing": "Is Missing", @@ -571,6 +573,7 @@ "previous": "Previous" }, "parent_studios": "Parent Studios", + "parent_tags": "Parent Tags", "path": "Path", "performer": "Performer", "performer_count": "Performer Count", diff --git a/ui/v2.5/src/locales/pt-BR.json b/ui/v2.5/src/locales/pt-BR.json index 64ce8e204..75d3af9ea 100644 --- a/ui/v2.5/src/locales/pt-BR.json +++ b/ui/v2.5/src/locales/pt-BR.json @@ -91,7 +91,7 @@ "birthdate": "Data de nascimento", "bitrate": "Taxa de bits", "career_length": "Duração da carreira", - "child_studios": "Estúdios filhos", + "subsidiary_studios": "Estúdios filhos", "component_tagger": { "config": { "active_instance": "Ativar stash-box:", @@ -501,7 +501,6 @@ "image_count": "Contagem de imagem", "images": "Imagens", "images-size": "Tamanho das imagens", - "include_child_studios": "Incluem estúdios filho", "instagram": "Instagram", "interactive": "Interativo", "isMissing": "Está faltando", diff --git a/ui/v2.5/src/locales/zh-CN.json b/ui/v2.5/src/locales/zh-CN.json index 3e12dd8d9..4f39bdc1f 100644 --- a/ui/v2.5/src/locales/zh-CN.json +++ b/ui/v2.5/src/locales/zh-CN.json @@ -95,7 +95,7 @@ "birthdate": "出生日期", "bitrate": "比特率", "career_length": "活跃年代", - "child_studios": "子工作室", + "subsidiary_studios": "子工作室", "component_tagger": { "config": { "active_instance": "目前使用的 Stash-box:", @@ -529,7 +529,6 @@ "image_count": "图片数量", "images": "图片", "images-size": "图片大小", - "include_child_studios": "包含子工作室", "instagram": "Instagram", "interactive": "互动", "isMissing": "缺失", diff --git a/ui/v2.5/src/locales/zh-TW.json b/ui/v2.5/src/locales/zh-TW.json index 07f443c52..999ddbaa3 100644 --- a/ui/v2.5/src/locales/zh-TW.json +++ b/ui/v2.5/src/locales/zh-TW.json @@ -95,7 +95,7 @@ "birthdate": "出生日期", "bitrate": "位元率", "career_length": "活躍年代", - "child_studios": "子工作室", + "subsidiary_studios": "子工作室", "component_tagger": { "config": { "active_instance": "目前使用的 Stash-box:", @@ -614,7 +614,6 @@ "filter": "過濾", "filter_name": "過濾條件名稱", "detail": "詳情", - "include_child_studios": "包含子工作室", "criterion": { "greater_than": "大於", "less_than": "小於", diff --git a/ui/v2.5/src/models/list-filter/criteria/tags.ts b/ui/v2.5/src/models/list-filter/criteria/tags.ts index 04545440e..50646a52a 100644 --- a/ui/v2.5/src/models/list-filter/criteria/tags.ts +++ b/ui/v2.5/src/models/list-filter/criteria/tags.ts @@ -1,6 +1,9 @@ -import { ILabeledIdCriterion, ILabeledIdCriterionOption } from "./criterion"; +import { + IHierarchicalLabeledIdCriterion, + ILabeledIdCriterionOption, +} from "./criterion"; -export class TagsCriterion extends ILabeledIdCriterion {} +export class TagsCriterion extends IHierarchicalLabeledIdCriterion {} export const TagsCriterionOption = new ILabeledIdCriterionOption( "tags", diff --git a/ui/v2.5/src/utils/navigation.ts b/ui/v2.5/src/utils/navigation.ts index f4af524cb..929b76654 100644 --- a/ui/v2.5/src/utils/navigation.ts +++ b/ui/v2.5/src/utils/navigation.ts @@ -143,7 +143,10 @@ const makeTagScenesUrl = (tag: Partial) => { if (!tag.id) return "#"; const filter = new ListFilterModel(GQL.FilterMode.Scenes); const criterion = new TagsCriterion(TagsCriterionOption); - criterion.value = [{ id: tag.id, label: tag.name || `Tag ${tag.id}` }]; + criterion.value = { + items: [{ id: tag.id, label: tag.name || `Tag ${tag.id}` }], + depth: 0, + }; filter.criteria.push(criterion); return `/scenes?${filter.makeQueryParameters()}`; }; @@ -152,7 +155,10 @@ const makeTagPerformersUrl = (tag: Partial) => { if (!tag.id) return "#"; const filter = new ListFilterModel(GQL.FilterMode.Performers); const criterion = new TagsCriterion(TagsCriterionOption); - criterion.value = [{ id: tag.id, label: tag.name || `Tag ${tag.id}` }]; + criterion.value = { + items: [{ id: tag.id, label: tag.name || `Tag ${tag.id}` }], + depth: 0, + }; filter.criteria.push(criterion); return `/performers?${filter.makeQueryParameters()}`; }; @@ -161,7 +167,10 @@ const makeTagSceneMarkersUrl = (tag: Partial) => { if (!tag.id) return "#"; const filter = new ListFilterModel(GQL.FilterMode.SceneMarkers); const criterion = new TagsCriterion(TagsCriterionOption); - criterion.value = [{ id: tag.id, label: tag.name || `Tag ${tag.id}` }]; + criterion.value = { + items: [{ id: tag.id, label: tag.name || `Tag ${tag.id}` }], + depth: 0, + }; filter.criteria.push(criterion); return `/scenes/markers?${filter.makeQueryParameters()}`; }; @@ -170,7 +179,10 @@ const makeTagGalleriesUrl = (tag: Partial) => { if (!tag.id) return "#"; const filter = new ListFilterModel(GQL.FilterMode.Galleries); const criterion = new TagsCriterion(TagsCriterionOption); - criterion.value = [{ id: tag.id, label: tag.name || `Tag ${tag.id}` }]; + criterion.value = { + items: [{ id: tag.id, label: tag.name || `Tag ${tag.id}` }], + depth: 0, + }; filter.criteria.push(criterion); return `/galleries?${filter.makeQueryParameters()}`; }; @@ -179,7 +191,10 @@ const makeTagImagesUrl = (tag: Partial) => { if (!tag.id) return "#"; const filter = new ListFilterModel(GQL.FilterMode.Images); const criterion = new TagsCriterion(TagsCriterionOption); - criterion.value = [{ id: tag.id, label: tag.name || `Tag ${tag.id}` }]; + criterion.value = { + items: [{ id: tag.id, label: tag.name || `Tag ${tag.id}` }], + depth: 0, + }; filter.criteria.push(criterion); return `/images?${filter.makeQueryParameters()}`; };