Containing Group/Sub-Group relationships (#5105)

* Add UI support for setting containing groups
* Show containing groups in group details panel
* Move tag hierarchical filter code into separate type
* Add depth to scene_count and add sub_group_count
* Add sub-groups tab to groups page
* Add containing groups to edit groups dialog
* Show containing group description in sub-group view
* Show group scene number in group scenes view
* Add ability to drag move grid cards
* Add sub group order option
* Add reorder sub-groups interface
* Separate page size selector component
* Add interfaces to add and remove sub-groups to a group
* Separate MultiSet components
* Allow setting description while setting containing groups
This commit is contained in:
WithoutPants
2024-08-30 11:43:44 +10:00
committed by GitHub
parent 96fdd94a01
commit bcf0fda7ac
99 changed files with 5388 additions and 935 deletions

41
pkg/group/create.go Normal file
View File

@@ -0,0 +1,41 @@
package group
import (
"context"
"errors"
"github.com/stashapp/stash/pkg/models"
)
var (
ErrEmptyName = errors.New("name cannot be empty")
ErrHierarchyLoop = errors.New("a group cannot be contained by one of its subgroups")
)
func (s *Service) Create(ctx context.Context, group *models.Group, frontimageData []byte, backimageData []byte) error {
r := s.Repository
if err := s.validateCreate(ctx, group); err != nil {
return err
}
err := r.Create(ctx, group)
if err != nil {
return err
}
// update image table
if len(frontimageData) > 0 {
if err := r.UpdateFrontImage(ctx, group.ID, frontimageData); err != nil {
return err
}
}
if len(backimageData) > 0 {
if err := r.UpdateBackImage(ctx, group.ID, backimageData); err != nil {
return err
}
}
return nil
}

View File

@@ -16,6 +16,18 @@ type ImporterReaderWriter interface {
FindByName(ctx context.Context, name string, nocase bool) (*models.Group, error)
}
type SubGroupNotExistError struct {
missingSubGroup string
}
func (e SubGroupNotExistError) Error() string {
return fmt.Sprintf("sub group <%s> does not exist", e.missingSubGroup)
}
func (e SubGroupNotExistError) MissingSubGroup() string {
return e.missingSubGroup
}
type Importer struct {
ReaderWriter ImporterReaderWriter
StudioWriter models.StudioFinderCreator
@@ -202,6 +214,22 @@ func (i *Importer) createStudio(ctx context.Context, name string) (int, error) {
}
func (i *Importer) PostImport(ctx context.Context, id int) error {
subGroups, err := i.getSubGroups(ctx)
if err != nil {
return err
}
if len(subGroups) > 0 {
if _, err := i.ReaderWriter.UpdatePartial(ctx, id, models.GroupPartial{
SubGroups: &models.UpdateGroupDescriptions{
Groups: subGroups,
Mode: models.RelationshipUpdateModeSet,
},
}); err != nil {
return fmt.Errorf("error setting parents: %v", err)
}
}
if len(i.frontImageData) > 0 {
if err := i.ReaderWriter.UpdateFrontImage(ctx, id, i.frontImageData); err != nil {
return fmt.Errorf("error setting group front image: %v", err)
@@ -256,3 +284,53 @@ func (i *Importer) Update(ctx context.Context, id int) error {
return nil
}
func (i *Importer) getSubGroups(ctx context.Context) ([]models.GroupIDDescription, error) {
var subGroups []models.GroupIDDescription
for _, subGroup := range i.Input.SubGroups {
group, err := i.ReaderWriter.FindByName(ctx, subGroup.Group, false)
if err != nil {
return nil, fmt.Errorf("error finding parent by name: %v", err)
}
if group == nil {
if i.MissingRefBehaviour == models.ImportMissingRefEnumFail {
return nil, SubGroupNotExistError{missingSubGroup: subGroup.Group}
}
if i.MissingRefBehaviour == models.ImportMissingRefEnumIgnore {
continue
}
if i.MissingRefBehaviour == models.ImportMissingRefEnumCreate {
parentID, err := i.createSubGroup(ctx, subGroup.Group)
if err != nil {
return nil, err
}
subGroups = append(subGroups, models.GroupIDDescription{
GroupID: parentID,
Description: subGroup.Description,
})
}
} else {
subGroups = append(subGroups, models.GroupIDDescription{
GroupID: group.ID,
Description: subGroup.Description,
})
}
}
return subGroups, nil
}
func (i *Importer) createSubGroup(ctx context.Context, name string) (int, error) {
newGroup := models.NewGroup()
newGroup.Name = name
err := i.ReaderWriter.Create(ctx, &newGroup)
if err != nil {
return 0, err
}
return newGroup.ID, nil
}

View File

@@ -30,3 +30,15 @@ func CountByTagID(ctx context.Context, r models.GroupQueryer, id int, depth *int
return r.QueryCount(ctx, filter, nil)
}
func CountByContainingGroupID(ctx context.Context, r models.GroupQueryer, id int, depth *int) (int, error) {
filter := &models.GroupFilterType{
ContainingGroups: &models.HierarchicalMultiCriterionInput{
Value: []string{strconv.Itoa(id)},
Modifier: models.CriterionModifierIncludes,
Depth: depth,
},
}
return r.QueryCount(ctx, filter, nil)
}

33
pkg/group/reorder.go Normal file
View File

@@ -0,0 +1,33 @@
package group
import (
"context"
"errors"
"github.com/stashapp/stash/pkg/models"
)
var ErrInvalidInsertIndex = errors.New("invalid insert index")
func (s *Service) ReorderSubGroups(ctx context.Context, groupID int, subGroupIDs []int, insertPointID int, insertAfter bool) error {
// get the group
existing, err := s.Repository.Find(ctx, groupID)
if err != nil {
return err
}
// ensure it exists
if existing == nil {
return models.ErrNotFound
}
// TODO - ensure the subgroups exist in the group
// ensure the insert index is valid
if insertPointID < 0 {
return ErrInvalidInsertIndex
}
// reorder the subgroups
return s.Repository.ReorderSubGroups(ctx, groupID, subGroupIDs, insertPointID, insertAfter)
}

46
pkg/group/service.go Normal file
View File

@@ -0,0 +1,46 @@
package group
import (
"context"
"github.com/stashapp/stash/pkg/models"
)
type CreatorUpdater interface {
models.GroupGetter
models.GroupCreator
models.GroupUpdater
models.ContainingGroupLoader
models.SubGroupLoader
AnscestorFinder
SubGroupIDFinder
SubGroupAdder
SubGroupRemover
SubGroupReorderer
}
type AnscestorFinder interface {
FindInAncestors(ctx context.Context, ascestorIDs []int, ids []int) ([]int, error)
}
type SubGroupIDFinder interface {
FindSubGroupIDs(ctx context.Context, containingID int, ids []int) ([]int, error)
}
type SubGroupAdder interface {
AddSubGroups(ctx context.Context, groupID int, subGroups []models.GroupIDDescription, insertIndex *int) error
}
type SubGroupRemover interface {
RemoveSubGroups(ctx context.Context, groupID int, subGroupIDs []int) error
}
type SubGroupReorderer interface {
ReorderSubGroups(ctx context.Context, groupID int, subGroupIDs []int, insertID int, insertAfter bool) error
}
type Service struct {
Repository CreatorUpdater
}

112
pkg/group/update.go Normal file
View File

@@ -0,0 +1,112 @@
package group
import (
"context"
"fmt"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/sliceutil"
)
type SubGroupAlreadyInGroupError struct {
GroupIDs []int
}
func (e *SubGroupAlreadyInGroupError) Error() string {
return fmt.Sprintf("subgroups with IDs %v already in group", e.GroupIDs)
}
type ImageInput struct {
Image []byte
Set bool
}
func (s *Service) UpdatePartial(ctx context.Context, id int, updatedGroup models.GroupPartial, frontImage ImageInput, backImage ImageInput) (*models.Group, error) {
if err := s.validateUpdate(ctx, id, updatedGroup); err != nil {
return nil, err
}
r := s.Repository
group, err := r.UpdatePartial(ctx, id, updatedGroup)
if err != nil {
return nil, err
}
// update image table
if frontImage.Set {
if err := r.UpdateFrontImage(ctx, id, frontImage.Image); err != nil {
return nil, err
}
}
if backImage.Set {
if err := r.UpdateBackImage(ctx, id, backImage.Image); err != nil {
return nil, err
}
}
return group, nil
}
func (s *Service) AddSubGroups(ctx context.Context, groupID int, subGroups []models.GroupIDDescription, insertIndex *int) error {
// get the group
existing, err := s.Repository.Find(ctx, groupID)
if err != nil {
return err
}
// ensure it exists
if existing == nil {
return models.ErrNotFound
}
// ensure the subgroups aren't already sub-groups of the group
subGroupIDs := sliceutil.Map(subGroups, func(sg models.GroupIDDescription) int {
return sg.GroupID
})
existingSubGroupIDs, err := s.Repository.FindSubGroupIDs(ctx, groupID, subGroupIDs)
if err != nil {
return err
}
if len(existingSubGroupIDs) > 0 {
return &SubGroupAlreadyInGroupError{
GroupIDs: existingSubGroupIDs,
}
}
// validate the hierarchy
d := &models.UpdateGroupDescriptions{
Groups: subGroups,
Mode: models.RelationshipUpdateModeAdd,
}
if err := s.validateUpdateGroupHierarchy(ctx, existing, nil, d); err != nil {
return err
}
// validate insert index
if insertIndex != nil && *insertIndex < 0 {
return ErrInvalidInsertIndex
}
// add the subgroups
return s.Repository.AddSubGroups(ctx, groupID, subGroups, insertIndex)
}
func (s *Service) RemoveSubGroups(ctx context.Context, groupID int, subGroupIDs []int) error {
// get the group
existing, err := s.Repository.Find(ctx, groupID)
if err != nil {
return err
}
// ensure it exists
if existing == nil {
return models.ErrNotFound
}
// add the subgroups
return s.Repository.RemoveSubGroups(ctx, groupID, subGroupIDs)
}

117
pkg/group/validate.go Normal file
View File

@@ -0,0 +1,117 @@
package group
import (
"context"
"strings"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/sliceutil"
)
func (s *Service) validateCreate(ctx context.Context, group *models.Group) error {
if err := validateName(group.Name); err != nil {
return err
}
containingIDs := group.ContainingGroups.IDs()
subIDs := group.SubGroups.IDs()
if err := s.validateGroupHierarchy(ctx, containingIDs, subIDs); err != nil {
return err
}
return nil
}
func (s *Service) validateUpdate(ctx context.Context, id int, partial models.GroupPartial) error {
// get the existing group - ensure it exists
existing, err := s.Repository.Find(ctx, id)
if err != nil {
return err
}
if existing == nil {
return models.ErrNotFound
}
if partial.Name.Set {
if err := validateName(partial.Name.Value); err != nil {
return err
}
}
if err := s.validateUpdateGroupHierarchy(ctx, existing, partial.ContainingGroups, partial.SubGroups); err != nil {
return err
}
return nil
}
func validateName(n string) error {
// ensure name is not empty
if strings.TrimSpace(n) == "" {
return ErrEmptyName
}
return nil
}
func (s *Service) validateGroupHierarchy(ctx context.Context, containingIDs []int, subIDs []int) error {
// only need to validate if both are non-empty
if len(containingIDs) == 0 || len(subIDs) == 0 {
return nil
}
// ensure none of the containing groups are in the sub groups
found, err := s.Repository.FindInAncestors(ctx, containingIDs, subIDs)
if err != nil {
return err
}
if len(found) > 0 {
return ErrHierarchyLoop
}
return nil
}
func (s *Service) validateUpdateGroupHierarchy(ctx context.Context, existing *models.Group, containingGroups *models.UpdateGroupDescriptions, subGroups *models.UpdateGroupDescriptions) error {
// no need to validate if there are no changes
if containingGroups == nil && subGroups == nil {
return nil
}
if err := existing.LoadContainingGroupIDs(ctx, s.Repository); err != nil {
return err
}
existingContainingGroups := existing.ContainingGroups.List()
if err := existing.LoadSubGroupIDs(ctx, s.Repository); err != nil {
return err
}
existingSubGroups := existing.SubGroups.List()
effectiveContainingGroups := existingContainingGroups
if containingGroups != nil {
effectiveContainingGroups = containingGroups.Apply(existingContainingGroups)
}
effectiveSubGroups := existingSubGroups
if subGroups != nil {
effectiveSubGroups = subGroups.Apply(existingSubGroups)
}
containingIDs := idsFromGroupDescriptions(effectiveContainingGroups)
subIDs := idsFromGroupDescriptions(effectiveSubGroups)
// ensure we haven't set the group as a subgroup of itself
if sliceutil.Contains(containingIDs, existing.ID) || sliceutil.Contains(subIDs, existing.ID) {
return ErrHierarchyLoop
}
return s.validateGroupHierarchy(ctx, containingIDs, subIDs)
}
func idsFromGroupDescriptions(v []models.GroupIDDescription) []int {
return sliceutil.Map(v, func(g models.GroupIDDescription) int { return g.GroupID })
}

View File

@@ -23,6 +23,14 @@ type GroupFilterType struct {
TagCount *IntCriterionInput `json:"tag_count"`
// Filter by date
Date *DateCriterionInput `json:"date"`
// Filter by containing groups
ContainingGroups *HierarchicalMultiCriterionInput `json:"containing_groups"`
// Filter by sub groups
SubGroups *HierarchicalMultiCriterionInput `json:"sub_groups"`
// Filter by number of containing groups the group has
ContainingGroupCount *IntCriterionInput `json:"containing_group_count"`
// Filter by number of sub-groups the group has
SubGroupCount *IntCriterionInput `json:"sub_group_count"`
// Filter by related scenes that meet this criteria
ScenesFilter *SceneFilterType `json:"scenes_filter"`
// Filter by related studios that meet this criteria

View File

@@ -11,21 +11,27 @@ import (
"github.com/stashapp/stash/pkg/models/json"
)
type SubGroupDescription struct {
Group string `json:"name,omitempty"`
Description string `json:"description,omitempty"`
}
type Group struct {
Name string `json:"name,omitempty"`
Aliases string `json:"aliases,omitempty"`
Duration int `json:"duration,omitempty"`
Date string `json:"date,omitempty"`
Rating int `json:"rating,omitempty"`
Director string `json:"director,omitempty"`
Synopsis string `json:"synopsis,omitempty"`
FrontImage string `json:"front_image,omitempty"`
BackImage string `json:"back_image,omitempty"`
URLs []string `json:"urls,omitempty"`
Studio string `json:"studio,omitempty"`
Tags []string `json:"tags,omitempty"`
CreatedAt json.JSONTime `json:"created_at,omitempty"`
UpdatedAt json.JSONTime `json:"updated_at,omitempty"`
Name string `json:"name,omitempty"`
Aliases string `json:"aliases,omitempty"`
Duration int `json:"duration,omitempty"`
Date string `json:"date,omitempty"`
Rating int `json:"rating,omitempty"`
Director string `json:"director,omitempty"`
Synopsis string `json:"synopsis,omitempty"`
FrontImage string `json:"front_image,omitempty"`
BackImage string `json:"back_image,omitempty"`
URLs []string `json:"urls,omitempty"`
Studio string `json:"studio,omitempty"`
Tags []string `json:"tags,omitempty"`
SubGroups []SubGroupDescription `json:"sub_groups,omitempty"`
CreatedAt json.JSONTime `json:"created_at,omitempty"`
UpdatedAt json.JSONTime `json:"updated_at,omitempty"`
// deprecated - for import only
URL string `json:"url,omitempty"`

View File

@@ -289,6 +289,29 @@ func (_m *GroupReaderWriter) GetBackImage(ctx context.Context, groupID int) ([]b
return r0, r1
}
// GetContainingGroupDescriptions provides a mock function with given fields: ctx, id
func (_m *GroupReaderWriter) GetContainingGroupDescriptions(ctx context.Context, id int) ([]models.GroupIDDescription, error) {
ret := _m.Called(ctx, id)
var r0 []models.GroupIDDescription
if rf, ok := ret.Get(0).(func(context.Context, int) []models.GroupIDDescription); ok {
r0 = rf(ctx, id)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]models.GroupIDDescription)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
r1 = rf(ctx, id)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// GetFrontImage provides a mock function with given fields: ctx, groupID
func (_m *GroupReaderWriter) GetFrontImage(ctx context.Context, groupID int) ([]byte, error) {
ret := _m.Called(ctx, groupID)
@@ -312,6 +335,29 @@ func (_m *GroupReaderWriter) GetFrontImage(ctx context.Context, groupID int) ([]
return r0, r1
}
// GetSubGroupDescriptions provides a mock function with given fields: ctx, id
func (_m *GroupReaderWriter) GetSubGroupDescriptions(ctx context.Context, id int) ([]models.GroupIDDescription, error) {
ret := _m.Called(ctx, id)
var r0 []models.GroupIDDescription
if rf, ok := ret.Get(0).(func(context.Context, int) []models.GroupIDDescription); ok {
r0 = rf(ctx, id)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]models.GroupIDDescription)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
r1 = rf(ctx, id)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// GetTagIDs provides a mock function with given fields: ctx, relatedID
func (_m *GroupReaderWriter) GetTagIDs(ctx context.Context, relatedID int) ([]int, error) {
ret := _m.Called(ctx, relatedID)

View File

@@ -21,6 +21,9 @@ type Group struct {
URLs RelatedStrings `json:"urls"`
TagIDs RelatedIDs `json:"tag_ids"`
ContainingGroups RelatedGroupDescriptions `json:"containing_groups"`
SubGroups RelatedGroupDescriptions `json:"sub_groups"`
}
func NewGroup() Group {
@@ -43,20 +46,34 @@ func (m *Group) LoadTagIDs(ctx context.Context, l TagIDLoader) error {
})
}
func (m *Group) LoadContainingGroupIDs(ctx context.Context, l ContainingGroupLoader) error {
return m.ContainingGroups.load(func() ([]GroupIDDescription, error) {
return l.GetContainingGroupDescriptions(ctx, m.ID)
})
}
func (m *Group) LoadSubGroupIDs(ctx context.Context, l SubGroupLoader) error {
return m.SubGroups.load(func() ([]GroupIDDescription, error) {
return l.GetSubGroupDescriptions(ctx, m.ID)
})
}
type GroupPartial struct {
Name OptionalString
Aliases OptionalString
Duration OptionalInt
Date OptionalDate
// Rating expressed in 1-100 scale
Rating OptionalInt
StudioID OptionalInt
Director OptionalString
Synopsis OptionalString
URLs *UpdateStrings
TagIDs *UpdateIDs
CreatedAt OptionalTime
UpdatedAt OptionalTime
Rating OptionalInt
StudioID OptionalInt
Director OptionalString
Synopsis OptionalString
URLs *UpdateStrings
TagIDs *UpdateIDs
ContainingGroups *UpdateGroupDescriptions
SubGroups *UpdateGroupDescriptions
CreatedAt OptionalTime
UpdatedAt OptionalTime
}
func NewGroupPartial() GroupPartial {

View File

@@ -68,3 +68,8 @@ func GroupsScenesFromInput(input []SceneMovieInput) ([]GroupsScenes, error) {
return ret, nil
}
type GroupIDDescription struct {
GroupID int `json:"group_id"`
Description string `json:"description"`
}

View File

@@ -2,6 +2,8 @@ package models
import (
"context"
"github.com/stashapp/stash/pkg/sliceutil"
)
type SceneIDLoader interface {
@@ -37,6 +39,14 @@ type SceneGroupLoader interface {
GetGroups(ctx context.Context, id int) ([]GroupsScenes, error)
}
type ContainingGroupLoader interface {
GetContainingGroupDescriptions(ctx context.Context, id int) ([]GroupIDDescription, error)
}
type SubGroupLoader interface {
GetSubGroupDescriptions(ctx context.Context, id int) ([]GroupIDDescription, error)
}
type StashIDLoader interface {
GetStashIDs(ctx context.Context, relatedID int) ([]StashID, error)
}
@@ -185,6 +195,82 @@ func (r *RelatedGroups) load(fn func() ([]GroupsScenes, error)) error {
return nil
}
type RelatedGroupDescriptions struct {
list []GroupIDDescription
}
// NewRelatedGroups returns a loaded RelateGroups object with the provided groups.
// Loaded will return true when called on the returned object if the provided slice is not nil.
func NewRelatedGroupDescriptions(list []GroupIDDescription) RelatedGroupDescriptions {
return RelatedGroupDescriptions{
list: list,
}
}
// Loaded returns true if the relationship has been loaded.
func (r RelatedGroupDescriptions) Loaded() bool {
return r.list != nil
}
func (r RelatedGroupDescriptions) mustLoaded() {
if !r.Loaded() {
panic("list has not been loaded")
}
}
// List returns the related Groups. Panics if the relationship has not been loaded.
func (r RelatedGroupDescriptions) List() []GroupIDDescription {
r.mustLoaded()
return r.list
}
// List returns the related Groups. Panics if the relationship has not been loaded.
func (r RelatedGroupDescriptions) IDs() []int {
r.mustLoaded()
return sliceutil.Map(r.list, func(d GroupIDDescription) int { return d.GroupID })
}
// Add adds the provided ids to the list. Panics if the relationship has not been loaded.
func (r *RelatedGroupDescriptions) Add(groups ...GroupIDDescription) {
r.mustLoaded()
r.list = append(r.list, groups...)
}
// ForID returns the GroupsScenes object for the given group ID. Returns nil if not found.
func (r *RelatedGroupDescriptions) ForID(id int) *GroupIDDescription {
r.mustLoaded()
for _, v := range r.list {
if v.GroupID == id {
return &v
}
}
return nil
}
func (r *RelatedGroupDescriptions) load(fn func() ([]GroupIDDescription, error)) error {
if r.Loaded() {
return nil
}
ids, err := fn()
if err != nil {
return err
}
if ids == nil {
ids = []GroupIDDescription{}
}
r.list = ids
return nil
}
type RelatedStashIDs struct {
list []StashID
}

View File

@@ -66,6 +66,8 @@ type GroupReader interface {
GroupCounter
URLLoader
TagIDLoader
ContainingGroupLoader
SubGroupLoader
All(ctx context.Context) ([]*Group, error)
GetFrontImage(ctx context.Context, groupID int) ([]byte, error)

View File

@@ -37,10 +37,7 @@ type SceneQueryer interface {
type SceneCounter interface {
Count(ctx context.Context) (int, error)
CountByPerformerID(ctx context.Context, performerID int) (int, error)
CountByGroupID(ctx context.Context, groupID int) (int, error)
CountByFileID(ctx context.Context, fileID FileID) (int, error)
CountByStudioID(ctx context.Context, studioID int) (int, error)
CountByTagID(ctx context.Context, tagID int) (int, error)
CountMissingChecksum(ctx context.Context) (int, error)
CountMissingOSHash(ctx context.Context) (int, error)
OCountByPerformerID(ctx context.Context, performerID int) (int, error)

View File

@@ -56,7 +56,7 @@ type SceneFilterType struct {
// Filter to only include scenes with this studio
Studios *HierarchicalMultiCriterionInput `json:"studios"`
// Filter to only include scenes with this group
Groups *MultiCriterionInput `json:"groups"`
Groups *HierarchicalMultiCriterionInput `json:"groups"`
// Filter to only include scenes with this movie
Movies *MultiCriterionInput `json:"movies"`
// Filter to only include scenes with this gallery

View File

@@ -133,3 +133,68 @@ func applyUpdate[T comparable](values []T, mode RelationshipUpdateMode, existing
return nil
}
type UpdateGroupDescriptions struct {
Groups []GroupIDDescription `json:"groups"`
Mode RelationshipUpdateMode `json:"mode"`
}
// Apply applies the update to a list of existing ids, returning the result.
func (u *UpdateGroupDescriptions) Apply(existing []GroupIDDescription) []GroupIDDescription {
if u == nil {
return existing
}
switch u.Mode {
case RelationshipUpdateModeAdd:
return u.applyAdd(existing)
case RelationshipUpdateModeRemove:
return u.applyRemove(existing)
case RelationshipUpdateModeSet:
return u.Groups
}
return nil
}
func (u *UpdateGroupDescriptions) applyAdd(existing []GroupIDDescription) []GroupIDDescription {
// overwrite any existing values with the same id
ret := append([]GroupIDDescription{}, existing...)
for _, v := range u.Groups {
found := false
for i, vv := range ret {
if vv.GroupID == v.GroupID {
ret[i] = v
found = true
break
}
}
if !found {
ret = append(ret, v)
}
}
return ret
}
func (u *UpdateGroupDescriptions) applyRemove(existing []GroupIDDescription) []GroupIDDescription {
// remove any existing values with the same id
var ret []GroupIDDescription
for _, v := range existing {
found := false
for _, vv := range u.Groups {
if vv.GroupID == v.GroupID {
found = true
break
}
}
// if not found in the remove list, keep it
if !found {
ret = append(ret, v)
}
}
return ret
}

View File

@@ -144,3 +144,15 @@ func CountByTagID(ctx context.Context, r models.SceneQueryer, id int, depth *int
return r.QueryCount(ctx, filter, nil)
}
func CountByGroupID(ctx context.Context, r models.SceneQueryer, id int, depth *int) (int, error) {
filter := &models.SceneFilterType{
Groups: &models.HierarchicalMultiCriterionInput{
Value: []string{strconv.Itoa(id)},
Modifier: models.CriterionModifierIncludes,
Depth: depth,
},
}
return r.QueryCount(ctx, filter, nil)
}

View File

@@ -146,7 +146,7 @@ func Filter[T any](vs []T, f func(T) bool) []T {
return ret
}
// Filter returns the result of applying f to each element of the vs slice.
// Map returns the result of applying f to each element of the vs slice.
func Map[T any, V any](vs []T, f func(T) V) []V {
ret := make([]V, len(vs))
for i, v := range vs {

View File

@@ -30,7 +30,7 @@ const (
dbConnTimeout = 30
)
var appSchemaVersion uint = 66
var appSchemaVersion uint = 67
//go:embed migrations/*.sql
var migrationsBox embed.FS

View File

@@ -0,0 +1,222 @@
package sqlite
import (
"context"
"fmt"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/utils"
)
// hierarchicalRelationshipHandler provides handlers for parent, children, parent count, and child count criteria.
type hierarchicalRelationshipHandler struct {
primaryTable string
relationTable string
aliasPrefix string
parentIDCol string
childIDCol string
}
func (h hierarchicalRelationshipHandler) validateModifier(m models.CriterionModifier) error {
switch m {
case models.CriterionModifierIncludesAll, models.CriterionModifierIncludes, models.CriterionModifierExcludes, models.CriterionModifierIsNull, models.CriterionModifierNotNull:
// valid
return nil
default:
return fmt.Errorf("invalid modifier %s", m)
}
}
func (h hierarchicalRelationshipHandler) handleNullNotNull(f *filterBuilder, m models.CriterionModifier, isParents bool) {
var notClause string
if m == models.CriterionModifierNotNull {
notClause = "NOT"
}
as := h.aliasPrefix + "_parents"
col := h.childIDCol
if !isParents {
as = h.aliasPrefix + "_children"
col = h.parentIDCol
}
// Based on:
// f.addLeftJoin("tags_relations", "parent_relations", "tags.id = parent_relations.child_id")
// f.addWhere(fmt.Sprintf("parent_relations.parent_id IS %s NULL", notClause))
f.addLeftJoin(h.relationTable, as, fmt.Sprintf("%s.id = %s.%s", h.primaryTable, as, col))
f.addWhere(fmt.Sprintf("%s.%s IS %s NULL", as, col, notClause))
}
func (h hierarchicalRelationshipHandler) parentsAlias() string {
return h.aliasPrefix + "_parents"
}
func (h hierarchicalRelationshipHandler) childrenAlias() string {
return h.aliasPrefix + "_children"
}
func (h hierarchicalRelationshipHandler) valueQuery(value []string, depth int, alias string, isParents bool) string {
var depthCondition string
if depth != -1 {
depthCondition = fmt.Sprintf("WHERE depth < %d", depth)
}
queryTempl := `{alias} AS (
SELECT {root_id_col} AS root_id, {item_id_col} AS item_id, 0 AS depth FROM {relation_table} WHERE {root_id_col} IN` + getInBinding(len(value)) + `
UNION
SELECT root_id, {item_id_col}, depth + 1 FROM {relation_table} INNER JOIN {alias} ON item_id = {root_id_col} ` + depthCondition + `
)`
var queryMap utils.StrFormatMap
if isParents {
queryMap = utils.StrFormatMap{
"root_id_col": h.parentIDCol,
"item_id_col": h.childIDCol,
}
} else {
queryMap = utils.StrFormatMap{
"root_id_col": h.childIDCol,
"item_id_col": h.parentIDCol,
}
}
queryMap["alias"] = alias
queryMap["relation_table"] = h.relationTable
return utils.StrFormat(queryTempl, queryMap)
}
func (h hierarchicalRelationshipHandler) handleValues(f *filterBuilder, c models.HierarchicalMultiCriterionInput, isParents bool, aliasSuffix string) {
if len(c.Value) == 0 {
return
}
var args []interface{}
for _, val := range c.Value {
args = append(args, val)
}
depthVal := 0
if c.Depth != nil {
depthVal = *c.Depth
}
tableAlias := h.parentsAlias()
if !isParents {
tableAlias = h.childrenAlias()
}
tableAlias += aliasSuffix
query := h.valueQuery(c.Value, depthVal, tableAlias, isParents)
f.addRecursiveWith(query, args...)
f.addLeftJoin(tableAlias, "", fmt.Sprintf("%s.item_id = %s.id", tableAlias, h.primaryTable))
addHierarchicalConditionClauses(f, c, tableAlias, "root_id")
}
func (h hierarchicalRelationshipHandler) handleValuesSimple(f *filterBuilder, value string, isParents bool) {
joinCol := h.childIDCol
valueCol := h.parentIDCol
if !isParents {
joinCol = h.parentIDCol
valueCol = h.childIDCol
}
tableAlias := h.parentsAlias()
if !isParents {
tableAlias = h.childrenAlias()
}
f.addInnerJoin(h.relationTable, tableAlias, fmt.Sprintf("%s.%s = %s.id", tableAlias, joinCol, h.primaryTable))
f.addWhere(fmt.Sprintf("%s.%s = ?", tableAlias, valueCol), value)
}
func (h hierarchicalRelationshipHandler) hierarchicalCriterionHandler(criterion *models.HierarchicalMultiCriterionInput, isParents bool) criterionHandlerFunc {
return func(ctx context.Context, f *filterBuilder) {
if criterion != nil {
c := criterion.CombineExcludes()
// validate the modifier
if err := h.validateModifier(c.Modifier); err != nil {
f.setError(err)
return
}
if c.Modifier == models.CriterionModifierIsNull || c.Modifier == models.CriterionModifierNotNull {
h.handleNullNotNull(f, c.Modifier, isParents)
return
}
if len(c.Value) == 0 && len(c.Excludes) == 0 {
return
}
depth := 0
if c.Depth != nil {
depth = *c.Depth
}
// if we have a single include, no excludes, and no depth, we can use a simple join and where clause
if (c.Modifier == models.CriterionModifierIncludes || c.Modifier == models.CriterionModifierIncludesAll) && len(c.Value) == 1 && len(c.Excludes) == 0 && depth == 0 {
h.handleValuesSimple(f, c.Value[0], isParents)
return
}
aliasSuffix := ""
h.handleValues(f, c, isParents, aliasSuffix)
if len(c.Excludes) > 0 {
exCriterion := models.HierarchicalMultiCriterionInput{
Value: c.Excludes,
Depth: c.Depth,
Modifier: models.CriterionModifierExcludes,
}
aliasSuffix := "2"
h.handleValues(f, exCriterion, isParents, aliasSuffix)
}
}
}
}
func (h hierarchicalRelationshipHandler) ParentsCriterionHandler(criterion *models.HierarchicalMultiCriterionInput) criterionHandlerFunc {
const isParents = true
return h.hierarchicalCriterionHandler(criterion, isParents)
}
func (h hierarchicalRelationshipHandler) ChildrenCriterionHandler(criterion *models.HierarchicalMultiCriterionInput) criterionHandlerFunc {
const isParents = false
return h.hierarchicalCriterionHandler(criterion, isParents)
}
func (h hierarchicalRelationshipHandler) countCriterionHandler(c *models.IntCriterionInput, isParents bool) criterionHandlerFunc {
tableAlias := h.parentsAlias()
col := h.childIDCol
otherCol := h.parentIDCol
if !isParents {
tableAlias = h.childrenAlias()
col = h.parentIDCol
otherCol = h.childIDCol
}
tableAlias += "_count"
return func(ctx context.Context, f *filterBuilder) {
if c != nil {
f.addLeftJoin(h.relationTable, tableAlias, fmt.Sprintf("%s.%s = %s.id", tableAlias, col, h.primaryTable))
clause, args := getIntCriterionWhereClause(fmt.Sprintf("count(distinct %s.%s)", tableAlias, otherCol), *c)
f.addHaving(clause, args...)
}
}
}
func (h hierarchicalRelationshipHandler) ParentCountCriterionHandler(parentCount *models.IntCriterionInput) criterionHandlerFunc {
const isParents = true
return h.countCriterionHandler(parentCount, isParents)
}
func (h hierarchicalRelationshipHandler) ChildCountCriterionHandler(childCount *models.IntCriterionInput) criterionHandlerFunc {
const isParents = false
return h.countCriterionHandler(childCount, isParents)
}

View File

@@ -27,6 +27,8 @@ const (
groupURLsTable = "group_urls"
groupURLColumn = "url"
groupRelationsTable = "groups_relations"
)
type groupRow struct {
@@ -128,6 +130,7 @@ var (
type GroupStore struct {
blobJoinQueryBuilder
tagRelationshipStore
groupRelationshipStore
tableMgr *table
}
@@ -143,6 +146,9 @@ func NewGroupStore(blobStore *BlobStore) *GroupStore {
joinTable: groupsTagsTableMgr,
},
},
groupRelationshipStore: groupRelationshipStore{
table: groupRelationshipTableMgr,
},
tableMgr: groupTableMgr,
}
@@ -176,6 +182,14 @@ func (qb *GroupStore) Create(ctx context.Context, newObject *models.Group) error
return err
}
if err := qb.groupRelationshipStore.createContainingRelationships(ctx, id, newObject.ContainingGroups); err != nil {
return err
}
if err := qb.groupRelationshipStore.createSubRelationships(ctx, id, newObject.SubGroups); err != nil {
return err
}
updated, err := qb.find(ctx, id)
if err != nil {
return fmt.Errorf("finding after create: %w", err)
@@ -211,6 +225,14 @@ func (qb *GroupStore) UpdatePartial(ctx context.Context, id int, partial models.
return nil, err
}
if err := qb.groupRelationshipStore.modifyContainingRelationships(ctx, id, partial.ContainingGroups); err != nil {
return nil, err
}
if err := qb.groupRelationshipStore.modifySubRelationships(ctx, id, partial.SubGroups); err != nil {
return nil, err
}
return qb.find(ctx, id)
}
@@ -232,6 +254,14 @@ func (qb *GroupStore) Update(ctx context.Context, updatedObject *models.Group) e
return err
}
if err := qb.groupRelationshipStore.replaceContainingRelationships(ctx, updatedObject.ID, updatedObject.ContainingGroups); err != nil {
return err
}
if err := qb.groupRelationshipStore.replaceSubRelationships(ctx, updatedObject.ID, updatedObject.SubGroups); err != nil {
return err
}
return nil
}
@@ -412,9 +442,7 @@ func (qb *GroupStore) makeQuery(ctx context.Context, groupFilter *models.GroupFi
return nil, err
}
var err error
query.sortAndPagination, err = qb.getGroupSort(findFilter)
if err != nil {
if err := qb.setGroupSort(&query, findFilter); err != nil {
return nil, err
}
@@ -460,11 +488,12 @@ var groupSortOptions = sortOptions{
"random",
"rating",
"scenes_count",
"sub_group_order",
"tag_count",
"updated_at",
}
func (qb *GroupStore) getGroupSort(findFilter *models.FindFilterType) (string, error) {
func (qb *GroupStore) setGroupSort(query *queryBuilder, findFilter *models.FindFilterType) error {
var sort string
var direction string
if findFilter == nil {
@@ -477,22 +506,31 @@ func (qb *GroupStore) getGroupSort(findFilter *models.FindFilterType) (string, e
// CVE-2024-32231 - ensure sort is in the list of allowed sorts
if err := groupSortOptions.validateSort(sort); err != nil {
return "", err
return err
}
sortQuery := ""
switch sort {
case "sub_group_order":
// sub_group_order is a special sort that sorts by the order_index of the subgroups
if query.hasJoin("groups_parents") {
query.sortAndPagination += getSort("order_index", direction, "groups_parents")
} else {
// this will give unexpected results if the query is not filtered by a parent group and
// the group has multiple parents and order indexes
query.join(groupRelationsTable, "", "groups.id = groups_relations.sub_id")
query.sortAndPagination += getSort("order_index", direction, groupRelationsTable)
}
case "tag_count":
sortQuery += getCountSort(groupTable, groupsTagsTable, groupIDColumn, direction)
query.sortAndPagination += getCountSort(groupTable, groupsTagsTable, groupIDColumn, direction)
case "scenes_count": // generic getSort won't work for this
sortQuery += getCountSort(groupTable, groupsScenesTable, groupIDColumn, direction)
query.sortAndPagination += getCountSort(groupTable, groupsScenesTable, groupIDColumn, direction)
default:
sortQuery += getSort(sort, direction, "groups")
query.sortAndPagination += getSort(sort, direction, "groups")
}
// Whatever the sorting, always use name/id as a final sort
sortQuery += ", COALESCE(groups.name, groups.id) COLLATE NATURAL_CI ASC"
return sortQuery, nil
query.sortAndPagination += ", COALESCE(groups.name, groups.id) COLLATE NATURAL_CI ASC"
return nil
}
func (qb *GroupStore) queryGroups(ctx context.Context, query string, args []interface{}) ([]*models.Group, error) {
@@ -592,3 +630,74 @@ WHERE groups.studio_id = ?
func (qb *GroupStore) GetURLs(ctx context.Context, groupID int) ([]string, error) {
return groupsURLsTableMgr.get(ctx, groupID)
}
// FindSubGroupIDs returns a list of group IDs where a group in the ids list is a sub-group of the parent group
func (qb *GroupStore) FindSubGroupIDs(ctx context.Context, containingID int, ids []int) ([]int, error) {
/*
SELECT gr.sub_id FROM groups_relations gr
WHERE gr.containing_id = :parentID AND gr.sub_id IN (:ids);
*/
table := groupRelationshipTableMgr.table
q := dialect.From(table).Prepared(true).
Select(table.Col("sub_id")).Where(
table.Col("containing_id").Eq(containingID),
table.Col("sub_id").In(ids),
)
const single = false
var ret []int
if err := queryFunc(ctx, q, single, func(r *sqlx.Rows) error {
var id int
if err := r.Scan(&id); err != nil {
return err
}
ret = append(ret, id)
return nil
}); err != nil {
return nil, err
}
return ret, nil
}
// FindInAscestors returns a list of group IDs where a group in the ids list is an ascestor of the ancestor group IDs
func (qb *GroupStore) FindInAncestors(ctx context.Context, ascestorIDs []int, ids []int) ([]int, error) {
/*
WITH RECURSIVE ascestors AS (
SELECT g.id AS parent_id FROM groups g WHERE g.id IN (:ascestorIDs)
UNION
SELECT gr.containing_id FROM groups_relations gr INNER JOIN ascestors a ON a.parent_id = gr.sub_id
)
SELECT p.parent_id FROM ascestors p WHERE p.parent_id IN (:ids);
*/
table := qb.table()
const ascestors = "ancestors"
const parentID = "parent_id"
q := dialect.From(ascestors).Prepared(true).
WithRecursive(ascestors,
dialect.From(qb.table()).Select(table.Col(idColumn).As(parentID)).
Where(table.Col(idColumn).In(ascestorIDs)).
Union(
dialect.From(groupRelationsJoinTable).InnerJoin(
goqu.I(ascestors), goqu.On(goqu.I("parent_id").Eq(goqu.I("sub_id"))),
).Select("containing_id"),
),
).Select(parentID).Where(goqu.I(parentID).In(ids))
const single = false
var ret []int
if err := queryFunc(ctx, q, single, func(r *sqlx.Rows) error {
var id int
if err := r.Scan(&id); err != nil {
return err
}
ret = append(ret, id)
return nil
}); err != nil {
return nil, err
}
return ret, nil
}

View File

@@ -51,6 +51,14 @@ func (qb *groupFilterHandler) handle(ctx context.Context, f *filterBuilder) {
f.handleCriterion(ctx, qb.criterionHandler())
}
var groupHierarchyHandler = hierarchicalRelationshipHandler{
primaryTable: groupTable,
relationTable: groupRelationsTable,
aliasPrefix: groupTable,
parentIDCol: "containing_id",
childIDCol: "sub_id",
}
func (qb *groupFilterHandler) criterionHandler() criterionHandler {
groupFilter := qb.groupFilter
return compoundHandler{
@@ -66,6 +74,10 @@ func (qb *groupFilterHandler) criterionHandler() criterionHandler {
qb.tagsCriterionHandler(groupFilter.Tags),
qb.tagCountCriterionHandler(groupFilter.TagCount),
&dateCriterionHandler{groupFilter.Date, "groups.date", nil},
groupHierarchyHandler.ParentsCriterionHandler(groupFilter.ContainingGroups),
groupHierarchyHandler.ChildrenCriterionHandler(groupFilter.SubGroups),
groupHierarchyHandler.ParentCountCriterionHandler(groupFilter.ContainingGroupCount),
groupHierarchyHandler.ChildCountCriterionHandler(groupFilter.SubGroupCount),
&timestampCriterionHandler{groupFilter.CreatedAt, "groups.created_at", nil},
&timestampCriterionHandler{groupFilter.UpdatedAt, "groups.updated_at", nil},

View File

@@ -0,0 +1,457 @@
package sqlite
import (
"context"
"fmt"
"github.com/doug-martin/goqu/v9"
"github.com/doug-martin/goqu/v9/exp"
"github.com/jmoiron/sqlx"
"github.com/stashapp/stash/pkg/models"
"gopkg.in/guregu/null.v4"
"gopkg.in/guregu/null.v4/zero"
)
type groupRelationshipRow struct {
ContainingID int `db:"containing_id"`
SubID int `db:"sub_id"`
OrderIndex int `db:"order_index"`
Description zero.String `db:"description"`
}
func (r groupRelationshipRow) resolve(useContainingID bool) models.GroupIDDescription {
id := r.ContainingID
if !useContainingID {
id = r.SubID
}
return models.GroupIDDescription{
GroupID: id,
Description: r.Description.String,
}
}
type groupRelationshipStore struct {
table *table
}
func (s *groupRelationshipStore) GetContainingGroupDescriptions(ctx context.Context, id int) ([]models.GroupIDDescription, error) {
const idIsContaining = false
return s.getGroupRelationships(ctx, id, idIsContaining)
}
func (s *groupRelationshipStore) GetSubGroupDescriptions(ctx context.Context, id int) ([]models.GroupIDDescription, error) {
const idIsContaining = true
return s.getGroupRelationships(ctx, id, idIsContaining)
}
func (s *groupRelationshipStore) getGroupRelationships(ctx context.Context, id int, idIsContaining bool) ([]models.GroupIDDescription, error) {
col := "containing_id"
if !idIsContaining {
col = "sub_id"
}
table := s.table.table
q := dialect.Select(table.All()).
From(table).
Where(table.Col(col).Eq(id)).
Order(table.Col("order_index").Asc())
const single = false
var ret []models.GroupIDDescription
if err := queryFunc(ctx, q, single, func(rows *sqlx.Rows) error {
var row groupRelationshipRow
if err := rows.StructScan(&row); err != nil {
return err
}
ret = append(ret, row.resolve(!idIsContaining))
return nil
}); err != nil {
return nil, fmt.Errorf("getting group relationships from %s: %w", table.GetTable(), err)
}
return ret, nil
}
// getMaxOrderIndex gets the maximum order index for the containing group with the given id
func (s *groupRelationshipStore) getMaxOrderIndex(ctx context.Context, containingID int) (int, error) {
idColumn := s.table.table.Col("containing_id")
q := dialect.Select(goqu.MAX("order_index")).
From(s.table.table).
Where(idColumn.Eq(containingID))
var maxOrderIndex zero.Int
if err := querySimple(ctx, q, &maxOrderIndex); err != nil {
return 0, fmt.Errorf("getting max order index: %w", err)
}
return int(maxOrderIndex.Int64), nil
}
// createRelationships creates relationships between a group and other groups.
// If idIsContaining is true, the provided id is the containing group.
func (s *groupRelationshipStore) createRelationships(ctx context.Context, id int, d models.RelatedGroupDescriptions, idIsContaining bool) error {
if d.Loaded() {
for i, v := range d.List() {
orderIndex := i + 1
r := groupRelationshipRow{
ContainingID: id,
SubID: v.GroupID,
OrderIndex: orderIndex,
Description: zero.StringFrom(v.Description),
}
if !idIsContaining {
// get the max order index of the containing groups sub groups
containingID := v.GroupID
maxOrderIndex, err := s.getMaxOrderIndex(ctx, containingID)
if err != nil {
return err
}
r.ContainingID = v.GroupID
r.SubID = id
r.OrderIndex = maxOrderIndex + 1
}
_, err := s.table.insert(ctx, r)
if err != nil {
return fmt.Errorf("inserting into %s: %w", s.table.table.GetTable(), err)
}
}
return nil
}
return nil
}
// createRelationships creates relationships between a group and other groups.
// If idIsContaining is true, the provided id is the containing group.
func (s *groupRelationshipStore) createContainingRelationships(ctx context.Context, id int, d models.RelatedGroupDescriptions) error {
const idIsContaining = false
return s.createRelationships(ctx, id, d, idIsContaining)
}
// createRelationships creates relationships between a group and other groups.
// If idIsContaining is true, the provided id is the containing group.
func (s *groupRelationshipStore) createSubRelationships(ctx context.Context, id int, d models.RelatedGroupDescriptions) error {
const idIsContaining = true
return s.createRelationships(ctx, id, d, idIsContaining)
}
func (s *groupRelationshipStore) replaceRelationships(ctx context.Context, id int, d models.RelatedGroupDescriptions, idIsContaining bool) error {
// always destroy the existing relationships even if the new list is empty
if err := s.destroyAllJoins(ctx, id, idIsContaining); err != nil {
return err
}
return s.createRelationships(ctx, id, d, idIsContaining)
}
func (s *groupRelationshipStore) replaceContainingRelationships(ctx context.Context, id int, d models.RelatedGroupDescriptions) error {
const idIsContaining = false
return s.replaceRelationships(ctx, id, d, idIsContaining)
}
func (s *groupRelationshipStore) replaceSubRelationships(ctx context.Context, id int, d models.RelatedGroupDescriptions) error {
const idIsContaining = true
return s.replaceRelationships(ctx, id, d, idIsContaining)
}
func (s *groupRelationshipStore) modifyRelationships(ctx context.Context, id int, v *models.UpdateGroupDescriptions, idIsContaining bool) error {
if v == nil {
return nil
}
switch v.Mode {
case models.RelationshipUpdateModeSet:
return s.replaceJoins(ctx, id, *v, idIsContaining)
case models.RelationshipUpdateModeAdd:
return s.addJoins(ctx, id, v.Groups, idIsContaining)
case models.RelationshipUpdateModeRemove:
toRemove := make([]int, len(v.Groups))
for i, vv := range v.Groups {
toRemove[i] = vv.GroupID
}
return s.destroyJoins(ctx, id, toRemove, idIsContaining)
}
return nil
}
func (s *groupRelationshipStore) modifyContainingRelationships(ctx context.Context, id int, v *models.UpdateGroupDescriptions) error {
const idIsContaining = false
return s.modifyRelationships(ctx, id, v, idIsContaining)
}
func (s *groupRelationshipStore) modifySubRelationships(ctx context.Context, id int, v *models.UpdateGroupDescriptions) error {
const idIsContaining = true
return s.modifyRelationships(ctx, id, v, idIsContaining)
}
func (s *groupRelationshipStore) addJoins(ctx context.Context, id int, groups []models.GroupIDDescription, idIsContaining bool) error {
// if we're adding to a containing group, get the max order index first
var maxOrderIndex int
if idIsContaining {
var err error
maxOrderIndex, err = s.getMaxOrderIndex(ctx, id)
if err != nil {
return err
}
}
for i, vv := range groups {
r := groupRelationshipRow{
Description: zero.StringFrom(vv.Description),
}
if idIsContaining {
r.ContainingID = id
r.SubID = vv.GroupID
r.OrderIndex = maxOrderIndex + (i + 1)
} else {
// get the max order index of the containing groups sub groups
containingMaxOrderIndex, err := s.getMaxOrderIndex(ctx, vv.GroupID)
if err != nil {
return err
}
r.ContainingID = vv.GroupID
r.SubID = id
r.OrderIndex = containingMaxOrderIndex + 1
}
_, err := s.table.insert(ctx, r)
if err != nil {
return fmt.Errorf("inserting into %s: %w", s.table.table.GetTable(), err)
}
}
return nil
}
func (s *groupRelationshipStore) destroyAllJoins(ctx context.Context, id int, idIsContaining bool) error {
table := s.table.table
idColumn := table.Col("containing_id")
if !idIsContaining {
idColumn = table.Col("sub_id")
}
q := dialect.Delete(table).Where(idColumn.Eq(id))
if _, err := exec(ctx, q); err != nil {
return fmt.Errorf("destroying %s: %w", table.GetTable(), err)
}
return nil
}
func (s *groupRelationshipStore) replaceJoins(ctx context.Context, id int, v models.UpdateGroupDescriptions, idIsContaining bool) error {
if err := s.destroyAllJoins(ctx, id, idIsContaining); err != nil {
return err
}
// convert to RelatedGroupDescriptions
rgd := models.NewRelatedGroupDescriptions(v.Groups)
return s.createRelationships(ctx, id, rgd, idIsContaining)
}
func (s *groupRelationshipStore) destroyJoins(ctx context.Context, id int, toRemove []int, idIsContaining bool) error {
table := s.table.table
idColumn := table.Col("containing_id")
fkColumn := table.Col("sub_id")
if !idIsContaining {
idColumn = table.Col("sub_id")
fkColumn = table.Col("containing_id")
}
q := dialect.Delete(table).Where(idColumn.Eq(id), fkColumn.In(toRemove))
if _, err := exec(ctx, q); err != nil {
return fmt.Errorf("destroying %s: %w", table.GetTable(), err)
}
return nil
}
func (s *groupRelationshipStore) getOrderIndexOfSubGroup(ctx context.Context, containingGroupID int, subGroupID int) (int, error) {
table := s.table.table
q := dialect.Select("order_index").
From(table).
Where(
table.Col("containing_id").Eq(containingGroupID),
table.Col("sub_id").Eq(subGroupID),
)
var orderIndex null.Int
if err := querySimple(ctx, q, &orderIndex); err != nil {
return 0, fmt.Errorf("getting order index: %w", err)
}
if !orderIndex.Valid {
return 0, fmt.Errorf("sub-group %d not found in containing group %d", subGroupID, containingGroupID)
}
return int(orderIndex.Int64), nil
}
func (s *groupRelationshipStore) getGroupIDAtOrderIndex(ctx context.Context, containingGroupID int, orderIndex int) (*int, error) {
table := s.table.table
q := dialect.Select(table.Col("sub_id")).From(table).Where(
table.Col("containing_id").Eq(containingGroupID),
table.Col("order_index").Eq(orderIndex),
)
var ret null.Int
if err := querySimple(ctx, q, &ret); err != nil {
return nil, fmt.Errorf("getting sub id for order index: %w", err)
}
if !ret.Valid {
return nil, nil
}
intRet := int(ret.Int64)
return &intRet, nil
}
func (s *groupRelationshipStore) getOrderIndexAfterOrderIndex(ctx context.Context, containingGroupID int, orderIndex int) (int, error) {
table := s.table.table
q := dialect.Select(goqu.MIN("order_index")).From(table).Where(
table.Col("containing_id").Eq(containingGroupID),
table.Col("order_index").Gt(orderIndex),
)
var ret null.Int
if err := querySimple(ctx, q, &ret); err != nil {
return 0, fmt.Errorf("getting order index: %w", err)
}
if !ret.Valid {
return orderIndex + 1, nil
}
return int(ret.Int64), nil
}
// incrementOrderIndexes increments the order_index value of all sub-groups in the containing group at or after the given index
func (s *groupRelationshipStore) incrementOrderIndexes(ctx context.Context, groupID int, indexBefore int) error {
table := s.table.table
// WORKAROUND - sqlite won't allow incrementing the value directly since it causes a
// unique constraint violation.
// Instead, we first set the order index to a negative value temporarily
// see https://stackoverflow.com/a/7703239/695786
q := dialect.Update(table).Set(exp.Record{
"order_index": goqu.L("-order_index"),
}).Where(
table.Col("containing_id").Eq(groupID),
table.Col("order_index").Gte(indexBefore),
)
if _, err := exec(ctx, q); err != nil {
return fmt.Errorf("updating %s: %w", table.GetTable(), err)
}
q = dialect.Update(table).Set(exp.Record{
"order_index": goqu.L("1-order_index"),
}).Where(
table.Col("containing_id").Eq(groupID),
table.Col("order_index").Lt(0),
)
if _, err := exec(ctx, q); err != nil {
return fmt.Errorf("updating %s: %w", table.GetTable(), err)
}
return nil
}
func (s *groupRelationshipStore) reorderSubGroup(ctx context.Context, groupID int, subGroupID int, insertPointID int, insertAfter bool) error {
insertPointIndex, err := s.getOrderIndexOfSubGroup(ctx, groupID, insertPointID)
if err != nil {
return err
}
// if we're setting before
if insertAfter {
insertPointIndex, err = s.getOrderIndexAfterOrderIndex(ctx, groupID, insertPointIndex)
if err != nil {
return err
}
}
// increment the order index of all sub-groups after and including the insertion point
if err := s.incrementOrderIndexes(ctx, groupID, int(insertPointIndex)); err != nil {
return err
}
// set the order index of the sub-group to the insertion point
table := s.table.table
q := dialect.Update(table).Set(exp.Record{
"order_index": insertPointIndex,
}).Where(
table.Col("containing_id").Eq(groupID),
table.Col("sub_id").Eq(subGroupID),
)
if _, err := exec(ctx, q); err != nil {
return fmt.Errorf("updating %s: %w", table.GetTable(), err)
}
return nil
}
func (s *groupRelationshipStore) AddSubGroups(ctx context.Context, groupID int, subGroups []models.GroupIDDescription, insertIndex *int) error {
const idIsContaining = true
if err := s.addJoins(ctx, groupID, subGroups, idIsContaining); err != nil {
return err
}
ids := make([]int, len(subGroups))
for i, v := range subGroups {
ids[i] = v.GroupID
}
if insertIndex != nil {
// get the id of the sub-group at the insert index
insertPointID, err := s.getGroupIDAtOrderIndex(ctx, groupID, *insertIndex)
if err != nil {
return err
}
if insertPointID == nil {
// if the insert index is out of bounds, just assume adding to the end
return nil
}
// reorder the sub-groups
const insertAfter = false
if err := s.ReorderSubGroups(ctx, groupID, ids, *insertPointID, insertAfter); err != nil {
return err
}
}
return nil
}
func (s *groupRelationshipStore) RemoveSubGroups(ctx context.Context, groupID int, subGroupIDs []int) error {
const idIsContaining = true
return s.destroyJoins(ctx, groupID, subGroupIDs, idIsContaining)
}
func (s *groupRelationshipStore) ReorderSubGroups(ctx context.Context, groupID int, subGroupIDs []int, insertPointID int, insertAfter bool) error {
for _, id := range subGroupIDs {
if err := s.reorderSubGroup(ctx, groupID, id, insertPointID, insertAfter); err != nil {
return err
}
}
return nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,13 @@
CREATE TABLE `groups_relations` (
`containing_id` integer not null,
`sub_id` integer not null,
`order_index` integer not null,
`description` varchar(255),
primary key (`containing_id`, `sub_id`),
foreign key (`containing_id`) references `groups`(`id`) on delete cascade,
foreign key (`sub_id`) references `groups`(`id`) on delete cascade,
check (`containing_id` != `sub_id`)
);
CREATE INDEX `index_groups_relations_sub_id` ON `groups_relations` (`sub_id`);
CREATE UNIQUE INDEX `index_groups_relations_order_index_unique` ON `groups_relations` (`containing_id`, `order_index`);

View File

@@ -110,6 +110,16 @@ func (qb *queryBuilder) addArg(args ...interface{}) {
qb.args = append(qb.args, args...)
}
func (qb *queryBuilder) hasJoin(alias string) bool {
for _, j := range qb.joins {
if j.alias() == alias {
return true
}
}
return false
}
func (qb *queryBuilder) join(table, as, onClause string) {
newJoin := join{
table: table,

View File

@@ -791,13 +791,6 @@ func (qb *SceneStore) FindByGroupID(ctx context.Context, groupID int) ([]*models
return ret, nil
}
func (qb *SceneStore) CountByGroupID(ctx context.Context, groupID int) (int, error) {
joinTable := scenesGroupsJoinTable
q := dialect.Select(goqu.COUNT("*")).From(joinTable).Where(joinTable.Col(groupIDColumn).Eq(groupID))
return count(ctx, q)
}
func (qb *SceneStore) Count(ctx context.Context) (int, error) {
q := dialect.Select(goqu.COUNT("*")).From(qb.table())
return count(ctx, q)
@@ -858,6 +851,7 @@ func (qb *SceneStore) PlayDuration(ctx context.Context) (float64, error) {
return ret, nil
}
// TODO - currently only used by unit test
func (qb *SceneStore) CountByStudioID(ctx context.Context, studioID int) (int, error) {
table := qb.table()
@@ -865,13 +859,6 @@ func (qb *SceneStore) CountByStudioID(ctx context.Context, studioID int) (int, e
return count(ctx, q)
}
func (qb *SceneStore) CountByTagID(ctx context.Context, tagID int) (int, error) {
joinTable := scenesTagsJoinTable
q := dialect.Select(goqu.COUNT("*")).From(joinTable).Where(joinTable.Col(tagIDColumn).Eq(tagID))
return count(ctx, q)
}
func (qb *SceneStore) countMissingFingerprints(ctx context.Context, fpType string) (int, error) {
fpTable := fingerprintTableMgr.table.As("fingerprints_temp")

View File

@@ -149,7 +149,7 @@ func (qb *sceneFilterHandler) criterionHandler() criterionHandler {
studioCriterionHandler(sceneTable, sceneFilter.Studios),
qb.groupsCriterionHandler(sceneFilter.Groups),
qb.groupsCriterionHandler(sceneFilter.Movies),
qb.moviesCriterionHandler(sceneFilter.Movies),
qb.galleriesCriterionHandler(sceneFilter.Galleries),
qb.performerTagsCriterionHandler(sceneFilter.PerformerTags),
@@ -483,7 +483,8 @@ func (qb *sceneFilterHandler) performerAgeCriterionHandler(performerAge *models.
}
}
func (qb *sceneFilterHandler) groupsCriterionHandler(movies *models.MultiCriterionInput) criterionHandlerFunc {
// legacy handler
func (qb *sceneFilterHandler) moviesCriterionHandler(movies *models.MultiCriterionInput) criterionHandlerFunc {
addJoinsFunc := func(f *filterBuilder) {
sceneRepository.groups.join(f, "", "scenes.id")
f.addLeftJoin("groups", "", "groups_scenes.group_id = groups.id")
@@ -492,6 +493,23 @@ func (qb *sceneFilterHandler) groupsCriterionHandler(movies *models.MultiCriteri
return h.handler(movies)
}
func (qb *sceneFilterHandler) groupsCriterionHandler(groups *models.HierarchicalMultiCriterionInput) criterionHandlerFunc {
h := joinedHierarchicalMultiCriterionHandlerBuilder{
primaryTable: sceneTable,
foreignTable: groupTable,
foreignFK: "group_id",
relationsTable: groupRelationsTable,
parentFK: "containing_id",
childFK: "sub_id",
joinAs: "scene_group",
joinTable: groupsScenesTable,
primaryFK: sceneIDColumn,
}
return h.handler(groups)
}
func (qb *sceneFilterHandler) galleriesCriterionHandler(galleries *models.MultiCriterionInput) criterionHandlerFunc {
addJoinsFunc := func(f *filterBuilder) {
sceneRepository.galleries.join(f, "", "scenes.id")

View File

@@ -16,6 +16,7 @@ import (
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/sliceutil"
"github.com/stashapp/stash/pkg/sliceutil/intslice"
"github.com/stretchr/testify/assert"
)
@@ -2217,7 +2218,7 @@ func TestSceneQuery(t *testing.T) {
},
})
if (err != nil) != tt.wantErr {
t.Errorf("PerformerStore.Query() error = %v, wantErr %v", err, tt.wantErr)
t.Errorf("SceneStore.Query() error = %v, wantErr %v", err, tt.wantErr)
return
}
@@ -3873,6 +3874,100 @@ func TestSceneQueryStudioDepth(t *testing.T) {
})
}
func TestSceneGroups(t *testing.T) {
type criterion struct {
valueIdxs []int
modifier models.CriterionModifier
depth int
}
tests := []struct {
name string
c criterion
q string
includeIdxs []int
excludeIdxs []int
}{
{
"includes",
criterion{
[]int{groupIdxWithScene},
models.CriterionModifierIncludes,
0,
},
"",
[]int{sceneIdxWithGroup},
nil,
},
{
"excludes",
criterion{
[]int{groupIdxWithScene},
models.CriterionModifierExcludes,
0,
},
getSceneStringValue(sceneIdxWithGroup, titleField),
nil,
[]int{sceneIdxWithGroup},
},
{
"includes (depth = 1)",
criterion{
[]int{groupIdxWithChildWithScene},
models.CriterionModifierIncludes,
1,
},
"",
[]int{sceneIdxWithGroupWithParent},
nil,
},
}
for _, tt := range tests {
valueIDs := indexesToIDs(groupIDs, tt.c.valueIdxs)
runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) {
assert := assert.New(t)
sceneFilter := &models.SceneFilterType{
Groups: &models.HierarchicalMultiCriterionInput{
Value: intslice.IntSliceToStringSlice(valueIDs),
Modifier: tt.c.modifier,
},
}
if tt.c.depth != 0 {
sceneFilter.Groups.Depth = &tt.c.depth
}
findFilter := &models.FindFilterType{}
if tt.q != "" {
findFilter.Q = &tt.q
}
results, err := db.Scene.Query(ctx, models.SceneQueryOptions{
SceneFilter: sceneFilter,
QueryOptions: models.QueryOptions{
FindFilter: findFilter,
},
})
if err != nil {
t.Errorf("SceneStore.Query() error = %v", err)
return
}
include := indexesToIDs(sceneIDs, tt.includeIdxs)
exclude := indexesToIDs(sceneIDs, tt.excludeIdxs)
assert.Subset(results.IDs, include)
for _, e := range exclude {
assert.NotContains(results.IDs, e)
}
})
}
}
func TestSceneQueryMovies(t *testing.T) {
withTxn(func(ctx context.Context) error {
sqb := db.Scene
@@ -4188,78 +4283,6 @@ func verifyScenesPerformerCount(t *testing.T, performerCountCriterion models.Int
})
}
func TestSceneCountByTagID(t *testing.T) {
withTxn(func(ctx context.Context) error {
sqb := db.Scene
sceneCount, err := sqb.CountByTagID(ctx, tagIDs[tagIdxWithScene])
if err != nil {
t.Errorf("error calling CountByTagID: %s", err.Error())
}
assert.Equal(t, 1, sceneCount)
sceneCount, err = sqb.CountByTagID(ctx, 0)
if err != nil {
t.Errorf("error calling CountByTagID: %s", err.Error())
}
assert.Equal(t, 0, sceneCount)
return nil
})
}
func TestSceneCountByGroupID(t *testing.T) {
withTxn(func(ctx context.Context) error {
sqb := db.Scene
sceneCount, err := sqb.CountByGroupID(ctx, groupIDs[groupIdxWithScene])
if err != nil {
t.Errorf("error calling CountByGroupID: %s", err.Error())
}
assert.Equal(t, 1, sceneCount)
sceneCount, err = sqb.CountByGroupID(ctx, 0)
if err != nil {
t.Errorf("error calling CountByGroupID: %s", err.Error())
}
assert.Equal(t, 0, sceneCount)
return nil
})
}
func TestSceneCountByStudioID(t *testing.T) {
withTxn(func(ctx context.Context) error {
sqb := db.Scene
sceneCount, err := sqb.CountByStudioID(ctx, studioIDs[studioIdxWithScene])
if err != nil {
t.Errorf("error calling CountByStudioID: %s", err.Error())
}
assert.Equal(t, 1, sceneCount)
sceneCount, err = sqb.CountByStudioID(ctx, 0)
if err != nil {
t.Errorf("error calling CountByStudioID: %s", err.Error())
}
assert.Equal(t, 0, sceneCount)
return nil
})
}
func TestFindByMovieID(t *testing.T) {
withTxn(func(ctx context.Context) error {
sqb := db.Scene

View File

@@ -78,6 +78,7 @@ const (
sceneIdxWithGrandChildStudio
sceneIdxMissingPhash
sceneIdxWithPerformerParentTag
sceneIdxWithGroupWithParent
// new indexes above
lastSceneIdx
@@ -153,9 +154,15 @@ const (
groupIdxWithTag
groupIdxWithTwoTags
groupIdxWithThreeTags
groupIdxWithGrandChild
groupIdxWithChild
groupIdxWithParentAndChild
groupIdxWithParent
groupIdxWithGrandParent
groupIdxWithParentAndScene
groupIdxWithChildWithScene
// groups with dup names start from the end
// create 7 more basic groups (can remove this if we add more indexes)
groupIdxWithDupName = groupIdxWithStudio + 7
groupIdxWithDupName
groupsNameCase = groupIdxWithDupName
groupsNameNoCase = 1
@@ -390,7 +397,8 @@ var (
}
sceneGroups = linkMap{
sceneIdxWithGroup: {groupIdxWithScene},
sceneIdxWithGroup: {groupIdxWithScene},
sceneIdxWithGroupWithParent: {groupIdxWithParentAndScene},
}
sceneStudios = map[int]int{
@@ -541,15 +549,31 @@ var (
}
)
var (
groupParentLinks = [][2]int{
{groupIdxWithChild, groupIdxWithParent},
{groupIdxWithGrandChild, groupIdxWithParentAndChild},
{groupIdxWithParentAndChild, groupIdxWithGrandParent},
{groupIdxWithChildWithScene, groupIdxWithParentAndScene},
}
)
func indexesToIDs(ids []int, indexes []int) []int {
ret := make([]int, len(indexes))
for i, idx := range indexes {
ret[i] = ids[idx]
ret[i] = indexToID(ids, idx)
}
return ret
}
func indexToID(ids []int, idx int) int {
if idx < 0 {
return invalidID
}
return ids[idx]
}
func indexFromID(ids []int, id int) int {
for i, v := range ids {
if v == id {
@@ -697,6 +721,10 @@ func populateDB() error {
return fmt.Errorf("error linking tags parent: %s", err.Error())
}
if err := linkGroupsParent(ctx, db.Group); err != nil {
return fmt.Errorf("error linking tags parent: %s", err.Error())
}
for _, ms := range markerSpecs {
if err := createMarker(ctx, db.SceneMarker, ms); err != nil {
return fmt.Errorf("error creating scene marker: %s", err.Error())
@@ -1885,6 +1913,24 @@ func linkTagsParent(ctx context.Context, qb models.TagReaderWriter) error {
})
}
func linkGroupsParent(ctx context.Context, qb models.GroupReaderWriter) error {
return doLinks(groupParentLinks, func(parentIndex, childIndex int) error {
groupID := groupIDs[childIndex]
p := models.GroupPartial{
ContainingGroups: &models.UpdateGroupDescriptions{
Groups: []models.GroupIDDescription{
{GroupID: groupIDs[parentIndex]},
},
Mode: models.RelationshipUpdateModeAdd,
},
}
_, err := qb.UpdatePartial(ctx, groupID, p)
return err
})
}
func addTagImage(ctx context.Context, qb models.TagWriter, tagIndex int) error {
return qb.UpdateImage(ctx, tagIDs[tagIndex], []byte("image"))
}

View File

@@ -37,8 +37,9 @@ var (
studiosTagsJoinTable = goqu.T(studiosTagsTable)
studiosStashIDsJoinTable = goqu.T("studio_stash_ids")
groupsURLsJoinTable = goqu.T(groupURLsTable)
groupsTagsJoinTable = goqu.T(groupsTagsTable)
groupsURLsJoinTable = goqu.T(groupURLsTable)
groupsTagsJoinTable = goqu.T(groupsTagsTable)
groupRelationsJoinTable = goqu.T(groupRelationsTable)
tagsAliasesJoinTable = goqu.T(tagAliasesTable)
tagRelationsJoinTable = goqu.T(tagRelationsTable)
@@ -361,6 +362,10 @@ var (
foreignTable: tagTableMgr,
orderBy: tagTableMgr.table.Col("name").Asc(),
}
groupRelationshipTableMgr = &table{
table: groupRelationsJoinTable,
}
)
var (

View File

@@ -2,7 +2,6 @@ package sqlite
import (
"context"
"fmt"
"github.com/stashapp/stash/pkg/models"
)
@@ -51,6 +50,14 @@ func (qb *tagFilterHandler) handle(ctx context.Context, f *filterBuilder) {
f.handleCriterion(ctx, qb.criterionHandler())
}
var tagHierarchyHandler = hierarchicalRelationshipHandler{
primaryTable: tagTable,
relationTable: tagRelationsTable,
aliasPrefix: tagTable,
parentIDCol: "parent_id",
childIDCol: "child_id",
}
func (qb *tagFilterHandler) criterionHandler() criterionHandler {
tagFilter := qb.tagFilter
return compoundHandler{
@@ -72,10 +79,10 @@ func (qb *tagFilterHandler) criterionHandler() criterionHandler {
qb.groupCountCriterionHandler(tagFilter.MovieCount),
qb.markerCountCriterionHandler(tagFilter.MarkerCount),
qb.parentsCriterionHandler(tagFilter.Parents),
qb.childrenCriterionHandler(tagFilter.Children),
qb.parentCountCriterionHandler(tagFilter.ParentCount),
qb.childCountCriterionHandler(tagFilter.ChildCount),
tagHierarchyHandler.ParentsCriterionHandler(tagFilter.Parents),
tagHierarchyHandler.ChildrenCriterionHandler(tagFilter.Children),
tagHierarchyHandler.ParentCountCriterionHandler(tagFilter.ParentCount),
tagHierarchyHandler.ChildCountCriterionHandler(tagFilter.ChildCount),
&timestampCriterionHandler{tagFilter.CreatedAt, "tags.created_at", nil},
&timestampCriterionHandler{tagFilter.UpdatedAt, "tags.updated_at", nil},
@@ -212,213 +219,3 @@ func (qb *tagFilterHandler) markerCountCriterionHandler(markerCount *models.IntC
}
}
}
func (qb *tagFilterHandler) parentsCriterionHandler(criterion *models.HierarchicalMultiCriterionInput) criterionHandlerFunc {
return func(ctx context.Context, f *filterBuilder) {
if criterion != nil {
tags := criterion.CombineExcludes()
// validate the modifier
switch tags.Modifier {
case models.CriterionModifierIncludesAll, models.CriterionModifierIncludes, models.CriterionModifierExcludes, models.CriterionModifierIsNull, models.CriterionModifierNotNull:
// valid
default:
f.setError(fmt.Errorf("invalid modifier %s for tag parent/children", criterion.Modifier))
}
if tags.Modifier == models.CriterionModifierIsNull || tags.Modifier == models.CriterionModifierNotNull {
var notClause string
if tags.Modifier == models.CriterionModifierNotNull {
notClause = "NOT"
}
f.addLeftJoin("tags_relations", "parent_relations", "tags.id = parent_relations.child_id")
f.addWhere(fmt.Sprintf("parent_relations.parent_id IS %s NULL", notClause))
return
}
if len(tags.Value) == 0 && len(tags.Excludes) == 0 {
return
}
if len(tags.Value) > 0 {
var args []interface{}
for _, val := range tags.Value {
args = append(args, val)
}
depthVal := 0
if tags.Depth != nil {
depthVal = *tags.Depth
}
var depthCondition string
if depthVal != -1 {
depthCondition = fmt.Sprintf("WHERE depth < %d", depthVal)
}
query := `parents AS (
SELECT parent_id AS root_id, child_id AS item_id, 0 AS depth FROM tags_relations WHERE parent_id IN` + getInBinding(len(tags.Value)) + `
UNION
SELECT root_id, child_id, depth + 1 FROM tags_relations INNER JOIN parents ON item_id = parent_id ` + depthCondition + `
)`
f.addRecursiveWith(query, args...)
f.addLeftJoin("parents", "", "parents.item_id = tags.id")
addHierarchicalConditionClauses(f, tags, "parents", "root_id")
}
if len(tags.Excludes) > 0 {
var args []interface{}
for _, val := range tags.Excludes {
args = append(args, val)
}
depthVal := 0
if tags.Depth != nil {
depthVal = *tags.Depth
}
var depthCondition string
if depthVal != -1 {
depthCondition = fmt.Sprintf("WHERE depth < %d", depthVal)
}
query := `parents2 AS (
SELECT parent_id AS root_id, child_id AS item_id, 0 AS depth FROM tags_relations WHERE parent_id IN` + getInBinding(len(tags.Excludes)) + `
UNION
SELECT root_id, child_id, depth + 1 FROM tags_relations INNER JOIN parents2 ON item_id = parent_id ` + depthCondition + `
)`
f.addRecursiveWith(query, args...)
f.addLeftJoin("parents2", "", "parents2.item_id = tags.id")
addHierarchicalConditionClauses(f, models.HierarchicalMultiCriterionInput{
Value: tags.Excludes,
Depth: tags.Depth,
Modifier: models.CriterionModifierExcludes,
}, "parents2", "root_id")
}
}
}
}
func (qb *tagFilterHandler) childrenCriterionHandler(criterion *models.HierarchicalMultiCriterionInput) criterionHandlerFunc {
return func(ctx context.Context, f *filterBuilder) {
if criterion != nil {
tags := criterion.CombineExcludes()
// validate the modifier
switch tags.Modifier {
case models.CriterionModifierIncludesAll, models.CriterionModifierIncludes, models.CriterionModifierExcludes, models.CriterionModifierIsNull, models.CriterionModifierNotNull:
// valid
default:
f.setError(fmt.Errorf("invalid modifier %s for tag parent/children", criterion.Modifier))
}
if tags.Modifier == models.CriterionModifierIsNull || tags.Modifier == models.CriterionModifierNotNull {
var notClause string
if tags.Modifier == models.CriterionModifierNotNull {
notClause = "NOT"
}
f.addLeftJoin("tags_relations", "child_relations", "tags.id = child_relations.parent_id")
f.addWhere(fmt.Sprintf("child_relations.child_id IS %s NULL", notClause))
return
}
if len(tags.Value) == 0 && len(tags.Excludes) == 0 {
return
}
if len(tags.Value) > 0 {
var args []interface{}
for _, val := range tags.Value {
args = append(args, val)
}
depthVal := 0
if tags.Depth != nil {
depthVal = *tags.Depth
}
var depthCondition string
if depthVal != -1 {
depthCondition = fmt.Sprintf("WHERE depth < %d", depthVal)
}
query := `children AS (
SELECT child_id AS root_id, parent_id AS item_id, 0 AS depth FROM tags_relations WHERE child_id IN` + getInBinding(len(tags.Value)) + `
UNION
SELECT root_id, parent_id, depth + 1 FROM tags_relations INNER JOIN children ON item_id = child_id ` + depthCondition + `
)`
f.addRecursiveWith(query, args...)
f.addLeftJoin("children", "", "children.item_id = tags.id")
addHierarchicalConditionClauses(f, tags, "children", "root_id")
}
if len(tags.Excludes) > 0 {
var args []interface{}
for _, val := range tags.Excludes {
args = append(args, val)
}
depthVal := 0
if tags.Depth != nil {
depthVal = *tags.Depth
}
var depthCondition string
if depthVal != -1 {
depthCondition = fmt.Sprintf("WHERE depth < %d", depthVal)
}
query := `children2 AS (
SELECT child_id AS root_id, parent_id AS item_id, 0 AS depth FROM tags_relations WHERE child_id IN` + getInBinding(len(tags.Excludes)) + `
UNION
SELECT root_id, parent_id, depth + 1 FROM tags_relations INNER JOIN children2 ON item_id = child_id ` + depthCondition + `
)`
f.addRecursiveWith(query, args...)
f.addLeftJoin("children2", "", "children2.item_id = tags.id")
addHierarchicalConditionClauses(f, models.HierarchicalMultiCriterionInput{
Value: tags.Excludes,
Depth: tags.Depth,
Modifier: models.CriterionModifierExcludes,
}, "children2", "root_id")
}
}
}
}
func (qb *tagFilterHandler) parentCountCriterionHandler(parentCount *models.IntCriterionInput) criterionHandlerFunc {
return func(ctx context.Context, f *filterBuilder) {
if parentCount != nil {
f.addLeftJoin("tags_relations", "parents_count", "parents_count.child_id = tags.id")
clause, args := getIntCriterionWhereClause("count(distinct parents_count.parent_id)", *parentCount)
f.addHaving(clause, args...)
}
}
}
func (qb *tagFilterHandler) childCountCriterionHandler(childCount *models.IntCriterionInput) criterionHandlerFunc {
return func(ctx context.Context, f *filterBuilder) {
if childCount != nil {
f.addLeftJoin("tags_relations", "children_count", "children_count.parent_id = tags.id")
clause, args := getIntCriterionWhereClause("count(distinct children_count.child_id)", *childCount)
f.addHaving(clause, args...)
}
}
}

View File

@@ -712,7 +712,7 @@ func TestTagQueryParent(t *testing.T) {
assert.Len(t, tags, 1)
// ensure id is correct
assert.Equal(t, sceneIDs[tagIdxWithParentTag], tags[0].ID)
assert.Equal(t, tagIDs[tagIdxWithParentTag], tags[0].ID)
tagCriterion.Modifier = models.CriterionModifierExcludes