Fix Performer Studio filtering (#1483)

* Fix performer studio filtering

* Fix studio filter hook
This commit is contained in:
WithoutPants
2021-06-06 15:05:05 +10:00
committed by GitHub
parent 732cc57149
commit c53799c25b
6 changed files with 115 additions and 52 deletions

View File

@@ -86,7 +86,7 @@ input PerformerFilterType {
"""Filter by death year""" """Filter by death year"""
death_year: IntCriterionInput death_year: IntCriterionInput
"""Filter by studios where performer appears in scene/image/gallery""" """Filter by studios where performer appears in scene/image/gallery"""
studios: MultiCriterionInput studios: HierarchicalMultiCriterionInput
} }
input SceneMarkerFilterType { input SceneMarkerFilterType {

View File

@@ -7,6 +7,7 @@ import (
"strings" "strings"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/utils"
) )
type sqlClause struct { type sqlClause struct {
@@ -540,41 +541,47 @@ type hierarchicalMultiCriterionHandlerBuilder struct {
parentFK string parentFK string
} }
func addHierarchicalWithClause(f *filterBuilder, value []string, derivedTable, table, parentFK string, depth int) {
var args []interface{}
for _, value := range value {
args = append(args, value)
}
inCount := len(args)
var depthCondition string
if depth != -1 {
depthCondition = fmt.Sprintf("WHERE depth < %d", depth)
}
withClause := utils.StrFormat(`RECURSIVE {derivedTable} AS (
SELECT id as id, id as child_id, 0 as depth FROM {table}
WHERE id in {inBinding}
UNION SELECT p.id, c.id, depth + 1 FROM {table} as c
INNER JOIN {derivedTable} as p ON c.{parentFK} = p.child_id {depthCondition})
`, utils.StrFormatMap{
"derivedTable": derivedTable,
"table": table,
"inBinding": getInBinding(inCount),
"parentFK": parentFK,
"depthCondition": depthCondition,
})
f.addWith(withClause, args...)
}
func (m *hierarchicalMultiCriterionHandlerBuilder) handler(criterion *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { func (m *hierarchicalMultiCriterionHandlerBuilder) handler(criterion *models.HierarchicalMultiCriterionInput) criterionHandlerFunc {
return func(f *filterBuilder) { return func(f *filterBuilder) {
if criterion != nil && len(criterion.Value) > 0 { if criterion != nil && len(criterion.Value) > 0 {
var args []interface{} addHierarchicalWithClause(f, criterion.Value, m.derivedTable, m.foreignTable, m.parentFK, criterion.Depth)
for _, value := range criterion.Value {
args = append(args, value)
}
inCount := len(args)
f.addJoin(m.derivedTable, "", fmt.Sprintf("%s.child_id = %s.%s", m.derivedTable, m.primaryTable, m.foreignFK)) f.addJoin(m.derivedTable, "", fmt.Sprintf("%s.child_id = %s.%s", m.derivedTable, m.primaryTable, m.foreignFK))
var depthCondition string
if criterion.Depth != -1 {
depthCondition = "WHERE depth < ?"
args = append(args, criterion.Depth)
}
withClause := fmt.Sprintf(
"RECURSIVE %s AS (SELECT id as id, id as child_id, 0 as depth FROM %s WHERE id in %s UNION SELECT p.id, c.id, depth + 1 FROM %s as c INNER JOIN %s as p ON c.%s = p.child_id %s)",
m.derivedTable,
m.foreignTable,
getInBinding(inCount),
m.foreignTable,
m.derivedTable,
m.parentFK,
depthCondition,
)
f.addWith(withClause, args...)
if criterion.Modifier == models.CriterionModifierIncludes { if criterion.Modifier == models.CriterionModifierIncludes {
f.addWhere(fmt.Sprintf("%s.id IS NOT NULL", m.derivedTable)) f.addWhere(fmt.Sprintf("%s.id IS NOT NULL", m.derivedTable))
} else if criterion.Modifier == models.CriterionModifierIncludesAll { } else if criterion.Modifier == models.CriterionModifierIncludesAll {
f.addWhere(fmt.Sprintf("%s.id IS NOT NULL", m.derivedTable)) f.addWhere(fmt.Sprintf("%s.id IS NOT NULL", m.derivedTable))
f.addHaving(fmt.Sprintf("count(distinct %s.id) IS %d", m.derivedTable, inCount)) f.addHaving(fmt.Sprintf("count(distinct %s.id) IS %d", m.derivedTable, len(criterion.Value)))
} else if criterion.Modifier == models.CriterionModifierExcludes { } else if criterion.Modifier == models.CriterionModifierExcludes {
f.addWhere(fmt.Sprintf("%s.id IS NULL", m.derivedTable)) f.addWhere(fmt.Sprintf("%s.id IS NULL", m.derivedTable))
} }

View File

@@ -7,6 +7,7 @@ import (
"strings" "strings"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/utils"
) )
const performerTable = "performers" const performerTable = "performers"
@@ -454,7 +455,7 @@ func performerGalleryCountCriterionHandler(qb *performerQueryBuilder, count *mod
return h.handler(count) return h.handler(count)
} }
func performerStudiosCriterionHandler(studios *models.MultiCriterionInput) criterionHandlerFunc { func performerStudiosCriterionHandler(studios *models.HierarchicalMultiCriterionInput) criterionHandlerFunc {
return func(f *filterBuilder) { return func(f *filterBuilder) {
if studios != nil { if studios != nil {
var countCondition string var countCondition string
@@ -465,32 +466,72 @@ func performerStudiosCriterionHandler(studios *models.MultiCriterionInput) crite
countCondition = " > 0" countCondition = " > 0"
clauseJoin = " OR " clauseJoin = " OR "
} else if studios.Modifier == models.CriterionModifierExcludes { } else if studios.Modifier == models.CriterionModifierExcludes {
// exclude performers who appear in scenes/images/galleries with any of the given studios // exclude performers who appear in scenes/images/galleries with any of the given studios
countCondition = " = 0" countCondition = " = 0"
clauseJoin = " AND " clauseJoin = " AND "
} else { } else {
return return
} }
templStr := "(SELECT COUNT(DISTINCT %[1]s.id) FROM %[1]s LEFT JOIN %[2]s ON %[1]s.id = %[2]s.%[3]s WHERE %[2]s.performer_id = performers.id AND %[1]s.studio_id IN %[4]s)" + countCondition
inBinding := getInBinding(len(studios.Value)) inBinding := getInBinding(len(studios.Value))
clauses := []string{ formatMaps := []utils.StrFormatMap{
fmt.Sprintf(templStr, sceneTable, performersScenesTable, sceneIDColumn, inBinding), {
fmt.Sprintf(templStr, imageTable, performersImagesTable, imageIDColumn, inBinding), "primaryTable": sceneTable,
fmt.Sprintf(templStr, galleryTable, performersGalleriesTable, galleryIDColumn, inBinding), "joinTable": performersScenesTable,
"primaryFK": sceneIDColumn,
"inBinding": inBinding,
},
{
"primaryTable": imageTable,
"joinTable": performersImagesTable,
"primaryFK": imageIDColumn,
"inBinding": inBinding,
},
{
"primaryTable": galleryTable,
"joinTable": performersGalleriesTable,
"primaryFK": galleryIDColumn,
"inBinding": inBinding,
},
} }
var args []interface{} // keeping existing behaviour for depth = 0 since it seems slightly better performing
for _, tagID := range studios.Value { if studios.Depth == 0 {
args = append(args, tagID) templStr := `(SELECT COUNT(DISTINCT {primaryTable}.id) FROM {primaryTable}
LEFT JOIN {joinTable} ON {primaryTable}.id = {joinTable}.{primaryFK}
WHERE {joinTable}.performer_id = performers.id AND {primaryTable}.studio_id IN {inBinding})` + countCondition
var clauses []string
for _, c := range formatMaps {
clauses = append(clauses, utils.StrFormat(templStr, c))
}
var args []interface{}
for _, tagID := range studios.Value {
args = append(args, tagID)
}
// this is a bit gross. We need the args three times
combinedArgs := append(args, append(args, args...)...)
f.addWhere(fmt.Sprintf("(%s)", strings.Join(clauses, clauseJoin)), combinedArgs...)
} else {
const derivedTable = "studio"
addHierarchicalWithClause(f, studios.Value, derivedTable, studioTable, "parent_id", studios.Depth)
templStr := `(SELECT COUNT(DISTINCT {primaryTable}.id) FROM {primaryTable}
LEFT JOIN {joinTable} ON {primaryTable}.id = {joinTable}.{primaryFK}
INNER JOIN studio ON {primaryTable}.studio_id = studio.child_id
WHERE {joinTable}.performer_id = performers.id)` + countCondition
var clauses []string
for _, c := range formatMaps {
clauses = append(clauses, utils.StrFormat(templStr, c))
}
f.addWhere(fmt.Sprintf("(%s)", strings.Join(clauses, clauseJoin)))
} }
// this is a bit gross. We need the args three times
combinedArgs := append(args, append(args, args...)...)
f.addWhere(fmt.Sprintf("(%s)", strings.Join(clauses, clauseJoin)), combinedArgs...)
} }
} }
} }

View File

@@ -746,7 +746,7 @@ func TestPerformerQueryStudio(t *testing.T) {
sqb := r.Performer() sqb := r.Performer()
for _, tc := range testCases { for _, tc := range testCases {
studioCriterion := models.MultiCriterionInput{ studioCriterion := models.HierarchicalMultiCriterionInput{
Value: []string{ Value: []string{
strconv.Itoa(studioIDs[tc.studioIndex]), strconv.Itoa(studioIDs[tc.studioIndex]),
}, },
@@ -764,7 +764,7 @@ func TestPerformerQueryStudio(t *testing.T) {
// ensure id is correct // ensure id is correct
assert.Equal(t, performerIDs[tc.performerIndex], performers[0].ID) assert.Equal(t, performerIDs[tc.performerIndex], performers[0].ID)
studioCriterion = models.MultiCriterionInput{ studioCriterion = models.HierarchicalMultiCriterionInput{
Value: []string{ Value: []string{
strconv.Itoa(studioIDs[tc.studioIndex]), strconv.Itoa(studioIDs[tc.studioIndex]),
}, },

View File

@@ -1,7 +1,9 @@
package utils package utils
import ( import (
"fmt"
"math/rand" "math/rand"
"strings"
"time" "time"
) )
@@ -15,3 +17,18 @@ func RandomSequence(n int) string {
} }
return string(b) return string(b)
} }
type StrFormatMap map[string]interface{}
func StrFormat(format string, m StrFormatMap) string {
args := make([]string, len(m)*2)
i := 0
for k, v := range m {
args[i] = fmt.Sprintf("{%s}", k)
args[i+1] = fmt.Sprint(v)
i += 2
}
return strings.NewReplacer(args...).Replace(format)
}

View File

@@ -15,16 +15,14 @@ export const studioFilterHook = (studio: Partial<GQL.StudioDataFragment>) => {
(studioCriterion.modifier === GQL.CriterionModifier.IncludesAll || (studioCriterion.modifier === GQL.CriterionModifier.IncludesAll ||
studioCriterion.modifier === GQL.CriterionModifier.Includes) studioCriterion.modifier === GQL.CriterionModifier.Includes)
) { ) {
// add the studio if not present // we should be showing studio only. Remove other values
if ( studioCriterion.value.items = studioCriterion.value.items.filter(
!studioCriterion.value.items.find((p) => { (v) => v.id === studio.id
return p.id === studio.id; );
})
) { if (studioCriterion.value.items.length === 0) {
studioCriterion.value.items.push(studioValue); studioCriterion.value.items.push(studioValue);
} }
studioCriterion.modifier = GQL.CriterionModifier.IncludesAll;
} else { } else {
// overwrite // overwrite
studioCriterion = new StudiosCriterion(); studioCriterion = new StudiosCriterion();