mirror of
https://github.com/stashapp/stash.git
synced 2025-12-16 20:07:05 +03:00
Containing Group/Sub-Group relationships (#5105)
* Add UI support for setting containing groups * Show containing groups in group details panel * Move tag hierarchical filter code into separate type * Add depth to scene_count and add sub_group_count * Add sub-groups tab to groups page * Add containing groups to edit groups dialog * Show containing group description in sub-group view * Show group scene number in group scenes view * Add ability to drag move grid cards * Add sub group order option * Add reorder sub-groups interface * Separate page size selector component * Add interfaces to add and remove sub-groups to a group * Separate MultiSet components * Allow setting description while setting containing groups
This commit is contained in:
41
pkg/group/create.go
Normal file
41
pkg/group/create.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package group
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrEmptyName = errors.New("name cannot be empty")
|
||||
ErrHierarchyLoop = errors.New("a group cannot be contained by one of its subgroups")
|
||||
)
|
||||
|
||||
func (s *Service) Create(ctx context.Context, group *models.Group, frontimageData []byte, backimageData []byte) error {
|
||||
r := s.Repository
|
||||
|
||||
if err := s.validateCreate(ctx, group); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err := r.Create(ctx, group)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// update image table
|
||||
if len(frontimageData) > 0 {
|
||||
if err := r.UpdateFrontImage(ctx, group.ID, frontimageData); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(backimageData) > 0 {
|
||||
if err := r.UpdateBackImage(ctx, group.ID, backimageData); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -16,6 +16,18 @@ type ImporterReaderWriter interface {
|
||||
FindByName(ctx context.Context, name string, nocase bool) (*models.Group, error)
|
||||
}
|
||||
|
||||
type SubGroupNotExistError struct {
|
||||
missingSubGroup string
|
||||
}
|
||||
|
||||
func (e SubGroupNotExistError) Error() string {
|
||||
return fmt.Sprintf("sub group <%s> does not exist", e.missingSubGroup)
|
||||
}
|
||||
|
||||
func (e SubGroupNotExistError) MissingSubGroup() string {
|
||||
return e.missingSubGroup
|
||||
}
|
||||
|
||||
type Importer struct {
|
||||
ReaderWriter ImporterReaderWriter
|
||||
StudioWriter models.StudioFinderCreator
|
||||
@@ -202,6 +214,22 @@ func (i *Importer) createStudio(ctx context.Context, name string) (int, error) {
|
||||
}
|
||||
|
||||
func (i *Importer) PostImport(ctx context.Context, id int) error {
|
||||
subGroups, err := i.getSubGroups(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(subGroups) > 0 {
|
||||
if _, err := i.ReaderWriter.UpdatePartial(ctx, id, models.GroupPartial{
|
||||
SubGroups: &models.UpdateGroupDescriptions{
|
||||
Groups: subGroups,
|
||||
Mode: models.RelationshipUpdateModeSet,
|
||||
},
|
||||
}); err != nil {
|
||||
return fmt.Errorf("error setting parents: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(i.frontImageData) > 0 {
|
||||
if err := i.ReaderWriter.UpdateFrontImage(ctx, id, i.frontImageData); err != nil {
|
||||
return fmt.Errorf("error setting group front image: %v", err)
|
||||
@@ -256,3 +284,53 @@ func (i *Importer) Update(ctx context.Context, id int) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *Importer) getSubGroups(ctx context.Context) ([]models.GroupIDDescription, error) {
|
||||
var subGroups []models.GroupIDDescription
|
||||
for _, subGroup := range i.Input.SubGroups {
|
||||
group, err := i.ReaderWriter.FindByName(ctx, subGroup.Group, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error finding parent by name: %v", err)
|
||||
}
|
||||
|
||||
if group == nil {
|
||||
if i.MissingRefBehaviour == models.ImportMissingRefEnumFail {
|
||||
return nil, SubGroupNotExistError{missingSubGroup: subGroup.Group}
|
||||
}
|
||||
|
||||
if i.MissingRefBehaviour == models.ImportMissingRefEnumIgnore {
|
||||
continue
|
||||
}
|
||||
|
||||
if i.MissingRefBehaviour == models.ImportMissingRefEnumCreate {
|
||||
parentID, err := i.createSubGroup(ctx, subGroup.Group)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
subGroups = append(subGroups, models.GroupIDDescription{
|
||||
GroupID: parentID,
|
||||
Description: subGroup.Description,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
subGroups = append(subGroups, models.GroupIDDescription{
|
||||
GroupID: group.ID,
|
||||
Description: subGroup.Description,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return subGroups, nil
|
||||
}
|
||||
|
||||
func (i *Importer) createSubGroup(ctx context.Context, name string) (int, error) {
|
||||
newGroup := models.NewGroup()
|
||||
newGroup.Name = name
|
||||
|
||||
err := i.ReaderWriter.Create(ctx, &newGroup)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return newGroup.ID, nil
|
||||
}
|
||||
|
||||
@@ -30,3 +30,15 @@ func CountByTagID(ctx context.Context, r models.GroupQueryer, id int, depth *int
|
||||
|
||||
return r.QueryCount(ctx, filter, nil)
|
||||
}
|
||||
|
||||
func CountByContainingGroupID(ctx context.Context, r models.GroupQueryer, id int, depth *int) (int, error) {
|
||||
filter := &models.GroupFilterType{
|
||||
ContainingGroups: &models.HierarchicalMultiCriterionInput{
|
||||
Value: []string{strconv.Itoa(id)},
|
||||
Modifier: models.CriterionModifierIncludes,
|
||||
Depth: depth,
|
||||
},
|
||||
}
|
||||
|
||||
return r.QueryCount(ctx, filter, nil)
|
||||
}
|
||||
|
||||
33
pkg/group/reorder.go
Normal file
33
pkg/group/reorder.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package group
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
)
|
||||
|
||||
var ErrInvalidInsertIndex = errors.New("invalid insert index")
|
||||
|
||||
func (s *Service) ReorderSubGroups(ctx context.Context, groupID int, subGroupIDs []int, insertPointID int, insertAfter bool) error {
|
||||
// get the group
|
||||
existing, err := s.Repository.Find(ctx, groupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// ensure it exists
|
||||
if existing == nil {
|
||||
return models.ErrNotFound
|
||||
}
|
||||
|
||||
// TODO - ensure the subgroups exist in the group
|
||||
|
||||
// ensure the insert index is valid
|
||||
if insertPointID < 0 {
|
||||
return ErrInvalidInsertIndex
|
||||
}
|
||||
|
||||
// reorder the subgroups
|
||||
return s.Repository.ReorderSubGroups(ctx, groupID, subGroupIDs, insertPointID, insertAfter)
|
||||
}
|
||||
46
pkg/group/service.go
Normal file
46
pkg/group/service.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package group
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
)
|
||||
|
||||
type CreatorUpdater interface {
|
||||
models.GroupGetter
|
||||
models.GroupCreator
|
||||
models.GroupUpdater
|
||||
|
||||
models.ContainingGroupLoader
|
||||
models.SubGroupLoader
|
||||
|
||||
AnscestorFinder
|
||||
SubGroupIDFinder
|
||||
SubGroupAdder
|
||||
SubGroupRemover
|
||||
SubGroupReorderer
|
||||
}
|
||||
|
||||
type AnscestorFinder interface {
|
||||
FindInAncestors(ctx context.Context, ascestorIDs []int, ids []int) ([]int, error)
|
||||
}
|
||||
|
||||
type SubGroupIDFinder interface {
|
||||
FindSubGroupIDs(ctx context.Context, containingID int, ids []int) ([]int, error)
|
||||
}
|
||||
|
||||
type SubGroupAdder interface {
|
||||
AddSubGroups(ctx context.Context, groupID int, subGroups []models.GroupIDDescription, insertIndex *int) error
|
||||
}
|
||||
|
||||
type SubGroupRemover interface {
|
||||
RemoveSubGroups(ctx context.Context, groupID int, subGroupIDs []int) error
|
||||
}
|
||||
|
||||
type SubGroupReorderer interface {
|
||||
ReorderSubGroups(ctx context.Context, groupID int, subGroupIDs []int, insertID int, insertAfter bool) error
|
||||
}
|
||||
|
||||
type Service struct {
|
||||
Repository CreatorUpdater
|
||||
}
|
||||
112
pkg/group/update.go
Normal file
112
pkg/group/update.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package group
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/sliceutil"
|
||||
)
|
||||
|
||||
type SubGroupAlreadyInGroupError struct {
|
||||
GroupIDs []int
|
||||
}
|
||||
|
||||
func (e *SubGroupAlreadyInGroupError) Error() string {
|
||||
return fmt.Sprintf("subgroups with IDs %v already in group", e.GroupIDs)
|
||||
}
|
||||
|
||||
type ImageInput struct {
|
||||
Image []byte
|
||||
Set bool
|
||||
}
|
||||
|
||||
func (s *Service) UpdatePartial(ctx context.Context, id int, updatedGroup models.GroupPartial, frontImage ImageInput, backImage ImageInput) (*models.Group, error) {
|
||||
if err := s.validateUpdate(ctx, id, updatedGroup); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r := s.Repository
|
||||
|
||||
group, err := r.UpdatePartial(ctx, id, updatedGroup)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// update image table
|
||||
if frontImage.Set {
|
||||
if err := r.UpdateFrontImage(ctx, id, frontImage.Image); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if backImage.Set {
|
||||
if err := r.UpdateBackImage(ctx, id, backImage.Image); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return group, nil
|
||||
}
|
||||
|
||||
func (s *Service) AddSubGroups(ctx context.Context, groupID int, subGroups []models.GroupIDDescription, insertIndex *int) error {
|
||||
// get the group
|
||||
existing, err := s.Repository.Find(ctx, groupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// ensure it exists
|
||||
if existing == nil {
|
||||
return models.ErrNotFound
|
||||
}
|
||||
|
||||
// ensure the subgroups aren't already sub-groups of the group
|
||||
subGroupIDs := sliceutil.Map(subGroups, func(sg models.GroupIDDescription) int {
|
||||
return sg.GroupID
|
||||
})
|
||||
|
||||
existingSubGroupIDs, err := s.Repository.FindSubGroupIDs(ctx, groupID, subGroupIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(existingSubGroupIDs) > 0 {
|
||||
return &SubGroupAlreadyInGroupError{
|
||||
GroupIDs: existingSubGroupIDs,
|
||||
}
|
||||
}
|
||||
|
||||
// validate the hierarchy
|
||||
d := &models.UpdateGroupDescriptions{
|
||||
Groups: subGroups,
|
||||
Mode: models.RelationshipUpdateModeAdd,
|
||||
}
|
||||
if err := s.validateUpdateGroupHierarchy(ctx, existing, nil, d); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// validate insert index
|
||||
if insertIndex != nil && *insertIndex < 0 {
|
||||
return ErrInvalidInsertIndex
|
||||
}
|
||||
|
||||
// add the subgroups
|
||||
return s.Repository.AddSubGroups(ctx, groupID, subGroups, insertIndex)
|
||||
}
|
||||
|
||||
func (s *Service) RemoveSubGroups(ctx context.Context, groupID int, subGroupIDs []int) error {
|
||||
// get the group
|
||||
existing, err := s.Repository.Find(ctx, groupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// ensure it exists
|
||||
if existing == nil {
|
||||
return models.ErrNotFound
|
||||
}
|
||||
|
||||
// add the subgroups
|
||||
return s.Repository.RemoveSubGroups(ctx, groupID, subGroupIDs)
|
||||
}
|
||||
117
pkg/group/validate.go
Normal file
117
pkg/group/validate.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package group
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/sliceutil"
|
||||
)
|
||||
|
||||
func (s *Service) validateCreate(ctx context.Context, group *models.Group) error {
|
||||
if err := validateName(group.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
containingIDs := group.ContainingGroups.IDs()
|
||||
subIDs := group.SubGroups.IDs()
|
||||
|
||||
if err := s.validateGroupHierarchy(ctx, containingIDs, subIDs); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) validateUpdate(ctx context.Context, id int, partial models.GroupPartial) error {
|
||||
// get the existing group - ensure it exists
|
||||
existing, err := s.Repository.Find(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if existing == nil {
|
||||
return models.ErrNotFound
|
||||
}
|
||||
|
||||
if partial.Name.Set {
|
||||
if err := validateName(partial.Name.Value); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.validateUpdateGroupHierarchy(ctx, existing, partial.ContainingGroups, partial.SubGroups); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateName(n string) error {
|
||||
// ensure name is not empty
|
||||
if strings.TrimSpace(n) == "" {
|
||||
return ErrEmptyName
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) validateGroupHierarchy(ctx context.Context, containingIDs []int, subIDs []int) error {
|
||||
// only need to validate if both are non-empty
|
||||
if len(containingIDs) == 0 || len(subIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensure none of the containing groups are in the sub groups
|
||||
found, err := s.Repository.FindInAncestors(ctx, containingIDs, subIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(found) > 0 {
|
||||
return ErrHierarchyLoop
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) validateUpdateGroupHierarchy(ctx context.Context, existing *models.Group, containingGroups *models.UpdateGroupDescriptions, subGroups *models.UpdateGroupDescriptions) error {
|
||||
// no need to validate if there are no changes
|
||||
if containingGroups == nil && subGroups == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := existing.LoadContainingGroupIDs(ctx, s.Repository); err != nil {
|
||||
return err
|
||||
}
|
||||
existingContainingGroups := existing.ContainingGroups.List()
|
||||
|
||||
if err := existing.LoadSubGroupIDs(ctx, s.Repository); err != nil {
|
||||
return err
|
||||
}
|
||||
existingSubGroups := existing.SubGroups.List()
|
||||
|
||||
effectiveContainingGroups := existingContainingGroups
|
||||
if containingGroups != nil {
|
||||
effectiveContainingGroups = containingGroups.Apply(existingContainingGroups)
|
||||
}
|
||||
|
||||
effectiveSubGroups := existingSubGroups
|
||||
if subGroups != nil {
|
||||
effectiveSubGroups = subGroups.Apply(existingSubGroups)
|
||||
}
|
||||
|
||||
containingIDs := idsFromGroupDescriptions(effectiveContainingGroups)
|
||||
subIDs := idsFromGroupDescriptions(effectiveSubGroups)
|
||||
|
||||
// ensure we haven't set the group as a subgroup of itself
|
||||
if sliceutil.Contains(containingIDs, existing.ID) || sliceutil.Contains(subIDs, existing.ID) {
|
||||
return ErrHierarchyLoop
|
||||
}
|
||||
|
||||
return s.validateGroupHierarchy(ctx, containingIDs, subIDs)
|
||||
}
|
||||
|
||||
func idsFromGroupDescriptions(v []models.GroupIDDescription) []int {
|
||||
return sliceutil.Map(v, func(g models.GroupIDDescription) int { return g.GroupID })
|
||||
}
|
||||
@@ -23,6 +23,14 @@ type GroupFilterType struct {
|
||||
TagCount *IntCriterionInput `json:"tag_count"`
|
||||
// Filter by date
|
||||
Date *DateCriterionInput `json:"date"`
|
||||
// Filter by containing groups
|
||||
ContainingGroups *HierarchicalMultiCriterionInput `json:"containing_groups"`
|
||||
// Filter by sub groups
|
||||
SubGroups *HierarchicalMultiCriterionInput `json:"sub_groups"`
|
||||
// Filter by number of containing groups the group has
|
||||
ContainingGroupCount *IntCriterionInput `json:"containing_group_count"`
|
||||
// Filter by number of sub-groups the group has
|
||||
SubGroupCount *IntCriterionInput `json:"sub_group_count"`
|
||||
// Filter by related scenes that meet this criteria
|
||||
ScenesFilter *SceneFilterType `json:"scenes_filter"`
|
||||
// Filter by related studios that meet this criteria
|
||||
|
||||
@@ -11,21 +11,27 @@ import (
|
||||
"github.com/stashapp/stash/pkg/models/json"
|
||||
)
|
||||
|
||||
type SubGroupDescription struct {
|
||||
Group string `json:"name,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
}
|
||||
|
||||
type Group struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Aliases string `json:"aliases,omitempty"`
|
||||
Duration int `json:"duration,omitempty"`
|
||||
Date string `json:"date,omitempty"`
|
||||
Rating int `json:"rating,omitempty"`
|
||||
Director string `json:"director,omitempty"`
|
||||
Synopsis string `json:"synopsis,omitempty"`
|
||||
FrontImage string `json:"front_image,omitempty"`
|
||||
BackImage string `json:"back_image,omitempty"`
|
||||
URLs []string `json:"urls,omitempty"`
|
||||
Studio string `json:"studio,omitempty"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
CreatedAt json.JSONTime `json:"created_at,omitempty"`
|
||||
UpdatedAt json.JSONTime `json:"updated_at,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Aliases string `json:"aliases,omitempty"`
|
||||
Duration int `json:"duration,omitempty"`
|
||||
Date string `json:"date,omitempty"`
|
||||
Rating int `json:"rating,omitempty"`
|
||||
Director string `json:"director,omitempty"`
|
||||
Synopsis string `json:"synopsis,omitempty"`
|
||||
FrontImage string `json:"front_image,omitempty"`
|
||||
BackImage string `json:"back_image,omitempty"`
|
||||
URLs []string `json:"urls,omitempty"`
|
||||
Studio string `json:"studio,omitempty"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
SubGroups []SubGroupDescription `json:"sub_groups,omitempty"`
|
||||
CreatedAt json.JSONTime `json:"created_at,omitempty"`
|
||||
UpdatedAt json.JSONTime `json:"updated_at,omitempty"`
|
||||
|
||||
// deprecated - for import only
|
||||
URL string `json:"url,omitempty"`
|
||||
|
||||
@@ -289,6 +289,29 @@ func (_m *GroupReaderWriter) GetBackImage(ctx context.Context, groupID int) ([]b
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// GetContainingGroupDescriptions provides a mock function with given fields: ctx, id
|
||||
func (_m *GroupReaderWriter) GetContainingGroupDescriptions(ctx context.Context, id int) ([]models.GroupIDDescription, error) {
|
||||
ret := _m.Called(ctx, id)
|
||||
|
||||
var r0 []models.GroupIDDescription
|
||||
if rf, ok := ret.Get(0).(func(context.Context, int) []models.GroupIDDescription); ok {
|
||||
r0 = rf(ctx, id)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]models.GroupIDDescription)
|
||||
}
|
||||
}
|
||||
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
|
||||
r1 = rf(ctx, id)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// GetFrontImage provides a mock function with given fields: ctx, groupID
|
||||
func (_m *GroupReaderWriter) GetFrontImage(ctx context.Context, groupID int) ([]byte, error) {
|
||||
ret := _m.Called(ctx, groupID)
|
||||
@@ -312,6 +335,29 @@ func (_m *GroupReaderWriter) GetFrontImage(ctx context.Context, groupID int) ([]
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// GetSubGroupDescriptions provides a mock function with given fields: ctx, id
|
||||
func (_m *GroupReaderWriter) GetSubGroupDescriptions(ctx context.Context, id int) ([]models.GroupIDDescription, error) {
|
||||
ret := _m.Called(ctx, id)
|
||||
|
||||
var r0 []models.GroupIDDescription
|
||||
if rf, ok := ret.Get(0).(func(context.Context, int) []models.GroupIDDescription); ok {
|
||||
r0 = rf(ctx, id)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]models.GroupIDDescription)
|
||||
}
|
||||
}
|
||||
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
|
||||
r1 = rf(ctx, id)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// GetTagIDs provides a mock function with given fields: ctx, relatedID
|
||||
func (_m *GroupReaderWriter) GetTagIDs(ctx context.Context, relatedID int) ([]int, error) {
|
||||
ret := _m.Called(ctx, relatedID)
|
||||
|
||||
@@ -21,6 +21,9 @@ type Group struct {
|
||||
|
||||
URLs RelatedStrings `json:"urls"`
|
||||
TagIDs RelatedIDs `json:"tag_ids"`
|
||||
|
||||
ContainingGroups RelatedGroupDescriptions `json:"containing_groups"`
|
||||
SubGroups RelatedGroupDescriptions `json:"sub_groups"`
|
||||
}
|
||||
|
||||
func NewGroup() Group {
|
||||
@@ -43,20 +46,34 @@ func (m *Group) LoadTagIDs(ctx context.Context, l TagIDLoader) error {
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Group) LoadContainingGroupIDs(ctx context.Context, l ContainingGroupLoader) error {
|
||||
return m.ContainingGroups.load(func() ([]GroupIDDescription, error) {
|
||||
return l.GetContainingGroupDescriptions(ctx, m.ID)
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Group) LoadSubGroupIDs(ctx context.Context, l SubGroupLoader) error {
|
||||
return m.SubGroups.load(func() ([]GroupIDDescription, error) {
|
||||
return l.GetSubGroupDescriptions(ctx, m.ID)
|
||||
})
|
||||
}
|
||||
|
||||
type GroupPartial struct {
|
||||
Name OptionalString
|
||||
Aliases OptionalString
|
||||
Duration OptionalInt
|
||||
Date OptionalDate
|
||||
// Rating expressed in 1-100 scale
|
||||
Rating OptionalInt
|
||||
StudioID OptionalInt
|
||||
Director OptionalString
|
||||
Synopsis OptionalString
|
||||
URLs *UpdateStrings
|
||||
TagIDs *UpdateIDs
|
||||
CreatedAt OptionalTime
|
||||
UpdatedAt OptionalTime
|
||||
Rating OptionalInt
|
||||
StudioID OptionalInt
|
||||
Director OptionalString
|
||||
Synopsis OptionalString
|
||||
URLs *UpdateStrings
|
||||
TagIDs *UpdateIDs
|
||||
ContainingGroups *UpdateGroupDescriptions
|
||||
SubGroups *UpdateGroupDescriptions
|
||||
CreatedAt OptionalTime
|
||||
UpdatedAt OptionalTime
|
||||
}
|
||||
|
||||
func NewGroupPartial() GroupPartial {
|
||||
|
||||
@@ -68,3 +68,8 @@ func GroupsScenesFromInput(input []SceneMovieInput) ([]GroupsScenes, error) {
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
type GroupIDDescription struct {
|
||||
GroupID int `json:"group_id"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@ package models
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/stashapp/stash/pkg/sliceutil"
|
||||
)
|
||||
|
||||
type SceneIDLoader interface {
|
||||
@@ -37,6 +39,14 @@ type SceneGroupLoader interface {
|
||||
GetGroups(ctx context.Context, id int) ([]GroupsScenes, error)
|
||||
}
|
||||
|
||||
type ContainingGroupLoader interface {
|
||||
GetContainingGroupDescriptions(ctx context.Context, id int) ([]GroupIDDescription, error)
|
||||
}
|
||||
|
||||
type SubGroupLoader interface {
|
||||
GetSubGroupDescriptions(ctx context.Context, id int) ([]GroupIDDescription, error)
|
||||
}
|
||||
|
||||
type StashIDLoader interface {
|
||||
GetStashIDs(ctx context.Context, relatedID int) ([]StashID, error)
|
||||
}
|
||||
@@ -185,6 +195,82 @@ func (r *RelatedGroups) load(fn func() ([]GroupsScenes, error)) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type RelatedGroupDescriptions struct {
|
||||
list []GroupIDDescription
|
||||
}
|
||||
|
||||
// NewRelatedGroups returns a loaded RelateGroups object with the provided groups.
|
||||
// Loaded will return true when called on the returned object if the provided slice is not nil.
|
||||
func NewRelatedGroupDescriptions(list []GroupIDDescription) RelatedGroupDescriptions {
|
||||
return RelatedGroupDescriptions{
|
||||
list: list,
|
||||
}
|
||||
}
|
||||
|
||||
// Loaded returns true if the relationship has been loaded.
|
||||
func (r RelatedGroupDescriptions) Loaded() bool {
|
||||
return r.list != nil
|
||||
}
|
||||
|
||||
func (r RelatedGroupDescriptions) mustLoaded() {
|
||||
if !r.Loaded() {
|
||||
panic("list has not been loaded")
|
||||
}
|
||||
}
|
||||
|
||||
// List returns the related Groups. Panics if the relationship has not been loaded.
|
||||
func (r RelatedGroupDescriptions) List() []GroupIDDescription {
|
||||
r.mustLoaded()
|
||||
|
||||
return r.list
|
||||
}
|
||||
|
||||
// List returns the related Groups. Panics if the relationship has not been loaded.
|
||||
func (r RelatedGroupDescriptions) IDs() []int {
|
||||
r.mustLoaded()
|
||||
|
||||
return sliceutil.Map(r.list, func(d GroupIDDescription) int { return d.GroupID })
|
||||
}
|
||||
|
||||
// Add adds the provided ids to the list. Panics if the relationship has not been loaded.
|
||||
func (r *RelatedGroupDescriptions) Add(groups ...GroupIDDescription) {
|
||||
r.mustLoaded()
|
||||
|
||||
r.list = append(r.list, groups...)
|
||||
}
|
||||
|
||||
// ForID returns the GroupsScenes object for the given group ID. Returns nil if not found.
|
||||
func (r *RelatedGroupDescriptions) ForID(id int) *GroupIDDescription {
|
||||
r.mustLoaded()
|
||||
|
||||
for _, v := range r.list {
|
||||
if v.GroupID == id {
|
||||
return &v
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RelatedGroupDescriptions) load(fn func() ([]GroupIDDescription, error)) error {
|
||||
if r.Loaded() {
|
||||
return nil
|
||||
}
|
||||
|
||||
ids, err := fn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if ids == nil {
|
||||
ids = []GroupIDDescription{}
|
||||
}
|
||||
|
||||
r.list = ids
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type RelatedStashIDs struct {
|
||||
list []StashID
|
||||
}
|
||||
|
||||
@@ -66,6 +66,8 @@ type GroupReader interface {
|
||||
GroupCounter
|
||||
URLLoader
|
||||
TagIDLoader
|
||||
ContainingGroupLoader
|
||||
SubGroupLoader
|
||||
|
||||
All(ctx context.Context) ([]*Group, error)
|
||||
GetFrontImage(ctx context.Context, groupID int) ([]byte, error)
|
||||
|
||||
@@ -37,10 +37,7 @@ type SceneQueryer interface {
|
||||
type SceneCounter interface {
|
||||
Count(ctx context.Context) (int, error)
|
||||
CountByPerformerID(ctx context.Context, performerID int) (int, error)
|
||||
CountByGroupID(ctx context.Context, groupID int) (int, error)
|
||||
CountByFileID(ctx context.Context, fileID FileID) (int, error)
|
||||
CountByStudioID(ctx context.Context, studioID int) (int, error)
|
||||
CountByTagID(ctx context.Context, tagID int) (int, error)
|
||||
CountMissingChecksum(ctx context.Context) (int, error)
|
||||
CountMissingOSHash(ctx context.Context) (int, error)
|
||||
OCountByPerformerID(ctx context.Context, performerID int) (int, error)
|
||||
|
||||
@@ -56,7 +56,7 @@ type SceneFilterType struct {
|
||||
// Filter to only include scenes with this studio
|
||||
Studios *HierarchicalMultiCriterionInput `json:"studios"`
|
||||
// Filter to only include scenes with this group
|
||||
Groups *MultiCriterionInput `json:"groups"`
|
||||
Groups *HierarchicalMultiCriterionInput `json:"groups"`
|
||||
// Filter to only include scenes with this movie
|
||||
Movies *MultiCriterionInput `json:"movies"`
|
||||
// Filter to only include scenes with this gallery
|
||||
|
||||
@@ -133,3 +133,68 @@ func applyUpdate[T comparable](values []T, mode RelationshipUpdateMode, existing
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type UpdateGroupDescriptions struct {
|
||||
Groups []GroupIDDescription `json:"groups"`
|
||||
Mode RelationshipUpdateMode `json:"mode"`
|
||||
}
|
||||
|
||||
// Apply applies the update to a list of existing ids, returning the result.
|
||||
func (u *UpdateGroupDescriptions) Apply(existing []GroupIDDescription) []GroupIDDescription {
|
||||
if u == nil {
|
||||
return existing
|
||||
}
|
||||
|
||||
switch u.Mode {
|
||||
case RelationshipUpdateModeAdd:
|
||||
return u.applyAdd(existing)
|
||||
case RelationshipUpdateModeRemove:
|
||||
return u.applyRemove(existing)
|
||||
case RelationshipUpdateModeSet:
|
||||
return u.Groups
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *UpdateGroupDescriptions) applyAdd(existing []GroupIDDescription) []GroupIDDescription {
|
||||
// overwrite any existing values with the same id
|
||||
ret := append([]GroupIDDescription{}, existing...)
|
||||
for _, v := range u.Groups {
|
||||
found := false
|
||||
for i, vv := range ret {
|
||||
if vv.GroupID == v.GroupID {
|
||||
ret[i] = v
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
ret = append(ret, v)
|
||||
}
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func (u *UpdateGroupDescriptions) applyRemove(existing []GroupIDDescription) []GroupIDDescription {
|
||||
// remove any existing values with the same id
|
||||
var ret []GroupIDDescription
|
||||
for _, v := range existing {
|
||||
found := false
|
||||
for _, vv := range u.Groups {
|
||||
if vv.GroupID == v.GroupID {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// if not found in the remove list, keep it
|
||||
if !found {
|
||||
ret = append(ret, v)
|
||||
}
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
@@ -144,3 +144,15 @@ func CountByTagID(ctx context.Context, r models.SceneQueryer, id int, depth *int
|
||||
|
||||
return r.QueryCount(ctx, filter, nil)
|
||||
}
|
||||
|
||||
func CountByGroupID(ctx context.Context, r models.SceneQueryer, id int, depth *int) (int, error) {
|
||||
filter := &models.SceneFilterType{
|
||||
Groups: &models.HierarchicalMultiCriterionInput{
|
||||
Value: []string{strconv.Itoa(id)},
|
||||
Modifier: models.CriterionModifierIncludes,
|
||||
Depth: depth,
|
||||
},
|
||||
}
|
||||
|
||||
return r.QueryCount(ctx, filter, nil)
|
||||
}
|
||||
|
||||
@@ -146,7 +146,7 @@ func Filter[T any](vs []T, f func(T) bool) []T {
|
||||
return ret
|
||||
}
|
||||
|
||||
// Filter returns the result of applying f to each element of the vs slice.
|
||||
// Map returns the result of applying f to each element of the vs slice.
|
||||
func Map[T any, V any](vs []T, f func(T) V) []V {
|
||||
ret := make([]V, len(vs))
|
||||
for i, v := range vs {
|
||||
|
||||
@@ -30,7 +30,7 @@ const (
|
||||
dbConnTimeout = 30
|
||||
)
|
||||
|
||||
var appSchemaVersion uint = 66
|
||||
var appSchemaVersion uint = 67
|
||||
|
||||
//go:embed migrations/*.sql
|
||||
var migrationsBox embed.FS
|
||||
|
||||
222
pkg/sqlite/filter_hierarchical.go
Normal file
222
pkg/sqlite/filter_hierarchical.go
Normal file
@@ -0,0 +1,222 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/utils"
|
||||
)
|
||||
|
||||
// hierarchicalRelationshipHandler provides handlers for parent, children, parent count, and child count criteria.
|
||||
type hierarchicalRelationshipHandler struct {
|
||||
primaryTable string
|
||||
relationTable string
|
||||
aliasPrefix string
|
||||
parentIDCol string
|
||||
childIDCol string
|
||||
}
|
||||
|
||||
func (h hierarchicalRelationshipHandler) validateModifier(m models.CriterionModifier) error {
|
||||
switch m {
|
||||
case models.CriterionModifierIncludesAll, models.CriterionModifierIncludes, models.CriterionModifierExcludes, models.CriterionModifierIsNull, models.CriterionModifierNotNull:
|
||||
// valid
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("invalid modifier %s", m)
|
||||
}
|
||||
}
|
||||
|
||||
func (h hierarchicalRelationshipHandler) handleNullNotNull(f *filterBuilder, m models.CriterionModifier, isParents bool) {
|
||||
var notClause string
|
||||
if m == models.CriterionModifierNotNull {
|
||||
notClause = "NOT"
|
||||
}
|
||||
|
||||
as := h.aliasPrefix + "_parents"
|
||||
col := h.childIDCol
|
||||
if !isParents {
|
||||
as = h.aliasPrefix + "_children"
|
||||
col = h.parentIDCol
|
||||
}
|
||||
|
||||
// Based on:
|
||||
// f.addLeftJoin("tags_relations", "parent_relations", "tags.id = parent_relations.child_id")
|
||||
// f.addWhere(fmt.Sprintf("parent_relations.parent_id IS %s NULL", notClause))
|
||||
|
||||
f.addLeftJoin(h.relationTable, as, fmt.Sprintf("%s.id = %s.%s", h.primaryTable, as, col))
|
||||
f.addWhere(fmt.Sprintf("%s.%s IS %s NULL", as, col, notClause))
|
||||
}
|
||||
|
||||
func (h hierarchicalRelationshipHandler) parentsAlias() string {
|
||||
return h.aliasPrefix + "_parents"
|
||||
}
|
||||
|
||||
func (h hierarchicalRelationshipHandler) childrenAlias() string {
|
||||
return h.aliasPrefix + "_children"
|
||||
}
|
||||
|
||||
func (h hierarchicalRelationshipHandler) valueQuery(value []string, depth int, alias string, isParents bool) string {
|
||||
var depthCondition string
|
||||
if depth != -1 {
|
||||
depthCondition = fmt.Sprintf("WHERE depth < %d", depth)
|
||||
}
|
||||
|
||||
queryTempl := `{alias} AS (
|
||||
SELECT {root_id_col} AS root_id, {item_id_col} AS item_id, 0 AS depth FROM {relation_table} WHERE {root_id_col} IN` + getInBinding(len(value)) + `
|
||||
UNION
|
||||
SELECT root_id, {item_id_col}, depth + 1 FROM {relation_table} INNER JOIN {alias} ON item_id = {root_id_col} ` + depthCondition + `
|
||||
)`
|
||||
|
||||
var queryMap utils.StrFormatMap
|
||||
if isParents {
|
||||
queryMap = utils.StrFormatMap{
|
||||
"root_id_col": h.parentIDCol,
|
||||
"item_id_col": h.childIDCol,
|
||||
}
|
||||
} else {
|
||||
queryMap = utils.StrFormatMap{
|
||||
"root_id_col": h.childIDCol,
|
||||
"item_id_col": h.parentIDCol,
|
||||
}
|
||||
}
|
||||
|
||||
queryMap["alias"] = alias
|
||||
queryMap["relation_table"] = h.relationTable
|
||||
|
||||
return utils.StrFormat(queryTempl, queryMap)
|
||||
}
|
||||
|
||||
func (h hierarchicalRelationshipHandler) handleValues(f *filterBuilder, c models.HierarchicalMultiCriterionInput, isParents bool, aliasSuffix string) {
|
||||
if len(c.Value) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
var args []interface{}
|
||||
for _, val := range c.Value {
|
||||
args = append(args, val)
|
||||
}
|
||||
|
||||
depthVal := 0
|
||||
if c.Depth != nil {
|
||||
depthVal = *c.Depth
|
||||
}
|
||||
|
||||
tableAlias := h.parentsAlias()
|
||||
if !isParents {
|
||||
tableAlias = h.childrenAlias()
|
||||
}
|
||||
tableAlias += aliasSuffix
|
||||
|
||||
query := h.valueQuery(c.Value, depthVal, tableAlias, isParents)
|
||||
f.addRecursiveWith(query, args...)
|
||||
|
||||
f.addLeftJoin(tableAlias, "", fmt.Sprintf("%s.item_id = %s.id", tableAlias, h.primaryTable))
|
||||
addHierarchicalConditionClauses(f, c, tableAlias, "root_id")
|
||||
}
|
||||
|
||||
func (h hierarchicalRelationshipHandler) handleValuesSimple(f *filterBuilder, value string, isParents bool) {
|
||||
joinCol := h.childIDCol
|
||||
valueCol := h.parentIDCol
|
||||
if !isParents {
|
||||
joinCol = h.parentIDCol
|
||||
valueCol = h.childIDCol
|
||||
}
|
||||
|
||||
tableAlias := h.parentsAlias()
|
||||
if !isParents {
|
||||
tableAlias = h.childrenAlias()
|
||||
}
|
||||
|
||||
f.addInnerJoin(h.relationTable, tableAlias, fmt.Sprintf("%s.%s = %s.id", tableAlias, joinCol, h.primaryTable))
|
||||
f.addWhere(fmt.Sprintf("%s.%s = ?", tableAlias, valueCol), value)
|
||||
}
|
||||
|
||||
func (h hierarchicalRelationshipHandler) hierarchicalCriterionHandler(criterion *models.HierarchicalMultiCriterionInput, isParents bool) criterionHandlerFunc {
|
||||
return func(ctx context.Context, f *filterBuilder) {
|
||||
if criterion != nil {
|
||||
c := criterion.CombineExcludes()
|
||||
|
||||
// validate the modifier
|
||||
if err := h.validateModifier(c.Modifier); err != nil {
|
||||
f.setError(err)
|
||||
return
|
||||
}
|
||||
|
||||
if c.Modifier == models.CriterionModifierIsNull || c.Modifier == models.CriterionModifierNotNull {
|
||||
h.handleNullNotNull(f, c.Modifier, isParents)
|
||||
return
|
||||
}
|
||||
|
||||
if len(c.Value) == 0 && len(c.Excludes) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
depth := 0
|
||||
if c.Depth != nil {
|
||||
depth = *c.Depth
|
||||
}
|
||||
|
||||
// if we have a single include, no excludes, and no depth, we can use a simple join and where clause
|
||||
if (c.Modifier == models.CriterionModifierIncludes || c.Modifier == models.CriterionModifierIncludesAll) && len(c.Value) == 1 && len(c.Excludes) == 0 && depth == 0 {
|
||||
h.handleValuesSimple(f, c.Value[0], isParents)
|
||||
return
|
||||
}
|
||||
|
||||
aliasSuffix := ""
|
||||
h.handleValues(f, c, isParents, aliasSuffix)
|
||||
|
||||
if len(c.Excludes) > 0 {
|
||||
exCriterion := models.HierarchicalMultiCriterionInput{
|
||||
Value: c.Excludes,
|
||||
Depth: c.Depth,
|
||||
Modifier: models.CriterionModifierExcludes,
|
||||
}
|
||||
|
||||
aliasSuffix := "2"
|
||||
h.handleValues(f, exCriterion, isParents, aliasSuffix)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h hierarchicalRelationshipHandler) ParentsCriterionHandler(criterion *models.HierarchicalMultiCriterionInput) criterionHandlerFunc {
|
||||
const isParents = true
|
||||
return h.hierarchicalCriterionHandler(criterion, isParents)
|
||||
}
|
||||
|
||||
func (h hierarchicalRelationshipHandler) ChildrenCriterionHandler(criterion *models.HierarchicalMultiCriterionInput) criterionHandlerFunc {
|
||||
const isParents = false
|
||||
return h.hierarchicalCriterionHandler(criterion, isParents)
|
||||
}
|
||||
|
||||
func (h hierarchicalRelationshipHandler) countCriterionHandler(c *models.IntCriterionInput, isParents bool) criterionHandlerFunc {
|
||||
tableAlias := h.parentsAlias()
|
||||
col := h.childIDCol
|
||||
otherCol := h.parentIDCol
|
||||
if !isParents {
|
||||
tableAlias = h.childrenAlias()
|
||||
col = h.parentIDCol
|
||||
otherCol = h.childIDCol
|
||||
}
|
||||
tableAlias += "_count"
|
||||
|
||||
return func(ctx context.Context, f *filterBuilder) {
|
||||
if c != nil {
|
||||
f.addLeftJoin(h.relationTable, tableAlias, fmt.Sprintf("%s.%s = %s.id", tableAlias, col, h.primaryTable))
|
||||
clause, args := getIntCriterionWhereClause(fmt.Sprintf("count(distinct %s.%s)", tableAlias, otherCol), *c)
|
||||
|
||||
f.addHaving(clause, args...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h hierarchicalRelationshipHandler) ParentCountCriterionHandler(parentCount *models.IntCriterionInput) criterionHandlerFunc {
|
||||
const isParents = true
|
||||
return h.countCriterionHandler(parentCount, isParents)
|
||||
}
|
||||
|
||||
func (h hierarchicalRelationshipHandler) ChildCountCriterionHandler(childCount *models.IntCriterionInput) criterionHandlerFunc {
|
||||
const isParents = false
|
||||
return h.countCriterionHandler(childCount, isParents)
|
||||
}
|
||||
@@ -27,6 +27,8 @@ const (
|
||||
|
||||
groupURLsTable = "group_urls"
|
||||
groupURLColumn = "url"
|
||||
|
||||
groupRelationsTable = "groups_relations"
|
||||
)
|
||||
|
||||
type groupRow struct {
|
||||
@@ -128,6 +130,7 @@ var (
|
||||
type GroupStore struct {
|
||||
blobJoinQueryBuilder
|
||||
tagRelationshipStore
|
||||
groupRelationshipStore
|
||||
|
||||
tableMgr *table
|
||||
}
|
||||
@@ -143,6 +146,9 @@ func NewGroupStore(blobStore *BlobStore) *GroupStore {
|
||||
joinTable: groupsTagsTableMgr,
|
||||
},
|
||||
},
|
||||
groupRelationshipStore: groupRelationshipStore{
|
||||
table: groupRelationshipTableMgr,
|
||||
},
|
||||
|
||||
tableMgr: groupTableMgr,
|
||||
}
|
||||
@@ -176,6 +182,14 @@ func (qb *GroupStore) Create(ctx context.Context, newObject *models.Group) error
|
||||
return err
|
||||
}
|
||||
|
||||
if err := qb.groupRelationshipStore.createContainingRelationships(ctx, id, newObject.ContainingGroups); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := qb.groupRelationshipStore.createSubRelationships(ctx, id, newObject.SubGroups); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updated, err := qb.find(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("finding after create: %w", err)
|
||||
@@ -211,6 +225,14 @@ func (qb *GroupStore) UpdatePartial(ctx context.Context, id int, partial models.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := qb.groupRelationshipStore.modifyContainingRelationships(ctx, id, partial.ContainingGroups); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := qb.groupRelationshipStore.modifySubRelationships(ctx, id, partial.SubGroups); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return qb.find(ctx, id)
|
||||
}
|
||||
|
||||
@@ -232,6 +254,14 @@ func (qb *GroupStore) Update(ctx context.Context, updatedObject *models.Group) e
|
||||
return err
|
||||
}
|
||||
|
||||
if err := qb.groupRelationshipStore.replaceContainingRelationships(ctx, updatedObject.ID, updatedObject.ContainingGroups); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := qb.groupRelationshipStore.replaceSubRelationships(ctx, updatedObject.ID, updatedObject.SubGroups); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -412,9 +442,7 @@ func (qb *GroupStore) makeQuery(ctx context.Context, groupFilter *models.GroupFi
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var err error
|
||||
query.sortAndPagination, err = qb.getGroupSort(findFilter)
|
||||
if err != nil {
|
||||
if err := qb.setGroupSort(&query, findFilter); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -460,11 +488,12 @@ var groupSortOptions = sortOptions{
|
||||
"random",
|
||||
"rating",
|
||||
"scenes_count",
|
||||
"sub_group_order",
|
||||
"tag_count",
|
||||
"updated_at",
|
||||
}
|
||||
|
||||
func (qb *GroupStore) getGroupSort(findFilter *models.FindFilterType) (string, error) {
|
||||
func (qb *GroupStore) setGroupSort(query *queryBuilder, findFilter *models.FindFilterType) error {
|
||||
var sort string
|
||||
var direction string
|
||||
if findFilter == nil {
|
||||
@@ -477,22 +506,31 @@ func (qb *GroupStore) getGroupSort(findFilter *models.FindFilterType) (string, e
|
||||
|
||||
// CVE-2024-32231 - ensure sort is in the list of allowed sorts
|
||||
if err := groupSortOptions.validateSort(sort); err != nil {
|
||||
return "", err
|
||||
return err
|
||||
}
|
||||
|
||||
sortQuery := ""
|
||||
switch sort {
|
||||
case "sub_group_order":
|
||||
// sub_group_order is a special sort that sorts by the order_index of the subgroups
|
||||
if query.hasJoin("groups_parents") {
|
||||
query.sortAndPagination += getSort("order_index", direction, "groups_parents")
|
||||
} else {
|
||||
// this will give unexpected results if the query is not filtered by a parent group and
|
||||
// the group has multiple parents and order indexes
|
||||
query.join(groupRelationsTable, "", "groups.id = groups_relations.sub_id")
|
||||
query.sortAndPagination += getSort("order_index", direction, groupRelationsTable)
|
||||
}
|
||||
case "tag_count":
|
||||
sortQuery += getCountSort(groupTable, groupsTagsTable, groupIDColumn, direction)
|
||||
query.sortAndPagination += getCountSort(groupTable, groupsTagsTable, groupIDColumn, direction)
|
||||
case "scenes_count": // generic getSort won't work for this
|
||||
sortQuery += getCountSort(groupTable, groupsScenesTable, groupIDColumn, direction)
|
||||
query.sortAndPagination += getCountSort(groupTable, groupsScenesTable, groupIDColumn, direction)
|
||||
default:
|
||||
sortQuery += getSort(sort, direction, "groups")
|
||||
query.sortAndPagination += getSort(sort, direction, "groups")
|
||||
}
|
||||
|
||||
// Whatever the sorting, always use name/id as a final sort
|
||||
sortQuery += ", COALESCE(groups.name, groups.id) COLLATE NATURAL_CI ASC"
|
||||
return sortQuery, nil
|
||||
query.sortAndPagination += ", COALESCE(groups.name, groups.id) COLLATE NATURAL_CI ASC"
|
||||
return nil
|
||||
}
|
||||
|
||||
func (qb *GroupStore) queryGroups(ctx context.Context, query string, args []interface{}) ([]*models.Group, error) {
|
||||
@@ -592,3 +630,74 @@ WHERE groups.studio_id = ?
|
||||
func (qb *GroupStore) GetURLs(ctx context.Context, groupID int) ([]string, error) {
|
||||
return groupsURLsTableMgr.get(ctx, groupID)
|
||||
}
|
||||
|
||||
// FindSubGroupIDs returns a list of group IDs where a group in the ids list is a sub-group of the parent group
|
||||
func (qb *GroupStore) FindSubGroupIDs(ctx context.Context, containingID int, ids []int) ([]int, error) {
|
||||
/*
|
||||
SELECT gr.sub_id FROM groups_relations gr
|
||||
WHERE gr.containing_id = :parentID AND gr.sub_id IN (:ids);
|
||||
*/
|
||||
table := groupRelationshipTableMgr.table
|
||||
q := dialect.From(table).Prepared(true).
|
||||
Select(table.Col("sub_id")).Where(
|
||||
table.Col("containing_id").Eq(containingID),
|
||||
table.Col("sub_id").In(ids),
|
||||
)
|
||||
|
||||
const single = false
|
||||
var ret []int
|
||||
if err := queryFunc(ctx, q, single, func(r *sqlx.Rows) error {
|
||||
var id int
|
||||
if err := r.Scan(&id); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ret = append(ret, id)
|
||||
return nil
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// FindInAscestors returns a list of group IDs where a group in the ids list is an ascestor of the ancestor group IDs
|
||||
func (qb *GroupStore) FindInAncestors(ctx context.Context, ascestorIDs []int, ids []int) ([]int, error) {
|
||||
/*
|
||||
WITH RECURSIVE ascestors AS (
|
||||
SELECT g.id AS parent_id FROM groups g WHERE g.id IN (:ascestorIDs)
|
||||
UNION
|
||||
SELECT gr.containing_id FROM groups_relations gr INNER JOIN ascestors a ON a.parent_id = gr.sub_id
|
||||
)
|
||||
SELECT p.parent_id FROM ascestors p WHERE p.parent_id IN (:ids);
|
||||
*/
|
||||
table := qb.table()
|
||||
const ascestors = "ancestors"
|
||||
const parentID = "parent_id"
|
||||
q := dialect.From(ascestors).Prepared(true).
|
||||
WithRecursive(ascestors,
|
||||
dialect.From(qb.table()).Select(table.Col(idColumn).As(parentID)).
|
||||
Where(table.Col(idColumn).In(ascestorIDs)).
|
||||
Union(
|
||||
dialect.From(groupRelationsJoinTable).InnerJoin(
|
||||
goqu.I(ascestors), goqu.On(goqu.I("parent_id").Eq(goqu.I("sub_id"))),
|
||||
).Select("containing_id"),
|
||||
),
|
||||
).Select(parentID).Where(goqu.I(parentID).In(ids))
|
||||
|
||||
const single = false
|
||||
var ret []int
|
||||
if err := queryFunc(ctx, q, single, func(r *sqlx.Rows) error {
|
||||
var id int
|
||||
if err := r.Scan(&id); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ret = append(ret, id)
|
||||
return nil
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
@@ -51,6 +51,14 @@ func (qb *groupFilterHandler) handle(ctx context.Context, f *filterBuilder) {
|
||||
f.handleCriterion(ctx, qb.criterionHandler())
|
||||
}
|
||||
|
||||
var groupHierarchyHandler = hierarchicalRelationshipHandler{
|
||||
primaryTable: groupTable,
|
||||
relationTable: groupRelationsTable,
|
||||
aliasPrefix: groupTable,
|
||||
parentIDCol: "containing_id",
|
||||
childIDCol: "sub_id",
|
||||
}
|
||||
|
||||
func (qb *groupFilterHandler) criterionHandler() criterionHandler {
|
||||
groupFilter := qb.groupFilter
|
||||
return compoundHandler{
|
||||
@@ -66,6 +74,10 @@ func (qb *groupFilterHandler) criterionHandler() criterionHandler {
|
||||
qb.tagsCriterionHandler(groupFilter.Tags),
|
||||
qb.tagCountCriterionHandler(groupFilter.TagCount),
|
||||
&dateCriterionHandler{groupFilter.Date, "groups.date", nil},
|
||||
groupHierarchyHandler.ParentsCriterionHandler(groupFilter.ContainingGroups),
|
||||
groupHierarchyHandler.ChildrenCriterionHandler(groupFilter.SubGroups),
|
||||
groupHierarchyHandler.ParentCountCriterionHandler(groupFilter.ContainingGroupCount),
|
||||
groupHierarchyHandler.ChildCountCriterionHandler(groupFilter.SubGroupCount),
|
||||
×tampCriterionHandler{groupFilter.CreatedAt, "groups.created_at", nil},
|
||||
×tampCriterionHandler{groupFilter.UpdatedAt, "groups.updated_at", nil},
|
||||
|
||||
|
||||
457
pkg/sqlite/group_relationships.go
Normal file
457
pkg/sqlite/group_relationships.go
Normal file
@@ -0,0 +1,457 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/doug-martin/goqu/v9"
|
||||
"github.com/doug-martin/goqu/v9/exp"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"gopkg.in/guregu/null.v4"
|
||||
"gopkg.in/guregu/null.v4/zero"
|
||||
)
|
||||
|
||||
type groupRelationshipRow struct {
|
||||
ContainingID int `db:"containing_id"`
|
||||
SubID int `db:"sub_id"`
|
||||
OrderIndex int `db:"order_index"`
|
||||
Description zero.String `db:"description"`
|
||||
}
|
||||
|
||||
func (r groupRelationshipRow) resolve(useContainingID bool) models.GroupIDDescription {
|
||||
id := r.ContainingID
|
||||
if !useContainingID {
|
||||
id = r.SubID
|
||||
}
|
||||
|
||||
return models.GroupIDDescription{
|
||||
GroupID: id,
|
||||
Description: r.Description.String,
|
||||
}
|
||||
}
|
||||
|
||||
type groupRelationshipStore struct {
|
||||
table *table
|
||||
}
|
||||
|
||||
func (s *groupRelationshipStore) GetContainingGroupDescriptions(ctx context.Context, id int) ([]models.GroupIDDescription, error) {
|
||||
const idIsContaining = false
|
||||
return s.getGroupRelationships(ctx, id, idIsContaining)
|
||||
}
|
||||
|
||||
func (s *groupRelationshipStore) GetSubGroupDescriptions(ctx context.Context, id int) ([]models.GroupIDDescription, error) {
|
||||
const idIsContaining = true
|
||||
return s.getGroupRelationships(ctx, id, idIsContaining)
|
||||
}
|
||||
|
||||
func (s *groupRelationshipStore) getGroupRelationships(ctx context.Context, id int, idIsContaining bool) ([]models.GroupIDDescription, error) {
|
||||
col := "containing_id"
|
||||
if !idIsContaining {
|
||||
col = "sub_id"
|
||||
}
|
||||
|
||||
table := s.table.table
|
||||
q := dialect.Select(table.All()).
|
||||
From(table).
|
||||
Where(table.Col(col).Eq(id)).
|
||||
Order(table.Col("order_index").Asc())
|
||||
|
||||
const single = false
|
||||
var ret []models.GroupIDDescription
|
||||
if err := queryFunc(ctx, q, single, func(rows *sqlx.Rows) error {
|
||||
var row groupRelationshipRow
|
||||
if err := rows.StructScan(&row); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ret = append(ret, row.resolve(!idIsContaining))
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return nil, fmt.Errorf("getting group relationships from %s: %w", table.GetTable(), err)
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// getMaxOrderIndex gets the maximum order index for the containing group with the given id
|
||||
func (s *groupRelationshipStore) getMaxOrderIndex(ctx context.Context, containingID int) (int, error) {
|
||||
idColumn := s.table.table.Col("containing_id")
|
||||
|
||||
q := dialect.Select(goqu.MAX("order_index")).
|
||||
From(s.table.table).
|
||||
Where(idColumn.Eq(containingID))
|
||||
|
||||
var maxOrderIndex zero.Int
|
||||
if err := querySimple(ctx, q, &maxOrderIndex); err != nil {
|
||||
return 0, fmt.Errorf("getting max order index: %w", err)
|
||||
}
|
||||
|
||||
return int(maxOrderIndex.Int64), nil
|
||||
}
|
||||
|
||||
// createRelationships creates relationships between a group and other groups.
|
||||
// If idIsContaining is true, the provided id is the containing group.
|
||||
func (s *groupRelationshipStore) createRelationships(ctx context.Context, id int, d models.RelatedGroupDescriptions, idIsContaining bool) error {
|
||||
if d.Loaded() {
|
||||
for i, v := range d.List() {
|
||||
orderIndex := i + 1
|
||||
|
||||
r := groupRelationshipRow{
|
||||
ContainingID: id,
|
||||
SubID: v.GroupID,
|
||||
OrderIndex: orderIndex,
|
||||
Description: zero.StringFrom(v.Description),
|
||||
}
|
||||
|
||||
if !idIsContaining {
|
||||
// get the max order index of the containing groups sub groups
|
||||
containingID := v.GroupID
|
||||
maxOrderIndex, err := s.getMaxOrderIndex(ctx, containingID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.ContainingID = v.GroupID
|
||||
r.SubID = id
|
||||
r.OrderIndex = maxOrderIndex + 1
|
||||
}
|
||||
|
||||
_, err := s.table.insert(ctx, r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("inserting into %s: %w", s.table.table.GetTable(), err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createRelationships creates relationships between a group and other groups.
|
||||
// If idIsContaining is true, the provided id is the containing group.
|
||||
func (s *groupRelationshipStore) createContainingRelationships(ctx context.Context, id int, d models.RelatedGroupDescriptions) error {
|
||||
const idIsContaining = false
|
||||
return s.createRelationships(ctx, id, d, idIsContaining)
|
||||
}
|
||||
|
||||
// createRelationships creates relationships between a group and other groups.
|
||||
// If idIsContaining is true, the provided id is the containing group.
|
||||
func (s *groupRelationshipStore) createSubRelationships(ctx context.Context, id int, d models.RelatedGroupDescriptions) error {
|
||||
const idIsContaining = true
|
||||
return s.createRelationships(ctx, id, d, idIsContaining)
|
||||
}
|
||||
|
||||
func (s *groupRelationshipStore) replaceRelationships(ctx context.Context, id int, d models.RelatedGroupDescriptions, idIsContaining bool) error {
|
||||
// always destroy the existing relationships even if the new list is empty
|
||||
if err := s.destroyAllJoins(ctx, id, idIsContaining); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.createRelationships(ctx, id, d, idIsContaining)
|
||||
}
|
||||
|
||||
func (s *groupRelationshipStore) replaceContainingRelationships(ctx context.Context, id int, d models.RelatedGroupDescriptions) error {
|
||||
const idIsContaining = false
|
||||
return s.replaceRelationships(ctx, id, d, idIsContaining)
|
||||
}
|
||||
|
||||
func (s *groupRelationshipStore) replaceSubRelationships(ctx context.Context, id int, d models.RelatedGroupDescriptions) error {
|
||||
const idIsContaining = true
|
||||
return s.replaceRelationships(ctx, id, d, idIsContaining)
|
||||
}
|
||||
|
||||
func (s *groupRelationshipStore) modifyRelationships(ctx context.Context, id int, v *models.UpdateGroupDescriptions, idIsContaining bool) error {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch v.Mode {
|
||||
case models.RelationshipUpdateModeSet:
|
||||
return s.replaceJoins(ctx, id, *v, idIsContaining)
|
||||
case models.RelationshipUpdateModeAdd:
|
||||
return s.addJoins(ctx, id, v.Groups, idIsContaining)
|
||||
case models.RelationshipUpdateModeRemove:
|
||||
toRemove := make([]int, len(v.Groups))
|
||||
for i, vv := range v.Groups {
|
||||
toRemove[i] = vv.GroupID
|
||||
}
|
||||
return s.destroyJoins(ctx, id, toRemove, idIsContaining)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *groupRelationshipStore) modifyContainingRelationships(ctx context.Context, id int, v *models.UpdateGroupDescriptions) error {
|
||||
const idIsContaining = false
|
||||
return s.modifyRelationships(ctx, id, v, idIsContaining)
|
||||
}
|
||||
|
||||
func (s *groupRelationshipStore) modifySubRelationships(ctx context.Context, id int, v *models.UpdateGroupDescriptions) error {
|
||||
const idIsContaining = true
|
||||
return s.modifyRelationships(ctx, id, v, idIsContaining)
|
||||
}
|
||||
|
||||
func (s *groupRelationshipStore) addJoins(ctx context.Context, id int, groups []models.GroupIDDescription, idIsContaining bool) error {
|
||||
// if we're adding to a containing group, get the max order index first
|
||||
var maxOrderIndex int
|
||||
if idIsContaining {
|
||||
var err error
|
||||
maxOrderIndex, err = s.getMaxOrderIndex(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
for i, vv := range groups {
|
||||
r := groupRelationshipRow{
|
||||
Description: zero.StringFrom(vv.Description),
|
||||
}
|
||||
|
||||
if idIsContaining {
|
||||
r.ContainingID = id
|
||||
r.SubID = vv.GroupID
|
||||
r.OrderIndex = maxOrderIndex + (i + 1)
|
||||
} else {
|
||||
// get the max order index of the containing groups sub groups
|
||||
containingMaxOrderIndex, err := s.getMaxOrderIndex(ctx, vv.GroupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.ContainingID = vv.GroupID
|
||||
r.SubID = id
|
||||
r.OrderIndex = containingMaxOrderIndex + 1
|
||||
}
|
||||
|
||||
_, err := s.table.insert(ctx, r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("inserting into %s: %w", s.table.table.GetTable(), err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *groupRelationshipStore) destroyAllJoins(ctx context.Context, id int, idIsContaining bool) error {
|
||||
table := s.table.table
|
||||
idColumn := table.Col("containing_id")
|
||||
if !idIsContaining {
|
||||
idColumn = table.Col("sub_id")
|
||||
}
|
||||
|
||||
q := dialect.Delete(table).Where(idColumn.Eq(id))
|
||||
|
||||
if _, err := exec(ctx, q); err != nil {
|
||||
return fmt.Errorf("destroying %s: %w", table.GetTable(), err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *groupRelationshipStore) replaceJoins(ctx context.Context, id int, v models.UpdateGroupDescriptions, idIsContaining bool) error {
|
||||
if err := s.destroyAllJoins(ctx, id, idIsContaining); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// convert to RelatedGroupDescriptions
|
||||
rgd := models.NewRelatedGroupDescriptions(v.Groups)
|
||||
return s.createRelationships(ctx, id, rgd, idIsContaining)
|
||||
}
|
||||
|
||||
func (s *groupRelationshipStore) destroyJoins(ctx context.Context, id int, toRemove []int, idIsContaining bool) error {
|
||||
table := s.table.table
|
||||
idColumn := table.Col("containing_id")
|
||||
fkColumn := table.Col("sub_id")
|
||||
if !idIsContaining {
|
||||
idColumn = table.Col("sub_id")
|
||||
fkColumn = table.Col("containing_id")
|
||||
}
|
||||
|
||||
q := dialect.Delete(table).Where(idColumn.Eq(id), fkColumn.In(toRemove))
|
||||
|
||||
if _, err := exec(ctx, q); err != nil {
|
||||
return fmt.Errorf("destroying %s: %w", table.GetTable(), err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *groupRelationshipStore) getOrderIndexOfSubGroup(ctx context.Context, containingGroupID int, subGroupID int) (int, error) {
|
||||
table := s.table.table
|
||||
q := dialect.Select("order_index").
|
||||
From(table).
|
||||
Where(
|
||||
table.Col("containing_id").Eq(containingGroupID),
|
||||
table.Col("sub_id").Eq(subGroupID),
|
||||
)
|
||||
|
||||
var orderIndex null.Int
|
||||
if err := querySimple(ctx, q, &orderIndex); err != nil {
|
||||
return 0, fmt.Errorf("getting order index: %w", err)
|
||||
}
|
||||
|
||||
if !orderIndex.Valid {
|
||||
return 0, fmt.Errorf("sub-group %d not found in containing group %d", subGroupID, containingGroupID)
|
||||
}
|
||||
|
||||
return int(orderIndex.Int64), nil
|
||||
}
|
||||
|
||||
func (s *groupRelationshipStore) getGroupIDAtOrderIndex(ctx context.Context, containingGroupID int, orderIndex int) (*int, error) {
|
||||
table := s.table.table
|
||||
q := dialect.Select(table.Col("sub_id")).From(table).Where(
|
||||
table.Col("containing_id").Eq(containingGroupID),
|
||||
table.Col("order_index").Eq(orderIndex),
|
||||
)
|
||||
|
||||
var ret null.Int
|
||||
if err := querySimple(ctx, q, &ret); err != nil {
|
||||
return nil, fmt.Errorf("getting sub id for order index: %w", err)
|
||||
}
|
||||
|
||||
if !ret.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
intRet := int(ret.Int64)
|
||||
return &intRet, nil
|
||||
}
|
||||
|
||||
func (s *groupRelationshipStore) getOrderIndexAfterOrderIndex(ctx context.Context, containingGroupID int, orderIndex int) (int, error) {
|
||||
table := s.table.table
|
||||
q := dialect.Select(goqu.MIN("order_index")).From(table).Where(
|
||||
table.Col("containing_id").Eq(containingGroupID),
|
||||
table.Col("order_index").Gt(orderIndex),
|
||||
)
|
||||
|
||||
var ret null.Int
|
||||
if err := querySimple(ctx, q, &ret); err != nil {
|
||||
return 0, fmt.Errorf("getting order index: %w", err)
|
||||
}
|
||||
|
||||
if !ret.Valid {
|
||||
return orderIndex + 1, nil
|
||||
}
|
||||
|
||||
return int(ret.Int64), nil
|
||||
}
|
||||
|
||||
// incrementOrderIndexes increments the order_index value of all sub-groups in the containing group at or after the given index
|
||||
func (s *groupRelationshipStore) incrementOrderIndexes(ctx context.Context, groupID int, indexBefore int) error {
|
||||
table := s.table.table
|
||||
|
||||
// WORKAROUND - sqlite won't allow incrementing the value directly since it causes a
|
||||
// unique constraint violation.
|
||||
// Instead, we first set the order index to a negative value temporarily
|
||||
// see https://stackoverflow.com/a/7703239/695786
|
||||
q := dialect.Update(table).Set(exp.Record{
|
||||
"order_index": goqu.L("-order_index"),
|
||||
}).Where(
|
||||
table.Col("containing_id").Eq(groupID),
|
||||
table.Col("order_index").Gte(indexBefore),
|
||||
)
|
||||
|
||||
if _, err := exec(ctx, q); err != nil {
|
||||
return fmt.Errorf("updating %s: %w", table.GetTable(), err)
|
||||
}
|
||||
|
||||
q = dialect.Update(table).Set(exp.Record{
|
||||
"order_index": goqu.L("1-order_index"),
|
||||
}).Where(
|
||||
table.Col("containing_id").Eq(groupID),
|
||||
table.Col("order_index").Lt(0),
|
||||
)
|
||||
|
||||
if _, err := exec(ctx, q); err != nil {
|
||||
return fmt.Errorf("updating %s: %w", table.GetTable(), err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *groupRelationshipStore) reorderSubGroup(ctx context.Context, groupID int, subGroupID int, insertPointID int, insertAfter bool) error {
|
||||
insertPointIndex, err := s.getOrderIndexOfSubGroup(ctx, groupID, insertPointID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// if we're setting before
|
||||
if insertAfter {
|
||||
insertPointIndex, err = s.getOrderIndexAfterOrderIndex(ctx, groupID, insertPointIndex)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// increment the order index of all sub-groups after and including the insertion point
|
||||
if err := s.incrementOrderIndexes(ctx, groupID, int(insertPointIndex)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// set the order index of the sub-group to the insertion point
|
||||
table := s.table.table
|
||||
q := dialect.Update(table).Set(exp.Record{
|
||||
"order_index": insertPointIndex,
|
||||
}).Where(
|
||||
table.Col("containing_id").Eq(groupID),
|
||||
table.Col("sub_id").Eq(subGroupID),
|
||||
)
|
||||
|
||||
if _, err := exec(ctx, q); err != nil {
|
||||
return fmt.Errorf("updating %s: %w", table.GetTable(), err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *groupRelationshipStore) AddSubGroups(ctx context.Context, groupID int, subGroups []models.GroupIDDescription, insertIndex *int) error {
|
||||
const idIsContaining = true
|
||||
|
||||
if err := s.addJoins(ctx, groupID, subGroups, idIsContaining); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ids := make([]int, len(subGroups))
|
||||
for i, v := range subGroups {
|
||||
ids[i] = v.GroupID
|
||||
}
|
||||
|
||||
if insertIndex != nil {
|
||||
// get the id of the sub-group at the insert index
|
||||
insertPointID, err := s.getGroupIDAtOrderIndex(ctx, groupID, *insertIndex)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if insertPointID == nil {
|
||||
// if the insert index is out of bounds, just assume adding to the end
|
||||
return nil
|
||||
}
|
||||
|
||||
// reorder the sub-groups
|
||||
const insertAfter = false
|
||||
if err := s.ReorderSubGroups(ctx, groupID, ids, *insertPointID, insertAfter); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *groupRelationshipStore) RemoveSubGroups(ctx context.Context, groupID int, subGroupIDs []int) error {
|
||||
const idIsContaining = true
|
||||
return s.destroyJoins(ctx, groupID, subGroupIDs, idIsContaining)
|
||||
}
|
||||
|
||||
func (s *groupRelationshipStore) ReorderSubGroups(ctx context.Context, groupID int, subGroupIDs []int, insertPointID int, insertAfter bool) error {
|
||||
for _, id := range subGroupIDs {
|
||||
if err := s.reorderSubGroup(ctx, groupID, id, insertPointID, insertAfter); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
13
pkg/sqlite/migrations/67_group_relationships.up.sql
Normal file
13
pkg/sqlite/migrations/67_group_relationships.up.sql
Normal file
@@ -0,0 +1,13 @@
|
||||
CREATE TABLE `groups_relations` (
|
||||
`containing_id` integer not null,
|
||||
`sub_id` integer not null,
|
||||
`order_index` integer not null,
|
||||
`description` varchar(255),
|
||||
primary key (`containing_id`, `sub_id`),
|
||||
foreign key (`containing_id`) references `groups`(`id`) on delete cascade,
|
||||
foreign key (`sub_id`) references `groups`(`id`) on delete cascade,
|
||||
check (`containing_id` != `sub_id`)
|
||||
);
|
||||
|
||||
CREATE INDEX `index_groups_relations_sub_id` ON `groups_relations` (`sub_id`);
|
||||
CREATE UNIQUE INDEX `index_groups_relations_order_index_unique` ON `groups_relations` (`containing_id`, `order_index`);
|
||||
@@ -110,6 +110,16 @@ func (qb *queryBuilder) addArg(args ...interface{}) {
|
||||
qb.args = append(qb.args, args...)
|
||||
}
|
||||
|
||||
func (qb *queryBuilder) hasJoin(alias string) bool {
|
||||
for _, j := range qb.joins {
|
||||
if j.alias() == alias {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (qb *queryBuilder) join(table, as, onClause string) {
|
||||
newJoin := join{
|
||||
table: table,
|
||||
|
||||
@@ -791,13 +791,6 @@ func (qb *SceneStore) FindByGroupID(ctx context.Context, groupID int) ([]*models
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (qb *SceneStore) CountByGroupID(ctx context.Context, groupID int) (int, error) {
|
||||
joinTable := scenesGroupsJoinTable
|
||||
|
||||
q := dialect.Select(goqu.COUNT("*")).From(joinTable).Where(joinTable.Col(groupIDColumn).Eq(groupID))
|
||||
return count(ctx, q)
|
||||
}
|
||||
|
||||
func (qb *SceneStore) Count(ctx context.Context) (int, error) {
|
||||
q := dialect.Select(goqu.COUNT("*")).From(qb.table())
|
||||
return count(ctx, q)
|
||||
@@ -858,6 +851,7 @@ func (qb *SceneStore) PlayDuration(ctx context.Context) (float64, error) {
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// TODO - currently only used by unit test
|
||||
func (qb *SceneStore) CountByStudioID(ctx context.Context, studioID int) (int, error) {
|
||||
table := qb.table()
|
||||
|
||||
@@ -865,13 +859,6 @@ func (qb *SceneStore) CountByStudioID(ctx context.Context, studioID int) (int, e
|
||||
return count(ctx, q)
|
||||
}
|
||||
|
||||
func (qb *SceneStore) CountByTagID(ctx context.Context, tagID int) (int, error) {
|
||||
joinTable := scenesTagsJoinTable
|
||||
|
||||
q := dialect.Select(goqu.COUNT("*")).From(joinTable).Where(joinTable.Col(tagIDColumn).Eq(tagID))
|
||||
return count(ctx, q)
|
||||
}
|
||||
|
||||
func (qb *SceneStore) countMissingFingerprints(ctx context.Context, fpType string) (int, error) {
|
||||
fpTable := fingerprintTableMgr.table.As("fingerprints_temp")
|
||||
|
||||
|
||||
@@ -149,7 +149,7 @@ func (qb *sceneFilterHandler) criterionHandler() criterionHandler {
|
||||
studioCriterionHandler(sceneTable, sceneFilter.Studios),
|
||||
|
||||
qb.groupsCriterionHandler(sceneFilter.Groups),
|
||||
qb.groupsCriterionHandler(sceneFilter.Movies),
|
||||
qb.moviesCriterionHandler(sceneFilter.Movies),
|
||||
|
||||
qb.galleriesCriterionHandler(sceneFilter.Galleries),
|
||||
qb.performerTagsCriterionHandler(sceneFilter.PerformerTags),
|
||||
@@ -483,7 +483,8 @@ func (qb *sceneFilterHandler) performerAgeCriterionHandler(performerAge *models.
|
||||
}
|
||||
}
|
||||
|
||||
func (qb *sceneFilterHandler) groupsCriterionHandler(movies *models.MultiCriterionInput) criterionHandlerFunc {
|
||||
// legacy handler
|
||||
func (qb *sceneFilterHandler) moviesCriterionHandler(movies *models.MultiCriterionInput) criterionHandlerFunc {
|
||||
addJoinsFunc := func(f *filterBuilder) {
|
||||
sceneRepository.groups.join(f, "", "scenes.id")
|
||||
f.addLeftJoin("groups", "", "groups_scenes.group_id = groups.id")
|
||||
@@ -492,6 +493,23 @@ func (qb *sceneFilterHandler) groupsCriterionHandler(movies *models.MultiCriteri
|
||||
return h.handler(movies)
|
||||
}
|
||||
|
||||
func (qb *sceneFilterHandler) groupsCriterionHandler(groups *models.HierarchicalMultiCriterionInput) criterionHandlerFunc {
|
||||
h := joinedHierarchicalMultiCriterionHandlerBuilder{
|
||||
primaryTable: sceneTable,
|
||||
foreignTable: groupTable,
|
||||
foreignFK: "group_id",
|
||||
|
||||
relationsTable: groupRelationsTable,
|
||||
parentFK: "containing_id",
|
||||
childFK: "sub_id",
|
||||
joinAs: "scene_group",
|
||||
joinTable: groupsScenesTable,
|
||||
primaryFK: sceneIDColumn,
|
||||
}
|
||||
|
||||
return h.handler(groups)
|
||||
}
|
||||
|
||||
func (qb *sceneFilterHandler) galleriesCriterionHandler(galleries *models.MultiCriterionInput) criterionHandlerFunc {
|
||||
addJoinsFunc := func(f *filterBuilder) {
|
||||
sceneRepository.galleries.join(f, "", "scenes.id")
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/sliceutil"
|
||||
"github.com/stashapp/stash/pkg/sliceutil/intslice"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@@ -2217,7 +2218,7 @@ func TestSceneQuery(t *testing.T) {
|
||||
},
|
||||
})
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("PerformerStore.Query() error = %v, wantErr %v", err, tt.wantErr)
|
||||
t.Errorf("SceneStore.Query() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -3873,6 +3874,100 @@ func TestSceneQueryStudioDepth(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestSceneGroups(t *testing.T) {
|
||||
type criterion struct {
|
||||
valueIdxs []int
|
||||
modifier models.CriterionModifier
|
||||
depth int
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
c criterion
|
||||
q string
|
||||
includeIdxs []int
|
||||
excludeIdxs []int
|
||||
}{
|
||||
{
|
||||
"includes",
|
||||
criterion{
|
||||
[]int{groupIdxWithScene},
|
||||
models.CriterionModifierIncludes,
|
||||
0,
|
||||
},
|
||||
"",
|
||||
[]int{sceneIdxWithGroup},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"excludes",
|
||||
criterion{
|
||||
[]int{groupIdxWithScene},
|
||||
models.CriterionModifierExcludes,
|
||||
0,
|
||||
},
|
||||
getSceneStringValue(sceneIdxWithGroup, titleField),
|
||||
nil,
|
||||
[]int{sceneIdxWithGroup},
|
||||
},
|
||||
{
|
||||
"includes (depth = 1)",
|
||||
criterion{
|
||||
[]int{groupIdxWithChildWithScene},
|
||||
models.CriterionModifierIncludes,
|
||||
1,
|
||||
},
|
||||
"",
|
||||
[]int{sceneIdxWithGroupWithParent},
|
||||
nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
valueIDs := indexesToIDs(groupIDs, tt.c.valueIdxs)
|
||||
|
||||
runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) {
|
||||
assert := assert.New(t)
|
||||
|
||||
sceneFilter := &models.SceneFilterType{
|
||||
Groups: &models.HierarchicalMultiCriterionInput{
|
||||
Value: intslice.IntSliceToStringSlice(valueIDs),
|
||||
Modifier: tt.c.modifier,
|
||||
},
|
||||
}
|
||||
|
||||
if tt.c.depth != 0 {
|
||||
sceneFilter.Groups.Depth = &tt.c.depth
|
||||
}
|
||||
|
||||
findFilter := &models.FindFilterType{}
|
||||
if tt.q != "" {
|
||||
findFilter.Q = &tt.q
|
||||
}
|
||||
|
||||
results, err := db.Scene.Query(ctx, models.SceneQueryOptions{
|
||||
SceneFilter: sceneFilter,
|
||||
QueryOptions: models.QueryOptions{
|
||||
FindFilter: findFilter,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("SceneStore.Query() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
include := indexesToIDs(sceneIDs, tt.includeIdxs)
|
||||
exclude := indexesToIDs(sceneIDs, tt.excludeIdxs)
|
||||
|
||||
assert.Subset(results.IDs, include)
|
||||
|
||||
for _, e := range exclude {
|
||||
assert.NotContains(results.IDs, e)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSceneQueryMovies(t *testing.T) {
|
||||
withTxn(func(ctx context.Context) error {
|
||||
sqb := db.Scene
|
||||
@@ -4188,78 +4283,6 @@ func verifyScenesPerformerCount(t *testing.T, performerCountCriterion models.Int
|
||||
})
|
||||
}
|
||||
|
||||
func TestSceneCountByTagID(t *testing.T) {
|
||||
withTxn(func(ctx context.Context) error {
|
||||
sqb := db.Scene
|
||||
|
||||
sceneCount, err := sqb.CountByTagID(ctx, tagIDs[tagIdxWithScene])
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("error calling CountByTagID: %s", err.Error())
|
||||
}
|
||||
|
||||
assert.Equal(t, 1, sceneCount)
|
||||
|
||||
sceneCount, err = sqb.CountByTagID(ctx, 0)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("error calling CountByTagID: %s", err.Error())
|
||||
}
|
||||
|
||||
assert.Equal(t, 0, sceneCount)
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func TestSceneCountByGroupID(t *testing.T) {
|
||||
withTxn(func(ctx context.Context) error {
|
||||
sqb := db.Scene
|
||||
|
||||
sceneCount, err := sqb.CountByGroupID(ctx, groupIDs[groupIdxWithScene])
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("error calling CountByGroupID: %s", err.Error())
|
||||
}
|
||||
|
||||
assert.Equal(t, 1, sceneCount)
|
||||
|
||||
sceneCount, err = sqb.CountByGroupID(ctx, 0)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("error calling CountByGroupID: %s", err.Error())
|
||||
}
|
||||
|
||||
assert.Equal(t, 0, sceneCount)
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func TestSceneCountByStudioID(t *testing.T) {
|
||||
withTxn(func(ctx context.Context) error {
|
||||
sqb := db.Scene
|
||||
|
||||
sceneCount, err := sqb.CountByStudioID(ctx, studioIDs[studioIdxWithScene])
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("error calling CountByStudioID: %s", err.Error())
|
||||
}
|
||||
|
||||
assert.Equal(t, 1, sceneCount)
|
||||
|
||||
sceneCount, err = sqb.CountByStudioID(ctx, 0)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("error calling CountByStudioID: %s", err.Error())
|
||||
}
|
||||
|
||||
assert.Equal(t, 0, sceneCount)
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func TestFindByMovieID(t *testing.T) {
|
||||
withTxn(func(ctx context.Context) error {
|
||||
sqb := db.Scene
|
||||
|
||||
@@ -78,6 +78,7 @@ const (
|
||||
sceneIdxWithGrandChildStudio
|
||||
sceneIdxMissingPhash
|
||||
sceneIdxWithPerformerParentTag
|
||||
sceneIdxWithGroupWithParent
|
||||
// new indexes above
|
||||
lastSceneIdx
|
||||
|
||||
@@ -153,9 +154,15 @@ const (
|
||||
groupIdxWithTag
|
||||
groupIdxWithTwoTags
|
||||
groupIdxWithThreeTags
|
||||
groupIdxWithGrandChild
|
||||
groupIdxWithChild
|
||||
groupIdxWithParentAndChild
|
||||
groupIdxWithParent
|
||||
groupIdxWithGrandParent
|
||||
groupIdxWithParentAndScene
|
||||
groupIdxWithChildWithScene
|
||||
// groups with dup names start from the end
|
||||
// create 7 more basic groups (can remove this if we add more indexes)
|
||||
groupIdxWithDupName = groupIdxWithStudio + 7
|
||||
groupIdxWithDupName
|
||||
|
||||
groupsNameCase = groupIdxWithDupName
|
||||
groupsNameNoCase = 1
|
||||
@@ -390,7 +397,8 @@ var (
|
||||
}
|
||||
|
||||
sceneGroups = linkMap{
|
||||
sceneIdxWithGroup: {groupIdxWithScene},
|
||||
sceneIdxWithGroup: {groupIdxWithScene},
|
||||
sceneIdxWithGroupWithParent: {groupIdxWithParentAndScene},
|
||||
}
|
||||
|
||||
sceneStudios = map[int]int{
|
||||
@@ -541,15 +549,31 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
groupParentLinks = [][2]int{
|
||||
{groupIdxWithChild, groupIdxWithParent},
|
||||
{groupIdxWithGrandChild, groupIdxWithParentAndChild},
|
||||
{groupIdxWithParentAndChild, groupIdxWithGrandParent},
|
||||
{groupIdxWithChildWithScene, groupIdxWithParentAndScene},
|
||||
}
|
||||
)
|
||||
|
||||
func indexesToIDs(ids []int, indexes []int) []int {
|
||||
ret := make([]int, len(indexes))
|
||||
for i, idx := range indexes {
|
||||
ret[i] = ids[idx]
|
||||
ret[i] = indexToID(ids, idx)
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func indexToID(ids []int, idx int) int {
|
||||
if idx < 0 {
|
||||
return invalidID
|
||||
}
|
||||
return ids[idx]
|
||||
}
|
||||
|
||||
func indexFromID(ids []int, id int) int {
|
||||
for i, v := range ids {
|
||||
if v == id {
|
||||
@@ -697,6 +721,10 @@ func populateDB() error {
|
||||
return fmt.Errorf("error linking tags parent: %s", err.Error())
|
||||
}
|
||||
|
||||
if err := linkGroupsParent(ctx, db.Group); err != nil {
|
||||
return fmt.Errorf("error linking tags parent: %s", err.Error())
|
||||
}
|
||||
|
||||
for _, ms := range markerSpecs {
|
||||
if err := createMarker(ctx, db.SceneMarker, ms); err != nil {
|
||||
return fmt.Errorf("error creating scene marker: %s", err.Error())
|
||||
@@ -1885,6 +1913,24 @@ func linkTagsParent(ctx context.Context, qb models.TagReaderWriter) error {
|
||||
})
|
||||
}
|
||||
|
||||
func linkGroupsParent(ctx context.Context, qb models.GroupReaderWriter) error {
|
||||
return doLinks(groupParentLinks, func(parentIndex, childIndex int) error {
|
||||
groupID := groupIDs[childIndex]
|
||||
|
||||
p := models.GroupPartial{
|
||||
ContainingGroups: &models.UpdateGroupDescriptions{
|
||||
Groups: []models.GroupIDDescription{
|
||||
{GroupID: groupIDs[parentIndex]},
|
||||
},
|
||||
Mode: models.RelationshipUpdateModeAdd,
|
||||
},
|
||||
}
|
||||
|
||||
_, err := qb.UpdatePartial(ctx, groupID, p)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func addTagImage(ctx context.Context, qb models.TagWriter, tagIndex int) error {
|
||||
return qb.UpdateImage(ctx, tagIDs[tagIndex], []byte("image"))
|
||||
}
|
||||
|
||||
@@ -37,8 +37,9 @@ var (
|
||||
studiosTagsJoinTable = goqu.T(studiosTagsTable)
|
||||
studiosStashIDsJoinTable = goqu.T("studio_stash_ids")
|
||||
|
||||
groupsURLsJoinTable = goqu.T(groupURLsTable)
|
||||
groupsTagsJoinTable = goqu.T(groupsTagsTable)
|
||||
groupsURLsJoinTable = goqu.T(groupURLsTable)
|
||||
groupsTagsJoinTable = goqu.T(groupsTagsTable)
|
||||
groupRelationsJoinTable = goqu.T(groupRelationsTable)
|
||||
|
||||
tagsAliasesJoinTable = goqu.T(tagAliasesTable)
|
||||
tagRelationsJoinTable = goqu.T(tagRelationsTable)
|
||||
@@ -361,6 +362,10 @@ var (
|
||||
foreignTable: tagTableMgr,
|
||||
orderBy: tagTableMgr.table.Col("name").Asc(),
|
||||
}
|
||||
|
||||
groupRelationshipTableMgr = &table{
|
||||
table: groupRelationsJoinTable,
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
|
||||
@@ -2,7 +2,6 @@ package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
)
|
||||
@@ -51,6 +50,14 @@ func (qb *tagFilterHandler) handle(ctx context.Context, f *filterBuilder) {
|
||||
f.handleCriterion(ctx, qb.criterionHandler())
|
||||
}
|
||||
|
||||
var tagHierarchyHandler = hierarchicalRelationshipHandler{
|
||||
primaryTable: tagTable,
|
||||
relationTable: tagRelationsTable,
|
||||
aliasPrefix: tagTable,
|
||||
parentIDCol: "parent_id",
|
||||
childIDCol: "child_id",
|
||||
}
|
||||
|
||||
func (qb *tagFilterHandler) criterionHandler() criterionHandler {
|
||||
tagFilter := qb.tagFilter
|
||||
return compoundHandler{
|
||||
@@ -72,10 +79,10 @@ func (qb *tagFilterHandler) criterionHandler() criterionHandler {
|
||||
qb.groupCountCriterionHandler(tagFilter.MovieCount),
|
||||
|
||||
qb.markerCountCriterionHandler(tagFilter.MarkerCount),
|
||||
qb.parentsCriterionHandler(tagFilter.Parents),
|
||||
qb.childrenCriterionHandler(tagFilter.Children),
|
||||
qb.parentCountCriterionHandler(tagFilter.ParentCount),
|
||||
qb.childCountCriterionHandler(tagFilter.ChildCount),
|
||||
tagHierarchyHandler.ParentsCriterionHandler(tagFilter.Parents),
|
||||
tagHierarchyHandler.ChildrenCriterionHandler(tagFilter.Children),
|
||||
tagHierarchyHandler.ParentCountCriterionHandler(tagFilter.ParentCount),
|
||||
tagHierarchyHandler.ChildCountCriterionHandler(tagFilter.ChildCount),
|
||||
×tampCriterionHandler{tagFilter.CreatedAt, "tags.created_at", nil},
|
||||
×tampCriterionHandler{tagFilter.UpdatedAt, "tags.updated_at", nil},
|
||||
|
||||
@@ -212,213 +219,3 @@ func (qb *tagFilterHandler) markerCountCriterionHandler(markerCount *models.IntC
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (qb *tagFilterHandler) parentsCriterionHandler(criterion *models.HierarchicalMultiCriterionInput) criterionHandlerFunc {
|
||||
return func(ctx context.Context, f *filterBuilder) {
|
||||
if criterion != nil {
|
||||
tags := criterion.CombineExcludes()
|
||||
|
||||
// validate the modifier
|
||||
switch tags.Modifier {
|
||||
case models.CriterionModifierIncludesAll, models.CriterionModifierIncludes, models.CriterionModifierExcludes, models.CriterionModifierIsNull, models.CriterionModifierNotNull:
|
||||
// valid
|
||||
default:
|
||||
f.setError(fmt.Errorf("invalid modifier %s for tag parent/children", criterion.Modifier))
|
||||
}
|
||||
|
||||
if tags.Modifier == models.CriterionModifierIsNull || tags.Modifier == models.CriterionModifierNotNull {
|
||||
var notClause string
|
||||
if tags.Modifier == models.CriterionModifierNotNull {
|
||||
notClause = "NOT"
|
||||
}
|
||||
|
||||
f.addLeftJoin("tags_relations", "parent_relations", "tags.id = parent_relations.child_id")
|
||||
|
||||
f.addWhere(fmt.Sprintf("parent_relations.parent_id IS %s NULL", notClause))
|
||||
return
|
||||
}
|
||||
|
||||
if len(tags.Value) == 0 && len(tags.Excludes) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if len(tags.Value) > 0 {
|
||||
var args []interface{}
|
||||
for _, val := range tags.Value {
|
||||
args = append(args, val)
|
||||
}
|
||||
|
||||
depthVal := 0
|
||||
if tags.Depth != nil {
|
||||
depthVal = *tags.Depth
|
||||
}
|
||||
|
||||
var depthCondition string
|
||||
if depthVal != -1 {
|
||||
depthCondition = fmt.Sprintf("WHERE depth < %d", depthVal)
|
||||
}
|
||||
|
||||
query := `parents AS (
|
||||
SELECT parent_id AS root_id, child_id AS item_id, 0 AS depth FROM tags_relations WHERE parent_id IN` + getInBinding(len(tags.Value)) + `
|
||||
UNION
|
||||
SELECT root_id, child_id, depth + 1 FROM tags_relations INNER JOIN parents ON item_id = parent_id ` + depthCondition + `
|
||||
)`
|
||||
|
||||
f.addRecursiveWith(query, args...)
|
||||
|
||||
f.addLeftJoin("parents", "", "parents.item_id = tags.id")
|
||||
|
||||
addHierarchicalConditionClauses(f, tags, "parents", "root_id")
|
||||
}
|
||||
|
||||
if len(tags.Excludes) > 0 {
|
||||
var args []interface{}
|
||||
for _, val := range tags.Excludes {
|
||||
args = append(args, val)
|
||||
}
|
||||
|
||||
depthVal := 0
|
||||
if tags.Depth != nil {
|
||||
depthVal = *tags.Depth
|
||||
}
|
||||
|
||||
var depthCondition string
|
||||
if depthVal != -1 {
|
||||
depthCondition = fmt.Sprintf("WHERE depth < %d", depthVal)
|
||||
}
|
||||
|
||||
query := `parents2 AS (
|
||||
SELECT parent_id AS root_id, child_id AS item_id, 0 AS depth FROM tags_relations WHERE parent_id IN` + getInBinding(len(tags.Excludes)) + `
|
||||
UNION
|
||||
SELECT root_id, child_id, depth + 1 FROM tags_relations INNER JOIN parents2 ON item_id = parent_id ` + depthCondition + `
|
||||
)`
|
||||
|
||||
f.addRecursiveWith(query, args...)
|
||||
|
||||
f.addLeftJoin("parents2", "", "parents2.item_id = tags.id")
|
||||
|
||||
addHierarchicalConditionClauses(f, models.HierarchicalMultiCriterionInput{
|
||||
Value: tags.Excludes,
|
||||
Depth: tags.Depth,
|
||||
Modifier: models.CriterionModifierExcludes,
|
||||
}, "parents2", "root_id")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (qb *tagFilterHandler) childrenCriterionHandler(criterion *models.HierarchicalMultiCriterionInput) criterionHandlerFunc {
|
||||
return func(ctx context.Context, f *filterBuilder) {
|
||||
if criterion != nil {
|
||||
tags := criterion.CombineExcludes()
|
||||
|
||||
// validate the modifier
|
||||
switch tags.Modifier {
|
||||
case models.CriterionModifierIncludesAll, models.CriterionModifierIncludes, models.CriterionModifierExcludes, models.CriterionModifierIsNull, models.CriterionModifierNotNull:
|
||||
// valid
|
||||
default:
|
||||
f.setError(fmt.Errorf("invalid modifier %s for tag parent/children", criterion.Modifier))
|
||||
}
|
||||
|
||||
if tags.Modifier == models.CriterionModifierIsNull || tags.Modifier == models.CriterionModifierNotNull {
|
||||
var notClause string
|
||||
if tags.Modifier == models.CriterionModifierNotNull {
|
||||
notClause = "NOT"
|
||||
}
|
||||
|
||||
f.addLeftJoin("tags_relations", "child_relations", "tags.id = child_relations.parent_id")
|
||||
|
||||
f.addWhere(fmt.Sprintf("child_relations.child_id IS %s NULL", notClause))
|
||||
return
|
||||
}
|
||||
|
||||
if len(tags.Value) == 0 && len(tags.Excludes) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if len(tags.Value) > 0 {
|
||||
var args []interface{}
|
||||
for _, val := range tags.Value {
|
||||
args = append(args, val)
|
||||
}
|
||||
|
||||
depthVal := 0
|
||||
if tags.Depth != nil {
|
||||
depthVal = *tags.Depth
|
||||
}
|
||||
|
||||
var depthCondition string
|
||||
if depthVal != -1 {
|
||||
depthCondition = fmt.Sprintf("WHERE depth < %d", depthVal)
|
||||
}
|
||||
|
||||
query := `children AS (
|
||||
SELECT child_id AS root_id, parent_id AS item_id, 0 AS depth FROM tags_relations WHERE child_id IN` + getInBinding(len(tags.Value)) + `
|
||||
UNION
|
||||
SELECT root_id, parent_id, depth + 1 FROM tags_relations INNER JOIN children ON item_id = child_id ` + depthCondition + `
|
||||
)`
|
||||
|
||||
f.addRecursiveWith(query, args...)
|
||||
|
||||
f.addLeftJoin("children", "", "children.item_id = tags.id")
|
||||
|
||||
addHierarchicalConditionClauses(f, tags, "children", "root_id")
|
||||
}
|
||||
|
||||
if len(tags.Excludes) > 0 {
|
||||
var args []interface{}
|
||||
for _, val := range tags.Excludes {
|
||||
args = append(args, val)
|
||||
}
|
||||
|
||||
depthVal := 0
|
||||
if tags.Depth != nil {
|
||||
depthVal = *tags.Depth
|
||||
}
|
||||
|
||||
var depthCondition string
|
||||
if depthVal != -1 {
|
||||
depthCondition = fmt.Sprintf("WHERE depth < %d", depthVal)
|
||||
}
|
||||
|
||||
query := `children2 AS (
|
||||
SELECT child_id AS root_id, parent_id AS item_id, 0 AS depth FROM tags_relations WHERE child_id IN` + getInBinding(len(tags.Excludes)) + `
|
||||
UNION
|
||||
SELECT root_id, parent_id, depth + 1 FROM tags_relations INNER JOIN children2 ON item_id = child_id ` + depthCondition + `
|
||||
)`
|
||||
|
||||
f.addRecursiveWith(query, args...)
|
||||
|
||||
f.addLeftJoin("children2", "", "children2.item_id = tags.id")
|
||||
|
||||
addHierarchicalConditionClauses(f, models.HierarchicalMultiCriterionInput{
|
||||
Value: tags.Excludes,
|
||||
Depth: tags.Depth,
|
||||
Modifier: models.CriterionModifierExcludes,
|
||||
}, "children2", "root_id")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (qb *tagFilterHandler) parentCountCriterionHandler(parentCount *models.IntCriterionInput) criterionHandlerFunc {
|
||||
return func(ctx context.Context, f *filterBuilder) {
|
||||
if parentCount != nil {
|
||||
f.addLeftJoin("tags_relations", "parents_count", "parents_count.child_id = tags.id")
|
||||
clause, args := getIntCriterionWhereClause("count(distinct parents_count.parent_id)", *parentCount)
|
||||
|
||||
f.addHaving(clause, args...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (qb *tagFilterHandler) childCountCriterionHandler(childCount *models.IntCriterionInput) criterionHandlerFunc {
|
||||
return func(ctx context.Context, f *filterBuilder) {
|
||||
if childCount != nil {
|
||||
f.addLeftJoin("tags_relations", "children_count", "children_count.parent_id = tags.id")
|
||||
clause, args := getIntCriterionWhereClause("count(distinct children_count.child_id)", *childCount)
|
||||
|
||||
f.addHaving(clause, args...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -712,7 +712,7 @@ func TestTagQueryParent(t *testing.T) {
|
||||
assert.Len(t, tags, 1)
|
||||
|
||||
// ensure id is correct
|
||||
assert.Equal(t, sceneIDs[tagIdxWithParentTag], tags[0].ID)
|
||||
assert.Equal(t, tagIDs[tagIdxWithParentTag], tags[0].ID)
|
||||
|
||||
tagCriterion.Modifier = models.CriterionModifierExcludes
|
||||
|
||||
|
||||
Reference in New Issue
Block a user