Containing Group/Sub-Group relationships (#5105)

* Add UI support for setting containing groups
* Show containing groups in group details panel
* Move tag hierarchical filter code into separate type
* Add depth to scene_count and add sub_group_count
* Add sub-groups tab to groups page
* Add containing groups to edit groups dialog
* Show containing group description in sub-group view
* Show group scene number in group scenes view
* Add ability to drag move grid cards
* Add sub group order option
* Add reorder sub-groups interface
* Separate page size selector component
* Add interfaces to add and remove sub-groups to a group
* Separate MultiSet components
* Allow setting description while setting containing groups
This commit is contained in:
WithoutPants
2024-08-30 11:43:44 +10:00
committed by GitHub
parent 96fdd94a01
commit bcf0fda7ac
99 changed files with 5388 additions and 935 deletions

View File

@@ -30,7 +30,7 @@ const (
dbConnTimeout = 30
)
var appSchemaVersion uint = 66
var appSchemaVersion uint = 67
//go:embed migrations/*.sql
var migrationsBox embed.FS

View File

@@ -0,0 +1,222 @@
package sqlite
import (
"context"
"fmt"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/utils"
)
// hierarchicalRelationshipHandler provides handlers for parent, children, parent count, and child count criteria.
type hierarchicalRelationshipHandler struct {
primaryTable string
relationTable string
aliasPrefix string
parentIDCol string
childIDCol string
}
func (h hierarchicalRelationshipHandler) validateModifier(m models.CriterionModifier) error {
switch m {
case models.CriterionModifierIncludesAll, models.CriterionModifierIncludes, models.CriterionModifierExcludes, models.CriterionModifierIsNull, models.CriterionModifierNotNull:
// valid
return nil
default:
return fmt.Errorf("invalid modifier %s", m)
}
}
func (h hierarchicalRelationshipHandler) handleNullNotNull(f *filterBuilder, m models.CriterionModifier, isParents bool) {
var notClause string
if m == models.CriterionModifierNotNull {
notClause = "NOT"
}
as := h.aliasPrefix + "_parents"
col := h.childIDCol
if !isParents {
as = h.aliasPrefix + "_children"
col = h.parentIDCol
}
// Based on:
// f.addLeftJoin("tags_relations", "parent_relations", "tags.id = parent_relations.child_id")
// f.addWhere(fmt.Sprintf("parent_relations.parent_id IS %s NULL", notClause))
f.addLeftJoin(h.relationTable, as, fmt.Sprintf("%s.id = %s.%s", h.primaryTable, as, col))
f.addWhere(fmt.Sprintf("%s.%s IS %s NULL", as, col, notClause))
}
func (h hierarchicalRelationshipHandler) parentsAlias() string {
return h.aliasPrefix + "_parents"
}
func (h hierarchicalRelationshipHandler) childrenAlias() string {
return h.aliasPrefix + "_children"
}
func (h hierarchicalRelationshipHandler) valueQuery(value []string, depth int, alias string, isParents bool) string {
var depthCondition string
if depth != -1 {
depthCondition = fmt.Sprintf("WHERE depth < %d", depth)
}
queryTempl := `{alias} AS (
SELECT {root_id_col} AS root_id, {item_id_col} AS item_id, 0 AS depth FROM {relation_table} WHERE {root_id_col} IN` + getInBinding(len(value)) + `
UNION
SELECT root_id, {item_id_col}, depth + 1 FROM {relation_table} INNER JOIN {alias} ON item_id = {root_id_col} ` + depthCondition + `
)`
var queryMap utils.StrFormatMap
if isParents {
queryMap = utils.StrFormatMap{
"root_id_col": h.parentIDCol,
"item_id_col": h.childIDCol,
}
} else {
queryMap = utils.StrFormatMap{
"root_id_col": h.childIDCol,
"item_id_col": h.parentIDCol,
}
}
queryMap["alias"] = alias
queryMap["relation_table"] = h.relationTable
return utils.StrFormat(queryTempl, queryMap)
}
func (h hierarchicalRelationshipHandler) handleValues(f *filterBuilder, c models.HierarchicalMultiCriterionInput, isParents bool, aliasSuffix string) {
if len(c.Value) == 0 {
return
}
var args []interface{}
for _, val := range c.Value {
args = append(args, val)
}
depthVal := 0
if c.Depth != nil {
depthVal = *c.Depth
}
tableAlias := h.parentsAlias()
if !isParents {
tableAlias = h.childrenAlias()
}
tableAlias += aliasSuffix
query := h.valueQuery(c.Value, depthVal, tableAlias, isParents)
f.addRecursiveWith(query, args...)
f.addLeftJoin(tableAlias, "", fmt.Sprintf("%s.item_id = %s.id", tableAlias, h.primaryTable))
addHierarchicalConditionClauses(f, c, tableAlias, "root_id")
}
func (h hierarchicalRelationshipHandler) handleValuesSimple(f *filterBuilder, value string, isParents bool) {
joinCol := h.childIDCol
valueCol := h.parentIDCol
if !isParents {
joinCol = h.parentIDCol
valueCol = h.childIDCol
}
tableAlias := h.parentsAlias()
if !isParents {
tableAlias = h.childrenAlias()
}
f.addInnerJoin(h.relationTable, tableAlias, fmt.Sprintf("%s.%s = %s.id", tableAlias, joinCol, h.primaryTable))
f.addWhere(fmt.Sprintf("%s.%s = ?", tableAlias, valueCol), value)
}
func (h hierarchicalRelationshipHandler) hierarchicalCriterionHandler(criterion *models.HierarchicalMultiCriterionInput, isParents bool) criterionHandlerFunc {
return func(ctx context.Context, f *filterBuilder) {
if criterion != nil {
c := criterion.CombineExcludes()
// validate the modifier
if err := h.validateModifier(c.Modifier); err != nil {
f.setError(err)
return
}
if c.Modifier == models.CriterionModifierIsNull || c.Modifier == models.CriterionModifierNotNull {
h.handleNullNotNull(f, c.Modifier, isParents)
return
}
if len(c.Value) == 0 && len(c.Excludes) == 0 {
return
}
depth := 0
if c.Depth != nil {
depth = *c.Depth
}
// if we have a single include, no excludes, and no depth, we can use a simple join and where clause
if (c.Modifier == models.CriterionModifierIncludes || c.Modifier == models.CriterionModifierIncludesAll) && len(c.Value) == 1 && len(c.Excludes) == 0 && depth == 0 {
h.handleValuesSimple(f, c.Value[0], isParents)
return
}
aliasSuffix := ""
h.handleValues(f, c, isParents, aliasSuffix)
if len(c.Excludes) > 0 {
exCriterion := models.HierarchicalMultiCriterionInput{
Value: c.Excludes,
Depth: c.Depth,
Modifier: models.CriterionModifierExcludes,
}
aliasSuffix := "2"
h.handleValues(f, exCriterion, isParents, aliasSuffix)
}
}
}
}
func (h hierarchicalRelationshipHandler) ParentsCriterionHandler(criterion *models.HierarchicalMultiCriterionInput) criterionHandlerFunc {
const isParents = true
return h.hierarchicalCriterionHandler(criterion, isParents)
}
func (h hierarchicalRelationshipHandler) ChildrenCriterionHandler(criterion *models.HierarchicalMultiCriterionInput) criterionHandlerFunc {
const isParents = false
return h.hierarchicalCriterionHandler(criterion, isParents)
}
func (h hierarchicalRelationshipHandler) countCriterionHandler(c *models.IntCriterionInput, isParents bool) criterionHandlerFunc {
tableAlias := h.parentsAlias()
col := h.childIDCol
otherCol := h.parentIDCol
if !isParents {
tableAlias = h.childrenAlias()
col = h.parentIDCol
otherCol = h.childIDCol
}
tableAlias += "_count"
return func(ctx context.Context, f *filterBuilder) {
if c != nil {
f.addLeftJoin(h.relationTable, tableAlias, fmt.Sprintf("%s.%s = %s.id", tableAlias, col, h.primaryTable))
clause, args := getIntCriterionWhereClause(fmt.Sprintf("count(distinct %s.%s)", tableAlias, otherCol), *c)
f.addHaving(clause, args...)
}
}
}
func (h hierarchicalRelationshipHandler) ParentCountCriterionHandler(parentCount *models.IntCriterionInput) criterionHandlerFunc {
const isParents = true
return h.countCriterionHandler(parentCount, isParents)
}
func (h hierarchicalRelationshipHandler) ChildCountCriterionHandler(childCount *models.IntCriterionInput) criterionHandlerFunc {
const isParents = false
return h.countCriterionHandler(childCount, isParents)
}

View File

@@ -27,6 +27,8 @@ const (
groupURLsTable = "group_urls"
groupURLColumn = "url"
groupRelationsTable = "groups_relations"
)
type groupRow struct {
@@ -128,6 +130,7 @@ var (
type GroupStore struct {
blobJoinQueryBuilder
tagRelationshipStore
groupRelationshipStore
tableMgr *table
}
@@ -143,6 +146,9 @@ func NewGroupStore(blobStore *BlobStore) *GroupStore {
joinTable: groupsTagsTableMgr,
},
},
groupRelationshipStore: groupRelationshipStore{
table: groupRelationshipTableMgr,
},
tableMgr: groupTableMgr,
}
@@ -176,6 +182,14 @@ func (qb *GroupStore) Create(ctx context.Context, newObject *models.Group) error
return err
}
if err := qb.groupRelationshipStore.createContainingRelationships(ctx, id, newObject.ContainingGroups); err != nil {
return err
}
if err := qb.groupRelationshipStore.createSubRelationships(ctx, id, newObject.SubGroups); err != nil {
return err
}
updated, err := qb.find(ctx, id)
if err != nil {
return fmt.Errorf("finding after create: %w", err)
@@ -211,6 +225,14 @@ func (qb *GroupStore) UpdatePartial(ctx context.Context, id int, partial models.
return nil, err
}
if err := qb.groupRelationshipStore.modifyContainingRelationships(ctx, id, partial.ContainingGroups); err != nil {
return nil, err
}
if err := qb.groupRelationshipStore.modifySubRelationships(ctx, id, partial.SubGroups); err != nil {
return nil, err
}
return qb.find(ctx, id)
}
@@ -232,6 +254,14 @@ func (qb *GroupStore) Update(ctx context.Context, updatedObject *models.Group) e
return err
}
if err := qb.groupRelationshipStore.replaceContainingRelationships(ctx, updatedObject.ID, updatedObject.ContainingGroups); err != nil {
return err
}
if err := qb.groupRelationshipStore.replaceSubRelationships(ctx, updatedObject.ID, updatedObject.SubGroups); err != nil {
return err
}
return nil
}
@@ -412,9 +442,7 @@ func (qb *GroupStore) makeQuery(ctx context.Context, groupFilter *models.GroupFi
return nil, err
}
var err error
query.sortAndPagination, err = qb.getGroupSort(findFilter)
if err != nil {
if err := qb.setGroupSort(&query, findFilter); err != nil {
return nil, err
}
@@ -460,11 +488,12 @@ var groupSortOptions = sortOptions{
"random",
"rating",
"scenes_count",
"sub_group_order",
"tag_count",
"updated_at",
}
func (qb *GroupStore) getGroupSort(findFilter *models.FindFilterType) (string, error) {
func (qb *GroupStore) setGroupSort(query *queryBuilder, findFilter *models.FindFilterType) error {
var sort string
var direction string
if findFilter == nil {
@@ -477,22 +506,31 @@ func (qb *GroupStore) getGroupSort(findFilter *models.FindFilterType) (string, e
// CVE-2024-32231 - ensure sort is in the list of allowed sorts
if err := groupSortOptions.validateSort(sort); err != nil {
return "", err
return err
}
sortQuery := ""
switch sort {
case "sub_group_order":
// sub_group_order is a special sort that sorts by the order_index of the subgroups
if query.hasJoin("groups_parents") {
query.sortAndPagination += getSort("order_index", direction, "groups_parents")
} else {
// this will give unexpected results if the query is not filtered by a parent group and
// the group has multiple parents and order indexes
query.join(groupRelationsTable, "", "groups.id = groups_relations.sub_id")
query.sortAndPagination += getSort("order_index", direction, groupRelationsTable)
}
case "tag_count":
sortQuery += getCountSort(groupTable, groupsTagsTable, groupIDColumn, direction)
query.sortAndPagination += getCountSort(groupTable, groupsTagsTable, groupIDColumn, direction)
case "scenes_count": // generic getSort won't work for this
sortQuery += getCountSort(groupTable, groupsScenesTable, groupIDColumn, direction)
query.sortAndPagination += getCountSort(groupTable, groupsScenesTable, groupIDColumn, direction)
default:
sortQuery += getSort(sort, direction, "groups")
query.sortAndPagination += getSort(sort, direction, "groups")
}
// Whatever the sorting, always use name/id as a final sort
sortQuery += ", COALESCE(groups.name, groups.id) COLLATE NATURAL_CI ASC"
return sortQuery, nil
query.sortAndPagination += ", COALESCE(groups.name, groups.id) COLLATE NATURAL_CI ASC"
return nil
}
func (qb *GroupStore) queryGroups(ctx context.Context, query string, args []interface{}) ([]*models.Group, error) {
@@ -592,3 +630,74 @@ WHERE groups.studio_id = ?
func (qb *GroupStore) GetURLs(ctx context.Context, groupID int) ([]string, error) {
return groupsURLsTableMgr.get(ctx, groupID)
}
// FindSubGroupIDs returns a list of group IDs where a group in the ids list is a sub-group of the parent group
func (qb *GroupStore) FindSubGroupIDs(ctx context.Context, containingID int, ids []int) ([]int, error) {
/*
SELECT gr.sub_id FROM groups_relations gr
WHERE gr.containing_id = :parentID AND gr.sub_id IN (:ids);
*/
table := groupRelationshipTableMgr.table
q := dialect.From(table).Prepared(true).
Select(table.Col("sub_id")).Where(
table.Col("containing_id").Eq(containingID),
table.Col("sub_id").In(ids),
)
const single = false
var ret []int
if err := queryFunc(ctx, q, single, func(r *sqlx.Rows) error {
var id int
if err := r.Scan(&id); err != nil {
return err
}
ret = append(ret, id)
return nil
}); err != nil {
return nil, err
}
return ret, nil
}
// FindInAscestors returns a list of group IDs where a group in the ids list is an ascestor of the ancestor group IDs
func (qb *GroupStore) FindInAncestors(ctx context.Context, ascestorIDs []int, ids []int) ([]int, error) {
/*
WITH RECURSIVE ascestors AS (
SELECT g.id AS parent_id FROM groups g WHERE g.id IN (:ascestorIDs)
UNION
SELECT gr.containing_id FROM groups_relations gr INNER JOIN ascestors a ON a.parent_id = gr.sub_id
)
SELECT p.parent_id FROM ascestors p WHERE p.parent_id IN (:ids);
*/
table := qb.table()
const ascestors = "ancestors"
const parentID = "parent_id"
q := dialect.From(ascestors).Prepared(true).
WithRecursive(ascestors,
dialect.From(qb.table()).Select(table.Col(idColumn).As(parentID)).
Where(table.Col(idColumn).In(ascestorIDs)).
Union(
dialect.From(groupRelationsJoinTable).InnerJoin(
goqu.I(ascestors), goqu.On(goqu.I("parent_id").Eq(goqu.I("sub_id"))),
).Select("containing_id"),
),
).Select(parentID).Where(goqu.I(parentID).In(ids))
const single = false
var ret []int
if err := queryFunc(ctx, q, single, func(r *sqlx.Rows) error {
var id int
if err := r.Scan(&id); err != nil {
return err
}
ret = append(ret, id)
return nil
}); err != nil {
return nil, err
}
return ret, nil
}

View File

@@ -51,6 +51,14 @@ func (qb *groupFilterHandler) handle(ctx context.Context, f *filterBuilder) {
f.handleCriterion(ctx, qb.criterionHandler())
}
var groupHierarchyHandler = hierarchicalRelationshipHandler{
primaryTable: groupTable,
relationTable: groupRelationsTable,
aliasPrefix: groupTable,
parentIDCol: "containing_id",
childIDCol: "sub_id",
}
func (qb *groupFilterHandler) criterionHandler() criterionHandler {
groupFilter := qb.groupFilter
return compoundHandler{
@@ -66,6 +74,10 @@ func (qb *groupFilterHandler) criterionHandler() criterionHandler {
qb.tagsCriterionHandler(groupFilter.Tags),
qb.tagCountCriterionHandler(groupFilter.TagCount),
&dateCriterionHandler{groupFilter.Date, "groups.date", nil},
groupHierarchyHandler.ParentsCriterionHandler(groupFilter.ContainingGroups),
groupHierarchyHandler.ChildrenCriterionHandler(groupFilter.SubGroups),
groupHierarchyHandler.ParentCountCriterionHandler(groupFilter.ContainingGroupCount),
groupHierarchyHandler.ChildCountCriterionHandler(groupFilter.SubGroupCount),
&timestampCriterionHandler{groupFilter.CreatedAt, "groups.created_at", nil},
&timestampCriterionHandler{groupFilter.UpdatedAt, "groups.updated_at", nil},

View File

@@ -0,0 +1,457 @@
package sqlite
import (
"context"
"fmt"
"github.com/doug-martin/goqu/v9"
"github.com/doug-martin/goqu/v9/exp"
"github.com/jmoiron/sqlx"
"github.com/stashapp/stash/pkg/models"
"gopkg.in/guregu/null.v4"
"gopkg.in/guregu/null.v4/zero"
)
type groupRelationshipRow struct {
ContainingID int `db:"containing_id"`
SubID int `db:"sub_id"`
OrderIndex int `db:"order_index"`
Description zero.String `db:"description"`
}
func (r groupRelationshipRow) resolve(useContainingID bool) models.GroupIDDescription {
id := r.ContainingID
if !useContainingID {
id = r.SubID
}
return models.GroupIDDescription{
GroupID: id,
Description: r.Description.String,
}
}
type groupRelationshipStore struct {
table *table
}
func (s *groupRelationshipStore) GetContainingGroupDescriptions(ctx context.Context, id int) ([]models.GroupIDDescription, error) {
const idIsContaining = false
return s.getGroupRelationships(ctx, id, idIsContaining)
}
func (s *groupRelationshipStore) GetSubGroupDescriptions(ctx context.Context, id int) ([]models.GroupIDDescription, error) {
const idIsContaining = true
return s.getGroupRelationships(ctx, id, idIsContaining)
}
func (s *groupRelationshipStore) getGroupRelationships(ctx context.Context, id int, idIsContaining bool) ([]models.GroupIDDescription, error) {
col := "containing_id"
if !idIsContaining {
col = "sub_id"
}
table := s.table.table
q := dialect.Select(table.All()).
From(table).
Where(table.Col(col).Eq(id)).
Order(table.Col("order_index").Asc())
const single = false
var ret []models.GroupIDDescription
if err := queryFunc(ctx, q, single, func(rows *sqlx.Rows) error {
var row groupRelationshipRow
if err := rows.StructScan(&row); err != nil {
return err
}
ret = append(ret, row.resolve(!idIsContaining))
return nil
}); err != nil {
return nil, fmt.Errorf("getting group relationships from %s: %w", table.GetTable(), err)
}
return ret, nil
}
// getMaxOrderIndex gets the maximum order index for the containing group with the given id
func (s *groupRelationshipStore) getMaxOrderIndex(ctx context.Context, containingID int) (int, error) {
idColumn := s.table.table.Col("containing_id")
q := dialect.Select(goqu.MAX("order_index")).
From(s.table.table).
Where(idColumn.Eq(containingID))
var maxOrderIndex zero.Int
if err := querySimple(ctx, q, &maxOrderIndex); err != nil {
return 0, fmt.Errorf("getting max order index: %w", err)
}
return int(maxOrderIndex.Int64), nil
}
// createRelationships creates relationships between a group and other groups.
// If idIsContaining is true, the provided id is the containing group.
func (s *groupRelationshipStore) createRelationships(ctx context.Context, id int, d models.RelatedGroupDescriptions, idIsContaining bool) error {
if d.Loaded() {
for i, v := range d.List() {
orderIndex := i + 1
r := groupRelationshipRow{
ContainingID: id,
SubID: v.GroupID,
OrderIndex: orderIndex,
Description: zero.StringFrom(v.Description),
}
if !idIsContaining {
// get the max order index of the containing groups sub groups
containingID := v.GroupID
maxOrderIndex, err := s.getMaxOrderIndex(ctx, containingID)
if err != nil {
return err
}
r.ContainingID = v.GroupID
r.SubID = id
r.OrderIndex = maxOrderIndex + 1
}
_, err := s.table.insert(ctx, r)
if err != nil {
return fmt.Errorf("inserting into %s: %w", s.table.table.GetTable(), err)
}
}
return nil
}
return nil
}
// createRelationships creates relationships between a group and other groups.
// If idIsContaining is true, the provided id is the containing group.
func (s *groupRelationshipStore) createContainingRelationships(ctx context.Context, id int, d models.RelatedGroupDescriptions) error {
const idIsContaining = false
return s.createRelationships(ctx, id, d, idIsContaining)
}
// createRelationships creates relationships between a group and other groups.
// If idIsContaining is true, the provided id is the containing group.
func (s *groupRelationshipStore) createSubRelationships(ctx context.Context, id int, d models.RelatedGroupDescriptions) error {
const idIsContaining = true
return s.createRelationships(ctx, id, d, idIsContaining)
}
func (s *groupRelationshipStore) replaceRelationships(ctx context.Context, id int, d models.RelatedGroupDescriptions, idIsContaining bool) error {
// always destroy the existing relationships even if the new list is empty
if err := s.destroyAllJoins(ctx, id, idIsContaining); err != nil {
return err
}
return s.createRelationships(ctx, id, d, idIsContaining)
}
func (s *groupRelationshipStore) replaceContainingRelationships(ctx context.Context, id int, d models.RelatedGroupDescriptions) error {
const idIsContaining = false
return s.replaceRelationships(ctx, id, d, idIsContaining)
}
func (s *groupRelationshipStore) replaceSubRelationships(ctx context.Context, id int, d models.RelatedGroupDescriptions) error {
const idIsContaining = true
return s.replaceRelationships(ctx, id, d, idIsContaining)
}
func (s *groupRelationshipStore) modifyRelationships(ctx context.Context, id int, v *models.UpdateGroupDescriptions, idIsContaining bool) error {
if v == nil {
return nil
}
switch v.Mode {
case models.RelationshipUpdateModeSet:
return s.replaceJoins(ctx, id, *v, idIsContaining)
case models.RelationshipUpdateModeAdd:
return s.addJoins(ctx, id, v.Groups, idIsContaining)
case models.RelationshipUpdateModeRemove:
toRemove := make([]int, len(v.Groups))
for i, vv := range v.Groups {
toRemove[i] = vv.GroupID
}
return s.destroyJoins(ctx, id, toRemove, idIsContaining)
}
return nil
}
func (s *groupRelationshipStore) modifyContainingRelationships(ctx context.Context, id int, v *models.UpdateGroupDescriptions) error {
const idIsContaining = false
return s.modifyRelationships(ctx, id, v, idIsContaining)
}
func (s *groupRelationshipStore) modifySubRelationships(ctx context.Context, id int, v *models.UpdateGroupDescriptions) error {
const idIsContaining = true
return s.modifyRelationships(ctx, id, v, idIsContaining)
}
func (s *groupRelationshipStore) addJoins(ctx context.Context, id int, groups []models.GroupIDDescription, idIsContaining bool) error {
// if we're adding to a containing group, get the max order index first
var maxOrderIndex int
if idIsContaining {
var err error
maxOrderIndex, err = s.getMaxOrderIndex(ctx, id)
if err != nil {
return err
}
}
for i, vv := range groups {
r := groupRelationshipRow{
Description: zero.StringFrom(vv.Description),
}
if idIsContaining {
r.ContainingID = id
r.SubID = vv.GroupID
r.OrderIndex = maxOrderIndex + (i + 1)
} else {
// get the max order index of the containing groups sub groups
containingMaxOrderIndex, err := s.getMaxOrderIndex(ctx, vv.GroupID)
if err != nil {
return err
}
r.ContainingID = vv.GroupID
r.SubID = id
r.OrderIndex = containingMaxOrderIndex + 1
}
_, err := s.table.insert(ctx, r)
if err != nil {
return fmt.Errorf("inserting into %s: %w", s.table.table.GetTable(), err)
}
}
return nil
}
func (s *groupRelationshipStore) destroyAllJoins(ctx context.Context, id int, idIsContaining bool) error {
table := s.table.table
idColumn := table.Col("containing_id")
if !idIsContaining {
idColumn = table.Col("sub_id")
}
q := dialect.Delete(table).Where(idColumn.Eq(id))
if _, err := exec(ctx, q); err != nil {
return fmt.Errorf("destroying %s: %w", table.GetTable(), err)
}
return nil
}
func (s *groupRelationshipStore) replaceJoins(ctx context.Context, id int, v models.UpdateGroupDescriptions, idIsContaining bool) error {
if err := s.destroyAllJoins(ctx, id, idIsContaining); err != nil {
return err
}
// convert to RelatedGroupDescriptions
rgd := models.NewRelatedGroupDescriptions(v.Groups)
return s.createRelationships(ctx, id, rgd, idIsContaining)
}
func (s *groupRelationshipStore) destroyJoins(ctx context.Context, id int, toRemove []int, idIsContaining bool) error {
table := s.table.table
idColumn := table.Col("containing_id")
fkColumn := table.Col("sub_id")
if !idIsContaining {
idColumn = table.Col("sub_id")
fkColumn = table.Col("containing_id")
}
q := dialect.Delete(table).Where(idColumn.Eq(id), fkColumn.In(toRemove))
if _, err := exec(ctx, q); err != nil {
return fmt.Errorf("destroying %s: %w", table.GetTable(), err)
}
return nil
}
func (s *groupRelationshipStore) getOrderIndexOfSubGroup(ctx context.Context, containingGroupID int, subGroupID int) (int, error) {
table := s.table.table
q := dialect.Select("order_index").
From(table).
Where(
table.Col("containing_id").Eq(containingGroupID),
table.Col("sub_id").Eq(subGroupID),
)
var orderIndex null.Int
if err := querySimple(ctx, q, &orderIndex); err != nil {
return 0, fmt.Errorf("getting order index: %w", err)
}
if !orderIndex.Valid {
return 0, fmt.Errorf("sub-group %d not found in containing group %d", subGroupID, containingGroupID)
}
return int(orderIndex.Int64), nil
}
func (s *groupRelationshipStore) getGroupIDAtOrderIndex(ctx context.Context, containingGroupID int, orderIndex int) (*int, error) {
table := s.table.table
q := dialect.Select(table.Col("sub_id")).From(table).Where(
table.Col("containing_id").Eq(containingGroupID),
table.Col("order_index").Eq(orderIndex),
)
var ret null.Int
if err := querySimple(ctx, q, &ret); err != nil {
return nil, fmt.Errorf("getting sub id for order index: %w", err)
}
if !ret.Valid {
return nil, nil
}
intRet := int(ret.Int64)
return &intRet, nil
}
func (s *groupRelationshipStore) getOrderIndexAfterOrderIndex(ctx context.Context, containingGroupID int, orderIndex int) (int, error) {
table := s.table.table
q := dialect.Select(goqu.MIN("order_index")).From(table).Where(
table.Col("containing_id").Eq(containingGroupID),
table.Col("order_index").Gt(orderIndex),
)
var ret null.Int
if err := querySimple(ctx, q, &ret); err != nil {
return 0, fmt.Errorf("getting order index: %w", err)
}
if !ret.Valid {
return orderIndex + 1, nil
}
return int(ret.Int64), nil
}
// incrementOrderIndexes increments the order_index value of all sub-groups in the containing group at or after the given index
func (s *groupRelationshipStore) incrementOrderIndexes(ctx context.Context, groupID int, indexBefore int) error {
table := s.table.table
// WORKAROUND - sqlite won't allow incrementing the value directly since it causes a
// unique constraint violation.
// Instead, we first set the order index to a negative value temporarily
// see https://stackoverflow.com/a/7703239/695786
q := dialect.Update(table).Set(exp.Record{
"order_index": goqu.L("-order_index"),
}).Where(
table.Col("containing_id").Eq(groupID),
table.Col("order_index").Gte(indexBefore),
)
if _, err := exec(ctx, q); err != nil {
return fmt.Errorf("updating %s: %w", table.GetTable(), err)
}
q = dialect.Update(table).Set(exp.Record{
"order_index": goqu.L("1-order_index"),
}).Where(
table.Col("containing_id").Eq(groupID),
table.Col("order_index").Lt(0),
)
if _, err := exec(ctx, q); err != nil {
return fmt.Errorf("updating %s: %w", table.GetTable(), err)
}
return nil
}
func (s *groupRelationshipStore) reorderSubGroup(ctx context.Context, groupID int, subGroupID int, insertPointID int, insertAfter bool) error {
insertPointIndex, err := s.getOrderIndexOfSubGroup(ctx, groupID, insertPointID)
if err != nil {
return err
}
// if we're setting before
if insertAfter {
insertPointIndex, err = s.getOrderIndexAfterOrderIndex(ctx, groupID, insertPointIndex)
if err != nil {
return err
}
}
// increment the order index of all sub-groups after and including the insertion point
if err := s.incrementOrderIndexes(ctx, groupID, int(insertPointIndex)); err != nil {
return err
}
// set the order index of the sub-group to the insertion point
table := s.table.table
q := dialect.Update(table).Set(exp.Record{
"order_index": insertPointIndex,
}).Where(
table.Col("containing_id").Eq(groupID),
table.Col("sub_id").Eq(subGroupID),
)
if _, err := exec(ctx, q); err != nil {
return fmt.Errorf("updating %s: %w", table.GetTable(), err)
}
return nil
}
func (s *groupRelationshipStore) AddSubGroups(ctx context.Context, groupID int, subGroups []models.GroupIDDescription, insertIndex *int) error {
const idIsContaining = true
if err := s.addJoins(ctx, groupID, subGroups, idIsContaining); err != nil {
return err
}
ids := make([]int, len(subGroups))
for i, v := range subGroups {
ids[i] = v.GroupID
}
if insertIndex != nil {
// get the id of the sub-group at the insert index
insertPointID, err := s.getGroupIDAtOrderIndex(ctx, groupID, *insertIndex)
if err != nil {
return err
}
if insertPointID == nil {
// if the insert index is out of bounds, just assume adding to the end
return nil
}
// reorder the sub-groups
const insertAfter = false
if err := s.ReorderSubGroups(ctx, groupID, ids, *insertPointID, insertAfter); err != nil {
return err
}
}
return nil
}
func (s *groupRelationshipStore) RemoveSubGroups(ctx context.Context, groupID int, subGroupIDs []int) error {
const idIsContaining = true
return s.destroyJoins(ctx, groupID, subGroupIDs, idIsContaining)
}
func (s *groupRelationshipStore) ReorderSubGroups(ctx context.Context, groupID int, subGroupIDs []int, insertPointID int, insertAfter bool) error {
for _, id := range subGroupIDs {
if err := s.reorderSubGroup(ctx, groupID, id, insertPointID, insertAfter); err != nil {
return err
}
}
return nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,13 @@
CREATE TABLE `groups_relations` (
`containing_id` integer not null,
`sub_id` integer not null,
`order_index` integer not null,
`description` varchar(255),
primary key (`containing_id`, `sub_id`),
foreign key (`containing_id`) references `groups`(`id`) on delete cascade,
foreign key (`sub_id`) references `groups`(`id`) on delete cascade,
check (`containing_id` != `sub_id`)
);
CREATE INDEX `index_groups_relations_sub_id` ON `groups_relations` (`sub_id`);
CREATE UNIQUE INDEX `index_groups_relations_order_index_unique` ON `groups_relations` (`containing_id`, `order_index`);

View File

@@ -110,6 +110,16 @@ func (qb *queryBuilder) addArg(args ...interface{}) {
qb.args = append(qb.args, args...)
}
func (qb *queryBuilder) hasJoin(alias string) bool {
for _, j := range qb.joins {
if j.alias() == alias {
return true
}
}
return false
}
func (qb *queryBuilder) join(table, as, onClause string) {
newJoin := join{
table: table,

View File

@@ -791,13 +791,6 @@ func (qb *SceneStore) FindByGroupID(ctx context.Context, groupID int) ([]*models
return ret, nil
}
func (qb *SceneStore) CountByGroupID(ctx context.Context, groupID int) (int, error) {
joinTable := scenesGroupsJoinTable
q := dialect.Select(goqu.COUNT("*")).From(joinTable).Where(joinTable.Col(groupIDColumn).Eq(groupID))
return count(ctx, q)
}
func (qb *SceneStore) Count(ctx context.Context) (int, error) {
q := dialect.Select(goqu.COUNT("*")).From(qb.table())
return count(ctx, q)
@@ -858,6 +851,7 @@ func (qb *SceneStore) PlayDuration(ctx context.Context) (float64, error) {
return ret, nil
}
// TODO - currently only used by unit test
func (qb *SceneStore) CountByStudioID(ctx context.Context, studioID int) (int, error) {
table := qb.table()
@@ -865,13 +859,6 @@ func (qb *SceneStore) CountByStudioID(ctx context.Context, studioID int) (int, e
return count(ctx, q)
}
func (qb *SceneStore) CountByTagID(ctx context.Context, tagID int) (int, error) {
joinTable := scenesTagsJoinTable
q := dialect.Select(goqu.COUNT("*")).From(joinTable).Where(joinTable.Col(tagIDColumn).Eq(tagID))
return count(ctx, q)
}
func (qb *SceneStore) countMissingFingerprints(ctx context.Context, fpType string) (int, error) {
fpTable := fingerprintTableMgr.table.As("fingerprints_temp")

View File

@@ -149,7 +149,7 @@ func (qb *sceneFilterHandler) criterionHandler() criterionHandler {
studioCriterionHandler(sceneTable, sceneFilter.Studios),
qb.groupsCriterionHandler(sceneFilter.Groups),
qb.groupsCriterionHandler(sceneFilter.Movies),
qb.moviesCriterionHandler(sceneFilter.Movies),
qb.galleriesCriterionHandler(sceneFilter.Galleries),
qb.performerTagsCriterionHandler(sceneFilter.PerformerTags),
@@ -483,7 +483,8 @@ func (qb *sceneFilterHandler) performerAgeCriterionHandler(performerAge *models.
}
}
func (qb *sceneFilterHandler) groupsCriterionHandler(movies *models.MultiCriterionInput) criterionHandlerFunc {
// legacy handler
func (qb *sceneFilterHandler) moviesCriterionHandler(movies *models.MultiCriterionInput) criterionHandlerFunc {
addJoinsFunc := func(f *filterBuilder) {
sceneRepository.groups.join(f, "", "scenes.id")
f.addLeftJoin("groups", "", "groups_scenes.group_id = groups.id")
@@ -492,6 +493,23 @@ func (qb *sceneFilterHandler) groupsCriterionHandler(movies *models.MultiCriteri
return h.handler(movies)
}
func (qb *sceneFilterHandler) groupsCriterionHandler(groups *models.HierarchicalMultiCriterionInput) criterionHandlerFunc {
h := joinedHierarchicalMultiCriterionHandlerBuilder{
primaryTable: sceneTable,
foreignTable: groupTable,
foreignFK: "group_id",
relationsTable: groupRelationsTable,
parentFK: "containing_id",
childFK: "sub_id",
joinAs: "scene_group",
joinTable: groupsScenesTable,
primaryFK: sceneIDColumn,
}
return h.handler(groups)
}
func (qb *sceneFilterHandler) galleriesCriterionHandler(galleries *models.MultiCriterionInput) criterionHandlerFunc {
addJoinsFunc := func(f *filterBuilder) {
sceneRepository.galleries.join(f, "", "scenes.id")

View File

@@ -16,6 +16,7 @@ import (
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/sliceutil"
"github.com/stashapp/stash/pkg/sliceutil/intslice"
"github.com/stretchr/testify/assert"
)
@@ -2217,7 +2218,7 @@ func TestSceneQuery(t *testing.T) {
},
})
if (err != nil) != tt.wantErr {
t.Errorf("PerformerStore.Query() error = %v, wantErr %v", err, tt.wantErr)
t.Errorf("SceneStore.Query() error = %v, wantErr %v", err, tt.wantErr)
return
}
@@ -3873,6 +3874,100 @@ func TestSceneQueryStudioDepth(t *testing.T) {
})
}
func TestSceneGroups(t *testing.T) {
type criterion struct {
valueIdxs []int
modifier models.CriterionModifier
depth int
}
tests := []struct {
name string
c criterion
q string
includeIdxs []int
excludeIdxs []int
}{
{
"includes",
criterion{
[]int{groupIdxWithScene},
models.CriterionModifierIncludes,
0,
},
"",
[]int{sceneIdxWithGroup},
nil,
},
{
"excludes",
criterion{
[]int{groupIdxWithScene},
models.CriterionModifierExcludes,
0,
},
getSceneStringValue(sceneIdxWithGroup, titleField),
nil,
[]int{sceneIdxWithGroup},
},
{
"includes (depth = 1)",
criterion{
[]int{groupIdxWithChildWithScene},
models.CriterionModifierIncludes,
1,
},
"",
[]int{sceneIdxWithGroupWithParent},
nil,
},
}
for _, tt := range tests {
valueIDs := indexesToIDs(groupIDs, tt.c.valueIdxs)
runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) {
assert := assert.New(t)
sceneFilter := &models.SceneFilterType{
Groups: &models.HierarchicalMultiCriterionInput{
Value: intslice.IntSliceToStringSlice(valueIDs),
Modifier: tt.c.modifier,
},
}
if tt.c.depth != 0 {
sceneFilter.Groups.Depth = &tt.c.depth
}
findFilter := &models.FindFilterType{}
if tt.q != "" {
findFilter.Q = &tt.q
}
results, err := db.Scene.Query(ctx, models.SceneQueryOptions{
SceneFilter: sceneFilter,
QueryOptions: models.QueryOptions{
FindFilter: findFilter,
},
})
if err != nil {
t.Errorf("SceneStore.Query() error = %v", err)
return
}
include := indexesToIDs(sceneIDs, tt.includeIdxs)
exclude := indexesToIDs(sceneIDs, tt.excludeIdxs)
assert.Subset(results.IDs, include)
for _, e := range exclude {
assert.NotContains(results.IDs, e)
}
})
}
}
func TestSceneQueryMovies(t *testing.T) {
withTxn(func(ctx context.Context) error {
sqb := db.Scene
@@ -4188,78 +4283,6 @@ func verifyScenesPerformerCount(t *testing.T, performerCountCriterion models.Int
})
}
func TestSceneCountByTagID(t *testing.T) {
withTxn(func(ctx context.Context) error {
sqb := db.Scene
sceneCount, err := sqb.CountByTagID(ctx, tagIDs[tagIdxWithScene])
if err != nil {
t.Errorf("error calling CountByTagID: %s", err.Error())
}
assert.Equal(t, 1, sceneCount)
sceneCount, err = sqb.CountByTagID(ctx, 0)
if err != nil {
t.Errorf("error calling CountByTagID: %s", err.Error())
}
assert.Equal(t, 0, sceneCount)
return nil
})
}
func TestSceneCountByGroupID(t *testing.T) {
withTxn(func(ctx context.Context) error {
sqb := db.Scene
sceneCount, err := sqb.CountByGroupID(ctx, groupIDs[groupIdxWithScene])
if err != nil {
t.Errorf("error calling CountByGroupID: %s", err.Error())
}
assert.Equal(t, 1, sceneCount)
sceneCount, err = sqb.CountByGroupID(ctx, 0)
if err != nil {
t.Errorf("error calling CountByGroupID: %s", err.Error())
}
assert.Equal(t, 0, sceneCount)
return nil
})
}
func TestSceneCountByStudioID(t *testing.T) {
withTxn(func(ctx context.Context) error {
sqb := db.Scene
sceneCount, err := sqb.CountByStudioID(ctx, studioIDs[studioIdxWithScene])
if err != nil {
t.Errorf("error calling CountByStudioID: %s", err.Error())
}
assert.Equal(t, 1, sceneCount)
sceneCount, err = sqb.CountByStudioID(ctx, 0)
if err != nil {
t.Errorf("error calling CountByStudioID: %s", err.Error())
}
assert.Equal(t, 0, sceneCount)
return nil
})
}
func TestFindByMovieID(t *testing.T) {
withTxn(func(ctx context.Context) error {
sqb := db.Scene

View File

@@ -78,6 +78,7 @@ const (
sceneIdxWithGrandChildStudio
sceneIdxMissingPhash
sceneIdxWithPerformerParentTag
sceneIdxWithGroupWithParent
// new indexes above
lastSceneIdx
@@ -153,9 +154,15 @@ const (
groupIdxWithTag
groupIdxWithTwoTags
groupIdxWithThreeTags
groupIdxWithGrandChild
groupIdxWithChild
groupIdxWithParentAndChild
groupIdxWithParent
groupIdxWithGrandParent
groupIdxWithParentAndScene
groupIdxWithChildWithScene
// groups with dup names start from the end
// create 7 more basic groups (can remove this if we add more indexes)
groupIdxWithDupName = groupIdxWithStudio + 7
groupIdxWithDupName
groupsNameCase = groupIdxWithDupName
groupsNameNoCase = 1
@@ -390,7 +397,8 @@ var (
}
sceneGroups = linkMap{
sceneIdxWithGroup: {groupIdxWithScene},
sceneIdxWithGroup: {groupIdxWithScene},
sceneIdxWithGroupWithParent: {groupIdxWithParentAndScene},
}
sceneStudios = map[int]int{
@@ -541,15 +549,31 @@ var (
}
)
var (
groupParentLinks = [][2]int{
{groupIdxWithChild, groupIdxWithParent},
{groupIdxWithGrandChild, groupIdxWithParentAndChild},
{groupIdxWithParentAndChild, groupIdxWithGrandParent},
{groupIdxWithChildWithScene, groupIdxWithParentAndScene},
}
)
func indexesToIDs(ids []int, indexes []int) []int {
ret := make([]int, len(indexes))
for i, idx := range indexes {
ret[i] = ids[idx]
ret[i] = indexToID(ids, idx)
}
return ret
}
func indexToID(ids []int, idx int) int {
if idx < 0 {
return invalidID
}
return ids[idx]
}
func indexFromID(ids []int, id int) int {
for i, v := range ids {
if v == id {
@@ -697,6 +721,10 @@ func populateDB() error {
return fmt.Errorf("error linking tags parent: %s", err.Error())
}
if err := linkGroupsParent(ctx, db.Group); err != nil {
return fmt.Errorf("error linking tags parent: %s", err.Error())
}
for _, ms := range markerSpecs {
if err := createMarker(ctx, db.SceneMarker, ms); err != nil {
return fmt.Errorf("error creating scene marker: %s", err.Error())
@@ -1885,6 +1913,24 @@ func linkTagsParent(ctx context.Context, qb models.TagReaderWriter) error {
})
}
func linkGroupsParent(ctx context.Context, qb models.GroupReaderWriter) error {
return doLinks(groupParentLinks, func(parentIndex, childIndex int) error {
groupID := groupIDs[childIndex]
p := models.GroupPartial{
ContainingGroups: &models.UpdateGroupDescriptions{
Groups: []models.GroupIDDescription{
{GroupID: groupIDs[parentIndex]},
},
Mode: models.RelationshipUpdateModeAdd,
},
}
_, err := qb.UpdatePartial(ctx, groupID, p)
return err
})
}
func addTagImage(ctx context.Context, qb models.TagWriter, tagIndex int) error {
return qb.UpdateImage(ctx, tagIDs[tagIndex], []byte("image"))
}

View File

@@ -37,8 +37,9 @@ var (
studiosTagsJoinTable = goqu.T(studiosTagsTable)
studiosStashIDsJoinTable = goqu.T("studio_stash_ids")
groupsURLsJoinTable = goqu.T(groupURLsTable)
groupsTagsJoinTable = goqu.T(groupsTagsTable)
groupsURLsJoinTable = goqu.T(groupURLsTable)
groupsTagsJoinTable = goqu.T(groupsTagsTable)
groupRelationsJoinTable = goqu.T(groupRelationsTable)
tagsAliasesJoinTable = goqu.T(tagAliasesTable)
tagRelationsJoinTable = goqu.T(tagRelationsTable)
@@ -361,6 +362,10 @@ var (
foreignTable: tagTableMgr,
orderBy: tagTableMgr.table.Col("name").Asc(),
}
groupRelationshipTableMgr = &table{
table: groupRelationsJoinTable,
}
)
var (

View File

@@ -2,7 +2,6 @@ package sqlite
import (
"context"
"fmt"
"github.com/stashapp/stash/pkg/models"
)
@@ -51,6 +50,14 @@ func (qb *tagFilterHandler) handle(ctx context.Context, f *filterBuilder) {
f.handleCriterion(ctx, qb.criterionHandler())
}
var tagHierarchyHandler = hierarchicalRelationshipHandler{
primaryTable: tagTable,
relationTable: tagRelationsTable,
aliasPrefix: tagTable,
parentIDCol: "parent_id",
childIDCol: "child_id",
}
func (qb *tagFilterHandler) criterionHandler() criterionHandler {
tagFilter := qb.tagFilter
return compoundHandler{
@@ -72,10 +79,10 @@ func (qb *tagFilterHandler) criterionHandler() criterionHandler {
qb.groupCountCriterionHandler(tagFilter.MovieCount),
qb.markerCountCriterionHandler(tagFilter.MarkerCount),
qb.parentsCriterionHandler(tagFilter.Parents),
qb.childrenCriterionHandler(tagFilter.Children),
qb.parentCountCriterionHandler(tagFilter.ParentCount),
qb.childCountCriterionHandler(tagFilter.ChildCount),
tagHierarchyHandler.ParentsCriterionHandler(tagFilter.Parents),
tagHierarchyHandler.ChildrenCriterionHandler(tagFilter.Children),
tagHierarchyHandler.ParentCountCriterionHandler(tagFilter.ParentCount),
tagHierarchyHandler.ChildCountCriterionHandler(tagFilter.ChildCount),
&timestampCriterionHandler{tagFilter.CreatedAt, "tags.created_at", nil},
&timestampCriterionHandler{tagFilter.UpdatedAt, "tags.updated_at", nil},
@@ -212,213 +219,3 @@ func (qb *tagFilterHandler) markerCountCriterionHandler(markerCount *models.IntC
}
}
}
func (qb *tagFilterHandler) parentsCriterionHandler(criterion *models.HierarchicalMultiCriterionInput) criterionHandlerFunc {
return func(ctx context.Context, f *filterBuilder) {
if criterion != nil {
tags := criterion.CombineExcludes()
// validate the modifier
switch tags.Modifier {
case models.CriterionModifierIncludesAll, models.CriterionModifierIncludes, models.CriterionModifierExcludes, models.CriterionModifierIsNull, models.CriterionModifierNotNull:
// valid
default:
f.setError(fmt.Errorf("invalid modifier %s for tag parent/children", criterion.Modifier))
}
if tags.Modifier == models.CriterionModifierIsNull || tags.Modifier == models.CriterionModifierNotNull {
var notClause string
if tags.Modifier == models.CriterionModifierNotNull {
notClause = "NOT"
}
f.addLeftJoin("tags_relations", "parent_relations", "tags.id = parent_relations.child_id")
f.addWhere(fmt.Sprintf("parent_relations.parent_id IS %s NULL", notClause))
return
}
if len(tags.Value) == 0 && len(tags.Excludes) == 0 {
return
}
if len(tags.Value) > 0 {
var args []interface{}
for _, val := range tags.Value {
args = append(args, val)
}
depthVal := 0
if tags.Depth != nil {
depthVal = *tags.Depth
}
var depthCondition string
if depthVal != -1 {
depthCondition = fmt.Sprintf("WHERE depth < %d", depthVal)
}
query := `parents AS (
SELECT parent_id AS root_id, child_id AS item_id, 0 AS depth FROM tags_relations WHERE parent_id IN` + getInBinding(len(tags.Value)) + `
UNION
SELECT root_id, child_id, depth + 1 FROM tags_relations INNER JOIN parents ON item_id = parent_id ` + depthCondition + `
)`
f.addRecursiveWith(query, args...)
f.addLeftJoin("parents", "", "parents.item_id = tags.id")
addHierarchicalConditionClauses(f, tags, "parents", "root_id")
}
if len(tags.Excludes) > 0 {
var args []interface{}
for _, val := range tags.Excludes {
args = append(args, val)
}
depthVal := 0
if tags.Depth != nil {
depthVal = *tags.Depth
}
var depthCondition string
if depthVal != -1 {
depthCondition = fmt.Sprintf("WHERE depth < %d", depthVal)
}
query := `parents2 AS (
SELECT parent_id AS root_id, child_id AS item_id, 0 AS depth FROM tags_relations WHERE parent_id IN` + getInBinding(len(tags.Excludes)) + `
UNION
SELECT root_id, child_id, depth + 1 FROM tags_relations INNER JOIN parents2 ON item_id = parent_id ` + depthCondition + `
)`
f.addRecursiveWith(query, args...)
f.addLeftJoin("parents2", "", "parents2.item_id = tags.id")
addHierarchicalConditionClauses(f, models.HierarchicalMultiCriterionInput{
Value: tags.Excludes,
Depth: tags.Depth,
Modifier: models.CriterionModifierExcludes,
}, "parents2", "root_id")
}
}
}
}
func (qb *tagFilterHandler) childrenCriterionHandler(criterion *models.HierarchicalMultiCriterionInput) criterionHandlerFunc {
return func(ctx context.Context, f *filterBuilder) {
if criterion != nil {
tags := criterion.CombineExcludes()
// validate the modifier
switch tags.Modifier {
case models.CriterionModifierIncludesAll, models.CriterionModifierIncludes, models.CriterionModifierExcludes, models.CriterionModifierIsNull, models.CriterionModifierNotNull:
// valid
default:
f.setError(fmt.Errorf("invalid modifier %s for tag parent/children", criterion.Modifier))
}
if tags.Modifier == models.CriterionModifierIsNull || tags.Modifier == models.CriterionModifierNotNull {
var notClause string
if tags.Modifier == models.CriterionModifierNotNull {
notClause = "NOT"
}
f.addLeftJoin("tags_relations", "child_relations", "tags.id = child_relations.parent_id")
f.addWhere(fmt.Sprintf("child_relations.child_id IS %s NULL", notClause))
return
}
if len(tags.Value) == 0 && len(tags.Excludes) == 0 {
return
}
if len(tags.Value) > 0 {
var args []interface{}
for _, val := range tags.Value {
args = append(args, val)
}
depthVal := 0
if tags.Depth != nil {
depthVal = *tags.Depth
}
var depthCondition string
if depthVal != -1 {
depthCondition = fmt.Sprintf("WHERE depth < %d", depthVal)
}
query := `children AS (
SELECT child_id AS root_id, parent_id AS item_id, 0 AS depth FROM tags_relations WHERE child_id IN` + getInBinding(len(tags.Value)) + `
UNION
SELECT root_id, parent_id, depth + 1 FROM tags_relations INNER JOIN children ON item_id = child_id ` + depthCondition + `
)`
f.addRecursiveWith(query, args...)
f.addLeftJoin("children", "", "children.item_id = tags.id")
addHierarchicalConditionClauses(f, tags, "children", "root_id")
}
if len(tags.Excludes) > 0 {
var args []interface{}
for _, val := range tags.Excludes {
args = append(args, val)
}
depthVal := 0
if tags.Depth != nil {
depthVal = *tags.Depth
}
var depthCondition string
if depthVal != -1 {
depthCondition = fmt.Sprintf("WHERE depth < %d", depthVal)
}
query := `children2 AS (
SELECT child_id AS root_id, parent_id AS item_id, 0 AS depth FROM tags_relations WHERE child_id IN` + getInBinding(len(tags.Excludes)) + `
UNION
SELECT root_id, parent_id, depth + 1 FROM tags_relations INNER JOIN children2 ON item_id = child_id ` + depthCondition + `
)`
f.addRecursiveWith(query, args...)
f.addLeftJoin("children2", "", "children2.item_id = tags.id")
addHierarchicalConditionClauses(f, models.HierarchicalMultiCriterionInput{
Value: tags.Excludes,
Depth: tags.Depth,
Modifier: models.CriterionModifierExcludes,
}, "children2", "root_id")
}
}
}
}
func (qb *tagFilterHandler) parentCountCriterionHandler(parentCount *models.IntCriterionInput) criterionHandlerFunc {
return func(ctx context.Context, f *filterBuilder) {
if parentCount != nil {
f.addLeftJoin("tags_relations", "parents_count", "parents_count.child_id = tags.id")
clause, args := getIntCriterionWhereClause("count(distinct parents_count.parent_id)", *parentCount)
f.addHaving(clause, args...)
}
}
}
func (qb *tagFilterHandler) childCountCriterionHandler(childCount *models.IntCriterionInput) criterionHandlerFunc {
return func(ctx context.Context, f *filterBuilder) {
if childCount != nil {
f.addLeftJoin("tags_relations", "children_count", "children_count.parent_id = tags.id")
clause, args := getIntCriterionWhereClause("count(distinct children_count.child_id)", *childCount)
f.addHaving(clause, args...)
}
}
}

View File

@@ -712,7 +712,7 @@ func TestTagQueryParent(t *testing.T) {
assert.Len(t, tags, 1)
// ensure id is correct
assert.Equal(t, sceneIDs[tagIdxWithParentTag], tags[0].ID)
assert.Equal(t, tagIDs[tagIdxWithParentTag], tags[0].ID)
tagCriterion.Modifier = models.CriterionModifierExcludes