Refactor stashbox package (#5699)

* Move stashbox package under pkg
* Remove StashBox from method names
* Add fingerprint conversion methods to Fingerprint

Refactor Fingerprints methods

* Make FindSceneByFingerprints accept fingerprints not scene ids
* Refactor SubmitSceneDraft to not require readers
* Have SubmitFingerprints accept scenes

Remove SceneReader dependency

* Move ScrapedScene to models package
* Move ScrapedImage into models package
* Move ScrapedGallery into models package
* Move Scene relationship matching out of stashbox package

This is now expected to be done in the client code

* Remove TagFinder dependency from stashbox.Client
* Make stashbox scene find full hierarchy of studios
* Move studio resolution into separate method
* Move studio matching out of stashbox package

This is now client code responsibility

* Move performer matching out of FindPerformerByID and FindPerformerByName
* Refactor performer querying logic and remove unused stashbox models

Renames FindStashBoxPerformersByPerformerNames to QueryPerformers and accepts names instead of performer ids

* Refactor SubmitPerformerDraft to not load relationships

This will be the responsibility of the calling code

* Remove repository references
This commit is contained in:
WithoutPants
2025-03-25 10:30:51 +11:00
committed by GitHub
parent 5d3d02e1e7
commit db7d45792e
43 changed files with 1292 additions and 1163 deletions

View File

@@ -13,7 +13,6 @@ import (
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/plugin/hook"
"github.com/stashapp/stash/pkg/scraper"
"github.com/stashapp/stash/pkg/scraper/stashbox"
)
var (
@@ -138,10 +137,6 @@ func (r *Resolver) withReadTxn(ctx context.Context, fn func(ctx context.Context)
return r.repository.WithReadTxn(ctx, fn)
}
func (r *Resolver) stashboxRepository() stashbox.Repository {
return stashbox.NewRepository(r.repository)
}
func (r *queryResolver) MarkerWall(ctx context.Context, q *string) (ret []*models.SceneMarker, err error) {
if err := r.withReadTxn(ctx, func(ctx context.Context) error {
ret, err = r.repository.SceneMarker.Wall(ctx, q)

View File

@@ -7,6 +7,10 @@ import (
"github.com/stashapp/stash/internal/manager"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/scene"
"github.com/stashapp/stash/pkg/sliceutil/stringslice"
"github.com/stashapp/stash/pkg/stashbox"
)
func (r *mutationResolver) SubmitStashBoxFingerprints(ctx context.Context, input StashBoxFingerprintSubmissionInput) (bool, error) {
@@ -15,8 +19,23 @@ func (r *mutationResolver) SubmitStashBoxFingerprints(ctx context.Context, input
return false, err
}
ids, err := stringslice.StringSliceToIntSlice(input.SceneIds)
if err != nil {
return false, err
}
client := r.newStashBoxClient(*b)
return client.SubmitStashBoxFingerprints(ctx, input.SceneIds)
var scenes []*models.Scene
if err := r.withReadTxn(ctx, func(ctx context.Context) error {
scenes, err = r.sceneService.FindMany(ctx, ids, scene.LoadStashIDs, scene.LoadFiles)
return err
}); err != nil {
return false, err
}
return client.SubmitFingerprints(ctx, scenes)
}
func (r *mutationResolver) StashBoxBatchPerformerTag(ctx context.Context, input manager.StashBoxBatchTagInput) (string, error) {
@@ -69,17 +88,76 @@ func (r *mutationResolver) SubmitStashBoxSceneDraft(ctx context.Context, input S
logger.Errorf("Error getting scene cover: %v", err)
}
if err := scene.LoadURLs(ctx, r.repository.Scene); err != nil {
return fmt.Errorf("loading scene URLs: %w", err)
draft, err := r.makeSceneDraft(ctx, scene, cover)
if err != nil {
return err
}
res, err = client.SubmitSceneDraft(ctx, scene, cover)
res, err = client.SubmitSceneDraft(ctx, *draft)
return err
})
return res, err
}
func (r *mutationResolver) makeSceneDraft(ctx context.Context, s *models.Scene, cover []byte) (*stashbox.SceneDraft, error) {
if err := s.LoadURLs(ctx, r.repository.Scene); err != nil {
return nil, fmt.Errorf("loading scene URLs: %w", err)
}
if err := s.LoadStashIDs(ctx, r.repository.Scene); err != nil {
return nil, err
}
draft := &stashbox.SceneDraft{
Scene: s,
}
pqb := r.repository.Performer
sqb := r.repository.Studio
if s.StudioID != nil {
var err error
draft.Studio, err = sqb.Find(ctx, *s.StudioID)
if err != nil {
return nil, err
}
if draft.Studio == nil {
return nil, fmt.Errorf("studio with id %d not found", *s.StudioID)
}
if err := draft.Studio.LoadStashIDs(ctx, r.repository.Studio); err != nil {
return nil, err
}
}
// submit all file fingerprints
if err := s.LoadFiles(ctx, r.repository.Scene); err != nil {
return nil, err
}
scenePerformers, err := pqb.FindBySceneID(ctx, s.ID)
if err != nil {
return nil, err
}
for _, p := range scenePerformers {
if err := p.LoadStashIDs(ctx, pqb); err != nil {
return nil, err
}
}
draft.Performers = scenePerformers
draft.Tags, err = r.repository.Tag.FindBySceneID(ctx, s.ID)
if err != nil {
return nil, err
}
draft.Cover = cover
return draft, nil
}
func (r *mutationResolver) SubmitStashBoxPerformerDraft(ctx context.Context, input StashBoxDraftSubmissionInput) (*string, error) {
b, err := resolveStashBox(input.StashBoxIndex, input.StashBoxEndpoint)
if err != nil {
@@ -105,7 +183,22 @@ func (r *mutationResolver) SubmitStashBoxPerformerDraft(ctx context.Context, inp
return fmt.Errorf("performer with id %d not found", id)
}
res, err = client.SubmitPerformerDraft(ctx, performer)
pqb := r.repository.Performer
if err := performer.LoadAliases(ctx, pqb); err != nil {
return err
}
if err := performer.LoadURLs(ctx, pqb); err != nil {
return err
}
if err := performer.LoadStashIDs(ctx, pqb); err != nil {
return err
}
img, _ := pqb.GetImage(ctx, performer.ID)
res, err = client.SubmitPerformerDraft(ctx, performer, img)
return err
})

View File

@@ -6,9 +6,10 @@ import (
"fmt"
"strconv"
"github.com/stashapp/stash/pkg/match"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/scraper"
"github.com/stashapp/stash/pkg/scraper/stashbox"
"github.com/stashapp/stash/pkg/sliceutil"
"github.com/stashapp/stash/pkg/sliceutil/stringslice"
)
@@ -29,7 +30,7 @@ func (r *queryResolver) ScrapePerformerURL(ctx context.Context, url string) (*mo
return marshalScrapedPerformer(content)
}
func (r *queryResolver) ScrapeSceneQuery(ctx context.Context, scraperID string, query string) ([]*scraper.ScrapedScene, error) {
func (r *queryResolver) ScrapeSceneQuery(ctx context.Context, scraperID string, query string) ([]*models.ScrapedScene, error) {
if query == "" {
return nil, nil
}
@@ -47,7 +48,7 @@ func (r *queryResolver) ScrapeSceneQuery(ctx context.Context, scraperID string,
return ret, nil
}
func (r *queryResolver) ScrapeSceneURL(ctx context.Context, url string) (*scraper.ScrapedScene, error) {
func (r *queryResolver) ScrapeSceneURL(ctx context.Context, url string) (*models.ScrapedScene, error) {
content, err := r.scraperCache().ScrapeURL(ctx, url, scraper.ScrapeContentTypeScene)
if err != nil {
return nil, err
@@ -61,7 +62,7 @@ func (r *queryResolver) ScrapeSceneURL(ctx context.Context, url string) (*scrape
return ret, nil
}
func (r *queryResolver) ScrapeGalleryURL(ctx context.Context, url string) (*scraper.ScrapedGallery, error) {
func (r *queryResolver) ScrapeGalleryURL(ctx context.Context, url string) (*models.ScrapedGallery, error) {
content, err := r.scraperCache().ScrapeURL(ctx, url, scraper.ScrapeContentTypeGallery)
if err != nil {
return nil, err
@@ -75,7 +76,7 @@ func (r *queryResolver) ScrapeGalleryURL(ctx context.Context, url string) (*scra
return ret, nil
}
func (r *queryResolver) ScrapeImageURL(ctx context.Context, url string) (*scraper.ScrapedImage, error) {
func (r *queryResolver) ScrapeImageURL(ctx context.Context, url string) (*models.ScrapedImage, error) {
content, err := r.scraperCache().ScrapeURL(ctx, url, scraper.ScrapeContentTypeImage)
if err != nil {
return nil, err
@@ -129,8 +130,8 @@ func (r *queryResolver) ScrapeGroupURL(ctx context.Context, url string) (*models
return group, nil
}
func (r *queryResolver) ScrapeSingleScene(ctx context.Context, source scraper.Source, input ScrapeSingleSceneInput) ([]*scraper.ScrapedScene, error) {
var ret []*scraper.ScrapedScene
func (r *queryResolver) ScrapeSingleScene(ctx context.Context, source scraper.Source, input ScrapeSingleSceneInput) ([]*models.ScrapedScene, error) {
var ret []*models.ScrapedScene
var sceneID int
if input.SceneID != nil {
@@ -182,9 +183,14 @@ func (r *queryResolver) ScrapeSingleScene(ctx context.Context, source scraper.So
switch {
case input.SceneID != nil:
ret, err = client.FindStashBoxSceneByFingerprints(ctx, sceneID)
var fps []models.Fingerprints
fps, err = r.getScenesFingerprints(ctx, []int{sceneID})
if err != nil {
return nil, err
}
ret, err = client.FindSceneByFingerprints(ctx, fps[0])
case input.Query != nil:
ret, err = client.QueryStashBoxScene(ctx, *input.Query)
ret, err = client.QueryScene(ctx, *input.Query)
default:
return nil, fmt.Errorf("%w: scene_id or query must be set", ErrInput)
}
@@ -192,6 +198,11 @@ func (r *queryResolver) ScrapeSingleScene(ctx context.Context, source scraper.So
if err != nil {
return nil, err
}
// TODO - this should happen after any scene is scraped
if err := r.matchScenesRelationships(ctx, ret, *source.StashBoxEndpoint); err != nil {
return nil, err
}
default:
return nil, fmt.Errorf("%w: scraper_id or stash_box_index must be set", ErrInput)
}
@@ -199,7 +210,7 @@ func (r *queryResolver) ScrapeSingleScene(ctx context.Context, source scraper.So
return ret, nil
}
func (r *queryResolver) ScrapeMultiScenes(ctx context.Context, source scraper.Source, input ScrapeMultiScenesInput) ([][]*scraper.ScrapedScene, error) {
func (r *queryResolver) ScrapeMultiScenes(ctx context.Context, source scraper.Source, input ScrapeMultiScenesInput) ([][]*models.ScrapedScene, error) {
if source.ScraperID != nil {
return nil, ErrNotImplemented
} else if source.StashBoxIndex != nil || source.StashBoxEndpoint != nil {
@@ -215,12 +226,89 @@ func (r *queryResolver) ScrapeMultiScenes(ctx context.Context, source scraper.So
return nil, err
}
return client.FindStashBoxScenesByFingerprints(ctx, sceneIDs)
fps, err := r.getScenesFingerprints(ctx, sceneIDs)
if err != nil {
return nil, err
}
ret, err := client.FindScenesByFingerprints(ctx, fps)
if err != nil {
return nil, err
}
// match relationships - this mutates the existing scenes so we can
// just flatten the slice and pass it in
flat := sliceutil.Flatten(ret)
if err := r.matchScenesRelationships(ctx, flat, *source.StashBoxEndpoint); err != nil {
return nil, err
}
return ret, nil
}
return nil, errors.New("scraper_id or stash_box_index must be set")
}
func (r *queryResolver) getScenesFingerprints(ctx context.Context, ids []int) ([]models.Fingerprints, error) {
fingerprints := make([]models.Fingerprints, len(ids))
if err := r.withReadTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Scene
for i, sceneID := range ids {
scene, err := qb.Find(ctx, sceneID)
if err != nil {
return err
}
if scene == nil {
return fmt.Errorf("scene with id %d not found", sceneID)
}
if err := scene.LoadFiles(ctx, qb); err != nil {
return err
}
var sceneFPs models.Fingerprints
for _, f := range scene.Files.List() {
sceneFPs = append(sceneFPs, f.Fingerprints...)
}
fingerprints[i] = sceneFPs
}
return nil
}); err != nil {
return nil, err
}
return fingerprints, nil
}
// matchSceneRelationships accepts scraped scenes and attempts to match its relationships to existing stash models.
func (r *queryResolver) matchScenesRelationships(ctx context.Context, ss []*models.ScrapedScene, endpoint string) error {
if err := r.withReadTxn(ctx, func(ctx context.Context) error {
matcher := match.SceneRelationships{
PerformerFinder: r.repository.Performer,
TagFinder: r.repository.Tag,
StudioFinder: r.repository.Studio,
}
for _, s := range ss {
if err := matcher.MatchRelationships(ctx, s, endpoint); err != nil {
return err
}
}
return nil
}); err != nil {
return err
}
return nil
}
func (r *queryResolver) ScrapeSingleStudio(ctx context.Context, source scraper.Source, input ScrapeSingleStudioInput) ([]*models.ScrapedStudio, error) {
if source.StashBoxIndex != nil || source.StashBoxEndpoint != nil {
b, err := resolveStashBox(source.StashBoxIndex, source.StashBoxEndpoint)
@@ -231,7 +319,7 @@ func (r *queryResolver) ScrapeSingleStudio(ctx context.Context, source scraper.S
client := r.newStashBoxClient(*b)
var ret []*models.ScrapedStudio
out, err := client.FindStashBoxStudio(ctx, *input.Query)
out, err := client.FindStudio(ctx, *input.Query)
if err != nil {
return nil, err
@@ -240,6 +328,17 @@ func (r *queryResolver) ScrapeSingleStudio(ctx context.Context, source scraper.S
}
if len(ret) > 0 {
if err := r.withReadTxn(ctx, func(ctx context.Context) error {
for _, studio := range ret {
if err := match.ScrapedStudioHierarchy(ctx, r.repository.Studio, studio, *source.StashBoxEndpoint); err != nil {
return err
}
}
return nil
}); err != nil {
return nil, err
}
return ret, nil
}
@@ -285,23 +384,29 @@ func (r *queryResolver) ScrapeSinglePerformer(ctx context.Context, source scrape
client := r.newStashBoxClient(*b)
var res []*stashbox.StashBoxPerformerQueryResult
var query string
switch {
case input.PerformerID != nil:
res, err = client.FindStashBoxPerformersByNames(ctx, []string{*input.PerformerID})
names, err := r.findPerformerNames(ctx, []string{*input.PerformerID})
if err != nil {
return nil, err
}
query = names[0]
case input.Query != nil:
res, err = client.QueryStashBoxPerformer(ctx, *input.Query)
query = *input.Query
default:
return nil, ErrNotImplemented
}
if query == "" {
return nil, nil
}
ret, err = client.QueryPerformer(ctx, query)
if err != nil {
return nil, err
}
if len(res) > 0 {
ret = res[0].Results
}
default:
return nil, errors.New("scraper_id or stash_box_index must be set")
}
@@ -313,6 +418,11 @@ func (r *queryResolver) ScrapeMultiPerformers(ctx context.Context, source scrape
if source.ScraperID != nil {
return nil, ErrNotImplemented
} else if source.StashBoxIndex != nil || source.StashBoxEndpoint != nil {
names, err := r.findPerformerNames(ctx, input.PerformerIds)
if err != nil {
return nil, err
}
b, err := resolveStashBox(source.StashBoxIndex, source.StashBoxEndpoint)
if err != nil {
return nil, err
@@ -320,14 +430,40 @@ func (r *queryResolver) ScrapeMultiPerformers(ctx context.Context, source scrape
client := r.newStashBoxClient(*b)
return client.FindStashBoxPerformersByPerformerNames(ctx, input.PerformerIds)
return client.QueryPerformers(ctx, names)
}
return nil, errors.New("scraper_id or stash_box_index must be set")
}
func (r *queryResolver) ScrapeSingleGallery(ctx context.Context, source scraper.Source, input ScrapeSingleGalleryInput) ([]*scraper.ScrapedGallery, error) {
var ret []*scraper.ScrapedGallery
func (r *queryResolver) findPerformerNames(ctx context.Context, performerIDs []string) ([]string, error) {
ids, err := stringslice.StringSliceToIntSlice(performerIDs)
if err != nil {
return nil, err
}
names := make([]string, len(ids))
if err := r.withReadTxn(ctx, func(ctx context.Context) error {
p, err := r.repository.Performer.FindMany(ctx, ids)
if err != nil {
return err
}
for i, pp := range p {
names[i] = pp.Name
}
return nil
}); err != nil {
return nil, err
}
return names, nil
}
func (r *queryResolver) ScrapeSingleGallery(ctx context.Context, source scraper.Source, input ScrapeSingleGalleryInput) ([]*models.ScrapedGallery, error) {
var ret []*models.ScrapedGallery
if source.StashBoxIndex != nil || source.StashBoxEndpoint != nil {
return nil, ErrNotSupported
@@ -369,7 +505,7 @@ func (r *queryResolver) ScrapeSingleGallery(ctx context.Context, source scraper.
return ret, nil
}
func (r *queryResolver) ScrapeSingleImage(ctx context.Context, source scraper.Source, input ScrapeSingleImageInput) ([]*scraper.ScrapedImage, error) {
func (r *queryResolver) ScrapeSingleImage(ctx context.Context, source scraper.Source, input ScrapeSingleImageInput) ([]*models.ScrapedImage, error) {
if source.StashBoxIndex != nil {
return nil, ErrNotSupported
}

View File

@@ -9,8 +9,8 @@ import (
// marshalScrapedScenes converts ScrapedContent into ScrapedScene. If conversion fails, an
// error is returned to the caller.
func marshalScrapedScenes(content []scraper.ScrapedContent) ([]*scraper.ScrapedScene, error) {
var ret []*scraper.ScrapedScene
func marshalScrapedScenes(content []scraper.ScrapedContent) ([]*models.ScrapedScene, error) {
var ret []*models.ScrapedScene
for _, c := range content {
if c == nil {
// graphql schema requires scenes to be non-nil
@@ -18,9 +18,9 @@ func marshalScrapedScenes(content []scraper.ScrapedContent) ([]*scraper.ScrapedS
}
switch s := c.(type) {
case *scraper.ScrapedScene:
case *models.ScrapedScene:
ret = append(ret, s)
case scraper.ScrapedScene:
case models.ScrapedScene:
ret = append(ret, &s)
default:
return nil, fmt.Errorf("%w: cannot turn ScrapedContent into ScrapedScene", models.ErrConversion)
@@ -55,8 +55,8 @@ func marshalScrapedPerformers(content []scraper.ScrapedContent) ([]*models.Scrap
// marshalScrapedGalleries converts ScrapedContent into ScrapedGallery. If
// conversion fails, an error is returned.
func marshalScrapedGalleries(content []scraper.ScrapedContent) ([]*scraper.ScrapedGallery, error) {
var ret []*scraper.ScrapedGallery
func marshalScrapedGalleries(content []scraper.ScrapedContent) ([]*models.ScrapedGallery, error) {
var ret []*models.ScrapedGallery
for _, c := range content {
if c == nil {
// graphql schema requires galleries to be non-nil
@@ -64,9 +64,9 @@ func marshalScrapedGalleries(content []scraper.ScrapedContent) ([]*scraper.Scrap
}
switch g := c.(type) {
case *scraper.ScrapedGallery:
case *models.ScrapedGallery:
ret = append(ret, g)
case scraper.ScrapedGallery:
case models.ScrapedGallery:
ret = append(ret, &g)
default:
return nil, fmt.Errorf("%w: cannot turn ScrapedContent into ScrapedGallery", models.ErrConversion)
@@ -76,8 +76,8 @@ func marshalScrapedGalleries(content []scraper.ScrapedContent) ([]*scraper.Scrap
return ret, nil
}
func marshalScrapedImages(content []scraper.ScrapedContent) ([]*scraper.ScrapedImage, error) {
var ret []*scraper.ScrapedImage
func marshalScrapedImages(content []scraper.ScrapedContent) ([]*models.ScrapedImage, error) {
var ret []*models.ScrapedImage
for _, c := range content {
if c == nil {
// graphql schema requires images to be non-nil
@@ -85,9 +85,9 @@ func marshalScrapedImages(content []scraper.ScrapedContent) ([]*scraper.ScrapedI
}
switch g := c.(type) {
case *scraper.ScrapedImage:
case *models.ScrapedImage:
ret = append(ret, g)
case scraper.ScrapedImage:
case models.ScrapedImage:
ret = append(ret, &g)
default:
return nil, fmt.Errorf("%w: cannot turn ScrapedContent into ScrapedImage", models.ErrConversion)
@@ -131,7 +131,7 @@ func marshalScrapedPerformer(content scraper.ScrapedContent) (*models.ScrapedPer
}
// marshalScrapedScene will marshal a single scraped scene
func marshalScrapedScene(content scraper.ScrapedContent) (*scraper.ScrapedScene, error) {
func marshalScrapedScene(content scraper.ScrapedContent) (*models.ScrapedScene, error) {
s, err := marshalScrapedScenes([]scraper.ScrapedContent{content})
if err != nil {
return nil, err
@@ -141,7 +141,7 @@ func marshalScrapedScene(content scraper.ScrapedContent) (*scraper.ScrapedScene,
}
// marshalScrapedGallery will marshal a single scraped gallery
func marshalScrapedGallery(content scraper.ScrapedContent) (*scraper.ScrapedGallery, error) {
func marshalScrapedGallery(content scraper.ScrapedContent) (*models.ScrapedGallery, error) {
g, err := marshalScrapedGalleries([]scraper.ScrapedContent{content})
if err != nil {
return nil, err
@@ -151,7 +151,7 @@ func marshalScrapedGallery(content scraper.ScrapedContent) (*scraper.ScrapedGall
}
// marshalScrapedImage will marshal a single scraped image
func marshalScrapedImage(content scraper.ScrapedContent) (*scraper.ScrapedImage, error) {
func marshalScrapedImage(content scraper.ScrapedContent) (*models.ScrapedImage, error) {
g, err := marshalScrapedImages([]scraper.ScrapedContent{content})
if err != nil {
return nil, err

View File

@@ -7,11 +7,11 @@ import (
"github.com/stashapp/stash/internal/manager"
"github.com/stashapp/stash/internal/manager/config"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/scraper/stashbox"
"github.com/stashapp/stash/pkg/stashbox"
)
func (r *Resolver) newStashBoxClient(box models.StashBox) *stashbox.Client {
return stashbox.NewClient(box, r.stashboxRepository(), manager.GetInstance().Config.GetScraperExcludeTagPatterns())
return stashbox.NewClient(box, manager.GetInstance().Config.GetScraperExcludeTagPatterns())
}
func resolveStashBoxFn(indexField, endpointField string) func(index *int, endpoint *string) (*models.StashBox, error) {