Add tags to studios (#4858)

* Fix makeTagFilter mode

* Remove studio_tags filter criterion

This is handled by studios_filter. The support for this still needs to be added in the UI, so I have removed the criterion options in the short-term.
---------
Co-authored-by: WithoutPants <53250216+WithoutPants@users.noreply.github.com>
This commit is contained in:
bob123491234
2024-06-18 00:55:20 -05:00
committed by GitHub
parent f26766033e
commit b3d35dfae4
51 changed files with 844 additions and 13 deletions

View File

@@ -22,6 +22,7 @@ type Studio struct {
Details string `json:"details,omitempty"`
Aliases []string `json:"aliases,omitempty"`
StashIDs []models.StashID `json:"stash_ids,omitempty"`
Tags []string `json:"tags,omitempty"`
IgnoreAutoTag bool `json:"ignore_auto_tag,omitempty"`
}

View File

@@ -58,6 +58,27 @@ func (_m *StudioReaderWriter) Count(ctx context.Context) (int, error) {
return r0, r1
}
// CountByTagID provides a mock function with given fields: ctx, tagID
func (_m *StudioReaderWriter) CountByTagID(ctx context.Context, tagID int) (int, error) {
ret := _m.Called(ctx, tagID)
var r0 int
if rf, ok := ret.Get(0).(func(context.Context, int) int); ok {
r0 = rf(ctx, tagID)
} else {
r0 = ret.Get(0).(int)
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
r1 = rf(ctx, tagID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Create provides a mock function with given fields: ctx, newStudio
func (_m *StudioReaderWriter) Create(ctx context.Context, newStudio *models.Studio) error {
ret := _m.Called(ctx, newStudio)
@@ -316,6 +337,29 @@ func (_m *StudioReaderWriter) GetStashIDs(ctx context.Context, relatedID int) ([
return r0, r1
}
// GetTagIDs provides a mock function with given fields: ctx, relatedID
func (_m *StudioReaderWriter) GetTagIDs(ctx context.Context, relatedID int) ([]int, error) {
ret := _m.Called(ctx, relatedID)
var r0 []int
if rf, ok := ret.Get(0).(func(context.Context, int) []int); ok {
r0 = rf(ctx, relatedID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]int)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
r1 = rf(ctx, relatedID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// HasImage provides a mock function with given fields: ctx, studioID
func (_m *StudioReaderWriter) HasImage(ctx context.Context, studioID int) (bool, error) {
ret := _m.Called(ctx, studioID)
@@ -367,6 +411,27 @@ func (_m *StudioReaderWriter) Query(ctx context.Context, studioFilter *models.St
return r0, r1, r2
}
// QueryCount provides a mock function with given fields: ctx, studioFilter, findFilter
func (_m *StudioReaderWriter) QueryCount(ctx context.Context, studioFilter *models.StudioFilterType, findFilter *models.FindFilterType) (int, error) {
ret := _m.Called(ctx, studioFilter, findFilter)
var r0 int
if rf, ok := ret.Get(0).(func(context.Context, *models.StudioFilterType, *models.FindFilterType) int); ok {
r0 = rf(ctx, studioFilter, findFilter)
} else {
r0 = ret.Get(0).(int)
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *models.StudioFilterType, *models.FindFilterType) error); ok {
r1 = rf(ctx, studioFilter, findFilter)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// QueryForAutoTag provides a mock function with given fields: ctx, words
func (_m *StudioReaderWriter) QueryForAutoTag(ctx context.Context, words []string) ([]*models.Studio, error) {
ret := _m.Called(ctx, words)

View File

@@ -427,6 +427,29 @@ func (_m *TagReaderWriter) FindBySceneMarkerID(ctx context.Context, sceneMarkerI
return r0, r1
}
// FindByStudioID provides a mock function with given fields: ctx, studioID
func (_m *TagReaderWriter) FindByStudioID(ctx context.Context, studioID int) ([]*models.Tag, error) {
ret := _m.Called(ctx, studioID)
var r0 []*models.Tag
if rf, ok := ret.Get(0).(func(context.Context, int) []*models.Tag); ok {
r0 = rf(ctx, studioID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Tag)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
r1 = rf(ctx, studioID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// FindMany provides a mock function with given fields: ctx, ids
func (_m *TagReaderWriter) FindMany(ctx context.Context, ids []int) ([]*models.Tag, error) {
ret := _m.Called(ctx, ids)

View File

@@ -19,6 +19,7 @@ type Studio struct {
IgnoreAutoTag bool `json:"ignore_auto_tag"`
Aliases RelatedStrings `json:"aliases"`
TagIDs RelatedIDs `json:"tag_ids"`
StashIDs RelatedStashIDs `json:"stash_ids"`
}
@@ -45,6 +46,7 @@ type StudioPartial struct {
IgnoreAutoTag OptionalBool
Aliases *UpdateStrings
TagIDs *UpdateIDs
StashIDs *UpdateStashIDs
}
@@ -61,6 +63,12 @@ func (s *Studio) LoadAliases(ctx context.Context, l AliasLoader) error {
})
}
func (s *Studio) LoadTagIDs(ctx context.Context, l TagIDLoader) error {
return s.TagIDs.load(func() ([]int, error) {
return l.GetTagIDs(ctx, s.ID)
})
}
func (s *Studio) LoadStashIDs(ctx context.Context, l StashIDLoader) error {
return s.StashIDs.load(func() ([]StashID, error) {
return l.GetStashIDs(ctx, s.ID)
@@ -72,6 +80,10 @@ func (s *Studio) LoadRelationships(ctx context.Context, l PerformerReader) error
return err
}
if err := s.LoadTagIDs(ctx, l); err != nil {
return err
}
if err := s.LoadStashIDs(ctx, l); err != nil {
return err
}

View File

@@ -22,6 +22,7 @@ type StudioFinder interface {
// StudioQueryer provides methods to query studios.
type StudioQueryer interface {
Query(ctx context.Context, studioFilter *StudioFilterType, findFilter *FindFilterType) ([]*Studio, int, error)
QueryCount(ctx context.Context, studioFilter *StudioFilterType, findFilter *FindFilterType) (int, error)
}
type StudioAutoTagQueryer interface {
@@ -36,6 +37,7 @@ type StudioAutoTagQueryer interface {
// StudioCounter provides methods to count studios.
type StudioCounter interface {
Count(ctx context.Context) (int, error)
CountByTagID(ctx context.Context, tagID int) (int, error)
}
// StudioCreator provides methods to create studios.
@@ -74,6 +76,7 @@ type StudioReader interface {
AliasLoader
StashIDLoader
TagIDLoader
All(ctx context.Context) ([]*Studio, error)
GetImage(ctx context.Context, studioID int) ([]byte, error)

View File

@@ -22,6 +22,7 @@ type TagFinder interface {
FindByPerformerID(ctx context.Context, performerID int) ([]*Tag, error)
FindByMovieID(ctx context.Context, movieID int) ([]*Tag, error)
FindBySceneMarkerID(ctx context.Context, sceneMarkerID int) ([]*Tag, error)
FindByStudioID(ctx context.Context, studioID int) ([]*Tag, error)
FindByName(ctx context.Context, name string, nocase bool) (*Tag, error)
FindByNames(ctx context.Context, names []string, nocase bool) ([]*Tag, error)
}

View File

@@ -14,6 +14,10 @@ type StudioFilterType struct {
IsMissing *string `json:"is_missing"`
// Filter by rating expressed as 1-100
Rating100 *IntCriterionInput `json:"rating100"`
// Filter to only include studios with these tags
Tags *HierarchicalMultiCriterionInput `json:"tags"`
// Filter by tag count
TagCount *IntCriterionInput `json:"tag_count"`
// Filter by favorite
Favorite *bool `json:"favorite"`
// Filter by scene count
@@ -53,6 +57,7 @@ type StudioCreateInput struct {
Favorite *bool `json:"favorite"`
Details *string `json:"details"`
Aliases []string `json:"aliases"`
TagIds []string `json:"tag_ids"`
IgnoreAutoTag *bool `json:"ignore_auto_tag"`
}
@@ -68,5 +73,6 @@ type StudioUpdateInput struct {
Favorite *bool `json:"favorite"`
Details *string `json:"details"`
Aliases []string `json:"aliases"`
TagIds []string `json:"tag_ids"`
IgnoreAutoTag *bool `json:"ignore_auto_tag"`
}

View File

@@ -20,6 +20,8 @@ type TagFilterType struct {
GalleryCount *IntCriterionInput `json:"gallery_count"`
// Filter by number of performers with this tag
PerformerCount *IntCriterionInput `json:"performer_count"`
// Filter by number of studios with this tag
StudioCount *IntCriterionInput `json:"studio_count"`
// Filter by number of movies with this tag
MovieCount *IntCriterionInput `json:"movie_count"`
// Filter by number of markers with this tag

View File

@@ -30,7 +30,7 @@ const (
dbConnTimeout = 30
)
var appSchemaVersion uint = 62
var appSchemaVersion uint = 63
//go:embed migrations/*.sql
var migrationsBox embed.FS

View File

@@ -0,0 +1,9 @@
CREATE TABLE `studios_tags` (
`studio_id` integer NOT NULL,
`tag_id` integer NOT NULL,
foreign key(`studio_id`) references `studios`(`id`) on delete CASCADE,
foreign key(`tag_id`) references `tags`(`id`) on delete CASCADE,
PRIMARY KEY(`studio_id`, `tag_id`)
);
CREATE INDEX `index_studios_tags_on_tag_id` on `studios_tags` (`tag_id`);

View File

@@ -207,6 +207,9 @@ const (
tagIdxWithPerformer
tagIdx1WithPerformer
tagIdx2WithPerformer
tagIdxWithStudio
tagIdx1WithStudio
tagIdx2WithStudio
tagIdxWithGallery
tagIdx1WithGallery
tagIdx2WithGallery
@@ -245,6 +248,10 @@ const (
studioIdxWithScenePerformer
studioIdxWithImagePerformer
studioIdxWithGalleryPerformer
studioIdxWithTag
studioIdx2WithTag
studioIdxWithTwoTags
studioIdxWithParentTag
studioIdxWithGrandChild
studioIdxWithParentAndChild
studioIdxWithGrandParent
@@ -510,6 +517,15 @@ var (
}
)
var (
studioTags = linkMap{
studioIdxWithTag: {tagIdxWithStudio},
studioIdx2WithTag: {tagIdx2WithStudio},
studioIdxWithTwoTags: {tagIdx1WithStudio, tagIdx2WithStudio},
studioIdxWithParentTag: {tagIdxWithParentAndChild},
}
)
var (
performerTags = linkMap{
performerIdxWithTag: {tagIdxWithPerformer},
@@ -1566,6 +1582,11 @@ func getTagPerformerCount(id int) int {
return len(performerTags.reverseLookup(idx))
}
func getTagStudioCount(id int) int {
idx := indexFromID(tagIDs, id)
return len(studioTags.reverseLookup(idx))
}
func getTagParentCount(id int) int {
if id == tagIDs[tagIdxWithParentTag] || id == tagIDs[tagIdxWithGrandParent] || id == tagIDs[tagIdxWithParentAndChild] {
return 1
@@ -1681,11 +1702,13 @@ func createStudios(ctx context.Context, n int, o int) error {
// studios [ i ] and [ n + o - i - 1 ] should have similar names with only the Name!=NaMe part different
name = getStudioStringValue(index, name)
tids := indexesToIDs(tagIDs, studioTags[i])
studio := models.Studio{
Name: name,
URL: getStudioStringValue(index, urlField),
Favorite: getStudioBoolValue(index),
IgnoreAutoTag: getIgnoreAutoTag(i),
TagIDs: models.NewRelatedIDs(tids),
}
// only add aliases for some scenes
if i == studioIdxWithMovie || i%5 == 0 {

View File

@@ -25,6 +25,7 @@ const (
studioParentIDColumn = "parent_id"
studioNameColumn = "name"
studioImageBlobColumn = "image_blob"
studiosTagsTable = "studios_tags"
)
type studioRow struct {
@@ -94,6 +95,7 @@ type studioRepositoryType struct {
repository
stashIDs stashIDRepository
tags joinRepository
scenes repository
images repository
@@ -124,11 +126,21 @@ var (
tableName: galleryTable,
idColumn: studioIDColumn,
},
tags: joinRepository{
repository: repository{
tableName: studiosTagsTable,
idColumn: studioIDColumn,
},
fkColumn: tagIDColumn,
foreignTable: tagTable,
orderBy: "tags.name ASC",
},
}
)
type StudioStore struct {
blobJoinQueryBuilder
tagRelationshipStore
tableMgr *table
}
@@ -139,6 +151,11 @@ func NewStudioStore(blobStore *BlobStore) *StudioStore {
blobStore: blobStore,
joinTable: studioTable,
},
tagRelationshipStore: tagRelationshipStore{
idRelationshipStore: idRelationshipStore{
joinTable: studiosTagsTableMgr,
},
},
tableMgr: studioTableMgr,
}
@@ -173,6 +190,10 @@ func (qb *StudioStore) Create(ctx context.Context, newObject *models.Studio) err
}
}
if err := qb.tagRelationshipStore.createRelationships(ctx, id, newObject.TagIDs); err != nil {
return err
}
if newObject.StashIDs.Loaded() {
if err := studiosStashIDsTableMgr.insertJoins(ctx, id, newObject.StashIDs.List()); err != nil {
return err
@@ -213,6 +234,10 @@ func (qb *StudioStore) UpdatePartial(ctx context.Context, input models.StudioPar
}
}
if err := qb.tagRelationshipStore.modifyRelationships(ctx, input.ID, input.TagIDs); err != nil {
return nil, err
}
if input.StashIDs != nil {
if err := studiosStashIDsTableMgr.modifyJoins(ctx, input.ID, input.StashIDs.StashIDs, input.StashIDs.Mode); err != nil {
return nil, err
@@ -237,6 +262,10 @@ func (qb *StudioStore) Update(ctx context.Context, updatedObject *models.Studio)
}
}
if err := qb.tagRelationshipStore.replaceRelationships(ctx, updatedObject.ID, updatedObject.TagIDs); err != nil {
return err
}
if updatedObject.StashIDs.Loaded() {
if err := studiosStashIDsTableMgr.replaceJoins(ctx, updatedObject.ID, updatedObject.StashIDs.List()); err != nil {
return err
@@ -538,6 +567,15 @@ func (qb *StudioStore) Query(ctx context.Context, studioFilter *models.StudioFil
return studios, countResult, nil
}
func (qb *StudioStore) QueryCount(ctx context.Context, studioFilter *models.StudioFilterType, findFilter *models.FindFilterType) (int, error) {
query, err := qb.makeQuery(ctx, studioFilter, findFilter)
if err != nil {
return 0, err
}
return query.executeCount(ctx)
}
var studioSortOptions = sortOptions{
"child_count",
"created_at",
@@ -569,6 +607,8 @@ func (qb *StudioStore) getStudioSort(findFilter *models.FindFilterType) (string,
sortQuery := ""
switch sort {
case "tag_count":
sortQuery += getCountSort(studioTable, studiosTagsTable, studioIDColumn, direction)
case "scenes_count":
sortQuery += getCountSort(studioTable, sceneTable, studioIDColumn, direction)
case "images_count":

View File

@@ -74,11 +74,13 @@ func (qb *studioFilterHandler) criterionHandler() criterionHandler {
},
qb.isMissingCriterionHandler(studioFilter.IsMissing),
qb.tagCountCriterionHandler(studioFilter.TagCount),
qb.sceneCountCriterionHandler(studioFilter.SceneCount),
qb.imageCountCriterionHandler(studioFilter.ImageCount),
qb.galleryCountCriterionHandler(studioFilter.GalleryCount),
qb.parentCriterionHandler(studioFilter.Parents),
qb.aliasCriterionHandler(studioFilter.Aliases),
qb.tagsCriterionHandler(studioFilter.Tags),
qb.childCountCriterionHandler(studioFilter.ChildCount),
&timestampCriterionHandler{studioFilter.CreatedAt, studioTable + ".created_at", nil},
&timestampCriterionHandler{studioFilter.UpdatedAt, studioTable + ".updated_at", nil},
@@ -161,6 +163,16 @@ func (qb *studioFilterHandler) galleryCountCriterionHandler(galleryCount *models
}
}
func (qb *studioFilterHandler) tagCountCriterionHandler(tagCount *models.IntCriterionInput) criterionHandlerFunc {
h := countCriterionHandlerBuilder{
primaryTable: studioTable,
joinTable: studiosTagsTable,
primaryFK: studioIDColumn,
}
return h.handler(tagCount)
}
func (qb *studioFilterHandler) parentCriterionHandler(parents *models.MultiCriterionInput) criterionHandlerFunc {
addJoinsFunc := func(f *filterBuilder) {
f.addLeftJoin("studios", "parent_studio", "parent_studio.id = studios.parent_id")
@@ -200,3 +212,18 @@ func (qb *studioFilterHandler) childCountCriterionHandler(childCount *models.Int
}
}
}
func (qb *studioFilterHandler) tagsCriterionHandler(tags *models.HierarchicalMultiCriterionInput) criterionHandlerFunc {
h := joinedHierarchicalMultiCriterionHandlerBuilder{
primaryTable: studioTable,
foreignTable: tagTable,
foreignFK: "tag_id",
relationsTable: "tags_relations",
joinTable: studiosTagsTable,
joinAs: "studio_tag",
primaryFK: studioIDColumn,
}
return h.handler(tags)
}

View File

@@ -704,6 +704,110 @@ func TestStudioQueryRating(t *testing.T) {
verifyStudiosRating(t, ratingCriterion)
}
func queryStudios(ctx context.Context, t *testing.T, studioFilter *models.StudioFilterType, findFilter *models.FindFilterType) []*models.Studio {
t.Helper()
studios, _, err := db.Studio.Query(ctx, studioFilter, findFilter)
if err != nil {
t.Errorf("Error querying studio: %s", err.Error())
}
return studios
}
func TestStudioQueryTags(t *testing.T) {
withTxn(func(ctx context.Context) error {
tagCriterion := models.HierarchicalMultiCriterionInput{
Value: []string{
strconv.Itoa(tagIDs[tagIdxWithStudio]),
strconv.Itoa(tagIDs[tagIdx1WithStudio]),
},
Modifier: models.CriterionModifierIncludes,
}
studioFilter := models.StudioFilterType{
Tags: &tagCriterion,
}
// ensure ids are correct
studios := queryStudios(ctx, t, &studioFilter, nil)
assert.Len(t, studios, 2)
for _, studio := range studios {
assert.True(t, studio.ID == studioIDs[studioIdxWithTag] || studio.ID == studioIDs[studioIdxWithTwoTags])
}
tagCriterion = models.HierarchicalMultiCriterionInput{
Value: []string{
strconv.Itoa(tagIDs[tagIdx1WithStudio]),
strconv.Itoa(tagIDs[tagIdx2WithStudio]),
},
Modifier: models.CriterionModifierIncludesAll,
}
studios = queryStudios(ctx, t, &studioFilter, nil)
assert.Len(t, studios, 1)
assert.Equal(t, sceneIDs[studioIdxWithTwoTags], studios[0].ID)
tagCriterion = models.HierarchicalMultiCriterionInput{
Value: []string{
strconv.Itoa(tagIDs[tagIdx1WithStudio]),
},
Modifier: models.CriterionModifierExcludes,
}
q := getSceneStringValue(studioIdxWithTwoTags, titleField)
findFilter := models.FindFilterType{
Q: &q,
}
studios = queryStudios(ctx, t, &studioFilter, &findFilter)
assert.Len(t, studios, 0)
return nil
})
}
func TestStudioQueryTagCount(t *testing.T) {
const tagCount = 1
tagCountCriterion := models.IntCriterionInput{
Value: tagCount,
Modifier: models.CriterionModifierEquals,
}
verifyStudiosTagCount(t, tagCountCriterion)
tagCountCriterion.Modifier = models.CriterionModifierNotEquals
verifyStudiosTagCount(t, tagCountCriterion)
tagCountCriterion.Modifier = models.CriterionModifierGreaterThan
verifyStudiosTagCount(t, tagCountCriterion)
tagCountCriterion.Modifier = models.CriterionModifierLessThan
verifyStudiosTagCount(t, tagCountCriterion)
}
func verifyStudiosTagCount(t *testing.T, tagCountCriterion models.IntCriterionInput) {
withTxn(func(ctx context.Context) error {
sqb := db.Studio
studioFilter := models.StudioFilterType{
TagCount: &tagCountCriterion,
}
studios := queryStudios(ctx, t, &studioFilter, nil)
assert.Greater(t, len(studios), 0)
for _, studio := range studios {
ids, err := sqb.GetTagIDs(ctx, studio.ID)
if err != nil {
return err
}
verifyInt(t, len(ids), tagCountCriterion)
}
return nil
})
}
func verifyStudioQuery(t *testing.T, filter models.StudioFilterType, verifyFn func(ctx context.Context, s *models.Studio)) {
withTxn(func(ctx context.Context) error {
t.Helper()

View File

@@ -34,6 +34,7 @@ var (
performersStashIDsJoinTable = goqu.T("performer_stash_ids")
studiosAliasesJoinTable = goqu.T(studioAliasesTable)
studiosTagsJoinTable = goqu.T(studiosTagsTable)
studiosStashIDsJoinTable = goqu.T("studio_stash_ids")
moviesURLsJoinTable = goqu.T(movieURLsTable)
@@ -294,6 +295,14 @@ var (
stringColumn: studiosAliasesJoinTable.Col(studioAliasColumn),
}
studiosTagsTableMgr = &joinTable{
table: table{
table: studiosTagsJoinTable,
idColumn: studiosTagsJoinTable.Col(studioIDColumn),
},
fkColumn: studiosTagsJoinTable.Col(tagIDColumn),
}
studiosStashIDsTableMgr = &stashIDTable{
table: table{
table: studiosStashIDsJoinTable,

View File

@@ -448,6 +448,18 @@ func (qb *TagStore) FindBySceneMarkerID(ctx context.Context, sceneMarkerID int)
return qb.queryTags(ctx, query, args)
}
func (qb *TagStore) FindByStudioID(ctx context.Context, studioID int) ([]*models.Tag, error) {
query := `
SELECT tags.* FROM tags
LEFT JOIN studios_tags as studios_join on studios_join.tag_id = tags.id
WHERE studios_join.studio_id = ?
GROUP BY tags.id
`
query += qb.getDefaultTagSort()
args := []interface{}{studioID}
return qb.queryTags(ctx, query, args)
}
func (qb *TagStore) FindByName(ctx context.Context, name string, nocase bool) (*models.Tag, error) {
// query := "SELECT * FROM tags WHERE name = ?"
// if nocase {
@@ -628,6 +640,7 @@ var tagSortOptions = sortOptions{
"id",
"images_count",
"movies_count",
"studios_count",
"name",
"performers_count",
"random",
@@ -668,6 +681,8 @@ func (qb *TagStore) getTagSort(query *queryBuilder, findFilter *models.FindFilte
sortQuery += getCountSort(tagTable, galleriesTagsTable, tagIDColumn, direction)
case "performers_count":
sortQuery += getCountSort(tagTable, performersTagsTable, tagIDColumn, direction)
case "studios_count":
sortQuery += getCountSort(tagTable, studiosTagsTable, tagIDColumn, direction)
case "movies_count":
sortQuery += getCountSort(tagTable, moviesTagsTable, tagIDColumn, direction)
default:
@@ -767,6 +782,7 @@ func (qb *TagStore) Merge(ctx context.Context, source []int, destination int) er
galleriesTagsTable: galleryIDColumn,
imagesTagsTable: imageIDColumn,
"performers_tags": "performer_id",
"studios_tags": "studio_id",
}
args = append(args, destination)

View File

@@ -66,6 +66,7 @@ func (qb *tagFilterHandler) criterionHandler() criterionHandler {
qb.imageCountCriterionHandler(tagFilter.ImageCount),
qb.galleryCountCriterionHandler(tagFilter.GalleryCount),
qb.performerCountCriterionHandler(tagFilter.PerformerCount),
qb.studioCountCriterionHandler(tagFilter.StudioCount),
qb.movieCountCriterionHandler(tagFilter.MovieCount),
qb.markerCountCriterionHandler(tagFilter.MarkerCount),
qb.parentsCriterionHandler(tagFilter.Parents),
@@ -175,6 +176,17 @@ func (qb *tagFilterHandler) performerCountCriterionHandler(performerCount *model
}
}
func (qb *tagFilterHandler) studioCountCriterionHandler(studioCount *models.IntCriterionInput) criterionHandlerFunc {
return func(ctx context.Context, f *filterBuilder) {
if studioCount != nil {
f.addLeftJoin("studios_tags", "", "studios_tags.tag_id = tags.id")
clause, args := getIntCriterionWhereClause("count(distinct studios_tags.studio_id)", *studioCount)
f.addHaving(clause, args...)
}
}
}
func (qb *tagFilterHandler) movieCountCriterionHandler(movieCount *models.IntCriterionInput) criterionHandlerFunc {
return func(ctx context.Context, f *filterBuilder) {
if movieCount != nil {

View File

@@ -230,6 +230,10 @@ func TestTagQuerySort(t *testing.T) {
tags = queryTags(ctx, t, sqb, nil, findFilter)
assert.Equal(tagIDs[tagIdx2WithPerformer], tags[0].ID)
sortBy = "studios_count"
tags = queryTags(ctx, t, sqb, nil, findFilter)
assert.Equal(tagIDs[tagIdx2WithStudio], tags[0].ID)
sortBy = "movies_count"
tags = queryTags(ctx, t, sqb, nil, findFilter)
assert.Equal(tagIDs[tagIdx1WithMovie], tags[0].ID)
@@ -569,6 +573,45 @@ func verifyTagPerformerCount(t *testing.T, imageCountCriterion models.IntCriteri
})
}
func TestTagQueryStudioCount(t *testing.T) {
countCriterion := models.IntCriterionInput{
Value: 1,
Modifier: models.CriterionModifierEquals,
}
verifyTagStudioCount(t, countCriterion)
countCriterion.Modifier = models.CriterionModifierNotEquals
verifyTagStudioCount(t, countCriterion)
countCriterion.Modifier = models.CriterionModifierLessThan
verifyTagStudioCount(t, countCriterion)
countCriterion.Value = 0
countCriterion.Modifier = models.CriterionModifierGreaterThan
verifyTagStudioCount(t, countCriterion)
}
func verifyTagStudioCount(t *testing.T, imageCountCriterion models.IntCriterionInput) {
withTxn(func(ctx context.Context) error {
qb := db.Tag
tagFilter := models.TagFilterType{
StudioCount: &imageCountCriterion,
}
tags, _, err := qb.Query(ctx, &tagFilter, nil)
if err != nil {
t.Errorf("Error querying tag: %s", err.Error())
}
for _, tag := range tags {
verifyInt(t, getTagStudioCount(tag.ID), imageCountCriterion)
}
return nil
})
}
func TestTagQueryParentCount(t *testing.T) {
countCriterion := models.IntCriterionInput{
Value: 1,
@@ -882,6 +925,9 @@ func TestTagMerge(t *testing.T) {
tagIdxWithPerformer,
tagIdx1WithPerformer,
tagIdx2WithPerformer,
tagIdxWithStudio,
tagIdx1WithStudio,
tagIdx2WithStudio,
tagIdxWithGallery,
tagIdx1WithGallery,
tagIdx2WithGallery,
@@ -970,6 +1016,14 @@ func TestTagMerge(t *testing.T) {
assert.Contains(performerTagIDs, destID)
// ensure studio points to new tag
studioTagIDs, err := db.Studio.GetTagIDs(ctx, studioIDs[studioIdxWithTwoTags])
if err != nil {
return err
}
assert.Contains(studioTagIDs, destID)
return nil
}); err != nil {
t.Error(err.Error())

View File

@@ -68,6 +68,7 @@ func createFullStudio(id int, parentID int) models.Studio {
Rating: &rating,
IgnoreAutoTag: autoTagIgnored,
Aliases: models.NewRelatedStrings(aliases),
TagIDs: models.NewRelatedIDs([]int{}),
StashIDs: models.NewRelatedStashIDs(stashIDs),
}
@@ -84,6 +85,7 @@ func createEmptyStudio(id int) models.Studio {
CreatedAt: createTime,
UpdatedAt: updateTime,
Aliases: models.NewRelatedStrings([]string{}),
TagIDs: models.NewRelatedIDs([]int{}),
StashIDs: models.NewRelatedStashIDs([]models.StashID{}),
}
}

View File

@@ -4,9 +4,11 @@ import (
"context"
"errors"
"fmt"
"strings"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/models/jsonschema"
"github.com/stashapp/stash/pkg/sliceutil"
"github.com/stashapp/stash/pkg/utils"
)
@@ -19,6 +21,7 @@ var ErrParentStudioNotExist = errors.New("parent studio does not exist")
type Importer struct {
ReaderWriter ImporterReaderWriter
TagWriter models.TagFinderCreator
Input jsonschema.Studio
MissingRefBehaviour models.ImportMissingRefEnum
@@ -34,6 +37,10 @@ func (i *Importer) PreImport(ctx context.Context) error {
return err
}
if err := i.populateTags(ctx); err != nil {
return err
}
var err error
if len(i.Input.Image) > 0 {
i.imageData, err = utils.ProcessBase64Image(i.Input.Image)
@@ -45,6 +52,74 @@ func (i *Importer) PreImport(ctx context.Context) error {
return nil
}
func (i *Importer) populateTags(ctx context.Context) error {
if len(i.Input.Tags) > 0 {
tags, err := importTags(ctx, i.TagWriter, i.Input.Tags, i.MissingRefBehaviour)
if err != nil {
return err
}
for _, p := range tags {
i.studio.TagIDs.Add(p.ID)
}
}
return nil
}
func importTags(ctx context.Context, tagWriter models.TagFinderCreator, names []string, missingRefBehaviour models.ImportMissingRefEnum) ([]*models.Tag, error) {
tags, err := tagWriter.FindByNames(ctx, names, false)
if err != nil {
return nil, err
}
var pluckedNames []string
for _, tag := range tags {
pluckedNames = append(pluckedNames, tag.Name)
}
missingTags := sliceutil.Filter(names, func(name string) bool {
return !sliceutil.Contains(pluckedNames, name)
})
if len(missingTags) > 0 {
if missingRefBehaviour == models.ImportMissingRefEnumFail {
return nil, fmt.Errorf("tags [%s] not found", strings.Join(missingTags, ", "))
}
if missingRefBehaviour == models.ImportMissingRefEnumCreate {
createdTags, err := createTags(ctx, tagWriter, missingTags)
if err != nil {
return nil, fmt.Errorf("error creating tags: %v", err)
}
tags = append(tags, createdTags...)
}
// ignore if MissingRefBehaviour set to Ignore
}
return tags, nil
}
func createTags(ctx context.Context, tagWriter models.TagFinderCreator, names []string) ([]*models.Tag, error) {
var ret []*models.Tag
for _, name := range names {
newTag := models.NewTag()
newTag.Name = name
err := tagWriter.Create(ctx, &newTag)
if err != nil {
return nil, err
}
ret = append(ret, &newTag)
}
return ret, nil
}
func (i *Importer) populateParentStudio(ctx context.Context) error {
if i.Input.ParentStudio != "" {
studio, err := i.ReaderWriter.FindByName(ctx, i.Input.ParentStudio, false)
@@ -149,6 +224,7 @@ func studioJSONtoStudio(studioJSON jsonschema.Studio) models.Studio {
CreatedAt: studioJSON.CreatedAt.GetTime(),
UpdatedAt: studioJSON.UpdatedAt.GetTime(),
TagIDs: models.NewRelatedIDs([]int{}),
StashIDs: models.NewRelatedStashIDs(studioJSON.StashIDs),
}

View File

@@ -16,13 +16,19 @@ const invalidImage = "aW1hZ2VCeXRlcw&&"
const (
studioNameErr = "studioNameErr"
existingStudioName = "existingTagName"
existingStudioName = "existingStudioName"
existingStudioID = 100
existingTagID = 105
errTagsID = 106
existingParentStudioName = "existingParentStudioName"
existingParentStudioErr = "existingParentStudioErr"
missingParentStudioName = "existingParentStudioName"
existingTagName = "existingTagName"
existingTagErr = "existingTagErr"
missingTagName = "missingTagName"
)
var testCtx = context.Background()
@@ -67,6 +73,97 @@ func TestImporterPreImport(t *testing.T) {
assert.Equal(t, expectedStudio, i.studio)
}
func TestImporterPreImportWithTag(t *testing.T) {
db := mocks.NewDatabase()
i := Importer{
ReaderWriter: db.Studio,
TagWriter: db.Tag,
MissingRefBehaviour: models.ImportMissingRefEnumFail,
Input: jsonschema.Studio{
Tags: []string{
existingTagName,
},
},
}
db.Tag.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{
{
ID: existingTagID,
Name: existingTagName,
},
}, nil).Once()
db.Tag.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once()
err := i.PreImport(testCtx)
assert.Nil(t, err)
assert.Equal(t, existingTagID, i.studio.TagIDs.List()[0])
i.Input.Tags = []string{existingTagErr}
err = i.PreImport(testCtx)
assert.NotNil(t, err)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingTag(t *testing.T) {
db := mocks.NewDatabase()
i := Importer{
ReaderWriter: db.Studio,
TagWriter: db.Tag,
Input: jsonschema.Studio{
Tags: []string{
missingTagName,
},
},
MissingRefBehaviour: models.ImportMissingRefEnumFail,
}
db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3)
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) {
t := args.Get(1).(*models.Tag)
t.ID = existingTagID
}).Return(nil)
err := i.PreImport(testCtx)
assert.NotNil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore
err = i.PreImport(testCtx)
assert.Nil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumCreate
err = i.PreImport(testCtx)
assert.Nil(t, err)
assert.Equal(t, existingTagID, i.studio.TagIDs.List()[0])
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) {
db := mocks.NewDatabase()
i := Importer{
ReaderWriter: db.Studio,
TagWriter: db.Tag,
Input: jsonschema.Studio{
Tags: []string{
missingTagName,
},
},
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
}
db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once()
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error"))
err := i.PreImport(testCtx)
assert.NotNil(t, err)
db.AssertExpectations(t)
}
func TestImporterPreImportWithParent(t *testing.T) {
db := mocks.NewDatabase()
@@ -156,6 +253,7 @@ func TestImporterPostImport(t *testing.T) {
i := Importer{
ReaderWriter: db.Studio,
TagWriter: db.Tag,
Input: jsonschema.Studio{
Aliases: []string{"alias"},
},
@@ -181,6 +279,7 @@ func TestImporterFindExistingID(t *testing.T) {
i := Importer{
ReaderWriter: db.Studio,
TagWriter: db.Tag,
Input: jsonschema.Studio{
Name: studioName,
},
@@ -223,6 +322,7 @@ func TestCreate(t *testing.T) {
i := Importer{
ReaderWriter: db.Studio,
TagWriter: db.Tag,
studio: studio,
}
@@ -258,6 +358,7 @@ func TestUpdate(t *testing.T) {
i := Importer{
ReaderWriter: db.Studio,
TagWriter: db.Tag,
studio: studio,
}

View File

@@ -2,6 +2,7 @@ package studio
import (
"context"
"strconv"
"github.com/stashapp/stash/pkg/models"
)
@@ -53,3 +54,15 @@ func ByAlias(ctx context.Context, qb models.StudioQueryer, alias string) (*model
return nil, nil
}
func CountByTagID(ctx context.Context, qb models.StudioQueryer, id int, depth *int) (int, error) {
filter := &models.StudioFilterType{
Tags: &models.HierarchicalMultiCriterionInput{
Value: []string{strconv.Itoa(id)},
Modifier: models.CriterionModifierIncludes,
Depth: depth,
},
}
return qb.QueryCount(ctx, filter, nil)
}