Bulk edit tags (#4925)

* Refactor tag relationships and add bulk edit
* Add bulk edit tags dialog
This commit is contained in:
WithoutPants
2024-06-11 13:41:20 +10:00
committed by GitHub
parent e18c050fb1
commit 2d483f2d11
17 changed files with 744 additions and 172 deletions

View File

@@ -450,6 +450,29 @@ func (_m *TagReaderWriter) GetAliases(ctx context.Context, relatedID int) ([]str
return r0, r1
}
// GetChildIDs provides a mock function with given fields: ctx, relatedID
func (_m *TagReaderWriter) GetChildIDs(ctx context.Context, relatedID int) ([]int, error) {
ret := _m.Called(ctx, relatedID)
var r0 []int
if rf, ok := ret.Get(0).(func(context.Context, int) []int); ok {
r0 = rf(ctx, relatedID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]int)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
r1 = rf(ctx, relatedID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// GetImage provides a mock function with given fields: ctx, tagID
func (_m *TagReaderWriter) GetImage(ctx context.Context, tagID int) ([]byte, error) {
ret := _m.Called(ctx, tagID)
@@ -473,6 +496,29 @@ func (_m *TagReaderWriter) GetImage(ctx context.Context, tagID int) ([]byte, err
return r0, r1
}
// GetParentIDs provides a mock function with given fields: ctx, relatedID
func (_m *TagReaderWriter) GetParentIDs(ctx context.Context, relatedID int) ([]int, error) {
ret := _m.Called(ctx, relatedID)
var r0 []int
if rf, ok := ret.Get(0).(func(context.Context, int) []int); ok {
r0 = rf(ctx, relatedID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]int)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
r1 = rf(ctx, relatedID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// HasImage provides a mock function with given fields: ctx, tagID
func (_m *TagReaderWriter) HasImage(ctx context.Context, tagID int) (bool, error) {
ret := _m.Called(ctx, tagID)

View File

@@ -1,6 +1,7 @@
package models
import (
"context"
"time"
)
@@ -12,6 +13,10 @@ type Tag struct {
IgnoreAutoTag bool `json:"ignore_auto_tag"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
Aliases RelatedStrings `json:"aliases"`
ParentIDs RelatedIDs `json:"parent_ids"`
ChildIDs RelatedIDs `json:"tag_ids"`
}
func NewTag() Tag {
@@ -22,6 +27,24 @@ func NewTag() Tag {
}
}
func (s *Tag) LoadAliases(ctx context.Context, l AliasLoader) error {
return s.Aliases.load(func() ([]string, error) {
return l.GetAliases(ctx, s.ID)
})
}
func (s *Tag) LoadParentIDs(ctx context.Context, l TagRelationLoader) error {
return s.ParentIDs.load(func() ([]int, error) {
return l.GetParentIDs(ctx, s.ID)
})
}
func (s *Tag) LoadChildIDs(ctx context.Context, l TagRelationLoader) error {
return s.ChildIDs.load(func() ([]int, error) {
return l.GetChildIDs(ctx, s.ID)
})
}
type TagPartial struct {
Name OptionalString
Description OptionalString
@@ -29,6 +52,10 @@ type TagPartial struct {
IgnoreAutoTag OptionalBool
CreatedAt OptionalTime
UpdatedAt OptionalTime
Aliases *UpdateStrings
ParentIDs *UpdateIDs
ChildIDs *UpdateIDs
}
func NewTagPartial() TagPartial {

View File

@@ -24,6 +24,11 @@ type TagIDLoader interface {
GetTagIDs(ctx context.Context, relatedID int) ([]int, error)
}
type TagRelationLoader interface {
GetParentIDs(ctx context.Context, relatedID int) ([]int, error)
GetChildIDs(ctx context.Context, relatedID int) ([]int, error)
}
type FileIDLoader interface {
GetManyFileIDs(ctx context.Context, ids []int) ([][]FileID, error)
}

View File

@@ -84,6 +84,7 @@ type TagReader interface {
TagCounter
AliasLoader
TagRelationLoader
All(ctx context.Context) ([]*Tag, error)
GetImage(ctx context.Context, tagID int) ([]byte, error)

View File

@@ -36,6 +36,9 @@ var (
studiosStashIDsJoinTable = goqu.T("studio_stash_ids")
moviesURLsJoinTable = goqu.T(movieURLsTable)
tagsAliasesJoinTable = goqu.T(tagAliasesTable)
tagRelationsJoinTable = goqu.T(tagRelationsTable)
)
var (
@@ -294,6 +297,24 @@ var (
table: goqu.T(tagTable),
idColumn: goqu.T(tagTable).Col(idColumn),
}
tagsAliasesTableMgr = &stringTable{
table: table{
table: tagsAliasesJoinTable,
idColumn: tagsAliasesJoinTable.Col(tagIDColumn),
},
stringColumn: tagsAliasesJoinTable.Col(tagAliasColumn),
}
tagsParentTagsTableMgr = &joinTable{
table: table{
table: tagRelationsJoinTable,
idColumn: tagRelationsJoinTable.Col(tagChildIDColumn),
},
fkColumn: tagRelationsJoinTable.Col(tagParentIDColumn),
}
tagsChildTagsTableMgr = *tagsParentTagsTableMgr.invert()
)
var (

View File

@@ -24,6 +24,10 @@ const (
tagAliasColumn = "alias"
tagImageBlobColumn = "image_blob"
tagRelationsTable = "tags_relations"
tagParentIDColumn = "parent_id"
tagChildIDColumn = "child_id"
)
type tagRow struct {
@@ -173,6 +177,24 @@ func (qb *TagStore) Create(ctx context.Context, newObject *models.Tag) error {
return err
}
if newObject.Aliases.Loaded() {
if err := tagsAliasesTableMgr.insertJoins(ctx, id, newObject.Aliases.List()); err != nil {
return err
}
}
if newObject.ParentIDs.Loaded() {
if err := tagsParentTagsTableMgr.insertJoins(ctx, id, newObject.ParentIDs.List()); err != nil {
return err
}
}
if newObject.ChildIDs.Loaded() {
if err := tagsChildTagsTableMgr.insertJoins(ctx, id, newObject.ChildIDs.List()); err != nil {
return err
}
}
updated, err := qb.find(ctx, id)
if err != nil {
return fmt.Errorf("finding after create: %w", err)
@@ -198,6 +220,24 @@ func (qb *TagStore) UpdatePartial(ctx context.Context, id int, partial models.Ta
}
}
if partial.Aliases != nil {
if err := tagsAliasesTableMgr.modifyJoins(ctx, id, partial.Aliases.Values, partial.Aliases.Mode); err != nil {
return nil, err
}
}
if partial.ParentIDs != nil {
if err := tagsParentTagsTableMgr.modifyJoins(ctx, id, partial.ParentIDs.IDs, partial.ParentIDs.Mode); err != nil {
return nil, err
}
}
if partial.ChildIDs != nil {
if err := tagsChildTagsTableMgr.modifyJoins(ctx, id, partial.ChildIDs.IDs, partial.ChildIDs.Mode); err != nil {
return nil, err
}
}
return qb.find(ctx, id)
}
@@ -209,6 +249,24 @@ func (qb *TagStore) Update(ctx context.Context, updatedObject *models.Tag) error
return err
}
if updatedObject.Aliases.Loaded() {
if err := tagsAliasesTableMgr.replaceJoins(ctx, updatedObject.ID, updatedObject.Aliases.List()); err != nil {
return err
}
}
if updatedObject.ParentIDs.Loaded() {
if err := tagsParentTagsTableMgr.replaceJoins(ctx, updatedObject.ID, updatedObject.ParentIDs.List()); err != nil {
return err
}
}
if updatedObject.ChildIDs.Loaded() {
if err := tagsChildTagsTableMgr.replaceJoins(ctx, updatedObject.ID, updatedObject.ChildIDs.List()); err != nil {
return err
}
}
return nil
}
@@ -423,6 +481,14 @@ func (qb *TagStore) FindByNames(ctx context.Context, names []string, nocase bool
return ret, nil
}
func (qb *TagStore) GetParentIDs(ctx context.Context, relatedID int) ([]int, error) {
return tagsParentTagsTableMgr.get(ctx, relatedID)
}
func (qb *TagStore) GetChildIDs(ctx context.Context, relatedID int) ([]int, error) {
return tagsChildTagsTableMgr.get(ctx, relatedID)
}
func (qb *TagStore) FindByParentTagID(ctx context.Context, parentID int) ([]*models.Tag, error) {
query := `
SELECT tags.* FROM tags

View File

@@ -33,6 +33,10 @@ type InvalidTagHierarchyError struct {
}
func (e *InvalidTagHierarchyError) Error() string {
if e.ApplyingTag == "" {
return fmt.Sprintf("cannot apply tag \"%s\" as a %s of tag as it is already %s", e.InvalidTag, e.Direction, e.CurrentRelation)
}
return fmt.Sprintf("cannot apply tag \"%s\" as a %s of \"%s\" as it is already %s (%s)", e.InvalidTag, e.Direction, e.ApplyingTag, e.CurrentRelation, e.TagPath)
}
@@ -80,16 +84,83 @@ func EnsureAliasesUnique(ctx context.Context, id int, aliases []string, qb model
type RelationshipFinder interface {
FindAllAncestors(ctx context.Context, tagID int, excludeIDs []int) ([]*models.TagPath, error)
FindAllDescendants(ctx context.Context, tagID int, excludeIDs []int) ([]*models.TagPath, error)
FindByChildTagID(ctx context.Context, childID int) ([]*models.Tag, error)
FindByParentTagID(ctx context.Context, parentID int) ([]*models.Tag, error)
models.TagRelationLoader
}
func ValidateHierarchy(ctx context.Context, tag *models.Tag, parentIDs, childIDs []int, qb RelationshipFinder) error {
id := tag.ID
func ValidateHierarchyNew(ctx context.Context, parentIDs, childIDs []int, qb RelationshipFinder) error {
allAncestors := make(map[int]*models.TagPath)
allDescendants := make(map[int]*models.TagPath)
parentsAncestors, err := qb.FindAllAncestors(ctx, id, nil)
for _, parentID := range parentIDs {
parentsAncestors, err := qb.FindAllAncestors(ctx, parentID, nil)
if err != nil {
return err
}
for _, ancestorTag := range parentsAncestors {
allAncestors[ancestorTag.ID] = ancestorTag
}
}
for _, childID := range childIDs {
childsDescendants, err := qb.FindAllDescendants(ctx, childID, nil)
if err != nil {
return err
}
for _, descendentTag := range childsDescendants {
allDescendants[descendentTag.ID] = descendentTag
}
}
// Validate that the tag is not a parent of any of its ancestors
validateParent := func(testID int) error {
if parentTag, exists := allDescendants[testID]; exists {
return &InvalidTagHierarchyError{
Direction: "parent",
CurrentRelation: "a descendant",
InvalidTag: parentTag.Name,
TagPath: parentTag.Path,
}
}
return nil
}
// Validate that the tag is not a child of any of its ancestors
validateChild := func(testID int) error {
if childTag, exists := allAncestors[testID]; exists {
return &InvalidTagHierarchyError{
Direction: "child",
CurrentRelation: "an ancestor",
InvalidTag: childTag.Name,
TagPath: childTag.Path,
}
}
return nil
}
for _, parentID := range parentIDs {
if err := validateParent(parentID); err != nil {
return err
}
}
for _, childID := range childIDs {
if err := validateChild(childID); err != nil {
return err
}
}
return nil
}
func ValidateHierarchyExisting(ctx context.Context, tag *models.Tag, parentIDs, childIDs []int, qb RelationshipFinder) error {
allAncestors := make(map[int]*models.TagPath)
allDescendants := make(map[int]*models.TagPath)
parentsAncestors, err := qb.FindAllAncestors(ctx, tag.ID, nil)
if err != nil {
return err
}
@@ -98,7 +169,7 @@ func ValidateHierarchy(ctx context.Context, tag *models.Tag, parentIDs, childIDs
allAncestors[ancestorTag.ID] = ancestorTag
}
childsDescendants, err := qb.FindAllDescendants(ctx, id, nil)
childsDescendants, err := qb.FindAllDescendants(ctx, tag.ID, nil)
if err != nil {
return err
}
@@ -135,28 +206,6 @@ func ValidateHierarchy(ctx context.Context, tag *models.Tag, parentIDs, childIDs
return nil
}
if parentIDs == nil {
parentTags, err := qb.FindByChildTagID(ctx, id)
if err != nil {
return err
}
for _, parentTag := range parentTags {
parentIDs = append(parentIDs, parentTag.ID)
}
}
if childIDs == nil {
childTags, err := qb.FindByParentTagID(ctx, id)
if err != nil {
return err
}
for _, childTag := range childTags {
childIDs = append(childIDs, childTag.ID)
}
}
for _, parentID := range parentIDs {
if err := validateParent(parentID); err != nil {
return err
@@ -176,38 +225,38 @@ func MergeHierarchy(ctx context.Context, destination int, sources []int, qb Rela
var mergedParents, mergedChildren []int
allIds := append([]int{destination}, sources...)
addTo := func(mergedItems []int, tags []*models.Tag) []int {
addTo := func(mergedItems []int, tagIDs []int) []int {
Tags:
for _, tag := range tags {
for _, tagID := range tagIDs {
// Ignore tags which are already set
for _, existingItem := range mergedItems {
if tag.ID == existingItem {
if tagID == existingItem {
continue Tags
}
}
// Ignore tags which are being merged, as these are rolled up anyway (if A is merged into B any direct link between them can be ignored)
for _, id := range allIds {
if tag.ID == id {
if tagID == id {
continue Tags
}
}
mergedItems = append(mergedItems, tag.ID)
mergedItems = append(mergedItems, tagID)
}
return mergedItems
}
for _, id := range allIds {
parents, err := qb.FindByChildTagID(ctx, id)
parents, err := qb.GetParentIDs(ctx, id)
if err != nil {
return nil, nil, err
}
mergedParents = addTo(mergedParents, parents)
children, err := qb.FindByParentTagID(ctx, id)
children, err := qb.GetChildIDs(ctx, id)
if err != nil {
return nil, nil, err
}

View File

@@ -211,14 +211,11 @@ var testUniqueHierarchyCases = []testUniqueHierarchyCase{
func TestEnsureHierarchy(t *testing.T) {
for _, tc := range testUniqueHierarchyCases {
testEnsureHierarchy(t, tc, false, false)
testEnsureHierarchy(t, tc, true, false)
testEnsureHierarchy(t, tc, false, true)
testEnsureHierarchy(t, tc, true, true)
testEnsureHierarchy(t, tc)
}
}
func testEnsureHierarchy(t *testing.T, tc testUniqueHierarchyCase, queryParents, queryChildren bool) {
func testEnsureHierarchy(t *testing.T, tc testUniqueHierarchyCase) {
db := mocks.NewDatabase()
var parentIDs, childIDs []int
@@ -244,16 +241,6 @@ func testEnsureHierarchy(t *testing.T, tc testUniqueHierarchyCase, queryParents,
}
}
if queryParents {
parentIDs = nil
db.Tag.On("FindByChildTagID", testCtx, tc.id).Return(tc.parents, nil).Once()
}
if queryChildren {
childIDs = nil
db.Tag.On("FindByParentTagID", testCtx, tc.id).Return(tc.children, nil).Once()
}
db.Tag.On("FindAllAncestors", testCtx, mock.AnythingOfType("int"), []int(nil)).Return(func(ctx context.Context, tagID int, excludeIDs []int) []*models.TagPath {
return tc.onFindAllAncestors
}, func(ctx context.Context, tagID int, excludeIDs []int) error {
@@ -272,7 +259,7 @@ func testEnsureHierarchy(t *testing.T, tc testUniqueHierarchyCase, queryParents,
return fmt.Errorf("undefined descendants for: %d", tagID)
}).Maybe()
res := ValidateHierarchy(testCtx, testUniqueHierarchyTags[tc.id], parentIDs, childIDs, db.Tag)
res := ValidateHierarchyExisting(testCtx, testUniqueHierarchyTags[tc.id], parentIDs, childIDs, db.Tag)
assert := assert.New(t)

102
pkg/tag/validate.go Normal file
View File

@@ -0,0 +1,102 @@
package tag
import (
"context"
"errors"
"fmt"
"github.com/stashapp/stash/pkg/models"
)
var (
ErrNameMissing = errors.New("tag name must not be blank")
)
type NotFoundError struct {
id int
}
func (e *NotFoundError) Error() string {
return fmt.Sprintf("tag with id %d not found", e.id)
}
func ValidateCreate(ctx context.Context, tag models.Tag, qb models.TagReader) error {
if tag.Name == "" {
return ErrNameMissing
}
if err := EnsureTagNameUnique(ctx, 0, tag.Name, qb); err != nil {
return err
}
if tag.Aliases.Loaded() {
if err := EnsureAliasesUnique(ctx, tag.ID, tag.Aliases.List(), qb); err != nil {
return err
}
}
if len(tag.ParentIDs.List()) > 0 || len(tag.ChildIDs.List()) > 0 {
if err := ValidateHierarchyNew(ctx, tag.ParentIDs.List(), tag.ChildIDs.List(), qb); err != nil {
return err
}
}
return nil
}
func ValidateUpdate(ctx context.Context, id int, partial models.TagPartial, qb models.TagReader) error {
existing, err := qb.Find(ctx, id)
if err != nil {
return err
}
if existing == nil {
return &NotFoundError{id}
}
if partial.Name.Set {
if partial.Name.Value == "" {
return ErrNameMissing
}
if err := EnsureTagNameUnique(ctx, id, partial.Name.Value, qb); err != nil {
return err
}
}
if partial.Aliases != nil {
if err := existing.LoadAliases(ctx, qb); err != nil {
return err
}
if err := EnsureAliasesUnique(ctx, id, partial.Aliases.Apply(existing.Aliases.List()), qb); err != nil {
return err
}
}
if partial.ParentIDs != nil || partial.ChildIDs != nil {
if err := existing.LoadParentIDs(ctx, qb); err != nil {
return err
}
if err := existing.LoadChildIDs(ctx, qb); err != nil {
return err
}
parentIDs := partial.ParentIDs
if parentIDs == nil {
parentIDs = &models.UpdateIDs{IDs: existing.ParentIDs.List(), Mode: models.RelationshipUpdateModeSet}
}
childIDs := partial.ChildIDs
if childIDs == nil {
childIDs = &models.UpdateIDs{IDs: existing.ChildIDs.List(), Mode: models.RelationshipUpdateModeSet}
}
if err := ValidateHierarchyExisting(ctx, existing, parentIDs.Apply(existing.ParentIDs.List()), childIDs.Apply(existing.ChildIDs.List()), qb); err != nil {
return err
}
}
return nil
}