Feature: Tag StashID support (#6255)

This commit is contained in:
Gykes
2025-11-12 19:24:09 -08:00
committed by GitHub
parent a08d2e258a
commit c99825a453
30 changed files with 387 additions and 34 deletions

View File

@@ -45,7 +45,7 @@ func (r SceneRelationships) MatchRelationships(ctx context.Context, s *models.Sc
}
for _, t := range s.Tags {
err := ScrapedTag(ctx, r.TagFinder, t)
err := ScrapedTag(ctx, r.TagFinder, t, endpoint)
if err != nil {
return err
}
@@ -190,11 +190,29 @@ func ScrapedGroup(ctx context.Context, qb GroupNamesFinder, storedID *string, na
// ScrapedTag matches the provided tag with the tags
// in the database and sets the ID field if one is found.
func ScrapedTag(ctx context.Context, qb models.TagQueryer, s *models.ScrapedTag) error {
func ScrapedTag(ctx context.Context, qb models.TagQueryer, s *models.ScrapedTag, stashBoxEndpoint string) error {
if s.StoredID != nil {
return nil
}
// Check if a tag with the StashID already exists
if stashBoxEndpoint != "" && s.RemoteSiteID != nil {
if finder, ok := qb.(models.TagFinder); ok {
tags, err := finder.FindByStashID(ctx, models.StashID{
StashID: *s.RemoteSiteID,
Endpoint: stashBoxEndpoint,
})
if err != nil {
return err
}
if len(tags) > 0 {
id := strconv.Itoa(tags[0].ID)
s.StoredID = &id
return nil
}
}
}
t, err := tag.ByName(ctx, qb, s.Name)
if err != nil {

View File

@@ -6,20 +6,22 @@ import (
jsoniter "github.com/json-iterator/go"
"github.com/stashapp/stash/pkg/fsutil"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/models/json"
)
type Tag struct {
Name string `json:"name,omitempty"`
SortName string `json:"sort_name,omitempty"`
Description string `json:"description,omitempty"`
Favorite bool `json:"favorite,omitempty"`
Aliases []string `json:"aliases,omitempty"`
Image string `json:"image,omitempty"`
Parents []string `json:"parents,omitempty"`
IgnoreAutoTag bool `json:"ignore_auto_tag,omitempty"`
CreatedAt json.JSONTime `json:"created_at,omitempty"`
UpdatedAt json.JSONTime `json:"updated_at,omitempty"`
Name string `json:"name,omitempty"`
SortName string `json:"sort_name,omitempty"`
Description string `json:"description,omitempty"`
Favorite bool `json:"favorite,omitempty"`
Aliases []string `json:"aliases,omitempty"`
Image string `json:"image,omitempty"`
Parents []string `json:"parents,omitempty"`
IgnoreAutoTag bool `json:"ignore_auto_tag,omitempty"`
StashIDs []models.StashID `json:"stash_ids,omitempty"`
CreatedAt json.JSONTime `json:"created_at,omitempty"`
UpdatedAt json.JSONTime `json:"updated_at,omitempty"`
}
func (s Tag) Filename() string {

View File

@@ -427,6 +427,29 @@ func (_m *TagReaderWriter) FindBySceneMarkerID(ctx context.Context, sceneMarkerI
return r0, r1
}
// FindByStashID provides a mock function with given fields: ctx, stashID
func (_m *TagReaderWriter) FindByStashID(ctx context.Context, stashID models.StashID) ([]*models.Tag, error) {
ret := _m.Called(ctx, stashID)
var r0 []*models.Tag
if rf, ok := ret.Get(0).(func(context.Context, models.StashID) []*models.Tag); ok {
r0 = rf(ctx, stashID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Tag)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, models.StashID) error); ok {
r1 = rf(ctx, stashID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// FindByStudioID provides a mock function with given fields: ctx, studioID
func (_m *TagReaderWriter) FindByStudioID(ctx context.Context, studioID int) ([]*models.Tag, error) {
ret := _m.Called(ctx, studioID)
@@ -565,6 +588,29 @@ func (_m *TagReaderWriter) GetParentIDs(ctx context.Context, relatedID int) ([]i
return r0, r1
}
// GetStashIDs provides a mock function with given fields: ctx, relatedID
func (_m *TagReaderWriter) GetStashIDs(ctx context.Context, relatedID int) ([]models.StashID, error) {
ret := _m.Called(ctx, relatedID)
var r0 []models.StashID
if rf, ok := ret.Get(0).(func(context.Context, int) []models.StashID); ok {
r0 = rf(ctx, relatedID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]models.StashID)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
r1 = rf(ctx, relatedID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// HasImage provides a mock function with given fields: ctx, tagID
func (_m *TagReaderWriter) HasImage(ctx context.Context, tagID int) (bool, error) {
ret := _m.Called(ctx, tagID)

View File

@@ -447,12 +447,31 @@ func (p *ScrapedPerformer) ToPartial(endpoint string, excluded map[string]bool,
type ScrapedTag struct {
// Set if tag matched
StoredID *string `json:"stored_id"`
Name string `json:"name"`
StoredID *string `json:"stored_id"`
Name string `json:"name"`
RemoteSiteID *string `json:"remote_site_id"`
}
func (ScrapedTag) IsScrapedContent() {}
func (t *ScrapedTag) ToTag(endpoint string, excluded map[string]bool) *Tag {
currentTime := time.Now()
ret := NewTag()
ret.Name = t.Name
if t.RemoteSiteID != nil && endpoint != "" {
ret.StashIDs = NewRelatedStashIDs([]StashID{
{
Endpoint: endpoint,
StashID: *t.RemoteSiteID,
UpdatedAt: currentTime,
},
})
}
return &ret
}
func ScrapedTagSortFunction(a, b *ScrapedTag) int {
return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
}

View File

@@ -15,9 +15,10 @@ type Tag struct {
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
Aliases RelatedStrings `json:"aliases"`
ParentIDs RelatedIDs `json:"parent_ids"`
ChildIDs RelatedIDs `json:"tag_ids"`
Aliases RelatedStrings `json:"aliases"`
ParentIDs RelatedIDs `json:"parent_ids"`
ChildIDs RelatedIDs `json:"tag_ids"`
StashIDs RelatedStashIDs `json:"stash_ids"`
}
func NewTag() Tag {
@@ -46,6 +47,12 @@ func (s *Tag) LoadChildIDs(ctx context.Context, l TagRelationLoader) error {
})
}
func (s *Tag) LoadStashIDs(ctx context.Context, l StashIDLoader) error {
return s.StashIDs.load(func() ([]StashID, error) {
return l.GetStashIDs(ctx, s.ID)
})
}
type TagPartial struct {
Name OptionalString
SortName OptionalString
@@ -58,6 +65,7 @@ type TagPartial struct {
Aliases *UpdateStrings
ParentIDs *UpdateIDs
ChildIDs *UpdateIDs
StashIDs *UpdateStashIDs
}
func NewTagPartial() TagPartial {

View File

@@ -25,6 +25,7 @@ type TagFinder interface {
FindByStudioID(ctx context.Context, studioID int) ([]*Tag, error)
FindByName(ctx context.Context, name string, nocase bool) (*Tag, error)
FindByNames(ctx context.Context, names []string, nocase bool) ([]*Tag, error)
FindByStashID(ctx context.Context, stashID StashID) ([]*Tag, error)
}
// TagQueryer provides methods to query tags.
@@ -87,6 +88,7 @@ type TagReader interface {
AliasLoader
TagRelationLoader
StashIDLoader
All(ctx context.Context) ([]*Tag, error)
GetImage(ctx context.Context, tagID int) ([]byte, error)

View File

@@ -15,7 +15,8 @@ func postProcessTags(ctx context.Context, tqb models.TagQueryer, scrapedTags []*
ret = make([]*models.ScrapedTag, 0, len(scrapedTags))
for _, t := range scrapedTags {
err := match.ScrapedTag(ctx, tqb, t)
// Pass empty string for endpoint since this is used by general scrapers, not just stash-box
err := match.ScrapedTag(ctx, tqb, t, "")
if err != nil {
return nil, err
}

View File

@@ -102,6 +102,7 @@ func (db *Anonymiser) deleteStashIDs() error {
func() error { return db.truncateTable("scene_stash_ids") },
func() error { return db.truncateTable("studio_stash_ids") },
func() error { return db.truncateTable("performer_stash_ids") },
func() error { return db.truncateTable("tag_stash_ids") },
})
}

View File

@@ -34,7 +34,7 @@ const (
cacheSizeEnv = "STASH_SQLITE_CACHE_SIZE"
)
var appSchemaVersion uint = 73
var appSchemaVersion uint = 74
//go:embed migrations/*.sql
var migrationsBox embed.FS

View File

@@ -0,0 +1,7 @@
CREATE TABLE `tag_stash_ids` (
`tag_id` integer,
`endpoint` varchar(255),
`stash_id` varchar(36),
`updated_at` datetime not null default '1970-01-01T00:00:00Z',
foreign key(`tag_id`) references `tags`(`id`) on delete CASCADE
);

View File

@@ -27,8 +27,9 @@ func testStashIDReaderWriter(ctx context.Context, t *testing.T, r stashIDReaderW
const stashIDStr = "stashID"
const endpoint = "endpoint"
stashID := models.StashID{
StashID: stashIDStr,
Endpoint: endpoint,
StashID: stashIDStr,
Endpoint: endpoint,
UpdatedAt: epochTime,
}
// update stash ids and ensure was updated

View File

@@ -47,6 +47,7 @@ var (
tagsAliasesJoinTable = goqu.T(tagAliasesTable)
tagRelationsJoinTable = goqu.T(tagRelationsTable)
tagsStashIDsJoinTable = goqu.T("tag_stash_ids")
)
var (
@@ -375,6 +376,13 @@ var (
}
tagsChildTagsTableMgr = *tagsParentTagsTableMgr.invert()
tagsStashIDsTableMgr = &stashIDTable{
table: table{
table: tagsStashIDsJoinTable,
idColumn: tagsStashIDsJoinTable.Col(tagIDColumn),
},
}
)
var (

View File

@@ -101,7 +101,8 @@ func (r *tagRowRecord) fromPartial(o models.TagPartial) {
type tagRepositoryType struct {
repository
aliases stringRepository
aliases stringRepository
stashIDs stashIDRepository
scenes joinRepository
images joinRepository
@@ -121,6 +122,12 @@ var (
},
stringColumn: tagAliasColumn,
},
stashIDs: stashIDRepository{
repository{
tableName: "tag_stash_ids",
idColumn: tagIDColumn,
},
},
scenes: joinRepository{
repository: repository{
tableName: scenesTagsTable,
@@ -199,6 +206,12 @@ func (qb *TagStore) Create(ctx context.Context, newObject *models.Tag) error {
}
}
if newObject.StashIDs.Loaded() {
if err := tagsStashIDsTableMgr.insertJoins(ctx, id, newObject.StashIDs.List()); err != nil {
return err
}
}
updated, err := qb.find(ctx, id)
if err != nil {
return fmt.Errorf("finding after create: %w", err)
@@ -242,6 +255,12 @@ func (qb *TagStore) UpdatePartial(ctx context.Context, id int, partial models.Ta
}
}
if partial.StashIDs != nil {
if err := tagsStashIDsTableMgr.modifyJoins(ctx, id, partial.StashIDs.StashIDs, partial.StashIDs.Mode); err != nil {
return nil, err
}
}
return qb.find(ctx, id)
}
@@ -271,6 +290,12 @@ func (qb *TagStore) Update(ctx context.Context, updatedObject *models.Tag) error
}
}
if updatedObject.StashIDs.Loaded() {
if err := tagsStashIDsTableMgr.replaceJoins(ctx, updatedObject.ID, updatedObject.StashIDs.List()); err != nil {
return err
}
}
return nil
}
@@ -509,6 +534,24 @@ func (qb *TagStore) FindByNames(ctx context.Context, names []string, nocase bool
return ret, nil
}
func (qb *TagStore) FindByStashID(ctx context.Context, stashID models.StashID) ([]*models.Tag, error) {
sq := dialect.From(tagsStashIDsJoinTable).Select(tagsStashIDsJoinTable.Col(tagIDColumn)).Where(
tagsStashIDsJoinTable.Col("stash_id").Eq(stashID.StashID),
tagsStashIDsJoinTable.Col("endpoint").Eq(stashID.Endpoint),
)
idsQuery := qb.selectDataset().Where(
qb.table().Col(idColumn).In(sq),
)
ret, err := qb.getMany(ctx, idsQuery)
if err != nil {
return nil, fmt.Errorf("getting tags for stash ID %s: %w", stashID.StashID, err)
}
return ret, nil
}
func (qb *TagStore) GetParentIDs(ctx context.Context, relatedID int) ([]int, error) {
return tagsParentTagsTableMgr.get(ctx, relatedID)
}
@@ -779,6 +822,14 @@ func (qb *TagStore) UpdateAliases(ctx context.Context, tagID int, aliases []stri
return tagRepository.aliases.replace(ctx, tagID, aliases)
}
func (qb *TagStore) GetStashIDs(ctx context.Context, tagID int) ([]models.StashID, error) {
return tagsStashIDsTableMgr.get(ctx, tagID)
}
func (qb *TagStore) UpdateStashIDs(ctx context.Context, tagID int, stashIDs []models.StashID) error {
return tagsStashIDsTableMgr.replaceJoins(ctx, tagID, stashIDs)
}
func (qb *TagStore) Merge(ctx context.Context, source []int, destination int) error {
if len(source) == 0 {
return nil
@@ -840,6 +891,19 @@ AND NOT EXISTS(SELECT 1 FROM `+table+` o WHERE o.`+idColumn+` = `+table+`.`+idCo
return err
}
// Merge StashIDs - move all source StashIDs to destination (ignoring duplicates)
_, err = dbWrapper.Exec(ctx, `UPDATE OR IGNORE `+"tag_stash_ids"+`
SET tag_id = ?
WHERE tag_id IN `+inBinding, args...)
if err != nil {
return err
}
// Delete remaining source StashIDs that couldn't be moved (duplicates)
if _, err := dbWrapper.Exec(ctx, `DELETE FROM tag_stash_ids WHERE tag_id IN `+inBinding, srcArgs...); err != nil {
return err
}
for _, id := range source {
err = qb.Destroy(ctx, id)
if err != nil {

View File

@@ -900,6 +900,66 @@ func TestTagUpdateAlias(t *testing.T) {
}
}
func TestTagStashIDs(t *testing.T) {
if err := withTxn(func(ctx context.Context) error {
qb := db.Tag
// create tag to test against
const name = "TestTagStashIDs"
tag := models.Tag{
Name: name,
}
err := qb.Create(ctx, &tag)
if err != nil {
return fmt.Errorf("Error creating tag: %s", err.Error())
}
testStashIDReaderWriter(ctx, t, qb, tag.ID)
return nil
}); err != nil {
t.Error(err.Error())
}
}
func TestTagFindByStashID(t *testing.T) {
withTxn(func(ctx context.Context) error {
qb := db.Tag
// create tag to test against
const name = "TestTagFindByStashID"
const stashID = "stashid"
const endpoint = "endpoint"
tag := models.Tag{
Name: name,
StashIDs: models.NewRelatedStashIDs([]models.StashID{{StashID: stashID, Endpoint: endpoint}}),
}
err := qb.Create(ctx, &tag)
if err != nil {
return fmt.Errorf("Error creating tag: %s", err.Error())
}
// find by stash ID
tags, err := qb.FindByStashID(ctx, models.StashID{StashID: stashID, Endpoint: endpoint})
if err != nil {
return fmt.Errorf("Error finding by stash ID: %s", err.Error())
}
assert.Len(t, tags, 1)
assert.Equal(t, tag.ID, tags[0].ID)
// find by non-existent stash ID
tags, err = qb.FindByStashID(ctx, models.StashID{StashID: "nonexistent", Endpoint: endpoint})
if err != nil {
return fmt.Errorf("Error finding by stash ID: %s", err.Error())
}
assert.Len(t, tags, 0)
return nil
})
}
func TestTagMerge(t *testing.T) {
assert := assert.New(t)

View File

@@ -205,7 +205,8 @@ func (c Client) sceneFragmentToScrapedScene(ctx context.Context, s *graphql.Scen
for _, t := range s.Tags {
st := &models.ScrapedTag{
Name: t.Name,
Name: t.Name,
RemoteSiteID: &t.ID,
}
ss.Tags = append(ss.Tags, st)
}
@@ -242,8 +243,9 @@ type SceneDraft struct {
Performers []*models.Performer
// StashIDs must be loaded
Studio *models.Studio
Tags []*models.Tag
Cover []byte
// StashIDs must be loaded
Tags []*models.Tag
Cover []byte
}
func (c Client) SubmitSceneDraft(ctx context.Context, d SceneDraft) (*string, error) {
@@ -347,7 +349,17 @@ func newSceneDraftInput(d SceneDraft, endpoint string) graphql.SceneDraftInput {
var tags []*graphql.DraftEntityInput
sceneTags := d.Tags
for _, tag := range sceneTags {
tags = append(tags, &graphql.DraftEntityInput{Name: tag.Name})
tagDraft := graphql.DraftEntityInput{Name: tag.Name}
stashIDs := tag.StashIDs.List()
for _, stashID := range stashIDs {
if stashID.Endpoint == endpoint {
tagDraft.ID = &stashID.StashID
break
}
}
tags = append(tags, &tagDraft)
}
draft.Tags = tags

View File

@@ -16,6 +16,7 @@ type FinderAliasImageGetter interface {
GetAliases(ctx context.Context, studioID int) ([]string, error)
GetImage(ctx context.Context, tagID int) ([]byte, error)
FindByChildTagID(ctx context.Context, childID int) ([]*models.Tag, error)
models.StashIDLoader
}
// ToJSON converts a Tag object into its JSON equivalent.
@@ -37,6 +38,15 @@ func ToJSON(ctx context.Context, reader FinderAliasImageGetter, tag *models.Tag)
newTagJSON.Aliases = aliases
if err := tag.LoadStashIDs(ctx, reader); err != nil {
return nil, fmt.Errorf("loading tag stash ids: %w", err)
}
stashIDs := tag.StashIDs.List()
if len(stashIDs) > 0 {
newTagJSON.StashIDs = stashIDs
}
image, err := reader.GetImage(ctx, tag.ID)
if err != nil {
logger.Errorf("Error getting tag image: %v", err)

View File

@@ -126,6 +126,13 @@ func TestToJSON(t *testing.T) {
db.Tag.On("GetAliases", testCtx, withParentsID).Return(nil, nil).Once()
db.Tag.On("GetAliases", testCtx, errParentsID).Return(nil, nil).Once()
db.Tag.On("GetStashIDs", testCtx, tagID).Return(nil, nil).Once()
db.Tag.On("GetStashIDs", testCtx, noImageID).Return(nil, nil).Once()
db.Tag.On("GetStashIDs", testCtx, errImageID).Return(nil, nil).Once()
// errAliasID test fails before GetStashIDs is called, so no mock needed
db.Tag.On("GetStashIDs", testCtx, withParentsID).Return(nil, nil).Once()
db.Tag.On("GetStashIDs", testCtx, errParentsID).Return(nil, nil).Once()
db.Tag.On("GetImage", testCtx, tagID).Return(imageBytes, nil).Once()
db.Tag.On("GetImage", testCtx, noImageID).Return(nil, nil).Once()
db.Tag.On("GetImage", testCtx, errImageID).Return(nil, imageErr).Once()

View File

@@ -42,6 +42,7 @@ func (i *Importer) PreImport(ctx context.Context) error {
Description: i.Input.Description,
Favorite: i.Input.Favorite,
IgnoreAutoTag: i.Input.IgnoreAutoTag,
StashIDs: models.NewRelatedStashIDs(i.Input.StashIDs),
CreatedAt: i.Input.CreatedAt.GetTime(),
UpdatedAt: i.Input.UpdatedAt.GetTime(),
}