Studio Tagger (#3510)

* Studio image and parent studio support in scene tagger
* Refactor studio backend and add studio tagger
---------
Co-authored-by: WithoutPants <53250216+WithoutPants@users.noreply.github.com>
This commit is contained in:
Flashy78
2023-07-30 16:50:24 -07:00
committed by GitHub
parent d48dbeb864
commit a665a56ef0
79 changed files with 5224 additions and 1039 deletions

View File

@@ -438,21 +438,6 @@ func (r *stashIDRepository) get(ctx context.Context, id int) ([]models.StashID,
return []models.StashID(ret), err
}
func (r *stashIDRepository) replace(ctx context.Context, id int, newIDs []models.StashID) error {
if err := r.destroy(ctx, []int{id}); err != nil {
return err
}
query := fmt.Sprintf("INSERT INTO %s (%s, endpoint, stash_id) VALUES (?, ?, ?)", r.tableName, r.idColumn)
for _, stashID := range newIDs {
_, err := r.tx.Exec(ctx, query, id, stashID.Endpoint, stashID.StashID)
if err != nil {
return err
}
}
return nil
}
type filesRepository struct {
repository
}

View File

@@ -631,7 +631,7 @@ func populateDB() error {
return fmt.Errorf("error creating performers: %s", err.Error())
}
if err := createStudios(ctx, db.Studio, studiosNameCase, studiosNameNoCase); err != nil {
if err := createStudios(ctx, studiosNameCase, studiosNameNoCase); err != nil {
return fmt.Errorf("error creating studios: %s", err.Error())
}
@@ -659,7 +659,7 @@ func populateDB() error {
return fmt.Errorf("error linking movie studios: %s", err.Error())
}
if err := linkStudiosParent(ctx, db.Studio); err != nil {
if err := linkStudiosParent(ctx); err != nil {
return fmt.Errorf("error linking studios parent: %s", err.Error())
}
@@ -1310,8 +1310,8 @@ func createMovies(ctx context.Context, mqb models.MovieReaderWriter, n int, o in
name = getMovieStringValue(index, name)
movie := models.Movie{
Name: name,
URL: getMovieNullStringValue(index, urlField),
Name: name,
URL: getMovieNullStringValue(index, urlField),
}
err := mqb.Create(ctx, &movie)
@@ -1573,9 +1573,9 @@ func getStudioNullStringValue(index int, field string) string {
return ret.String
}
func createStudio(ctx context.Context, sqb models.StudioReaderWriter, name string, parentID *int) (*models.Studio, error) {
func createStudio(ctx context.Context, sqb *sqlite.StudioStore, name string, parentID *int) (*models.Studio, error) {
studio := models.Studio{
Name: name,
Name: name,
}
if parentID != nil {
@@ -1590,7 +1590,7 @@ func createStudio(ctx context.Context, sqb models.StudioReaderWriter, name strin
return &studio, nil
}
func createStudioFromModel(ctx context.Context, sqb models.StudioReaderWriter, studio *models.Studio) error {
func createStudioFromModel(ctx context.Context, sqb *sqlite.StudioStore, studio *models.Studio) error {
err := sqb.Create(ctx, studio)
if err != nil {
@@ -1601,7 +1601,8 @@ func createStudioFromModel(ctx context.Context, sqb models.StudioReaderWriter, s
}
// createStudios creates n studios with plain Name and o studios with camel cased NaMe included
func createStudios(ctx context.Context, sqb models.StudioReaderWriter, n int, o int) error {
func createStudios(ctx context.Context, n int, o int) error {
sqb := db.Studio
const namePlain = "Name"
const nameNoCase = "NaMe"
@@ -1618,22 +1619,18 @@ func createStudios(ctx context.Context, sqb models.StudioReaderWriter, n int, o
name = getStudioStringValue(index, name)
studio := models.Studio{
Name: name,
URL: getStudioNullStringValue(index, urlField),
URL: getStudioStringValue(index, urlField),
IgnoreAutoTag: getIgnoreAutoTag(i),
}
err := createStudioFromModel(ctx, sqb, &studio)
if err != nil {
return err
}
// add alias
// only add aliases for some scenes
if i == studioIdxWithMovie || i%5 == 0 {
alias := getStudioStringValue(i, "Alias")
if err := sqb.UpdateAliases(ctx, studio.ID, []string{alias}); err != nil {
return fmt.Errorf("error setting studio alias: %s", err.Error())
}
studio.Aliases = models.NewRelatedStrings([]string{alias})
}
err := createStudioFromModel(ctx, sqb, &studio)
if err != nil {
return err
}
studioIDs = append(studioIDs, studio.ID)
@@ -1756,12 +1753,14 @@ func linkMovieStudios(ctx context.Context, mqb models.MovieWriter) error {
})
}
func linkStudiosParent(ctx context.Context, qb models.StudioWriter) error {
func linkStudiosParent(ctx context.Context) error {
qb := db.Studio
return doLinks(studioParentLinks, func(parentIndex, childIndex int) error {
studio := models.StudioPartial{
input := &models.StudioPartial{
ID: studioIDs[childIndex],
ParentID: models.NewOptionalInt(studioIDs[parentIndex]),
}
_, err := qb.UpdatePartial(ctx, studioIDs[childIndex], studio)
_, err := qb.UpdatePartial(ctx, *input)
return err
})

View File

@@ -5,7 +5,6 @@ import (
"database/sql"
"errors"
"fmt"
"strings"
"github.com/doug-martin/goqu/v9"
"github.com/doug-martin/goqu/v9/exp"
@@ -15,14 +14,16 @@ import (
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/sliceutil/intslice"
"github.com/stashapp/stash/pkg/studio"
)
const (
studioTable = "studios"
studioIDColumn = "studio_id"
studioAliasesTable = "studio_aliases"
studioAliasColumn = "alias"
studioTable = "studios"
studioIDColumn = "studio_id"
studioAliasesTable = "studio_aliases"
studioAliasColumn = "alias"
studioParentIDColumn = "parent_id"
studioNameColumn = "name"
studioImageBlobColumn = "image_blob"
)
@@ -39,7 +40,7 @@ type studioRow struct {
IgnoreAutoTag bool `db:"ignore_auto_tag"`
// not used in resolutions or updates
CoverBlob zero.String `db:"image_blob"`
ImageBlob zero.String `db:"image_blob"`
}
func (r *studioRow) fromStudio(o models.Studio) {
@@ -116,6 +117,8 @@ func (qb *StudioStore) selectDataset() *goqu.SelectDataset {
}
func (qb *StudioStore) Create(ctx context.Context, newObject *models.Studio) error {
var err error
var r studioRow
r.fromStudio(*newObject)
@@ -124,34 +127,66 @@ func (qb *StudioStore) Create(ctx context.Context, newObject *models.Studio) err
return err
}
updated, err := qb.find(ctx, id)
if newObject.Aliases.Loaded() {
if err := studio.EnsureAliasesUnique(ctx, id, newObject.Aliases.List(), qb); err != nil {
return err
}
if err := studiosAliasesTableMgr.insertJoins(ctx, id, newObject.Aliases.List()); err != nil {
return err
}
}
if newObject.StashIDs.Loaded() {
if err := studiosStashIDsTableMgr.insertJoins(ctx, id, newObject.StashIDs.List()); err != nil {
return err
}
}
updated, _ := qb.find(ctx, id)
if err != nil {
return fmt.Errorf("finding after create: %w", err)
}
*newObject = *updated
return nil
}
func (qb *StudioStore) UpdatePartial(ctx context.Context, id int, partial models.StudioPartial) (*models.Studio, error) {
func (qb *StudioStore) UpdatePartial(ctx context.Context, input models.StudioPartial) (*models.Studio, error) {
r := studioRowRecord{
updateRecord{
Record: make(exp.Record),
},
}
r.fromPartial(partial)
r.fromPartial(input)
if len(r.Record) > 0 {
if err := qb.tableMgr.updateByID(ctx, id, r.Record); err != nil {
if err := qb.tableMgr.updateByID(ctx, input.ID, r.Record); err != nil {
return nil, err
}
}
return qb.find(ctx, id)
if input.Aliases != nil {
if err := studio.EnsureAliasesUnique(ctx, input.ID, input.Aliases.Values, qb); err != nil {
return nil, err
}
if err := studiosAliasesTableMgr.modifyJoins(ctx, input.ID, input.Aliases.Values, input.Aliases.Mode); err != nil {
return nil, err
}
}
if input.StashIDs != nil {
if err := studiosStashIDsTableMgr.modifyJoins(ctx, input.ID, input.StashIDs.StashIDs, input.StashIDs.Mode); err != nil {
return nil, err
}
}
return qb.Find(ctx, input.ID)
}
// This is only used by the Import/Export functionality
func (qb *StudioStore) Update(ctx context.Context, updatedObject *models.Studio) error {
var r studioRow
r.fromStudio(*updatedObject)
@@ -160,6 +195,18 @@ func (qb *StudioStore) Update(ctx context.Context, updatedObject *models.Studio)
return err
}
if updatedObject.Aliases.Loaded() {
if err := studiosAliasesTableMgr.replaceJoins(ctx, updatedObject.ID, updatedObject.Aliases.List()); err != nil {
return err
}
}
if updatedObject.StashIDs.Loaded() {
if err := studiosStashIDsTableMgr.replaceJoins(ctx, updatedObject.ID, updatedObject.StashIDs.List()); err != nil {
return err
}
}
return nil
}
@@ -257,10 +304,22 @@ func (qb *StudioStore) getMany(ctx context.Context, q *goqu.SelectDataset) ([]*m
return ret, nil
}
func (qb *StudioStore) findBySubquery(ctx context.Context, sq *goqu.SelectDataset) ([]*models.Studio, error) {
table := qb.table()
q := qb.selectDataset().Where(
table.Col(idColumn).Eq(
sq,
),
)
return qb.getMany(ctx, q)
}
func (qb *StudioStore) FindChildren(ctx context.Context, id int) ([]*models.Studio, error) {
// SELECT studios.* FROM studios WHERE studios.parent_id = ?
table := qb.table()
sq := qb.selectDataset().Where(table.Col("parent_id").Eq(id))
sq := qb.selectDataset().Where(table.Col(studioParentIDColumn).Eq(id))
ret, err := qb.getMany(ctx, sq)
if err != nil {
@@ -309,13 +368,44 @@ func (qb *StudioStore) FindByName(ctx context.Context, name string, nocase bool)
}
func (qb *StudioStore) FindByStashID(ctx context.Context, stashID models.StashID) ([]*models.Studio, error) {
query := selectAll("studios") + `
LEFT JOIN studio_stash_ids on studio_stash_ids.studio_id = studios.id
WHERE studio_stash_ids.stash_id = ?
AND studio_stash_ids.endpoint = ?
`
args := []interface{}{stashID.StashID, stashID.Endpoint}
return qb.queryStudios(ctx, query, args)
sq := dialect.From(studiosStashIDsJoinTable).Select(studiosStashIDsJoinTable.Col(studioIDColumn)).Where(
studiosStashIDsJoinTable.Col("stash_id").Eq(stashID.StashID),
studiosStashIDsJoinTable.Col("endpoint").Eq(stashID.Endpoint),
)
ret, err := qb.findBySubquery(ctx, sq)
if err != nil {
return nil, fmt.Errorf("getting studios for stash ID %s: %w", stashID.StashID, err)
}
return ret, nil
}
func (qb *StudioStore) FindByStashIDStatus(ctx context.Context, hasStashID bool, stashboxEndpoint string) ([]*models.Studio, error) {
table := qb.table()
sq := dialect.From(table).LeftJoin(
studiosStashIDsJoinTable,
goqu.On(table.Col(idColumn).Eq(studiosStashIDsJoinTable.Col(studioIDColumn))),
).Select(table.Col(idColumn))
if hasStashID {
sq = sq.Where(
studiosStashIDsJoinTable.Col("stash_id").IsNotNull(),
studiosStashIDsJoinTable.Col("endpoint").Eq(stashboxEndpoint),
)
} else {
sq = sq.Where(
studiosStashIDsJoinTable.Col("stash_id").IsNull(),
)
}
ret, err := qb.findBySubquery(ctx, sq)
if err != nil {
return nil, fmt.Errorf("getting studios for stash-box endpoint %s: %w", stashboxEndpoint, err)
}
return ret, nil
}
func (qb *StudioStore) Count(ctx context.Context) (int, error) {
@@ -325,38 +415,37 @@ func (qb *StudioStore) Count(ctx context.Context) (int, error) {
func (qb *StudioStore) All(ctx context.Context) ([]*models.Studio, error) {
table := qb.table()
return qb.getMany(ctx, qb.selectDataset().Order(
table.Col("name").Asc(),
table.Col(idColumn).Asc(),
))
return qb.getMany(ctx, qb.selectDataset().Order(table.Col(studioNameColumn).Asc()))
}
func (qb *StudioStore) QueryForAutoTag(ctx context.Context, words []string) ([]*models.Studio, error) {
// TODO - Query needs to be changed to support queries of this type, and
// this method should be removed
query := selectAll(studioTable)
query += " LEFT JOIN studio_aliases ON studio_aliases.studio_id = studios.id"
table := qb.table()
sq := dialect.From(table).Select(table.Col(idColumn)).LeftJoin(
studiosAliasesJoinTable,
goqu.On(studiosAliasesJoinTable.Col(studioIDColumn).Eq(table.Col(idColumn))),
)
var whereClauses []string
var args []interface{}
var whereClauses []exp.Expression
for _, w := range words {
ww := w + "%"
whereClauses = append(whereClauses, "studios.name like ?")
args = append(args, ww)
// include aliases
whereClauses = append(whereClauses, "studio_aliases.alias like ?")
args = append(args, ww)
whereClauses = append(whereClauses, table.Col(studioNameColumn).Like(w+"%"))
whereClauses = append(whereClauses, studiosAliasesJoinTable.Col("alias").Like(w+"%"))
}
whereOr := "(" + strings.Join(whereClauses, " OR ") + ")"
where := strings.Join([]string{
"studios.ignore_auto_tag = 0",
whereOr,
}, " AND ")
return qb.queryStudios(ctx, query+" WHERE "+where, args)
sq = sq.Where(
goqu.Or(whereClauses...),
table.Col("ignore_auto_tag").Eq(0),
)
ret, err := qb.findBySubquery(ctx, sq)
if err != nil {
return nil, fmt.Errorf("getting performers for autotag: %w", err)
}
return ret, nil
}
func (qb *StudioStore) validateFilter(filter *models.StudioFilterType) error {
@@ -430,13 +519,13 @@ func (qb *StudioStore) makeFilter(ctx context.Context, studioFilter *models.Stud
query.handleCriterion(ctx, studioGalleryCountCriterionHandler(qb, studioFilter.GalleryCount))
query.handleCriterion(ctx, studioParentCriterionHandler(qb, studioFilter.Parents))
query.handleCriterion(ctx, studioAliasCriterionHandler(qb, studioFilter.Aliases))
query.handleCriterion(ctx, timestampCriterionHandler(studioFilter.CreatedAt, "studios.created_at"))
query.handleCriterion(ctx, timestampCriterionHandler(studioFilter.UpdatedAt, "studios.updated_at"))
query.handleCriterion(ctx, timestampCriterionHandler(studioFilter.CreatedAt, studioTable+".created_at"))
query.handleCriterion(ctx, timestampCriterionHandler(studioFilter.UpdatedAt, studioTable+".updated_at"))
return query
}
func (qb *StudioStore) Query(ctx context.Context, studioFilter *models.StudioFilterType, findFilter *models.FindFilterType) ([]*models.Studio, int, error) {
func (qb *StudioStore) makeQuery(ctx context.Context, studioFilter *models.StudioFilterType, findFilter *models.FindFilterType) (*queryBuilder, error) {
if studioFilter == nil {
studioFilter = &models.StudioFilterType{}
}
@@ -450,20 +539,29 @@ func (qb *StudioStore) Query(ctx context.Context, studioFilter *models.StudioFil
if q := findFilter.Q; q != nil && *q != "" {
query.join(studioAliasesTable, "", "studio_aliases.studio_id = studios.id")
searchColumns := []string{"studios.name", "studio_aliases.alias"}
query.parseQueryString(searchColumns, *q)
}
if err := qb.validateFilter(studioFilter); err != nil {
return nil, 0, err
return nil, err
}
filter := qb.makeFilter(ctx, studioFilter)
if err := query.addFilter(filter); err != nil {
return nil, 0, err
return nil, err
}
query.sortAndPagination = qb.getStudioSort(findFilter) + getPagination(findFilter)
return &query, nil
}
func (qb *StudioStore) Query(ctx context.Context, studioFilter *models.StudioFilterType, findFilter *models.FindFilterType) ([]*models.Studio, int, error) {
query, err := qb.makeQuery(ctx, studioFilter, findFilter)
if err != nil {
return nil, 0, err
}
idsResult, countResult, err := query.executeFind(ctx)
if err != nil {
return nil, 0, err
@@ -546,7 +644,7 @@ func studioAliasCriterionHandler(qb *StudioStore, alias *models.StringCriterionI
joinTable: studioAliasesTable,
stringColumn: studioAliasColumn,
addJoinTable: func(f *filterBuilder) {
qb.aliasRepository().join(f, "", "studios.id")
studiosAliasesTableMgr.join(f, "", "studios.id")
},
}
@@ -581,26 +679,6 @@ func (qb *StudioStore) getStudioSort(findFilter *models.FindFilterType) string {
return sortQuery
}
func (qb *StudioStore) queryStudios(ctx context.Context, query string, args []interface{}) ([]*models.Studio, error) {
const single = false
var ret []*models.Studio
if err := qb.queryFunc(ctx, query, args, single, func(r *sqlx.Rows) error {
var f studioRow
if err := r.StructScan(&f); err != nil {
return err
}
s := f.resolve()
ret = append(ret, s)
return nil
}); err != nil {
return nil, err
}
return ret, nil
}
func (qb *StudioStore) GetImage(ctx context.Context, studioID int) ([]byte, error) {
return qb.blobJoinQueryBuilder.GetImage(ctx, studioID, studioImageBlobColumn)
}
@@ -628,28 +706,9 @@ func (qb *StudioStore) stashIDRepository() *stashIDRepository {
}
func (qb *StudioStore) GetStashIDs(ctx context.Context, studioID int) ([]models.StashID, error) {
return qb.stashIDRepository().get(ctx, studioID)
}
func (qb *StudioStore) UpdateStashIDs(ctx context.Context, studioID int, stashIDs []models.StashID) error {
return qb.stashIDRepository().replace(ctx, studioID, stashIDs)
}
func (qb *StudioStore) aliasRepository() *stringRepository {
return &stringRepository{
repository: repository{
tx: qb.tx,
tableName: studioAliasesTable,
idColumn: studioIDColumn,
},
stringColumn: studioAliasColumn,
}
return studiosStashIDsTableMgr.get(ctx, studioID)
}
func (qb *StudioStore) GetAliases(ctx context.Context, studioID int) ([]string, error) {
return qb.aliasRepository().get(ctx, studioID)
}
func (qb *StudioStore) UpdateAliases(ctx context.Context, studioID int, aliases []string) error {
return qb.aliasRepository().replace(ctx, studioID, aliases)
return studiosAliasesTableMgr.get(ctx, studioID)
}

View File

@@ -219,18 +219,15 @@ func TestStudioQueryForAutoTag(t *testing.T) {
assert.Len(t, studios, 1)
assert.Equal(t, strings.ToLower(studioNames[studioIdxWithMovie]), strings.ToLower(studios[0].Name))
// find by alias
name = getStudioStringValue(studioIdxWithMovie, "Alias")
studios, err = tqb.QueryForAutoTag(ctx, []string{name})
if err != nil {
t.Errorf("Error finding studios: %s", err.Error())
}
if assert.Len(t, studios, 1) {
assert.Equal(t, studioIDs[studioIdxWithMovie], studios[0].ID)
}
return nil
})
}
@@ -363,11 +360,12 @@ func TestStudioUpdateClearParent(t *testing.T) {
sqb := db.Studio
// clear the parent id from the child
updatePartial := models.StudioPartial{
input := models.StudioPartial{
ID: createdChild.ID,
ParentID: models.NewOptionalIntPtr(nil),
}
updatedStudio, err := sqb.UpdatePartial(ctx, createdChild.ID, updatePartial)
updatedStudio, err := sqb.UpdatePartial(ctx, input)
if err != nil {
return fmt.Errorf("Error updated studio: %s", err.Error())
@@ -548,7 +546,7 @@ func verifyStudiosGalleryCount(t *testing.T, galleryCountCriterion models.IntCri
}
func TestStudioStashIDs(t *testing.T) {
if err := withTxn(func(ctx context.Context) error {
if err := withRollbackTxn(func(ctx context.Context) error {
qb := db.Studio
// create studio to test against
@@ -558,13 +556,83 @@ func TestStudioStashIDs(t *testing.T) {
return fmt.Errorf("Error creating studio: %s", err.Error())
}
testStashIDReaderWriter(ctx, t, qb, created.ID)
studio, err := qb.Find(ctx, created.ID)
if err != nil {
return fmt.Errorf("Error getting studio: %s", err.Error())
}
if err := studio.LoadStashIDs(ctx, qb); err != nil {
return err
}
testStudioStashIDs(ctx, t, studio)
return nil
}); err != nil {
t.Error(err.Error())
}
}
func testStudioStashIDs(ctx context.Context, t *testing.T, s *models.Studio) {
qb := db.Studio
if err := s.LoadStashIDs(ctx, qb); err != nil {
t.Error(err.Error())
return
}
// ensure no stash IDs to begin with
assert.Len(t, s.StashIDs.List(), 0)
// add stash ids
const stashIDStr = "stashID"
const endpoint = "endpoint"
stashID := models.StashID{
StashID: stashIDStr,
Endpoint: endpoint,
}
// update stash ids and ensure was updated
input := models.StudioPartial{
ID: s.ID,
StashIDs: &models.UpdateStashIDs{
StashIDs: []models.StashID{stashID},
Mode: models.RelationshipUpdateModeSet,
},
}
var err error
s, err = qb.UpdatePartial(ctx, input)
if err != nil {
t.Error(err.Error())
}
if err := s.LoadStashIDs(ctx, qb); err != nil {
t.Error(err.Error())
return
}
assert.Equal(t, []models.StashID{stashID}, s.StashIDs.List())
// remove stash ids and ensure was updated
input = models.StudioPartial{
ID: s.ID,
StashIDs: &models.UpdateStashIDs{
StashIDs: []models.StashID{stashID},
Mode: models.RelationshipUpdateModeRemove,
},
}
s, err = qb.UpdatePartial(ctx, input)
if err != nil {
t.Error(err.Error())
}
if err := s.LoadStashIDs(ctx, qb); err != nil {
t.Error(err.Error())
return
}
assert.Len(t, s.StashIDs.List(), 0)
}
func TestStudioQueryURL(t *testing.T) {
const sceneIdx = 1
studioURL := getStudioStringValue(sceneIdx, urlField)
@@ -684,7 +752,7 @@ func TestStudioQueryIsMissingRating(t *testing.T) {
assert.True(t, len(studios) > 0)
for _, studio := range studios {
assert.True(t, studio.Rating == nil)
assert.Nil(t, studio.Rating)
}
return nil
@@ -778,36 +846,87 @@ func TestStudioQueryAlias(t *testing.T) {
verifyStudioQuery(t, studioFilter, verifyFn)
}
func TestStudioUpdateAlias(t *testing.T) {
if err := withTxn(func(ctx context.Context) error {
func TestStudioAlias(t *testing.T) {
if err := withRollbackTxn(func(ctx context.Context) error {
qb := db.Studio
// create studio to test against
const name = "TestStudioUpdateAlias"
created, err := createStudio(ctx, qb, name, nil)
const name = "TestStudioAlias"
created, err := createStudio(ctx, db.Studio, name, nil)
if err != nil {
return fmt.Errorf("Error creating studio: %s", err.Error())
}
aliases := []string{"alias1", "alias2"}
err = qb.UpdateAliases(ctx, created.ID, aliases)
studio, err := qb.Find(ctx, created.ID)
if err != nil {
return fmt.Errorf("Error updating studio aliases: %s", err.Error())
return fmt.Errorf("Error getting studio: %s", err.Error())
}
// ensure aliases set
storedAliases, err := qb.GetAliases(ctx, created.ID)
if err != nil {
return fmt.Errorf("Error getting aliases: %s", err.Error())
if err := studio.LoadStashIDs(ctx, qb); err != nil {
return err
}
assert.Equal(t, aliases, storedAliases)
testStudioAlias(ctx, t, studio)
return nil
}); err != nil {
t.Error(err.Error())
}
}
func testStudioAlias(ctx context.Context, t *testing.T, s *models.Studio) {
qb := db.Studio
if err := s.LoadAliases(ctx, qb); err != nil {
t.Error(err.Error())
return
}
// ensure no alias to begin with
assert.Len(t, s.Aliases.List(), 0)
aliases := []string{"alias1", "alias2"}
// update alias and ensure was updated
input := models.StudioPartial{
ID: s.ID,
Aliases: &models.UpdateStrings{
Values: aliases,
Mode: models.RelationshipUpdateModeSet,
},
}
var err error
s, err = qb.UpdatePartial(ctx, input)
if err != nil {
t.Error(err.Error())
}
if err := s.LoadAliases(ctx, qb); err != nil {
t.Error(err.Error())
return
}
assert.Equal(t, aliases, s.Aliases.List())
// remove alias and ensure was updated
input = models.StudioPartial{
ID: s.ID,
Aliases: &models.UpdateStrings{
Values: aliases,
Mode: models.RelationshipUpdateModeRemove,
},
}
s, err = qb.UpdatePartial(ctx, input)
if err != nil {
t.Error(err.Error())
}
if err := s.LoadAliases(ctx, qb); err != nil {
t.Error(err.Error())
return
}
assert.Len(t, s.Aliases.List(), 0)
}
// TestStudioQueryFast does a quick test for major errors, no result verification
func TestStudioQueryFast(t *testing.T) {

View File

@@ -29,6 +29,9 @@ var (
performersAliasesJoinTable = goqu.T(performersAliasesTable)
performersTagsJoinTable = goqu.T(performersTagsTable)
performersStashIDsJoinTable = goqu.T("performer_stash_ids")
studiosAliasesJoinTable = goqu.T(studioAliasesTable)
studiosStashIDsJoinTable = goqu.T("studio_stash_ids")
)
var (
@@ -233,6 +236,21 @@ var (
table: goqu.T(studioTable),
idColumn: goqu.T(studioTable).Col(idColumn),
}
studiosAliasesTableMgr = &stringTable{
table: table{
table: studiosAliasesJoinTable,
idColumn: studiosAliasesJoinTable.Col(studioIDColumn),
},
stringColumn: studiosAliasesJoinTable.Col(studioAliasColumn),
}
studiosStashIDsTableMgr = &stashIDTable{
table: table{
table: studiosStashIDsJoinTable,
idColumn: studiosStashIDsJoinTable.Col(studioIDColumn),
},
}
)
var (