From c53799c25bf2833eec27fd2689452bb24df1589d Mon Sep 17 00:00:00 2001 From: WithoutPants <53250216+WithoutPants@users.noreply.github.com> Date: Sun, 6 Jun 2021 15:05:05 +1000 Subject: [PATCH] Fix Performer Studio filtering (#1483) * Fix performer studio filtering * Fix studio filter hook --- graphql/schema/types/filters.graphql | 2 +- pkg/sqlite/filter.go | 57 ++++++++++++---------- pkg/sqlite/performer.go | 73 ++++++++++++++++++++++------ pkg/sqlite/performer_test.go | 4 +- pkg/utils/strings.go | 17 +++++++ ui/v2.5/src/core/studios.ts | 14 +++--- 6 files changed, 115 insertions(+), 52 deletions(-) diff --git a/graphql/schema/types/filters.graphql b/graphql/schema/types/filters.graphql index 10d614008..d6466ec83 100644 --- a/graphql/schema/types/filters.graphql +++ b/graphql/schema/types/filters.graphql @@ -86,7 +86,7 @@ input PerformerFilterType { """Filter by death year""" death_year: IntCriterionInput """Filter by studios where performer appears in scene/image/gallery""" - studios: MultiCriterionInput + studios: HierarchicalMultiCriterionInput } input SceneMarkerFilterType { diff --git a/pkg/sqlite/filter.go b/pkg/sqlite/filter.go index 5c28d9146..ebe10fd8d 100644 --- a/pkg/sqlite/filter.go +++ b/pkg/sqlite/filter.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/utils" ) type sqlClause struct { @@ -540,41 +541,47 @@ type hierarchicalMultiCriterionHandlerBuilder struct { parentFK string } +func addHierarchicalWithClause(f *filterBuilder, value []string, derivedTable, table, parentFK string, depth int) { + var args []interface{} + + for _, value := range value { + args = append(args, value) + } + inCount := len(args) + + var depthCondition string + if depth != -1 { + depthCondition = fmt.Sprintf("WHERE depth < %d", depth) + } + + withClause := utils.StrFormat(`RECURSIVE {derivedTable} AS ( +SELECT id as id, id as child_id, 0 as depth FROM {table} +WHERE id in {inBinding} +UNION SELECT p.id, c.id, depth + 1 FROM {table} as c +INNER JOIN {derivedTable} as p ON c.{parentFK} = p.child_id {depthCondition}) +`, utils.StrFormatMap{ + "derivedTable": derivedTable, + "table": table, + "inBinding": getInBinding(inCount), + "parentFK": parentFK, + "depthCondition": depthCondition, + }) + + f.addWith(withClause, args...) +} + func (m *hierarchicalMultiCriterionHandlerBuilder) handler(criterion *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { return func(f *filterBuilder) { if criterion != nil && len(criterion.Value) > 0 { - var args []interface{} - for _, value := range criterion.Value { - args = append(args, value) - } - inCount := len(args) + addHierarchicalWithClause(f, criterion.Value, m.derivedTable, m.foreignTable, m.parentFK, criterion.Depth) f.addJoin(m.derivedTable, "", fmt.Sprintf("%s.child_id = %s.%s", m.derivedTable, m.primaryTable, m.foreignFK)) - var depthCondition string - if criterion.Depth != -1 { - depthCondition = "WHERE depth < ?" - args = append(args, criterion.Depth) - } - - withClause := fmt.Sprintf( - "RECURSIVE %s AS (SELECT id as id, id as child_id, 0 as depth FROM %s WHERE id in %s UNION SELECT p.id, c.id, depth + 1 FROM %s as c INNER JOIN %s as p ON c.%s = p.child_id %s)", - m.derivedTable, - m.foreignTable, - getInBinding(inCount), - m.foreignTable, - m.derivedTable, - m.parentFK, - depthCondition, - ) - - f.addWith(withClause, args...) - if criterion.Modifier == models.CriterionModifierIncludes { f.addWhere(fmt.Sprintf("%s.id IS NOT NULL", m.derivedTable)) } else if criterion.Modifier == models.CriterionModifierIncludesAll { f.addWhere(fmt.Sprintf("%s.id IS NOT NULL", m.derivedTable)) - f.addHaving(fmt.Sprintf("count(distinct %s.id) IS %d", m.derivedTable, inCount)) + f.addHaving(fmt.Sprintf("count(distinct %s.id) IS %d", m.derivedTable, len(criterion.Value))) } else if criterion.Modifier == models.CriterionModifierExcludes { f.addWhere(fmt.Sprintf("%s.id IS NULL", m.derivedTable)) } diff --git a/pkg/sqlite/performer.go b/pkg/sqlite/performer.go index f15a120e9..2b6c8d4dd 100644 --- a/pkg/sqlite/performer.go +++ b/pkg/sqlite/performer.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/utils" ) const performerTable = "performers" @@ -454,7 +455,7 @@ func performerGalleryCountCriterionHandler(qb *performerQueryBuilder, count *mod return h.handler(count) } -func performerStudiosCriterionHandler(studios *models.MultiCriterionInput) criterionHandlerFunc { +func performerStudiosCriterionHandler(studios *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { return func(f *filterBuilder) { if studios != nil { var countCondition string @@ -465,32 +466,72 @@ func performerStudiosCriterionHandler(studios *models.MultiCriterionInput) crite countCondition = " > 0" clauseJoin = " OR " } else if studios.Modifier == models.CriterionModifierExcludes { - // exclude performers who appear in scenes/images/galleries with any of the given studios + // exclude performers who appear in scenes/images/galleries with any of the given studios countCondition = " = 0" clauseJoin = " AND " } else { return } - templStr := "(SELECT COUNT(DISTINCT %[1]s.id) FROM %[1]s LEFT JOIN %[2]s ON %[1]s.id = %[2]s.%[3]s WHERE %[2]s.performer_id = performers.id AND %[1]s.studio_id IN %[4]s)" + countCondition - inBinding := getInBinding(len(studios.Value)) - clauses := []string{ - fmt.Sprintf(templStr, sceneTable, performersScenesTable, sceneIDColumn, inBinding), - fmt.Sprintf(templStr, imageTable, performersImagesTable, imageIDColumn, inBinding), - fmt.Sprintf(templStr, galleryTable, performersGalleriesTable, galleryIDColumn, inBinding), + formatMaps := []utils.StrFormatMap{ + { + "primaryTable": sceneTable, + "joinTable": performersScenesTable, + "primaryFK": sceneIDColumn, + "inBinding": inBinding, + }, + { + "primaryTable": imageTable, + "joinTable": performersImagesTable, + "primaryFK": imageIDColumn, + "inBinding": inBinding, + }, + { + "primaryTable": galleryTable, + "joinTable": performersGalleriesTable, + "primaryFK": galleryIDColumn, + "inBinding": inBinding, + }, } - var args []interface{} - for _, tagID := range studios.Value { - args = append(args, tagID) + // keeping existing behaviour for depth = 0 since it seems slightly better performing + if studios.Depth == 0 { + templStr := `(SELECT COUNT(DISTINCT {primaryTable}.id) FROM {primaryTable} +LEFT JOIN {joinTable} ON {primaryTable}.id = {joinTable}.{primaryFK} +WHERE {joinTable}.performer_id = performers.id AND {primaryTable}.studio_id IN {inBinding})` + countCondition + + var clauses []string + for _, c := range formatMaps { + clauses = append(clauses, utils.StrFormat(templStr, c)) + } + + var args []interface{} + for _, tagID := range studios.Value { + args = append(args, tagID) + } + + // this is a bit gross. We need the args three times + combinedArgs := append(args, append(args, args...)...) + + f.addWhere(fmt.Sprintf("(%s)", strings.Join(clauses, clauseJoin)), combinedArgs...) + } else { + const derivedTable = "studio" + addHierarchicalWithClause(f, studios.Value, derivedTable, studioTable, "parent_id", studios.Depth) + + templStr := `(SELECT COUNT(DISTINCT {primaryTable}.id) FROM {primaryTable} + LEFT JOIN {joinTable} ON {primaryTable}.id = {joinTable}.{primaryFK} + INNER JOIN studio ON {primaryTable}.studio_id = studio.child_id + WHERE {joinTable}.performer_id = performers.id)` + countCondition + + var clauses []string + for _, c := range formatMaps { + clauses = append(clauses, utils.StrFormat(templStr, c)) + } + + f.addWhere(fmt.Sprintf("(%s)", strings.Join(clauses, clauseJoin))) } - - // this is a bit gross. We need the args three times - combinedArgs := append(args, append(args, args...)...) - - f.addWhere(fmt.Sprintf("(%s)", strings.Join(clauses, clauseJoin)), combinedArgs...) } } } diff --git a/pkg/sqlite/performer_test.go b/pkg/sqlite/performer_test.go index 128bcdfab..c5292f384 100644 --- a/pkg/sqlite/performer_test.go +++ b/pkg/sqlite/performer_test.go @@ -746,7 +746,7 @@ func TestPerformerQueryStudio(t *testing.T) { sqb := r.Performer() for _, tc := range testCases { - studioCriterion := models.MultiCriterionInput{ + studioCriterion := models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(studioIDs[tc.studioIndex]), }, @@ -764,7 +764,7 @@ func TestPerformerQueryStudio(t *testing.T) { // ensure id is correct assert.Equal(t, performerIDs[tc.performerIndex], performers[0].ID) - studioCriterion = models.MultiCriterionInput{ + studioCriterion = models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(studioIDs[tc.studioIndex]), }, diff --git a/pkg/utils/strings.go b/pkg/utils/strings.go index f0b8d93d4..ba6a655fa 100644 --- a/pkg/utils/strings.go +++ b/pkg/utils/strings.go @@ -1,7 +1,9 @@ package utils import ( + "fmt" "math/rand" + "strings" "time" ) @@ -15,3 +17,18 @@ func RandomSequence(n int) string { } return string(b) } + +type StrFormatMap map[string]interface{} + +func StrFormat(format string, m StrFormatMap) string { + args := make([]string, len(m)*2) + i := 0 + + for k, v := range m { + args[i] = fmt.Sprintf("{%s}", k) + args[i+1] = fmt.Sprint(v) + i += 2 + } + + return strings.NewReplacer(args...).Replace(format) +} diff --git a/ui/v2.5/src/core/studios.ts b/ui/v2.5/src/core/studios.ts index 29a66dc4e..48d7b1b2a 100644 --- a/ui/v2.5/src/core/studios.ts +++ b/ui/v2.5/src/core/studios.ts @@ -15,16 +15,14 @@ export const studioFilterHook = (studio: Partial) => { (studioCriterion.modifier === GQL.CriterionModifier.IncludesAll || studioCriterion.modifier === GQL.CriterionModifier.Includes) ) { - // add the studio if not present - if ( - !studioCriterion.value.items.find((p) => { - return p.id === studio.id; - }) - ) { + // we should be showing studio only. Remove other values + studioCriterion.value.items = studioCriterion.value.items.filter( + (v) => v.id === studio.id + ); + + if (studioCriterion.value.items.length === 0) { studioCriterion.value.items.push(studioValue); } - - studioCriterion.modifier = GQL.CriterionModifier.IncludesAll; } else { // overwrite studioCriterion = new StudiosCriterion();