diff --git a/graphql/documents/data/tag.graphql b/graphql/documents/data/tag.graphql index cf1e2c050..ab83adf73 100644 --- a/graphql/documents/data/tag.graphql +++ b/graphql/documents/data/tag.graphql @@ -8,4 +8,12 @@ fragment TagData on Tag { image_count gallery_count performer_count + + parents { + ...SlimTagData + } + + children { + ...SlimTagData + } } diff --git a/graphql/schema/types/filters.graphql b/graphql/schema/types/filters.graphql index 76640bb24..04a508f06 100644 --- a/graphql/schema/types/filters.graphql +++ b/graphql/schema/types/filters.graphql @@ -72,7 +72,7 @@ input PerformerFilterType { """Filter to only include performers missing this property""" is_missing: String """Filter to only include performers with these tags""" - tags: MultiCriterionInput + tags: HierarchicalMultiCriterionInput """Filter by tag count""" tag_count: IntCriterionInput """Filter by scene count""" @@ -99,11 +99,11 @@ input PerformerFilterType { input SceneMarkerFilterType { """Filter to only include scene markers with this tag""" - tag_id: ID + tag_id: ID @deprecated(reason: "use tags filter instead") """Filter to only include scene markers with these tags""" - tags: MultiCriterionInput + tags: HierarchicalMultiCriterionInput """Filter to only include scene markers attached to a scene with these tags""" - scene_tags: MultiCriterionInput + scene_tags: HierarchicalMultiCriterionInput """Filter to only include scene markers with these performers""" performers: MultiCriterionInput } @@ -143,11 +143,11 @@ input SceneFilterType { """Filter to only include scenes with this movie""" movies: MultiCriterionInput """Filter to only include scenes with these tags""" - tags: MultiCriterionInput + tags: HierarchicalMultiCriterionInput """Filter by tag count""" tag_count: IntCriterionInput """Filter to only include scenes with performers with these tags""" - performer_tags: MultiCriterionInput + performer_tags: HierarchicalMultiCriterionInput """Filter to only include scenes with these performers""" performers: MultiCriterionInput """Filter by performer count""" @@ -226,11 +226,11 @@ input GalleryFilterType { """Filter to only include galleries with this studio""" studios: HierarchicalMultiCriterionInput """Filter to only include galleries with these tags""" - tags: MultiCriterionInput + tags: HierarchicalMultiCriterionInput """Filter by tag count""" tag_count: IntCriterionInput """Filter to only include galleries with performers with these tags""" - performer_tags: MultiCriterionInput + performer_tags: HierarchicalMultiCriterionInput """Filter to only include galleries with these performers""" performers: MultiCriterionInput """Filter by performer count""" @@ -295,11 +295,11 @@ input ImageFilterType { """Filter to only include images with this studio""" studios: HierarchicalMultiCriterionInput """Filter to only include images with these tags""" - tags: MultiCriterionInput + tags: HierarchicalMultiCriterionInput """Filter by tag count""" tag_count: IntCriterionInput """Filter to only include images with performers with these tags""" - performer_tags: MultiCriterionInput + performer_tags: HierarchicalMultiCriterionInput """Filter to only include images with these performers""" performers: MultiCriterionInput """Filter by performer count""" diff --git a/graphql/schema/types/tag.graphql b/graphql/schema/types/tag.graphql index 46f8c25cd..fea9f40fe 100644 --- a/graphql/schema/types/tag.graphql +++ b/graphql/schema/types/tag.graphql @@ -11,6 +11,9 @@ type Tag { image_count: Int # Resolver gallery_count: Int # Resolver performer_count: Int + + parents: [Tag!]! + children: [Tag!]! } input TagCreateInput { @@ -19,6 +22,9 @@ input TagCreateInput { """This should be a URL or a base64 encoded data URL""" image: String + + parent_ids: [ID!] + child_ids: [ID!] } input TagUpdateInput { @@ -28,6 +34,9 @@ input TagUpdateInput { """This should be a URL or a base64 encoded data URL""" image: String + + parent_ids: [ID!] + child_ids: [ID!] } input TagDestroyInput { diff --git a/pkg/api/resolver_model_tag.go b/pkg/api/resolver_model_tag.go index a863cd233..3b90463ca 100644 --- a/pkg/api/resolver_model_tag.go +++ b/pkg/api/resolver_model_tag.go @@ -10,6 +10,28 @@ import ( "github.com/stashapp/stash/pkg/models" ) +func (r *tagResolver) Parents(ctx context.Context, obj *models.Tag) (ret []*models.Tag, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + ret, err = repo.Tag().FindByChildTagID(obj.ID) + return err + }); err != nil { + return nil, err + } + + return ret, nil +} + +func (r *tagResolver) Children(ctx context.Context, obj *models.Tag) (ret []*models.Tag, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + ret, err = repo.Tag().FindByParentTagID(obj.ID) + return err + }); err != nil { + return nil, err + } + + return ret, nil +} + func (r *tagResolver) Aliases(ctx context.Context, obj *models.Tag) (ret []string, err error) { if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { ret, err = repo.Tag().GetAliases(obj.ID) diff --git a/pkg/api/resolver_mutation_tag.go b/pkg/api/resolver_mutation_tag.go index 94292a7ba..d1dc230e4 100644 --- a/pkg/api/resolver_mutation_tag.go +++ b/pkg/api/resolver_mutation_tag.go @@ -75,6 +75,28 @@ func (r *mutationResolver) TagCreate(ctx context.Context, input models.TagCreate } } + if input.ParentIds != nil && len(input.ParentIds) > 0 { + ids, err := utils.StringSliceToIntSlice(input.ParentIds) + if err != nil { + return err + } + + if err := qb.UpdateParentTags(t.ID, ids); err != nil { + return err + } + } + + if input.ChildIds != nil && len(input.ChildIds) > 0 { + ids, err := utils.StringSliceToIntSlice(input.ChildIds) + if err != nil { + return err + } + + if err := qb.UpdateChildTags(t.ID, ids); err != nil { + return err + } + } + return nil }); err != nil { return nil, err @@ -161,6 +183,41 @@ func (r *mutationResolver) TagUpdate(ctx context.Context, input models.TagUpdate } } + var parentIDs []int + var childIDs []int + + if translator.hasField("parent_ids") { + parentIDs, err = utils.StringSliceToIntSlice(input.ParentIds) + if err != nil { + return err + } + } + + if translator.hasField("child_ids") { + childIDs, err = utils.StringSliceToIntSlice(input.ChildIds) + if err != nil { + return err + } + } + + if parentIDs != nil || childIDs != nil { + if err := tag.EnsureUniqueHierarchy(tagID, parentIDs, childIDs, qb); err != nil { + return err + } + } + + if parentIDs != nil { + if err := qb.UpdateParentTags(tagID, parentIDs); err != nil { + return err + } + } + + if childIDs != nil { + if err := qb.UpdateChildTags(tagID, childIDs); err != nil { + return err + } + } + return nil }); err != nil { return nil, err @@ -242,10 +299,24 @@ func (r *mutationResolver) TagsMerge(ctx context.Context, input models.TagsMerge return fmt.Errorf("Tag with ID %d not found", destination) } + parents, children, err := tag.MergeHierarchy(destination, source, qb) + if err != nil { + return err + } + if err = qb.Merge(source, destination); err != nil { return err } + err = qb.UpdateParentTags(destination, parents) + if err != nil { + return err + } + err = qb.UpdateChildTags(destination, children) + if err != nil { + return err + } + return nil }); err != nil { return nil, err diff --git a/pkg/database/database.go b/pkg/database/database.go index cedcc9428..b4a6a8c29 100644 --- a/pkg/database/database.go +++ b/pkg/database/database.go @@ -23,7 +23,7 @@ import ( var DB *sqlx.DB var WriteMu *sync.Mutex var dbPath string -var appSchemaVersion uint = 25 +var appSchemaVersion uint = 26 var databaseSchemaVersion uint var ( diff --git a/pkg/database/migrations/26_tag_hierarchy.up.sql b/pkg/database/migrations/26_tag_hierarchy.up.sql new file mode 100644 index 000000000..6e9855fc2 --- /dev/null +++ b/pkg/database/migrations/26_tag_hierarchy.up.sql @@ -0,0 +1,7 @@ +CREATE TABLE tags_relations ( + parent_id integer, + child_id integer, + primary key (parent_id, child_id), + foreign key (parent_id) references tags(id) on delete cascade, + foreign key (child_id) references tags(id) on delete cascade +); diff --git a/pkg/dlna/cds.go b/pkg/dlna/cds.go index e99f63be8..79a80ab51 100644 --- a/pkg/dlna/cds.go +++ b/pkg/dlna/cds.go @@ -570,9 +570,10 @@ func (me *contentDirectoryService) getTags() []interface{} { func (me *contentDirectoryService) getTagScenes(paths []string, host string) []interface{} { sceneFilter := &models.SceneFilterType{ - Tags: &models.MultiCriterionInput{ + Tags: &models.HierarchicalMultiCriterionInput{ Modifier: models.CriterionModifierIncludes, Value: []string{paths[0]}, + Depth: 0, }, } diff --git a/pkg/gallery/query.go b/pkg/gallery/query.go index a354f5aa5..a7fbed856 100644 --- a/pkg/gallery/query.go +++ b/pkg/gallery/query.go @@ -31,9 +31,10 @@ func CountByStudioID(r models.GalleryReader, id int) (int, error) { func CountByTagID(r models.GalleryReader, id int) (int, error) { filter := &models.GalleryFilterType{ - Tags: &models.MultiCriterionInput{ + Tags: &models.HierarchicalMultiCriterionInput{ Value: []string{strconv.Itoa(id)}, Modifier: models.CriterionModifierIncludes, + Depth: 0, }, } diff --git a/pkg/image/query.go b/pkg/image/query.go index b5c70738c..eca6d49a4 100644 --- a/pkg/image/query.go +++ b/pkg/image/query.go @@ -31,9 +31,10 @@ func CountByStudioID(r models.ImageReader, id int) (int, error) { func CountByTagID(r models.ImageReader, id int) (int, error) { filter := &models.ImageFilterType{ - Tags: &models.MultiCriterionInput{ + Tags: &models.HierarchicalMultiCriterionInput{ Value: []string{strconv.Itoa(id)}, Modifier: models.CriterionModifierIncludes, + Depth: 0, }, } diff --git a/pkg/manager/jsonschema/tag.go b/pkg/manager/jsonschema/tag.go index 66fa910ab..fbb85a835 100644 --- a/pkg/manager/jsonschema/tag.go +++ b/pkg/manager/jsonschema/tag.go @@ -12,6 +12,7 @@ type Tag struct { Name string `json:"name,omitempty"` Aliases []string `json:"aliases,omitempty"` Image string `json:"image,omitempty"` + Parents []string `json:"parents,omitempty"` CreatedAt models.JSONTime `json:"created_at,omitempty"` UpdatedAt models.JSONTime `json:"updated_at,omitempty"` } diff --git a/pkg/manager/task_import.go b/pkg/manager/task_import.go index db3d9f76d..1bf2f3476 100644 --- a/pkg/manager/task_import.go +++ b/pkg/manager/task_import.go @@ -376,6 +376,7 @@ func (t *ImportTask) ImportGalleries(ctx context.Context) { } func (t *ImportTask) ImportTags(ctx context.Context) { + pendingParent := make(map[string][]*jsonschema.Tag) logger.Info("[tags] importing") for i, mappingJSON := range t.mappings.Tags { @@ -389,23 +390,64 @@ func (t *ImportTask) ImportTags(ctx context.Context) { logger.Progressf("[tags] %d of %d", index, len(t.mappings.Tags)) if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error { - readerWriter := r.Tag() - - tagImporter := &tag.Importer{ - ReaderWriter: readerWriter, - Input: *tagJSON, + return t.ImportTag(tagJSON, pendingParent, false, r.Tag()) + }); err != nil { + if parentError, ok := err.(tag.ParentTagNotExistError); ok { + pendingParent[parentError.MissingParent()] = append(pendingParent[parentError.MissingParent()], tagJSON) + continue } - return performImport(tagImporter, t.DuplicateBehaviour) - }); err != nil { logger.Errorf("[tags] <%s> failed to import: %s", mappingJSON.Checksum, err.Error()) continue } } + for _, s := range pendingParent { + for _, orphanTagJSON := range s { + if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error { + return t.ImportTag(orphanTagJSON, nil, true, r.Tag()) + }); err != nil { + logger.Errorf("[tags] <%s> failed to create: %s", orphanTagJSON.Name, err.Error()) + continue + } + } + } + logger.Info("[tags] import complete") } +func (t *ImportTask) ImportTag(tagJSON *jsonschema.Tag, pendingParent map[string][]*jsonschema.Tag, fail bool, readerWriter models.TagReaderWriter) error { + importer := &tag.Importer{ + ReaderWriter: readerWriter, + Input: *tagJSON, + MissingRefBehaviour: t.MissingRefBehaviour, + } + + // first phase: return error if parent does not exist + if !fail { + importer.MissingRefBehaviour = models.ImportMissingRefEnumFail + } + + if err := performImport(importer, t.DuplicateBehaviour); err != nil { + return err + } + + for _, childTagJSON := range pendingParent[tagJSON.Name] { + if err := t.ImportTag(childTagJSON, pendingParent, fail, readerWriter); err != nil { + if parentError, ok := err.(tag.ParentTagNotExistError); ok { + pendingParent[parentError.MissingParent()] = append(pendingParent[parentError.MissingParent()], tagJSON) + continue + } + + return fmt.Errorf("failed to create child tag <%s>: %s", childTagJSON.Name, err.Error()) + } + } + + delete(pendingParent, tagJSON.Name) + + return nil +} + func (t *ImportTask) ImportScrapedItems(ctx context.Context) { if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error { logger.Info("[scraped sites] importing") diff --git a/pkg/models/mocks/TagReaderWriter.go b/pkg/models/mocks/TagReaderWriter.go index f7cabb8ce..f6b257ed2 100644 --- a/pkg/models/mocks/TagReaderWriter.go +++ b/pkg/models/mocks/TagReaderWriter.go @@ -130,6 +130,75 @@ func (_m *TagReaderWriter) Find(id int) (*models.Tag, error) { return r0, r1 } +// FindAllAncestors provides a mock function with given fields: tagID, excludeIDs +func (_m *TagReaderWriter) FindAllAncestors(tagID int, excludeIDs []int) ([]*models.Tag, error) { + ret := _m.Called(tagID, excludeIDs) + + var r0 []*models.Tag + if rf, ok := ret.Get(0).(func(int, []int) []*models.Tag); ok { + r0 = rf(tagID, excludeIDs) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*models.Tag) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(int, []int) error); ok { + r1 = rf(tagID, excludeIDs) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// FindAllDescendants provides a mock function with given fields: tagID, excludeIDs +func (_m *TagReaderWriter) FindAllDescendants(tagID int, excludeIDs []int) ([]*models.Tag, error) { + ret := _m.Called(tagID, excludeIDs) + + var r0 []*models.Tag + if rf, ok := ret.Get(0).(func(int, []int) []*models.Tag); ok { + r0 = rf(tagID, excludeIDs) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*models.Tag) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(int, []int) error); ok { + r1 = rf(tagID, excludeIDs) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// FindByChildTagID provides a mock function with given fields: childID +func (_m *TagReaderWriter) FindByChildTagID(childID int) ([]*models.Tag, error) { + ret := _m.Called(childID) + + var r0 []*models.Tag + if rf, ok := ret.Get(0).(func(int) []*models.Tag); ok { + r0 = rf(childID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*models.Tag) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(int) error); ok { + r1 = rf(childID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // FindByGalleryID provides a mock function with given fields: galleryID func (_m *TagReaderWriter) FindByGalleryID(galleryID int) ([]*models.Tag, error) { ret := _m.Called(galleryID) @@ -222,6 +291,29 @@ func (_m *TagReaderWriter) FindByNames(names []string, nocase bool) ([]*models.T return r0, r1 } +// FindByParentTagID provides a mock function with given fields: parentID +func (_m *TagReaderWriter) FindByParentTagID(parentID int) ([]*models.Tag, error) { + ret := _m.Called(parentID) + + var r0 []*models.Tag + if rf, ok := ret.Get(0).(func(int) []*models.Tag); ok { + r0 = rf(parentID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*models.Tag) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(int) error); ok { + r1 = rf(parentID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // FindByPerformerID provides a mock function with given fields: performerID func (_m *TagReaderWriter) FindByPerformerID(performerID int) ([]*models.Tag, error) { ret := _m.Called(performerID) @@ -464,6 +556,20 @@ func (_m *TagReaderWriter) UpdateAliases(tagID int, aliases []string) error { return r0 } +// UpdateChildTags provides a mock function with given fields: tagID, parentIDs +func (_m *TagReaderWriter) UpdateChildTags(tagID int, parentIDs []int) error { + ret := _m.Called(tagID, parentIDs) + + var r0 error + if rf, ok := ret.Get(0).(func(int, []int) error); ok { + r0 = rf(tagID, parentIDs) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // UpdateFull provides a mock function with given fields: updatedTag func (_m *TagReaderWriter) UpdateFull(updatedTag models.Tag) (*models.Tag, error) { ret := _m.Called(updatedTag) @@ -500,3 +606,17 @@ func (_m *TagReaderWriter) UpdateImage(tagID int, image []byte) error { return r0 } + +// UpdateParentTags provides a mock function with given fields: tagID, parentIDs +func (_m *TagReaderWriter) UpdateParentTags(tagID int, parentIDs []int) error { + ret := _m.Called(tagID, parentIDs) + + var r0 error + if rf, ok := ret.Get(0).(func(int, []int) error); ok { + r0 = rf(tagID, parentIDs) + } else { + r0 = ret.Error(0) + } + + return r0 +} diff --git a/pkg/models/tag.go b/pkg/models/tag.go index 4d3e0d84e..bcb7096f0 100644 --- a/pkg/models/tag.go +++ b/pkg/models/tag.go @@ -10,6 +10,8 @@ type TagReader interface { FindByGalleryID(galleryID int) ([]*Tag, error) FindByName(name string, nocase bool) (*Tag, error) FindByNames(names []string, nocase bool) ([]*Tag, error) + FindByParentTagID(parentID int) ([]*Tag, error) + FindByChildTagID(childID int) ([]*Tag, error) Count() (int, error) All() ([]*Tag, error) // TODO - this interface is temporary until the filter schema can fully @@ -18,6 +20,8 @@ type TagReader interface { Query(tagFilter *TagFilterType, findFilter *FindFilterType) ([]*Tag, int, error) GetImage(tagID int) ([]byte, error) GetAliases(tagID int) ([]string, error) + FindAllAncestors(tagID int, excludeIDs []int) ([]*Tag, error) + FindAllDescendants(tagID int, excludeIDs []int) ([]*Tag, error) } type TagWriter interface { @@ -29,6 +33,8 @@ type TagWriter interface { DestroyImage(tagID int) error UpdateAliases(tagID int, aliases []string) error Merge(source []int, destination int) error + UpdateParentTags(tagID int, parentIDs []int) error + UpdateChildTags(tagID int, parentIDs []int) error } type TagReaderWriter interface { diff --git a/pkg/sqlite/filter.go b/pkg/sqlite/filter.go index d83ce12b0..0d47c2fb2 100644 --- a/pkg/sqlite/filter.go +++ b/pkg/sqlite/filter.go @@ -3,7 +3,9 @@ package sqlite import ( "errors" "fmt" + "github.com/stashapp/stash/pkg/logger" "regexp" + "strconv" "strings" "github.com/stashapp/stash/pkg/models" @@ -97,6 +99,7 @@ type filterBuilder struct { whereClauses []sqlClause havingClauses []sqlClause withClauses []sqlClause + recursiveWith bool err error } @@ -188,6 +191,16 @@ func (f *filterBuilder) addWith(sql string, args ...interface{}) { f.withClauses = append(f.withClauses, makeClause(sql, args...)) } +// addRecursiveWith adds a with clause and arguments to the filter, and sets it to recursive +func (f *filterBuilder) addRecursiveWith(sql string, args ...interface{}) { + if sql == "" { + return + } + + f.addWith(sql, args...) + f.recursiveWith = true +} + func (f *filterBuilder) getSubFilterClause(clause, subFilterClause string) string { ret := clause @@ -510,18 +523,41 @@ func (m *stringListCriterionHandlerBuilder) handler(criterion *models.StringCrit } type hierarchicalMultiCriterionHandlerBuilder struct { + tx dbi + primaryTable string foreignTable string foreignFK string - derivedTable string - parentFK string + derivedTable string + parentFK string + relationsTable string } -func addHierarchicalWithClause(f *filterBuilder, value []string, derivedTable, table, parentFK string, depth int) { +func getHierarchicalValues(tx dbi, values []string, table, relationsTable, parentFK string, depth int) string { var args []interface{} - for _, value := range value { + if depth == 0 { + valid := true + var valuesClauses []string + for _, value := range values { + id, err := strconv.Atoi(value) + // In case of invalid value just run the query. + // Building VALUES() based on provided values just saves a query when depth is 0. + if err != nil { + valid = false + break + } + + valuesClauses = append(valuesClauses, fmt.Sprintf("(%d,%d)", id, id)) + } + + if valid { + return "VALUES" + strings.Join(valuesClauses, ",") + } + } + + for _, value := range values { args = append(args, value) } inCount := len(args) @@ -532,45 +568,107 @@ func addHierarchicalWithClause(f *filterBuilder, value []string, derivedTable, t } withClauseMap := utils.StrFormatMap{ - "derivedTable": derivedTable, - "table": table, - "inBinding": getInBinding(inCount), - "parentFK": parentFK, - "depthCondition": depthCondition, - "unionClause": "", + "table": table, + "relationsTable": relationsTable, + "inBinding": getInBinding(inCount), + "recursiveSelect": "", + "parentFK": parentFK, + "depthCondition": depthCondition, + "unionClause": "", + } + + if relationsTable != "" { + withClauseMap["recursiveSelect"] = utils.StrFormat(`SELECT p.root_id, c.child_id, depth + 1 FROM {relationsTable} AS c +INNER JOIN items as p ON c.parent_id = p.item_id +`, withClauseMap) + } else { + withClauseMap["recursiveSelect"] = utils.StrFormat(`SELECT p.root_id, c.id, depth + 1 FROM {table} as c +INNER JOIN items as p ON c.{parentFK} = p.item_id +`, withClauseMap) } if depth != 0 { withClauseMap["unionClause"] = utils.StrFormat(` -UNION SELECT p.id, c.id, depth + 1 FROM {table} as c -INNER JOIN {derivedTable} as p ON c.{parentFK} = p.child_id {depthCondition} +UNION {recursiveSelect} {depthCondition} `, withClauseMap) } - withClause := utils.StrFormat(`RECURSIVE {derivedTable} AS ( -SELECT id as id, id as child_id, 0 as depth FROM {table} + withClause := utils.StrFormat(`items AS ( +SELECT id as root_id, id as item_id, 0 as depth FROM {table} WHERE id in {inBinding} {unionClause}) `, withClauseMap) - f.addWith(withClause, args...) + query := fmt.Sprintf("WITH RECURSIVE %s SELECT 'VALUES' || GROUP_CONCAT('(' || root_id || ', ' || item_id || ')') AS val FROM items", withClause) + + var valuesClause string + err := tx.Get(&valuesClause, query, args...) + if err != nil { + logger.Error(err) + // return record which never matches so we don't have to handle error here + return "VALUES(NULL, NULL)" + } + + return valuesClause +} + +func addHierarchicalConditionClauses(f *filterBuilder, criterion *models.HierarchicalMultiCriterionInput, table, idColumn string) { + if criterion.Modifier == models.CriterionModifierIncludes { + f.addWhere(fmt.Sprintf("%s.%s IS NOT NULL", table, idColumn)) + } else if criterion.Modifier == models.CriterionModifierIncludesAll { + f.addWhere(fmt.Sprintf("%s.%s IS NOT NULL", table, idColumn)) + f.addHaving(fmt.Sprintf("count(distinct %s.%s) IS %d", table, idColumn, len(criterion.Value))) + } else if criterion.Modifier == models.CriterionModifierExcludes { + f.addWhere(fmt.Sprintf("%s.%s IS NULL", table, idColumn)) + } } func (m *hierarchicalMultiCriterionHandlerBuilder) handler(criterion *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { return func(f *filterBuilder) { if criterion != nil && len(criterion.Value) > 0 { - addHierarchicalWithClause(f, criterion.Value, m.derivedTable, m.foreignTable, m.parentFK, criterion.Depth) + valuesClause := getHierarchicalValues(m.tx, criterion.Value, m.foreignTable, m.relationsTable, m.parentFK, criterion.Depth) - f.addJoin(m.derivedTable, "", fmt.Sprintf("%s.child_id = %s.%s", m.derivedTable, m.primaryTable, m.foreignFK)) + f.addJoin("(SELECT column1 AS root_id, column2 AS item_id FROM ("+valuesClause+"))", m.derivedTable, fmt.Sprintf("%s.item_id = %s.%s", m.derivedTable, m.primaryTable, m.foreignFK)) - if criterion.Modifier == models.CriterionModifierIncludes { - f.addWhere(fmt.Sprintf("%s.id IS NOT NULL", m.derivedTable)) - } else if criterion.Modifier == models.CriterionModifierIncludesAll { - f.addWhere(fmt.Sprintf("%s.id IS NOT NULL", m.derivedTable)) - f.addHaving(fmt.Sprintf("count(distinct %s.id) IS %d", m.derivedTable, len(criterion.Value))) - } else if criterion.Modifier == models.CriterionModifierExcludes { - f.addWhere(fmt.Sprintf("%s.id IS NULL", m.derivedTable)) - } + addHierarchicalConditionClauses(f, criterion, m.derivedTable, "root_id") + } + } +} + +type joinedHierarchicalMultiCriterionHandlerBuilder struct { + tx dbi + + primaryTable string + foreignTable string + foreignFK string + + parentFK string + relationsTable string + + joinAs string + joinTable string + primaryFK string +} + +func (m *joinedHierarchicalMultiCriterionHandlerBuilder) handler(criterion *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { + return func(f *filterBuilder) { + if criterion != nil && len(criterion.Value) > 0 { + valuesClause := getHierarchicalValues(m.tx, criterion.Value, m.foreignTable, m.relationsTable, m.parentFK, criterion.Depth) + + joinAlias := m.joinAs + joinTable := utils.StrFormat(`( + SELECT j.*, d.column1 AS root_id, d.column2 AS item_id FROM {joinTable} AS j + INNER JOIN ({valuesClause}) AS d ON j.{foreignFK} = d.column2 +) +`, utils.StrFormatMap{ + "joinTable": m.joinTable, + "foreignFK": m.foreignFK, + "valuesClause": valuesClause, + }) + + f.addJoin(joinTable, joinAlias, fmt.Sprintf("%s.%s = %s.id", joinAlias, m.primaryFK, m.primaryTable)) + + addHierarchicalConditionClauses(f, criterion, joinAlias, "root_id") } } } diff --git a/pkg/sqlite/gallery.go b/pkg/sqlite/gallery.go index 30f13c75c..2564c068b 100644 --- a/pkg/sqlite/gallery.go +++ b/pkg/sqlite/gallery.go @@ -311,17 +311,18 @@ func galleryIsMissingCriterionHandler(qb *galleryQueryBuilder, isMissing *string } } -func galleryTagsCriterionHandler(qb *galleryQueryBuilder, tags *models.MultiCriterionInput) criterionHandlerFunc { - h := joinedMultiCriterionHandlerBuilder{ - primaryTable: galleryTable, - joinTable: galleriesTagsTable, - joinAs: "tags_join", - primaryFK: galleryIDColumn, - foreignFK: tagIDColumn, +func galleryTagsCriterionHandler(qb *galleryQueryBuilder, tags *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { + h := joinedHierarchicalMultiCriterionHandlerBuilder{ + tx: qb.tx, - addJoinTable: func(f *filterBuilder) { - qb.tagsRepository().join(f, "tags_join", "galleries.id") - }, + primaryTable: galleryTable, + foreignTable: tagTable, + foreignFK: "tag_id", + + relationsTable: "tags_relations", + joinAs: "image_tag", + joinTable: galleriesTagsTable, + primaryFK: galleryIDColumn, } return h.handler(tags) @@ -375,6 +376,8 @@ func galleryImageCountCriterionHandler(qb *galleryQueryBuilder, imageCount *mode func galleryStudioCriterionHandler(qb *galleryQueryBuilder, studios *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { h := hierarchicalMultiCriterionHandlerBuilder{ + tx: qb.tx, + primaryTable: galleryTable, foreignTable: studioTable, foreignFK: studioIDColumn, @@ -385,31 +388,20 @@ func galleryStudioCriterionHandler(qb *galleryQueryBuilder, studios *models.Hier return h.handler(studios) } -func galleryPerformerTagsCriterionHandler(qb *galleryQueryBuilder, performerTagsFilter *models.MultiCriterionInput) criterionHandlerFunc { +func galleryPerformerTagsCriterionHandler(qb *galleryQueryBuilder, tags *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { return func(f *filterBuilder) { - if performerTagsFilter != nil && len(performerTagsFilter.Value) > 0 { - qb.performersRepository().join(f, "performers_join", "galleries.id") - f.addJoin("performers_tags", "performer_tags_join", "performers_join.performer_id = performer_tags_join.performer_id") + if tags != nil && len(tags.Value) > 0 { + valuesClause := getHierarchicalValues(qb.tx, tags.Value, tagTable, "tags_relations", "", tags.Depth) - var args []interface{} - for _, tagID := range performerTagsFilter.Value { - args = append(args, tagID) - } + f.addWith(`performer_tags AS ( +SELECT pg.gallery_id, t.column1 AS root_tag_id FROM performers_galleries pg +INNER JOIN performers_tags pt ON pt.performer_id = pg.performer_id +INNER JOIN (` + valuesClause + `) t ON t.column2 = pt.tag_id +)`) - if performerTagsFilter.Modifier == models.CriterionModifierIncludes { - // includes any of the provided ids - f.addWhere("performer_tags_join.tag_id IN "+getInBinding(len(performerTagsFilter.Value)), args...) - } else if performerTagsFilter.Modifier == models.CriterionModifierIncludesAll { - // includes all of the provided ids - f.addWhere("performer_tags_join.tag_id IN "+getInBinding(len(performerTagsFilter.Value)), args...) - f.addHaving(fmt.Sprintf("count(distinct performer_tags_join.tag_id) IS %d", len(performerTagsFilter.Value))) - } else if performerTagsFilter.Modifier == models.CriterionModifierExcludes { - f.addWhere(fmt.Sprintf(`not exists - (select performers_galleries.performer_id from performers_galleries - left join performers_tags on performers_tags.performer_id = performers_galleries.performer_id where - performers_galleries.gallery_id = galleries.id AND - performers_tags.tag_id in %s)`, getInBinding(len(performerTagsFilter.Value))), args...) - } + f.addJoin("performer_tags", "", "performer_tags.gallery_id = galleries.id") + + addHierarchicalConditionClauses(f, tags, "performer_tags", "root_tag_id") } } } diff --git a/pkg/sqlite/gallery_test.go b/pkg/sqlite/gallery_test.go index 03f03a544..4ba755b00 100644 --- a/pkg/sqlite/gallery_test.go +++ b/pkg/sqlite/gallery_test.go @@ -621,12 +621,13 @@ func TestGalleryQueryPerformers(t *testing.T) { func TestGalleryQueryTags(t *testing.T) { withTxn(func(r models.Repository) error { sqb := r.Gallery() - tagCriterion := models.MultiCriterionInput{ + tagCriterion := models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdxWithGallery]), strconv.Itoa(tagIDs[tagIdx1WithGallery]), }, Modifier: models.CriterionModifierIncludes, + Depth: 0, } galleryFilter := models.GalleryFilterType{ @@ -641,12 +642,13 @@ func TestGalleryQueryTags(t *testing.T) { assert.True(t, gallery.ID == galleryIDs[galleryIdxWithTag] || gallery.ID == galleryIDs[galleryIdxWithTwoTags]) } - tagCriterion = models.MultiCriterionInput{ + tagCriterion = models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdx1WithGallery]), strconv.Itoa(tagIDs[tagIdx2WithGallery]), }, Modifier: models.CriterionModifierIncludesAll, + Depth: 0, } galleries = queryGallery(t, sqb, &galleryFilter, nil) @@ -654,11 +656,12 @@ func TestGalleryQueryTags(t *testing.T) { assert.Len(t, galleries, 1) assert.Equal(t, galleryIDs[galleryIdxWithTwoTags], galleries[0].ID) - tagCriterion = models.MultiCriterionInput{ + tagCriterion = models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdx1WithGallery]), }, Modifier: models.CriterionModifierExcludes, + Depth: 0, } q := getGalleryStringValue(galleryIdxWithTwoTags, titleField) @@ -776,12 +779,13 @@ func TestGalleryQueryStudioDepth(t *testing.T) { func TestGalleryQueryPerformerTags(t *testing.T) { withTxn(func(r models.Repository) error { sqb := r.Gallery() - tagCriterion := models.MultiCriterionInput{ + tagCriterion := models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdxWithPerformer]), strconv.Itoa(tagIDs[tagIdx1WithPerformer]), }, Modifier: models.CriterionModifierIncludes, + Depth: 0, } galleryFilter := models.GalleryFilterType{ @@ -796,12 +800,13 @@ func TestGalleryQueryPerformerTags(t *testing.T) { assert.True(t, gallery.ID == galleryIDs[galleryIdxWithPerformerTag] || gallery.ID == galleryIDs[galleryIdxWithPerformerTwoTags]) } - tagCriterion = models.MultiCriterionInput{ + tagCriterion = models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdx1WithPerformer]), strconv.Itoa(tagIDs[tagIdx2WithPerformer]), }, Modifier: models.CriterionModifierIncludesAll, + Depth: 0, } galleries = queryGallery(t, sqb, &galleryFilter, nil) @@ -809,11 +814,12 @@ func TestGalleryQueryPerformerTags(t *testing.T) { assert.Len(t, galleries, 1) assert.Equal(t, galleryIDs[galleryIdxWithPerformerTwoTags], galleries[0].ID) - tagCriterion = models.MultiCriterionInput{ + tagCriterion = models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdx1WithPerformer]), }, Modifier: models.CriterionModifierExcludes, + Depth: 0, } q := getGalleryStringValue(galleryIdxWithPerformerTwoTags, titleField) diff --git a/pkg/sqlite/image.go b/pkg/sqlite/image.go index d0b6f16f8..01598c51a 100644 --- a/pkg/sqlite/image.go +++ b/pkg/sqlite/image.go @@ -348,17 +348,18 @@ func (qb *imageQueryBuilder) getMultiCriterionHandlerBuilder(foreignTable, joinT } } -func imageTagsCriterionHandler(qb *imageQueryBuilder, tags *models.MultiCriterionInput) criterionHandlerFunc { - h := joinedMultiCriterionHandlerBuilder{ - primaryTable: imageTable, - joinTable: imagesTagsTable, - joinAs: "tags_join", - primaryFK: imageIDColumn, - foreignFK: tagIDColumn, +func imageTagsCriterionHandler(qb *imageQueryBuilder, tags *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { + h := joinedHierarchicalMultiCriterionHandlerBuilder{ + tx: qb.tx, - addJoinTable: func(f *filterBuilder) { - qb.tagsRepository().join(f, "tags_join", "images.id") - }, + primaryTable: imageTable, + foreignTable: tagTable, + foreignFK: "tag_id", + + relationsTable: "tags_relations", + joinAs: "image_tag", + joinTable: imagesTagsTable, + primaryFK: imageIDColumn, } return h.handler(tags) @@ -412,6 +413,8 @@ func imagePerformerCountCriterionHandler(qb *imageQueryBuilder, performerCount * func imageStudioCriterionHandler(qb *imageQueryBuilder, studios *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { h := hierarchicalMultiCriterionHandlerBuilder{ + tx: qb.tx, + primaryTable: imageTable, foreignTable: studioTable, foreignFK: studioIDColumn, @@ -422,31 +425,20 @@ func imageStudioCriterionHandler(qb *imageQueryBuilder, studios *models.Hierarch return h.handler(studios) } -func imagePerformerTagsCriterionHandler(qb *imageQueryBuilder, performerTagsFilter *models.MultiCriterionInput) criterionHandlerFunc { +func imagePerformerTagsCriterionHandler(qb *imageQueryBuilder, tags *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { return func(f *filterBuilder) { - if performerTagsFilter != nil && len(performerTagsFilter.Value) > 0 { - qb.performersRepository().join(f, "performers_join", "images.id") - f.addJoin("performers_tags", "performer_tags_join", "performers_join.performer_id = performer_tags_join.performer_id") + if tags != nil && len(tags.Value) > 0 { + valuesClause := getHierarchicalValues(qb.tx, tags.Value, tagTable, "tags_relations", "", tags.Depth) - var args []interface{} - for _, tagID := range performerTagsFilter.Value { - args = append(args, tagID) - } + f.addWith(`performer_tags AS ( +SELECT pi.image_id, t.column1 AS root_tag_id FROM performers_images pi +INNER JOIN performers_tags pt ON pt.performer_id = pi.performer_id +INNER JOIN (` + valuesClause + `) t ON t.column2 = pt.tag_id +)`) - if performerTagsFilter.Modifier == models.CriterionModifierIncludes { - // includes any of the provided ids - f.addWhere("performer_tags_join.tag_id IN "+getInBinding(len(performerTagsFilter.Value)), args...) - } else if performerTagsFilter.Modifier == models.CriterionModifierIncludesAll { - // includes all of the provided ids - f.addWhere("performer_tags_join.tag_id IN "+getInBinding(len(performerTagsFilter.Value)), args...) - f.addHaving(fmt.Sprintf("count(distinct performer_tags_join.tag_id) IS %d", len(performerTagsFilter.Value))) - } else if performerTagsFilter.Modifier == models.CriterionModifierExcludes { - f.addWhere(fmt.Sprintf(`not exists - (select performers_images.performer_id from performers_images - left join performers_tags on performers_tags.performer_id = performers_images.performer_id where - performers_images.image_id = images.id AND - performers_tags.tag_id in %s)`, getInBinding(len(performerTagsFilter.Value))), args...) - } + f.addJoin("performer_tags", "", "performer_tags.image_id = images.id") + + addHierarchicalConditionClauses(f, tags, "performer_tags", "root_tag_id") } } } diff --git a/pkg/sqlite/image_test.go b/pkg/sqlite/image_test.go index 6d866041f..02e077c8d 100644 --- a/pkg/sqlite/image_test.go +++ b/pkg/sqlite/image_test.go @@ -722,12 +722,13 @@ func TestImageQueryPerformers(t *testing.T) { func TestImageQueryTags(t *testing.T) { withTxn(func(r models.Repository) error { sqb := r.Image() - tagCriterion := models.MultiCriterionInput{ + tagCriterion := models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdxWithImage]), strconv.Itoa(tagIDs[tagIdx1WithImage]), }, Modifier: models.CriterionModifierIncludes, + Depth: 0, } imageFilter := models.ImageFilterType{ @@ -746,12 +747,13 @@ func TestImageQueryTags(t *testing.T) { assert.True(t, image.ID == imageIDs[imageIdxWithTag] || image.ID == imageIDs[imageIdxWithTwoTags]) } - tagCriterion = models.MultiCriterionInput{ + tagCriterion = models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdx1WithImage]), strconv.Itoa(tagIDs[tagIdx2WithImage]), }, Modifier: models.CriterionModifierIncludesAll, + Depth: 0, } images, _, err = sqb.Query(&imageFilter, nil) @@ -762,11 +764,12 @@ func TestImageQueryTags(t *testing.T) { assert.Len(t, images, 1) assert.Equal(t, imageIDs[imageIdxWithTwoTags], images[0].ID) - tagCriterion = models.MultiCriterionInput{ + tagCriterion = models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdx1WithImage]), }, Modifier: models.CriterionModifierExcludes, + Depth: 0, } q := getImageStringValue(imageIdxWithTwoTags, titleField) @@ -902,12 +905,13 @@ func queryImages(t *testing.T, sqb models.ImageReader, imageFilter *models.Image func TestImageQueryPerformerTags(t *testing.T) { withTxn(func(r models.Repository) error { sqb := r.Image() - tagCriterion := models.MultiCriterionInput{ + tagCriterion := models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdxWithPerformer]), strconv.Itoa(tagIDs[tagIdx1WithPerformer]), }, Modifier: models.CriterionModifierIncludes, + Depth: 0, } imageFilter := models.ImageFilterType{ @@ -922,12 +926,13 @@ func TestImageQueryPerformerTags(t *testing.T) { assert.True(t, image.ID == imageIDs[imageIdxWithPerformerTag] || image.ID == imageIDs[imageIdxWithPerformerTwoTags]) } - tagCriterion = models.MultiCriterionInput{ + tagCriterion = models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdx1WithPerformer]), strconv.Itoa(tagIDs[tagIdx2WithPerformer]), }, Modifier: models.CriterionModifierIncludesAll, + Depth: 0, } images = queryImages(t, sqb, &imageFilter, nil) @@ -935,11 +940,12 @@ func TestImageQueryPerformerTags(t *testing.T) { assert.Len(t, images, 1) assert.Equal(t, imageIDs[imageIdxWithPerformerTwoTags], images[0].ID) - tagCriterion = models.MultiCriterionInput{ + tagCriterion = models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdx1WithPerformer]), }, Modifier: models.CriterionModifierExcludes, + Depth: 0, } q := getImageStringValue(imageIdxWithPerformerTwoTags, titleField) diff --git a/pkg/sqlite/movies.go b/pkg/sqlite/movies.go index 9e69b8451..92c3636e3 100644 --- a/pkg/sqlite/movies.go +++ b/pkg/sqlite/movies.go @@ -195,6 +195,8 @@ func movieIsMissingCriterionHandler(qb *movieQueryBuilder, isMissing *string) cr func movieStudioCriterionHandler(qb *movieQueryBuilder, studios *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { h := hierarchicalMultiCriterionHandlerBuilder{ + tx: qb.tx, + primaryTable: movieTable, foreignTable: studioTable, foreignFK: studioIDColumn, diff --git a/pkg/sqlite/performer.go b/pkg/sqlite/performer.go index 526a6e360..c8b3f86de 100644 --- a/pkg/sqlite/performer.go +++ b/pkg/sqlite/performer.go @@ -280,7 +280,7 @@ func (qb *performerQueryBuilder) makeFilter(filter *models.PerformerFilterType) query.handleCriterion(performerTagsCriterionHandler(qb, filter.Tags)) - query.handleCriterion(performerStudiosCriterionHandler(filter.Studios)) + query.handleCriterion(performerStudiosCriterionHandler(qb, filter.Studios)) query.handleCriterion(performerTagCountCriterionHandler(qb, filter.TagCount)) query.handleCriterion(performerSceneCountCriterionHandler(qb, filter.SceneCount)) @@ -376,17 +376,18 @@ func performerAgeFilterCriterionHandler(age *models.IntCriterionInput) criterion } } -func performerTagsCriterionHandler(qb *performerQueryBuilder, tags *models.MultiCriterionInput) criterionHandlerFunc { - h := joinedMultiCriterionHandlerBuilder{ - primaryTable: performerTable, - joinTable: performersTagsTable, - joinAs: "tags_join", - primaryFK: performerIDColumn, - foreignFK: tagIDColumn, +func performerTagsCriterionHandler(qb *performerQueryBuilder, tags *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { + h := joinedHierarchicalMultiCriterionHandlerBuilder{ + tx: qb.tx, - addJoinTable: func(f *filterBuilder) { - qb.tagsRepository().join(f, "tags_join", "performers.id") - }, + primaryTable: performerTable, + foreignTable: tagTable, + foreignFK: "tag_id", + + relationsTable: "tags_relations", + joinAs: "image_tag", + joinTable: performersTagsTable, + primaryFK: performerIDColumn, } return h.handler(tags) @@ -432,7 +433,7 @@ func performerGalleryCountCriterionHandler(qb *performerQueryBuilder, count *mod return h.handler(count) } -func performerStudiosCriterionHandler(studios *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { +func performerStudiosCriterionHandler(qb *performerQueryBuilder, studios *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { return func(f *filterBuilder) { if studios != nil { var clauseCondition string @@ -465,13 +466,13 @@ func performerStudiosCriterionHandler(studios *models.HierarchicalMultiCriterion }, } - const derivedStudioTable = "studio" const derivedPerformerStudioTable = "performer_studio" - addHierarchicalWithClause(f, studios.Value, derivedStudioTable, studioTable, "parent_id", studios.Depth) + valuesClause := getHierarchicalValues(qb.tx, studios.Value, studioTable, "", "parent_id", studios.Depth) + f.addWith("studio(root_id, item_id) AS (" + valuesClause + ")") templStr := `SELECT performer_id FROM {primaryTable} INNER JOIN {joinTable} ON {primaryTable}.id = {joinTable}.{primaryFK} - INNER JOIN studio ON {primaryTable}.studio_id = studio.child_id` + INNER JOIN studio ON {primaryTable}.studio_id = studio.item_id` var unions []string for _, c := range formatMaps { diff --git a/pkg/sqlite/performer_test.go b/pkg/sqlite/performer_test.go index a68e09e03..4ce990c35 100644 --- a/pkg/sqlite/performer_test.go +++ b/pkg/sqlite/performer_test.go @@ -500,12 +500,13 @@ func queryPerformers(t *testing.T, qb models.PerformerReader, performerFilter *m func TestPerformerQueryTags(t *testing.T) { withTxn(func(r models.Repository) error { sqb := r.Performer() - tagCriterion := models.MultiCriterionInput{ + tagCriterion := models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdxWithPerformer]), strconv.Itoa(tagIDs[tagIdx1WithPerformer]), }, Modifier: models.CriterionModifierIncludes, + Depth: 0, } performerFilter := models.PerformerFilterType{ @@ -519,12 +520,13 @@ func TestPerformerQueryTags(t *testing.T) { assert.True(t, performer.ID == performerIDs[performerIdxWithTag] || performer.ID == performerIDs[performerIdxWithTwoTags]) } - tagCriterion = models.MultiCriterionInput{ + tagCriterion = models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdx1WithPerformer]), strconv.Itoa(tagIDs[tagIdx2WithPerformer]), }, Modifier: models.CriterionModifierIncludesAll, + Depth: 0, } performers = queryPerformers(t, sqb, &performerFilter, nil) @@ -532,11 +534,12 @@ func TestPerformerQueryTags(t *testing.T) { assert.Len(t, performers, 1) assert.Equal(t, sceneIDs[performerIdxWithTwoTags], performers[0].ID) - tagCriterion = models.MultiCriterionInput{ + tagCriterion = models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdx1WithPerformer]), }, Modifier: models.CriterionModifierExcludes, + Depth: 0, } q := getSceneStringValue(performerIdxWithTwoTags, titleField) diff --git a/pkg/sqlite/query.go b/pkg/sqlite/query.go index 2d31ee583..158e7f971 100644 --- a/pkg/sqlite/query.go +++ b/pkg/sqlite/query.go @@ -18,6 +18,7 @@ type queryBuilder struct { havingClauses []string args []interface{} withClauses []string + recursiveWith bool sortAndPagination string @@ -32,7 +33,7 @@ func (qb queryBuilder) executeFind() ([]int, int, error) { body := qb.body body += qb.joins.toSQL() - return qb.repository.executeFindQuery(body, qb.args, qb.sortAndPagination, qb.whereClauses, qb.havingClauses, qb.withClauses) + return qb.repository.executeFindQuery(body, qb.args, qb.sortAndPagination, qb.whereClauses, qb.havingClauses, qb.withClauses, qb.recursiveWith) } func (qb queryBuilder) executeCount() (int, error) { @@ -45,7 +46,11 @@ func (qb queryBuilder) executeCount() (int, error) { withClause := "" if len(qb.withClauses) > 0 { - withClause = "WITH " + strings.Join(qb.withClauses, ", ") + " " + var recursive string + if qb.recursiveWith { + recursive = " RECURSIVE " + } + withClause = "WITH " + recursive + strings.Join(qb.withClauses, ", ") + " " } body = qb.repository.buildQueryBody(body, qb.whereClauses, qb.havingClauses) @@ -69,12 +74,14 @@ func (qb *queryBuilder) addHaving(clauses ...string) { } } -func (qb *queryBuilder) addWith(clauses ...string) { +func (qb *queryBuilder) addWith(recursive bool, clauses ...string) { for _, clause := range clauses { if len(clause) > 0 { qb.withClauses = append(qb.withClauses, clause) } } + + qb.recursiveWith = qb.recursiveWith || recursive } func (qb *queryBuilder) addArg(args ...interface{}) { @@ -104,7 +111,7 @@ func (qb *queryBuilder) addFilter(f *filterBuilder) { clause, args := f.generateWithClauses() if len(clause) > 0 { - qb.addWith(clause) + qb.addWith(f.recursiveWith, clause) } if len(args) > 0 { diff --git a/pkg/sqlite/repository.go b/pkg/sqlite/repository.go index 6550d0d67..ffe5b78ad 100644 --- a/pkg/sqlite/repository.go +++ b/pkg/sqlite/repository.go @@ -246,12 +246,16 @@ func (r *repository) buildQueryBody(body string, whereClauses []string, havingCl return body } -func (r *repository) executeFindQuery(body string, args []interface{}, sortAndPagination string, whereClauses []string, havingClauses []string, withClauses []string) ([]int, int, error) { +func (r *repository) executeFindQuery(body string, args []interface{}, sortAndPagination string, whereClauses []string, havingClauses []string, withClauses []string, recursiveWith bool) ([]int, int, error) { body = r.buildQueryBody(body, whereClauses, havingClauses) withClause := "" if len(withClauses) > 0 { - withClause = "WITH " + strings.Join(withClauses, ", ") + " " + var recursive string + if recursiveWith { + recursive = " RECURSIVE " + } + withClause = "WITH " + recursive + strings.Join(withClauses, ", ") + " " } countQuery := withClause + r.buildCountQuery(body) diff --git a/pkg/sqlite/scene.go b/pkg/sqlite/scene.go index 92cdd5c2c..5aa0d4722 100644 --- a/pkg/sqlite/scene.go +++ b/pkg/sqlite/scene.go @@ -548,12 +548,19 @@ func (qb *sceneQueryBuilder) getMultiCriterionHandlerBuilder(foreignTable, joinT } } -func sceneTagsCriterionHandler(qb *sceneQueryBuilder, tags *models.MultiCriterionInput) criterionHandlerFunc { - addJoinsFunc := func(f *filterBuilder) { - qb.tagsRepository().join(f, "tags_join", "scenes.id") - f.addJoin("tags", "", "tags_join.tag_id = tags.id") +func sceneTagsCriterionHandler(qb *sceneQueryBuilder, tags *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { + h := joinedHierarchicalMultiCriterionHandlerBuilder{ + tx: qb.tx, + + primaryTable: sceneTable, + foreignTable: tagTable, + foreignFK: "tag_id", + + relationsTable: "tags_relations", + joinAs: "scene_tag", + joinTable: scenesTagsTable, + primaryFK: sceneIDColumn, } - h := qb.getMultiCriterionHandlerBuilder(tagTable, scenesTagsTable, tagIDColumn, addJoinsFunc) return h.handler(tags) } @@ -596,6 +603,8 @@ func scenePerformerCountCriterionHandler(qb *sceneQueryBuilder, performerCount * func sceneStudioCriterionHandler(qb *sceneQueryBuilder, studios *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { h := hierarchicalMultiCriterionHandlerBuilder{ + tx: qb.tx, + primaryTable: sceneTable, foreignTable: studioTable, foreignFK: studioIDColumn, @@ -615,31 +624,20 @@ func sceneMoviesCriterionHandler(qb *sceneQueryBuilder, movies *models.MultiCrit return h.handler(movies) } -func scenePerformerTagsCriterionHandler(qb *sceneQueryBuilder, performerTagsFilter *models.MultiCriterionInput) criterionHandlerFunc { +func scenePerformerTagsCriterionHandler(qb *sceneQueryBuilder, tags *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { return func(f *filterBuilder) { - if performerTagsFilter != nil && len(performerTagsFilter.Value) > 0 { - qb.performersRepository().join(f, "performers_join", "scenes.id") - f.addJoin("performers_tags", "performer_tags_join", "performers_join.performer_id = performer_tags_join.performer_id") + if tags != nil && len(tags.Value) > 0 { + valuesClause := getHierarchicalValues(qb.tx, tags.Value, tagTable, "tags_relations", "", tags.Depth) - var args []interface{} - for _, tagID := range performerTagsFilter.Value { - args = append(args, tagID) - } + f.addWith(`performer_tags AS ( +SELECT ps.scene_id, t.column1 AS root_tag_id FROM performers_scenes ps +INNER JOIN performers_tags pt ON pt.performer_id = ps.performer_id +INNER JOIN (` + valuesClause + `) t ON t.column2 = pt.tag_id +)`) - if performerTagsFilter.Modifier == models.CriterionModifierIncludes { - // includes any of the provided ids - f.addWhere("performer_tags_join.tag_id IN "+getInBinding(len(performerTagsFilter.Value)), args...) - } else if performerTagsFilter.Modifier == models.CriterionModifierIncludesAll { - // includes all of the provided ids - f.addWhere("performer_tags_join.tag_id IN "+getInBinding(len(performerTagsFilter.Value)), args...) - f.addHaving(fmt.Sprintf("count(distinct performer_tags_join.tag_id) IS %d", len(performerTagsFilter.Value))) - } else if performerTagsFilter.Modifier == models.CriterionModifierExcludes { - f.addWhere(fmt.Sprintf(`not exists - (select performers_scenes.performer_id from performers_scenes - left join performers_tags on performers_tags.performer_id = performers_scenes.performer_id where - performers_scenes.scene_id = scenes.id AND - performers_tags.tag_id in %s)`, getInBinding(len(performerTagsFilter.Value))), args...) - } + f.addJoin("performer_tags", "", "performer_tags.scene_id = scenes.id") + + addHierarchicalConditionClauses(f, tags, "performer_tags", "root_tag_id") } } } diff --git a/pkg/sqlite/scene_marker.go b/pkg/sqlite/scene_marker.go index 0ad3cda43..392bfc75d 100644 --- a/pkg/sqlite/scene_marker.go +++ b/pkg/sqlite/scene_marker.go @@ -3,8 +3,6 @@ package sqlite import ( "database/sql" "fmt" - "strconv" - "github.com/stashapp/stash/pkg/database" "github.com/stashapp/stash/pkg/models" ) @@ -127,6 +125,17 @@ func (qb *sceneMarkerQueryBuilder) Wall(q *string) ([]*models.SceneMarker, error return qb.querySceneMarkers(query, nil) } +func (qb *sceneMarkerQueryBuilder) makeFilter(sceneMarkerFilter *models.SceneMarkerFilterType) *filterBuilder { + query := &filterBuilder{} + + query.handleCriterion(sceneMarkerTagIDCriterionHandler(qb, sceneMarkerFilter.TagID)) + query.handleCriterion(sceneMarkerTagsCriterionHandler(qb, sceneMarkerFilter.Tags)) + query.handleCriterion(sceneMarkerSceneTagsCriterionHandler(qb, sceneMarkerFilter.SceneTags)) + query.handleCriterion(sceneMarkerPerformersCriterionHandler(qb, sceneMarkerFilter.Performers)) + + return query +} + func (qb *sceneMarkerQueryBuilder) Query(sceneMarkerFilter *models.SceneMarkerFilterType, findFilter *models.FindFilterType) ([]*models.SceneMarker, int, error) { if sceneMarkerFilter == nil { sceneMarkerFilter = &models.SceneMarkerFilterType{} @@ -135,121 +144,23 @@ func (qb *sceneMarkerQueryBuilder) Query(sceneMarkerFilter *models.SceneMarkerFi findFilter = &models.FindFilterType{} } - var whereClauses []string - var havingClauses []string - var args []interface{} - body := selectDistinctIDs("scene_markers") - body = body + ` - left join tags as primary_tag on primary_tag.id = scene_markers.primary_tag_id - left join scenes as scene on scene.id = scene_markers.scene_id - left join scene_markers_tags as tags_join on tags_join.scene_marker_id = scene_markers.id - left join tags on tags_join.tag_id = tags.id - ` + query := qb.newQuery() - if tagsFilter := sceneMarkerFilter.Tags; tagsFilter != nil && len(tagsFilter.Value) > 0 { - //select `scene_markers`.* from `scene_markers` - //left join `tags` as `primary_tags_join` - // on `primary_tags_join`.`id` = `scene_markers`.`primary_tag_id` - // and `primary_tags_join`.`id` in ('3', '37', '9', '89') - //left join `scene_markers_tags` as `tags_join` - // on `tags_join`.`scene_marker_id` = `scene_markers`.`id` - // and `tags_join`.`tag_id` in ('3', '37', '9', '89') - //group by `scene_markers`.`id` - //having ((count(distinct `primary_tags_join`.`id`) + count(distinct `tags_join`.`tag_id`)) = 4) - - length := len(tagsFilter.Value) - - if tagsFilter.Modifier == models.CriterionModifierIncludes || tagsFilter.Modifier == models.CriterionModifierIncludesAll { - body += " LEFT JOIN tags AS ptj ON ptj.id = scene_markers.primary_tag_id AND ptj.id IN " + getInBinding(length) - body += " LEFT JOIN scene_markers_tags AS tj ON tj.scene_marker_id = scene_markers.id AND tj.tag_id IN " + getInBinding(length) - - // only one required for include any - requiredCount := 1 - - // all required for include all - if tagsFilter.Modifier == models.CriterionModifierIncludesAll { - requiredCount = length - } - - havingClauses = append(havingClauses, "((COUNT(DISTINCT ptj.id) + COUNT(DISTINCT tj.tag_id)) >= "+strconv.Itoa(requiredCount)+")") - } else if tagsFilter.Modifier == models.CriterionModifierExcludes { - // excludes all of the provided ids - whereClauses = append(whereClauses, "scene_markers.primary_tag_id not in "+getInBinding(length)) - whereClauses = append(whereClauses, "not exists (select smt.scene_marker_id from scene_markers_tags as smt where smt.scene_marker_id = scene_markers.id and smt.tag_id in "+getInBinding(length)+")") - } - - for _, tagID := range tagsFilter.Value { - args = append(args, tagID) - } - for _, tagID := range tagsFilter.Value { - args = append(args, tagID) - } - } - - if sceneTagsFilter := sceneMarkerFilter.SceneTags; sceneTagsFilter != nil && len(sceneTagsFilter.Value) > 0 { - length := len(sceneTagsFilter.Value) - - if sceneTagsFilter.Modifier == models.CriterionModifierIncludes || sceneTagsFilter.Modifier == models.CriterionModifierIncludesAll { - body += " LEFT JOIN scenes_tags AS scene_tags_join ON scene_tags_join.scene_id = scene.id AND scene_tags_join.tag_id IN " + getInBinding(length) - - // only one required for include any - requiredCount := 1 - - // all required for include all - if sceneTagsFilter.Modifier == models.CriterionModifierIncludesAll { - requiredCount = length - } - - havingClauses = append(havingClauses, "COUNT(DISTINCT scene_tags_join.tag_id) >= "+strconv.Itoa(requiredCount)) - } else if sceneTagsFilter.Modifier == models.CriterionModifierExcludes { - // excludes all of the provided ids - whereClauses = append(whereClauses, "not exists (select st.scene_id from scenes_tags as st where st.scene_id = scene.id AND st.tag_id IN "+getInBinding(length)+")") - } - - for _, tagID := range sceneTagsFilter.Value { - args = append(args, tagID) - } - } - - if performersFilter := sceneMarkerFilter.Performers; performersFilter != nil && len(performersFilter.Value) > 0 { - length := len(performersFilter.Value) - - if performersFilter.Modifier == models.CriterionModifierIncludes || performersFilter.Modifier == models.CriterionModifierIncludesAll { - body += " LEFT JOIN performers_scenes as scene_performers ON scene.id = scene_performers.scene_id" - whereClauses = append(whereClauses, "scene_performers.performer_id IN "+getInBinding(length)) - - // only one required for include any - requiredCount := 1 - - // all required for include all - if performersFilter.Modifier == models.CriterionModifierIncludesAll { - requiredCount = length - } - - havingClauses = append(havingClauses, "COUNT(DISTINCT scene_performers.performer_id) >= "+strconv.Itoa(requiredCount)) - } else if performersFilter.Modifier == models.CriterionModifierExcludes { - // excludes all of the provided ids - whereClauses = append(whereClauses, "not exists (select sp.scene_id from performers_scenes as sp where sp.scene_id = scene.id AND sp.performer_id IN "+getInBinding(length)+")") - } - - for _, performerID := range performersFilter.Value { - args = append(args, performerID) - } - } + query.body = selectDistinctIDs("scene_markers") if q := findFilter.Q; q != nil && *q != "" { searchColumns := []string{"scene_markers.title", "scene.title"} clause, thisArgs := getSearchBinding(searchColumns, *q, false) - whereClauses = append(whereClauses, clause) - args = append(args, thisArgs...) + query.addWhere(clause) + query.addArg(thisArgs...) } - if tagID := sceneMarkerFilter.TagID; tagID != nil { - whereClauses = append(whereClauses, "(scene_markers.primary_tag_id = "+*tagID+" OR tags.id = "+*tagID+")") - } + filter := qb.makeFilter(sceneMarkerFilter) - sortAndPagination := qb.getSceneMarkerSort(findFilter) + getPagination(findFilter) - idsResult, countResult, err := qb.executeFindQuery(body, args, sortAndPagination, whereClauses, havingClauses, []string{}) + query.addFilter(filter) + + query.sortAndPagination = qb.getSceneMarkerSort(findFilter) + getPagination(findFilter) + idsResult, countResult, err := query.executeFind() if err != nil { return nil, 0, err } @@ -267,6 +178,74 @@ func (qb *sceneMarkerQueryBuilder) Query(sceneMarkerFilter *models.SceneMarkerFi return sceneMarkers, countResult, nil } +func sceneMarkerTagIDCriterionHandler(qb *sceneMarkerQueryBuilder, tagID *string) criterionHandlerFunc { + return func(f *filterBuilder) { + if tagID != nil { + f.addJoin("scene_markers_tags", "", "scene_markers_tags.scene_marker_id = scene_markers.id") + + f.addWhere("(scene_markers.primary_tag_id = ? OR scene_markers_tags.tag_id = ?)", *tagID, *tagID) + } + } +} + +func sceneMarkerTagsCriterionHandler(qb *sceneMarkerQueryBuilder, tags *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { + return func(f *filterBuilder) { + if tags != nil && len(tags.Value) > 0 { + valuesClause := getHierarchicalValues(qb.tx, tags.Value, tagTable, "tags_relations", "", tags.Depth) + + f.addWith(`marker_tags AS ( +SELECT mt.scene_marker_id, t.column1 AS root_tag_id FROM scene_markers_tags mt +INNER JOIN (` + valuesClause + `) t ON t.column2 = mt.tag_id +UNION +SELECT m.id, t.column1 FROM scene_markers m +INNER JOIN (` + valuesClause + `) t ON t.column2 = m.primary_tag_id +)`) + + f.addJoin("marker_tags", "", "marker_tags.scene_marker_id = scene_markers.id") + + addHierarchicalConditionClauses(f, tags, "marker_tags", "root_tag_id") + } + } +} + +func sceneMarkerSceneTagsCriterionHandler(qb *sceneMarkerQueryBuilder, tags *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { + return func(f *filterBuilder) { + if tags != nil && len(tags.Value) > 0 { + valuesClause := getHierarchicalValues(qb.tx, tags.Value, tagTable, "tags_relations", "", tags.Depth) + + f.addWith(`scene_tags AS ( +SELECT st.scene_id, t.column1 AS root_tag_id FROM scenes_tags st +INNER JOIN (` + valuesClause + `) t ON t.column2 = st.tag_id +)`) + + f.addJoin("scene_tags", "", "scene_tags.scene_id = scene_markers.scene_id") + + addHierarchicalConditionClauses(f, tags, "scene_tags", "root_tag_id") + } + } +} + +func sceneMarkerPerformersCriterionHandler(qb *sceneMarkerQueryBuilder, performers *models.MultiCriterionInput) criterionHandlerFunc { + h := joinedMultiCriterionHandlerBuilder{ + primaryTable: sceneTable, + joinTable: performersScenesTable, + joinAs: "performers_join", + primaryFK: sceneIDColumn, + foreignFK: performerIDColumn, + + addJoinTable: func(f *filterBuilder) { + f.addJoin(performersScenesTable, "performers_join", "performers_join.scene_id = scene_markers.scene_id") + }, + } + + handler := h.handler(performers) + return func(f *filterBuilder) { + // Make sure scenes is included, otherwise excludes filter fails + f.addJoin(sceneTable, "", "scenes.id = scene_markers.scene_id") + handler(f) + } +} + func (qb *sceneMarkerQueryBuilder) getSceneMarkerSort(findFilter *models.FindFilterType) string { sort := findFilter.GetSort("title") direction := findFilter.GetDirection() diff --git a/pkg/sqlite/scene_test.go b/pkg/sqlite/scene_test.go index 2be8e4d45..aa895b8b7 100644 --- a/pkg/sqlite/scene_test.go +++ b/pkg/sqlite/scene_test.go @@ -1034,12 +1034,13 @@ func TestSceneQueryPerformers(t *testing.T) { func TestSceneQueryTags(t *testing.T) { withTxn(func(r models.Repository) error { sqb := r.Scene() - tagCriterion := models.MultiCriterionInput{ + tagCriterion := models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdxWithScene]), strconv.Itoa(tagIDs[tagIdx1WithScene]), }, Modifier: models.CriterionModifierIncludes, + Depth: 0, } sceneFilter := models.SceneFilterType{ @@ -1054,12 +1055,13 @@ func TestSceneQueryTags(t *testing.T) { assert.True(t, scene.ID == sceneIDs[sceneIdxWithTag] || scene.ID == sceneIDs[sceneIdxWithTwoTags]) } - tagCriterion = models.MultiCriterionInput{ + tagCriterion = models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdx1WithScene]), strconv.Itoa(tagIDs[tagIdx2WithScene]), }, Modifier: models.CriterionModifierIncludesAll, + Depth: 0, } scenes = queryScene(t, sqb, &sceneFilter, nil) @@ -1067,11 +1069,12 @@ func TestSceneQueryTags(t *testing.T) { assert.Len(t, scenes, 1) assert.Equal(t, sceneIDs[sceneIdxWithTwoTags], scenes[0].ID) - tagCriterion = models.MultiCriterionInput{ + tagCriterion = models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdx1WithScene]), }, Modifier: models.CriterionModifierExcludes, + Depth: 0, } q := getSceneStringValue(sceneIdxWithTwoTags, titleField) @@ -1089,12 +1092,13 @@ func TestSceneQueryTags(t *testing.T) { func TestSceneQueryPerformerTags(t *testing.T) { withTxn(func(r models.Repository) error { sqb := r.Scene() - tagCriterion := models.MultiCriterionInput{ + tagCriterion := models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdxWithPerformer]), strconv.Itoa(tagIDs[tagIdx1WithPerformer]), }, Modifier: models.CriterionModifierIncludes, + Depth: 0, } sceneFilter := models.SceneFilterType{ @@ -1109,12 +1113,13 @@ func TestSceneQueryPerformerTags(t *testing.T) { assert.True(t, scene.ID == sceneIDs[sceneIdxWithPerformerTag] || scene.ID == sceneIDs[sceneIdxWithPerformerTwoTags]) } - tagCriterion = models.MultiCriterionInput{ + tagCriterion = models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdx1WithPerformer]), strconv.Itoa(tagIDs[tagIdx2WithPerformer]), }, Modifier: models.CriterionModifierIncludesAll, + Depth: 0, } scenes = queryScene(t, sqb, &sceneFilter, nil) @@ -1122,11 +1127,12 @@ func TestSceneQueryPerformerTags(t *testing.T) { assert.Len(t, scenes, 1) assert.Equal(t, sceneIDs[sceneIdxWithPerformerTwoTags], scenes[0].ID) - tagCriterion = models.MultiCriterionInput{ + tagCriterion = models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdx1WithPerformer]), }, Modifier: models.CriterionModifierExcludes, + Depth: 0, } q := getSceneStringValue(sceneIdxWithPerformerTwoTags, titleField) diff --git a/pkg/sqlite/tag.go b/pkg/sqlite/tag.go index c67a8dea5..28237867c 100644 --- a/pkg/sqlite/tag.go +++ b/pkg/sqlite/tag.go @@ -196,6 +196,28 @@ func (qb *tagQueryBuilder) FindByNames(names []string, nocase bool) ([]*models.T return qb.queryTags(query, args) } +func (qb *tagQueryBuilder) FindByParentTagID(parentID int) ([]*models.Tag, error) { + query := ` + SELECT tags.* FROM tags + INNER JOIN tags_relations ON tags_relations.child_id = tags.id + WHERE tags_relations.parent_id = ? + ` + query += qb.getDefaultTagSort() + args := []interface{}{parentID} + return qb.queryTags(query, args) +} + +func (qb *tagQueryBuilder) FindByChildTagID(parentID int) ([]*models.Tag, error) { + query := ` + SELECT tags.* FROM tags + INNER JOIN tags_relations ON tags_relations.parent_id = tags.id + WHERE tags_relations.child_id = ? + ` + query += qb.getDefaultTagSort() + args := []interface{}{parentID} + return qb.queryTags(query, args) +} + func (qb *tagQueryBuilder) Count() (int, error) { return qb.runCountQuery(qb.buildCountQuery("SELECT tags.id FROM tags"), nil) } @@ -572,3 +594,115 @@ AND NOT EXISTS(SELECT 1 FROM `+table+` o WHERE o.`+idColumn+` = `+table+`.`+idCo return nil } + +func (qb *tagQueryBuilder) UpdateParentTags(tagID int, parentIDs []int) error { + tx := qb.tx + if _, err := tx.Exec("DELETE FROM tags_relations WHERE child_id = ?", tagID); err != nil { + return err + } + + if len(parentIDs) > 0 { + var args []interface{} + var values []string + for _, parentID := range parentIDs { + values = append(values, "(? , ?)") + args = append(args, parentID, tagID) + } + + query := "INSERT INTO tags_relations (parent_id, child_id) VALUES " + strings.Join(values, ", ") + if _, err := tx.Exec(query, args...); err != nil { + return err + } + } + + return nil +} + +func (qb *tagQueryBuilder) UpdateChildTags(tagID int, childIDs []int) error { + tx := qb.tx + if _, err := tx.Exec("DELETE FROM tags_relations WHERE parent_id = ?", tagID); err != nil { + return err + } + + if len(childIDs) > 0 { + var args []interface{} + var values []string + for _, childID := range childIDs { + values = append(values, "(? , ?)") + args = append(args, tagID, childID) + } + + query := "INSERT INTO tags_relations (parent_id, child_id) VALUES " + strings.Join(values, ", ") + if _, err := tx.Exec(query, args...); err != nil { + return err + } + } + + return nil +} + +func (qb *tagQueryBuilder) FindAllAncestors(tagID int, excludeIDs []int) ([]*models.Tag, error) { + inBinding := getInBinding(len(excludeIDs) + 1) + + query := `WITH RECURSIVE +parents AS ( + SELECT t.id AS parent_id, t.id AS child_id FROM tags t WHERE t.id = ? + UNION + SELECT tr.parent_id, tr.child_id FROM tags_relations tr INNER JOIN parents p ON p.parent_id = tr.child_id WHERE tr.parent_id NOT IN` + inBinding + ` +), +children AS ( + SELECT tr.parent_id, tr.child_id FROM tags_relations tr INNER JOIN parents p ON p.parent_id = tr.parent_id WHERE tr.child_id NOT IN` + inBinding + ` + UNION + SELECT tr.parent_id, tr.child_id FROM tags_relations tr INNER JOIN children c ON c.child_id = tr.parent_id WHERE tr.child_id NOT IN` + inBinding + ` +) +SELECT t.* FROM tags t INNER JOIN parents p ON t.id = p.parent_id +UNION +SELECT t.* FROM tags t INNER JOIN children c ON t.id = c.child_id +` + + var ret models.Tags + excludeArgs := []interface{}{tagID} + for _, excludeID := range excludeIDs { + excludeArgs = append(excludeArgs, excludeID) + } + args := []interface{}{tagID} + args = append(args, append(append(excludeArgs, excludeArgs...), excludeArgs...)...) + if err := qb.query(query, args, &ret); err != nil { + return nil, err + } + + return ret, nil +} + +func (qb *tagQueryBuilder) FindAllDescendants(tagID int, excludeIDs []int) ([]*models.Tag, error) { + inBinding := getInBinding(len(excludeIDs) + 1) + + query := `WITH RECURSIVE +children AS ( + SELECT t.id AS parent_id, t.id AS child_id FROM tags t WHERE t.id = ? + UNION + SELECT tr.parent_id, tr.child_id FROM tags_relations tr INNER JOIN children c ON c.child_id = tr.parent_id WHERE tr.child_id NOT IN` + inBinding + ` +), +parents AS ( + SELECT tr.parent_id, tr.child_id FROM tags_relations tr INNER JOIN children c ON c.child_id = tr.child_id WHERE tr.parent_id NOT IN` + inBinding + ` + UNION + SELECT tr.parent_id, tr.child_id FROM tags_relations tr INNER JOIN parents p ON p.parent_id = tr.child_id WHERE tr.parent_id NOT IN` + inBinding + ` +) +SELECT t.* FROM tags t INNER JOIN children c ON t.id = c.child_id +UNION +SELECT t.* FROM tags t INNER JOIN parents p ON t.id = p.parent_id +` + + var ret models.Tags + excludeArgs := []interface{}{tagID} + for _, excludeID := range excludeIDs { + excludeArgs = append(excludeArgs, excludeID) + } + args := []interface{}{tagID} + args = append(args, append(append(excludeArgs, excludeArgs...), excludeArgs...)...) + if err := qb.query(query, args, &ret); err != nil { + return nil, err + } + + return ret, nil +} diff --git a/pkg/tag/export.go b/pkg/tag/export.go index ba9d6da82..54c64990e 100644 --- a/pkg/tag/export.go +++ b/pkg/tag/export.go @@ -32,6 +32,13 @@ func ToJSON(reader models.TagReader, tag *models.Tag) (*jsonschema.Tag, error) { newTagJSON.Image = utils.GetBase64StringFromData(image) } + parents, err := reader.FindByChildTagID(tag.ID) + if err != nil { + return nil, fmt.Errorf("error getting parents: %s", err.Error()) + } + + newTagJSON.Parents = GetNames(parents) + return &newTagJSON, nil } diff --git a/pkg/tag/export_test.go b/pkg/tag/export_test.go index 16e85292d..36b7a0855 100644 --- a/pkg/tag/export_test.go +++ b/pkg/tag/export_test.go @@ -13,10 +13,12 @@ import ( ) const ( - tagID = 1 - noImageID = 2 - errImageID = 3 - errAliasID = 4 + tagID = 1 + noImageID = 2 + errImageID = 3 + errAliasID = 4 + withParentsID = 5 + errParentsID = 6 ) const tagName = "testTag" @@ -37,7 +39,7 @@ func createTag(id int) models.Tag { } } -func createJSONTag(aliases []string, image string) *jsonschema.Tag { +func createJSONTag(aliases []string, image string, parents []string) *jsonschema.Tag { return &jsonschema.Tag{ Name: tagName, Aliases: aliases, @@ -47,7 +49,8 @@ func createJSONTag(aliases []string, image string) *jsonschema.Tag { UpdatedAt: models.JSONTime{ Time: updateTime, }, - Image: image, + Image: image, + Parents: parents, } } @@ -63,12 +66,12 @@ func initTestTable() { scenarios = []testScenario{ { createTag(tagID), - createJSONTag([]string{"alias"}, "PHN2ZwogICB4bWxuczpkYz0iaHR0cDovL3B1cmwub3JnL2RjL2VsZW1lbnRzLzEuMS8iCiAgIHhtbG5zOmNjPSJodHRwOi8vY3JlYXRpdmVjb21tb25zLm9yZy9ucyMiCiAgIHhtbG5zOnJkZj0iaHR0cDovL3d3dy53My5vcmcvMTk5OS8wMi8yMi1yZGYtc3ludGF4LW5zIyIKICAgeG1sbnM6c3ZnPSJodHRwOi8vd3d3LnczLm9yZy8yMDAwL3N2ZyIKICAgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIgogICB4bWxuczpzb2RpcG9kaT0iaHR0cDovL3NvZGlwb2RpLnNvdXJjZWZvcmdlLm5ldC9EVEQvc29kaXBvZGktMC5kdGQiCiAgIHhtbG5zOmlua3NjYXBlPSJodHRwOi8vd3d3Lmlua3NjYXBlLm9yZy9uYW1lc3BhY2VzL2lua3NjYXBlIgogICB3aWR0aD0iMjAwIgogICBoZWlnaHQ9IjIwMCIKICAgaWQ9InN2ZzIiCiAgIHZlcnNpb249IjEuMSIKICAgaW5rc2NhcGU6dmVyc2lvbj0iMC40OC40IHI5OTM5IgogICBzb2RpcG9kaTpkb2NuYW1lPSJ0YWcuc3ZnIj4KICA8ZGVmcwogICAgIGlkPSJkZWZzNCIgLz4KICA8c29kaXBvZGk6bmFtZWR2aWV3CiAgICAgaWQ9ImJhc2UiCiAgICAgcGFnZWNvbG9yPSIjMDAwMDAwIgogICAgIGJvcmRlcmNvbG9yPSIjNjY2NjY2IgogICAgIGJvcmRlcm9wYWNpdHk9IjEuMCIKICAgICBpbmtzY2FwZTpwYWdlb3BhY2l0eT0iMSIKICAgICBpbmtzY2FwZTpwYWdlc2hhZG93PSIyIgogICAgIGlua3NjYXBlOnpvb209IjEiCiAgICAgaW5rc2NhcGU6Y3g9IjE4MS43Nzc3MSIKICAgICBpbmtzY2FwZTpjeT0iMjc5LjcyMzc2IgogICAgIGlua3NjYXBlOmRvY3VtZW50LXVuaXRzPSJweCIKICAgICBpbmtzY2FwZTpjdXJyZW50LWxheWVyPSJsYXllcjEiCiAgICAgc2hvd2dyaWQ9ImZhbHNlIgogICAgIGZpdC1tYXJnaW4tdG9wPSIwIgogICAgIGZpdC1tYXJnaW4tbGVmdD0iMCIKICAgICBmaXQtbWFyZ2luLXJpZ2h0PSIwIgogICAgIGZpdC1tYXJnaW4tYm90dG9tPSIwIgogICAgIGlua3NjYXBlOndpbmRvdy13aWR0aD0iMTkyMCIKICAgICBpbmtzY2FwZTp3aW5kb3ctaGVpZ2h0PSIxMDE3IgogICAgIGlua3NjYXBlOndpbmRvdy14PSItOCIKICAgICBpbmtzY2FwZTp3aW5kb3cteT0iLTgiCiAgICAgaW5rc2NhcGU6d2luZG93LW1heGltaXplZD0iMSIgLz4KICA8bWV0YWRhdGEKICAgICBpZD0ibWV0YWRhdGE3Ij4KICAgIDxyZGY6UkRGPgogICAgICA8Y2M6V29yawogICAgICAgICByZGY6YWJvdXQ9IiI+CiAgICAgICAgPGRjOmZvcm1hdD5pbWFnZS9zdmcreG1sPC9kYzpmb3JtYXQ+CiAgICAgICAgPGRjOnR5cGUKICAgICAgICAgICByZGY6cmVzb3VyY2U9Imh0dHA6Ly9wdXJsLm9yZy9kYy9kY21pdHlwZS9TdGlsbEltYWdlIiAvPgogICAgICAgIDxkYzp0aXRsZT48L2RjOnRpdGxlPgogICAgICA8L2NjOldvcms+CiAgICA8L3JkZjpSREY+CiAgPC9tZXRhZGF0YT4KICA8ZwogICAgIGlua3NjYXBlOmxhYmVsPSJMYXllciAxIgogICAgIGlua3NjYXBlOmdyb3VwbW9kZT0ibGF5ZXIiCiAgICAgaWQ9ImxheWVyMSIKICAgICB0cmFuc2Zvcm09InRyYW5zbGF0ZSgtMTU3Ljg0MzU4LC01MjQuNjk1MjIpIj4KICAgIDxwYXRoCiAgICAgICBpZD0icGF0aDI5ODciCiAgICAgICBkPSJtIDIyOS45NDMxNCw2NjkuMjY1NDkgLTM2LjA4NDY2LC0zNi4wODQ2NiBjIC00LjY4NjUzLC00LjY4NjUzIC00LjY4NjUzLC0xMi4yODQ2OCAwLC0xNi45NzEyMSBsIDM2LjA4NDY2LC0zNi4wODQ2NyBhIDEyLjAwMDQ1MywxMi4wMDA0NTMgMCAwIDEgOC40ODU2LC0zLjUxNDggbCA3NC45MTQ0MywwIGMgNi42Mjc2MSwwIDEyLjAwMDQxLDUuMzcyOCAxMi4wMDA0MSwxMi4wMDA0MSBsIDAsNzIuMTY5MzMgYyAwLDYuNjI3NjEgLTUuMzcyOCwxMi4wMDA0MSAtMTIuMDAwNDEsMTIuMDAwNDEgbCAtNzQuOTE0NDMsMCBhIDEyLjAwMDQ1MywxMi4wMDA0NTMgMCAwIDEgLTguNDg1NiwtMy41MTQ4MSB6IG0gLTEzLjQ1NjM5LC01My4wNTU4NyBjIC00LjY4NjUzLDQuNjg2NTMgLTQuNjg2NTMsMTIuMjg0NjggMCwxNi45NzEyMSA0LjY4NjUyLDQuNjg2NTIgMTIuMjg0NjcsNC42ODY1MiAxNi45NzEyLDAgNC42ODY1MywtNC42ODY1MyA0LjY4NjUzLC0xMi4yODQ2OCAwLC0xNi45NzEyMSAtNC42ODY1MywtNC42ODY1MiAtMTIuMjg0NjgsLTQuNjg2NTIgLTE2Ljk3MTIsMCB6IgogICAgICAgaW5rc2NhcGU6Y29ubmVjdG9yLWN1cnZhdHVyZT0iMCIKICAgICAgIHN0eWxlPSJmaWxsOiNmZmZmZmY7ZmlsbC1vcGFjaXR5OjEiIC8+CiAgPC9nPgo8L3N2Zz4="), + createJSONTag([]string{"alias"}, image, nil), false, }, { createTag(noImageID), - createJSONTag(nil, ""), + createJSONTag(nil, "", nil), false, }, { @@ -81,6 +84,16 @@ func initTestTable() { nil, true, }, + { + createTag(withParentsID), + createJSONTag(nil, image, []string{"parent"}), + false, + }, + { + createTag(errParentsID), + nil, + true, + }, } } @@ -91,15 +104,25 @@ func TestToJSON(t *testing.T) { imageErr := errors.New("error getting image") aliasErr := errors.New("error getting aliases") + parentsErr := errors.New("error getting parents") mockTagReader.On("GetAliases", tagID).Return([]string{"alias"}, nil).Once() mockTagReader.On("GetAliases", noImageID).Return(nil, nil).Once() mockTagReader.On("GetAliases", errImageID).Return(nil, nil).Once() mockTagReader.On("GetAliases", errAliasID).Return(nil, aliasErr).Once() + mockTagReader.On("GetAliases", withParentsID).Return(nil, nil).Once() + mockTagReader.On("GetAliases", errParentsID).Return(nil, nil).Once() - mockTagReader.On("GetImage", tagID).Return(models.DefaultTagImage, nil).Once() + mockTagReader.On("GetImage", tagID).Return(imageBytes, nil).Once() mockTagReader.On("GetImage", noImageID).Return(nil, nil).Once() mockTagReader.On("GetImage", errImageID).Return(nil, imageErr).Once() + mockTagReader.On("GetImage", withParentsID).Return(imageBytes, nil).Once() + mockTagReader.On("GetImage", errParentsID).Return(nil, nil).Once() + + mockTagReader.On("FindByChildTagID", tagID).Return(nil, nil).Once() + mockTagReader.On("FindByChildTagID", noImageID).Return(nil, nil).Once() + mockTagReader.On("FindByChildTagID", withParentsID).Return([]*models.Tag{{Name: "parent"}}, nil).Once() + mockTagReader.On("FindByChildTagID", errParentsID).Return(nil, parentsErr).Once() for i, s := range scenarios { tag := s.tag diff --git a/pkg/tag/import.go b/pkg/tag/import.go index 54de4bd0e..8fe0410d2 100644 --- a/pkg/tag/import.go +++ b/pkg/tag/import.go @@ -8,9 +8,22 @@ import ( "github.com/stashapp/stash/pkg/utils" ) +type ParentTagNotExistError struct { + missingParent string +} + +func (e ParentTagNotExistError) Error() string { + return fmt.Sprintf("parent tag <%s> does not exist", e.missingParent) +} + +func (e ParentTagNotExistError) MissingParent() string { + return e.missingParent +} + type Importer struct { - ReaderWriter models.TagReaderWriter - Input jsonschema.Tag + ReaderWriter models.TagReaderWriter + Input jsonschema.Tag + MissingRefBehaviour models.ImportMissingRefEnum tag models.Tag imageData []byte @@ -45,6 +58,15 @@ func (i *Importer) PostImport(id int) error { return fmt.Errorf("error setting tag aliases: %s", err.Error()) } + parents, err := i.getParents() + if err != nil { + return err + } + + if err := i.ReaderWriter.UpdateParentTags(id, parents); err != nil { + return fmt.Errorf("error setting parents: %s", err.Error()) + } + return nil } @@ -87,3 +109,46 @@ func (i *Importer) Update(id int) error { return nil } + +func (i *Importer) getParents() ([]int, error) { + var parents []int + for _, parent := range i.Input.Parents { + tag, err := i.ReaderWriter.FindByName(parent, false) + if err != nil { + return nil, fmt.Errorf("error finding parent by name: %s", err.Error()) + } + + if tag == nil { + if i.MissingRefBehaviour == models.ImportMissingRefEnumFail { + return nil, ParentTagNotExistError{missingParent: parent} + } + + if i.MissingRefBehaviour == models.ImportMissingRefEnumIgnore { + continue + } + + if i.MissingRefBehaviour == models.ImportMissingRefEnumCreate { + parentID, err := i.createParent(parent) + if err != nil { + return nil, err + } + parents = append(parents, parentID) + } + } else { + parents = append(parents, tag.ID) + } + } + + return parents, nil +} + +func (i *Importer) createParent(name string) (int, error) { + newTag := *models.NewTag(name) + + created, err := i.ReaderWriter.Create(newTag) + if err != nil { + return 0, err + } + + return created.ID, nil +} diff --git a/pkg/tag/import_test.go b/pkg/tag/import_test.go index ea29e47c3..c2f29d8e5 100644 --- a/pkg/tag/import_test.go +++ b/pkg/tag/import_test.go @@ -8,6 +8,7 @@ import ( "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models/mocks" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" ) const image = "aW1hZ2VCeXRlcw==" @@ -64,13 +65,25 @@ func TestImporterPostImport(t *testing.T) { updateTagImageErr := errors.New("UpdateImage error") updateTagAliasErr := errors.New("UpdateAlias error") + updateTagParentsErr := errors.New("UpdateParentTags error") readerWriter.On("UpdateAliases", tagID, i.Input.Aliases).Return(nil).Once() readerWriter.On("UpdateAliases", errAliasID, i.Input.Aliases).Return(updateTagAliasErr).Once() + readerWriter.On("UpdateAliases", withParentsID, i.Input.Aliases).Return(nil).Once() + readerWriter.On("UpdateAliases", errParentsID, i.Input.Aliases).Return(nil).Once() readerWriter.On("UpdateImage", tagID, imageBytes).Return(nil).Once() readerWriter.On("UpdateImage", errAliasID, imageBytes).Return(nil).Once() readerWriter.On("UpdateImage", errImageID, imageBytes).Return(updateTagImageErr).Once() + readerWriter.On("UpdateImage", withParentsID, imageBytes).Return(nil).Once() + readerWriter.On("UpdateImage", errParentsID, imageBytes).Return(nil).Once() + + var parentTags []int + readerWriter.On("UpdateParentTags", tagID, parentTags).Return(nil).Once() + readerWriter.On("UpdateParentTags", withParentsID, []int{100}).Return(nil).Once() + readerWriter.On("UpdateParentTags", errParentsID, []int{100}).Return(updateTagParentsErr).Once() + + readerWriter.On("FindByName", "Parent", false).Return(&models.Tag{ID: 100}, nil) err := i.PostImport(tagID) assert.Nil(t, err) @@ -81,6 +94,106 @@ func TestImporterPostImport(t *testing.T) { err = i.PostImport(errAliasID) assert.NotNil(t, err) + i.Input.Parents = []string{"Parent"} + err = i.PostImport(withParentsID) + assert.Nil(t, err) + + err = i.PostImport(errParentsID) + assert.NotNil(t, err) + + readerWriter.AssertExpectations(t) +} + +func TestImporterPostImportParentMissing(t *testing.T) { + readerWriter := &mocks.TagReaderWriter{} + + i := Importer{ + ReaderWriter: readerWriter, + Input: jsonschema.Tag{}, + imageData: imageBytes, + } + + createID := 1 + createErrorID := 2 + createFindErrorID := 3 + createFoundID := 4 + failID := 5 + failFindErrorID := 6 + failFoundID := 7 + ignoreID := 8 + ignoreFindErrorID := 9 + ignoreFoundID := 10 + + findError := errors.New("failed finding parent") + + var emptyParents []int + + readerWriter.On("UpdateImage", mock.Anything, mock.Anything).Return(nil) + readerWriter.On("UpdateAliases", mock.Anything, mock.Anything).Return(nil) + + readerWriter.On("FindByName", "Create", false).Return(nil, nil).Once() + readerWriter.On("FindByName", "CreateError", false).Return(nil, nil).Once() + readerWriter.On("FindByName", "CreateFindError", false).Return(nil, findError).Once() + readerWriter.On("FindByName", "CreateFound", false).Return(&models.Tag{ID: 101}, nil).Once() + readerWriter.On("FindByName", "Fail", false).Return(nil, nil).Once() + readerWriter.On("FindByName", "FailFindError", false).Return(nil, findError) + readerWriter.On("FindByName", "FailFound", false).Return(&models.Tag{ID: 102}, nil).Once() + readerWriter.On("FindByName", "Ignore", false).Return(nil, nil).Once() + readerWriter.On("FindByName", "IgnoreFindError", false).Return(nil, findError) + readerWriter.On("FindByName", "IgnoreFound", false).Return(&models.Tag{ID: 103}, nil).Once() + + readerWriter.On("UpdateParentTags", createID, []int{100}).Return(nil).Once() + readerWriter.On("UpdateParentTags", createFoundID, []int{101}).Return(nil).Once() + readerWriter.On("UpdateParentTags", failFoundID, []int{102}).Return(nil).Once() + readerWriter.On("UpdateParentTags", ignoreID, emptyParents).Return(nil).Once() + readerWriter.On("UpdateParentTags", ignoreFoundID, []int{103}).Return(nil).Once() + + readerWriter.On("Create", mock.MatchedBy(func(t models.Tag) bool { return t.Name == "Create" })).Return(&models.Tag{ID: 100}, nil).Once() + readerWriter.On("Create", mock.MatchedBy(func(t models.Tag) bool { return t.Name == "CreateError" })).Return(nil, errors.New("failed creating parent")).Once() + + i.MissingRefBehaviour = models.ImportMissingRefEnumCreate + i.Input.Parents = []string{"Create"} + err := i.PostImport(createID) + assert.Nil(t, err) + + i.Input.Parents = []string{"CreateError"} + err = i.PostImport(createErrorID) + assert.NotNil(t, err) + + i.Input.Parents = []string{"CreateFindError"} + err = i.PostImport(createFindErrorID) + assert.NotNil(t, err) + + i.Input.Parents = []string{"CreateFound"} + err = i.PostImport(createFoundID) + assert.Nil(t, err) + + i.MissingRefBehaviour = models.ImportMissingRefEnumFail + i.Input.Parents = []string{"Fail"} + err = i.PostImport(failID) + assert.NotNil(t, err) + + i.Input.Parents = []string{"FailFindError"} + err = i.PostImport(failFindErrorID) + assert.NotNil(t, err) + + i.Input.Parents = []string{"FailFound"} + err = i.PostImport(failFoundID) + assert.Nil(t, err) + + i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore + i.Input.Parents = []string{"Ignore"} + err = i.PostImport(ignoreID) + assert.Nil(t, err) + + i.Input.Parents = []string{"IgnoreFindError"} + err = i.PostImport(ignoreFindErrorID) + assert.NotNil(t, err) + + i.Input.Parents = []string{"IgnoreFound"} + err = i.PostImport(ignoreFoundID) + assert.Nil(t, err) + readerWriter.AssertExpectations(t) } diff --git a/pkg/tag/update.go b/pkg/tag/update.go index 4f5b9b18b..d4a03f8f5 100644 --- a/pkg/tag/update.go +++ b/pkg/tag/update.go @@ -2,7 +2,6 @@ package tag import ( "fmt" - "github.com/stashapp/stash/pkg/models" ) @@ -23,6 +22,20 @@ func (e *NameUsedByAliasError) Error() string { return fmt.Sprintf("name '%s' is used as alias for '%s'", e.Name, e.OtherTag) } +type InvalidTagHierarchyError struct { + Direction string + InvalidTag string + ApplyingTag string +} + +func (e *InvalidTagHierarchyError) Error() string { + if e.InvalidTag == e.ApplyingTag { + return fmt.Sprintf("Cannot apply tag \"%s\" as it already is a %s", e.InvalidTag, e.Direction) + } else { + return fmt.Sprintf("Cannot apply tag \"%s\" as it is linked to \"%s\" which already is a %s", e.ApplyingTag, e.InvalidTag, e.Direction) + } +} + // EnsureTagNameUnique returns an error if the tag name provided // is used as a name or alias of another existing tag. func EnsureTagNameUnique(id int, name string, qb models.TagReader) error { @@ -63,3 +76,150 @@ func EnsureAliasesUnique(id int, aliases []string, qb models.TagReader) error { return nil } + +func EnsureUniqueHierarchy(id int, parentIDs, childIDs []int, qb models.TagReader) error { + allAncestors := make(map[int]*models.Tag) + allDescendants := make(map[int]*models.Tag) + excludeIDs := []int{id} + + validateParent := func(testID, applyingID int) error { + if parentTag, exists := allAncestors[testID]; exists { + applyingTag, err := qb.Find(applyingID) + + if err != nil { + return nil + } + + return &InvalidTagHierarchyError{ + Direction: "parent", + InvalidTag: parentTag.Name, + ApplyingTag: applyingTag.Name, + } + } + + return nil + } + + validateChild := func(testID, applyingID int) error { + if childTag, exists := allDescendants[testID]; exists { + applyingTag, err := qb.Find(applyingID) + + if err != nil { + return nil + } + + return &InvalidTagHierarchyError{ + Direction: "child", + InvalidTag: childTag.Name, + ApplyingTag: applyingTag.Name, + } + } + + return validateParent(testID, applyingID) + } + + if parentIDs == nil { + parentTags, err := qb.FindByChildTagID(id) + if err != nil { + return err + } + + for _, parentTag := range parentTags { + parentIDs = append(parentIDs, parentTag.ID) + } + } + + if childIDs == nil { + childTags, err := qb.FindByParentTagID(id) + if err != nil { + return err + } + + for _, childTag := range childTags { + childIDs = append(childIDs, childTag.ID) + } + } + + for _, parentID := range parentIDs { + parentsAncestors, err := qb.FindAllAncestors(parentID, excludeIDs) + if err != nil { + return err + } + + for _, ancestorTag := range parentsAncestors { + if err := validateParent(ancestorTag.ID, parentID); err != nil { + return err + } + + allAncestors[ancestorTag.ID] = ancestorTag + } + } + + for _, childID := range childIDs { + childsDescendants, err := qb.FindAllDescendants(childID, excludeIDs) + if err != nil { + return err + } + + for _, descendentTag := range childsDescendants { + if err := validateChild(descendentTag.ID, childID); err != nil { + return err + } + + allDescendants[descendentTag.ID] = descendentTag + } + } + + return nil +} + +func MergeHierarchy(destination int, sources []int, qb models.TagReader) ([]int, []int, error) { + var mergedParents, mergedChildren []int + allIds := append([]int{destination}, sources...) + + addTo := func(mergedItems []int, tags []*models.Tag) []int { + Tags: + for _, tag := range tags { + // Ignore tags which are already set + for _, existingItem := range mergedItems { + if tag.ID == existingItem { + continue Tags + } + } + + // Ignore tags which are being merged, as these are rolled up anyway (if A is merged into B any direct link between them can be ignored) + for _, id := range allIds { + if tag.ID == id { + continue Tags + } + } + + mergedItems = append(mergedItems, tag.ID) + } + + return mergedItems + } + + for _, id := range allIds { + parents, err := qb.FindByChildTagID(id) + if err != nil { + return nil, nil, err + } + + mergedParents = addTo(mergedParents, parents) + + children, err := qb.FindByParentTagID(id) + if err != nil { + return nil, nil, err + } + + mergedChildren = addTo(mergedChildren, children) + } + + err := EnsureUniqueHierarchy(destination, mergedParents, mergedChildren, qb) + if err != nil { + return nil, nil, err + } + + return mergedParents, mergedChildren, nil +} diff --git a/pkg/tag/update_test.go b/pkg/tag/update_test.go new file mode 100644 index 000000000..5a9cddf9d --- /dev/null +++ b/pkg/tag/update_test.go @@ -0,0 +1,302 @@ +package tag + +import ( + "fmt" + "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/models/mocks" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "testing" +) + +var testUniqueHierarchyTags = map[int]*models.Tag{ + 1: { + ID: 1, + Name: "one", + }, + 2: { + ID: 2, + Name: "two", + }, + 3: { + ID: 3, + Name: "three", + }, + 4: { + ID: 4, + Name: "four", + }, +} + +type testUniqueHierarchyCase struct { + id int + parents []*models.Tag + children []*models.Tag + + onFindAllAncestors map[int][]*models.Tag + onFindAllDescendants map[int][]*models.Tag + + expectedError string +} + +var testUniqueHierarchyCases = []testUniqueHierarchyCase{ + { + id: 1, + parents: []*models.Tag{}, + children: []*models.Tag{}, + onFindAllAncestors: map[int][]*models.Tag{ + 1: {}, + }, + onFindAllDescendants: map[int][]*models.Tag{ + 1: {}, + }, + expectedError: "", + }, + { + id: 1, + parents: []*models.Tag{testUniqueHierarchyTags[2]}, + children: []*models.Tag{testUniqueHierarchyTags[3]}, + onFindAllAncestors: map[int][]*models.Tag{ + 2: {testUniqueHierarchyTags[2]}, + }, + onFindAllDescendants: map[int][]*models.Tag{ + 3: {testUniqueHierarchyTags[3]}, + }, + expectedError: "", + }, + { + id: 2, + parents: []*models.Tag{testUniqueHierarchyTags[3]}, + children: make([]*models.Tag, 0), + onFindAllAncestors: map[int][]*models.Tag{ + 3: {testUniqueHierarchyTags[3]}, + }, + onFindAllDescendants: map[int][]*models.Tag{ + 2: {testUniqueHierarchyTags[2]}, + }, + expectedError: "", + }, + { + id: 2, + parents: []*models.Tag{ + testUniqueHierarchyTags[3], + testUniqueHierarchyTags[4], + }, + children: []*models.Tag{}, + onFindAllAncestors: map[int][]*models.Tag{ + 3: {testUniqueHierarchyTags[3], testUniqueHierarchyTags[4]}, + 4: {testUniqueHierarchyTags[4]}, + }, + onFindAllDescendants: map[int][]*models.Tag{ + 2: {testUniqueHierarchyTags[2]}, + }, + expectedError: "Cannot apply tag \"four\" as it already is a parent", + }, + { + id: 2, + parents: []*models.Tag{}, + children: []*models.Tag{testUniqueHierarchyTags[3]}, + onFindAllAncestors: map[int][]*models.Tag{ + 2: {testUniqueHierarchyTags[2]}, + }, + onFindAllDescendants: map[int][]*models.Tag{ + 3: {testUniqueHierarchyTags[3]}, + }, + expectedError: "", + }, + { + id: 2, + parents: []*models.Tag{}, + children: []*models.Tag{ + testUniqueHierarchyTags[3], + testUniqueHierarchyTags[4], + }, + onFindAllAncestors: map[int][]*models.Tag{ + 2: {testUniqueHierarchyTags[2]}, + }, + onFindAllDescendants: map[int][]*models.Tag{ + 3: {testUniqueHierarchyTags[3], testUniqueHierarchyTags[4]}, + 4: {testUniqueHierarchyTags[4]}, + }, + expectedError: "Cannot apply tag \"four\" as it already is a child", + }, + { + id: 1, + parents: []*models.Tag{testUniqueHierarchyTags[2]}, + children: []*models.Tag{testUniqueHierarchyTags[3]}, + onFindAllAncestors: map[int][]*models.Tag{ + 2: {testUniqueHierarchyTags[2], testUniqueHierarchyTags[3]}, + }, + onFindAllDescendants: map[int][]*models.Tag{ + 3: {testUniqueHierarchyTags[3]}, + }, + expectedError: "Cannot apply tag \"three\" as it already is a parent", + }, + { + id: 1, + parents: []*models.Tag{testUniqueHierarchyTags[2]}, + children: []*models.Tag{testUniqueHierarchyTags[3]}, + onFindAllAncestors: map[int][]*models.Tag{ + 2: {testUniqueHierarchyTags[2]}, + }, + onFindAllDescendants: map[int][]*models.Tag{ + 3: {testUniqueHierarchyTags[3], testUniqueHierarchyTags[2]}, + }, + expectedError: "Cannot apply tag \"three\" as it is linked to \"two\" which already is a parent", + }, + { + id: 1, + parents: []*models.Tag{testUniqueHierarchyTags[3]}, + children: []*models.Tag{testUniqueHierarchyTags[3]}, + onFindAllAncestors: map[int][]*models.Tag{ + 3: {testUniqueHierarchyTags[3]}, + }, + onFindAllDescendants: map[int][]*models.Tag{ + 3: {testUniqueHierarchyTags[3]}, + }, + expectedError: "Cannot apply tag \"three\" as it already is a parent", + }, + { + id: 1, + parents: []*models.Tag{ + testUniqueHierarchyTags[2], + }, + children: []*models.Tag{ + testUniqueHierarchyTags[3], + }, + onFindAllAncestors: map[int][]*models.Tag{ + 2: {testUniqueHierarchyTags[2]}, + }, + onFindAllDescendants: map[int][]*models.Tag{ + 3: {testUniqueHierarchyTags[3], testUniqueHierarchyTags[2]}, + }, + expectedError: "Cannot apply tag \"three\" as it is linked to \"two\" which already is a parent", + }, + { + id: 1, + parents: []*models.Tag{testUniqueHierarchyTags[2]}, + children: []*models.Tag{testUniqueHierarchyTags[2]}, + onFindAllAncestors: map[int][]*models.Tag{ + 2: {testUniqueHierarchyTags[2]}, + }, + onFindAllDescendants: map[int][]*models.Tag{ + 2: {testUniqueHierarchyTags[2]}, + }, + expectedError: "Cannot apply tag \"two\" as it already is a parent", + }, + { + id: 2, + parents: []*models.Tag{testUniqueHierarchyTags[1]}, + children: []*models.Tag{testUniqueHierarchyTags[3]}, + onFindAllAncestors: map[int][]*models.Tag{ + 1: {testUniqueHierarchyTags[1]}, + }, + onFindAllDescendants: map[int][]*models.Tag{ + 3: {testUniqueHierarchyTags[3], testUniqueHierarchyTags[1]}, + }, + expectedError: "Cannot apply tag \"three\" as it is linked to \"one\" which already is a parent", + }, +} + +func TestEnsureUniqueHierarchy(t *testing.T) { + for _, tc := range testUniqueHierarchyCases { + testEnsureUniqueHierarchy(t, tc, false, false) + testEnsureUniqueHierarchy(t, tc, true, false) + testEnsureUniqueHierarchy(t, tc, false, true) + testEnsureUniqueHierarchy(t, tc, true, true) + } +} + +func testEnsureUniqueHierarchy(t *testing.T, tc testUniqueHierarchyCase, queryParents, queryChildren bool) { + mockTagReader := &mocks.TagReaderWriter{} + + var parentIDs, childIDs []int + var find map[int]*models.Tag + find = make(map[int]*models.Tag) + if tc.parents != nil { + parentIDs = make([]int, 0) + for _, parent := range tc.parents { + if parent.ID != tc.id { + find[parent.ID] = parent + parentIDs = append(parentIDs, parent.ID) + } + } + } + + if tc.children != nil { + childIDs = make([]int, 0) + for _, child := range tc.children { + if child.ID != tc.id { + find[child.ID] = child + childIDs = append(childIDs, child.ID) + } + } + } + + if queryParents { + parentIDs = nil + mockTagReader.On("FindByChildTagID", tc.id).Return(tc.parents, nil).Once() + } + + if queryChildren { + childIDs = nil + mockTagReader.On("FindByParentTagID", tc.id).Return(tc.children, nil).Once() + } + + mockTagReader.On("Find", mock.AnythingOfType("int")).Return(func(tagID int) *models.Tag { + for id, tag := range find { + if id == tagID { + return tag + } + } + return nil + }, func(tagID int) error { + return nil + }).Maybe() + + mockTagReader.On("FindAllAncestors", mock.AnythingOfType("int"), []int{tc.id}).Return(func(tagID int, excludeIDs []int) []*models.Tag { + for id, tags := range tc.onFindAllAncestors { + if id == tagID { + return tags + } + } + return nil + }, func(tagID int, excludeIDs []int) error { + for id, _ := range tc.onFindAllAncestors { + if id == tagID { + return nil + } + } + return fmt.Errorf("undefined ancestors for: %d", tagID) + }).Maybe() + + mockTagReader.On("FindAllDescendants", mock.AnythingOfType("int"), []int{tc.id}).Return(func(tagID int, excludeIDs []int) []*models.Tag { + for id, tags := range tc.onFindAllDescendants { + if id == tagID { + return tags + } + } + return nil + }, func(tagID int, excludeIDs []int) error { + for id, _ := range tc.onFindAllDescendants { + if id == tagID { + return nil + } + } + return fmt.Errorf("undefined descendants for: %d", tagID) + }).Maybe() + + res := EnsureUniqueHierarchy(tc.id, parentIDs, childIDs, mockTagReader) + + assert := assert.New(t) + + if tc.expectedError != "" { + if assert.NotNil(res) { + assert.Equal(tc.expectedError, res.Error()) + } + } else { + assert.Nil(res) + } + + mockTagReader.AssertExpectations(t) +} diff --git a/ui/v2.5/src/components/Changelog/versions/v0100.md b/ui/v2.5/src/components/Changelog/versions/v0100.md index 5cb28b991..6fec97614 100644 --- a/ui/v2.5/src/components/Changelog/versions/v0100.md +++ b/ui/v2.5/src/components/Changelog/versions/v0100.md @@ -1,4 +1,6 @@ ### ✨ New Features +* Added support for Tag hierarchies. ([#1519](https://github.com/stashapp/stash/pull/1519)) +* Added native support for Apple Silicon / M1 Macs. ([#1646] https://github.com/stashapp/stash/pull/1646) * Added Movies to Scene bulk edit dialog. ([#1676](https://github.com/stashapp/stash/pull/1676)) * Added Movies tab to Studio and Performer pages. ([#1675](https://github.com/stashapp/stash/pull/1675)) * Support filtering Movies by Performers. ([#1675](https://github.com/stashapp/stash/pull/1675)) diff --git a/ui/v2.5/src/components/List/Filters/HierarchicalLabelValueFilter.tsx b/ui/v2.5/src/components/List/Filters/HierarchicalLabelValueFilter.tsx index 151a813c6..9105beefc 100644 --- a/ui/v2.5/src/components/List/Filters/HierarchicalLabelValueFilter.tsx +++ b/ui/v2.5/src/components/List/Filters/HierarchicalLabelValueFilter.tsx @@ -1,6 +1,6 @@ import React from "react"; import { Form } from "react-bootstrap"; -import { defineMessages, useIntl } from "react-intl"; +import { defineMessages, MessageDescriptor, useIntl } from "react-intl"; import { FilterSelect, ValidTypes } from "../../Shared"; import { Criterion } from "../../../models/list-filter/criteria/criterion"; import { IHierarchicalLabelValue } from "../../../models/list-filter/types"; @@ -49,6 +49,16 @@ export const HierarchicalLabelValueFilter: React.FC @@ -63,7 +73,7 @@ export const HierarchicalLabelValueFilter: React.FC onDepthChanged(criterion.value.depth !== 0 ? 0 : -1)} /> diff --git a/ui/v2.5/src/components/Shared/DeleteEntityDialog.tsx b/ui/v2.5/src/components/Shared/DeleteEntityDialog.tsx index 1ab0432f4..2ba92061c 100644 --- a/ui/v2.5/src/components/Shared/DeleteEntityDialog.tsx +++ b/ui/v2.5/src/components/Shared/DeleteEntityDialog.tsx @@ -20,6 +20,7 @@ interface IDeleteEntityDialogProps { singularEntity: string; pluralEntity: string; destroyMutation: DestroyMutation; + onDeleted?: () => void; } const messages = defineMessages({ @@ -43,6 +44,7 @@ const DeleteEntityDialog: React.FC = ({ singularEntity, pluralEntity, destroyMutation, + onDeleted, }) => { const intl = useIntl(); const Toast = useToast(); @@ -56,6 +58,9 @@ const DeleteEntityDialog: React.FC = ({ setIsDeleting(true); try { await deleteEntities(); + if (onDeleted) { + onDeleted(); + } Toast.success({ content: intl.formatMessage(messages.deleteToast, { count, diff --git a/ui/v2.5/src/components/Studios/StudioDetails/Studio.tsx b/ui/v2.5/src/components/Studios/StudioDetails/Studio.tsx index 52323bc87..c0481ec2c 100644 --- a/ui/v2.5/src/components/Studios/StudioDetails/Studio.tsx +++ b/ui/v2.5/src/components/Studios/StudioDetails/Studio.tsx @@ -283,7 +283,7 @@ export const Studio: React.FC = () => { diff --git a/ui/v2.5/src/components/Tags/TagDetails/Tag.tsx b/ui/v2.5/src/components/Tags/TagDetails/Tag.tsx index 094dc8f91..fe4f47673 100644 --- a/ui/v2.5/src/components/Tags/TagDetails/Tag.tsx +++ b/ui/v2.5/src/components/Tags/TagDetails/Tag.tsx @@ -21,6 +21,7 @@ import { Icon, } from "src/components/Shared"; import { useToast } from "src/hooks"; +import { tagRelationHook } from "src/core/tags"; import { TagScenesPanel } from "./TagScenesPanel"; import { TagMarkersPanel } from "./TagMarkersPanel"; import { TagImagesPanel } from "./TagImagesPanel"; @@ -123,6 +124,10 @@ export const Tag: React.FC = () => { input: Partial ) { try { + const oldRelations = { + parents: tag?.parents ?? [], + children: tag?.children ?? [], + }; if (!isNew) { const result = await updateTag({ variables: { @@ -131,7 +136,12 @@ export const Tag: React.FC = () => { }); if (result.data?.tagUpdate) { setIsEditing(false); - return result.data.tagUpdate.id; + const updated = result.data.tagUpdate; + tagRelationHook(updated, oldRelations, { + parents: updated.parents, + children: updated.children, + }); + return updated.id; } } else { const result = await createTag({ @@ -141,7 +151,12 @@ export const Tag: React.FC = () => { }); if (result.data?.tagCreate?.id) { setIsEditing(false); - return result.data.tagCreate.id; + const created = result.data.tagCreate; + tagRelationHook(created, oldRelations, { + parents: created.parents, + children: created.children, + }); + return created.id; } } } catch (e) { @@ -161,7 +176,15 @@ export const Tag: React.FC = () => { async function onDelete() { try { + const oldRelations = { + parents: tag?.parents ?? [], + children: tag?.children ?? [], + }; await deleteTag(); + tagRelationHook(tag as GQL.TagDataFragment, oldRelations, { + parents: [], + children: [], + }); } catch (e) { Toast.error(e); } diff --git a/ui/v2.5/src/components/Tags/TagDetails/TagDetailsPanel.tsx b/ui/v2.5/src/components/Tags/TagDetails/TagDetailsPanel.tsx index 92d007ebb..c02570b85 100644 --- a/ui/v2.5/src/components/Tags/TagDetails/TagDetailsPanel.tsx +++ b/ui/v2.5/src/components/Tags/TagDetails/TagDetailsPanel.tsx @@ -1,6 +1,7 @@ import React from "react"; import { Badge } from "react-bootstrap"; import { FormattedMessage } from "react-intl"; +import { Link } from "react-router-dom"; import * as GQL from "src/core/generated-graphql"; interface ITagDetails { @@ -29,5 +30,53 @@ export const TagDetailsPanel: React.FC = ({ tag }) => { ); } - return <>{renderAliasesField()}; + function renderParentsField() { + if (!tag.parents?.length) { + return; + } + + return ( +
+
+ +
+
+ {tag.parents.map((p) => ( + + {p.name} + + ))} +
+
+ ); + } + + function renderChildrenField() { + if (!tag.children?.length) { + return; + } + + return ( +
+
+ +
+
+ {tag.children.map((c) => ( + + {c.name} + + ))} +
+
+ ); + } + + return ( + <> + {renderAliasesField()} + {renderParentsField()} + {renderChildrenField()} + + ); }; diff --git a/ui/v2.5/src/components/Tags/TagDetails/TagEditPanel.tsx b/ui/v2.5/src/components/Tags/TagDetails/TagEditPanel.tsx index 9eb5321c0..86b53558b 100644 --- a/ui/v2.5/src/components/Tags/TagDetails/TagEditPanel.tsx +++ b/ui/v2.5/src/components/Tags/TagDetails/TagEditPanel.tsx @@ -2,9 +2,9 @@ import React, { useEffect } from "react"; import { FormattedMessage, useIntl } from "react-intl"; import * as GQL from "src/core/generated-graphql"; import * as yup from "yup"; -import { DetailsEditNavbar } from "src/components/Shared"; +import { DetailsEditNavbar, TagSelect } from "src/components/Shared"; import { Form, Col, Row } from "react-bootstrap"; -import { ImageUtils } from "src/utils"; +import { FormUtils, ImageUtils } from "src/utils"; import { useFormik } from "formik"; import { Prompt, useHistory } from "react-router-dom"; import Mousetrap from "mousetrap"; @@ -51,11 +51,15 @@ export const TagEditPanel: React.FC = ({ }, message: "aliases must be unique", }), + parent_ids: yup.array(yup.string().required()).optional().nullable(), + child_ids: yup.array(yup.string().required()).optional().nullable(), }); const initialValues = { name: tag?.name, aliases: tag?.aliases, + parent_ids: (tag?.parents ?? []).map((t) => t.id), + child_ids: (tag?.children ?? []).map((t) => t.id), }; type InputValues = typeof initialValues; @@ -153,6 +157,60 @@ export const TagEditPanel: React.FC = ({ /> + + + {FormUtils.renderLabel({ + title: intl.formatMessage({ id: "parent_tags" }), + labelProps: { + column: true, + sm: 3, + xl: 12, + }, + })} + + + formik.setFieldValue( + "parent_ids", + items.map((item) => item.id) + ) + } + ids={formik.values.parent_ids} + excludeIds={(tag?.id ? [tag.id] : []).concat( + ...formik.values.child_ids + )} + creatable={false} + /> + + + + + {FormUtils.renderLabel({ + title: intl.formatMessage({ id: "sub_tags" }), + labelProps: { + column: true, + sm: 3, + xl: 12, + }, + })} + + + formik.setFieldValue( + "child_ids", + items.map((item) => item.id) + ) + } + ids={formik.values.child_ids} + excludeIds={(tag?.id ? [tag.id] : []).concat( + ...formik.values.parent_ids + )} + creatable={false} + /> + + = ({ tag }) => { ) { // add the tag if not present if ( - !tagCriterion.value.find((p) => { + !tagCriterion.value.items.find((p) => { return p.id === tag.id; }) ) { - tagCriterion.value.push(tagValue); + tagCriterion.value.items.push(tagValue); } tagCriterion.modifier = GQL.CriterionModifier.IncludesAll; } else { // overwrite tagCriterion = new TagsCriterion(TagsCriterionOption); - tagCriterion.value = [tagValue]; + tagCriterion.value = { + items: [tagValue], + depth: 0, + }; filter.criteria.push(tagCriterion); } diff --git a/ui/v2.5/src/components/Tags/TagList.tsx b/ui/v2.5/src/components/Tags/TagList.tsx index d96320e85..88e804a82 100644 --- a/ui/v2.5/src/components/Tags/TagList.tsx +++ b/ui/v2.5/src/components/Tags/TagList.tsx @@ -24,6 +24,7 @@ import { NavUtils } from "src/utils"; import { Icon, Modal, DeleteEntityDialog } from "src/components/Shared"; import { TagCard } from "./TagCard"; import { ExportDialog } from "../Shared/ExportDialog"; +import { tagRelationHook } from "../../core/tags"; interface ITagList { filterHook?: (filter: ListFilterModel) => ListFilterModel; @@ -138,6 +139,15 @@ export const TagList: React.FC = ({ filterHook }) => { singularEntity={intl.formatMessage({ id: "tag" })} pluralEntity={intl.formatMessage({ id: "tags" })} destroyMutation={useTagsDestroy} + onDeleted={() => { + selectedTags.forEach((t) => + tagRelationHook( + t, + { parents: t.parents ?? [], children: t.children ?? [] }, + { parents: [], children: [] } + ) + ); + }} /> ); @@ -175,7 +185,15 @@ export const TagList: React.FC = ({ filterHook }) => { async function onDelete() { try { + const oldRelations = { + parents: deletingTag?.parents ?? [], + children: deletingTag?.children ?? [], + }; await deleteTag(); + tagRelationHook(deletingTag as GQL.TagDataFragment, oldRelations, { + parents: [], + children: [], + }); Toast.success({ content: intl.formatMessage( { id: "toast.delete_past_tense" }, diff --git a/ui/v2.5/src/core/createClient.ts b/ui/v2.5/src/core/createClient.ts index ce976be77..ec4766f6c 100644 --- a/ui/v2.5/src/core/createClient.ts +++ b/ui/v2.5/src/core/createClient.ts @@ -68,6 +68,17 @@ const typePolicies: TypePolicies = { }, }, }, + + Tag: { + fields: { + parents: { + merge: false, + }, + children: { + merge: false, + }, + }, + }, }; export const getPlatformURL = (ws?: boolean) => { diff --git a/ui/v2.5/src/core/tags.ts b/ui/v2.5/src/core/tags.ts index 0c0afed61..d7365b8ac 100644 --- a/ui/v2.5/src/core/tags.ts +++ b/ui/v2.5/src/core/tags.ts @@ -1,4 +1,6 @@ +import { gql } from "@apollo/client"; import * as GQL from "src/core/generated-graphql"; +import { getClient } from "src/core/StashService"; import { TagsCriterion, TagsCriterionOption, @@ -20,21 +22,83 @@ export const tagFilterHook = (tag: GQL.TagDataFragment) => { ) { // add the tag if not present if ( - !tagCriterion.value.find((p) => { + !tagCriterion.value.items.find((p) => { return p.id === tag.id; }) ) { - tagCriterion.value.push(tagValue); + tagCriterion.value.items.push(tagValue); } tagCriterion.modifier = GQL.CriterionModifier.IncludesAll; } else { // overwrite tagCriterion = new TagsCriterion(TagsCriterionOption); - tagCriterion.value = [tagValue]; + tagCriterion.value = { + items: [tagValue], + depth: 0, + }; filter.criteria.push(tagCriterion); } return filter; }; }; + +interface ITagRelationTuple { + parents: GQL.SlimTagDataFragment[]; + children: GQL.SlimTagDataFragment[]; +} + +export const tagRelationHook = ( + tag: GQL.SlimTagDataFragment | GQL.TagDataFragment, + old: ITagRelationTuple, + updated: ITagRelationTuple +) => { + const { cache } = getClient(); + + const tagRef = cache.writeFragment({ + data: tag, + fragment: gql` + fragment Tag on Tag { + id + } + `, + }); + + function updater( + property: "parents" | "children", + oldTags: GQL.SlimTagDataFragment[], + updatedTags: GQL.SlimTagDataFragment[] + ) { + oldTags.forEach((o) => { + if (!updatedTags.some((u) => u.id === o.id)) { + cache.modify({ + id: cache.identify(o), + fields: { + [property](value, { readField }) { + return value.filter( + (t: GQL.SlimTagDataFragment) => readField("id", t) !== tag.id + ); + }, + }, + }); + } + }); + + updatedTags.forEach((u) => { + if (!oldTags.some((o) => o.id === u.id)) { + cache.modify({ + id: cache.identify(u), + fields: { + [property](value) { + return [...value, tagRef]; + }, + }, + }); + } + }); + } + + updater("children", old.parents, updated.parents); + updater("parents", old.children, updated.children); +}; diff --git a/ui/v2.5/src/locales/de-DE.json b/ui/v2.5/src/locales/de-DE.json index 32bd74623..312206a08 100644 --- a/ui/v2.5/src/locales/de-DE.json +++ b/ui/v2.5/src/locales/de-DE.json @@ -91,7 +91,7 @@ "birthdate": "Geburtsdatum", "bitrate": "Bitrate", "career_length": "Länge der Karriere", - "child_studios": "Tochterstudios", + "subsidiary_studios": "Tochterstudios", "component_tagger": { "config": { "active_instance": "Aktive stash-box Instanz:", @@ -501,7 +501,6 @@ "image_count": "Bilderanzahl", "images": "Bilder", "images-size": "Bildgröße", - "include_child_studios": "Tochterstudios einbeziehen", "instagram": "Instagram", "interactive": "Interaktiv", "isMissing": "Wird vermisst", diff --git a/ui/v2.5/src/locales/en-GB.json b/ui/v2.5/src/locales/en-GB.json index 060032cdc..0251ff1be 100644 --- a/ui/v2.5/src/locales/en-GB.json +++ b/ui/v2.5/src/locales/en-GB.json @@ -95,7 +95,8 @@ "birthdate": "Birthdate", "bitrate": "Bit Rate", "career_length": "Career Length", - "child_studios": "Child Studios", + "subsidiary_studios": "Subsidiary Studios", + "sub_tags": "Sub-Tags", "component_tagger": { "config": { "active_instance": "Active stash-box instance:", @@ -529,7 +530,8 @@ "image_count": "Image Count", "images": "Images", "images-size": "Images size", - "include_child_studios": "Include child studios", + "include_sub_studios": "Include subsidiary studios", + "include_sub_tags": "Include sub-tags", "instagram": "Instagram", "interactive": "Interactive", "isMissing": "Is Missing", @@ -571,6 +573,7 @@ "previous": "Previous" }, "parent_studios": "Parent Studios", + "parent_tags": "Parent Tags", "path": "Path", "performer": "Performer", "performer_count": "Performer Count", diff --git a/ui/v2.5/src/locales/pt-BR.json b/ui/v2.5/src/locales/pt-BR.json index 64ce8e204..75d3af9ea 100644 --- a/ui/v2.5/src/locales/pt-BR.json +++ b/ui/v2.5/src/locales/pt-BR.json @@ -91,7 +91,7 @@ "birthdate": "Data de nascimento", "bitrate": "Taxa de bits", "career_length": "Duração da carreira", - "child_studios": "Estúdios filhos", + "subsidiary_studios": "Estúdios filhos", "component_tagger": { "config": { "active_instance": "Ativar stash-box:", @@ -501,7 +501,6 @@ "image_count": "Contagem de imagem", "images": "Imagens", "images-size": "Tamanho das imagens", - "include_child_studios": "Incluem estúdios filho", "instagram": "Instagram", "interactive": "Interativo", "isMissing": "Está faltando", diff --git a/ui/v2.5/src/locales/zh-CN.json b/ui/v2.5/src/locales/zh-CN.json index 3e12dd8d9..4f39bdc1f 100644 --- a/ui/v2.5/src/locales/zh-CN.json +++ b/ui/v2.5/src/locales/zh-CN.json @@ -95,7 +95,7 @@ "birthdate": "出生日期", "bitrate": "比特率", "career_length": "活跃年代", - "child_studios": "子工作室", + "subsidiary_studios": "子工作室", "component_tagger": { "config": { "active_instance": "目前使用的 Stash-box:", @@ -529,7 +529,6 @@ "image_count": "图片数量", "images": "图片", "images-size": "图片大小", - "include_child_studios": "包含子工作室", "instagram": "Instagram", "interactive": "互动", "isMissing": "缺失", diff --git a/ui/v2.5/src/locales/zh-TW.json b/ui/v2.5/src/locales/zh-TW.json index 07f443c52..999ddbaa3 100644 --- a/ui/v2.5/src/locales/zh-TW.json +++ b/ui/v2.5/src/locales/zh-TW.json @@ -95,7 +95,7 @@ "birthdate": "出生日期", "bitrate": "位元率", "career_length": "活躍年代", - "child_studios": "子工作室", + "subsidiary_studios": "子工作室", "component_tagger": { "config": { "active_instance": "目前使用的 Stash-box:", @@ -614,7 +614,6 @@ "filter": "過濾", "filter_name": "過濾條件名稱", "detail": "詳情", - "include_child_studios": "包含子工作室", "criterion": { "greater_than": "大於", "less_than": "小於", diff --git a/ui/v2.5/src/models/list-filter/criteria/tags.ts b/ui/v2.5/src/models/list-filter/criteria/tags.ts index 04545440e..50646a52a 100644 --- a/ui/v2.5/src/models/list-filter/criteria/tags.ts +++ b/ui/v2.5/src/models/list-filter/criteria/tags.ts @@ -1,6 +1,9 @@ -import { ILabeledIdCriterion, ILabeledIdCriterionOption } from "./criterion"; +import { + IHierarchicalLabeledIdCriterion, + ILabeledIdCriterionOption, +} from "./criterion"; -export class TagsCriterion extends ILabeledIdCriterion {} +export class TagsCriterion extends IHierarchicalLabeledIdCriterion {} export const TagsCriterionOption = new ILabeledIdCriterionOption( "tags", diff --git a/ui/v2.5/src/utils/navigation.ts b/ui/v2.5/src/utils/navigation.ts index f4af524cb..929b76654 100644 --- a/ui/v2.5/src/utils/navigation.ts +++ b/ui/v2.5/src/utils/navigation.ts @@ -143,7 +143,10 @@ const makeTagScenesUrl = (tag: Partial) => { if (!tag.id) return "#"; const filter = new ListFilterModel(GQL.FilterMode.Scenes); const criterion = new TagsCriterion(TagsCriterionOption); - criterion.value = [{ id: tag.id, label: tag.name || `Tag ${tag.id}` }]; + criterion.value = { + items: [{ id: tag.id, label: tag.name || `Tag ${tag.id}` }], + depth: 0, + }; filter.criteria.push(criterion); return `/scenes?${filter.makeQueryParameters()}`; }; @@ -152,7 +155,10 @@ const makeTagPerformersUrl = (tag: Partial) => { if (!tag.id) return "#"; const filter = new ListFilterModel(GQL.FilterMode.Performers); const criterion = new TagsCriterion(TagsCriterionOption); - criterion.value = [{ id: tag.id, label: tag.name || `Tag ${tag.id}` }]; + criterion.value = { + items: [{ id: tag.id, label: tag.name || `Tag ${tag.id}` }], + depth: 0, + }; filter.criteria.push(criterion); return `/performers?${filter.makeQueryParameters()}`; }; @@ -161,7 +167,10 @@ const makeTagSceneMarkersUrl = (tag: Partial) => { if (!tag.id) return "#"; const filter = new ListFilterModel(GQL.FilterMode.SceneMarkers); const criterion = new TagsCriterion(TagsCriterionOption); - criterion.value = [{ id: tag.id, label: tag.name || `Tag ${tag.id}` }]; + criterion.value = { + items: [{ id: tag.id, label: tag.name || `Tag ${tag.id}` }], + depth: 0, + }; filter.criteria.push(criterion); return `/scenes/markers?${filter.makeQueryParameters()}`; }; @@ -170,7 +179,10 @@ const makeTagGalleriesUrl = (tag: Partial) => { if (!tag.id) return "#"; const filter = new ListFilterModel(GQL.FilterMode.Galleries); const criterion = new TagsCriterion(TagsCriterionOption); - criterion.value = [{ id: tag.id, label: tag.name || `Tag ${tag.id}` }]; + criterion.value = { + items: [{ id: tag.id, label: tag.name || `Tag ${tag.id}` }], + depth: 0, + }; filter.criteria.push(criterion); return `/galleries?${filter.makeQueryParameters()}`; }; @@ -179,7 +191,10 @@ const makeTagImagesUrl = (tag: Partial) => { if (!tag.id) return "#"; const filter = new ListFilterModel(GQL.FilterMode.Images); const criterion = new TagsCriterion(TagsCriterionOption); - criterion.value = [{ id: tag.id, label: tag.name || `Tag ${tag.id}` }]; + criterion.value = { + items: [{ id: tag.id, label: tag.name || `Tag ${tag.id}` }], + depth: 0, + }; filter.criteria.push(criterion); return `/images?${filter.makeQueryParameters()}`; };