mirror of
https://github.com/stashapp/stash.git
synced 2025-12-18 04:44:37 +03:00
Studio Tagger (#3510)
* Studio image and parent studio support in scene tagger * Refactor studio backend and add studio tagger --------- Co-authored-by: WithoutPants <53250216+WithoutPants@users.noreply.github.com>
This commit is contained in:
@@ -438,21 +438,6 @@ func (r *stashIDRepository) get(ctx context.Context, id int) ([]models.StashID,
|
||||
return []models.StashID(ret), err
|
||||
}
|
||||
|
||||
func (r *stashIDRepository) replace(ctx context.Context, id int, newIDs []models.StashID) error {
|
||||
if err := r.destroy(ctx, []int{id}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("INSERT INTO %s (%s, endpoint, stash_id) VALUES (?, ?, ?)", r.tableName, r.idColumn)
|
||||
for _, stashID := range newIDs {
|
||||
_, err := r.tx.Exec(ctx, query, id, stashID.Endpoint, stashID.StashID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type filesRepository struct {
|
||||
repository
|
||||
}
|
||||
|
||||
@@ -631,7 +631,7 @@ func populateDB() error {
|
||||
return fmt.Errorf("error creating performers: %s", err.Error())
|
||||
}
|
||||
|
||||
if err := createStudios(ctx, db.Studio, studiosNameCase, studiosNameNoCase); err != nil {
|
||||
if err := createStudios(ctx, studiosNameCase, studiosNameNoCase); err != nil {
|
||||
return fmt.Errorf("error creating studios: %s", err.Error())
|
||||
}
|
||||
|
||||
@@ -659,7 +659,7 @@ func populateDB() error {
|
||||
return fmt.Errorf("error linking movie studios: %s", err.Error())
|
||||
}
|
||||
|
||||
if err := linkStudiosParent(ctx, db.Studio); err != nil {
|
||||
if err := linkStudiosParent(ctx); err != nil {
|
||||
return fmt.Errorf("error linking studios parent: %s", err.Error())
|
||||
}
|
||||
|
||||
@@ -1310,8 +1310,8 @@ func createMovies(ctx context.Context, mqb models.MovieReaderWriter, n int, o in
|
||||
|
||||
name = getMovieStringValue(index, name)
|
||||
movie := models.Movie{
|
||||
Name: name,
|
||||
URL: getMovieNullStringValue(index, urlField),
|
||||
Name: name,
|
||||
URL: getMovieNullStringValue(index, urlField),
|
||||
}
|
||||
|
||||
err := mqb.Create(ctx, &movie)
|
||||
@@ -1573,9 +1573,9 @@ func getStudioNullStringValue(index int, field string) string {
|
||||
return ret.String
|
||||
}
|
||||
|
||||
func createStudio(ctx context.Context, sqb models.StudioReaderWriter, name string, parentID *int) (*models.Studio, error) {
|
||||
func createStudio(ctx context.Context, sqb *sqlite.StudioStore, name string, parentID *int) (*models.Studio, error) {
|
||||
studio := models.Studio{
|
||||
Name: name,
|
||||
Name: name,
|
||||
}
|
||||
|
||||
if parentID != nil {
|
||||
@@ -1590,7 +1590,7 @@ func createStudio(ctx context.Context, sqb models.StudioReaderWriter, name strin
|
||||
return &studio, nil
|
||||
}
|
||||
|
||||
func createStudioFromModel(ctx context.Context, sqb models.StudioReaderWriter, studio *models.Studio) error {
|
||||
func createStudioFromModel(ctx context.Context, sqb *sqlite.StudioStore, studio *models.Studio) error {
|
||||
err := sqb.Create(ctx, studio)
|
||||
|
||||
if err != nil {
|
||||
@@ -1601,7 +1601,8 @@ func createStudioFromModel(ctx context.Context, sqb models.StudioReaderWriter, s
|
||||
}
|
||||
|
||||
// createStudios creates n studios with plain Name and o studios with camel cased NaMe included
|
||||
func createStudios(ctx context.Context, sqb models.StudioReaderWriter, n int, o int) error {
|
||||
func createStudios(ctx context.Context, n int, o int) error {
|
||||
sqb := db.Studio
|
||||
const namePlain = "Name"
|
||||
const nameNoCase = "NaMe"
|
||||
|
||||
@@ -1618,22 +1619,18 @@ func createStudios(ctx context.Context, sqb models.StudioReaderWriter, n int, o
|
||||
name = getStudioStringValue(index, name)
|
||||
studio := models.Studio{
|
||||
Name: name,
|
||||
URL: getStudioNullStringValue(index, urlField),
|
||||
URL: getStudioStringValue(index, urlField),
|
||||
IgnoreAutoTag: getIgnoreAutoTag(i),
|
||||
}
|
||||
|
||||
err := createStudioFromModel(ctx, sqb, &studio)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// add alias
|
||||
// only add aliases for some scenes
|
||||
if i == studioIdxWithMovie || i%5 == 0 {
|
||||
alias := getStudioStringValue(i, "Alias")
|
||||
if err := sqb.UpdateAliases(ctx, studio.ID, []string{alias}); err != nil {
|
||||
return fmt.Errorf("error setting studio alias: %s", err.Error())
|
||||
}
|
||||
studio.Aliases = models.NewRelatedStrings([]string{alias})
|
||||
}
|
||||
err := createStudioFromModel(ctx, sqb, &studio)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
studioIDs = append(studioIDs, studio.ID)
|
||||
@@ -1756,12 +1753,14 @@ func linkMovieStudios(ctx context.Context, mqb models.MovieWriter) error {
|
||||
})
|
||||
}
|
||||
|
||||
func linkStudiosParent(ctx context.Context, qb models.StudioWriter) error {
|
||||
func linkStudiosParent(ctx context.Context) error {
|
||||
qb := db.Studio
|
||||
return doLinks(studioParentLinks, func(parentIndex, childIndex int) error {
|
||||
studio := models.StudioPartial{
|
||||
input := &models.StudioPartial{
|
||||
ID: studioIDs[childIndex],
|
||||
ParentID: models.NewOptionalInt(studioIDs[parentIndex]),
|
||||
}
|
||||
_, err := qb.UpdatePartial(ctx, studioIDs[childIndex], studio)
|
||||
_, err := qb.UpdatePartial(ctx, *input)
|
||||
|
||||
return err
|
||||
})
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/doug-martin/goqu/v9"
|
||||
"github.com/doug-martin/goqu/v9/exp"
|
||||
@@ -15,14 +14,16 @@ import (
|
||||
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/sliceutil/intslice"
|
||||
"github.com/stashapp/stash/pkg/studio"
|
||||
)
|
||||
|
||||
const (
|
||||
studioTable = "studios"
|
||||
studioIDColumn = "studio_id"
|
||||
studioAliasesTable = "studio_aliases"
|
||||
studioAliasColumn = "alias"
|
||||
|
||||
studioTable = "studios"
|
||||
studioIDColumn = "studio_id"
|
||||
studioAliasesTable = "studio_aliases"
|
||||
studioAliasColumn = "alias"
|
||||
studioParentIDColumn = "parent_id"
|
||||
studioNameColumn = "name"
|
||||
studioImageBlobColumn = "image_blob"
|
||||
)
|
||||
|
||||
@@ -39,7 +40,7 @@ type studioRow struct {
|
||||
IgnoreAutoTag bool `db:"ignore_auto_tag"`
|
||||
|
||||
// not used in resolutions or updates
|
||||
CoverBlob zero.String `db:"image_blob"`
|
||||
ImageBlob zero.String `db:"image_blob"`
|
||||
}
|
||||
|
||||
func (r *studioRow) fromStudio(o models.Studio) {
|
||||
@@ -116,6 +117,8 @@ func (qb *StudioStore) selectDataset() *goqu.SelectDataset {
|
||||
}
|
||||
|
||||
func (qb *StudioStore) Create(ctx context.Context, newObject *models.Studio) error {
|
||||
var err error
|
||||
|
||||
var r studioRow
|
||||
r.fromStudio(*newObject)
|
||||
|
||||
@@ -124,34 +127,66 @@ func (qb *StudioStore) Create(ctx context.Context, newObject *models.Studio) err
|
||||
return err
|
||||
}
|
||||
|
||||
updated, err := qb.find(ctx, id)
|
||||
if newObject.Aliases.Loaded() {
|
||||
if err := studio.EnsureAliasesUnique(ctx, id, newObject.Aliases.List(), qb); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := studiosAliasesTableMgr.insertJoins(ctx, id, newObject.Aliases.List()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if newObject.StashIDs.Loaded() {
|
||||
if err := studiosStashIDsTableMgr.insertJoins(ctx, id, newObject.StashIDs.List()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
updated, _ := qb.find(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("finding after create: %w", err)
|
||||
}
|
||||
|
||||
*newObject = *updated
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (qb *StudioStore) UpdatePartial(ctx context.Context, id int, partial models.StudioPartial) (*models.Studio, error) {
|
||||
func (qb *StudioStore) UpdatePartial(ctx context.Context, input models.StudioPartial) (*models.Studio, error) {
|
||||
r := studioRowRecord{
|
||||
updateRecord{
|
||||
Record: make(exp.Record),
|
||||
},
|
||||
}
|
||||
|
||||
r.fromPartial(partial)
|
||||
r.fromPartial(input)
|
||||
|
||||
if len(r.Record) > 0 {
|
||||
if err := qb.tableMgr.updateByID(ctx, id, r.Record); err != nil {
|
||||
if err := qb.tableMgr.updateByID(ctx, input.ID, r.Record); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return qb.find(ctx, id)
|
||||
if input.Aliases != nil {
|
||||
if err := studio.EnsureAliasesUnique(ctx, input.ID, input.Aliases.Values, qb); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := studiosAliasesTableMgr.modifyJoins(ctx, input.ID, input.Aliases.Values, input.Aliases.Mode); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if input.StashIDs != nil {
|
||||
if err := studiosStashIDsTableMgr.modifyJoins(ctx, input.ID, input.StashIDs.StashIDs, input.StashIDs.Mode); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return qb.Find(ctx, input.ID)
|
||||
}
|
||||
|
||||
// This is only used by the Import/Export functionality
|
||||
func (qb *StudioStore) Update(ctx context.Context, updatedObject *models.Studio) error {
|
||||
var r studioRow
|
||||
r.fromStudio(*updatedObject)
|
||||
@@ -160,6 +195,18 @@ func (qb *StudioStore) Update(ctx context.Context, updatedObject *models.Studio)
|
||||
return err
|
||||
}
|
||||
|
||||
if updatedObject.Aliases.Loaded() {
|
||||
if err := studiosAliasesTableMgr.replaceJoins(ctx, updatedObject.ID, updatedObject.Aliases.List()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if updatedObject.StashIDs.Loaded() {
|
||||
if err := studiosStashIDsTableMgr.replaceJoins(ctx, updatedObject.ID, updatedObject.StashIDs.List()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -257,10 +304,22 @@ func (qb *StudioStore) getMany(ctx context.Context, q *goqu.SelectDataset) ([]*m
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (qb *StudioStore) findBySubquery(ctx context.Context, sq *goqu.SelectDataset) ([]*models.Studio, error) {
|
||||
table := qb.table()
|
||||
|
||||
q := qb.selectDataset().Where(
|
||||
table.Col(idColumn).Eq(
|
||||
sq,
|
||||
),
|
||||
)
|
||||
|
||||
return qb.getMany(ctx, q)
|
||||
}
|
||||
|
||||
func (qb *StudioStore) FindChildren(ctx context.Context, id int) ([]*models.Studio, error) {
|
||||
// SELECT studios.* FROM studios WHERE studios.parent_id = ?
|
||||
table := qb.table()
|
||||
sq := qb.selectDataset().Where(table.Col("parent_id").Eq(id))
|
||||
sq := qb.selectDataset().Where(table.Col(studioParentIDColumn).Eq(id))
|
||||
ret, err := qb.getMany(ctx, sq)
|
||||
|
||||
if err != nil {
|
||||
@@ -309,13 +368,44 @@ func (qb *StudioStore) FindByName(ctx context.Context, name string, nocase bool)
|
||||
}
|
||||
|
||||
func (qb *StudioStore) FindByStashID(ctx context.Context, stashID models.StashID) ([]*models.Studio, error) {
|
||||
query := selectAll("studios") + `
|
||||
LEFT JOIN studio_stash_ids on studio_stash_ids.studio_id = studios.id
|
||||
WHERE studio_stash_ids.stash_id = ?
|
||||
AND studio_stash_ids.endpoint = ?
|
||||
`
|
||||
args := []interface{}{stashID.StashID, stashID.Endpoint}
|
||||
return qb.queryStudios(ctx, query, args)
|
||||
sq := dialect.From(studiosStashIDsJoinTable).Select(studiosStashIDsJoinTable.Col(studioIDColumn)).Where(
|
||||
studiosStashIDsJoinTable.Col("stash_id").Eq(stashID.StashID),
|
||||
studiosStashIDsJoinTable.Col("endpoint").Eq(stashID.Endpoint),
|
||||
)
|
||||
ret, err := qb.findBySubquery(ctx, sq)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting studios for stash ID %s: %w", stashID.StashID, err)
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (qb *StudioStore) FindByStashIDStatus(ctx context.Context, hasStashID bool, stashboxEndpoint string) ([]*models.Studio, error) {
|
||||
table := qb.table()
|
||||
sq := dialect.From(table).LeftJoin(
|
||||
studiosStashIDsJoinTable,
|
||||
goqu.On(table.Col(idColumn).Eq(studiosStashIDsJoinTable.Col(studioIDColumn))),
|
||||
).Select(table.Col(idColumn))
|
||||
|
||||
if hasStashID {
|
||||
sq = sq.Where(
|
||||
studiosStashIDsJoinTable.Col("stash_id").IsNotNull(),
|
||||
studiosStashIDsJoinTable.Col("endpoint").Eq(stashboxEndpoint),
|
||||
)
|
||||
} else {
|
||||
sq = sq.Where(
|
||||
studiosStashIDsJoinTable.Col("stash_id").IsNull(),
|
||||
)
|
||||
}
|
||||
|
||||
ret, err := qb.findBySubquery(ctx, sq)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting studios for stash-box endpoint %s: %w", stashboxEndpoint, err)
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (qb *StudioStore) Count(ctx context.Context) (int, error) {
|
||||
@@ -325,38 +415,37 @@ func (qb *StudioStore) Count(ctx context.Context) (int, error) {
|
||||
|
||||
func (qb *StudioStore) All(ctx context.Context) ([]*models.Studio, error) {
|
||||
table := qb.table()
|
||||
|
||||
return qb.getMany(ctx, qb.selectDataset().Order(
|
||||
table.Col("name").Asc(),
|
||||
table.Col(idColumn).Asc(),
|
||||
))
|
||||
return qb.getMany(ctx, qb.selectDataset().Order(table.Col(studioNameColumn).Asc()))
|
||||
}
|
||||
|
||||
func (qb *StudioStore) QueryForAutoTag(ctx context.Context, words []string) ([]*models.Studio, error) {
|
||||
// TODO - Query needs to be changed to support queries of this type, and
|
||||
// this method should be removed
|
||||
query := selectAll(studioTable)
|
||||
query += " LEFT JOIN studio_aliases ON studio_aliases.studio_id = studios.id"
|
||||
table := qb.table()
|
||||
sq := dialect.From(table).Select(table.Col(idColumn)).LeftJoin(
|
||||
studiosAliasesJoinTable,
|
||||
goqu.On(studiosAliasesJoinTable.Col(studioIDColumn).Eq(table.Col(idColumn))),
|
||||
)
|
||||
|
||||
var whereClauses []string
|
||||
var args []interface{}
|
||||
var whereClauses []exp.Expression
|
||||
|
||||
for _, w := range words {
|
||||
ww := w + "%"
|
||||
whereClauses = append(whereClauses, "studios.name like ?")
|
||||
args = append(args, ww)
|
||||
|
||||
// include aliases
|
||||
whereClauses = append(whereClauses, "studio_aliases.alias like ?")
|
||||
args = append(args, ww)
|
||||
whereClauses = append(whereClauses, table.Col(studioNameColumn).Like(w+"%"))
|
||||
whereClauses = append(whereClauses, studiosAliasesJoinTable.Col("alias").Like(w+"%"))
|
||||
}
|
||||
|
||||
whereOr := "(" + strings.Join(whereClauses, " OR ") + ")"
|
||||
where := strings.Join([]string{
|
||||
"studios.ignore_auto_tag = 0",
|
||||
whereOr,
|
||||
}, " AND ")
|
||||
return qb.queryStudios(ctx, query+" WHERE "+where, args)
|
||||
sq = sq.Where(
|
||||
goqu.Or(whereClauses...),
|
||||
table.Col("ignore_auto_tag").Eq(0),
|
||||
)
|
||||
|
||||
ret, err := qb.findBySubquery(ctx, sq)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting performers for autotag: %w", err)
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (qb *StudioStore) validateFilter(filter *models.StudioFilterType) error {
|
||||
@@ -430,13 +519,13 @@ func (qb *StudioStore) makeFilter(ctx context.Context, studioFilter *models.Stud
|
||||
query.handleCriterion(ctx, studioGalleryCountCriterionHandler(qb, studioFilter.GalleryCount))
|
||||
query.handleCriterion(ctx, studioParentCriterionHandler(qb, studioFilter.Parents))
|
||||
query.handleCriterion(ctx, studioAliasCriterionHandler(qb, studioFilter.Aliases))
|
||||
query.handleCriterion(ctx, timestampCriterionHandler(studioFilter.CreatedAt, "studios.created_at"))
|
||||
query.handleCriterion(ctx, timestampCriterionHandler(studioFilter.UpdatedAt, "studios.updated_at"))
|
||||
query.handleCriterion(ctx, timestampCriterionHandler(studioFilter.CreatedAt, studioTable+".created_at"))
|
||||
query.handleCriterion(ctx, timestampCriterionHandler(studioFilter.UpdatedAt, studioTable+".updated_at"))
|
||||
|
||||
return query
|
||||
}
|
||||
|
||||
func (qb *StudioStore) Query(ctx context.Context, studioFilter *models.StudioFilterType, findFilter *models.FindFilterType) ([]*models.Studio, int, error) {
|
||||
func (qb *StudioStore) makeQuery(ctx context.Context, studioFilter *models.StudioFilterType, findFilter *models.FindFilterType) (*queryBuilder, error) {
|
||||
if studioFilter == nil {
|
||||
studioFilter = &models.StudioFilterType{}
|
||||
}
|
||||
@@ -450,20 +539,29 @@ func (qb *StudioStore) Query(ctx context.Context, studioFilter *models.StudioFil
|
||||
if q := findFilter.Q; q != nil && *q != "" {
|
||||
query.join(studioAliasesTable, "", "studio_aliases.studio_id = studios.id")
|
||||
searchColumns := []string{"studios.name", "studio_aliases.alias"}
|
||||
|
||||
query.parseQueryString(searchColumns, *q)
|
||||
}
|
||||
|
||||
if err := qb.validateFilter(studioFilter); err != nil {
|
||||
return nil, 0, err
|
||||
return nil, err
|
||||
}
|
||||
filter := qb.makeFilter(ctx, studioFilter)
|
||||
|
||||
if err := query.addFilter(filter); err != nil {
|
||||
return nil, 0, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
query.sortAndPagination = qb.getStudioSort(findFilter) + getPagination(findFilter)
|
||||
|
||||
return &query, nil
|
||||
}
|
||||
|
||||
func (qb *StudioStore) Query(ctx context.Context, studioFilter *models.StudioFilterType, findFilter *models.FindFilterType) ([]*models.Studio, int, error) {
|
||||
query, err := qb.makeQuery(ctx, studioFilter, findFilter)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
idsResult, countResult, err := query.executeFind(ctx)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
@@ -546,7 +644,7 @@ func studioAliasCriterionHandler(qb *StudioStore, alias *models.StringCriterionI
|
||||
joinTable: studioAliasesTable,
|
||||
stringColumn: studioAliasColumn,
|
||||
addJoinTable: func(f *filterBuilder) {
|
||||
qb.aliasRepository().join(f, "", "studios.id")
|
||||
studiosAliasesTableMgr.join(f, "", "studios.id")
|
||||
},
|
||||
}
|
||||
|
||||
@@ -581,26 +679,6 @@ func (qb *StudioStore) getStudioSort(findFilter *models.FindFilterType) string {
|
||||
return sortQuery
|
||||
}
|
||||
|
||||
func (qb *StudioStore) queryStudios(ctx context.Context, query string, args []interface{}) ([]*models.Studio, error) {
|
||||
const single = false
|
||||
var ret []*models.Studio
|
||||
if err := qb.queryFunc(ctx, query, args, single, func(r *sqlx.Rows) error {
|
||||
var f studioRow
|
||||
if err := r.StructScan(&f); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s := f.resolve()
|
||||
|
||||
ret = append(ret, s)
|
||||
return nil
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (qb *StudioStore) GetImage(ctx context.Context, studioID int) ([]byte, error) {
|
||||
return qb.blobJoinQueryBuilder.GetImage(ctx, studioID, studioImageBlobColumn)
|
||||
}
|
||||
@@ -628,28 +706,9 @@ func (qb *StudioStore) stashIDRepository() *stashIDRepository {
|
||||
}
|
||||
|
||||
func (qb *StudioStore) GetStashIDs(ctx context.Context, studioID int) ([]models.StashID, error) {
|
||||
return qb.stashIDRepository().get(ctx, studioID)
|
||||
}
|
||||
|
||||
func (qb *StudioStore) UpdateStashIDs(ctx context.Context, studioID int, stashIDs []models.StashID) error {
|
||||
return qb.stashIDRepository().replace(ctx, studioID, stashIDs)
|
||||
}
|
||||
|
||||
func (qb *StudioStore) aliasRepository() *stringRepository {
|
||||
return &stringRepository{
|
||||
repository: repository{
|
||||
tx: qb.tx,
|
||||
tableName: studioAliasesTable,
|
||||
idColumn: studioIDColumn,
|
||||
},
|
||||
stringColumn: studioAliasColumn,
|
||||
}
|
||||
return studiosStashIDsTableMgr.get(ctx, studioID)
|
||||
}
|
||||
|
||||
func (qb *StudioStore) GetAliases(ctx context.Context, studioID int) ([]string, error) {
|
||||
return qb.aliasRepository().get(ctx, studioID)
|
||||
}
|
||||
|
||||
func (qb *StudioStore) UpdateAliases(ctx context.Context, studioID int, aliases []string) error {
|
||||
return qb.aliasRepository().replace(ctx, studioID, aliases)
|
||||
return studiosAliasesTableMgr.get(ctx, studioID)
|
||||
}
|
||||
|
||||
@@ -219,18 +219,15 @@ func TestStudioQueryForAutoTag(t *testing.T) {
|
||||
assert.Len(t, studios, 1)
|
||||
assert.Equal(t, strings.ToLower(studioNames[studioIdxWithMovie]), strings.ToLower(studios[0].Name))
|
||||
|
||||
// find by alias
|
||||
name = getStudioStringValue(studioIdxWithMovie, "Alias")
|
||||
studios, err = tqb.QueryForAutoTag(ctx, []string{name})
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Error finding studios: %s", err.Error())
|
||||
}
|
||||
|
||||
if assert.Len(t, studios, 1) {
|
||||
assert.Equal(t, studioIDs[studioIdxWithMovie], studios[0].ID)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
@@ -363,11 +360,12 @@ func TestStudioUpdateClearParent(t *testing.T) {
|
||||
sqb := db.Studio
|
||||
|
||||
// clear the parent id from the child
|
||||
updatePartial := models.StudioPartial{
|
||||
input := models.StudioPartial{
|
||||
ID: createdChild.ID,
|
||||
ParentID: models.NewOptionalIntPtr(nil),
|
||||
}
|
||||
|
||||
updatedStudio, err := sqb.UpdatePartial(ctx, createdChild.ID, updatePartial)
|
||||
updatedStudio, err := sqb.UpdatePartial(ctx, input)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error updated studio: %s", err.Error())
|
||||
@@ -548,7 +546,7 @@ func verifyStudiosGalleryCount(t *testing.T, galleryCountCriterion models.IntCri
|
||||
}
|
||||
|
||||
func TestStudioStashIDs(t *testing.T) {
|
||||
if err := withTxn(func(ctx context.Context) error {
|
||||
if err := withRollbackTxn(func(ctx context.Context) error {
|
||||
qb := db.Studio
|
||||
|
||||
// create studio to test against
|
||||
@@ -558,13 +556,83 @@ func TestStudioStashIDs(t *testing.T) {
|
||||
return fmt.Errorf("Error creating studio: %s", err.Error())
|
||||
}
|
||||
|
||||
testStashIDReaderWriter(ctx, t, qb, created.ID)
|
||||
studio, err := qb.Find(ctx, created.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error getting studio: %s", err.Error())
|
||||
}
|
||||
|
||||
if err := studio.LoadStashIDs(ctx, qb); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
testStudioStashIDs(ctx, t, studio)
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Error(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func testStudioStashIDs(ctx context.Context, t *testing.T, s *models.Studio) {
|
||||
qb := db.Studio
|
||||
|
||||
if err := s.LoadStashIDs(ctx, qb); err != nil {
|
||||
t.Error(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// ensure no stash IDs to begin with
|
||||
assert.Len(t, s.StashIDs.List(), 0)
|
||||
|
||||
// add stash ids
|
||||
const stashIDStr = "stashID"
|
||||
const endpoint = "endpoint"
|
||||
stashID := models.StashID{
|
||||
StashID: stashIDStr,
|
||||
Endpoint: endpoint,
|
||||
}
|
||||
|
||||
// update stash ids and ensure was updated
|
||||
input := models.StudioPartial{
|
||||
ID: s.ID,
|
||||
StashIDs: &models.UpdateStashIDs{
|
||||
StashIDs: []models.StashID{stashID},
|
||||
Mode: models.RelationshipUpdateModeSet,
|
||||
},
|
||||
}
|
||||
var err error
|
||||
s, err = qb.UpdatePartial(ctx, input)
|
||||
if err != nil {
|
||||
t.Error(err.Error())
|
||||
}
|
||||
|
||||
if err := s.LoadStashIDs(ctx, qb); err != nil {
|
||||
t.Error(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, []models.StashID{stashID}, s.StashIDs.List())
|
||||
|
||||
// remove stash ids and ensure was updated
|
||||
input = models.StudioPartial{
|
||||
ID: s.ID,
|
||||
StashIDs: &models.UpdateStashIDs{
|
||||
StashIDs: []models.StashID{stashID},
|
||||
Mode: models.RelationshipUpdateModeRemove,
|
||||
},
|
||||
}
|
||||
s, err = qb.UpdatePartial(ctx, input)
|
||||
if err != nil {
|
||||
t.Error(err.Error())
|
||||
}
|
||||
|
||||
if err := s.LoadStashIDs(ctx, qb); err != nil {
|
||||
t.Error(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
assert.Len(t, s.StashIDs.List(), 0)
|
||||
}
|
||||
|
||||
func TestStudioQueryURL(t *testing.T) {
|
||||
const sceneIdx = 1
|
||||
studioURL := getStudioStringValue(sceneIdx, urlField)
|
||||
@@ -684,7 +752,7 @@ func TestStudioQueryIsMissingRating(t *testing.T) {
|
||||
assert.True(t, len(studios) > 0)
|
||||
|
||||
for _, studio := range studios {
|
||||
assert.True(t, studio.Rating == nil)
|
||||
assert.Nil(t, studio.Rating)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -778,36 +846,87 @@ func TestStudioQueryAlias(t *testing.T) {
|
||||
verifyStudioQuery(t, studioFilter, verifyFn)
|
||||
}
|
||||
|
||||
func TestStudioUpdateAlias(t *testing.T) {
|
||||
if err := withTxn(func(ctx context.Context) error {
|
||||
func TestStudioAlias(t *testing.T) {
|
||||
if err := withRollbackTxn(func(ctx context.Context) error {
|
||||
qb := db.Studio
|
||||
|
||||
// create studio to test against
|
||||
const name = "TestStudioUpdateAlias"
|
||||
created, err := createStudio(ctx, qb, name, nil)
|
||||
const name = "TestStudioAlias"
|
||||
created, err := createStudio(ctx, db.Studio, name, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error creating studio: %s", err.Error())
|
||||
}
|
||||
|
||||
aliases := []string{"alias1", "alias2"}
|
||||
err = qb.UpdateAliases(ctx, created.ID, aliases)
|
||||
studio, err := qb.Find(ctx, created.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error updating studio aliases: %s", err.Error())
|
||||
return fmt.Errorf("Error getting studio: %s", err.Error())
|
||||
}
|
||||
|
||||
// ensure aliases set
|
||||
storedAliases, err := qb.GetAliases(ctx, created.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error getting aliases: %s", err.Error())
|
||||
if err := studio.LoadStashIDs(ctx, qb); err != nil {
|
||||
return err
|
||||
}
|
||||
assert.Equal(t, aliases, storedAliases)
|
||||
|
||||
testStudioAlias(ctx, t, studio)
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Error(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func testStudioAlias(ctx context.Context, t *testing.T, s *models.Studio) {
|
||||
qb := db.Studio
|
||||
if err := s.LoadAliases(ctx, qb); err != nil {
|
||||
t.Error(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// ensure no alias to begin with
|
||||
assert.Len(t, s.Aliases.List(), 0)
|
||||
|
||||
aliases := []string{"alias1", "alias2"}
|
||||
|
||||
// update alias and ensure was updated
|
||||
input := models.StudioPartial{
|
||||
ID: s.ID,
|
||||
Aliases: &models.UpdateStrings{
|
||||
Values: aliases,
|
||||
Mode: models.RelationshipUpdateModeSet,
|
||||
},
|
||||
}
|
||||
var err error
|
||||
s, err = qb.UpdatePartial(ctx, input)
|
||||
if err != nil {
|
||||
t.Error(err.Error())
|
||||
}
|
||||
|
||||
if err := s.LoadAliases(ctx, qb); err != nil {
|
||||
t.Error(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, aliases, s.Aliases.List())
|
||||
|
||||
// remove alias and ensure was updated
|
||||
input = models.StudioPartial{
|
||||
ID: s.ID,
|
||||
Aliases: &models.UpdateStrings{
|
||||
Values: aliases,
|
||||
Mode: models.RelationshipUpdateModeRemove,
|
||||
},
|
||||
}
|
||||
s, err = qb.UpdatePartial(ctx, input)
|
||||
if err != nil {
|
||||
t.Error(err.Error())
|
||||
}
|
||||
|
||||
if err := s.LoadAliases(ctx, qb); err != nil {
|
||||
t.Error(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
assert.Len(t, s.Aliases.List(), 0)
|
||||
}
|
||||
|
||||
// TestStudioQueryFast does a quick test for major errors, no result verification
|
||||
func TestStudioQueryFast(t *testing.T) {
|
||||
|
||||
|
||||
@@ -29,6 +29,9 @@ var (
|
||||
performersAliasesJoinTable = goqu.T(performersAliasesTable)
|
||||
performersTagsJoinTable = goqu.T(performersTagsTable)
|
||||
performersStashIDsJoinTable = goqu.T("performer_stash_ids")
|
||||
|
||||
studiosAliasesJoinTable = goqu.T(studioAliasesTable)
|
||||
studiosStashIDsJoinTable = goqu.T("studio_stash_ids")
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -233,6 +236,21 @@ var (
|
||||
table: goqu.T(studioTable),
|
||||
idColumn: goqu.T(studioTable).Col(idColumn),
|
||||
}
|
||||
|
||||
studiosAliasesTableMgr = &stringTable{
|
||||
table: table{
|
||||
table: studiosAliasesJoinTable,
|
||||
idColumn: studiosAliasesJoinTable.Col(studioIDColumn),
|
||||
},
|
||||
stringColumn: studiosAliasesJoinTable.Col(studioAliasColumn),
|
||||
}
|
||||
|
||||
studiosStashIDsTableMgr = &stashIDTable{
|
||||
table: table{
|
||||
table: studiosStashIDsJoinTable,
|
||||
idColumn: studiosStashIDsJoinTable.Col(studioIDColumn),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
|
||||
Reference in New Issue
Block a user