Restructure data layer (#2532)

* Add new txn manager interface
* Add txn management to sqlite
* Rename get to getByID
* Add contexts to repository methods
* Update query builders
* Add context to reader writer interfaces
* Use repository in resolver
* Tighten interfaces
* Tighten interfaces in dlna
* Tighten interfaces in match package
* Tighten interfaces in scraper package
* Tighten interfaces in scan code
* Tighten interfaces on autotag package
* Remove ReaderWriter usage
* Merge database package into sqlite
This commit is contained in:
WithoutPants
2022-05-19 17:49:32 +10:00
parent 7b5bd80515
commit 964b559309
244 changed files with 7377 additions and 6699 deletions

View File

@@ -11,6 +11,7 @@ import (
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/plugin" "github.com/stashapp/stash/pkg/plugin"
"github.com/stashapp/stash/pkg/scraper" "github.com/stashapp/stash/pkg/scraper"
"github.com/stashapp/stash/pkg/txn"
) )
var ( var (
@@ -30,7 +31,9 @@ type hookExecutor interface {
} }
type Resolver struct { type Resolver struct {
txnManager models.TransactionManager txnManager txn.Manager
repository models.Repository
hookExecutor hookExecutor hookExecutor hookExecutor
} }
@@ -85,17 +88,13 @@ type studioResolver struct{ *Resolver }
type movieResolver struct{ *Resolver } type movieResolver struct{ *Resolver }
type tagResolver struct{ *Resolver } type tagResolver struct{ *Resolver }
func (r *Resolver) withTxn(ctx context.Context, fn func(r models.Repository) error) error { func (r *Resolver) withTxn(ctx context.Context, fn func(ctx context.Context) error) error {
return r.txnManager.WithTxn(ctx, fn) return txn.WithTxn(ctx, r.txnManager, fn)
}
func (r *Resolver) withReadTxn(ctx context.Context, fn func(r models.ReaderRepository) error) error {
return r.txnManager.WithReadTxn(ctx, fn)
} }
func (r *queryResolver) MarkerWall(ctx context.Context, q *string) (ret []*models.SceneMarker, err error) { func (r *queryResolver) MarkerWall(ctx context.Context, q *string) (ret []*models.SceneMarker, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = repo.SceneMarker().Wall(q) ret, err = r.repository.SceneMarker.Wall(ctx, q)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -104,8 +103,8 @@ func (r *queryResolver) MarkerWall(ctx context.Context, q *string) (ret []*model
} }
func (r *queryResolver) SceneWall(ctx context.Context, q *string) (ret []*models.Scene, err error) { func (r *queryResolver) SceneWall(ctx context.Context, q *string) (ret []*models.Scene, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = repo.Scene().Wall(q) ret, err = r.repository.Scene.Wall(ctx, q)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -115,8 +114,8 @@ func (r *queryResolver) SceneWall(ctx context.Context, q *string) (ret []*models
} }
func (r *queryResolver) MarkerStrings(ctx context.Context, q *string, sort *string) (ret []*models.MarkerStringsResultType, err error) { func (r *queryResolver) MarkerStrings(ctx context.Context, q *string, sort *string) (ret []*models.MarkerStringsResultType, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = repo.SceneMarker().GetMarkerStrings(q, sort) ret, err = r.repository.SceneMarker.GetMarkerStrings(ctx, q, sort)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -127,24 +126,25 @@ func (r *queryResolver) MarkerStrings(ctx context.Context, q *string, sort *stri
func (r *queryResolver) Stats(ctx context.Context) (*StatsResultType, error) { func (r *queryResolver) Stats(ctx context.Context) (*StatsResultType, error) {
var ret StatsResultType var ret StatsResultType
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
scenesQB := repo.Scene() repo := r.repository
imageQB := repo.Image() scenesQB := repo.Scene
galleryQB := repo.Gallery() imageQB := repo.Image
studiosQB := repo.Studio() galleryQB := repo.Gallery
performersQB := repo.Performer() studiosQB := repo.Studio
moviesQB := repo.Movie() performersQB := repo.Performer
tagsQB := repo.Tag() moviesQB := repo.Movie
scenesCount, _ := scenesQB.Count() tagsQB := repo.Tag
scenesSize, _ := scenesQB.Size() scenesCount, _ := scenesQB.Count(ctx)
scenesDuration, _ := scenesQB.Duration() scenesSize, _ := scenesQB.Size(ctx)
imageCount, _ := imageQB.Count() scenesDuration, _ := scenesQB.Duration(ctx)
imageSize, _ := imageQB.Size() imageCount, _ := imageQB.Count(ctx)
galleryCount, _ := galleryQB.Count() imageSize, _ := imageQB.Size(ctx)
performersCount, _ := performersQB.Count() galleryCount, _ := galleryQB.Count(ctx)
studiosCount, _ := studiosQB.Count() performersCount, _ := performersQB.Count(ctx)
moviesCount, _ := moviesQB.Count() studiosCount, _ := studiosQB.Count(ctx)
tagsCount, _ := tagsQB.Count() moviesCount, _ := moviesQB.Count(ctx)
tagsCount, _ := tagsQB.Count(ctx)
ret = StatsResultType{ ret = StatsResultType{
SceneCount: scenesCount, SceneCount: scenesCount,
@@ -202,15 +202,15 @@ func (r *queryResolver) SceneMarkerTags(ctx context.Context, scene_id string) ([
var keys []int var keys []int
tags := make(map[int]*SceneMarkerTag) tags := make(map[int]*SceneMarkerTag)
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
sceneMarkers, err := repo.SceneMarker().FindBySceneID(sceneID) sceneMarkers, err := r.repository.SceneMarker.FindBySceneID(ctx, sceneID)
if err != nil { if err != nil {
return err return err
} }
tqb := repo.Tag() tqb := r.repository.Tag
for _, sceneMarker := range sceneMarkers { for _, sceneMarker := range sceneMarkers {
markerPrimaryTag, err := tqb.Find(sceneMarker.PrimaryTagID) markerPrimaryTag, err := tqb.Find(ctx, sceneMarker.PrimaryTagID)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -24,12 +24,12 @@ func (r *galleryResolver) Title(ctx context.Context, obj *models.Gallery) (*stri
} }
func (r *galleryResolver) Images(ctx context.Context, obj *models.Gallery) (ret []*models.Image, err error) { func (r *galleryResolver) Images(ctx context.Context, obj *models.Gallery) (ret []*models.Image, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
var err error var err error
// #2376 - sort images by path // #2376 - sort images by path
// doing this via Query is really slow, so stick with FindByGalleryID // doing this via Query is really slow, so stick with FindByGalleryID
ret, err = repo.Image().FindByGalleryID(obj.ID) ret, err = r.repository.Image.FindByGalleryID(ctx, obj.ID)
if err != nil { if err != nil {
return err return err
} }
@@ -43,9 +43,9 @@ func (r *galleryResolver) Images(ctx context.Context, obj *models.Gallery) (ret
} }
func (r *galleryResolver) Cover(ctx context.Context, obj *models.Gallery) (ret *models.Image, err error) { func (r *galleryResolver) Cover(ctx context.Context, obj *models.Gallery) (ret *models.Image, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
// doing this via Query is really slow, so stick with FindByGalleryID // doing this via Query is really slow, so stick with FindByGalleryID
imgs, err := repo.Image().FindByGalleryID(obj.ID) imgs, err := r.repository.Image.FindByGalleryID(ctx, obj.ID)
if err != nil { if err != nil {
return err return err
} }
@@ -100,9 +100,9 @@ func (r *galleryResolver) Rating(ctx context.Context, obj *models.Gallery) (*int
} }
func (r *galleryResolver) Scenes(ctx context.Context, obj *models.Gallery) (ret []*models.Scene, err error) { func (r *galleryResolver) Scenes(ctx context.Context, obj *models.Gallery) (ret []*models.Scene, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
var err error var err error
ret, err = repo.Scene().FindByGalleryID(obj.ID) ret, err = r.repository.Scene.FindByGalleryID(ctx, obj.ID)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -116,9 +116,9 @@ func (r *galleryResolver) Studio(ctx context.Context, obj *models.Gallery) (ret
return nil, nil return nil, nil
} }
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
var err error var err error
ret, err = repo.Studio().Find(int(obj.StudioID.Int64)) ret, err = r.repository.Studio.Find(ctx, int(obj.StudioID.Int64))
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -128,9 +128,9 @@ func (r *galleryResolver) Studio(ctx context.Context, obj *models.Gallery) (ret
} }
func (r *galleryResolver) Tags(ctx context.Context, obj *models.Gallery) (ret []*models.Tag, err error) { func (r *galleryResolver) Tags(ctx context.Context, obj *models.Gallery) (ret []*models.Tag, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
var err error var err error
ret, err = repo.Tag().FindByGalleryID(obj.ID) ret, err = r.repository.Tag.FindByGalleryID(ctx, obj.ID)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -140,9 +140,9 @@ func (r *galleryResolver) Tags(ctx context.Context, obj *models.Gallery) (ret []
} }
func (r *galleryResolver) Performers(ctx context.Context, obj *models.Gallery) (ret []*models.Performer, err error) { func (r *galleryResolver) Performers(ctx context.Context, obj *models.Gallery) (ret []*models.Performer, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
var err error var err error
ret, err = repo.Performer().FindByGalleryID(obj.ID) ret, err = r.repository.Performer.FindByGalleryID(ctx, obj.ID)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -152,9 +152,9 @@ func (r *galleryResolver) Performers(ctx context.Context, obj *models.Gallery) (
} }
func (r *galleryResolver) ImageCount(ctx context.Context, obj *models.Gallery) (ret int, err error) { func (r *galleryResolver) ImageCount(ctx context.Context, obj *models.Gallery) (ret int, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
var err error var err error
ret, err = repo.Image().CountByGalleryID(obj.ID) ret, err = r.repository.Image.CountByGalleryID(ctx, obj.ID)
return err return err
}); err != nil { }); err != nil {
return 0, err return 0, err

View File

@@ -45,9 +45,9 @@ func (r *imageResolver) Paths(ctx context.Context, obj *models.Image) (*ImagePat
} }
func (r *imageResolver) Galleries(ctx context.Context, obj *models.Image) (ret []*models.Gallery, err error) { func (r *imageResolver) Galleries(ctx context.Context, obj *models.Image) (ret []*models.Gallery, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
var err error var err error
ret, err = repo.Gallery().FindByImageID(obj.ID) ret, err = r.repository.Gallery.FindByImageID(ctx, obj.ID)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -61,8 +61,8 @@ func (r *imageResolver) Studio(ctx context.Context, obj *models.Image) (ret *mod
return nil, nil return nil, nil
} }
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = repo.Studio().Find(int(obj.StudioID.Int64)) ret, err = r.repository.Studio.Find(ctx, int(obj.StudioID.Int64))
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -72,8 +72,8 @@ func (r *imageResolver) Studio(ctx context.Context, obj *models.Image) (ret *mod
} }
func (r *imageResolver) Tags(ctx context.Context, obj *models.Image) (ret []*models.Tag, err error) { func (r *imageResolver) Tags(ctx context.Context, obj *models.Image) (ret []*models.Tag, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = repo.Tag().FindByImageID(obj.ID) ret, err = r.repository.Tag.FindByImageID(ctx, obj.ID)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -83,8 +83,8 @@ func (r *imageResolver) Tags(ctx context.Context, obj *models.Image) (ret []*mod
} }
func (r *imageResolver) Performers(ctx context.Context, obj *models.Image) (ret []*models.Performer, err error) { func (r *imageResolver) Performers(ctx context.Context, obj *models.Image) (ret []*models.Performer, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = repo.Performer().FindByImageID(obj.ID) ret, err = r.repository.Performer.FindByImageID(ctx, obj.ID)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err

View File

@@ -56,8 +56,8 @@ func (r *movieResolver) Rating(ctx context.Context, obj *models.Movie) (*int, er
func (r *movieResolver) Studio(ctx context.Context, obj *models.Movie) (ret *models.Studio, err error) { func (r *movieResolver) Studio(ctx context.Context, obj *models.Movie) (ret *models.Studio, err error) {
if obj.StudioID.Valid { if obj.StudioID.Valid {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = repo.Studio().Find(int(obj.StudioID.Int64)) ret, err = r.repository.Studio.Find(ctx, int(obj.StudioID.Int64))
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -92,9 +92,9 @@ func (r *movieResolver) FrontImagePath(ctx context.Context, obj *models.Movie) (
func (r *movieResolver) BackImagePath(ctx context.Context, obj *models.Movie) (*string, error) { func (r *movieResolver) BackImagePath(ctx context.Context, obj *models.Movie) (*string, error) {
// don't return any thing if there is no back image // don't return any thing if there is no back image
var img []byte var img []byte
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
var err error var err error
img, err = repo.Movie().GetBackImage(obj.ID) img, err = r.repository.Movie.GetBackImage(ctx, obj.ID)
if err != nil { if err != nil {
return err return err
} }
@@ -115,8 +115,8 @@ func (r *movieResolver) BackImagePath(ctx context.Context, obj *models.Movie) (*
func (r *movieResolver) SceneCount(ctx context.Context, obj *models.Movie) (ret *int, err error) { func (r *movieResolver) SceneCount(ctx context.Context, obj *models.Movie) (ret *int, err error) {
var res int var res int
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
res, err = repo.Scene().CountByMovieID(obj.ID) res, err = r.repository.Scene.CountByMovieID(ctx, obj.ID)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -126,9 +126,9 @@ func (r *movieResolver) SceneCount(ctx context.Context, obj *models.Movie) (ret
} }
func (r *movieResolver) Scenes(ctx context.Context, obj *models.Movie) (ret []*models.Scene, err error) { func (r *movieResolver) Scenes(ctx context.Context, obj *models.Movie) (ret []*models.Scene, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
var err error var err error
ret, err = repo.Scene().FindByMovieID(obj.ID) ret, err = r.repository.Scene.FindByMovieID(ctx, obj.ID)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err

View File

@@ -142,8 +142,8 @@ func (r *performerResolver) ImagePath(ctx context.Context, obj *models.Performer
} }
func (r *performerResolver) Tags(ctx context.Context, obj *models.Performer) (ret []*models.Tag, err error) { func (r *performerResolver) Tags(ctx context.Context, obj *models.Performer) (ret []*models.Tag, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = repo.Tag().FindByPerformerID(obj.ID) ret, err = r.repository.Tag.FindByPerformerID(ctx, obj.ID)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -154,8 +154,8 @@ func (r *performerResolver) Tags(ctx context.Context, obj *models.Performer) (re
func (r *performerResolver) SceneCount(ctx context.Context, obj *models.Performer) (ret *int, err error) { func (r *performerResolver) SceneCount(ctx context.Context, obj *models.Performer) (ret *int, err error) {
var res int var res int
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
res, err = repo.Scene().CountByPerformerID(obj.ID) res, err = r.repository.Scene.CountByPerformerID(ctx, obj.ID)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -166,8 +166,8 @@ func (r *performerResolver) SceneCount(ctx context.Context, obj *models.Performe
func (r *performerResolver) ImageCount(ctx context.Context, obj *models.Performer) (ret *int, err error) { func (r *performerResolver) ImageCount(ctx context.Context, obj *models.Performer) (ret *int, err error) {
var res int var res int
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
res, err = image.CountByPerformerID(repo.Image(), obj.ID) res, err = image.CountByPerformerID(ctx, r.repository.Image, obj.ID)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -178,8 +178,8 @@ func (r *performerResolver) ImageCount(ctx context.Context, obj *models.Performe
func (r *performerResolver) GalleryCount(ctx context.Context, obj *models.Performer) (ret *int, err error) { func (r *performerResolver) GalleryCount(ctx context.Context, obj *models.Performer) (ret *int, err error) {
var res int var res int
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
res, err = gallery.CountByPerformerID(repo.Gallery(), obj.ID) res, err = gallery.CountByPerformerID(ctx, r.repository.Gallery, obj.ID)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -189,8 +189,8 @@ func (r *performerResolver) GalleryCount(ctx context.Context, obj *models.Perfor
} }
func (r *performerResolver) Scenes(ctx context.Context, obj *models.Performer) (ret []*models.Scene, err error) { func (r *performerResolver) Scenes(ctx context.Context, obj *models.Performer) (ret []*models.Scene, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = repo.Scene().FindByPerformerID(obj.ID) ret, err = r.repository.Scene.FindByPerformerID(ctx, obj.ID)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -200,8 +200,8 @@ func (r *performerResolver) Scenes(ctx context.Context, obj *models.Performer) (
} }
func (r *performerResolver) StashIds(ctx context.Context, obj *models.Performer) (ret []*models.StashID, err error) { func (r *performerResolver) StashIds(ctx context.Context, obj *models.Performer) (ret []*models.StashID, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = repo.Performer().GetStashIDs(obj.ID) ret, err = r.repository.Performer.GetStashIDs(ctx, obj.ID)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -256,8 +256,8 @@ func (r *performerResolver) UpdatedAt(ctx context.Context, obj *models.Performer
} }
func (r *performerResolver) Movies(ctx context.Context, obj *models.Performer) (ret []*models.Movie, err error) { func (r *performerResolver) Movies(ctx context.Context, obj *models.Performer) (ret []*models.Movie, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = repo.Movie().FindByPerformerID(obj.ID) ret, err = r.repository.Movie.FindByPerformerID(ctx, obj.ID)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -268,8 +268,8 @@ func (r *performerResolver) Movies(ctx context.Context, obj *models.Performer) (
func (r *performerResolver) MovieCount(ctx context.Context, obj *models.Performer) (ret *int, err error) { func (r *performerResolver) MovieCount(ctx context.Context, obj *models.Performer) (ret *int, err error) {
var res int var res int
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
res, err = repo.Movie().CountByPerformerID(obj.ID) res, err = r.repository.Movie.CountByPerformerID(ctx, obj.ID)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err

View File

@@ -116,8 +116,8 @@ func (r *sceneResolver) Paths(ctx context.Context, obj *models.Scene) (*ScenePat
} }
func (r *sceneResolver) SceneMarkers(ctx context.Context, obj *models.Scene) (ret []*models.SceneMarker, err error) { func (r *sceneResolver) SceneMarkers(ctx context.Context, obj *models.Scene) (ret []*models.SceneMarker, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = repo.SceneMarker().FindBySceneID(obj.ID) ret, err = r.repository.SceneMarker.FindBySceneID(ctx, obj.ID)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -127,8 +127,8 @@ func (r *sceneResolver) SceneMarkers(ctx context.Context, obj *models.Scene) (re
} }
func (r *sceneResolver) Captions(ctx context.Context, obj *models.Scene) (ret []*models.SceneCaption, err error) { func (r *sceneResolver) Captions(ctx context.Context, obj *models.Scene) (ret []*models.SceneCaption, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = repo.Scene().GetCaptions(obj.ID) ret, err = r.repository.Scene.GetCaptions(ctx, obj.ID)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -138,8 +138,8 @@ func (r *sceneResolver) Captions(ctx context.Context, obj *models.Scene) (ret []
} }
func (r *sceneResolver) Galleries(ctx context.Context, obj *models.Scene) (ret []*models.Gallery, err error) { func (r *sceneResolver) Galleries(ctx context.Context, obj *models.Scene) (ret []*models.Gallery, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = repo.Gallery().FindBySceneID(obj.ID) ret, err = r.repository.Gallery.FindBySceneID(ctx, obj.ID)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -153,8 +153,8 @@ func (r *sceneResolver) Studio(ctx context.Context, obj *models.Scene) (ret *mod
return nil, nil return nil, nil
} }
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = repo.Studio().Find(int(obj.StudioID.Int64)) ret, err = r.repository.Studio.Find(ctx, int(obj.StudioID.Int64))
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -164,17 +164,17 @@ func (r *sceneResolver) Studio(ctx context.Context, obj *models.Scene) (ret *mod
} }
func (r *sceneResolver) Movies(ctx context.Context, obj *models.Scene) (ret []*SceneMovie, err error) { func (r *sceneResolver) Movies(ctx context.Context, obj *models.Scene) (ret []*SceneMovie, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := repo.Scene() qb := r.repository.Scene
mqb := repo.Movie() mqb := r.repository.Movie
sceneMovies, err := qb.GetMovies(obj.ID) sceneMovies, err := qb.GetMovies(ctx, obj.ID)
if err != nil { if err != nil {
return err return err
} }
for _, sm := range sceneMovies { for _, sm := range sceneMovies {
movie, err := mqb.Find(sm.MovieID) movie, err := mqb.Find(ctx, sm.MovieID)
if err != nil { if err != nil {
return err return err
} }
@@ -200,8 +200,8 @@ func (r *sceneResolver) Movies(ctx context.Context, obj *models.Scene) (ret []*S
} }
func (r *sceneResolver) Tags(ctx context.Context, obj *models.Scene) (ret []*models.Tag, err error) { func (r *sceneResolver) Tags(ctx context.Context, obj *models.Scene) (ret []*models.Tag, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = repo.Tag().FindBySceneID(obj.ID) ret, err = r.repository.Tag.FindBySceneID(ctx, obj.ID)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -211,8 +211,8 @@ func (r *sceneResolver) Tags(ctx context.Context, obj *models.Scene) (ret []*mod
} }
func (r *sceneResolver) Performers(ctx context.Context, obj *models.Scene) (ret []*models.Performer, err error) { func (r *sceneResolver) Performers(ctx context.Context, obj *models.Scene) (ret []*models.Performer, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = repo.Performer().FindBySceneID(obj.ID) ret, err = r.repository.Performer.FindBySceneID(ctx, obj.ID)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -222,8 +222,8 @@ func (r *sceneResolver) Performers(ctx context.Context, obj *models.Scene) (ret
} }
func (r *sceneResolver) StashIds(ctx context.Context, obj *models.Scene) (ret []*models.StashID, err error) { func (r *sceneResolver) StashIds(ctx context.Context, obj *models.Scene) (ret []*models.StashID, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = repo.Scene().GetStashIDs(obj.ID) ret, err = r.repository.Scene.GetStashIDs(ctx, obj.ID)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err

View File

@@ -13,9 +13,9 @@ func (r *sceneMarkerResolver) Scene(ctx context.Context, obj *models.SceneMarker
panic("Invalid scene id") panic("Invalid scene id")
} }
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
sceneID := int(obj.SceneID.Int64) sceneID := int(obj.SceneID.Int64)
ret, err = repo.Scene().Find(sceneID) ret, err = r.repository.Scene.Find(ctx, sceneID)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -25,8 +25,8 @@ func (r *sceneMarkerResolver) Scene(ctx context.Context, obj *models.SceneMarker
} }
func (r *sceneMarkerResolver) PrimaryTag(ctx context.Context, obj *models.SceneMarker) (ret *models.Tag, err error) { func (r *sceneMarkerResolver) PrimaryTag(ctx context.Context, obj *models.SceneMarker) (ret *models.Tag, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = repo.Tag().Find(obj.PrimaryTagID) ret, err = r.repository.Tag.Find(ctx, obj.PrimaryTagID)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -36,8 +36,8 @@ func (r *sceneMarkerResolver) PrimaryTag(ctx context.Context, obj *models.SceneM
} }
func (r *sceneMarkerResolver) Tags(ctx context.Context, obj *models.SceneMarker) (ret []*models.Tag, err error) { func (r *sceneMarkerResolver) Tags(ctx context.Context, obj *models.SceneMarker) (ret []*models.Tag, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = repo.Tag().FindBySceneMarkerID(obj.ID) ret, err = r.repository.Tag.FindBySceneMarkerID(ctx, obj.ID)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err

View File

@@ -29,9 +29,9 @@ func (r *studioResolver) ImagePath(ctx context.Context, obj *models.Studio) (*st
imagePath := urlbuilders.NewStudioURLBuilder(baseURL, obj).GetStudioImageURL() imagePath := urlbuilders.NewStudioURLBuilder(baseURL, obj).GetStudioImageURL()
var hasImage bool var hasImage bool
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
var err error var err error
hasImage, err = repo.Studio().HasImage(obj.ID) hasImage, err = r.repository.Studio.HasImage(ctx, obj.ID)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -46,8 +46,8 @@ func (r *studioResolver) ImagePath(ctx context.Context, obj *models.Studio) (*st
} }
func (r *studioResolver) Aliases(ctx context.Context, obj *models.Studio) (ret []string, err error) { func (r *studioResolver) Aliases(ctx context.Context, obj *models.Studio) (ret []string, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = repo.Studio().GetAliases(obj.ID) ret, err = r.repository.Studio.GetAliases(ctx, obj.ID)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -58,8 +58,8 @@ func (r *studioResolver) Aliases(ctx context.Context, obj *models.Studio) (ret [
func (r *studioResolver) SceneCount(ctx context.Context, obj *models.Studio) (ret *int, err error) { func (r *studioResolver) SceneCount(ctx context.Context, obj *models.Studio) (ret *int, err error) {
var res int var res int
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
res, err = repo.Scene().CountByStudioID(obj.ID) res, err = r.repository.Scene.CountByStudioID(ctx, obj.ID)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -70,8 +70,8 @@ func (r *studioResolver) SceneCount(ctx context.Context, obj *models.Studio) (re
func (r *studioResolver) ImageCount(ctx context.Context, obj *models.Studio) (ret *int, err error) { func (r *studioResolver) ImageCount(ctx context.Context, obj *models.Studio) (ret *int, err error) {
var res int var res int
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
res, err = image.CountByStudioID(repo.Image(), obj.ID) res, err = image.CountByStudioID(ctx, r.repository.Image, obj.ID)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -82,8 +82,8 @@ func (r *studioResolver) ImageCount(ctx context.Context, obj *models.Studio) (re
func (r *studioResolver) GalleryCount(ctx context.Context, obj *models.Studio) (ret *int, err error) { func (r *studioResolver) GalleryCount(ctx context.Context, obj *models.Studio) (ret *int, err error) {
var res int var res int
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
res, err = gallery.CountByStudioID(repo.Gallery(), obj.ID) res, err = gallery.CountByStudioID(ctx, r.repository.Gallery, obj.ID)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -97,8 +97,8 @@ func (r *studioResolver) ParentStudio(ctx context.Context, obj *models.Studio) (
return nil, nil return nil, nil
} }
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = repo.Studio().Find(int(obj.ParentID.Int64)) ret, err = r.repository.Studio.Find(ctx, int(obj.ParentID.Int64))
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -108,8 +108,8 @@ func (r *studioResolver) ParentStudio(ctx context.Context, obj *models.Studio) (
} }
func (r *studioResolver) ChildStudios(ctx context.Context, obj *models.Studio) (ret []*models.Studio, err error) { func (r *studioResolver) ChildStudios(ctx context.Context, obj *models.Studio) (ret []*models.Studio, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = repo.Studio().FindChildren(obj.ID) ret, err = r.repository.Studio.FindChildren(ctx, obj.ID)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -119,8 +119,8 @@ func (r *studioResolver) ChildStudios(ctx context.Context, obj *models.Studio) (
} }
func (r *studioResolver) StashIds(ctx context.Context, obj *models.Studio) (ret []*models.StashID, err error) { func (r *studioResolver) StashIds(ctx context.Context, obj *models.Studio) (ret []*models.StashID, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = repo.Studio().GetStashIDs(obj.ID) ret, err = r.repository.Studio.GetStashIDs(ctx, obj.ID)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -153,8 +153,8 @@ func (r *studioResolver) UpdatedAt(ctx context.Context, obj *models.Studio) (*ti
} }
func (r *studioResolver) Movies(ctx context.Context, obj *models.Studio) (ret []*models.Movie, err error) { func (r *studioResolver) Movies(ctx context.Context, obj *models.Studio) (ret []*models.Movie, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = repo.Movie().FindByStudioID(obj.ID) ret, err = r.repository.Movie.FindByStudioID(ctx, obj.ID)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -165,8 +165,8 @@ func (r *studioResolver) Movies(ctx context.Context, obj *models.Studio) (ret []
func (r *studioResolver) MovieCount(ctx context.Context, obj *models.Studio) (ret *int, err error) { func (r *studioResolver) MovieCount(ctx context.Context, obj *models.Studio) (ret *int, err error) {
var res int var res int
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
res, err = repo.Movie().CountByStudioID(obj.ID) res, err = r.repository.Movie.CountByStudioID(ctx, obj.ID)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err

View File

@@ -11,8 +11,8 @@ import (
) )
func (r *tagResolver) Parents(ctx context.Context, obj *models.Tag) (ret []*models.Tag, err error) { func (r *tagResolver) Parents(ctx context.Context, obj *models.Tag) (ret []*models.Tag, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = repo.Tag().FindByChildTagID(obj.ID) ret, err = r.repository.Tag.FindByChildTagID(ctx, obj.ID)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -22,8 +22,8 @@ func (r *tagResolver) Parents(ctx context.Context, obj *models.Tag) (ret []*mode
} }
func (r *tagResolver) Children(ctx context.Context, obj *models.Tag) (ret []*models.Tag, err error) { func (r *tagResolver) Children(ctx context.Context, obj *models.Tag) (ret []*models.Tag, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = repo.Tag().FindByParentTagID(obj.ID) ret, err = r.repository.Tag.FindByParentTagID(ctx, obj.ID)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -33,8 +33,8 @@ func (r *tagResolver) Children(ctx context.Context, obj *models.Tag) (ret []*mod
} }
func (r *tagResolver) Aliases(ctx context.Context, obj *models.Tag) (ret []string, err error) { func (r *tagResolver) Aliases(ctx context.Context, obj *models.Tag) (ret []string, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = repo.Tag().GetAliases(obj.ID) ret, err = r.repository.Tag.GetAliases(ctx, obj.ID)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -45,8 +45,8 @@ func (r *tagResolver) Aliases(ctx context.Context, obj *models.Tag) (ret []strin
func (r *tagResolver) SceneCount(ctx context.Context, obj *models.Tag) (ret *int, err error) { func (r *tagResolver) SceneCount(ctx context.Context, obj *models.Tag) (ret *int, err error) {
var count int var count int
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
count, err = repo.Scene().CountByTagID(obj.ID) count, err = r.repository.Scene.CountByTagID(ctx, obj.ID)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -57,8 +57,8 @@ func (r *tagResolver) SceneCount(ctx context.Context, obj *models.Tag) (ret *int
func (r *tagResolver) SceneMarkerCount(ctx context.Context, obj *models.Tag) (ret *int, err error) { func (r *tagResolver) SceneMarkerCount(ctx context.Context, obj *models.Tag) (ret *int, err error) {
var count int var count int
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
count, err = repo.SceneMarker().CountByTagID(obj.ID) count, err = r.repository.SceneMarker.CountByTagID(ctx, obj.ID)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -69,8 +69,8 @@ func (r *tagResolver) SceneMarkerCount(ctx context.Context, obj *models.Tag) (re
func (r *tagResolver) ImageCount(ctx context.Context, obj *models.Tag) (ret *int, err error) { func (r *tagResolver) ImageCount(ctx context.Context, obj *models.Tag) (ret *int, err error) {
var res int var res int
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
res, err = image.CountByTagID(repo.Image(), obj.ID) res, err = image.CountByTagID(ctx, r.repository.Image, obj.ID)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -81,8 +81,8 @@ func (r *tagResolver) ImageCount(ctx context.Context, obj *models.Tag) (ret *int
func (r *tagResolver) GalleryCount(ctx context.Context, obj *models.Tag) (ret *int, err error) { func (r *tagResolver) GalleryCount(ctx context.Context, obj *models.Tag) (ret *int, err error) {
var res int var res int
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
res, err = gallery.CountByTagID(repo.Gallery(), obj.ID) res, err = gallery.CountByTagID(ctx, r.repository.Gallery, obj.ID)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -93,8 +93,8 @@ func (r *tagResolver) GalleryCount(ctx context.Context, obj *models.Tag) (ret *i
func (r *tagResolver) PerformerCount(ctx context.Context, obj *models.Tag) (ret *int, err error) { func (r *tagResolver) PerformerCount(ctx context.Context, obj *models.Tag) (ret *int, err error) {
var count int var count int
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
count, err = repo.Performer().CountByTagID(obj.ID) count, err = r.repository.Performer.CountByTagID(ctx, obj.ID)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err

View File

@@ -132,7 +132,9 @@ func (r *mutationResolver) ConfigureGeneral(ctx context.Context, input ConfigGen
} }
// validate changing VideoFileNamingAlgorithm // validate changing VideoFileNamingAlgorithm
if err := manager.ValidateVideoFileNamingAlgorithm(r.txnManager, *input.VideoFileNamingAlgorithm); err != nil { if err := r.withTxn(context.TODO(), func(ctx context.Context) error {
return manager.ValidateVideoFileNamingAlgorithm(ctx, r.repository.Scene, *input.VideoFileNamingAlgorithm)
}); err != nil {
return makeConfigGeneralResult(), err return makeConfigGeneralResult(), err
} }

View File

@@ -11,6 +11,7 @@ import (
"github.com/stashapp/stash/internal/manager" "github.com/stashapp/stash/internal/manager"
"github.com/stashapp/stash/pkg/file" "github.com/stashapp/stash/pkg/file"
"github.com/stashapp/stash/pkg/gallery"
"github.com/stashapp/stash/pkg/hash/md5" "github.com/stashapp/stash/pkg/hash/md5"
"github.com/stashapp/stash/pkg/image" "github.com/stashapp/stash/pkg/image"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
@@ -21,8 +22,8 @@ import (
) )
func (r *mutationResolver) getGallery(ctx context.Context, id int) (ret *models.Gallery, err error) { func (r *mutationResolver) getGallery(ctx context.Context, id int) (ret *models.Gallery, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = repo.Gallery().Find(id) ret, err = r.repository.Gallery.Find(ctx, id)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -80,26 +81,26 @@ func (r *mutationResolver) GalleryCreate(ctx context.Context, input GalleryCreat
// Start the transaction and save the gallery // Start the transaction and save the gallery
var gallery *models.Gallery var gallery *models.Gallery
if err := r.withTxn(ctx, func(repo models.Repository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := repo.Gallery() qb := r.repository.Gallery
var err error var err error
gallery, err = qb.Create(newGallery) gallery, err = qb.Create(ctx, newGallery)
if err != nil { if err != nil {
return err return err
} }
// Save the performers // Save the performers
if err := r.updateGalleryPerformers(qb, gallery.ID, input.PerformerIds); err != nil { if err := r.updateGalleryPerformers(ctx, qb, gallery.ID, input.PerformerIds); err != nil {
return err return err
} }
// Save the tags // Save the tags
if err := r.updateGalleryTags(qb, gallery.ID, input.TagIds); err != nil { if err := r.updateGalleryTags(ctx, qb, gallery.ID, input.TagIds); err != nil {
return err return err
} }
// Save the scenes // Save the scenes
if err := r.updateGalleryScenes(qb, gallery.ID, input.SceneIds); err != nil { if err := r.updateGalleryScenes(ctx, qb, gallery.ID, input.SceneIds); err != nil {
return err return err
} }
@@ -112,28 +113,32 @@ func (r *mutationResolver) GalleryCreate(ctx context.Context, input GalleryCreat
return r.getGallery(ctx, gallery.ID) return r.getGallery(ctx, gallery.ID)
} }
func (r *mutationResolver) updateGalleryPerformers(qb models.GalleryReaderWriter, galleryID int, performerIDs []string) error { func (r *mutationResolver) updateGalleryPerformers(ctx context.Context, qb gallery.PerformerUpdater, galleryID int, performerIDs []string) error {
ids, err := stringslice.StringSliceToIntSlice(performerIDs) ids, err := stringslice.StringSliceToIntSlice(performerIDs)
if err != nil { if err != nil {
return err return err
} }
return qb.UpdatePerformers(galleryID, ids) return qb.UpdatePerformers(ctx, galleryID, ids)
} }
func (r *mutationResolver) updateGalleryTags(qb models.GalleryReaderWriter, galleryID int, tagIDs []string) error { func (r *mutationResolver) updateGalleryTags(ctx context.Context, qb gallery.TagUpdater, galleryID int, tagIDs []string) error {
ids, err := stringslice.StringSliceToIntSlice(tagIDs) ids, err := stringslice.StringSliceToIntSlice(tagIDs)
if err != nil { if err != nil {
return err return err
} }
return qb.UpdateTags(galleryID, ids) return qb.UpdateTags(ctx, galleryID, ids)
} }
func (r *mutationResolver) updateGalleryScenes(qb models.GalleryReaderWriter, galleryID int, sceneIDs []string) error { type GallerySceneUpdater interface {
UpdateScenes(ctx context.Context, galleryID int, sceneIDs []int) error
}
func (r *mutationResolver) updateGalleryScenes(ctx context.Context, qb GallerySceneUpdater, galleryID int, sceneIDs []string) error {
ids, err := stringslice.StringSliceToIntSlice(sceneIDs) ids, err := stringslice.StringSliceToIntSlice(sceneIDs)
if err != nil { if err != nil {
return err return err
} }
return qb.UpdateScenes(galleryID, ids) return qb.UpdateScenes(ctx, galleryID, ids)
} }
func (r *mutationResolver) GalleryUpdate(ctx context.Context, input models.GalleryUpdateInput) (ret *models.Gallery, err error) { func (r *mutationResolver) GalleryUpdate(ctx context.Context, input models.GalleryUpdateInput) (ret *models.Gallery, err error) {
@@ -142,8 +147,8 @@ func (r *mutationResolver) GalleryUpdate(ctx context.Context, input models.Galle
} }
// Start the transaction and save the gallery // Start the transaction and save the gallery
if err := r.withTxn(ctx, func(repo models.Repository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.galleryUpdate(input, translator, repo) ret, err = r.galleryUpdate(ctx, input, translator)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -158,13 +163,13 @@ func (r *mutationResolver) GalleriesUpdate(ctx context.Context, input []*models.
inputMaps := getUpdateInputMaps(ctx) inputMaps := getUpdateInputMaps(ctx)
// Start the transaction and save the gallery // Start the transaction and save the gallery
if err := r.withTxn(ctx, func(repo models.Repository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
for i, gallery := range input { for i, gallery := range input {
translator := changesetTranslator{ translator := changesetTranslator{
inputMap: inputMaps[i], inputMap: inputMaps[i],
} }
thisGallery, err := r.galleryUpdate(*gallery, translator, repo) thisGallery, err := r.galleryUpdate(ctx, *gallery, translator)
if err != nil { if err != nil {
return err return err
} }
@@ -196,8 +201,8 @@ func (r *mutationResolver) GalleriesUpdate(ctx context.Context, input []*models.
return newRet, nil return newRet, nil
} }
func (r *mutationResolver) galleryUpdate(input models.GalleryUpdateInput, translator changesetTranslator, repo models.Repository) (*models.Gallery, error) { func (r *mutationResolver) galleryUpdate(ctx context.Context, input models.GalleryUpdateInput, translator changesetTranslator) (*models.Gallery, error) {
qb := repo.Gallery() qb := r.repository.Gallery
// Populate gallery from the input // Populate gallery from the input
galleryID, err := strconv.Atoi(input.ID) galleryID, err := strconv.Atoi(input.ID)
@@ -205,7 +210,7 @@ func (r *mutationResolver) galleryUpdate(input models.GalleryUpdateInput, transl
return nil, err return nil, err
} }
originalGallery, err := qb.Find(galleryID) originalGallery, err := qb.Find(ctx, galleryID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -244,28 +249,28 @@ func (r *mutationResolver) galleryUpdate(input models.GalleryUpdateInput, transl
// gallery scene is set from the scene only // gallery scene is set from the scene only
gallery, err := qb.UpdatePartial(updatedGallery) gallery, err := qb.UpdatePartial(ctx, updatedGallery)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Save the performers // Save the performers
if translator.hasField("performer_ids") { if translator.hasField("performer_ids") {
if err := r.updateGalleryPerformers(qb, galleryID, input.PerformerIds); err != nil { if err := r.updateGalleryPerformers(ctx, qb, galleryID, input.PerformerIds); err != nil {
return nil, err return nil, err
} }
} }
// Save the tags // Save the tags
if translator.hasField("tag_ids") { if translator.hasField("tag_ids") {
if err := r.updateGalleryTags(qb, galleryID, input.TagIds); err != nil { if err := r.updateGalleryTags(ctx, qb, galleryID, input.TagIds); err != nil {
return nil, err return nil, err
} }
} }
// Save the scenes // Save the scenes
if translator.hasField("scene_ids") { if translator.hasField("scene_ids") {
if err := r.updateGalleryScenes(qb, galleryID, input.SceneIds); err != nil { if err := r.updateGalleryScenes(ctx, qb, galleryID, input.SceneIds); err != nil {
return nil, err return nil, err
} }
} }
@@ -295,14 +300,14 @@ func (r *mutationResolver) BulkGalleryUpdate(ctx context.Context, input BulkGall
ret := []*models.Gallery{} ret := []*models.Gallery{}
// Start the transaction and save the galleries // Start the transaction and save the galleries
if err := r.withTxn(ctx, func(repo models.Repository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := repo.Gallery() qb := r.repository.Gallery
for _, galleryIDStr := range input.Ids { for _, galleryIDStr := range input.Ids {
galleryID, _ := strconv.Atoi(galleryIDStr) galleryID, _ := strconv.Atoi(galleryIDStr)
updatedGallery.ID = galleryID updatedGallery.ID = galleryID
gallery, err := qb.UpdatePartial(updatedGallery) gallery, err := qb.UpdatePartial(ctx, updatedGallery)
if err != nil { if err != nil {
return err return err
} }
@@ -311,36 +316,36 @@ func (r *mutationResolver) BulkGalleryUpdate(ctx context.Context, input BulkGall
// Save the performers // Save the performers
if translator.hasField("performer_ids") { if translator.hasField("performer_ids") {
performerIDs, err := adjustGalleryPerformerIDs(qb, galleryID, *input.PerformerIds) performerIDs, err := adjustGalleryPerformerIDs(ctx, qb, galleryID, *input.PerformerIds)
if err != nil { if err != nil {
return err return err
} }
if err := qb.UpdatePerformers(galleryID, performerIDs); err != nil { if err := qb.UpdatePerformers(ctx, galleryID, performerIDs); err != nil {
return err return err
} }
} }
// Save the tags // Save the tags
if translator.hasField("tag_ids") { if translator.hasField("tag_ids") {
tagIDs, err := adjustGalleryTagIDs(qb, galleryID, *input.TagIds) tagIDs, err := adjustGalleryTagIDs(ctx, qb, galleryID, *input.TagIds)
if err != nil { if err != nil {
return err return err
} }
if err := qb.UpdateTags(galleryID, tagIDs); err != nil { if err := qb.UpdateTags(ctx, galleryID, tagIDs); err != nil {
return err return err
} }
} }
// Save the scenes // Save the scenes
if translator.hasField("scene_ids") { if translator.hasField("scene_ids") {
sceneIDs, err := adjustGallerySceneIDs(qb, galleryID, *input.SceneIds) sceneIDs, err := adjustGallerySceneIDs(ctx, qb, galleryID, *input.SceneIds)
if err != nil { if err != nil {
return err return err
} }
if err := qb.UpdateScenes(galleryID, sceneIDs); err != nil { if err := qb.UpdateScenes(ctx, galleryID, sceneIDs); err != nil {
return err return err
} }
} }
@@ -367,8 +372,20 @@ func (r *mutationResolver) BulkGalleryUpdate(ctx context.Context, input BulkGall
return newRet, nil return newRet, nil
} }
func adjustGalleryPerformerIDs(qb models.GalleryReader, galleryID int, ids BulkUpdateIds) (ret []int, err error) { type GalleryPerformerGetter interface {
ret, err = qb.GetPerformerIDs(galleryID) GetPerformerIDs(ctx context.Context, galleryID int) ([]int, error)
}
type GalleryTagGetter interface {
GetTagIDs(ctx context.Context, galleryID int) ([]int, error)
}
type GallerySceneGetter interface {
GetSceneIDs(ctx context.Context, galleryID int) ([]int, error)
}
func adjustGalleryPerformerIDs(ctx context.Context, qb GalleryPerformerGetter, galleryID int, ids BulkUpdateIds) (ret []int, err error) {
ret, err = qb.GetPerformerIDs(ctx, galleryID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -376,8 +393,8 @@ func adjustGalleryPerformerIDs(qb models.GalleryReader, galleryID int, ids BulkU
return adjustIDs(ret, ids), nil return adjustIDs(ret, ids), nil
} }
func adjustGalleryTagIDs(qb models.GalleryReader, galleryID int, ids BulkUpdateIds) (ret []int, err error) { func adjustGalleryTagIDs(ctx context.Context, qb GalleryTagGetter, galleryID int, ids BulkUpdateIds) (ret []int, err error) {
ret, err = qb.GetTagIDs(galleryID) ret, err = qb.GetTagIDs(ctx, galleryID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -385,8 +402,8 @@ func adjustGalleryTagIDs(qb models.GalleryReader, galleryID int, ids BulkUpdateI
return adjustIDs(ret, ids), nil return adjustIDs(ret, ids), nil
} }
func adjustGallerySceneIDs(qb models.GalleryReader, galleryID int, ids BulkUpdateIds) (ret []int, err error) { func adjustGallerySceneIDs(ctx context.Context, qb GallerySceneGetter, galleryID int, ids BulkUpdateIds) (ret []int, err error) {
ret, err = qb.GetSceneIDs(galleryID) ret, err = qb.GetSceneIDs(ctx, galleryID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -410,12 +427,12 @@ func (r *mutationResolver) GalleryDestroy(ctx context.Context, input models.Gall
deleteGenerated := utils.IsTrue(input.DeleteGenerated) deleteGenerated := utils.IsTrue(input.DeleteGenerated)
deleteFile := utils.IsTrue(input.DeleteFile) deleteFile := utils.IsTrue(input.DeleteFile)
if err := r.withTxn(ctx, func(repo models.Repository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := repo.Gallery() qb := r.repository.Gallery
iqb := repo.Image() iqb := r.repository.Image
for _, id := range galleryIDs { for _, id := range galleryIDs {
gallery, err := qb.Find(id) gallery, err := qb.Find(ctx, id)
if err != nil { if err != nil {
return err return err
} }
@@ -428,13 +445,13 @@ func (r *mutationResolver) GalleryDestroy(ctx context.Context, input models.Gall
// if this is a zip-based gallery, delete the images as well first // if this is a zip-based gallery, delete the images as well first
if gallery.Zip { if gallery.Zip {
imgs, err := iqb.FindByGalleryID(id) imgs, err := iqb.FindByGalleryID(ctx, id)
if err != nil { if err != nil {
return err return err
} }
for _, img := range imgs { for _, img := range imgs {
if err := image.Destroy(img, iqb, fileDeleter, deleteGenerated, false); err != nil { if err := image.Destroy(ctx, img, iqb, fileDeleter, deleteGenerated, false); err != nil {
return err return err
} }
@@ -448,19 +465,19 @@ func (r *mutationResolver) GalleryDestroy(ctx context.Context, input models.Gall
} }
} else if deleteFile { } else if deleteFile {
// Delete image if it is only attached to this gallery // Delete image if it is only attached to this gallery
imgs, err := iqb.FindByGalleryID(id) imgs, err := iqb.FindByGalleryID(ctx, id)
if err != nil { if err != nil {
return err return err
} }
for _, img := range imgs { for _, img := range imgs {
imgGalleries, err := qb.FindByImageID(img.ID) imgGalleries, err := qb.FindByImageID(ctx, img.ID)
if err != nil { if err != nil {
return err return err
} }
if len(imgGalleries) == 1 { if len(imgGalleries) == 1 {
if err := image.Destroy(img, iqb, fileDeleter, deleteGenerated, deleteFile); err != nil { if err := image.Destroy(ctx, img, iqb, fileDeleter, deleteGenerated, deleteFile); err != nil {
return err return err
} }
@@ -472,7 +489,7 @@ func (r *mutationResolver) GalleryDestroy(ctx context.Context, input models.Gall
// don't do this with the file deleter // don't do this with the file deleter
} }
if err := qb.Destroy(id); err != nil { if err := qb.Destroy(ctx, id); err != nil {
return err return err
} }
} }
@@ -537,9 +554,9 @@ func (r *mutationResolver) AddGalleryImages(ctx context.Context, input GalleryAd
return false, err return false, err
} }
if err := r.withTxn(ctx, func(repo models.Repository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := repo.Gallery() qb := r.repository.Gallery
gallery, err := qb.Find(galleryID) gallery, err := qb.Find(ctx, galleryID)
if err != nil { if err != nil {
return err return err
} }
@@ -552,13 +569,13 @@ func (r *mutationResolver) AddGalleryImages(ctx context.Context, input GalleryAd
return errors.New("cannot modify zip gallery images") return errors.New("cannot modify zip gallery images")
} }
newIDs, err := qb.GetImageIDs(galleryID) newIDs, err := qb.GetImageIDs(ctx, galleryID)
if err != nil { if err != nil {
return err return err
} }
newIDs = intslice.IntAppendUniques(newIDs, imageIDs) newIDs = intslice.IntAppendUniques(newIDs, imageIDs)
return qb.UpdateImages(galleryID, newIDs) return qb.UpdateImages(ctx, galleryID, newIDs)
}); err != nil { }); err != nil {
return false, err return false, err
} }
@@ -577,9 +594,9 @@ func (r *mutationResolver) RemoveGalleryImages(ctx context.Context, input Galler
return false, err return false, err
} }
if err := r.withTxn(ctx, func(repo models.Repository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := repo.Gallery() qb := r.repository.Gallery
gallery, err := qb.Find(galleryID) gallery, err := qb.Find(ctx, galleryID)
if err != nil { if err != nil {
return err return err
} }
@@ -592,13 +609,13 @@ func (r *mutationResolver) RemoveGalleryImages(ctx context.Context, input Galler
return errors.New("cannot modify zip gallery images") return errors.New("cannot modify zip gallery images")
} }
newIDs, err := qb.GetImageIDs(galleryID) newIDs, err := qb.GetImageIDs(ctx, galleryID)
if err != nil { if err != nil {
return err return err
} }
newIDs = intslice.IntExclude(newIDs, imageIDs) newIDs = intslice.IntExclude(newIDs, imageIDs)
return qb.UpdateImages(galleryID, newIDs) return qb.UpdateImages(ctx, galleryID, newIDs)
}); err != nil { }); err != nil {
return false, err return false, err
} }

View File

@@ -16,8 +16,8 @@ import (
) )
func (r *mutationResolver) getImage(ctx context.Context, id int) (ret *models.Image, err error) { func (r *mutationResolver) getImage(ctx context.Context, id int) (ret *models.Image, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = repo.Image().Find(id) ret, err = r.repository.Image.Find(ctx, id)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -32,8 +32,8 @@ func (r *mutationResolver) ImageUpdate(ctx context.Context, input ImageUpdateInp
} }
// Start the transaction and save the image // Start the transaction and save the image
if err := r.withTxn(ctx, func(repo models.Repository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.imageUpdate(input, translator, repo) ret, err = r.imageUpdate(ctx, input, translator)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -48,13 +48,13 @@ func (r *mutationResolver) ImagesUpdate(ctx context.Context, input []*ImageUpdat
inputMaps := getUpdateInputMaps(ctx) inputMaps := getUpdateInputMaps(ctx)
// Start the transaction and save the image // Start the transaction and save the image
if err := r.withTxn(ctx, func(repo models.Repository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
for i, image := range input { for i, image := range input {
translator := changesetTranslator{ translator := changesetTranslator{
inputMap: inputMaps[i], inputMap: inputMaps[i],
} }
thisImage, err := r.imageUpdate(*image, translator, repo) thisImage, err := r.imageUpdate(ctx, *image, translator)
if err != nil { if err != nil {
return err return err
} }
@@ -86,7 +86,7 @@ func (r *mutationResolver) ImagesUpdate(ctx context.Context, input []*ImageUpdat
return newRet, nil return newRet, nil
} }
func (r *mutationResolver) imageUpdate(input ImageUpdateInput, translator changesetTranslator, repo models.Repository) (*models.Image, error) { func (r *mutationResolver) imageUpdate(ctx context.Context, input ImageUpdateInput, translator changesetTranslator) (*models.Image, error) {
// Populate image from the input // Populate image from the input
imageID, err := strconv.Atoi(input.ID) imageID, err := strconv.Atoi(input.ID)
if err != nil { if err != nil {
@@ -104,28 +104,28 @@ func (r *mutationResolver) imageUpdate(input ImageUpdateInput, translator change
updatedImage.StudioID = translator.nullInt64FromString(input.StudioID, "studio_id") updatedImage.StudioID = translator.nullInt64FromString(input.StudioID, "studio_id")
updatedImage.Organized = input.Organized updatedImage.Organized = input.Organized
qb := repo.Image() qb := r.repository.Image
image, err := qb.Update(updatedImage) image, err := qb.Update(ctx, updatedImage)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if translator.hasField("gallery_ids") { if translator.hasField("gallery_ids") {
if err := r.updateImageGalleries(qb, imageID, input.GalleryIds); err != nil { if err := r.updateImageGalleries(ctx, imageID, input.GalleryIds); err != nil {
return nil, err return nil, err
} }
} }
// Save the performers // Save the performers
if translator.hasField("performer_ids") { if translator.hasField("performer_ids") {
if err := r.updateImagePerformers(qb, imageID, input.PerformerIds); err != nil { if err := r.updateImagePerformers(ctx, imageID, input.PerformerIds); err != nil {
return nil, err return nil, err
} }
} }
// Save the tags // Save the tags
if translator.hasField("tag_ids") { if translator.hasField("tag_ids") {
if err := r.updateImageTags(qb, imageID, input.TagIds); err != nil { if err := r.updateImageTags(ctx, imageID, input.TagIds); err != nil {
return nil, err return nil, err
} }
} }
@@ -133,28 +133,28 @@ func (r *mutationResolver) imageUpdate(input ImageUpdateInput, translator change
return image, nil return image, nil
} }
func (r *mutationResolver) updateImageGalleries(qb models.ImageReaderWriter, imageID int, galleryIDs []string) error { func (r *mutationResolver) updateImageGalleries(ctx context.Context, imageID int, galleryIDs []string) error {
ids, err := stringslice.StringSliceToIntSlice(galleryIDs) ids, err := stringslice.StringSliceToIntSlice(galleryIDs)
if err != nil { if err != nil {
return err return err
} }
return qb.UpdateGalleries(imageID, ids) return r.repository.Image.UpdateGalleries(ctx, imageID, ids)
} }
func (r *mutationResolver) updateImagePerformers(qb models.ImageReaderWriter, imageID int, performerIDs []string) error { func (r *mutationResolver) updateImagePerformers(ctx context.Context, imageID int, performerIDs []string) error {
ids, err := stringslice.StringSliceToIntSlice(performerIDs) ids, err := stringslice.StringSliceToIntSlice(performerIDs)
if err != nil { if err != nil {
return err return err
} }
return qb.UpdatePerformers(imageID, ids) return r.repository.Image.UpdatePerformers(ctx, imageID, ids)
} }
func (r *mutationResolver) updateImageTags(qb models.ImageReaderWriter, imageID int, tagsIDs []string) error { func (r *mutationResolver) updateImageTags(ctx context.Context, imageID int, tagsIDs []string) error {
ids, err := stringslice.StringSliceToIntSlice(tagsIDs) ids, err := stringslice.StringSliceToIntSlice(tagsIDs)
if err != nil { if err != nil {
return err return err
} }
return qb.UpdateTags(imageID, ids) return r.repository.Image.UpdateTags(ctx, imageID, ids)
} }
func (r *mutationResolver) BulkImageUpdate(ctx context.Context, input BulkImageUpdateInput) (ret []*models.Image, err error) { func (r *mutationResolver) BulkImageUpdate(ctx context.Context, input BulkImageUpdateInput) (ret []*models.Image, err error) {
@@ -180,13 +180,13 @@ func (r *mutationResolver) BulkImageUpdate(ctx context.Context, input BulkImageU
updatedImage.Organized = input.Organized updatedImage.Organized = input.Organized
// Start the transaction and save the image marker // Start the transaction and save the image marker
if err := r.withTxn(ctx, func(repo models.Repository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := repo.Image() qb := r.repository.Image
for _, imageID := range imageIDs { for _, imageID := range imageIDs {
updatedImage.ID = imageID updatedImage.ID = imageID
image, err := qb.Update(updatedImage) image, err := qb.Update(ctx, updatedImage)
if err != nil { if err != nil {
return err return err
} }
@@ -195,36 +195,36 @@ func (r *mutationResolver) BulkImageUpdate(ctx context.Context, input BulkImageU
// Save the galleries // Save the galleries
if translator.hasField("gallery_ids") { if translator.hasField("gallery_ids") {
galleryIDs, err := adjustImageGalleryIDs(qb, imageID, *input.GalleryIds) galleryIDs, err := r.adjustImageGalleryIDs(ctx, imageID, *input.GalleryIds)
if err != nil { if err != nil {
return err return err
} }
if err := qb.UpdateGalleries(imageID, galleryIDs); err != nil { if err := qb.UpdateGalleries(ctx, imageID, galleryIDs); err != nil {
return err return err
} }
} }
// Save the performers // Save the performers
if translator.hasField("performer_ids") { if translator.hasField("performer_ids") {
performerIDs, err := adjustImagePerformerIDs(qb, imageID, *input.PerformerIds) performerIDs, err := r.adjustImagePerformerIDs(ctx, imageID, *input.PerformerIds)
if err != nil { if err != nil {
return err return err
} }
if err := qb.UpdatePerformers(imageID, performerIDs); err != nil { if err := qb.UpdatePerformers(ctx, imageID, performerIDs); err != nil {
return err return err
} }
} }
// Save the tags // Save the tags
if translator.hasField("tag_ids") { if translator.hasField("tag_ids") {
tagIDs, err := adjustImageTagIDs(qb, imageID, *input.TagIds) tagIDs, err := r.adjustImageTagIDs(ctx, imageID, *input.TagIds)
if err != nil { if err != nil {
return err return err
} }
if err := qb.UpdateTags(imageID, tagIDs); err != nil { if err := qb.UpdateTags(ctx, imageID, tagIDs); err != nil {
return err return err
} }
} }
@@ -251,8 +251,8 @@ func (r *mutationResolver) BulkImageUpdate(ctx context.Context, input BulkImageU
return newRet, nil return newRet, nil
} }
func adjustImageGalleryIDs(qb models.ImageReader, imageID int, ids BulkUpdateIds) (ret []int, err error) { func (r *mutationResolver) adjustImageGalleryIDs(ctx context.Context, imageID int, ids BulkUpdateIds) (ret []int, err error) {
ret, err = qb.GetGalleryIDs(imageID) ret, err = r.repository.Image.GetGalleryIDs(ctx, imageID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -260,8 +260,8 @@ func adjustImageGalleryIDs(qb models.ImageReader, imageID int, ids BulkUpdateIds
return adjustIDs(ret, ids), nil return adjustIDs(ret, ids), nil
} }
func adjustImagePerformerIDs(qb models.ImageReader, imageID int, ids BulkUpdateIds) (ret []int, err error) { func (r *mutationResolver) adjustImagePerformerIDs(ctx context.Context, imageID int, ids BulkUpdateIds) (ret []int, err error) {
ret, err = qb.GetPerformerIDs(imageID) ret, err = r.repository.Image.GetPerformerIDs(ctx, imageID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -269,8 +269,8 @@ func adjustImagePerformerIDs(qb models.ImageReader, imageID int, ids BulkUpdateI
return adjustIDs(ret, ids), nil return adjustIDs(ret, ids), nil
} }
func adjustImageTagIDs(qb models.ImageReader, imageID int, ids BulkUpdateIds) (ret []int, err error) { func (r *mutationResolver) adjustImageTagIDs(ctx context.Context, imageID int, ids BulkUpdateIds) (ret []int, err error) {
ret, err = qb.GetTagIDs(imageID) ret, err = r.repository.Image.GetTagIDs(ctx, imageID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -289,10 +289,10 @@ func (r *mutationResolver) ImageDestroy(ctx context.Context, input models.ImageD
Deleter: *file.NewDeleter(), Deleter: *file.NewDeleter(),
Paths: manager.GetInstance().Paths, Paths: manager.GetInstance().Paths,
} }
if err := r.withTxn(ctx, func(repo models.Repository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := repo.Image() qb := r.repository.Image
i, err = qb.Find(imageID) i, err = r.repository.Image.Find(ctx, imageID)
if err != nil { if err != nil {
return err return err
} }
@@ -301,7 +301,7 @@ func (r *mutationResolver) ImageDestroy(ctx context.Context, input models.ImageD
return fmt.Errorf("image with id %d not found", imageID) return fmt.Errorf("image with id %d not found", imageID)
} }
return image.Destroy(i, qb, fileDeleter, utils.IsTrue(input.DeleteGenerated), utils.IsTrue(input.DeleteFile)) return image.Destroy(ctx, i, qb, fileDeleter, utils.IsTrue(input.DeleteGenerated), utils.IsTrue(input.DeleteFile))
}); err != nil { }); err != nil {
fileDeleter.Rollback() fileDeleter.Rollback()
return false, err return false, err
@@ -331,12 +331,12 @@ func (r *mutationResolver) ImagesDestroy(ctx context.Context, input models.Image
Deleter: *file.NewDeleter(), Deleter: *file.NewDeleter(),
Paths: manager.GetInstance().Paths, Paths: manager.GetInstance().Paths,
} }
if err := r.withTxn(ctx, func(repo models.Repository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := repo.Image() qb := r.repository.Image
for _, imageID := range imageIDs { for _, imageID := range imageIDs {
i, err := qb.Find(imageID) i, err := qb.Find(ctx, imageID)
if err != nil { if err != nil {
return err return err
} }
@@ -347,7 +347,7 @@ func (r *mutationResolver) ImagesDestroy(ctx context.Context, input models.Image
images = append(images, i) images = append(images, i)
if err := image.Destroy(i, qb, fileDeleter, utils.IsTrue(input.DeleteGenerated), utils.IsTrue(input.DeleteFile)); err != nil { if err := image.Destroy(ctx, i, qb, fileDeleter, utils.IsTrue(input.DeleteGenerated), utils.IsTrue(input.DeleteFile)); err != nil {
return err return err
} }
} }
@@ -379,10 +379,10 @@ func (r *mutationResolver) ImageIncrementO(ctx context.Context, id string) (ret
return 0, err return 0, err
} }
if err := r.withTxn(ctx, func(repo models.Repository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := repo.Image() qb := r.repository.Image
ret, err = qb.IncrementOCounter(imageID) ret, err = qb.IncrementOCounter(ctx, imageID)
return err return err
}); err != nil { }); err != nil {
return 0, err return 0, err
@@ -397,10 +397,10 @@ func (r *mutationResolver) ImageDecrementO(ctx context.Context, id string) (ret
return 0, err return 0, err
} }
if err := r.withTxn(ctx, func(repo models.Repository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := repo.Image() qb := r.repository.Image
ret, err = qb.DecrementOCounter(imageID) ret, err = qb.DecrementOCounter(ctx, imageID)
return err return err
}); err != nil { }); err != nil {
return 0, err return 0, err
@@ -415,10 +415,10 @@ func (r *mutationResolver) ImageResetO(ctx context.Context, id string) (ret int,
return 0, err return 0, err
} }
if err := r.withTxn(ctx, func(repo models.Repository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := repo.Image() qb := r.repository.Image
ret, err = qb.ResetOCounter(imageID) ret, err = qb.ResetOCounter(ctx, imageID)
return err return err
}); err != nil { }); err != nil {
return 0, err return 0, err

View File

@@ -12,7 +12,6 @@ import (
"github.com/stashapp/stash/internal/identify" "github.com/stashapp/stash/internal/identify"
"github.com/stashapp/stash/internal/manager" "github.com/stashapp/stash/internal/manager"
"github.com/stashapp/stash/internal/manager/config" "github.com/stashapp/stash/internal/manager/config"
"github.com/stashapp/stash/pkg/database"
"github.com/stashapp/stash/pkg/fsutil" "github.com/stashapp/stash/pkg/fsutil"
"github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/logger"
) )
@@ -111,6 +110,7 @@ func (r *mutationResolver) BackupDatabase(ctx context.Context, input BackupDatab
// if download is true, then backup to temporary file and return a link // if download is true, then backup to temporary file and return a link
download := input.Download != nil && *input.Download download := input.Download != nil && *input.Download
mgr := manager.GetInstance() mgr := manager.GetInstance()
database := mgr.Database
var backupPath string var backupPath string
if download { if download {
if err := fsutil.EnsureDir(mgr.Paths.Generated.Downloads); err != nil { if err := fsutil.EnsureDir(mgr.Paths.Generated.Downloads); err != nil {
@@ -127,7 +127,7 @@ func (r *mutationResolver) BackupDatabase(ctx context.Context, input BackupDatab
backupPath = database.DatabaseBackupPath() backupPath = database.DatabaseBackupPath()
} }
err := database.Backup(database.DB, backupPath) err := database.Backup(backupPath)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -15,8 +15,8 @@ import (
) )
func (r *mutationResolver) getMovie(ctx context.Context, id int) (ret *models.Movie, err error) { func (r *mutationResolver) getMovie(ctx context.Context, id int) (ret *models.Movie, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = repo.Movie().Find(id) ret, err = r.repository.Movie.Find(ctx, id)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -100,16 +100,16 @@ func (r *mutationResolver) MovieCreate(ctx context.Context, input MovieCreateInp
// Start the transaction and save the movie // Start the transaction and save the movie
var movie *models.Movie var movie *models.Movie
if err := r.withTxn(ctx, func(repo models.Repository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := repo.Movie() qb := r.repository.Movie
movie, err = qb.Create(newMovie) movie, err = qb.Create(ctx, newMovie)
if err != nil { if err != nil {
return err return err
} }
// update image table // update image table
if len(frontimageData) > 0 { if len(frontimageData) > 0 {
if err := qb.UpdateImages(movie.ID, frontimageData, backimageData); err != nil { if err := qb.UpdateImages(ctx, movie.ID, frontimageData, backimageData); err != nil {
return err return err
} }
} }
@@ -174,9 +174,9 @@ func (r *mutationResolver) MovieUpdate(ctx context.Context, input MovieUpdateInp
// Start the transaction and save the movie // Start the transaction and save the movie
var movie *models.Movie var movie *models.Movie
if err := r.withTxn(ctx, func(repo models.Repository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := repo.Movie() qb := r.repository.Movie
movie, err = qb.Update(updatedMovie) movie, err = qb.Update(ctx, updatedMovie)
if err != nil { if err != nil {
return err return err
} }
@@ -184,13 +184,13 @@ func (r *mutationResolver) MovieUpdate(ctx context.Context, input MovieUpdateInp
// update image table // update image table
if frontImageIncluded || backImageIncluded { if frontImageIncluded || backImageIncluded {
if !frontImageIncluded { if !frontImageIncluded {
frontimageData, err = qb.GetFrontImage(updatedMovie.ID) frontimageData, err = qb.GetFrontImage(ctx, updatedMovie.ID)
if err != nil { if err != nil {
return err return err
} }
} }
if !backImageIncluded { if !backImageIncluded {
backimageData, err = qb.GetBackImage(updatedMovie.ID) backimageData, err = qb.GetBackImage(ctx, updatedMovie.ID)
if err != nil { if err != nil {
return err return err
} }
@@ -198,7 +198,7 @@ func (r *mutationResolver) MovieUpdate(ctx context.Context, input MovieUpdateInp
if len(frontimageData) == 0 && len(backimageData) == 0 { if len(frontimageData) == 0 && len(backimageData) == 0 {
// both images are being nulled. Destroy them. // both images are being nulled. Destroy them.
if err := qb.DestroyImages(movie.ID); err != nil { if err := qb.DestroyImages(ctx, movie.ID); err != nil {
return err return err
} }
} else { } else {
@@ -208,7 +208,7 @@ func (r *mutationResolver) MovieUpdate(ctx context.Context, input MovieUpdateInp
frontimageData, _ = utils.ProcessImageInput(ctx, models.DefaultMovieImage) frontimageData, _ = utils.ProcessImageInput(ctx, models.DefaultMovieImage)
} }
if err := qb.UpdateImages(movie.ID, frontimageData, backimageData); err != nil { if err := qb.UpdateImages(ctx, movie.ID, frontimageData, backimageData); err != nil {
return err return err
} }
} }
@@ -245,13 +245,13 @@ func (r *mutationResolver) BulkMovieUpdate(ctx context.Context, input BulkMovieU
ret := []*models.Movie{} ret := []*models.Movie{}
if err := r.withTxn(ctx, func(repo models.Repository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := repo.Movie() qb := r.repository.Movie
for _, movieID := range movieIDs { for _, movieID := range movieIDs {
updatedMovie.ID = movieID updatedMovie.ID = movieID
existing, err := qb.Find(movieID) existing, err := qb.Find(ctx, movieID)
if err != nil { if err != nil {
return err return err
} }
@@ -260,7 +260,7 @@ func (r *mutationResolver) BulkMovieUpdate(ctx context.Context, input BulkMovieU
return fmt.Errorf("movie with id %d not found", movieID) return fmt.Errorf("movie with id %d not found", movieID)
} }
movie, err := qb.Update(updatedMovie) movie, err := qb.Update(ctx, updatedMovie)
if err != nil { if err != nil {
return err return err
} }
@@ -294,8 +294,8 @@ func (r *mutationResolver) MovieDestroy(ctx context.Context, input MovieDestroyI
return false, err return false, err
} }
if err := r.withTxn(ctx, func(repo models.Repository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
return repo.Movie().Destroy(id) return r.repository.Movie.Destroy(ctx, id)
}); err != nil { }); err != nil {
return false, err return false, err
} }
@@ -311,10 +311,10 @@ func (r *mutationResolver) MoviesDestroy(ctx context.Context, movieIDs []string)
return false, err return false, err
} }
if err := r.withTxn(ctx, func(repo models.Repository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := repo.Movie() qb := r.repository.Movie
for _, id := range ids { for _, id := range ids {
if err := qb.Destroy(id); err != nil { if err := qb.Destroy(ctx, id); err != nil {
return err return err
} }
} }

View File

@@ -16,8 +16,8 @@ import (
) )
func (r *mutationResolver) getPerformer(ctx context.Context, id int) (ret *models.Performer, err error) { func (r *mutationResolver) getPerformer(ctx context.Context, id int) (ret *models.Performer, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = repo.Performer().Find(id) ret, err = r.repository.Performer.Find(ctx, id)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -129,23 +129,23 @@ func (r *mutationResolver) PerformerCreate(ctx context.Context, input PerformerC
// Start the transaction and save the performer // Start the transaction and save the performer
var performer *models.Performer var performer *models.Performer
if err := r.withTxn(ctx, func(repo models.Repository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := repo.Performer() qb := r.repository.Performer
performer, err = qb.Create(newPerformer) performer, err = qb.Create(ctx, newPerformer)
if err != nil { if err != nil {
return err return err
} }
if len(input.TagIds) > 0 { if len(input.TagIds) > 0 {
if err := r.updatePerformerTags(qb, performer.ID, input.TagIds); err != nil { if err := r.updatePerformerTags(ctx, performer.ID, input.TagIds); err != nil {
return err return err
} }
} }
// update image table // update image table
if len(imageData) > 0 { if len(imageData) > 0 {
if err := qb.UpdateImage(performer.ID, imageData); err != nil { if err := qb.UpdateImage(ctx, performer.ID, imageData); err != nil {
return err return err
} }
} }
@@ -153,7 +153,7 @@ func (r *mutationResolver) PerformerCreate(ctx context.Context, input PerformerC
// Save the stash_ids // Save the stash_ids
if input.StashIds != nil { if input.StashIds != nil {
stashIDJoins := models.StashIDsFromInput(input.StashIds) stashIDJoins := models.StashIDsFromInput(input.StashIds)
if err := qb.UpdateStashIDs(performer.ID, stashIDJoins); err != nil { if err := qb.UpdateStashIDs(ctx, performer.ID, stashIDJoins); err != nil {
return err return err
} }
} }
@@ -230,11 +230,11 @@ func (r *mutationResolver) PerformerUpdate(ctx context.Context, input PerformerU
// Start the transaction and save the p // Start the transaction and save the p
var p *models.Performer var p *models.Performer
if err := r.withTxn(ctx, func(repo models.Repository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := repo.Performer() qb := r.repository.Performer
// need to get existing performer // need to get existing performer
existing, err := qb.Find(updatedPerformer.ID) existing, err := qb.Find(ctx, updatedPerformer.ID)
if err != nil { if err != nil {
return err return err
} }
@@ -249,26 +249,26 @@ func (r *mutationResolver) PerformerUpdate(ctx context.Context, input PerformerU
} }
} }
p, err = qb.Update(updatedPerformer) p, err = qb.Update(ctx, updatedPerformer)
if err != nil { if err != nil {
return err return err
} }
// Save the tags // Save the tags
if translator.hasField("tag_ids") { if translator.hasField("tag_ids") {
if err := r.updatePerformerTags(qb, p.ID, input.TagIds); err != nil { if err := r.updatePerformerTags(ctx, p.ID, input.TagIds); err != nil {
return err return err
} }
} }
// update image table // update image table
if len(imageData) > 0 { if len(imageData) > 0 {
if err := qb.UpdateImage(p.ID, imageData); err != nil { if err := qb.UpdateImage(ctx, p.ID, imageData); err != nil {
return err return err
} }
} else if imageIncluded { } else if imageIncluded {
// must be unsetting // must be unsetting
if err := qb.DestroyImage(p.ID); err != nil { if err := qb.DestroyImage(ctx, p.ID); err != nil {
return err return err
} }
} }
@@ -276,7 +276,7 @@ func (r *mutationResolver) PerformerUpdate(ctx context.Context, input PerformerU
// Save the stash_ids // Save the stash_ids
if translator.hasField("stash_ids") { if translator.hasField("stash_ids") {
stashIDJoins := models.StashIDsFromInput(input.StashIds) stashIDJoins := models.StashIDsFromInput(input.StashIds)
if err := qb.UpdateStashIDs(performerID, stashIDJoins); err != nil { if err := qb.UpdateStashIDs(ctx, performerID, stashIDJoins); err != nil {
return err return err
} }
} }
@@ -290,12 +290,12 @@ func (r *mutationResolver) PerformerUpdate(ctx context.Context, input PerformerU
return r.getPerformer(ctx, p.ID) return r.getPerformer(ctx, p.ID)
} }
func (r *mutationResolver) updatePerformerTags(qb models.PerformerReaderWriter, performerID int, tagsIDs []string) error { func (r *mutationResolver) updatePerformerTags(ctx context.Context, performerID int, tagsIDs []string) error {
ids, err := stringslice.StringSliceToIntSlice(tagsIDs) ids, err := stringslice.StringSliceToIntSlice(tagsIDs)
if err != nil { if err != nil {
return err return err
} }
return qb.UpdateTags(performerID, ids) return r.repository.Performer.UpdateTags(ctx, performerID, ids)
} }
func (r *mutationResolver) BulkPerformerUpdate(ctx context.Context, input BulkPerformerUpdateInput) ([]*models.Performer, error) { func (r *mutationResolver) BulkPerformerUpdate(ctx context.Context, input BulkPerformerUpdateInput) ([]*models.Performer, error) {
@@ -348,14 +348,14 @@ func (r *mutationResolver) BulkPerformerUpdate(ctx context.Context, input BulkPe
ret := []*models.Performer{} ret := []*models.Performer{}
// Start the transaction and save the scene marker // Start the transaction and save the scene marker
if err := r.withTxn(ctx, func(repo models.Repository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := repo.Performer() qb := r.repository.Performer
for _, performerID := range performerIDs { for _, performerID := range performerIDs {
updatedPerformer.ID = performerID updatedPerformer.ID = performerID
// need to get existing performer // need to get existing performer
existing, err := qb.Find(performerID) existing, err := qb.Find(ctx, performerID)
if err != nil { if err != nil {
return err return err
} }
@@ -368,7 +368,7 @@ func (r *mutationResolver) BulkPerformerUpdate(ctx context.Context, input BulkPe
return err return err
} }
performer, err := qb.Update(updatedPerformer) performer, err := qb.Update(ctx, updatedPerformer)
if err != nil { if err != nil {
return err return err
} }
@@ -377,12 +377,12 @@ func (r *mutationResolver) BulkPerformerUpdate(ctx context.Context, input BulkPe
// Save the tags // Save the tags
if translator.hasField("tag_ids") { if translator.hasField("tag_ids") {
tagIDs, err := adjustTagIDs(qb, performerID, *input.TagIds) tagIDs, err := adjustTagIDs(ctx, qb, performerID, *input.TagIds)
if err != nil { if err != nil {
return err return err
} }
if err := qb.UpdateTags(performerID, tagIDs); err != nil { if err := qb.UpdateTags(ctx, performerID, tagIDs); err != nil {
return err return err
} }
} }
@@ -415,8 +415,8 @@ func (r *mutationResolver) PerformerDestroy(ctx context.Context, input Performer
return false, err return false, err
} }
if err := r.withTxn(ctx, func(repo models.Repository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
return repo.Performer().Destroy(id) return r.repository.Performer.Destroy(ctx, id)
}); err != nil { }); err != nil {
return false, err return false, err
} }
@@ -432,10 +432,10 @@ func (r *mutationResolver) PerformersDestroy(ctx context.Context, performerIDs [
return false, err return false, err
} }
if err := r.withTxn(ctx, func(repo models.Repository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := repo.Performer() qb := r.repository.Performer
for _, id := range ids { for _, id := range ids {
if err := qb.Destroy(id); err != nil { if err := qb.Destroy(ctx, id); err != nil {
return err return err
} }
} }

View File

@@ -23,17 +23,17 @@ func (r *mutationResolver) SaveFilter(ctx context.Context, input SaveFilterInput
id = &idv id = &idv
} }
if err := r.withTxn(ctx, func(repo models.Repository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
f := models.SavedFilter{ f := models.SavedFilter{
Mode: input.Mode, Mode: input.Mode,
Name: input.Name, Name: input.Name,
Filter: input.Filter, Filter: input.Filter,
} }
if id == nil { if id == nil {
ret, err = repo.SavedFilter().Create(f) ret, err = r.repository.SavedFilter.Create(ctx, f)
} else { } else {
f.ID = *id f.ID = *id
ret, err = repo.SavedFilter().Update(f) ret, err = r.repository.SavedFilter.Update(ctx, f)
} }
return err return err
}); err != nil { }); err != nil {
@@ -48,8 +48,8 @@ func (r *mutationResolver) DestroySavedFilter(ctx context.Context, input Destroy
return false, err return false, err
} }
if err := r.withTxn(ctx, func(repo models.Repository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
return repo.SavedFilter().Destroy(id) return r.repository.SavedFilter.Destroy(ctx, id)
}); err != nil { }); err != nil {
return false, err return false, err
} }
@@ -58,24 +58,24 @@ func (r *mutationResolver) DestroySavedFilter(ctx context.Context, input Destroy
} }
func (r *mutationResolver) SetDefaultFilter(ctx context.Context, input SetDefaultFilterInput) (bool, error) { func (r *mutationResolver) SetDefaultFilter(ctx context.Context, input SetDefaultFilterInput) (bool, error) {
if err := r.withTxn(ctx, func(repo models.Repository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := repo.SavedFilter() qb := r.repository.SavedFilter
if input.Filter == nil { if input.Filter == nil {
// clearing // clearing
def, err := qb.FindDefault(input.Mode) def, err := qb.FindDefault(ctx, input.Mode)
if err != nil { if err != nil {
return err return err
} }
if def != nil { if def != nil {
return qb.Destroy(def.ID) return qb.Destroy(ctx, def.ID)
} }
return nil return nil
} }
_, err := qb.SetDefault(models.SavedFilter{ _, err := qb.SetDefault(ctx, models.SavedFilter{
Mode: input.Mode, Mode: input.Mode,
Filter: *input.Filter, Filter: *input.Filter,
}) })

View File

@@ -19,8 +19,8 @@ import (
) )
func (r *mutationResolver) getScene(ctx context.Context, id int) (ret *models.Scene, err error) { func (r *mutationResolver) getScene(ctx context.Context, id int) (ret *models.Scene, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = repo.Scene().Find(id) ret, err = r.repository.Scene.Find(ctx, id)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -35,8 +35,8 @@ func (r *mutationResolver) SceneUpdate(ctx context.Context, input models.SceneUp
} }
// Start the transaction and save the scene // Start the transaction and save the scene
if err := r.withTxn(ctx, func(repo models.Repository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = r.sceneUpdate(ctx, input, translator, repo) ret, err = r.sceneUpdate(ctx, input, translator)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -50,13 +50,13 @@ func (r *mutationResolver) ScenesUpdate(ctx context.Context, input []*models.Sce
inputMaps := getUpdateInputMaps(ctx) inputMaps := getUpdateInputMaps(ctx)
// Start the transaction and save the scene // Start the transaction and save the scene
if err := r.withTxn(ctx, func(repo models.Repository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
for i, scene := range input { for i, scene := range input {
translator := changesetTranslator{ translator := changesetTranslator{
inputMap: inputMaps[i], inputMap: inputMaps[i],
} }
thisScene, err := r.sceneUpdate(ctx, *scene, translator, repo) thisScene, err := r.sceneUpdate(ctx, *scene, translator)
ret = append(ret, thisScene) ret = append(ret, thisScene)
if err != nil { if err != nil {
@@ -89,7 +89,7 @@ func (r *mutationResolver) ScenesUpdate(ctx context.Context, input []*models.Sce
return newRet, nil return newRet, nil
} }
func (r *mutationResolver) sceneUpdate(ctx context.Context, input models.SceneUpdateInput, translator changesetTranslator, repo models.Repository) (*models.Scene, error) { func (r *mutationResolver) sceneUpdate(ctx context.Context, input models.SceneUpdateInput, translator changesetTranslator) (*models.Scene, error) {
// Populate scene from the input // Populate scene from the input
sceneID, err := strconv.Atoi(input.ID) sceneID, err := strconv.Atoi(input.ID)
if err != nil { if err != nil {
@@ -122,43 +122,43 @@ func (r *mutationResolver) sceneUpdate(ctx context.Context, input models.SceneUp
// update the cover after updating the scene // update the cover after updating the scene
} }
qb := repo.Scene() qb := r.repository.Scene
s, err := qb.Update(updatedScene) s, err := qb.Update(ctx, updatedScene)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// update cover table // update cover table
if len(coverImageData) > 0 { if len(coverImageData) > 0 {
if err := qb.UpdateCover(sceneID, coverImageData); err != nil { if err := qb.UpdateCover(ctx, sceneID, coverImageData); err != nil {
return nil, err return nil, err
} }
} }
// Save the performers // Save the performers
if translator.hasField("performer_ids") { if translator.hasField("performer_ids") {
if err := r.updateScenePerformers(qb, sceneID, input.PerformerIds); err != nil { if err := r.updateScenePerformers(ctx, sceneID, input.PerformerIds); err != nil {
return nil, err return nil, err
} }
} }
// Save the movies // Save the movies
if translator.hasField("movies") { if translator.hasField("movies") {
if err := r.updateSceneMovies(qb, sceneID, input.Movies); err != nil { if err := r.updateSceneMovies(ctx, sceneID, input.Movies); err != nil {
return nil, err return nil, err
} }
} }
// Save the tags // Save the tags
if translator.hasField("tag_ids") { if translator.hasField("tag_ids") {
if err := r.updateSceneTags(qb, sceneID, input.TagIds); err != nil { if err := r.updateSceneTags(ctx, sceneID, input.TagIds); err != nil {
return nil, err return nil, err
} }
} }
// Save the galleries // Save the galleries
if translator.hasField("gallery_ids") { if translator.hasField("gallery_ids") {
if err := r.updateSceneGalleries(qb, sceneID, input.GalleryIds); err != nil { if err := r.updateSceneGalleries(ctx, sceneID, input.GalleryIds); err != nil {
return nil, err return nil, err
} }
} }
@@ -166,7 +166,7 @@ func (r *mutationResolver) sceneUpdate(ctx context.Context, input models.SceneUp
// Save the stash_ids // Save the stash_ids
if translator.hasField("stash_ids") { if translator.hasField("stash_ids") {
stashIDJoins := models.StashIDsFromInput(input.StashIds) stashIDJoins := models.StashIDsFromInput(input.StashIds)
if err := qb.UpdateStashIDs(sceneID, stashIDJoins); err != nil { if err := qb.UpdateStashIDs(ctx, sceneID, stashIDJoins); err != nil {
return nil, err return nil, err
} }
} }
@@ -182,15 +182,15 @@ func (r *mutationResolver) sceneUpdate(ctx context.Context, input models.SceneUp
return s, nil return s, nil
} }
func (r *mutationResolver) updateScenePerformers(qb models.SceneReaderWriter, sceneID int, performerIDs []string) error { func (r *mutationResolver) updateScenePerformers(ctx context.Context, sceneID int, performerIDs []string) error {
ids, err := stringslice.StringSliceToIntSlice(performerIDs) ids, err := stringslice.StringSliceToIntSlice(performerIDs)
if err != nil { if err != nil {
return err return err
} }
return qb.UpdatePerformers(sceneID, ids) return r.repository.Scene.UpdatePerformers(ctx, sceneID, ids)
} }
func (r *mutationResolver) updateSceneMovies(qb models.SceneReaderWriter, sceneID int, movies []*models.SceneMovieInput) error { func (r *mutationResolver) updateSceneMovies(ctx context.Context, sceneID int, movies []*models.SceneMovieInput) error {
var movieJoins []models.MoviesScenes var movieJoins []models.MoviesScenes
for _, movie := range movies { for _, movie := range movies {
@@ -213,23 +213,23 @@ func (r *mutationResolver) updateSceneMovies(qb models.SceneReaderWriter, sceneI
movieJoins = append(movieJoins, movieJoin) movieJoins = append(movieJoins, movieJoin)
} }
return qb.UpdateMovies(sceneID, movieJoins) return r.repository.Scene.UpdateMovies(ctx, sceneID, movieJoins)
} }
func (r *mutationResolver) updateSceneTags(qb models.SceneReaderWriter, sceneID int, tagsIDs []string) error { func (r *mutationResolver) updateSceneTags(ctx context.Context, sceneID int, tagsIDs []string) error {
ids, err := stringslice.StringSliceToIntSlice(tagsIDs) ids, err := stringslice.StringSliceToIntSlice(tagsIDs)
if err != nil { if err != nil {
return err return err
} }
return qb.UpdateTags(sceneID, ids) return r.repository.Scene.UpdateTags(ctx, sceneID, ids)
} }
func (r *mutationResolver) updateSceneGalleries(qb models.SceneReaderWriter, sceneID int, galleryIDs []string) error { func (r *mutationResolver) updateSceneGalleries(ctx context.Context, sceneID int, galleryIDs []string) error {
ids, err := stringslice.StringSliceToIntSlice(galleryIDs) ids, err := stringslice.StringSliceToIntSlice(galleryIDs)
if err != nil { if err != nil {
return err return err
} }
return qb.UpdateGalleries(sceneID, ids) return r.repository.Scene.UpdateGalleries(ctx, sceneID, ids)
} }
func (r *mutationResolver) BulkSceneUpdate(ctx context.Context, input BulkSceneUpdateInput) ([]*models.Scene, error) { func (r *mutationResolver) BulkSceneUpdate(ctx context.Context, input BulkSceneUpdateInput) ([]*models.Scene, error) {
@@ -260,13 +260,13 @@ func (r *mutationResolver) BulkSceneUpdate(ctx context.Context, input BulkSceneU
ret := []*models.Scene{} ret := []*models.Scene{}
// Start the transaction and save the scene marker // Start the transaction and save the scene marker
if err := r.withTxn(ctx, func(repo models.Repository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := repo.Scene() qb := r.repository.Scene
for _, sceneID := range sceneIDs { for _, sceneID := range sceneIDs {
updatedScene.ID = sceneID updatedScene.ID = sceneID
scene, err := qb.Update(updatedScene) scene, err := qb.Update(ctx, updatedScene)
if err != nil { if err != nil {
return err return err
} }
@@ -275,48 +275,48 @@ func (r *mutationResolver) BulkSceneUpdate(ctx context.Context, input BulkSceneU
// Save the performers // Save the performers
if translator.hasField("performer_ids") { if translator.hasField("performer_ids") {
performerIDs, err := adjustScenePerformerIDs(qb, sceneID, *input.PerformerIds) performerIDs, err := r.adjustScenePerformerIDs(ctx, sceneID, *input.PerformerIds)
if err != nil { if err != nil {
return err return err
} }
if err := qb.UpdatePerformers(sceneID, performerIDs); err != nil { if err := qb.UpdatePerformers(ctx, sceneID, performerIDs); err != nil {
return err return err
} }
} }
// Save the tags // Save the tags
if translator.hasField("tag_ids") { if translator.hasField("tag_ids") {
tagIDs, err := adjustTagIDs(qb, sceneID, *input.TagIds) tagIDs, err := adjustTagIDs(ctx, qb, sceneID, *input.TagIds)
if err != nil { if err != nil {
return err return err
} }
if err := qb.UpdateTags(sceneID, tagIDs); err != nil { if err := qb.UpdateTags(ctx, sceneID, tagIDs); err != nil {
return err return err
} }
} }
// Save the galleries // Save the galleries
if translator.hasField("gallery_ids") { if translator.hasField("gallery_ids") {
galleryIDs, err := adjustSceneGalleryIDs(qb, sceneID, *input.GalleryIds) galleryIDs, err := r.adjustSceneGalleryIDs(ctx, sceneID, *input.GalleryIds)
if err != nil { if err != nil {
return err return err
} }
if err := qb.UpdateGalleries(sceneID, galleryIDs); err != nil { if err := qb.UpdateGalleries(ctx, sceneID, galleryIDs); err != nil {
return err return err
} }
} }
// Save the movies // Save the movies
if translator.hasField("movie_ids") { if translator.hasField("movie_ids") {
movies, err := adjustSceneMovieIDs(qb, sceneID, *input.MovieIds) movies, err := r.adjustSceneMovieIDs(ctx, sceneID, *input.MovieIds)
if err != nil { if err != nil {
return err return err
} }
if err := qb.UpdateMovies(sceneID, movies); err != nil { if err := qb.UpdateMovies(ctx, sceneID, movies); err != nil {
return err return err
} }
} }
@@ -380,8 +380,8 @@ func adjustIDs(existingIDs []int, updateIDs BulkUpdateIds) []int {
return existingIDs return existingIDs
} }
func adjustScenePerformerIDs(qb models.SceneReader, sceneID int, ids BulkUpdateIds) (ret []int, err error) { func (r *mutationResolver) adjustScenePerformerIDs(ctx context.Context, sceneID int, ids BulkUpdateIds) (ret []int, err error) {
ret, err = qb.GetPerformerIDs(sceneID) ret, err = r.repository.Scene.GetPerformerIDs(ctx, sceneID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -390,11 +390,11 @@ func adjustScenePerformerIDs(qb models.SceneReader, sceneID int, ids BulkUpdateI
} }
type tagIDsGetter interface { type tagIDsGetter interface {
GetTagIDs(id int) ([]int, error) GetTagIDs(ctx context.Context, id int) ([]int, error)
} }
func adjustTagIDs(qb tagIDsGetter, sceneID int, ids BulkUpdateIds) (ret []int, err error) { func adjustTagIDs(ctx context.Context, qb tagIDsGetter, sceneID int, ids BulkUpdateIds) (ret []int, err error) {
ret, err = qb.GetTagIDs(sceneID) ret, err = qb.GetTagIDs(ctx, sceneID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -402,8 +402,8 @@ func adjustTagIDs(qb tagIDsGetter, sceneID int, ids BulkUpdateIds) (ret []int, e
return adjustIDs(ret, ids), nil return adjustIDs(ret, ids), nil
} }
func adjustSceneGalleryIDs(qb models.SceneReader, sceneID int, ids BulkUpdateIds) (ret []int, err error) { func (r *mutationResolver) adjustSceneGalleryIDs(ctx context.Context, sceneID int, ids BulkUpdateIds) (ret []int, err error) {
ret, err = qb.GetGalleryIDs(sceneID) ret, err = r.repository.Scene.GetGalleryIDs(ctx, sceneID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -411,8 +411,8 @@ func adjustSceneGalleryIDs(qb models.SceneReader, sceneID int, ids BulkUpdateIds
return adjustIDs(ret, ids), nil return adjustIDs(ret, ids), nil
} }
func adjustSceneMovieIDs(qb models.SceneReader, sceneID int, updateIDs BulkUpdateIds) ([]models.MoviesScenes, error) { func (r *mutationResolver) adjustSceneMovieIDs(ctx context.Context, sceneID int, updateIDs BulkUpdateIds) ([]models.MoviesScenes, error) {
existingMovies, err := qb.GetMovies(sceneID) existingMovies, err := r.repository.Scene.GetMovies(ctx, sceneID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -471,10 +471,10 @@ func (r *mutationResolver) SceneDestroy(ctx context.Context, input models.SceneD
deleteGenerated := utils.IsTrue(input.DeleteGenerated) deleteGenerated := utils.IsTrue(input.DeleteGenerated)
deleteFile := utils.IsTrue(input.DeleteFile) deleteFile := utils.IsTrue(input.DeleteFile)
if err := r.withTxn(ctx, func(repo models.Repository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := repo.Scene() qb := r.repository.Scene
var err error var err error
s, err = qb.Find(sceneID) s, err = qb.Find(ctx, sceneID)
if err != nil { if err != nil {
return err return err
} }
@@ -486,7 +486,7 @@ func (r *mutationResolver) SceneDestroy(ctx context.Context, input models.SceneD
// kill any running encoders // kill any running encoders
manager.KillRunningStreams(s, fileNamingAlgo) manager.KillRunningStreams(s, fileNamingAlgo)
return scene.Destroy(s, repo, fileDeleter, deleteGenerated, deleteFile) return scene.Destroy(ctx, s, r.repository.Scene, r.repository.SceneMarker, fileDeleter, deleteGenerated, deleteFile)
}); err != nil { }); err != nil {
fileDeleter.Rollback() fileDeleter.Rollback()
return false, err return false, err
@@ -519,13 +519,13 @@ func (r *mutationResolver) ScenesDestroy(ctx context.Context, input models.Scene
deleteGenerated := utils.IsTrue(input.DeleteGenerated) deleteGenerated := utils.IsTrue(input.DeleteGenerated)
deleteFile := utils.IsTrue(input.DeleteFile) deleteFile := utils.IsTrue(input.DeleteFile)
if err := r.withTxn(ctx, func(repo models.Repository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := repo.Scene() qb := r.repository.Scene
for _, id := range input.Ids { for _, id := range input.Ids {
sceneID, _ := strconv.Atoi(id) sceneID, _ := strconv.Atoi(id)
s, err := qb.Find(sceneID) s, err := qb.Find(ctx, sceneID)
if err != nil { if err != nil {
return err return err
} }
@@ -536,7 +536,7 @@ func (r *mutationResolver) ScenesDestroy(ctx context.Context, input models.Scene
// kill any running encoders // kill any running encoders
manager.KillRunningStreams(s, fileNamingAlgo) manager.KillRunningStreams(s, fileNamingAlgo)
if err := scene.Destroy(s, repo, fileDeleter, deleteGenerated, deleteFile); err != nil { if err := scene.Destroy(ctx, s, r.repository.Scene, r.repository.SceneMarker, fileDeleter, deleteGenerated, deleteFile); err != nil {
return err return err
} }
} }
@@ -564,8 +564,8 @@ func (r *mutationResolver) ScenesDestroy(ctx context.Context, input models.Scene
} }
func (r *mutationResolver) getSceneMarker(ctx context.Context, id int) (ret *models.SceneMarker, err error) { func (r *mutationResolver) getSceneMarker(ctx context.Context, id int) (ret *models.SceneMarker, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = repo.SceneMarker().Find(id) ret, err = r.repository.SceneMarker.Find(ctx, id)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -666,11 +666,11 @@ func (r *mutationResolver) SceneMarkerDestroy(ctx context.Context, id string) (b
Paths: manager.GetInstance().Paths, Paths: manager.GetInstance().Paths,
} }
if err := r.withTxn(ctx, func(repo models.Repository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := repo.SceneMarker() qb := r.repository.SceneMarker
sqb := repo.Scene() sqb := r.repository.Scene
marker, err := qb.Find(markerID) marker, err := qb.Find(ctx, markerID)
if err != nil { if err != nil {
return err return err
@@ -680,12 +680,12 @@ func (r *mutationResolver) SceneMarkerDestroy(ctx context.Context, id string) (b
return fmt.Errorf("scene marker with id %d not found", markerID) return fmt.Errorf("scene marker with id %d not found", markerID)
} }
s, err := sqb.Find(int(marker.SceneID.Int64)) s, err := sqb.Find(ctx, int(marker.SceneID.Int64))
if err != nil { if err != nil {
return err return err
} }
return scene.DestroyMarker(s, marker, qb, fileDeleter) return scene.DestroyMarker(ctx, s, marker, qb, fileDeleter)
}); err != nil { }); err != nil {
fileDeleter.Rollback() fileDeleter.Rollback()
return false, err return false, err
@@ -713,26 +713,26 @@ func (r *mutationResolver) changeMarker(ctx context.Context, changeType int, cha
} }
// Start the transaction and save the scene marker // Start the transaction and save the scene marker
if err := r.withTxn(ctx, func(repo models.Repository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := repo.SceneMarker() qb := r.repository.SceneMarker
sqb := repo.Scene() sqb := r.repository.Scene
var err error var err error
switch changeType { switch changeType {
case create: case create:
sceneMarker, err = qb.Create(changedMarker) sceneMarker, err = qb.Create(ctx, changedMarker)
case update: case update:
// check to see if timestamp was changed // check to see if timestamp was changed
existingMarker, err = qb.Find(changedMarker.ID) existingMarker, err = qb.Find(ctx, changedMarker.ID)
if err != nil { if err != nil {
return err return err
} }
sceneMarker, err = qb.Update(changedMarker) sceneMarker, err = qb.Update(ctx, changedMarker)
if err != nil { if err != nil {
return err return err
} }
s, err = sqb.Find(int(existingMarker.SceneID.Int64)) s, err = sqb.Find(ctx, int(existingMarker.SceneID.Int64))
} }
if err != nil { if err != nil {
return err return err
@@ -749,7 +749,7 @@ func (r *mutationResolver) changeMarker(ctx context.Context, changeType int, cha
// Save the marker tags // Save the marker tags
// If this tag is the primary tag, then let's not add it. // If this tag is the primary tag, then let's not add it.
tagIDs = intslice.IntExclude(tagIDs, []int{changedMarker.PrimaryTagID}) tagIDs = intslice.IntExclude(tagIDs, []int{changedMarker.PrimaryTagID})
return qb.UpdateTags(sceneMarker.ID, tagIDs) return qb.UpdateTags(ctx, sceneMarker.ID, tagIDs)
}); err != nil { }); err != nil {
fileDeleter.Rollback() fileDeleter.Rollback()
return nil, err return nil, err
@@ -766,10 +766,10 @@ func (r *mutationResolver) SceneIncrementO(ctx context.Context, id string) (ret
return 0, err return 0, err
} }
if err := r.withTxn(ctx, func(repo models.Repository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := repo.Scene() qb := r.repository.Scene
ret, err = qb.IncrementOCounter(sceneID) ret, err = qb.IncrementOCounter(ctx, sceneID)
return err return err
}); err != nil { }); err != nil {
return 0, err return 0, err
@@ -784,10 +784,10 @@ func (r *mutationResolver) SceneDecrementO(ctx context.Context, id string) (ret
return 0, err return 0, err
} }
if err := r.withTxn(ctx, func(repo models.Repository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := repo.Scene() qb := r.repository.Scene
ret, err = qb.DecrementOCounter(sceneID) ret, err = qb.DecrementOCounter(ctx, sceneID)
return err return err
}); err != nil { }); err != nil {
return 0, err return 0, err
@@ -802,10 +802,10 @@ func (r *mutationResolver) SceneResetO(ctx context.Context, id string) (ret int,
return 0, err return 0, err
} }
if err := r.withTxn(ctx, func(repo models.Repository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := repo.Scene() qb := r.repository.Scene
ret, err = qb.ResetOCounter(sceneID) ret, err = qb.ResetOCounter(ctx, sceneID)
return err return err
}); err != nil { }); err != nil {
return 0, err return 0, err

View File

@@ -7,10 +7,18 @@ import (
"github.com/stashapp/stash/internal/manager" "github.com/stashapp/stash/internal/manager"
"github.com/stashapp/stash/internal/manager/config" "github.com/stashapp/stash/internal/manager/config"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/scraper/stashbox" "github.com/stashapp/stash/pkg/scraper/stashbox"
) )
func (r *Resolver) stashboxRepository() stashbox.Repository {
return stashbox.Repository{
Scene: r.repository.Scene,
Performer: r.repository.Performer,
Tag: r.repository.Tag,
Studio: r.repository.Studio,
}
}
func (r *mutationResolver) SubmitStashBoxFingerprints(ctx context.Context, input StashBoxFingerprintSubmissionInput) (bool, error) { func (r *mutationResolver) SubmitStashBoxFingerprints(ctx context.Context, input StashBoxFingerprintSubmissionInput) (bool, error) {
boxes := config.GetInstance().GetStashBoxes() boxes := config.GetInstance().GetStashBoxes()
@@ -18,7 +26,7 @@ func (r *mutationResolver) SubmitStashBoxFingerprints(ctx context.Context, input
return false, fmt.Errorf("invalid stash_box_index %d", input.StashBoxIndex) return false, fmt.Errorf("invalid stash_box_index %d", input.StashBoxIndex)
} }
client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.txnManager) client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.txnManager, r.stashboxRepository())
return client.SubmitStashBoxFingerprints(ctx, input.SceneIds, boxes[input.StashBoxIndex].Endpoint) return client.SubmitStashBoxFingerprints(ctx, input.SceneIds, boxes[input.StashBoxIndex].Endpoint)
} }
@@ -35,7 +43,7 @@ func (r *mutationResolver) SubmitStashBoxSceneDraft(ctx context.Context, input S
return nil, fmt.Errorf("invalid stash_box_index %d", input.StashBoxIndex) return nil, fmt.Errorf("invalid stash_box_index %d", input.StashBoxIndex)
} }
client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.txnManager) client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.txnManager, r.stashboxRepository())
id, err := strconv.Atoi(input.ID) id, err := strconv.Atoi(input.ID)
if err != nil { if err != nil {
@@ -43,9 +51,9 @@ func (r *mutationResolver) SubmitStashBoxSceneDraft(ctx context.Context, input S
} }
var res *string var res *string
err = r.withReadTxn(ctx, func(repo models.ReaderRepository) error { err = r.withTxn(ctx, func(ctx context.Context) error {
qb := repo.Scene() qb := r.repository.Scene
scene, err := qb.Find(id) scene, err := qb.Find(ctx, id)
if err != nil { if err != nil {
return err return err
} }
@@ -65,7 +73,7 @@ func (r *mutationResolver) SubmitStashBoxPerformerDraft(ctx context.Context, inp
return nil, fmt.Errorf("invalid stash_box_index %d", input.StashBoxIndex) return nil, fmt.Errorf("invalid stash_box_index %d", input.StashBoxIndex)
} }
client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.txnManager) client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.txnManager, r.stashboxRepository())
id, err := strconv.Atoi(input.ID) id, err := strconv.Atoi(input.ID)
if err != nil { if err != nil {
@@ -73,9 +81,9 @@ func (r *mutationResolver) SubmitStashBoxPerformerDraft(ctx context.Context, inp
} }
var res *string var res *string
err = r.withReadTxn(ctx, func(repo models.ReaderRepository) error { err = r.withTxn(ctx, func(ctx context.Context) error {
qb := repo.Performer() qb := r.repository.Performer
performer, err := qb.Find(id) performer, err := qb.Find(ctx, id)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -17,8 +17,8 @@ import (
) )
func (r *mutationResolver) getStudio(ctx context.Context, id int) (ret *models.Studio, err error) { func (r *mutationResolver) getStudio(ctx context.Context, id int) (ret *models.Studio, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = repo.Studio().Find(id) ret, err = r.repository.Studio.Find(ctx, id)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -72,18 +72,18 @@ func (r *mutationResolver) StudioCreate(ctx context.Context, input StudioCreateI
// Start the transaction and save the studio // Start the transaction and save the studio
var s *models.Studio var s *models.Studio
if err := r.withTxn(ctx, func(repo models.Repository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := repo.Studio() qb := r.repository.Studio
var err error var err error
s, err = qb.Create(newStudio) s, err = qb.Create(ctx, newStudio)
if err != nil { if err != nil {
return err return err
} }
// update image table // update image table
if len(imageData) > 0 { if len(imageData) > 0 {
if err := qb.UpdateImage(s.ID, imageData); err != nil { if err := qb.UpdateImage(ctx, s.ID, imageData); err != nil {
return err return err
} }
} }
@@ -91,17 +91,17 @@ func (r *mutationResolver) StudioCreate(ctx context.Context, input StudioCreateI
// Save the stash_ids // Save the stash_ids
if input.StashIds != nil { if input.StashIds != nil {
stashIDJoins := models.StashIDsFromInput(input.StashIds) stashIDJoins := models.StashIDsFromInput(input.StashIds)
if err := qb.UpdateStashIDs(s.ID, stashIDJoins); err != nil { if err := qb.UpdateStashIDs(ctx, s.ID, stashIDJoins); err != nil {
return err return err
} }
} }
if len(input.Aliases) > 0 { if len(input.Aliases) > 0 {
if err := studio.EnsureAliasesUnique(s.ID, input.Aliases, qb); err != nil { if err := studio.EnsureAliasesUnique(ctx, s.ID, input.Aliases, qb); err != nil {
return err return err
} }
if err := qb.UpdateAliases(s.ID, input.Aliases); err != nil { if err := qb.UpdateAliases(ctx, s.ID, input.Aliases); err != nil {
return err return err
} }
} }
@@ -155,27 +155,27 @@ func (r *mutationResolver) StudioUpdate(ctx context.Context, input StudioUpdateI
// Start the transaction and save the studio // Start the transaction and save the studio
var s *models.Studio var s *models.Studio
if err := r.withTxn(ctx, func(repo models.Repository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := repo.Studio() qb := r.repository.Studio
if err := manager.ValidateModifyStudio(updatedStudio, qb); err != nil { if err := manager.ValidateModifyStudio(ctx, updatedStudio, qb); err != nil {
return err return err
} }
var err error var err error
s, err = qb.Update(updatedStudio) s, err = qb.Update(ctx, updatedStudio)
if err != nil { if err != nil {
return err return err
} }
// update image table // update image table
if len(imageData) > 0 { if len(imageData) > 0 {
if err := qb.UpdateImage(s.ID, imageData); err != nil { if err := qb.UpdateImage(ctx, s.ID, imageData); err != nil {
return err return err
} }
} else if imageIncluded { } else if imageIncluded {
// must be unsetting // must be unsetting
if err := qb.DestroyImage(s.ID); err != nil { if err := qb.DestroyImage(ctx, s.ID); err != nil {
return err return err
} }
} }
@@ -183,17 +183,17 @@ func (r *mutationResolver) StudioUpdate(ctx context.Context, input StudioUpdateI
// Save the stash_ids // Save the stash_ids
if translator.hasField("stash_ids") { if translator.hasField("stash_ids") {
stashIDJoins := models.StashIDsFromInput(input.StashIds) stashIDJoins := models.StashIDsFromInput(input.StashIds)
if err := qb.UpdateStashIDs(studioID, stashIDJoins); err != nil { if err := qb.UpdateStashIDs(ctx, studioID, stashIDJoins); err != nil {
return err return err
} }
} }
if translator.hasField("aliases") { if translator.hasField("aliases") {
if err := studio.EnsureAliasesUnique(studioID, input.Aliases, qb); err != nil { if err := studio.EnsureAliasesUnique(ctx, studioID, input.Aliases, qb); err != nil {
return err return err
} }
if err := qb.UpdateAliases(studioID, input.Aliases); err != nil { if err := qb.UpdateAliases(ctx, studioID, input.Aliases); err != nil {
return err return err
} }
} }
@@ -213,8 +213,8 @@ func (r *mutationResolver) StudioDestroy(ctx context.Context, input StudioDestro
return false, err return false, err
} }
if err := r.withTxn(ctx, func(repo models.Repository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
return repo.Studio().Destroy(id) return r.repository.Studio.Destroy(ctx, id)
}); err != nil { }); err != nil {
return false, err return false, err
} }
@@ -230,10 +230,10 @@ func (r *mutationResolver) StudiosDestroy(ctx context.Context, studioIDs []strin
return false, err return false, err
} }
if err := r.withTxn(ctx, func(repo models.Repository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := repo.Studio() qb := r.repository.Studio
for _, id := range ids { for _, id := range ids {
if err := qb.Destroy(id); err != nil { if err := qb.Destroy(ctx, id); err != nil {
return err return err
} }
} }

View File

@@ -15,8 +15,8 @@ import (
) )
func (r *mutationResolver) getTag(ctx context.Context, id int) (ret *models.Tag, err error) { func (r *mutationResolver) getTag(ctx context.Context, id int) (ret *models.Tag, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = repo.Tag().Find(id) ret, err = r.repository.Tag.Find(ctx, id)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -68,44 +68,44 @@ func (r *mutationResolver) TagCreate(ctx context.Context, input TagCreateInput)
// Start the transaction and save the tag // Start the transaction and save the tag
var t *models.Tag var t *models.Tag
if err := r.withTxn(ctx, func(repo models.Repository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := repo.Tag() qb := r.repository.Tag
// ensure name is unique // ensure name is unique
if err := tag.EnsureTagNameUnique(0, newTag.Name, qb); err != nil { if err := tag.EnsureTagNameUnique(ctx, 0, newTag.Name, qb); err != nil {
return err return err
} }
t, err = qb.Create(newTag) t, err = qb.Create(ctx, newTag)
if err != nil { if err != nil {
return err return err
} }
// update image table // update image table
if len(imageData) > 0 { if len(imageData) > 0 {
if err := qb.UpdateImage(t.ID, imageData); err != nil { if err := qb.UpdateImage(ctx, t.ID, imageData); err != nil {
return err return err
} }
} }
if len(input.Aliases) > 0 { if len(input.Aliases) > 0 {
if err := tag.EnsureAliasesUnique(t.ID, input.Aliases, qb); err != nil { if err := tag.EnsureAliasesUnique(ctx, t.ID, input.Aliases, qb); err != nil {
return err return err
} }
if err := qb.UpdateAliases(t.ID, input.Aliases); err != nil { if err := qb.UpdateAliases(ctx, t.ID, input.Aliases); err != nil {
return err return err
} }
} }
if len(parentIDs) > 0 { if len(parentIDs) > 0 {
if err := qb.UpdateParentTags(t.ID, parentIDs); err != nil { if err := qb.UpdateParentTags(ctx, t.ID, parentIDs); err != nil {
return err return err
} }
} }
if len(childIDs) > 0 { if len(childIDs) > 0 {
if err := qb.UpdateChildTags(t.ID, childIDs); err != nil { if err := qb.UpdateChildTags(ctx, t.ID, childIDs); err != nil {
return err return err
} }
} }
@@ -113,7 +113,7 @@ func (r *mutationResolver) TagCreate(ctx context.Context, input TagCreateInput)
// FIXME: This should be called before any changes are made, but // FIXME: This should be called before any changes are made, but
// requires a rewrite of ValidateHierarchy. // requires a rewrite of ValidateHierarchy.
if len(parentIDs) > 0 || len(childIDs) > 0 { if len(parentIDs) > 0 || len(childIDs) > 0 {
if err := tag.ValidateHierarchy(t, parentIDs, childIDs, qb); err != nil { if err := tag.ValidateHierarchy(ctx, t, parentIDs, childIDs, qb); err != nil {
return err return err
} }
} }
@@ -168,11 +168,11 @@ func (r *mutationResolver) TagUpdate(ctx context.Context, input TagUpdateInput)
// Start the transaction and save the tag // Start the transaction and save the tag
var t *models.Tag var t *models.Tag
if err := r.withTxn(ctx, func(repo models.Repository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := repo.Tag() qb := r.repository.Tag
// ensure name is unique // ensure name is unique
t, err = qb.Find(tagID) t, err = qb.Find(ctx, tagID)
if err != nil { if err != nil {
return err return err
} }
@@ -188,48 +188,48 @@ func (r *mutationResolver) TagUpdate(ctx context.Context, input TagUpdateInput)
} }
if input.Name != nil && t.Name != *input.Name { if input.Name != nil && t.Name != *input.Name {
if err := tag.EnsureTagNameUnique(tagID, *input.Name, qb); err != nil { if err := tag.EnsureTagNameUnique(ctx, tagID, *input.Name, qb); err != nil {
return err return err
} }
updatedTag.Name = input.Name updatedTag.Name = input.Name
} }
t, err = qb.Update(updatedTag) t, err = qb.Update(ctx, updatedTag)
if err != nil { if err != nil {
return err return err
} }
// update image table // update image table
if len(imageData) > 0 { if len(imageData) > 0 {
if err := qb.UpdateImage(tagID, imageData); err != nil { if err := qb.UpdateImage(ctx, tagID, imageData); err != nil {
return err return err
} }
} else if imageIncluded { } else if imageIncluded {
// must be unsetting // must be unsetting
if err := qb.DestroyImage(tagID); err != nil { if err := qb.DestroyImage(ctx, tagID); err != nil {
return err return err
} }
} }
if translator.hasField("aliases") { if translator.hasField("aliases") {
if err := tag.EnsureAliasesUnique(tagID, input.Aliases, qb); err != nil { if err := tag.EnsureAliasesUnique(ctx, tagID, input.Aliases, qb); err != nil {
return err return err
} }
if err := qb.UpdateAliases(tagID, input.Aliases); err != nil { if err := qb.UpdateAliases(ctx, tagID, input.Aliases); err != nil {
return err return err
} }
} }
if parentIDs != nil { if parentIDs != nil {
if err := qb.UpdateParentTags(tagID, parentIDs); err != nil { if err := qb.UpdateParentTags(ctx, tagID, parentIDs); err != nil {
return err return err
} }
} }
if childIDs != nil { if childIDs != nil {
if err := qb.UpdateChildTags(tagID, childIDs); err != nil { if err := qb.UpdateChildTags(ctx, tagID, childIDs); err != nil {
return err return err
} }
} }
@@ -237,7 +237,7 @@ func (r *mutationResolver) TagUpdate(ctx context.Context, input TagUpdateInput)
// FIXME: This should be called before any changes are made, but // FIXME: This should be called before any changes are made, but
// requires a rewrite of ValidateHierarchy. // requires a rewrite of ValidateHierarchy.
if parentIDs != nil || childIDs != nil { if parentIDs != nil || childIDs != nil {
if err := tag.ValidateHierarchy(t, parentIDs, childIDs, qb); err != nil { if err := tag.ValidateHierarchy(ctx, t, parentIDs, childIDs, qb); err != nil {
logger.Errorf("Error saving tag: %s", err) logger.Errorf("Error saving tag: %s", err)
return err return err
} }
@@ -258,8 +258,8 @@ func (r *mutationResolver) TagDestroy(ctx context.Context, input TagDestroyInput
return false, err return false, err
} }
if err := r.withTxn(ctx, func(repo models.Repository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
return repo.Tag().Destroy(tagID) return r.repository.Tag.Destroy(ctx, tagID)
}); err != nil { }); err != nil {
return false, err return false, err
} }
@@ -275,10 +275,10 @@ func (r *mutationResolver) TagsDestroy(ctx context.Context, tagIDs []string) (bo
return false, err return false, err
} }
if err := r.withTxn(ctx, func(repo models.Repository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := repo.Tag() qb := r.repository.Tag
for _, id := range ids { for _, id := range ids {
if err := qb.Destroy(id); err != nil { if err := qb.Destroy(ctx, id); err != nil {
return err return err
} }
} }
@@ -311,11 +311,11 @@ func (r *mutationResolver) TagsMerge(ctx context.Context, input TagsMergeInput)
} }
var t *models.Tag var t *models.Tag
if err := r.withTxn(ctx, func(repo models.Repository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := repo.Tag() qb := r.repository.Tag
var err error var err error
t, err = qb.Find(destination) t, err = qb.Find(ctx, destination)
if err != nil { if err != nil {
return err return err
} }
@@ -324,25 +324,25 @@ func (r *mutationResolver) TagsMerge(ctx context.Context, input TagsMergeInput)
return fmt.Errorf("Tag with ID %d not found", destination) return fmt.Errorf("Tag with ID %d not found", destination)
} }
parents, children, err := tag.MergeHierarchy(destination, source, qb) parents, children, err := tag.MergeHierarchy(ctx, destination, source, qb)
if err != nil { if err != nil {
return err return err
} }
if err = qb.Merge(source, destination); err != nil { if err = qb.Merge(ctx, source, destination); err != nil {
return err return err
} }
err = qb.UpdateParentTags(destination, parents) err = qb.UpdateParentTags(ctx, destination, parents)
if err != nil { if err != nil {
return err return err
} }
err = qb.UpdateChildTags(destination, children) err = qb.UpdateChildTags(ctx, destination, children)
if err != nil { if err != nil {
return err return err
} }
err = tag.ValidateHierarchy(t, parents, children, qb) err = tag.ValidateHierarchy(ctx, t, parents, children, qb)
if err != nil { if err != nil {
logger.Errorf("Error merging tag: %s", err) logger.Errorf("Error merging tag: %s", err)
return err return err

View File

@@ -16,17 +16,23 @@ import (
// TODO - move this into a common area // TODO - move this into a common area
func newResolver() *Resolver { func newResolver() *Resolver {
return &Resolver{ return &Resolver{
txnManager: mocks.NewTransactionManager(), txnManager: &mocks.TxnManager{},
repository: mocks.NewTxnRepository(),
hookExecutor: &mockHookExecutor{}, hookExecutor: &mockHookExecutor{},
} }
} }
const tagName = "tagName" const (
const errTagName = "errTagName" tagName = "tagName"
errTagName = "errTagName"
const existingTagID = 1 existingTagID = 1
const existingTagName = "existingTagName" existingTagName = "existingTagName"
const newTagID = 2
newTagID = 2
)
var testCtx = context.Background()
type mockHookExecutor struct{} type mockHookExecutor struct{}
@@ -36,7 +42,7 @@ func (*mockHookExecutor) ExecutePostHooks(ctx context.Context, id int, hookType
func TestTagCreate(t *testing.T) { func TestTagCreate(t *testing.T) {
r := newResolver() r := newResolver()
tagRW := r.txnManager.(*mocks.TransactionManager).Tag().(*mocks.TagReaderWriter) tagRW := r.repository.Tag.(*mocks.TagReaderWriter)
pp := 1 pp := 1
findFilter := &models.FindFilterType{ findFilter := &models.FindFilterType{
@@ -61,25 +67,25 @@ func TestTagCreate(t *testing.T) {
} }
} }
tagRW.On("Query", tagFilterForName(existingTagName), findFilter).Return([]*models.Tag{ tagRW.On("Query", testCtx, tagFilterForName(existingTagName), findFilter).Return([]*models.Tag{
{ {
ID: existingTagID, ID: existingTagID,
Name: existingTagName, Name: existingTagName,
}, },
}, 1, nil).Once() }, 1, nil).Once()
tagRW.On("Query", tagFilterForName(errTagName), findFilter).Return(nil, 0, nil).Once() tagRW.On("Query", testCtx, tagFilterForName(errTagName), findFilter).Return(nil, 0, nil).Once()
tagRW.On("Query", tagFilterForAlias(errTagName), findFilter).Return(nil, 0, nil).Once() tagRW.On("Query", testCtx, tagFilterForAlias(errTagName), findFilter).Return(nil, 0, nil).Once()
expectedErr := errors.New("TagCreate error") expectedErr := errors.New("TagCreate error")
tagRW.On("Create", mock.AnythingOfType("models.Tag")).Return(nil, expectedErr) tagRW.On("Create", testCtx, mock.AnythingOfType("models.Tag")).Return(nil, expectedErr)
_, err := r.Mutation().TagCreate(context.TODO(), TagCreateInput{ _, err := r.Mutation().TagCreate(testCtx, TagCreateInput{
Name: existingTagName, Name: existingTagName,
}) })
assert.NotNil(t, err) assert.NotNil(t, err)
_, err = r.Mutation().TagCreate(context.TODO(), TagCreateInput{ _, err = r.Mutation().TagCreate(testCtx, TagCreateInput{
Name: errTagName, Name: errTagName,
}) })
@@ -87,18 +93,18 @@ func TestTagCreate(t *testing.T) {
tagRW.AssertExpectations(t) tagRW.AssertExpectations(t)
r = newResolver() r = newResolver()
tagRW = r.txnManager.(*mocks.TransactionManager).Tag().(*mocks.TagReaderWriter) tagRW = r.repository.Tag.(*mocks.TagReaderWriter)
tagRW.On("Query", tagFilterForName(tagName), findFilter).Return(nil, 0, nil).Once() tagRW.On("Query", testCtx, tagFilterForName(tagName), findFilter).Return(nil, 0, nil).Once()
tagRW.On("Query", tagFilterForAlias(tagName), findFilter).Return(nil, 0, nil).Once() tagRW.On("Query", testCtx, tagFilterForAlias(tagName), findFilter).Return(nil, 0, nil).Once()
newTag := &models.Tag{ newTag := &models.Tag{
ID: newTagID, ID: newTagID,
Name: tagName, Name: tagName,
} }
tagRW.On("Create", mock.AnythingOfType("models.Tag")).Return(newTag, nil) tagRW.On("Create", testCtx, mock.AnythingOfType("models.Tag")).Return(newTag, nil)
tagRW.On("Find", newTagID).Return(newTag, nil) tagRW.On("Find", testCtx, newTagID).Return(newTag, nil)
tag, err := r.Mutation().TagCreate(context.TODO(), TagCreateInput{ tag, err := r.Mutation().TagCreate(testCtx, TagCreateInput{
Name: tagName, Name: tagName,
}) })

View File

@@ -222,7 +222,7 @@ func makeConfigUIResult() map[string]interface{} {
} }
func (r *queryResolver) ValidateStashBoxCredentials(ctx context.Context, input config.StashBoxInput) (*StashBoxValidationResult, error) { func (r *queryResolver) ValidateStashBoxCredentials(ctx context.Context, input config.StashBoxInput) (*StashBoxValidationResult, error) {
client := stashbox.NewClient(models.StashBox{Endpoint: input.Endpoint, APIKey: input.APIKey}, r.txnManager) client := stashbox.NewClient(models.StashBox{Endpoint: input.Endpoint, APIKey: input.APIKey}, r.txnManager, r.stashboxRepository())
user, err := client.GetUser(ctx) user, err := client.GetUser(ctx)
valid := user != nil && user.Me != nil valid := user != nil && user.Me != nil

View File

@@ -13,8 +13,8 @@ func (r *queryResolver) FindGallery(ctx context.Context, id string) (ret *models
return nil, err return nil, err
} }
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = repo.Gallery().Find(idInt) ret, err = r.repository.Gallery.Find(ctx, idInt)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -24,8 +24,8 @@ func (r *queryResolver) FindGallery(ctx context.Context, id string) (ret *models
} }
func (r *queryResolver) FindGalleries(ctx context.Context, galleryFilter *models.GalleryFilterType, filter *models.FindFilterType) (ret *FindGalleriesResultType, err error) { func (r *queryResolver) FindGalleries(ctx context.Context, galleryFilter *models.GalleryFilterType, filter *models.FindFilterType) (ret *FindGalleriesResultType, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
galleries, total, err := repo.Gallery().Query(galleryFilter, filter) galleries, total, err := r.repository.Gallery.Query(ctx, galleryFilter, filter)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -12,8 +12,8 @@ import (
func (r *queryResolver) FindImage(ctx context.Context, id *string, checksum *string) (*models.Image, error) { func (r *queryResolver) FindImage(ctx context.Context, id *string, checksum *string) (*models.Image, error) {
var image *models.Image var image *models.Image
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := repo.Image() qb := r.repository.Image
var err error var err error
if id != nil { if id != nil {
@@ -22,12 +22,12 @@ func (r *queryResolver) FindImage(ctx context.Context, id *string, checksum *str
return err return err
} }
image, err = qb.Find(idInt) image, err = qb.Find(ctx, idInt)
if err != nil { if err != nil {
return err return err
} }
} else if checksum != nil { } else if checksum != nil {
image, err = qb.FindByChecksum(*checksum) image, err = qb.FindByChecksum(ctx, *checksum)
} }
return err return err
@@ -39,12 +39,12 @@ func (r *queryResolver) FindImage(ctx context.Context, id *string, checksum *str
} }
func (r *queryResolver) FindImages(ctx context.Context, imageFilter *models.ImageFilterType, imageIds []int, filter *models.FindFilterType) (ret *FindImagesResultType, err error) { func (r *queryResolver) FindImages(ctx context.Context, imageFilter *models.ImageFilterType, imageIds []int, filter *models.FindFilterType) (ret *FindImagesResultType, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := repo.Image() qb := r.repository.Image
fields := graphql.CollectAllFields(ctx) fields := graphql.CollectAllFields(ctx)
result, err := qb.Query(models.ImageQueryOptions{ result, err := qb.Query(ctx, models.ImageQueryOptions{
QueryOptions: models.QueryOptions{ QueryOptions: models.QueryOptions{
FindFilter: filter, FindFilter: filter,
Count: stringslice.StrInclude(fields, "count"), Count: stringslice.StrInclude(fields, "count"),
@@ -57,7 +57,7 @@ func (r *queryResolver) FindImages(ctx context.Context, imageFilter *models.Imag
return err return err
} }
images, err := result.Resolve() images, err := result.Resolve(ctx)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -13,8 +13,8 @@ func (r *queryResolver) FindMovie(ctx context.Context, id string) (ret *models.M
return nil, err return nil, err
} }
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = repo.Movie().Find(idInt) ret, err = r.repository.Movie.Find(ctx, idInt)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -24,8 +24,8 @@ func (r *queryResolver) FindMovie(ctx context.Context, id string) (ret *models.M
} }
func (r *queryResolver) FindMovies(ctx context.Context, movieFilter *models.MovieFilterType, filter *models.FindFilterType) (ret *FindMoviesResultType, err error) { func (r *queryResolver) FindMovies(ctx context.Context, movieFilter *models.MovieFilterType, filter *models.FindFilterType) (ret *FindMoviesResultType, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
movies, total, err := repo.Movie().Query(movieFilter, filter) movies, total, err := r.repository.Movie.Query(ctx, movieFilter, filter)
if err != nil { if err != nil {
return err return err
} }
@@ -44,8 +44,8 @@ func (r *queryResolver) FindMovies(ctx context.Context, movieFilter *models.Movi
} }
func (r *queryResolver) AllMovies(ctx context.Context) (ret []*models.Movie, err error) { func (r *queryResolver) AllMovies(ctx context.Context) (ret []*models.Movie, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = repo.Movie().All() ret, err = r.repository.Movie.All(ctx)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err

View File

@@ -13,8 +13,8 @@ func (r *queryResolver) FindPerformer(ctx context.Context, id string) (ret *mode
return nil, err return nil, err
} }
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = repo.Performer().Find(idInt) ret, err = r.repository.Performer.Find(ctx, idInt)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -24,8 +24,8 @@ func (r *queryResolver) FindPerformer(ctx context.Context, id string) (ret *mode
} }
func (r *queryResolver) FindPerformers(ctx context.Context, performerFilter *models.PerformerFilterType, filter *models.FindFilterType) (ret *FindPerformersResultType, err error) { func (r *queryResolver) FindPerformers(ctx context.Context, performerFilter *models.PerformerFilterType, filter *models.FindFilterType) (ret *FindPerformersResultType, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
performers, total, err := repo.Performer().Query(performerFilter, filter) performers, total, err := r.repository.Performer.Query(ctx, performerFilter, filter)
if err != nil { if err != nil {
return err return err
} }
@@ -43,8 +43,8 @@ func (r *queryResolver) FindPerformers(ctx context.Context, performerFilter *mod
} }
func (r *queryResolver) AllPerformers(ctx context.Context) (ret []*models.Performer, err error) { func (r *queryResolver) AllPerformers(ctx context.Context) (ret []*models.Performer, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = repo.Performer().All() ret, err = r.repository.Performer.All(ctx)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err

View File

@@ -13,8 +13,8 @@ func (r *queryResolver) FindSavedFilter(ctx context.Context, id string) (ret *mo
return nil, err return nil, err
} }
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = repo.SavedFilter().Find(idInt) ret, err = r.repository.SavedFilter.Find(ctx, idInt)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -23,11 +23,11 @@ func (r *queryResolver) FindSavedFilter(ctx context.Context, id string) (ret *mo
} }
func (r *queryResolver) FindSavedFilters(ctx context.Context, mode *models.FilterMode) (ret []*models.SavedFilter, err error) { func (r *queryResolver) FindSavedFilters(ctx context.Context, mode *models.FilterMode) (ret []*models.SavedFilter, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
if mode != nil { if mode != nil {
ret, err = repo.SavedFilter().FindByMode(*mode) ret, err = r.repository.SavedFilter.FindByMode(ctx, *mode)
} else { } else {
ret, err = repo.SavedFilter().All() ret, err = r.repository.SavedFilter.All(ctx)
} }
return err return err
}); err != nil { }); err != nil {
@@ -37,8 +37,8 @@ func (r *queryResolver) FindSavedFilters(ctx context.Context, mode *models.Filte
} }
func (r *queryResolver) FindDefaultFilter(ctx context.Context, mode models.FilterMode) (ret *models.SavedFilter, err error) { func (r *queryResolver) FindDefaultFilter(ctx context.Context, mode models.FilterMode) (ret *models.SavedFilter, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = repo.SavedFilter().FindDefault(mode) ret, err = r.repository.SavedFilter.FindDefault(ctx, mode)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err

View File

@@ -12,20 +12,20 @@ import (
func (r *queryResolver) FindScene(ctx context.Context, id *string, checksum *string) (*models.Scene, error) { func (r *queryResolver) FindScene(ctx context.Context, id *string, checksum *string) (*models.Scene, error) {
var scene *models.Scene var scene *models.Scene
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := repo.Scene() qb := r.repository.Scene
var err error var err error
if id != nil { if id != nil {
idInt, err := strconv.Atoi(*id) idInt, err := strconv.Atoi(*id)
if err != nil { if err != nil {
return err return err
} }
scene, err = qb.Find(idInt) scene, err = qb.Find(ctx, idInt)
if err != nil { if err != nil {
return err return err
} }
} else if checksum != nil { } else if checksum != nil {
scene, err = qb.FindByChecksum(*checksum) scene, err = qb.FindByChecksum(ctx, *checksum)
} }
return err return err
@@ -39,18 +39,18 @@ func (r *queryResolver) FindScene(ctx context.Context, id *string, checksum *str
func (r *queryResolver) FindSceneByHash(ctx context.Context, input SceneHashInput) (*models.Scene, error) { func (r *queryResolver) FindSceneByHash(ctx context.Context, input SceneHashInput) (*models.Scene, error) {
var scene *models.Scene var scene *models.Scene
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := repo.Scene() qb := r.repository.Scene
var err error var err error
if input.Checksum != nil { if input.Checksum != nil {
scene, err = qb.FindByChecksum(*input.Checksum) scene, err = qb.FindByChecksum(ctx, *input.Checksum)
if err != nil { if err != nil {
return err return err
} }
} }
if scene == nil && input.Oshash != nil { if scene == nil && input.Oshash != nil {
scene, err = qb.FindByOSHash(*input.Oshash) scene, err = qb.FindByOSHash(ctx, *input.Oshash)
if err != nil { if err != nil {
return err return err
} }
@@ -65,7 +65,7 @@ func (r *queryResolver) FindSceneByHash(ctx context.Context, input SceneHashInpu
} }
func (r *queryResolver) FindScenes(ctx context.Context, sceneFilter *models.SceneFilterType, sceneIDs []int, filter *models.FindFilterType) (ret *FindScenesResultType, err error) { func (r *queryResolver) FindScenes(ctx context.Context, sceneFilter *models.SceneFilterType, sceneIDs []int, filter *models.FindFilterType) (ret *FindScenesResultType, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
var scenes []*models.Scene var scenes []*models.Scene
var err error var err error
@@ -73,7 +73,7 @@ func (r *queryResolver) FindScenes(ctx context.Context, sceneFilter *models.Scen
result := &models.SceneQueryResult{} result := &models.SceneQueryResult{}
if len(sceneIDs) > 0 { if len(sceneIDs) > 0 {
scenes, err = repo.Scene().FindMany(sceneIDs) scenes, err = r.repository.Scene.FindMany(ctx, sceneIDs)
if err == nil { if err == nil {
result.Count = len(scenes) result.Count = len(scenes)
for _, s := range scenes { for _, s := range scenes {
@@ -83,7 +83,7 @@ func (r *queryResolver) FindScenes(ctx context.Context, sceneFilter *models.Scen
} }
} }
} else { } else {
result, err = repo.Scene().Query(models.SceneQueryOptions{ result, err = r.repository.Scene.Query(ctx, models.SceneQueryOptions{
QueryOptions: models.QueryOptions{ QueryOptions: models.QueryOptions{
FindFilter: filter, FindFilter: filter,
Count: stringslice.StrInclude(fields, "count"), Count: stringslice.StrInclude(fields, "count"),
@@ -93,7 +93,7 @@ func (r *queryResolver) FindScenes(ctx context.Context, sceneFilter *models.Scen
TotalSize: stringslice.StrInclude(fields, "filesize"), TotalSize: stringslice.StrInclude(fields, "filesize"),
}) })
if err == nil { if err == nil {
scenes, err = result.Resolve() scenes, err = result.Resolve(ctx)
} }
} }
@@ -117,7 +117,7 @@ func (r *queryResolver) FindScenes(ctx context.Context, sceneFilter *models.Scen
} }
func (r *queryResolver) FindScenesByPathRegex(ctx context.Context, filter *models.FindFilterType) (ret *FindScenesResultType, err error) { func (r *queryResolver) FindScenesByPathRegex(ctx context.Context, filter *models.FindFilterType) (ret *FindScenesResultType, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
sceneFilter := &models.SceneFilterType{} sceneFilter := &models.SceneFilterType{}
@@ -138,7 +138,7 @@ func (r *queryResolver) FindScenesByPathRegex(ctx context.Context, filter *model
fields := graphql.CollectAllFields(ctx) fields := graphql.CollectAllFields(ctx)
result, err := repo.Scene().Query(models.SceneQueryOptions{ result, err := r.repository.Scene.Query(ctx, models.SceneQueryOptions{
QueryOptions: models.QueryOptions{ QueryOptions: models.QueryOptions{
FindFilter: queryFilter, FindFilter: queryFilter,
Count: stringslice.StrInclude(fields, "count"), Count: stringslice.StrInclude(fields, "count"),
@@ -151,7 +151,7 @@ func (r *queryResolver) FindScenesByPathRegex(ctx context.Context, filter *model
return err return err
} }
scenes, err := result.Resolve() scenes, err := result.Resolve(ctx)
if err != nil { if err != nil {
return err return err
} }
@@ -174,8 +174,14 @@ func (r *queryResolver) FindScenesByPathRegex(ctx context.Context, filter *model
func (r *queryResolver) ParseSceneFilenames(ctx context.Context, filter *models.FindFilterType, config manager.SceneParserInput) (ret *SceneParserResultType, err error) { func (r *queryResolver) ParseSceneFilenames(ctx context.Context, filter *models.FindFilterType, config manager.SceneParserInput) (ret *SceneParserResultType, err error) {
parser := manager.NewSceneFilenameParser(filter, config) parser := manager.NewSceneFilenameParser(filter, config)
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
result, count, err := parser.Parse(repo) result, count, err := parser.Parse(ctx, manager.SceneFilenameParserRepository{
Scene: r.repository.Scene,
Performer: r.repository.Performer,
Studio: r.repository.Studio,
Movie: r.repository.Movie,
Tag: r.repository.Tag,
})
if err != nil { if err != nil {
return err return err
@@ -199,8 +205,8 @@ func (r *queryResolver) FindDuplicateScenes(ctx context.Context, distance *int)
if distance != nil { if distance != nil {
dist = *distance dist = *distance
} }
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = repo.Scene().FindDuplicates(dist) ret, err = r.repository.Scene.FindDuplicates(ctx, dist)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err

View File

@@ -7,8 +7,8 @@ import (
) )
func (r *queryResolver) FindSceneMarkers(ctx context.Context, sceneMarkerFilter *models.SceneMarkerFilterType, filter *models.FindFilterType) (ret *FindSceneMarkersResultType, err error) { func (r *queryResolver) FindSceneMarkers(ctx context.Context, sceneMarkerFilter *models.SceneMarkerFilterType, filter *models.FindFilterType) (ret *FindSceneMarkersResultType, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
sceneMarkers, total, err := repo.SceneMarker().Query(sceneMarkerFilter, filter) sceneMarkers, total, err := r.repository.SceneMarker.Query(ctx, sceneMarkerFilter, filter)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -13,9 +13,9 @@ func (r *queryResolver) FindStudio(ctx context.Context, id string) (ret *models.
return nil, err return nil, err
} }
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
var err error var err error
ret, err = repo.Studio().Find(idInt) ret, err = r.repository.Studio.Find(ctx, idInt)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -25,8 +25,8 @@ func (r *queryResolver) FindStudio(ctx context.Context, id string) (ret *models.
} }
func (r *queryResolver) FindStudios(ctx context.Context, studioFilter *models.StudioFilterType, filter *models.FindFilterType) (ret *FindStudiosResultType, err error) { func (r *queryResolver) FindStudios(ctx context.Context, studioFilter *models.StudioFilterType, filter *models.FindFilterType) (ret *FindStudiosResultType, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
studios, total, err := repo.Studio().Query(studioFilter, filter) studios, total, err := r.repository.Studio.Query(ctx, studioFilter, filter)
if err != nil { if err != nil {
return err return err
} }
@@ -45,8 +45,8 @@ func (r *queryResolver) FindStudios(ctx context.Context, studioFilter *models.St
} }
func (r *queryResolver) AllStudios(ctx context.Context) (ret []*models.Studio, err error) { func (r *queryResolver) AllStudios(ctx context.Context) (ret []*models.Studio, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = repo.Studio().All() ret, err = r.repository.Studio.All(ctx)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err

View File

@@ -13,8 +13,8 @@ func (r *queryResolver) FindTag(ctx context.Context, id string) (ret *models.Tag
return nil, err return nil, err
} }
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = repo.Tag().Find(idInt) ret, err = r.repository.Tag.Find(ctx, idInt)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -24,8 +24,8 @@ func (r *queryResolver) FindTag(ctx context.Context, id string) (ret *models.Tag
} }
func (r *queryResolver) FindTags(ctx context.Context, tagFilter *models.TagFilterType, filter *models.FindFilterType) (ret *FindTagsResultType, err error) { func (r *queryResolver) FindTags(ctx context.Context, tagFilter *models.TagFilterType, filter *models.FindFilterType) (ret *FindTagsResultType, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
tags, total, err := repo.Tag().Query(tagFilter, filter) tags, total, err := r.repository.Tag.Query(ctx, tagFilter, filter)
if err != nil { if err != nil {
return err return err
} }
@@ -44,8 +44,8 @@ func (r *queryResolver) FindTags(ctx context.Context, tagFilter *models.TagFilte
} }
func (r *queryResolver) AllTags(ctx context.Context) (ret []*models.Tag, err error) { func (r *queryResolver) AllTags(ctx context.Context) (ret []*models.Tag, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
ret, err = repo.Tag().All() ret, err = r.repository.Tag.All(ctx)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err

View File

@@ -14,10 +14,10 @@ import (
func (r *queryResolver) SceneStreams(ctx context.Context, id *string) ([]*manager.SceneStreamEndpoint, error) { func (r *queryResolver) SceneStreams(ctx context.Context, id *string) ([]*manager.SceneStreamEndpoint, error) {
// find the scene // find the scene
var scene *models.Scene var scene *models.Scene
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
idInt, _ := strconv.Atoi(*id) idInt, _ := strconv.Atoi(*id)
var err error var err error
scene, err = repo.Scene().Find(idInt) scene, err = r.repository.Scene.Find(ctx, idInt)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err

View File

@@ -234,7 +234,7 @@ func (r *queryResolver) getStashBoxClient(index int) (*stashbox.Client, error) {
return nil, fmt.Errorf("%w: invalid stash_box_index %d", ErrInput, index) return nil, fmt.Errorf("%w: invalid stash_box_index %d", ErrInput, index)
} }
return stashbox.NewClient(*boxes[index], r.txnManager), nil return stashbox.NewClient(*boxes[index], r.txnManager, r.stashboxRepository()), nil
} }
func (r *queryResolver) ScrapeSingleScene(ctx context.Context, source scraper.Source, input ScrapeSingleSceneInput) ([]*scraper.ScrapedScene, error) { func (r *queryResolver) ScrapeSingleScene(ctx context.Context, source scraper.Source, input ScrapeSingleSceneInput) ([]*scraper.ScrapedScene, error) {

View File

@@ -13,17 +13,24 @@ import (
"github.com/stashapp/stash/pkg/image" "github.com/stashapp/stash/pkg/image"
"github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
) )
type ImageFinder interface {
Find(ctx context.Context, id int) (*models.Image, error)
FindByChecksum(ctx context.Context, checksum string) (*models.Image, error)
}
type imageRoutes struct { type imageRoutes struct {
txnManager models.TransactionManager txnManager txn.Manager
imageFinder ImageFinder
} }
func (rs imageRoutes) Routes() chi.Router { func (rs imageRoutes) Routes() chi.Router {
r := chi.NewRouter() r := chi.NewRouter()
r.Route("/{imageId}", func(r chi.Router) { r.Route("/{imageId}", func(r chi.Router) {
r.Use(ImageCtx) r.Use(rs.ImageCtx)
r.Get("/image", rs.Image) r.Get("/image", rs.Image)
r.Get("/thumbnail", rs.Thumbnail) r.Get("/thumbnail", rs.Thumbnail)
@@ -85,18 +92,18 @@ func (rs imageRoutes) Image(w http.ResponseWriter, r *http.Request) {
// endregion // endregion
func ImageCtx(next http.Handler) http.Handler { func (rs imageRoutes) ImageCtx(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
imageIdentifierQueryParam := chi.URLParam(r, "imageId") imageIdentifierQueryParam := chi.URLParam(r, "imageId")
imageID, _ := strconv.Atoi(imageIdentifierQueryParam) imageID, _ := strconv.Atoi(imageIdentifierQueryParam)
var image *models.Image var image *models.Image
readTxnErr := manager.GetInstance().TxnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { readTxnErr := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
qb := repo.Image() qb := rs.imageFinder
if imageID == 0 { if imageID == 0 {
image, _ = qb.FindByChecksum(imageIdentifierQueryParam) image, _ = qb.FindByChecksum(ctx, imageIdentifierQueryParam)
} else { } else {
image, _ = qb.Find(imageID) image, _ = qb.Find(ctx, imageID)
} }
return nil return nil

View File

@@ -6,21 +6,28 @@ import (
"strconv" "strconv"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/stashapp/stash/internal/manager"
"github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
) )
type MovieFinder interface {
GetFrontImage(ctx context.Context, movieID int) ([]byte, error)
GetBackImage(ctx context.Context, movieID int) ([]byte, error)
Find(ctx context.Context, id int) (*models.Movie, error)
}
type movieRoutes struct { type movieRoutes struct {
txnManager models.TransactionManager txnManager txn.Manager
movieFinder MovieFinder
} }
func (rs movieRoutes) Routes() chi.Router { func (rs movieRoutes) Routes() chi.Router {
r := chi.NewRouter() r := chi.NewRouter()
r.Route("/{movieId}", func(r chi.Router) { r.Route("/{movieId}", func(r chi.Router) {
r.Use(MovieCtx) r.Use(rs.MovieCtx)
r.Get("/frontimage", rs.FrontImage) r.Get("/frontimage", rs.FrontImage)
r.Get("/backimage", rs.BackImage) r.Get("/backimage", rs.BackImage)
}) })
@@ -33,8 +40,8 @@ func (rs movieRoutes) FrontImage(w http.ResponseWriter, r *http.Request) {
defaultParam := r.URL.Query().Get("default") defaultParam := r.URL.Query().Get("default")
var image []byte var image []byte
if defaultParam != "true" { if defaultParam != "true" {
err := rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
image, _ = repo.Movie().GetFrontImage(movie.ID) image, _ = rs.movieFinder.GetFrontImage(ctx, movie.ID)
return nil return nil
}) })
if err != nil { if err != nil {
@@ -56,8 +63,8 @@ func (rs movieRoutes) BackImage(w http.ResponseWriter, r *http.Request) {
defaultParam := r.URL.Query().Get("default") defaultParam := r.URL.Query().Get("default")
var image []byte var image []byte
if defaultParam != "true" { if defaultParam != "true" {
err := rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
image, _ = repo.Movie().GetBackImage(movie.ID) image, _ = rs.movieFinder.GetBackImage(ctx, movie.ID)
return nil return nil
}) })
if err != nil { if err != nil {
@@ -74,7 +81,7 @@ func (rs movieRoutes) BackImage(w http.ResponseWriter, r *http.Request) {
} }
} }
func MovieCtx(next http.Handler) http.Handler { func (rs movieRoutes) MovieCtx(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
movieID, err := strconv.Atoi(chi.URLParam(r, "movieId")) movieID, err := strconv.Atoi(chi.URLParam(r, "movieId"))
if err != nil { if err != nil {
@@ -83,9 +90,9 @@ func MovieCtx(next http.Handler) http.Handler {
} }
var movie *models.Movie var movie *models.Movie
if err := manager.GetInstance().TxnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { if err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
var err error var err error
movie, err = repo.Movie().Find(movieID) movie, err = rs.movieFinder.Find(ctx, movieID)
return err return err
}); err != nil { }); err != nil {
http.Error(w, http.StatusText(404), 404) http.Error(w, http.StatusText(404), 404)

View File

@@ -6,22 +6,28 @@ import (
"strconv" "strconv"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/stashapp/stash/internal/manager"
"github.com/stashapp/stash/internal/manager/config" "github.com/stashapp/stash/internal/manager/config"
"github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
) )
type PerformerFinder interface {
Find(ctx context.Context, id int) (*models.Performer, error)
GetImage(ctx context.Context, performerID int) ([]byte, error)
}
type performerRoutes struct { type performerRoutes struct {
txnManager models.TransactionManager txnManager txn.Manager
performerFinder PerformerFinder
} }
func (rs performerRoutes) Routes() chi.Router { func (rs performerRoutes) Routes() chi.Router {
r := chi.NewRouter() r := chi.NewRouter()
r.Route("/{performerId}", func(r chi.Router) { r.Route("/{performerId}", func(r chi.Router) {
r.Use(PerformerCtx) r.Use(rs.PerformerCtx)
r.Get("/image", rs.Image) r.Get("/image", rs.Image)
}) })
@@ -34,8 +40,8 @@ func (rs performerRoutes) Image(w http.ResponseWriter, r *http.Request) {
var image []byte var image []byte
if defaultParam != "true" { if defaultParam != "true" {
readTxnErr := rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { readTxnErr := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
image, _ = repo.Performer().GetImage(performer.ID) image, _ = rs.performerFinder.GetImage(ctx, performer.ID)
return nil return nil
}) })
if readTxnErr != nil { if readTxnErr != nil {
@@ -52,7 +58,7 @@ func (rs performerRoutes) Image(w http.ResponseWriter, r *http.Request) {
} }
} }
func PerformerCtx(next http.Handler) http.Handler { func (rs performerRoutes) PerformerCtx(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
performerID, err := strconv.Atoi(chi.URLParam(r, "performerId")) performerID, err := strconv.Atoi(chi.URLParam(r, "performerId"))
if err != nil { if err != nil {
@@ -61,9 +67,9 @@ func PerformerCtx(next http.Handler) http.Handler {
} }
var performer *models.Performer var performer *models.Performer
if err := manager.GetInstance().TxnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { if err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
var err error var err error
performer, err = repo.Performer().Find(performerID) performer, err = rs.performerFinder.Find(ctx, performerID)
return err return err
}); err != nil { }); err != nil {
http.Error(w, http.StatusText(404), 404) http.Error(w, http.StatusText(404), 404)

View File

@@ -15,18 +15,36 @@ import (
"github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/scene" "github.com/stashapp/stash/pkg/scene"
"github.com/stashapp/stash/pkg/txn"
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
) )
type SceneFinder interface {
manager.SceneCoverGetter
scene.IDFinder
FindByChecksum(ctx context.Context, checksum string) (*models.Scene, error)
FindByOSHash(ctx context.Context, oshash string) (*models.Scene, error)
GetCaptions(ctx context.Context, sceneID int) ([]*models.SceneCaption, error)
}
type SceneMarkerFinder interface {
Find(ctx context.Context, id int) (*models.SceneMarker, error)
FindBySceneID(ctx context.Context, sceneID int) ([]*models.SceneMarker, error)
}
type sceneRoutes struct { type sceneRoutes struct {
txnManager models.TransactionManager txnManager txn.Manager
sceneFinder SceneFinder
sceneMarkerFinder SceneMarkerFinder
tagFinder scene.MarkerTagFinder
} }
func (rs sceneRoutes) Routes() chi.Router { func (rs sceneRoutes) Routes() chi.Router {
r := chi.NewRouter() r := chi.NewRouter()
r.Route("/{sceneId}", func(r chi.Router) { r.Route("/{sceneId}", func(r chi.Router) {
r.Use(SceneCtx) r.Use(rs.SceneCtx)
// streaming endpoints // streaming endpoints
r.Get("/stream", rs.StreamDirect) r.Get("/stream", rs.StreamDirect)
@@ -48,8 +66,8 @@ func (rs sceneRoutes) Routes() chi.Router {
r.Get("/scene_marker/{sceneMarkerId}/preview", rs.SceneMarkerPreview) r.Get("/scene_marker/{sceneMarkerId}/preview", rs.SceneMarkerPreview)
r.Get("/scene_marker/{sceneMarkerId}/screenshot", rs.SceneMarkerScreenshot) r.Get("/scene_marker/{sceneMarkerId}/screenshot", rs.SceneMarkerScreenshot)
}) })
r.With(SceneCtx).Get("/{sceneId}_thumbs.vtt", rs.VttThumbs) r.With(rs.SceneCtx).Get("/{sceneId}_thumbs.vtt", rs.VttThumbs)
r.With(SceneCtx).Get("/{sceneId}_sprite.jpg", rs.VttSprite) r.With(rs.SceneCtx).Get("/{sceneId}_sprite.jpg", rs.VttSprite)
return r return r
} }
@@ -60,7 +78,8 @@ func (rs sceneRoutes) StreamDirect(w http.ResponseWriter, r *http.Request) {
scene := r.Context().Value(sceneKey).(*models.Scene) scene := r.Context().Value(sceneKey).(*models.Scene)
ss := manager.SceneServer{ ss := manager.SceneServer{
TXNManager: rs.txnManager, TxnManager: rs.txnManager,
SceneCoverGetter: rs.sceneFinder,
} }
ss.StreamSceneDirect(scene, w, r) ss.StreamSceneDirect(scene, w, r)
} }
@@ -190,7 +209,8 @@ func (rs sceneRoutes) Screenshot(w http.ResponseWriter, r *http.Request) {
scene := r.Context().Value(sceneKey).(*models.Scene) scene := r.Context().Value(sceneKey).(*models.Scene)
ss := manager.SceneServer{ ss := manager.SceneServer{
TXNManager: rs.txnManager, TxnManager: rs.txnManager,
SceneCoverGetter: rs.sceneFinder,
} }
ss.ServeScreenshot(scene, w, r) ss.ServeScreenshot(scene, w, r)
} }
@@ -221,16 +241,16 @@ func (rs sceneRoutes) getChapterVttTitle(ctx context.Context, marker *models.Sce
} }
var ret string var ret string
if err := rs.txnManager.WithReadTxn(ctx, func(repo models.ReaderRepository) error { if err := txn.WithTxn(ctx, rs.txnManager, func(ctx context.Context) error {
qb := repo.Tag() qb := rs.tagFinder
primaryTag, err := qb.Find(marker.PrimaryTagID) primaryTag, err := qb.Find(ctx, marker.PrimaryTagID)
if err != nil { if err != nil {
return err return err
} }
ret = primaryTag.Name ret = primaryTag.Name
tags, err := qb.FindBySceneMarkerID(marker.ID) tags, err := qb.FindBySceneMarkerID(ctx, marker.ID)
if err != nil { if err != nil {
return err return err
} }
@@ -250,9 +270,9 @@ func (rs sceneRoutes) getChapterVttTitle(ctx context.Context, marker *models.Sce
func (rs sceneRoutes) ChapterVtt(w http.ResponseWriter, r *http.Request) { func (rs sceneRoutes) ChapterVtt(w http.ResponseWriter, r *http.Request) {
scene := r.Context().Value(sceneKey).(*models.Scene) scene := r.Context().Value(sceneKey).(*models.Scene)
var sceneMarkers []*models.SceneMarker var sceneMarkers []*models.SceneMarker
if err := rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { if err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
var err error var err error
sceneMarkers, err = repo.SceneMarker().FindBySceneID(scene.ID) sceneMarkers, err = rs.sceneMarkerFinder.FindBySceneID(ctx, scene.ID)
return err return err
}); err != nil { }); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
@@ -289,9 +309,9 @@ func (rs sceneRoutes) InteractiveHeatmap(w http.ResponseWriter, r *http.Request)
func (rs sceneRoutes) Caption(w http.ResponseWriter, r *http.Request, lang string, ext string) { func (rs sceneRoutes) Caption(w http.ResponseWriter, r *http.Request, lang string, ext string) {
s := r.Context().Value(sceneKey).(*models.Scene) s := r.Context().Value(sceneKey).(*models.Scene)
if err := rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { if err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
var err error var err error
captions, err := repo.Scene().GetCaptions(s.ID) captions, err := rs.sceneFinder.GetCaptions(ctx, s.ID)
for _, caption := range captions { for _, caption := range captions {
if lang == caption.LanguageCode && ext == caption.CaptionType { if lang == caption.LanguageCode && ext == caption.CaptionType {
sub, err := scene.ReadSubs(caption.Path(s.Path)) sub, err := scene.ReadSubs(caption.Path(s.Path))
@@ -344,9 +364,9 @@ func (rs sceneRoutes) SceneMarkerStream(w http.ResponseWriter, r *http.Request)
scene := r.Context().Value(sceneKey).(*models.Scene) scene := r.Context().Value(sceneKey).(*models.Scene)
sceneMarkerID, _ := strconv.Atoi(chi.URLParam(r, "sceneMarkerId")) sceneMarkerID, _ := strconv.Atoi(chi.URLParam(r, "sceneMarkerId"))
var sceneMarker *models.SceneMarker var sceneMarker *models.SceneMarker
if err := rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { if err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
var err error var err error
sceneMarker, err = repo.SceneMarker().Find(sceneMarkerID) sceneMarker, err = rs.sceneMarkerFinder.Find(ctx, sceneMarkerID)
return err return err
}); err != nil { }); err != nil {
logger.Warnf("Error when getting scene marker for stream: %s", err.Error()) logger.Warnf("Error when getting scene marker for stream: %s", err.Error())
@@ -367,9 +387,9 @@ func (rs sceneRoutes) SceneMarkerPreview(w http.ResponseWriter, r *http.Request)
scene := r.Context().Value(sceneKey).(*models.Scene) scene := r.Context().Value(sceneKey).(*models.Scene)
sceneMarkerID, _ := strconv.Atoi(chi.URLParam(r, "sceneMarkerId")) sceneMarkerID, _ := strconv.Atoi(chi.URLParam(r, "sceneMarkerId"))
var sceneMarker *models.SceneMarker var sceneMarker *models.SceneMarker
if err := rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { if err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
var err error var err error
sceneMarker, err = repo.SceneMarker().Find(sceneMarkerID) sceneMarker, err = rs.sceneMarkerFinder.Find(ctx, sceneMarkerID)
return err return err
}); err != nil { }); err != nil {
logger.Warnf("Error when getting scene marker for stream: %s", err.Error()) logger.Warnf("Error when getting scene marker for stream: %s", err.Error())
@@ -400,9 +420,9 @@ func (rs sceneRoutes) SceneMarkerScreenshot(w http.ResponseWriter, r *http.Reque
scene := r.Context().Value(sceneKey).(*models.Scene) scene := r.Context().Value(sceneKey).(*models.Scene)
sceneMarkerID, _ := strconv.Atoi(chi.URLParam(r, "sceneMarkerId")) sceneMarkerID, _ := strconv.Atoi(chi.URLParam(r, "sceneMarkerId"))
var sceneMarker *models.SceneMarker var sceneMarker *models.SceneMarker
if err := rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { if err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
var err error var err error
sceneMarker, err = repo.SceneMarker().Find(sceneMarkerID) sceneMarker, err = rs.sceneMarkerFinder.Find(ctx, sceneMarkerID)
return err return err
}); err != nil { }); err != nil {
logger.Warnf("Error when getting scene marker for stream: %s", err.Error()) logger.Warnf("Error when getting scene marker for stream: %s", err.Error())
@@ -431,23 +451,23 @@ func (rs sceneRoutes) SceneMarkerScreenshot(w http.ResponseWriter, r *http.Reque
// endregion // endregion
func SceneCtx(next http.Handler) http.Handler { func (rs sceneRoutes) SceneCtx(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sceneIdentifierQueryParam := chi.URLParam(r, "sceneId") sceneIdentifierQueryParam := chi.URLParam(r, "sceneId")
sceneID, _ := strconv.Atoi(sceneIdentifierQueryParam) sceneID, _ := strconv.Atoi(sceneIdentifierQueryParam)
var scene *models.Scene var scene *models.Scene
readTxnErr := manager.GetInstance().TxnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { readTxnErr := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
qb := repo.Scene() qb := rs.sceneFinder
if sceneID == 0 { if sceneID == 0 {
// determine checksum/os by the length of the query param // determine checksum/os by the length of the query param
if len(sceneIdentifierQueryParam) == 32 { if len(sceneIdentifierQueryParam) == 32 {
scene, _ = qb.FindByChecksum(sceneIdentifierQueryParam) scene, _ = qb.FindByChecksum(ctx, sceneIdentifierQueryParam)
} else { } else {
scene, _ = qb.FindByOSHash(sceneIdentifierQueryParam) scene, _ = qb.FindByOSHash(ctx, sceneIdentifierQueryParam)
} }
} else { } else {
scene, _ = qb.Find(sceneID) scene, _ = qb.Find(ctx, sceneID)
} }
return nil return nil

View File

@@ -8,21 +8,28 @@ import (
"syscall" "syscall"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/stashapp/stash/internal/manager"
"github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/studio"
"github.com/stashapp/stash/pkg/txn"
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
) )
type StudioFinder interface {
studio.Finder
GetImage(ctx context.Context, studioID int) ([]byte, error)
}
type studioRoutes struct { type studioRoutes struct {
txnManager models.TransactionManager txnManager txn.Manager
studioFinder StudioFinder
} }
func (rs studioRoutes) Routes() chi.Router { func (rs studioRoutes) Routes() chi.Router {
r := chi.NewRouter() r := chi.NewRouter()
r.Route("/{studioId}", func(r chi.Router) { r.Route("/{studioId}", func(r chi.Router) {
r.Use(StudioCtx) r.Use(rs.StudioCtx)
r.Get("/image", rs.Image) r.Get("/image", rs.Image)
}) })
@@ -35,8 +42,8 @@ func (rs studioRoutes) Image(w http.ResponseWriter, r *http.Request) {
var image []byte var image []byte
if defaultParam != "true" { if defaultParam != "true" {
err := rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
image, _ = repo.Studio().GetImage(studio.ID) image, _ = rs.studioFinder.GetImage(ctx, studio.ID)
return nil return nil
}) })
if err != nil { if err != nil {
@@ -58,7 +65,7 @@ func (rs studioRoutes) Image(w http.ResponseWriter, r *http.Request) {
} }
} }
func StudioCtx(next http.Handler) http.Handler { func (rs studioRoutes) StudioCtx(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
studioID, err := strconv.Atoi(chi.URLParam(r, "studioId")) studioID, err := strconv.Atoi(chi.URLParam(r, "studioId"))
if err != nil { if err != nil {
@@ -67,9 +74,9 @@ func StudioCtx(next http.Handler) http.Handler {
} }
var studio *models.Studio var studio *models.Studio
if err := manager.GetInstance().TxnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { if err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
var err error var err error
studio, err = repo.Studio().Find(studioID) studio, err = rs.studioFinder.Find(ctx, studioID)
return err return err
}); err != nil { }); err != nil {
http.Error(w, http.StatusText(404), 404) http.Error(w, http.StatusText(404), 404)

View File

@@ -6,21 +6,28 @@ import (
"strconv" "strconv"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/stashapp/stash/internal/manager"
"github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/tag"
"github.com/stashapp/stash/pkg/txn"
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
) )
type TagFinder interface {
tag.Finder
GetImage(ctx context.Context, tagID int) ([]byte, error)
}
type tagRoutes struct { type tagRoutes struct {
txnManager models.TransactionManager txnManager txn.Manager
tagFinder TagFinder
} }
func (rs tagRoutes) Routes() chi.Router { func (rs tagRoutes) Routes() chi.Router {
r := chi.NewRouter() r := chi.NewRouter()
r.Route("/{tagId}", func(r chi.Router) { r.Route("/{tagId}", func(r chi.Router) {
r.Use(TagCtx) r.Use(rs.TagCtx)
r.Get("/image", rs.Image) r.Get("/image", rs.Image)
}) })
@@ -33,8 +40,8 @@ func (rs tagRoutes) Image(w http.ResponseWriter, r *http.Request) {
var image []byte var image []byte
if defaultParam != "true" { if defaultParam != "true" {
err := rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
image, _ = repo.Tag().GetImage(tag.ID) image, _ = rs.tagFinder.GetImage(ctx, tag.ID)
return nil return nil
}) })
if err != nil { if err != nil {
@@ -51,7 +58,7 @@ func (rs tagRoutes) Image(w http.ResponseWriter, r *http.Request) {
} }
} }
func TagCtx(next http.Handler) http.Handler { func (rs tagRoutes) TagCtx(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tagID, err := strconv.Atoi(chi.URLParam(r, "tagId")) tagID, err := strconv.Atoi(chi.URLParam(r, "tagId"))
if err != nil { if err != nil {
@@ -60,9 +67,9 @@ func TagCtx(next http.Handler) http.Handler {
} }
var tag *models.Tag var tag *models.Tag
if err := manager.GetInstance().TxnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { if err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
var err error var err error
tag, err = repo.Tag().Find(tagID) tag, err = rs.tagFinder.Find(ctx, tagID)
return err return err
}); err != nil { }); err != nil {
http.Error(w, http.StatusText(404), 404) http.Error(w, http.StatusText(404), 404)

View File

@@ -73,10 +73,11 @@ func Start() error {
return errors.New(message) return errors.New(message)
} }
txnManager := manager.GetInstance().TxnManager txnManager := manager.GetInstance().Repository
pluginCache := manager.GetInstance().PluginCache pluginCache := manager.GetInstance().PluginCache
resolver := &Resolver{ resolver := &Resolver{
txnManager: txnManager, txnManager: txnManager,
repository: txnManager,
hookExecutor: pluginCache, hookExecutor: pluginCache,
} }
@@ -119,21 +120,29 @@ func Start() error {
r.Mount("/performer", performerRoutes{ r.Mount("/performer", performerRoutes{
txnManager: txnManager, txnManager: txnManager,
performerFinder: txnManager.Performer,
}.Routes()) }.Routes())
r.Mount("/scene", sceneRoutes{ r.Mount("/scene", sceneRoutes{
txnManager: txnManager, txnManager: txnManager,
sceneFinder: txnManager.Scene,
sceneMarkerFinder: txnManager.SceneMarker,
tagFinder: txnManager.Tag,
}.Routes()) }.Routes())
r.Mount("/image", imageRoutes{ r.Mount("/image", imageRoutes{
txnManager: txnManager, txnManager: txnManager,
imageFinder: txnManager.Image,
}.Routes()) }.Routes())
r.Mount("/studio", studioRoutes{ r.Mount("/studio", studioRoutes{
txnManager: txnManager, txnManager: txnManager,
studioFinder: txnManager.Studio,
}.Routes()) }.Routes())
r.Mount("/movie", movieRoutes{ r.Mount("/movie", movieRoutes{
txnManager: txnManager, txnManager: txnManager,
movieFinder: txnManager.Movie,
}.Routes()) }.Routes())
r.Mount("/tag", tagRoutes{ r.Mount("/tag", tagRoutes{
txnManager: txnManager, txnManager: txnManager,
tagFinder: txnManager.Tag,
}.Routes()) }.Routes())
r.Mount("/downloads", downloadsRoutes{}.Routes()) r.Mount("/downloads", downloadsRoutes{}.Routes())

View File

@@ -1,6 +1,8 @@
package autotag package autotag
import ( import (
"context"
"github.com/stashapp/stash/pkg/gallery" "github.com/stashapp/stash/pkg/gallery"
"github.com/stashapp/stash/pkg/match" "github.com/stashapp/stash/pkg/match"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
@@ -21,18 +23,18 @@ func getGalleryFileTagger(s *models.Gallery, cache *match.Cache) tagger {
} }
// GalleryPerformers tags the provided gallery with performers whose name matches the gallery's path. // GalleryPerformers tags the provided gallery with performers whose name matches the gallery's path.
func GalleryPerformers(s *models.Gallery, rw models.GalleryReaderWriter, performerReader models.PerformerReader, cache *match.Cache) error { func GalleryPerformers(ctx context.Context, s *models.Gallery, rw gallery.PerformerUpdater, performerReader match.PerformerAutoTagQueryer, cache *match.Cache) error {
t := getGalleryFileTagger(s, cache) t := getGalleryFileTagger(s, cache)
return t.tagPerformers(performerReader, func(subjectID, otherID int) (bool, error) { return t.tagPerformers(ctx, performerReader, func(subjectID, otherID int) (bool, error) {
return gallery.AddPerformer(rw, subjectID, otherID) return gallery.AddPerformer(ctx, rw, subjectID, otherID)
}) })
} }
// GalleryStudios tags the provided gallery with the first studio whose name matches the gallery's path. // GalleryStudios tags the provided gallery with the first studio whose name matches the gallery's path.
// //
// Gallerys will not be tagged if studio is already set. // Gallerys will not be tagged if studio is already set.
func GalleryStudios(s *models.Gallery, rw models.GalleryReaderWriter, studioReader models.StudioReader, cache *match.Cache) error { func GalleryStudios(ctx context.Context, s *models.Gallery, rw GalleryFinderUpdater, studioReader match.StudioAutoTagQueryer, cache *match.Cache) error {
if s.StudioID.Valid { if s.StudioID.Valid {
// don't modify // don't modify
return nil return nil
@@ -40,16 +42,16 @@ func GalleryStudios(s *models.Gallery, rw models.GalleryReaderWriter, studioRead
t := getGalleryFileTagger(s, cache) t := getGalleryFileTagger(s, cache)
return t.tagStudios(studioReader, func(subjectID, otherID int) (bool, error) { return t.tagStudios(ctx, studioReader, func(subjectID, otherID int) (bool, error) {
return addGalleryStudio(rw, subjectID, otherID) return addGalleryStudio(ctx, rw, subjectID, otherID)
}) })
} }
// GalleryTags tags the provided gallery with tags whose name matches the gallery's path. // GalleryTags tags the provided gallery with tags whose name matches the gallery's path.
func GalleryTags(s *models.Gallery, rw models.GalleryReaderWriter, tagReader models.TagReader, cache *match.Cache) error { func GalleryTags(ctx context.Context, s *models.Gallery, rw gallery.TagUpdater, tagReader match.TagAutoTagQueryer, cache *match.Cache) error {
t := getGalleryFileTagger(s, cache) t := getGalleryFileTagger(s, cache)
return t.tagTags(tagReader, func(subjectID, otherID int) (bool, error) { return t.tagTags(ctx, tagReader, func(subjectID, otherID int) (bool, error) {
return gallery.AddTag(rw, subjectID, otherID) return gallery.AddTag(ctx, rw, subjectID, otherID)
}) })
} }

View File

@@ -1,6 +1,7 @@
package autotag package autotag
import ( import (
"context"
"testing" "testing"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
@@ -11,6 +12,8 @@ import (
const galleryExt = "zip" const galleryExt = "zip"
var testCtx = context.Background()
func TestGalleryPerformers(t *testing.T) { func TestGalleryPerformers(t *testing.T) {
t.Parallel() t.Parallel()
@@ -37,19 +40,19 @@ func TestGalleryPerformers(t *testing.T) {
mockPerformerReader := &mocks.PerformerReaderWriter{} mockPerformerReader := &mocks.PerformerReaderWriter{}
mockGalleryReader := &mocks.GalleryReaderWriter{} mockGalleryReader := &mocks.GalleryReaderWriter{}
mockPerformerReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil) mockPerformerReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockPerformerReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once() mockPerformerReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once()
if test.Matches { if test.Matches {
mockGalleryReader.On("GetPerformerIDs", galleryID).Return(nil, nil).Once() mockGalleryReader.On("GetPerformerIDs", testCtx, galleryID).Return(nil, nil).Once()
mockGalleryReader.On("UpdatePerformers", galleryID, []int{performerID}).Return(nil).Once() mockGalleryReader.On("UpdatePerformers", testCtx, galleryID, []int{performerID}).Return(nil).Once()
} }
gallery := models.Gallery{ gallery := models.Gallery{
ID: galleryID, ID: galleryID,
Path: models.NullString(test.Path), Path: models.NullString(test.Path),
} }
err := GalleryPerformers(&gallery, mockGalleryReader, mockPerformerReader, nil) err := GalleryPerformers(testCtx, &gallery, mockGalleryReader, mockPerformerReader, nil)
assert.Nil(err) assert.Nil(err)
mockPerformerReader.AssertExpectations(t) mockPerformerReader.AssertExpectations(t)
@@ -81,9 +84,9 @@ func TestGalleryStudios(t *testing.T) {
doTest := func(mockStudioReader *mocks.StudioReaderWriter, mockGalleryReader *mocks.GalleryReaderWriter, test pathTestTable) { doTest := func(mockStudioReader *mocks.StudioReaderWriter, mockGalleryReader *mocks.GalleryReaderWriter, test pathTestTable) {
if test.Matches { if test.Matches {
mockGalleryReader.On("Find", galleryID).Return(&models.Gallery{}, nil).Once() mockGalleryReader.On("Find", testCtx, galleryID).Return(&models.Gallery{}, nil).Once()
expectedStudioID := models.NullInt64(studioID) expectedStudioID := models.NullInt64(studioID)
mockGalleryReader.On("UpdatePartial", models.GalleryPartial{ mockGalleryReader.On("UpdatePartial", testCtx, models.GalleryPartial{
ID: galleryID, ID: galleryID,
StudioID: &expectedStudioID, StudioID: &expectedStudioID,
}).Return(nil, nil).Once() }).Return(nil, nil).Once()
@@ -93,7 +96,7 @@ func TestGalleryStudios(t *testing.T) {
ID: galleryID, ID: galleryID,
Path: models.NullString(test.Path), Path: models.NullString(test.Path),
} }
err := GalleryStudios(&gallery, mockGalleryReader, mockStudioReader, nil) err := GalleryStudios(testCtx, &gallery, mockGalleryReader, mockStudioReader, nil)
assert.Nil(err) assert.Nil(err)
mockStudioReader.AssertExpectations(t) mockStudioReader.AssertExpectations(t)
@@ -104,9 +107,9 @@ func TestGalleryStudios(t *testing.T) {
mockStudioReader := &mocks.StudioReaderWriter{} mockStudioReader := &mocks.StudioReaderWriter{}
mockGalleryReader := &mocks.GalleryReaderWriter{} mockGalleryReader := &mocks.GalleryReaderWriter{}
mockStudioReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil) mockStudioReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockStudioReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once() mockStudioReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
mockStudioReader.On("GetAliases", mock.Anything).Return([]string{}, nil).Maybe() mockStudioReader.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe()
doTest(mockStudioReader, mockGalleryReader, test) doTest(mockStudioReader, mockGalleryReader, test)
} }
@@ -119,12 +122,12 @@ func TestGalleryStudios(t *testing.T) {
mockStudioReader := &mocks.StudioReaderWriter{} mockStudioReader := &mocks.StudioReaderWriter{}
mockGalleryReader := &mocks.GalleryReaderWriter{} mockGalleryReader := &mocks.GalleryReaderWriter{}
mockStudioReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil) mockStudioReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockStudioReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once() mockStudioReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
mockStudioReader.On("GetAliases", studioID).Return([]string{ mockStudioReader.On("GetAliases", testCtx, studioID).Return([]string{
studioName, studioName,
}, nil).Once() }, nil).Once()
mockStudioReader.On("GetAliases", reversedStudioID).Return([]string{}, nil).Once() mockStudioReader.On("GetAliases", testCtx, reversedStudioID).Return([]string{}, nil).Once()
doTest(mockStudioReader, mockGalleryReader, test) doTest(mockStudioReader, mockGalleryReader, test)
} }
@@ -154,15 +157,15 @@ func TestGalleryTags(t *testing.T) {
doTest := func(mockTagReader *mocks.TagReaderWriter, mockGalleryReader *mocks.GalleryReaderWriter, test pathTestTable) { doTest := func(mockTagReader *mocks.TagReaderWriter, mockGalleryReader *mocks.GalleryReaderWriter, test pathTestTable) {
if test.Matches { if test.Matches {
mockGalleryReader.On("GetTagIDs", galleryID).Return(nil, nil).Once() mockGalleryReader.On("GetTagIDs", testCtx, galleryID).Return(nil, nil).Once()
mockGalleryReader.On("UpdateTags", galleryID, []int{tagID}).Return(nil).Once() mockGalleryReader.On("UpdateTags", testCtx, galleryID, []int{tagID}).Return(nil).Once()
} }
gallery := models.Gallery{ gallery := models.Gallery{
ID: galleryID, ID: galleryID,
Path: models.NullString(test.Path), Path: models.NullString(test.Path),
} }
err := GalleryTags(&gallery, mockGalleryReader, mockTagReader, nil) err := GalleryTags(testCtx, &gallery, mockGalleryReader, mockTagReader, nil)
assert.Nil(err) assert.Nil(err)
mockTagReader.AssertExpectations(t) mockTagReader.AssertExpectations(t)
@@ -173,9 +176,9 @@ func TestGalleryTags(t *testing.T) {
mockTagReader := &mocks.TagReaderWriter{} mockTagReader := &mocks.TagReaderWriter{}
mockGalleryReader := &mocks.GalleryReaderWriter{} mockGalleryReader := &mocks.GalleryReaderWriter{}
mockTagReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil) mockTagReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockTagReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once() mockTagReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
mockTagReader.On("GetAliases", mock.Anything).Return([]string{}, nil).Maybe() mockTagReader.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe()
doTest(mockTagReader, mockGalleryReader, test) doTest(mockTagReader, mockGalleryReader, test)
} }
@@ -187,12 +190,12 @@ func TestGalleryTags(t *testing.T) {
mockTagReader := &mocks.TagReaderWriter{} mockTagReader := &mocks.TagReaderWriter{}
mockGalleryReader := &mocks.GalleryReaderWriter{} mockGalleryReader := &mocks.GalleryReaderWriter{}
mockTagReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil) mockTagReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockTagReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once() mockTagReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
mockTagReader.On("GetAliases", tagID).Return([]string{ mockTagReader.On("GetAliases", testCtx, tagID).Return([]string{
tagName, tagName,
}, nil).Once() }, nil).Once()
mockTagReader.On("GetAliases", reversedTagID).Return([]string{}, nil).Once() mockTagReader.On("GetAliases", testCtx, reversedTagID).Return([]string{}, nil).Once()
doTest(mockTagReader, mockGalleryReader, test) doTest(mockTagReader, mockGalleryReader, test)
} }

View File

@@ -1,6 +1,8 @@
package autotag package autotag
import ( import (
"context"
"github.com/stashapp/stash/pkg/image" "github.com/stashapp/stash/pkg/image"
"github.com/stashapp/stash/pkg/match" "github.com/stashapp/stash/pkg/match"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
@@ -17,18 +19,18 @@ func getImageFileTagger(s *models.Image, cache *match.Cache) tagger {
} }
// ImagePerformers tags the provided image with performers whose name matches the image's path. // ImagePerformers tags the provided image with performers whose name matches the image's path.
func ImagePerformers(s *models.Image, rw models.ImageReaderWriter, performerReader models.PerformerReader, cache *match.Cache) error { func ImagePerformers(ctx context.Context, s *models.Image, rw image.PerformerUpdater, performerReader match.PerformerAutoTagQueryer, cache *match.Cache) error {
t := getImageFileTagger(s, cache) t := getImageFileTagger(s, cache)
return t.tagPerformers(performerReader, func(subjectID, otherID int) (bool, error) { return t.tagPerformers(ctx, performerReader, func(subjectID, otherID int) (bool, error) {
return image.AddPerformer(rw, subjectID, otherID) return image.AddPerformer(ctx, rw, subjectID, otherID)
}) })
} }
// ImageStudios tags the provided image with the first studio whose name matches the image's path. // ImageStudios tags the provided image with the first studio whose name matches the image's path.
// //
// Images will not be tagged if studio is already set. // Images will not be tagged if studio is already set.
func ImageStudios(s *models.Image, rw models.ImageReaderWriter, studioReader models.StudioReader, cache *match.Cache) error { func ImageStudios(ctx context.Context, s *models.Image, rw ImageFinderUpdater, studioReader match.StudioAutoTagQueryer, cache *match.Cache) error {
if s.StudioID.Valid { if s.StudioID.Valid {
// don't modify // don't modify
return nil return nil
@@ -36,16 +38,16 @@ func ImageStudios(s *models.Image, rw models.ImageReaderWriter, studioReader mod
t := getImageFileTagger(s, cache) t := getImageFileTagger(s, cache)
return t.tagStudios(studioReader, func(subjectID, otherID int) (bool, error) { return t.tagStudios(ctx, studioReader, func(subjectID, otherID int) (bool, error) {
return addImageStudio(rw, subjectID, otherID) return addImageStudio(ctx, rw, subjectID, otherID)
}) })
} }
// ImageTags tags the provided image with tags whose name matches the image's path. // ImageTags tags the provided image with tags whose name matches the image's path.
func ImageTags(s *models.Image, rw models.ImageReaderWriter, tagReader models.TagReader, cache *match.Cache) error { func ImageTags(ctx context.Context, s *models.Image, rw image.TagUpdater, tagReader match.TagAutoTagQueryer, cache *match.Cache) error {
t := getImageFileTagger(s, cache) t := getImageFileTagger(s, cache)
return t.tagTags(tagReader, func(subjectID, otherID int) (bool, error) { return t.tagTags(ctx, tagReader, func(subjectID, otherID int) (bool, error) {
return image.AddTag(rw, subjectID, otherID) return image.AddTag(ctx, rw, subjectID, otherID)
}) })
} }

View File

@@ -37,19 +37,19 @@ func TestImagePerformers(t *testing.T) {
mockPerformerReader := &mocks.PerformerReaderWriter{} mockPerformerReader := &mocks.PerformerReaderWriter{}
mockImageReader := &mocks.ImageReaderWriter{} mockImageReader := &mocks.ImageReaderWriter{}
mockPerformerReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil) mockPerformerReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockPerformerReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once() mockPerformerReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once()
if test.Matches { if test.Matches {
mockImageReader.On("GetPerformerIDs", imageID).Return(nil, nil).Once() mockImageReader.On("GetPerformerIDs", testCtx, imageID).Return(nil, nil).Once()
mockImageReader.On("UpdatePerformers", imageID, []int{performerID}).Return(nil).Once() mockImageReader.On("UpdatePerformers", testCtx, imageID, []int{performerID}).Return(nil).Once()
} }
image := models.Image{ image := models.Image{
ID: imageID, ID: imageID,
Path: test.Path, Path: test.Path,
} }
err := ImagePerformers(&image, mockImageReader, mockPerformerReader, nil) err := ImagePerformers(testCtx, &image, mockImageReader, mockPerformerReader, nil)
assert.Nil(err) assert.Nil(err)
mockPerformerReader.AssertExpectations(t) mockPerformerReader.AssertExpectations(t)
@@ -81,9 +81,9 @@ func TestImageStudios(t *testing.T) {
doTest := func(mockStudioReader *mocks.StudioReaderWriter, mockImageReader *mocks.ImageReaderWriter, test pathTestTable) { doTest := func(mockStudioReader *mocks.StudioReaderWriter, mockImageReader *mocks.ImageReaderWriter, test pathTestTable) {
if test.Matches { if test.Matches {
mockImageReader.On("Find", imageID).Return(&models.Image{}, nil).Once() mockImageReader.On("Find", testCtx, imageID).Return(&models.Image{}, nil).Once()
expectedStudioID := models.NullInt64(studioID) expectedStudioID := models.NullInt64(studioID)
mockImageReader.On("Update", models.ImagePartial{ mockImageReader.On("Update", testCtx, models.ImagePartial{
ID: imageID, ID: imageID,
StudioID: &expectedStudioID, StudioID: &expectedStudioID,
}).Return(nil, nil).Once() }).Return(nil, nil).Once()
@@ -93,7 +93,7 @@ func TestImageStudios(t *testing.T) {
ID: imageID, ID: imageID,
Path: test.Path, Path: test.Path,
} }
err := ImageStudios(&image, mockImageReader, mockStudioReader, nil) err := ImageStudios(testCtx, &image, mockImageReader, mockStudioReader, nil)
assert.Nil(err) assert.Nil(err)
mockStudioReader.AssertExpectations(t) mockStudioReader.AssertExpectations(t)
@@ -104,9 +104,9 @@ func TestImageStudios(t *testing.T) {
mockStudioReader := &mocks.StudioReaderWriter{} mockStudioReader := &mocks.StudioReaderWriter{}
mockImageReader := &mocks.ImageReaderWriter{} mockImageReader := &mocks.ImageReaderWriter{}
mockStudioReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil) mockStudioReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockStudioReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once() mockStudioReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
mockStudioReader.On("GetAliases", mock.Anything).Return([]string{}, nil).Maybe() mockStudioReader.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe()
doTest(mockStudioReader, mockImageReader, test) doTest(mockStudioReader, mockImageReader, test)
} }
@@ -119,12 +119,12 @@ func TestImageStudios(t *testing.T) {
mockStudioReader := &mocks.StudioReaderWriter{} mockStudioReader := &mocks.StudioReaderWriter{}
mockImageReader := &mocks.ImageReaderWriter{} mockImageReader := &mocks.ImageReaderWriter{}
mockStudioReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil) mockStudioReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockStudioReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once() mockStudioReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
mockStudioReader.On("GetAliases", studioID).Return([]string{ mockStudioReader.On("GetAliases", testCtx, studioID).Return([]string{
studioName, studioName,
}, nil).Once() }, nil).Once()
mockStudioReader.On("GetAliases", reversedStudioID).Return([]string{}, nil).Once() mockStudioReader.On("GetAliases", testCtx, reversedStudioID).Return([]string{}, nil).Once()
doTest(mockStudioReader, mockImageReader, test) doTest(mockStudioReader, mockImageReader, test)
} }
@@ -154,15 +154,15 @@ func TestImageTags(t *testing.T) {
doTest := func(mockTagReader *mocks.TagReaderWriter, mockImageReader *mocks.ImageReaderWriter, test pathTestTable) { doTest := func(mockTagReader *mocks.TagReaderWriter, mockImageReader *mocks.ImageReaderWriter, test pathTestTable) {
if test.Matches { if test.Matches {
mockImageReader.On("GetTagIDs", imageID).Return(nil, nil).Once() mockImageReader.On("GetTagIDs", testCtx, imageID).Return(nil, nil).Once()
mockImageReader.On("UpdateTags", imageID, []int{tagID}).Return(nil).Once() mockImageReader.On("UpdateTags", testCtx, imageID, []int{tagID}).Return(nil).Once()
} }
image := models.Image{ image := models.Image{
ID: imageID, ID: imageID,
Path: test.Path, Path: test.Path,
} }
err := ImageTags(&image, mockImageReader, mockTagReader, nil) err := ImageTags(testCtx, &image, mockImageReader, mockTagReader, nil)
assert.Nil(err) assert.Nil(err)
mockTagReader.AssertExpectations(t) mockTagReader.AssertExpectations(t)
@@ -173,9 +173,9 @@ func TestImageTags(t *testing.T) {
mockTagReader := &mocks.TagReaderWriter{} mockTagReader := &mocks.TagReaderWriter{}
mockImageReader := &mocks.ImageReaderWriter{} mockImageReader := &mocks.ImageReaderWriter{}
mockTagReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil) mockTagReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockTagReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once() mockTagReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
mockTagReader.On("GetAliases", mock.Anything).Return([]string{}, nil).Maybe() mockTagReader.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe()
doTest(mockTagReader, mockImageReader, test) doTest(mockTagReader, mockImageReader, test)
} }
@@ -188,12 +188,12 @@ func TestImageTags(t *testing.T) {
mockTagReader := &mocks.TagReaderWriter{} mockTagReader := &mocks.TagReaderWriter{}
mockImageReader := &mocks.ImageReaderWriter{} mockImageReader := &mocks.ImageReaderWriter{}
mockTagReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil) mockTagReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockTagReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once() mockTagReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
mockTagReader.On("GetAliases", tagID).Return([]string{ mockTagReader.On("GetAliases", testCtx, tagID).Return([]string{
tagName, tagName,
}, nil).Once() }, nil).Once()
mockTagReader.On("GetAliases", reversedTagID).Return([]string{}, nil).Once() mockTagReader.On("GetAliases", testCtx, reversedTagID).Return([]string{}, nil).Once()
doTest(mockTagReader, mockImageReader, test) doTest(mockTagReader, mockImageReader, test)
} }

View File

@@ -10,10 +10,10 @@ import (
"os" "os"
"testing" "testing"
"github.com/stashapp/stash/pkg/database"
"github.com/stashapp/stash/pkg/hash/md5" "github.com/stashapp/stash/pkg/hash/md5"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/sqlite" "github.com/stashapp/stash/pkg/sqlite"
"github.com/stashapp/stash/pkg/txn"
_ "github.com/golang-migrate/migrate/v4/database/sqlite3" _ "github.com/golang-migrate/migrate/v4/database/sqlite3"
_ "github.com/golang-migrate/migrate/v4/source/file" _ "github.com/golang-migrate/migrate/v4/source/file"
@@ -28,8 +28,11 @@ const existingStudioGalleryName = testName + ".dontChangeStudio.mp4"
var existingStudioID int var existingStudioID int
var db *sqlite.Database
var r models.Repository
func testTeardown(databaseFile string) { func testTeardown(databaseFile string) {
err := database.DB.Close() err := db.Close()
if err != nil { if err != nil {
panic(err) panic(err)
@@ -50,10 +53,13 @@ func runTests(m *testing.M) int {
f.Close() f.Close()
databaseFile := f.Name() databaseFile := f.Name()
if err := database.Initialize(databaseFile); err != nil { db = &sqlite.Database{}
if err := db.Open(databaseFile); err != nil {
panic(fmt.Sprintf("Could not initialize database: %s", err.Error())) panic(fmt.Sprintf("Could not initialize database: %s", err.Error()))
} }
r = db.TxnRepository()
// defer close and delete the database // defer close and delete the database
defer testTeardown(databaseFile) defer testTeardown(databaseFile)
@@ -71,7 +77,7 @@ func TestMain(m *testing.M) {
os.Exit(ret) os.Exit(ret)
} }
func createPerformer(pqb models.PerformerWriter) error { func createPerformer(ctx context.Context, pqb models.PerformerWriter) error {
// create the performer // create the performer
performer := models.Performer{ performer := models.Performer{
Checksum: testName, Checksum: testName,
@@ -79,7 +85,7 @@ func createPerformer(pqb models.PerformerWriter) error {
Favorite: sql.NullBool{Valid: true, Bool: false}, Favorite: sql.NullBool{Valid: true, Bool: false},
} }
_, err := pqb.Create(performer) _, err := pqb.Create(ctx, performer)
if err != nil { if err != nil {
return err return err
} }
@@ -87,23 +93,23 @@ func createPerformer(pqb models.PerformerWriter) error {
return nil return nil
} }
func createStudio(qb models.StudioWriter, name string) (*models.Studio, error) { func createStudio(ctx context.Context, qb models.StudioWriter, name string) (*models.Studio, error) {
// create the studio // create the studio
studio := models.Studio{ studio := models.Studio{
Checksum: name, Checksum: name,
Name: sql.NullString{Valid: true, String: name}, Name: sql.NullString{Valid: true, String: name},
} }
return qb.Create(studio) return qb.Create(ctx, studio)
} }
func createTag(qb models.TagWriter) error { func createTag(ctx context.Context, qb models.TagWriter) error {
// create the studio // create the studio
tag := models.Tag{ tag := models.Tag{
Name: testName, Name: testName,
} }
_, err := qb.Create(tag) _, err := qb.Create(ctx, tag)
if err != nil { if err != nil {
return err return err
} }
@@ -111,18 +117,18 @@ func createTag(qb models.TagWriter) error {
return nil return nil
} }
func createScenes(sqb models.SceneReaderWriter) error { func createScenes(ctx context.Context, sqb models.SceneReaderWriter) error {
// create the scenes // create the scenes
scenePatterns, falseScenePatterns := generateTestPaths(testName, sceneExt) scenePatterns, falseScenePatterns := generateTestPaths(testName, sceneExt)
for _, fn := range scenePatterns { for _, fn := range scenePatterns {
err := createScene(sqb, makeScene(fn, true)) err := createScene(ctx, sqb, makeScene(fn, true))
if err != nil { if err != nil {
return err return err
} }
} }
for _, fn := range falseScenePatterns { for _, fn := range falseScenePatterns {
err := createScene(sqb, makeScene(fn, false)) err := createScene(ctx, sqb, makeScene(fn, false))
if err != nil { if err != nil {
return err return err
} }
@@ -132,7 +138,7 @@ func createScenes(sqb models.SceneReaderWriter) error {
for _, fn := range scenePatterns { for _, fn := range scenePatterns {
s := makeScene("organized"+fn, false) s := makeScene("organized"+fn, false)
s.Organized = true s.Organized = true
err := createScene(sqb, s) err := createScene(ctx, sqb, s)
if err != nil { if err != nil {
return err return err
} }
@@ -141,7 +147,7 @@ func createScenes(sqb models.SceneReaderWriter) error {
// create scene with existing studio io // create scene with existing studio io
studioScene := makeScene(existingStudioSceneName, true) studioScene := makeScene(existingStudioSceneName, true)
studioScene.StudioID = sql.NullInt64{Valid: true, Int64: int64(existingStudioID)} studioScene.StudioID = sql.NullInt64{Valid: true, Int64: int64(existingStudioID)}
err := createScene(sqb, studioScene) err := createScene(ctx, sqb, studioScene)
if err != nil { if err != nil {
return err return err
} }
@@ -163,8 +169,8 @@ func makeScene(name string, expectedResult bool) *models.Scene {
return scene return scene
} }
func createScene(sqb models.SceneWriter, scene *models.Scene) error { func createScene(ctx context.Context, sqb models.SceneWriter, scene *models.Scene) error {
_, err := sqb.Create(*scene) _, err := sqb.Create(ctx, *scene)
if err != nil { if err != nil {
return fmt.Errorf("Failed to create scene with name '%s': %s", scene.Path, err.Error()) return fmt.Errorf("Failed to create scene with name '%s': %s", scene.Path, err.Error())
@@ -173,18 +179,18 @@ func createScene(sqb models.SceneWriter, scene *models.Scene) error {
return nil return nil
} }
func createImages(sqb models.ImageReaderWriter) error { func createImages(ctx context.Context, sqb models.ImageReaderWriter) error {
// create the images // create the images
imagePatterns, falseImagePatterns := generateTestPaths(testName, imageExt) imagePatterns, falseImagePatterns := generateTestPaths(testName, imageExt)
for _, fn := range imagePatterns { for _, fn := range imagePatterns {
err := createImage(sqb, makeImage(fn, true)) err := createImage(ctx, sqb, makeImage(fn, true))
if err != nil { if err != nil {
return err return err
} }
} }
for _, fn := range falseImagePatterns { for _, fn := range falseImagePatterns {
err := createImage(sqb, makeImage(fn, false)) err := createImage(ctx, sqb, makeImage(fn, false))
if err != nil { if err != nil {
return err return err
} }
@@ -194,7 +200,7 @@ func createImages(sqb models.ImageReaderWriter) error {
for _, fn := range imagePatterns { for _, fn := range imagePatterns {
s := makeImage("organized"+fn, false) s := makeImage("organized"+fn, false)
s.Organized = true s.Organized = true
err := createImage(sqb, s) err := createImage(ctx, sqb, s)
if err != nil { if err != nil {
return err return err
} }
@@ -203,7 +209,7 @@ func createImages(sqb models.ImageReaderWriter) error {
// create image with existing studio io // create image with existing studio io
studioImage := makeImage(existingStudioImageName, true) studioImage := makeImage(existingStudioImageName, true)
studioImage.StudioID = sql.NullInt64{Valid: true, Int64: int64(existingStudioID)} studioImage.StudioID = sql.NullInt64{Valid: true, Int64: int64(existingStudioID)}
err := createImage(sqb, studioImage) err := createImage(ctx, sqb, studioImage)
if err != nil { if err != nil {
return err return err
} }
@@ -225,8 +231,8 @@ func makeImage(name string, expectedResult bool) *models.Image {
return image return image
} }
func createImage(sqb models.ImageWriter, image *models.Image) error { func createImage(ctx context.Context, sqb models.ImageWriter, image *models.Image) error {
_, err := sqb.Create(*image) _, err := sqb.Create(ctx, *image)
if err != nil { if err != nil {
return fmt.Errorf("Failed to create image with name '%s': %s", image.Path, err.Error()) return fmt.Errorf("Failed to create image with name '%s': %s", image.Path, err.Error())
@@ -235,18 +241,18 @@ func createImage(sqb models.ImageWriter, image *models.Image) error {
return nil return nil
} }
func createGalleries(sqb models.GalleryReaderWriter) error { func createGalleries(ctx context.Context, sqb models.GalleryReaderWriter) error {
// create the galleries // create the galleries
galleryPatterns, falseGalleryPatterns := generateTestPaths(testName, galleryExt) galleryPatterns, falseGalleryPatterns := generateTestPaths(testName, galleryExt)
for _, fn := range galleryPatterns { for _, fn := range galleryPatterns {
err := createGallery(sqb, makeGallery(fn, true)) err := createGallery(ctx, sqb, makeGallery(fn, true))
if err != nil { if err != nil {
return err return err
} }
} }
for _, fn := range falseGalleryPatterns { for _, fn := range falseGalleryPatterns {
err := createGallery(sqb, makeGallery(fn, false)) err := createGallery(ctx, sqb, makeGallery(fn, false))
if err != nil { if err != nil {
return err return err
} }
@@ -256,7 +262,7 @@ func createGalleries(sqb models.GalleryReaderWriter) error {
for _, fn := range galleryPatterns { for _, fn := range galleryPatterns {
s := makeGallery("organized"+fn, false) s := makeGallery("organized"+fn, false)
s.Organized = true s.Organized = true
err := createGallery(sqb, s) err := createGallery(ctx, sqb, s)
if err != nil { if err != nil {
return err return err
} }
@@ -265,7 +271,7 @@ func createGalleries(sqb models.GalleryReaderWriter) error {
// create gallery with existing studio io // create gallery with existing studio io
studioGallery := makeGallery(existingStudioGalleryName, true) studioGallery := makeGallery(existingStudioGalleryName, true)
studioGallery.StudioID = sql.NullInt64{Valid: true, Int64: int64(existingStudioID)} studioGallery.StudioID = sql.NullInt64{Valid: true, Int64: int64(existingStudioID)}
err := createGallery(sqb, studioGallery) err := createGallery(ctx, sqb, studioGallery)
if err != nil { if err != nil {
return err return err
} }
@@ -287,8 +293,8 @@ func makeGallery(name string, expectedResult bool) *models.Gallery {
return gallery return gallery
} }
func createGallery(sqb models.GalleryWriter, gallery *models.Gallery) error { func createGallery(ctx context.Context, sqb models.GalleryWriter, gallery *models.Gallery) error {
_, err := sqb.Create(*gallery) _, err := sqb.Create(ctx, *gallery)
if err != nil { if err != nil {
return fmt.Errorf("Failed to create gallery with name '%s': %s", gallery.Path.String, err.Error()) return fmt.Errorf("Failed to create gallery with name '%s': %s", gallery.Path.String, err.Error())
@@ -297,47 +303,46 @@ func createGallery(sqb models.GalleryWriter, gallery *models.Gallery) error {
return nil return nil
} }
func withTxn(f func(r models.Repository) error) error { func withTxn(f func(ctx context.Context) error) error {
t := sqlite.NewTransactionManager() return txn.WithTxn(context.TODO(), db, f)
return t.WithTxn(context.TODO(), f)
} }
func populateDB() error { func populateDB() error {
if err := withTxn(func(r models.Repository) error { if err := withTxn(func(ctx context.Context) error {
err := createPerformer(r.Performer()) err := createPerformer(ctx, r.Performer)
if err != nil { if err != nil {
return err return err
} }
_, err = createStudio(r.Studio(), testName) _, err = createStudio(ctx, r.Studio, testName)
if err != nil { if err != nil {
return err return err
} }
// create existing studio // create existing studio
existingStudio, err := createStudio(r.Studio(), existingStudioName) existingStudio, err := createStudio(ctx, r.Studio, existingStudioName)
if err != nil { if err != nil {
return err return err
} }
existingStudioID = existingStudio.ID existingStudioID = existingStudio.ID
err = createTag(r.Tag()) err = createTag(ctx, r.Tag)
if err != nil { if err != nil {
return err return err
} }
err = createScenes(r.Scene()) err = createScenes(ctx, r.Scene)
if err != nil { if err != nil {
return err return err
} }
err = createImages(r.Image()) err = createImages(ctx, r.Image)
if err != nil { if err != nil {
return err return err
} }
err = createGalleries(r.Gallery()) err = createGalleries(ctx, r.Gallery)
if err != nil { if err != nil {
return err return err
} }
@@ -352,9 +357,9 @@ func populateDB() error {
func TestParsePerformerScenes(t *testing.T) { func TestParsePerformerScenes(t *testing.T) {
var performers []*models.Performer var performers []*models.Performer
if err := withTxn(func(r models.Repository) error { if err := withTxn(func(ctx context.Context) error {
var err error var err error
performers, err = r.Performer().All() performers, err = r.Performer.All(ctx)
return err return err
}); err != nil { }); err != nil {
t.Errorf("Error getting performer: %s", err) t.Errorf("Error getting performer: %s", err)
@@ -362,24 +367,24 @@ func TestParsePerformerScenes(t *testing.T) {
} }
for _, p := range performers { for _, p := range performers {
if err := withTxn(func(r models.Repository) error { if err := withTxn(func(ctx context.Context) error {
return PerformerScenes(p, nil, r.Scene(), nil) return PerformerScenes(ctx, p, nil, r.Scene, nil)
}); err != nil { }); err != nil {
t.Errorf("Error auto-tagging performers: %s", err) t.Errorf("Error auto-tagging performers: %s", err)
} }
} }
// verify that scenes were tagged correctly // verify that scenes were tagged correctly
withTxn(func(r models.Repository) error { withTxn(func(ctx context.Context) error {
pqb := r.Performer() pqb := r.Performer
scenes, err := r.Scene().All() scenes, err := r.Scene.All(ctx)
if err != nil { if err != nil {
t.Error(err.Error()) t.Error(err.Error())
} }
for _, scene := range scenes { for _, scene := range scenes {
performers, err := pqb.FindBySceneID(scene.ID) performers, err := pqb.FindBySceneID(ctx, scene.ID)
if err != nil { if err != nil {
t.Errorf("Error getting scene performers: %s", err.Error()) t.Errorf("Error getting scene performers: %s", err.Error())
@@ -399,9 +404,9 @@ func TestParsePerformerScenes(t *testing.T) {
func TestParseStudioScenes(t *testing.T) { func TestParseStudioScenes(t *testing.T) {
var studios []*models.Studio var studios []*models.Studio
if err := withTxn(func(r models.Repository) error { if err := withTxn(func(ctx context.Context) error {
var err error var err error
studios, err = r.Studio().All() studios, err = r.Studio.All(ctx)
return err return err
}); err != nil { }); err != nil {
t.Errorf("Error getting studio: %s", err) t.Errorf("Error getting studio: %s", err)
@@ -409,21 +414,21 @@ func TestParseStudioScenes(t *testing.T) {
} }
for _, s := range studios { for _, s := range studios {
if err := withTxn(func(r models.Repository) error { if err := withTxn(func(ctx context.Context) error {
aliases, err := r.Studio().GetAliases(s.ID) aliases, err := r.Studio.GetAliases(ctx, s.ID)
if err != nil { if err != nil {
return err return err
} }
return StudioScenes(s, nil, aliases, r.Scene(), nil) return StudioScenes(ctx, s, nil, aliases, r.Scene, nil)
}); err != nil { }); err != nil {
t.Errorf("Error auto-tagging performers: %s", err) t.Errorf("Error auto-tagging performers: %s", err)
} }
} }
// verify that scenes were tagged correctly // verify that scenes were tagged correctly
withTxn(func(r models.Repository) error { withTxn(func(ctx context.Context) error {
scenes, err := r.Scene().All() scenes, err := r.Scene.All(ctx)
if err != nil { if err != nil {
t.Error(err.Error()) t.Error(err.Error())
} }
@@ -455,9 +460,9 @@ func TestParseStudioScenes(t *testing.T) {
func TestParseTagScenes(t *testing.T) { func TestParseTagScenes(t *testing.T) {
var tags []*models.Tag var tags []*models.Tag
if err := withTxn(func(r models.Repository) error { if err := withTxn(func(ctx context.Context) error {
var err error var err error
tags, err = r.Tag().All() tags, err = r.Tag.All(ctx)
return err return err
}); err != nil { }); err != nil {
t.Errorf("Error getting performer: %s", err) t.Errorf("Error getting performer: %s", err)
@@ -465,29 +470,29 @@ func TestParseTagScenes(t *testing.T) {
} }
for _, s := range tags { for _, s := range tags {
if err := withTxn(func(r models.Repository) error { if err := withTxn(func(ctx context.Context) error {
aliases, err := r.Tag().GetAliases(s.ID) aliases, err := r.Tag.GetAliases(ctx, s.ID)
if err != nil { if err != nil {
return err return err
} }
return TagScenes(s, nil, aliases, r.Scene(), nil) return TagScenes(ctx, s, nil, aliases, r.Scene, nil)
}); err != nil { }); err != nil {
t.Errorf("Error auto-tagging performers: %s", err) t.Errorf("Error auto-tagging performers: %s", err)
} }
} }
// verify that scenes were tagged correctly // verify that scenes were tagged correctly
withTxn(func(r models.Repository) error { withTxn(func(ctx context.Context) error {
scenes, err := r.Scene().All() scenes, err := r.Scene.All(ctx)
if err != nil { if err != nil {
t.Error(err.Error()) t.Error(err.Error())
} }
tqb := r.Tag() tqb := r.Tag
for _, scene := range scenes { for _, scene := range scenes {
tags, err := tqb.FindBySceneID(scene.ID) tags, err := tqb.FindBySceneID(ctx, scene.ID)
if err != nil { if err != nil {
t.Errorf("Error getting scene tags: %s", err.Error()) t.Errorf("Error getting scene tags: %s", err.Error())
@@ -507,9 +512,9 @@ func TestParseTagScenes(t *testing.T) {
func TestParsePerformerImages(t *testing.T) { func TestParsePerformerImages(t *testing.T) {
var performers []*models.Performer var performers []*models.Performer
if err := withTxn(func(r models.Repository) error { if err := withTxn(func(ctx context.Context) error {
var err error var err error
performers, err = r.Performer().All() performers, err = r.Performer.All(ctx)
return err return err
}); err != nil { }); err != nil {
t.Errorf("Error getting performer: %s", err) t.Errorf("Error getting performer: %s", err)
@@ -517,24 +522,24 @@ func TestParsePerformerImages(t *testing.T) {
} }
for _, p := range performers { for _, p := range performers {
if err := withTxn(func(r models.Repository) error { if err := withTxn(func(ctx context.Context) error {
return PerformerImages(p, nil, r.Image(), nil) return PerformerImages(ctx, p, nil, r.Image, nil)
}); err != nil { }); err != nil {
t.Errorf("Error auto-tagging performers: %s", err) t.Errorf("Error auto-tagging performers: %s", err)
} }
} }
// verify that images were tagged correctly // verify that images were tagged correctly
withTxn(func(r models.Repository) error { withTxn(func(ctx context.Context) error {
pqb := r.Performer() pqb := r.Performer
images, err := r.Image().All() images, err := r.Image.All(ctx)
if err != nil { if err != nil {
t.Error(err.Error()) t.Error(err.Error())
} }
for _, image := range images { for _, image := range images {
performers, err := pqb.FindByImageID(image.ID) performers, err := pqb.FindByImageID(ctx, image.ID)
if err != nil { if err != nil {
t.Errorf("Error getting image performers: %s", err.Error()) t.Errorf("Error getting image performers: %s", err.Error())
@@ -554,9 +559,9 @@ func TestParsePerformerImages(t *testing.T) {
func TestParseStudioImages(t *testing.T) { func TestParseStudioImages(t *testing.T) {
var studios []*models.Studio var studios []*models.Studio
if err := withTxn(func(r models.Repository) error { if err := withTxn(func(ctx context.Context) error {
var err error var err error
studios, err = r.Studio().All() studios, err = r.Studio.All(ctx)
return err return err
}); err != nil { }); err != nil {
t.Errorf("Error getting studio: %s", err) t.Errorf("Error getting studio: %s", err)
@@ -564,21 +569,21 @@ func TestParseStudioImages(t *testing.T) {
} }
for _, s := range studios { for _, s := range studios {
if err := withTxn(func(r models.Repository) error { if err := withTxn(func(ctx context.Context) error {
aliases, err := r.Studio().GetAliases(s.ID) aliases, err := r.Studio.GetAliases(ctx, s.ID)
if err != nil { if err != nil {
return err return err
} }
return StudioImages(s, nil, aliases, r.Image(), nil) return StudioImages(ctx, s, nil, aliases, r.Image, nil)
}); err != nil { }); err != nil {
t.Errorf("Error auto-tagging performers: %s", err) t.Errorf("Error auto-tagging performers: %s", err)
} }
} }
// verify that images were tagged correctly // verify that images were tagged correctly
withTxn(func(r models.Repository) error { withTxn(func(ctx context.Context) error {
images, err := r.Image().All() images, err := r.Image.All(ctx)
if err != nil { if err != nil {
t.Error(err.Error()) t.Error(err.Error())
} }
@@ -610,9 +615,9 @@ func TestParseStudioImages(t *testing.T) {
func TestParseTagImages(t *testing.T) { func TestParseTagImages(t *testing.T) {
var tags []*models.Tag var tags []*models.Tag
if err := withTxn(func(r models.Repository) error { if err := withTxn(func(ctx context.Context) error {
var err error var err error
tags, err = r.Tag().All() tags, err = r.Tag.All(ctx)
return err return err
}); err != nil { }); err != nil {
t.Errorf("Error getting performer: %s", err) t.Errorf("Error getting performer: %s", err)
@@ -620,29 +625,29 @@ func TestParseTagImages(t *testing.T) {
} }
for _, s := range tags { for _, s := range tags {
if err := withTxn(func(r models.Repository) error { if err := withTxn(func(ctx context.Context) error {
aliases, err := r.Tag().GetAliases(s.ID) aliases, err := r.Tag.GetAliases(ctx, s.ID)
if err != nil { if err != nil {
return err return err
} }
return TagImages(s, nil, aliases, r.Image(), nil) return TagImages(ctx, s, nil, aliases, r.Image, nil)
}); err != nil { }); err != nil {
t.Errorf("Error auto-tagging performers: %s", err) t.Errorf("Error auto-tagging performers: %s", err)
} }
} }
// verify that images were tagged correctly // verify that images were tagged correctly
withTxn(func(r models.Repository) error { withTxn(func(ctx context.Context) error {
images, err := r.Image().All() images, err := r.Image.All(ctx)
if err != nil { if err != nil {
t.Error(err.Error()) t.Error(err.Error())
} }
tqb := r.Tag() tqb := r.Tag
for _, image := range images { for _, image := range images {
tags, err := tqb.FindByImageID(image.ID) tags, err := tqb.FindByImageID(ctx, image.ID)
if err != nil { if err != nil {
t.Errorf("Error getting image tags: %s", err.Error()) t.Errorf("Error getting image tags: %s", err.Error())
@@ -662,9 +667,9 @@ func TestParseTagImages(t *testing.T) {
func TestParsePerformerGalleries(t *testing.T) { func TestParsePerformerGalleries(t *testing.T) {
var performers []*models.Performer var performers []*models.Performer
if err := withTxn(func(r models.Repository) error { if err := withTxn(func(ctx context.Context) error {
var err error var err error
performers, err = r.Performer().All() performers, err = r.Performer.All(ctx)
return err return err
}); err != nil { }); err != nil {
t.Errorf("Error getting performer: %s", err) t.Errorf("Error getting performer: %s", err)
@@ -672,24 +677,24 @@ func TestParsePerformerGalleries(t *testing.T) {
} }
for _, p := range performers { for _, p := range performers {
if err := withTxn(func(r models.Repository) error { if err := withTxn(func(ctx context.Context) error {
return PerformerGalleries(p, nil, r.Gallery(), nil) return PerformerGalleries(ctx, p, nil, r.Gallery, nil)
}); err != nil { }); err != nil {
t.Errorf("Error auto-tagging performers: %s", err) t.Errorf("Error auto-tagging performers: %s", err)
} }
} }
// verify that galleries were tagged correctly // verify that galleries were tagged correctly
withTxn(func(r models.Repository) error { withTxn(func(ctx context.Context) error {
pqb := r.Performer() pqb := r.Performer
galleries, err := r.Gallery().All() galleries, err := r.Gallery.All(ctx)
if err != nil { if err != nil {
t.Error(err.Error()) t.Error(err.Error())
} }
for _, gallery := range galleries { for _, gallery := range galleries {
performers, err := pqb.FindByGalleryID(gallery.ID) performers, err := pqb.FindByGalleryID(ctx, gallery.ID)
if err != nil { if err != nil {
t.Errorf("Error getting gallery performers: %s", err.Error()) t.Errorf("Error getting gallery performers: %s", err.Error())
@@ -709,9 +714,9 @@ func TestParsePerformerGalleries(t *testing.T) {
func TestParseStudioGalleries(t *testing.T) { func TestParseStudioGalleries(t *testing.T) {
var studios []*models.Studio var studios []*models.Studio
if err := withTxn(func(r models.Repository) error { if err := withTxn(func(ctx context.Context) error {
var err error var err error
studios, err = r.Studio().All() studios, err = r.Studio.All(ctx)
return err return err
}); err != nil { }); err != nil {
t.Errorf("Error getting studio: %s", err) t.Errorf("Error getting studio: %s", err)
@@ -719,21 +724,21 @@ func TestParseStudioGalleries(t *testing.T) {
} }
for _, s := range studios { for _, s := range studios {
if err := withTxn(func(r models.Repository) error { if err := withTxn(func(ctx context.Context) error {
aliases, err := r.Studio().GetAliases(s.ID) aliases, err := r.Studio.GetAliases(ctx, s.ID)
if err != nil { if err != nil {
return err return err
} }
return StudioGalleries(s, nil, aliases, r.Gallery(), nil) return StudioGalleries(ctx, s, nil, aliases, r.Gallery, nil)
}); err != nil { }); err != nil {
t.Errorf("Error auto-tagging performers: %s", err) t.Errorf("Error auto-tagging performers: %s", err)
} }
} }
// verify that galleries were tagged correctly // verify that galleries were tagged correctly
withTxn(func(r models.Repository) error { withTxn(func(ctx context.Context) error {
galleries, err := r.Gallery().All() galleries, err := r.Gallery.All(ctx)
if err != nil { if err != nil {
t.Error(err.Error()) t.Error(err.Error())
} }
@@ -765,9 +770,9 @@ func TestParseStudioGalleries(t *testing.T) {
func TestParseTagGalleries(t *testing.T) { func TestParseTagGalleries(t *testing.T) {
var tags []*models.Tag var tags []*models.Tag
if err := withTxn(func(r models.Repository) error { if err := withTxn(func(ctx context.Context) error {
var err error var err error
tags, err = r.Tag().All() tags, err = r.Tag.All(ctx)
return err return err
}); err != nil { }); err != nil {
t.Errorf("Error getting performer: %s", err) t.Errorf("Error getting performer: %s", err)
@@ -775,29 +780,29 @@ func TestParseTagGalleries(t *testing.T) {
} }
for _, s := range tags { for _, s := range tags {
if err := withTxn(func(r models.Repository) error { if err := withTxn(func(ctx context.Context) error {
aliases, err := r.Tag().GetAliases(s.ID) aliases, err := r.Tag.GetAliases(ctx, s.ID)
if err != nil { if err != nil {
return err return err
} }
return TagGalleries(s, nil, aliases, r.Gallery(), nil) return TagGalleries(ctx, s, nil, aliases, r.Gallery, nil)
}); err != nil { }); err != nil {
t.Errorf("Error auto-tagging performers: %s", err) t.Errorf("Error auto-tagging performers: %s", err)
} }
} }
// verify that galleries were tagged correctly // verify that galleries were tagged correctly
withTxn(func(r models.Repository) error { withTxn(func(ctx context.Context) error {
galleries, err := r.Gallery().All() galleries, err := r.Gallery.All(ctx)
if err != nil { if err != nil {
t.Error(err.Error()) t.Error(err.Error())
} }
tqb := r.Tag() tqb := r.Tag
for _, gallery := range galleries { for _, gallery := range galleries {
tags, err := tqb.FindByGalleryID(gallery.ID) tags, err := tqb.FindByGalleryID(ctx, gallery.ID)
if err != nil { if err != nil {
t.Errorf("Error getting gallery tags: %s", err.Error()) t.Errorf("Error getting gallery tags: %s", err.Error())

View File

@@ -1,6 +1,8 @@
package autotag package autotag
import ( import (
"context"
"github.com/stashapp/stash/pkg/gallery" "github.com/stashapp/stash/pkg/gallery"
"github.com/stashapp/stash/pkg/image" "github.com/stashapp/stash/pkg/image"
"github.com/stashapp/stash/pkg/match" "github.com/stashapp/stash/pkg/match"
@@ -8,6 +10,21 @@ import (
"github.com/stashapp/stash/pkg/scene" "github.com/stashapp/stash/pkg/scene"
) )
type SceneQueryPerformerUpdater interface {
scene.Queryer
scene.PerformerUpdater
}
type ImageQueryPerformerUpdater interface {
image.Queryer
image.PerformerUpdater
}
type GalleryQueryPerformerUpdater interface {
gallery.Queryer
gallery.PerformerUpdater
}
func getPerformerTagger(p *models.Performer, cache *match.Cache) tagger { func getPerformerTagger(p *models.Performer, cache *match.Cache) tagger {
return tagger{ return tagger{
ID: p.ID, ID: p.ID,
@@ -18,28 +35,28 @@ func getPerformerTagger(p *models.Performer, cache *match.Cache) tagger {
} }
// PerformerScenes searches for scenes whose path matches the provided performer name and tags the scene with the performer. // PerformerScenes searches for scenes whose path matches the provided performer name and tags the scene with the performer.
func PerformerScenes(p *models.Performer, paths []string, rw models.SceneReaderWriter, cache *match.Cache) error { func PerformerScenes(ctx context.Context, p *models.Performer, paths []string, rw SceneQueryPerformerUpdater, cache *match.Cache) error {
t := getPerformerTagger(p, cache) t := getPerformerTagger(p, cache)
return t.tagScenes(paths, rw, func(subjectID, otherID int) (bool, error) { return t.tagScenes(ctx, paths, rw, func(subjectID, otherID int) (bool, error) {
return scene.AddPerformer(rw, otherID, subjectID) return scene.AddPerformer(ctx, rw, otherID, subjectID)
}) })
} }
// PerformerImages searches for images whose path matches the provided performer name and tags the image with the performer. // PerformerImages searches for images whose path matches the provided performer name and tags the image with the performer.
func PerformerImages(p *models.Performer, paths []string, rw models.ImageReaderWriter, cache *match.Cache) error { func PerformerImages(ctx context.Context, p *models.Performer, paths []string, rw ImageQueryPerformerUpdater, cache *match.Cache) error {
t := getPerformerTagger(p, cache) t := getPerformerTagger(p, cache)
return t.tagImages(paths, rw, func(subjectID, otherID int) (bool, error) { return t.tagImages(ctx, paths, rw, func(subjectID, otherID int) (bool, error) {
return image.AddPerformer(rw, otherID, subjectID) return image.AddPerformer(ctx, rw, otherID, subjectID)
}) })
} }
// PerformerGalleries searches for galleries whose path matches the provided performer name and tags the gallery with the performer. // PerformerGalleries searches for galleries whose path matches the provided performer name and tags the gallery with the performer.
func PerformerGalleries(p *models.Performer, paths []string, rw models.GalleryReaderWriter, cache *match.Cache) error { func PerformerGalleries(ctx context.Context, p *models.Performer, paths []string, rw GalleryQueryPerformerUpdater, cache *match.Cache) error {
t := getPerformerTagger(p, cache) t := getPerformerTagger(p, cache)
return t.tagGalleries(paths, rw, func(subjectID, otherID int) (bool, error) { return t.tagGalleries(ctx, paths, rw, func(subjectID, otherID int) (bool, error) {
return gallery.AddPerformer(rw, otherID, subjectID) return gallery.AddPerformer(ctx, rw, otherID, subjectID)
}) })
} }

View File

@@ -72,16 +72,16 @@ func testPerformerScenes(t *testing.T, performerName, expectedRegex string) {
PerPage: &perPage, PerPage: &perPage,
} }
mockSceneReader.On("Query", scene.QueryOptions(expectedSceneFilter, expectedFindFilter, false)). mockSceneReader.On("Query", testCtx, scene.QueryOptions(expectedSceneFilter, expectedFindFilter, false)).
Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once() Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once()
for i := range matchingPaths { for i := range matchingPaths {
sceneID := i + 1 sceneID := i + 1
mockSceneReader.On("GetPerformerIDs", sceneID).Return(nil, nil).Once() mockSceneReader.On("GetPerformerIDs", testCtx, sceneID).Return(nil, nil).Once()
mockSceneReader.On("UpdatePerformers", sceneID, []int{performerID}).Return(nil).Once() mockSceneReader.On("UpdatePerformers", testCtx, sceneID, []int{performerID}).Return(nil).Once()
} }
err := PerformerScenes(&performer, nil, mockSceneReader, nil) err := PerformerScenes(testCtx, &performer, nil, mockSceneReader, nil)
assert := assert.New(t) assert := assert.New(t)
@@ -147,16 +147,16 @@ func testPerformerImages(t *testing.T, performerName, expectedRegex string) {
PerPage: &perPage, PerPage: &perPage,
} }
mockImageReader.On("Query", image.QueryOptions(expectedImageFilter, expectedFindFilter, false)). mockImageReader.On("Query", testCtx, image.QueryOptions(expectedImageFilter, expectedFindFilter, false)).
Return(mocks.ImageQueryResult(images, len(images)), nil).Once() Return(mocks.ImageQueryResult(images, len(images)), nil).Once()
for i := range matchingPaths { for i := range matchingPaths {
imageID := i + 1 imageID := i + 1
mockImageReader.On("GetPerformerIDs", imageID).Return(nil, nil).Once() mockImageReader.On("GetPerformerIDs", testCtx, imageID).Return(nil, nil).Once()
mockImageReader.On("UpdatePerformers", imageID, []int{performerID}).Return(nil).Once() mockImageReader.On("UpdatePerformers", testCtx, imageID, []int{performerID}).Return(nil).Once()
} }
err := PerformerImages(&performer, nil, mockImageReader, nil) err := PerformerImages(testCtx, &performer, nil, mockImageReader, nil)
assert := assert.New(t) assert := assert.New(t)
@@ -222,15 +222,15 @@ func testPerformerGalleries(t *testing.T, performerName, expectedRegex string) {
PerPage: &perPage, PerPage: &perPage,
} }
mockGalleryReader.On("Query", expectedGalleryFilter, expectedFindFilter).Return(galleries, len(galleries), nil).Once() mockGalleryReader.On("Query", testCtx, expectedGalleryFilter, expectedFindFilter).Return(galleries, len(galleries), nil).Once()
for i := range matchingPaths { for i := range matchingPaths {
galleryID := i + 1 galleryID := i + 1
mockGalleryReader.On("GetPerformerIDs", galleryID).Return(nil, nil).Once() mockGalleryReader.On("GetPerformerIDs", testCtx, galleryID).Return(nil, nil).Once()
mockGalleryReader.On("UpdatePerformers", galleryID, []int{performerID}).Return(nil).Once() mockGalleryReader.On("UpdatePerformers", testCtx, galleryID, []int{performerID}).Return(nil).Once()
} }
err := PerformerGalleries(&performer, nil, mockGalleryReader, nil) err := PerformerGalleries(testCtx, &performer, nil, mockGalleryReader, nil)
assert := assert.New(t) assert := assert.New(t)

View File

@@ -1,6 +1,8 @@
package autotag package autotag
import ( import (
"context"
"github.com/stashapp/stash/pkg/match" "github.com/stashapp/stash/pkg/match"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/scene" "github.com/stashapp/stash/pkg/scene"
@@ -17,18 +19,18 @@ func getSceneFileTagger(s *models.Scene, cache *match.Cache) tagger {
} }
// ScenePerformers tags the provided scene with performers whose name matches the scene's path. // ScenePerformers tags the provided scene with performers whose name matches the scene's path.
func ScenePerformers(s *models.Scene, rw models.SceneReaderWriter, performerReader models.PerformerReader, cache *match.Cache) error { func ScenePerformers(ctx context.Context, s *models.Scene, rw scene.PerformerUpdater, performerReader match.PerformerAutoTagQueryer, cache *match.Cache) error {
t := getSceneFileTagger(s, cache) t := getSceneFileTagger(s, cache)
return t.tagPerformers(performerReader, func(subjectID, otherID int) (bool, error) { return t.tagPerformers(ctx, performerReader, func(subjectID, otherID int) (bool, error) {
return scene.AddPerformer(rw, subjectID, otherID) return scene.AddPerformer(ctx, rw, subjectID, otherID)
}) })
} }
// SceneStudios tags the provided scene with the first studio whose name matches the scene's path. // SceneStudios tags the provided scene with the first studio whose name matches the scene's path.
// //
// Scenes will not be tagged if studio is already set. // Scenes will not be tagged if studio is already set.
func SceneStudios(s *models.Scene, rw models.SceneReaderWriter, studioReader models.StudioReader, cache *match.Cache) error { func SceneStudios(ctx context.Context, s *models.Scene, rw SceneFinderUpdater, studioReader match.StudioAutoTagQueryer, cache *match.Cache) error {
if s.StudioID.Valid { if s.StudioID.Valid {
// don't modify // don't modify
return nil return nil
@@ -36,16 +38,16 @@ func SceneStudios(s *models.Scene, rw models.SceneReaderWriter, studioReader mod
t := getSceneFileTagger(s, cache) t := getSceneFileTagger(s, cache)
return t.tagStudios(studioReader, func(subjectID, otherID int) (bool, error) { return t.tagStudios(ctx, studioReader, func(subjectID, otherID int) (bool, error) {
return addSceneStudio(rw, subjectID, otherID) return addSceneStudio(ctx, rw, subjectID, otherID)
}) })
} }
// SceneTags tags the provided scene with tags whose name matches the scene's path. // SceneTags tags the provided scene with tags whose name matches the scene's path.
func SceneTags(s *models.Scene, rw models.SceneReaderWriter, tagReader models.TagReader, cache *match.Cache) error { func SceneTags(ctx context.Context, s *models.Scene, rw scene.TagUpdater, tagReader match.TagAutoTagQueryer, cache *match.Cache) error {
t := getSceneFileTagger(s, cache) t := getSceneFileTagger(s, cache)
return t.tagTags(tagReader, func(subjectID, otherID int) (bool, error) { return t.tagTags(ctx, tagReader, func(subjectID, otherID int) (bool, error) {
return scene.AddTag(rw, subjectID, otherID) return scene.AddTag(ctx, rw, subjectID, otherID)
}) })
} }

View File

@@ -173,19 +173,19 @@ func TestScenePerformers(t *testing.T) {
mockPerformerReader := &mocks.PerformerReaderWriter{} mockPerformerReader := &mocks.PerformerReaderWriter{}
mockSceneReader := &mocks.SceneReaderWriter{} mockSceneReader := &mocks.SceneReaderWriter{}
mockPerformerReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil) mockPerformerReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockPerformerReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once() mockPerformerReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once()
if test.Matches { if test.Matches {
mockSceneReader.On("GetPerformerIDs", sceneID).Return(nil, nil).Once() mockSceneReader.On("GetPerformerIDs", testCtx, sceneID).Return(nil, nil).Once()
mockSceneReader.On("UpdatePerformers", sceneID, []int{performerID}).Return(nil).Once() mockSceneReader.On("UpdatePerformers", testCtx, sceneID, []int{performerID}).Return(nil).Once()
} }
scene := models.Scene{ scene := models.Scene{
ID: sceneID, ID: sceneID,
Path: test.Path, Path: test.Path,
} }
err := ScenePerformers(&scene, mockSceneReader, mockPerformerReader, nil) err := ScenePerformers(testCtx, &scene, mockSceneReader, mockPerformerReader, nil)
assert.Nil(err) assert.Nil(err)
mockPerformerReader.AssertExpectations(t) mockPerformerReader.AssertExpectations(t)
@@ -217,9 +217,9 @@ func TestSceneStudios(t *testing.T) {
doTest := func(mockStudioReader *mocks.StudioReaderWriter, mockSceneReader *mocks.SceneReaderWriter, test pathTestTable) { doTest := func(mockStudioReader *mocks.StudioReaderWriter, mockSceneReader *mocks.SceneReaderWriter, test pathTestTable) {
if test.Matches { if test.Matches {
mockSceneReader.On("Find", sceneID).Return(&models.Scene{}, nil).Once() mockSceneReader.On("Find", testCtx, sceneID).Return(&models.Scene{}, nil).Once()
expectedStudioID := models.NullInt64(studioID) expectedStudioID := models.NullInt64(studioID)
mockSceneReader.On("Update", models.ScenePartial{ mockSceneReader.On("Update", testCtx, models.ScenePartial{
ID: sceneID, ID: sceneID,
StudioID: &expectedStudioID, StudioID: &expectedStudioID,
}).Return(nil, nil).Once() }).Return(nil, nil).Once()
@@ -229,7 +229,7 @@ func TestSceneStudios(t *testing.T) {
ID: sceneID, ID: sceneID,
Path: test.Path, Path: test.Path,
} }
err := SceneStudios(&scene, mockSceneReader, mockStudioReader, nil) err := SceneStudios(testCtx, &scene, mockSceneReader, mockStudioReader, nil)
assert.Nil(err) assert.Nil(err)
mockStudioReader.AssertExpectations(t) mockStudioReader.AssertExpectations(t)
@@ -240,9 +240,9 @@ func TestSceneStudios(t *testing.T) {
mockStudioReader := &mocks.StudioReaderWriter{} mockStudioReader := &mocks.StudioReaderWriter{}
mockSceneReader := &mocks.SceneReaderWriter{} mockSceneReader := &mocks.SceneReaderWriter{}
mockStudioReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil) mockStudioReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockStudioReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once() mockStudioReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
mockStudioReader.On("GetAliases", mock.Anything).Return([]string{}, nil).Maybe() mockStudioReader.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe()
doTest(mockStudioReader, mockSceneReader, test) doTest(mockStudioReader, mockSceneReader, test)
} }
@@ -255,12 +255,12 @@ func TestSceneStudios(t *testing.T) {
mockStudioReader := &mocks.StudioReaderWriter{} mockStudioReader := &mocks.StudioReaderWriter{}
mockSceneReader := &mocks.SceneReaderWriter{} mockSceneReader := &mocks.SceneReaderWriter{}
mockStudioReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil) mockStudioReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockStudioReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once() mockStudioReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
mockStudioReader.On("GetAliases", studioID).Return([]string{ mockStudioReader.On("GetAliases", testCtx, studioID).Return([]string{
studioName, studioName,
}, nil).Once() }, nil).Once()
mockStudioReader.On("GetAliases", reversedStudioID).Return([]string{}, nil).Once() mockStudioReader.On("GetAliases", testCtx, reversedStudioID).Return([]string{}, nil).Once()
doTest(mockStudioReader, mockSceneReader, test) doTest(mockStudioReader, mockSceneReader, test)
} }
@@ -290,15 +290,15 @@ func TestSceneTags(t *testing.T) {
doTest := func(mockTagReader *mocks.TagReaderWriter, mockSceneReader *mocks.SceneReaderWriter, test pathTestTable) { doTest := func(mockTagReader *mocks.TagReaderWriter, mockSceneReader *mocks.SceneReaderWriter, test pathTestTable) {
if test.Matches { if test.Matches {
mockSceneReader.On("GetTagIDs", sceneID).Return(nil, nil).Once() mockSceneReader.On("GetTagIDs", testCtx, sceneID).Return(nil, nil).Once()
mockSceneReader.On("UpdateTags", sceneID, []int{tagID}).Return(nil).Once() mockSceneReader.On("UpdateTags", testCtx, sceneID, []int{tagID}).Return(nil).Once()
} }
scene := models.Scene{ scene := models.Scene{
ID: sceneID, ID: sceneID,
Path: test.Path, Path: test.Path,
} }
err := SceneTags(&scene, mockSceneReader, mockTagReader, nil) err := SceneTags(testCtx, &scene, mockSceneReader, mockTagReader, nil)
assert.Nil(err) assert.Nil(err)
mockTagReader.AssertExpectations(t) mockTagReader.AssertExpectations(t)
@@ -309,9 +309,9 @@ func TestSceneTags(t *testing.T) {
mockTagReader := &mocks.TagReaderWriter{} mockTagReader := &mocks.TagReaderWriter{}
mockSceneReader := &mocks.SceneReaderWriter{} mockSceneReader := &mocks.SceneReaderWriter{}
mockTagReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil) mockTagReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockTagReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once() mockTagReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
mockTagReader.On("GetAliases", mock.Anything).Return([]string{}, nil).Maybe() mockTagReader.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe()
doTest(mockTagReader, mockSceneReader, test) doTest(mockTagReader, mockSceneReader, test)
} }
@@ -324,12 +324,12 @@ func TestSceneTags(t *testing.T) {
mockTagReader := &mocks.TagReaderWriter{} mockTagReader := &mocks.TagReaderWriter{}
mockSceneReader := &mocks.SceneReaderWriter{} mockSceneReader := &mocks.SceneReaderWriter{}
mockTagReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil) mockTagReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockTagReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once() mockTagReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
mockTagReader.On("GetAliases", tagID).Return([]string{ mockTagReader.On("GetAliases", testCtx, tagID).Return([]string{
tagName, tagName,
}, nil).Once() }, nil).Once()
mockTagReader.On("GetAliases", reversedTagID).Return([]string{}, nil).Once() mockTagReader.On("GetAliases", testCtx, reversedTagID).Return([]string{}, nil).Once()
doTest(mockTagReader, mockSceneReader, test) doTest(mockTagReader, mockSceneReader, test)
} }

View File

@@ -1,15 +1,19 @@
package autotag package autotag
import ( import (
"context"
"database/sql" "database/sql"
"github.com/stashapp/stash/pkg/gallery"
"github.com/stashapp/stash/pkg/image"
"github.com/stashapp/stash/pkg/match" "github.com/stashapp/stash/pkg/match"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/scene"
) )
func addSceneStudio(sceneWriter models.SceneReaderWriter, sceneID, studioID int) (bool, error) { func addSceneStudio(ctx context.Context, sceneWriter SceneFinderUpdater, sceneID, studioID int) (bool, error) {
// don't set if already set // don't set if already set
scene, err := sceneWriter.Find(sceneID) scene, err := sceneWriter.Find(ctx, sceneID)
if err != nil { if err != nil {
return false, err return false, err
} }
@@ -25,15 +29,15 @@ func addSceneStudio(sceneWriter models.SceneReaderWriter, sceneID, studioID int)
StudioID: &s, StudioID: &s,
} }
if _, err := sceneWriter.Update(scenePartial); err != nil { if _, err := sceneWriter.Update(ctx, scenePartial); err != nil {
return false, err return false, err
} }
return true, nil return true, nil
} }
func addImageStudio(imageWriter models.ImageReaderWriter, imageID, studioID int) (bool, error) { func addImageStudio(ctx context.Context, imageWriter ImageFinderUpdater, imageID, studioID int) (bool, error) {
// don't set if already set // don't set if already set
image, err := imageWriter.Find(imageID) image, err := imageWriter.Find(ctx, imageID)
if err != nil { if err != nil {
return false, err return false, err
} }
@@ -49,15 +53,15 @@ func addImageStudio(imageWriter models.ImageReaderWriter, imageID, studioID int)
StudioID: &s, StudioID: &s,
} }
if _, err := imageWriter.Update(imagePartial); err != nil { if _, err := imageWriter.Update(ctx, imagePartial); err != nil {
return false, err return false, err
} }
return true, nil return true, nil
} }
func addGalleryStudio(galleryWriter models.GalleryReaderWriter, galleryID, studioID int) (bool, error) { func addGalleryStudio(ctx context.Context, galleryWriter GalleryFinderUpdater, galleryID, studioID int) (bool, error) {
// don't set if already set // don't set if already set
gallery, err := galleryWriter.Find(galleryID) gallery, err := galleryWriter.Find(ctx, galleryID)
if err != nil { if err != nil {
return false, err return false, err
} }
@@ -73,7 +77,7 @@ func addGalleryStudio(galleryWriter models.GalleryReaderWriter, galleryID, studi
StudioID: &s, StudioID: &s,
} }
if _, err := galleryWriter.UpdatePartial(galleryPartial); err != nil { if _, err := galleryWriter.UpdatePartial(ctx, galleryPartial); err != nil {
return false, err return false, err
} }
return true, nil return true, nil
@@ -98,13 +102,19 @@ func getStudioTagger(p *models.Studio, aliases []string, cache *match.Cache) []t
return ret return ret
} }
type SceneFinderUpdater interface {
scene.Queryer
Find(ctx context.Context, id int) (*models.Scene, error)
Update(ctx context.Context, updatedScene models.ScenePartial) (*models.Scene, error)
}
// StudioScenes searches for scenes whose path matches the provided studio name and tags the scene with the studio, if studio is not already set on the scene. // StudioScenes searches for scenes whose path matches the provided studio name and tags the scene with the studio, if studio is not already set on the scene.
func StudioScenes(p *models.Studio, paths []string, aliases []string, rw models.SceneReaderWriter, cache *match.Cache) error { func StudioScenes(ctx context.Context, p *models.Studio, paths []string, aliases []string, rw SceneFinderUpdater, cache *match.Cache) error {
t := getStudioTagger(p, aliases, cache) t := getStudioTagger(p, aliases, cache)
for _, tt := range t { for _, tt := range t {
if err := tt.tagScenes(paths, rw, func(subjectID, otherID int) (bool, error) { if err := tt.tagScenes(ctx, paths, rw, func(subjectID, otherID int) (bool, error) {
return addSceneStudio(rw, otherID, subjectID) return addSceneStudio(ctx, rw, otherID, subjectID)
}); err != nil { }); err != nil {
return err return err
} }
@@ -113,13 +123,19 @@ func StudioScenes(p *models.Studio, paths []string, aliases []string, rw models.
return nil return nil
} }
type ImageFinderUpdater interface {
image.Queryer
Find(ctx context.Context, id int) (*models.Image, error)
Update(ctx context.Context, updatedImage models.ImagePartial) (*models.Image, error)
}
// StudioImages searches for images whose path matches the provided studio name and tags the image with the studio, if studio is not already set on the image. // StudioImages searches for images whose path matches the provided studio name and tags the image with the studio, if studio is not already set on the image.
func StudioImages(p *models.Studio, paths []string, aliases []string, rw models.ImageReaderWriter, cache *match.Cache) error { func StudioImages(ctx context.Context, p *models.Studio, paths []string, aliases []string, rw ImageFinderUpdater, cache *match.Cache) error {
t := getStudioTagger(p, aliases, cache) t := getStudioTagger(p, aliases, cache)
for _, tt := range t { for _, tt := range t {
if err := tt.tagImages(paths, rw, func(subjectID, otherID int) (bool, error) { if err := tt.tagImages(ctx, paths, rw, func(subjectID, otherID int) (bool, error) {
return addImageStudio(rw, otherID, subjectID) return addImageStudio(ctx, rw, otherID, subjectID)
}); err != nil { }); err != nil {
return err return err
} }
@@ -128,13 +144,19 @@ func StudioImages(p *models.Studio, paths []string, aliases []string, rw models.
return nil return nil
} }
type GalleryFinderUpdater interface {
gallery.Queryer
Find(ctx context.Context, id int) (*models.Gallery, error)
UpdatePartial(ctx context.Context, updatedGallery models.GalleryPartial) (*models.Gallery, error)
}
// StudioGalleries searches for galleries whose path matches the provided studio name and tags the gallery with the studio, if studio is not already set on the gallery. // StudioGalleries searches for galleries whose path matches the provided studio name and tags the gallery with the studio, if studio is not already set on the gallery.
func StudioGalleries(p *models.Studio, paths []string, aliases []string, rw models.GalleryReaderWriter, cache *match.Cache) error { func StudioGalleries(ctx context.Context, p *models.Studio, paths []string, aliases []string, rw GalleryFinderUpdater, cache *match.Cache) error {
t := getStudioTagger(p, aliases, cache) t := getStudioTagger(p, aliases, cache)
for _, tt := range t { for _, tt := range t {
if err := tt.tagGalleries(paths, rw, func(subjectID, otherID int) (bool, error) { if err := tt.tagGalleries(ctx, paths, rw, func(subjectID, otherID int) (bool, error) {
return addGalleryStudio(rw, otherID, subjectID) return addGalleryStudio(ctx, rw, otherID, subjectID)
}); err != nil { }); err != nil {
return err return err
} }

View File

@@ -113,7 +113,7 @@ func testStudioScenes(t *testing.T, tc testStudioCase) {
} }
// if alias provided, then don't find by name // if alias provided, then don't find by name
onNameQuery := mockSceneReader.On("Query", scene.QueryOptions(expectedSceneFilter, expectedFindFilter, false)) onNameQuery := mockSceneReader.On("Query", testCtx, scene.QueryOptions(expectedSceneFilter, expectedFindFilter, false))
if aliasName == "" { if aliasName == "" {
onNameQuery.Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once() onNameQuery.Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once()
@@ -128,21 +128,21 @@ func testStudioScenes(t *testing.T, tc testStudioCase) {
}, },
} }
mockSceneReader.On("Query", scene.QueryOptions(expectedAliasFilter, expectedFindFilter, false)). mockSceneReader.On("Query", testCtx, scene.QueryOptions(expectedAliasFilter, expectedFindFilter, false)).
Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once() Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once()
} }
for i := range matchingPaths { for i := range matchingPaths {
sceneID := i + 1 sceneID := i + 1
mockSceneReader.On("Find", sceneID).Return(&models.Scene{}, nil).Once() mockSceneReader.On("Find", testCtx, sceneID).Return(&models.Scene{}, nil).Once()
expectedStudioID := models.NullInt64(studioID) expectedStudioID := models.NullInt64(studioID)
mockSceneReader.On("Update", models.ScenePartial{ mockSceneReader.On("Update", testCtx, models.ScenePartial{
ID: sceneID, ID: sceneID,
StudioID: &expectedStudioID, StudioID: &expectedStudioID,
}).Return(nil, nil).Once() }).Return(nil, nil).Once()
} }
err := StudioScenes(&studio, nil, aliases, mockSceneReader, nil) err := StudioScenes(testCtx, &studio, nil, aliases, mockSceneReader, nil)
assert := assert.New(t) assert := assert.New(t)
@@ -206,7 +206,7 @@ func testStudioImages(t *testing.T, tc testStudioCase) {
} }
// if alias provided, then don't find by name // if alias provided, then don't find by name
onNameQuery := mockImageReader.On("Query", image.QueryOptions(expectedImageFilter, expectedFindFilter, false)) onNameQuery := mockImageReader.On("Query", testCtx, image.QueryOptions(expectedImageFilter, expectedFindFilter, false))
if aliasName == "" { if aliasName == "" {
onNameQuery.Return(mocks.ImageQueryResult(images, len(images)), nil).Once() onNameQuery.Return(mocks.ImageQueryResult(images, len(images)), nil).Once()
} else { } else {
@@ -220,21 +220,21 @@ func testStudioImages(t *testing.T, tc testStudioCase) {
}, },
} }
mockImageReader.On("Query", image.QueryOptions(expectedAliasFilter, expectedFindFilter, false)). mockImageReader.On("Query", testCtx, image.QueryOptions(expectedAliasFilter, expectedFindFilter, false)).
Return(mocks.ImageQueryResult(images, len(images)), nil).Once() Return(mocks.ImageQueryResult(images, len(images)), nil).Once()
} }
for i := range matchingPaths { for i := range matchingPaths {
imageID := i + 1 imageID := i + 1
mockImageReader.On("Find", imageID).Return(&models.Image{}, nil).Once() mockImageReader.On("Find", testCtx, imageID).Return(&models.Image{}, nil).Once()
expectedStudioID := models.NullInt64(studioID) expectedStudioID := models.NullInt64(studioID)
mockImageReader.On("Update", models.ImagePartial{ mockImageReader.On("Update", testCtx, models.ImagePartial{
ID: imageID, ID: imageID,
StudioID: &expectedStudioID, StudioID: &expectedStudioID,
}).Return(nil, nil).Once() }).Return(nil, nil).Once()
} }
err := StudioImages(&studio, nil, aliases, mockImageReader, nil) err := StudioImages(testCtx, &studio, nil, aliases, mockImageReader, nil)
assert := assert.New(t) assert := assert.New(t)
@@ -297,7 +297,7 @@ func testStudioGalleries(t *testing.T, tc testStudioCase) {
} }
// if alias provided, then don't find by name // if alias provided, then don't find by name
onNameQuery := mockGalleryReader.On("Query", expectedGalleryFilter, expectedFindFilter) onNameQuery := mockGalleryReader.On("Query", testCtx, expectedGalleryFilter, expectedFindFilter)
if aliasName == "" { if aliasName == "" {
onNameQuery.Return(galleries, len(galleries), nil).Once() onNameQuery.Return(galleries, len(galleries), nil).Once()
} else { } else {
@@ -311,20 +311,20 @@ func testStudioGalleries(t *testing.T, tc testStudioCase) {
}, },
} }
mockGalleryReader.On("Query", expectedAliasFilter, expectedFindFilter).Return(galleries, len(galleries), nil).Once() mockGalleryReader.On("Query", testCtx, expectedAliasFilter, expectedFindFilter).Return(galleries, len(galleries), nil).Once()
} }
for i := range matchingPaths { for i := range matchingPaths {
galleryID := i + 1 galleryID := i + 1
mockGalleryReader.On("Find", galleryID).Return(&models.Gallery{}, nil).Once() mockGalleryReader.On("Find", testCtx, galleryID).Return(&models.Gallery{}, nil).Once()
expectedStudioID := models.NullInt64(studioID) expectedStudioID := models.NullInt64(studioID)
mockGalleryReader.On("UpdatePartial", models.GalleryPartial{ mockGalleryReader.On("UpdatePartial", testCtx, models.GalleryPartial{
ID: galleryID, ID: galleryID,
StudioID: &expectedStudioID, StudioID: &expectedStudioID,
}).Return(nil, nil).Once() }).Return(nil, nil).Once()
} }
err := StudioGalleries(&studio, nil, aliases, mockGalleryReader, nil) err := StudioGalleries(testCtx, &studio, nil, aliases, mockGalleryReader, nil)
assert := assert.New(t) assert := assert.New(t)

View File

@@ -1,6 +1,8 @@
package autotag package autotag
import ( import (
"context"
"github.com/stashapp/stash/pkg/gallery" "github.com/stashapp/stash/pkg/gallery"
"github.com/stashapp/stash/pkg/image" "github.com/stashapp/stash/pkg/image"
"github.com/stashapp/stash/pkg/match" "github.com/stashapp/stash/pkg/match"
@@ -8,6 +10,21 @@ import (
"github.com/stashapp/stash/pkg/scene" "github.com/stashapp/stash/pkg/scene"
) )
type SceneQueryTagUpdater interface {
scene.Queryer
scene.TagUpdater
}
type ImageQueryTagUpdater interface {
image.Queryer
image.TagUpdater
}
type GalleryQueryTagUpdater interface {
gallery.Queryer
gallery.TagUpdater
}
func getTagTaggers(p *models.Tag, aliases []string, cache *match.Cache) []tagger { func getTagTaggers(p *models.Tag, aliases []string, cache *match.Cache) []tagger {
ret := []tagger{{ ret := []tagger{{
ID: p.ID, ID: p.ID,
@@ -29,12 +46,12 @@ func getTagTaggers(p *models.Tag, aliases []string, cache *match.Cache) []tagger
} }
// TagScenes searches for scenes whose path matches the provided tag name and tags the scene with the tag. // TagScenes searches for scenes whose path matches the provided tag name and tags the scene with the tag.
func TagScenes(p *models.Tag, paths []string, aliases []string, rw models.SceneReaderWriter, cache *match.Cache) error { func TagScenes(ctx context.Context, p *models.Tag, paths []string, aliases []string, rw SceneQueryTagUpdater, cache *match.Cache) error {
t := getTagTaggers(p, aliases, cache) t := getTagTaggers(p, aliases, cache)
for _, tt := range t { for _, tt := range t {
if err := tt.tagScenes(paths, rw, func(subjectID, otherID int) (bool, error) { if err := tt.tagScenes(ctx, paths, rw, func(subjectID, otherID int) (bool, error) {
return scene.AddTag(rw, otherID, subjectID) return scene.AddTag(ctx, rw, otherID, subjectID)
}); err != nil { }); err != nil {
return err return err
} }
@@ -43,12 +60,12 @@ func TagScenes(p *models.Tag, paths []string, aliases []string, rw models.SceneR
} }
// TagImages searches for images whose path matches the provided tag name and tags the image with the tag. // TagImages searches for images whose path matches the provided tag name and tags the image with the tag.
func TagImages(p *models.Tag, paths []string, aliases []string, rw models.ImageReaderWriter, cache *match.Cache) error { func TagImages(ctx context.Context, p *models.Tag, paths []string, aliases []string, rw ImageQueryTagUpdater, cache *match.Cache) error {
t := getTagTaggers(p, aliases, cache) t := getTagTaggers(p, aliases, cache)
for _, tt := range t { for _, tt := range t {
if err := tt.tagImages(paths, rw, func(subjectID, otherID int) (bool, error) { if err := tt.tagImages(ctx, paths, rw, func(subjectID, otherID int) (bool, error) {
return image.AddTag(rw, otherID, subjectID) return image.AddTag(ctx, rw, otherID, subjectID)
}); err != nil { }); err != nil {
return err return err
} }
@@ -57,12 +74,12 @@ func TagImages(p *models.Tag, paths []string, aliases []string, rw models.ImageR
} }
// TagGalleries searches for galleries whose path matches the provided tag name and tags the gallery with the tag. // TagGalleries searches for galleries whose path matches the provided tag name and tags the gallery with the tag.
func TagGalleries(p *models.Tag, paths []string, aliases []string, rw models.GalleryReaderWriter, cache *match.Cache) error { func TagGalleries(ctx context.Context, p *models.Tag, paths []string, aliases []string, rw GalleryQueryTagUpdater, cache *match.Cache) error {
t := getTagTaggers(p, aliases, cache) t := getTagTaggers(p, aliases, cache)
for _, tt := range t { for _, tt := range t {
if err := tt.tagGalleries(paths, rw, func(subjectID, otherID int) (bool, error) { if err := tt.tagGalleries(ctx, paths, rw, func(subjectID, otherID int) (bool, error) {
return gallery.AddTag(rw, otherID, subjectID) return gallery.AddTag(ctx, rw, otherID, subjectID)
}); err != nil { }); err != nil {
return err return err
} }

View File

@@ -113,7 +113,7 @@ func testTagScenes(t *testing.T, tc testTagCase) {
} }
// if alias provided, then don't find by name // if alias provided, then don't find by name
onNameQuery := mockSceneReader.On("Query", scene.QueryOptions(expectedSceneFilter, expectedFindFilter, false)) onNameQuery := mockSceneReader.On("Query", testCtx, scene.QueryOptions(expectedSceneFilter, expectedFindFilter, false))
if aliasName == "" { if aliasName == "" {
onNameQuery.Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once() onNameQuery.Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once()
} else { } else {
@@ -127,17 +127,17 @@ func testTagScenes(t *testing.T, tc testTagCase) {
}, },
} }
mockSceneReader.On("Query", scene.QueryOptions(expectedAliasFilter, expectedFindFilter, false)). mockSceneReader.On("Query", testCtx, scene.QueryOptions(expectedAliasFilter, expectedFindFilter, false)).
Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once() Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once()
} }
for i := range matchingPaths { for i := range matchingPaths {
sceneID := i + 1 sceneID := i + 1
mockSceneReader.On("GetTagIDs", sceneID).Return(nil, nil).Once() mockSceneReader.On("GetTagIDs", testCtx, sceneID).Return(nil, nil).Once()
mockSceneReader.On("UpdateTags", sceneID, []int{tagID}).Return(nil).Once() mockSceneReader.On("UpdateTags", testCtx, sceneID, []int{tagID}).Return(nil).Once()
} }
err := TagScenes(&tag, nil, aliases, mockSceneReader, nil) err := TagScenes(testCtx, &tag, nil, aliases, mockSceneReader, nil)
assert := assert.New(t) assert := assert.New(t)
@@ -201,7 +201,7 @@ func testTagImages(t *testing.T, tc testTagCase) {
} }
// if alias provided, then don't find by name // if alias provided, then don't find by name
onNameQuery := mockImageReader.On("Query", image.QueryOptions(expectedImageFilter, expectedFindFilter, false)) onNameQuery := mockImageReader.On("Query", testCtx, image.QueryOptions(expectedImageFilter, expectedFindFilter, false))
if aliasName == "" { if aliasName == "" {
onNameQuery.Return(mocks.ImageQueryResult(images, len(images)), nil).Once() onNameQuery.Return(mocks.ImageQueryResult(images, len(images)), nil).Once()
} else { } else {
@@ -215,17 +215,17 @@ func testTagImages(t *testing.T, tc testTagCase) {
}, },
} }
mockImageReader.On("Query", image.QueryOptions(expectedAliasFilter, expectedFindFilter, false)). mockImageReader.On("Query", testCtx, image.QueryOptions(expectedAliasFilter, expectedFindFilter, false)).
Return(mocks.ImageQueryResult(images, len(images)), nil).Once() Return(mocks.ImageQueryResult(images, len(images)), nil).Once()
} }
for i := range matchingPaths { for i := range matchingPaths {
imageID := i + 1 imageID := i + 1
mockImageReader.On("GetTagIDs", imageID).Return(nil, nil).Once() mockImageReader.On("GetTagIDs", testCtx, imageID).Return(nil, nil).Once()
mockImageReader.On("UpdateTags", imageID, []int{tagID}).Return(nil).Once() mockImageReader.On("UpdateTags", testCtx, imageID, []int{tagID}).Return(nil).Once()
} }
err := TagImages(&tag, nil, aliases, mockImageReader, nil) err := TagImages(testCtx, &tag, nil, aliases, mockImageReader, nil)
assert := assert.New(t) assert := assert.New(t)
@@ -289,7 +289,7 @@ func testTagGalleries(t *testing.T, tc testTagCase) {
} }
// if alias provided, then don't find by name // if alias provided, then don't find by name
onNameQuery := mockGalleryReader.On("Query", expectedGalleryFilter, expectedFindFilter) onNameQuery := mockGalleryReader.On("Query", testCtx, expectedGalleryFilter, expectedFindFilter)
if aliasName == "" { if aliasName == "" {
onNameQuery.Return(galleries, len(galleries), nil).Once() onNameQuery.Return(galleries, len(galleries), nil).Once()
} else { } else {
@@ -303,16 +303,16 @@ func testTagGalleries(t *testing.T, tc testTagCase) {
}, },
} }
mockGalleryReader.On("Query", expectedAliasFilter, expectedFindFilter).Return(galleries, len(galleries), nil).Once() mockGalleryReader.On("Query", testCtx, expectedAliasFilter, expectedFindFilter).Return(galleries, len(galleries), nil).Once()
} }
for i := range matchingPaths { for i := range matchingPaths {
galleryID := i + 1 galleryID := i + 1
mockGalleryReader.On("GetTagIDs", galleryID).Return(nil, nil).Once() mockGalleryReader.On("GetTagIDs", testCtx, galleryID).Return(nil, nil).Once()
mockGalleryReader.On("UpdateTags", galleryID, []int{tagID}).Return(nil).Once() mockGalleryReader.On("UpdateTags", testCtx, galleryID, []int{tagID}).Return(nil).Once()
} }
err := TagGalleries(&tag, nil, aliases, mockGalleryReader, nil) err := TagGalleries(testCtx, &tag, nil, aliases, mockGalleryReader, nil)
assert := assert.New(t) assert := assert.New(t)

View File

@@ -14,11 +14,14 @@
package autotag package autotag
import ( import (
"context"
"fmt" "fmt"
"github.com/stashapp/stash/pkg/gallery"
"github.com/stashapp/stash/pkg/image"
"github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/match" "github.com/stashapp/stash/pkg/match"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/scene"
) )
type tagger struct { type tagger struct {
@@ -41,8 +44,8 @@ func (t *tagger) addLog(otherType, otherName string) {
logger.Infof("Added %s '%s' to %s '%s'", otherType, otherName, t.Type, t.Name) logger.Infof("Added %s '%s' to %s '%s'", otherType, otherName, t.Type, t.Name)
} }
func (t *tagger) tagPerformers(performerReader models.PerformerReader, addFunc addLinkFunc) error { func (t *tagger) tagPerformers(ctx context.Context, performerReader match.PerformerAutoTagQueryer, addFunc addLinkFunc) error {
others, err := match.PathToPerformers(t.Path, performerReader, t.cache, t.trimExt) others, err := match.PathToPerformers(ctx, t.Path, performerReader, t.cache, t.trimExt)
if err != nil { if err != nil {
return err return err
} }
@@ -62,8 +65,8 @@ func (t *tagger) tagPerformers(performerReader models.PerformerReader, addFunc a
return nil return nil
} }
func (t *tagger) tagStudios(studioReader models.StudioReader, addFunc addLinkFunc) error { func (t *tagger) tagStudios(ctx context.Context, studioReader match.StudioAutoTagQueryer, addFunc addLinkFunc) error {
studio, err := match.PathToStudio(t.Path, studioReader, t.cache, t.trimExt) studio, err := match.PathToStudio(ctx, t.Path, studioReader, t.cache, t.trimExt)
if err != nil { if err != nil {
return err return err
} }
@@ -83,8 +86,8 @@ func (t *tagger) tagStudios(studioReader models.StudioReader, addFunc addLinkFun
return nil return nil
} }
func (t *tagger) tagTags(tagReader models.TagReader, addFunc addLinkFunc) error { func (t *tagger) tagTags(ctx context.Context, tagReader match.TagAutoTagQueryer, addFunc addLinkFunc) error {
others, err := match.PathToTags(t.Path, tagReader, t.cache, t.trimExt) others, err := match.PathToTags(ctx, t.Path, tagReader, t.cache, t.trimExt)
if err != nil { if err != nil {
return err return err
} }
@@ -104,8 +107,8 @@ func (t *tagger) tagTags(tagReader models.TagReader, addFunc addLinkFunc) error
return nil return nil
} }
func (t *tagger) tagScenes(paths []string, sceneReader models.SceneReader, addFunc addLinkFunc) error { func (t *tagger) tagScenes(ctx context.Context, paths []string, sceneReader scene.Queryer, addFunc addLinkFunc) error {
others, err := match.PathToScenes(t.Name, paths, sceneReader) others, err := match.PathToScenes(ctx, t.Name, paths, sceneReader)
if err != nil { if err != nil {
return err return err
} }
@@ -125,8 +128,8 @@ func (t *tagger) tagScenes(paths []string, sceneReader models.SceneReader, addFu
return nil return nil
} }
func (t *tagger) tagImages(paths []string, imageReader models.ImageReader, addFunc addLinkFunc) error { func (t *tagger) tagImages(ctx context.Context, paths []string, imageReader image.Queryer, addFunc addLinkFunc) error {
others, err := match.PathToImages(t.Name, paths, imageReader) others, err := match.PathToImages(ctx, t.Name, paths, imageReader)
if err != nil { if err != nil {
return err return err
} }
@@ -146,8 +149,8 @@ func (t *tagger) tagImages(paths []string, imageReader models.ImageReader, addFu
return nil return nil
} }
func (t *tagger) tagGalleries(paths []string, galleryReader models.GalleryReader, addFunc addLinkFunc) error { func (t *tagger) tagGalleries(ctx context.Context, paths []string, galleryReader gallery.Queryer, addFunc addLinkFunc) error {
others, err := match.PathToGalleries(t.Name, paths, galleryReader) others, err := match.PathToGalleries(ctx, t.Name, paths, galleryReader)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -41,6 +41,7 @@ import (
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/scene" "github.com/stashapp/stash/pkg/scene"
"github.com/stashapp/stash/pkg/sliceutil/stringslice" "github.com/stashapp/stash/pkg/sliceutil/stringslice"
"github.com/stashapp/stash/pkg/txn"
) )
var pageSize = 100 var pageSize = 100
@@ -56,7 +57,6 @@ type browse struct {
type contentDirectoryService struct { type contentDirectoryService struct {
*Server *Server
upnp.Eventing upnp.Eventing
txnManager models.TransactionManager
} }
func formatDurationSexagesimal(d time.Duration) string { func formatDurationSexagesimal(d time.Duration) string {
@@ -352,8 +352,8 @@ func (me *contentDirectoryService) handleBrowseMetadata(obj object, host string)
} else { } else {
var scene *models.Scene var scene *models.Scene
if err := me.txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { if err := txn.WithTxn(context.TODO(), me.txnManager, func(ctx context.Context) error {
scene, err = r.Scene().Find(sceneID) scene, err = me.repository.SceneFinder.Find(ctx, sceneID)
if err != nil { if err != nil {
return err return err
} }
@@ -431,14 +431,14 @@ func getRootObjects() []interface{} {
func (me *contentDirectoryService) getVideos(sceneFilter *models.SceneFilterType, parentID string, host string) []interface{} { func (me *contentDirectoryService) getVideos(sceneFilter *models.SceneFilterType, parentID string, host string) []interface{} {
var objs []interface{} var objs []interface{}
if err := me.txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { if err := txn.WithTxn(context.TODO(), me.txnManager, func(ctx context.Context) error {
sort := "title" sort := "title"
findFilter := &models.FindFilterType{ findFilter := &models.FindFilterType{
PerPage: &pageSize, PerPage: &pageSize,
Sort: &sort, Sort: &sort,
} }
scenes, total, err := scene.QueryWithCount(r.Scene(), sceneFilter, findFilter) scenes, total, err := scene.QueryWithCount(ctx, me.repository.SceneFinder, sceneFilter, findFilter)
if err != nil { if err != nil {
return err return err
} }
@@ -449,7 +449,7 @@ func (me *contentDirectoryService) getVideos(sceneFilter *models.SceneFilterType
parentID: parentID, parentID: parentID,
} }
objs, err = pager.getPages(r, total) objs, err = pager.getPages(ctx, me.repository.SceneFinder, total)
if err != nil { if err != nil {
return err return err
} }
@@ -470,14 +470,14 @@ func (me *contentDirectoryService) getVideos(sceneFilter *models.SceneFilterType
func (me *contentDirectoryService) getPageVideos(sceneFilter *models.SceneFilterType, parentID string, page int, host string) []interface{} { func (me *contentDirectoryService) getPageVideos(sceneFilter *models.SceneFilterType, parentID string, page int, host string) []interface{} {
var objs []interface{} var objs []interface{}
if err := me.txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { if err := txn.WithTxn(context.TODO(), me.txnManager, func(ctx context.Context) error {
pager := scenePager{ pager := scenePager{
sceneFilter: sceneFilter, sceneFilter: sceneFilter,
parentID: parentID, parentID: parentID,
} }
var err error var err error
objs, err = pager.getPageVideos(r, page, host) objs, err = pager.getPageVideos(ctx, me.repository.SceneFinder, page, host)
if err != nil { if err != nil {
return err return err
} }
@@ -511,8 +511,8 @@ func (me *contentDirectoryService) getAllScenes(host string) []interface{} {
func (me *contentDirectoryService) getStudios() []interface{} { func (me *contentDirectoryService) getStudios() []interface{} {
var objs []interface{} var objs []interface{}
if err := me.txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { if err := txn.WithTxn(context.TODO(), me.txnManager, func(ctx context.Context) error {
studios, err := r.Studio().All() studios, err := me.repository.StudioFinder.All(ctx)
if err != nil { if err != nil {
return err return err
} }
@@ -550,8 +550,8 @@ func (me *contentDirectoryService) getStudioScenes(paths []string, host string)
func (me *contentDirectoryService) getTags() []interface{} { func (me *contentDirectoryService) getTags() []interface{} {
var objs []interface{} var objs []interface{}
if err := me.txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { if err := txn.WithTxn(context.TODO(), me.txnManager, func(ctx context.Context) error {
tags, err := r.Tag().All() tags, err := me.repository.TagFinder.All(ctx)
if err != nil { if err != nil {
return err return err
} }
@@ -589,8 +589,8 @@ func (me *contentDirectoryService) getTagScenes(paths []string, host string) []i
func (me *contentDirectoryService) getPerformers() []interface{} { func (me *contentDirectoryService) getPerformers() []interface{} {
var objs []interface{} var objs []interface{}
if err := me.txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { if err := txn.WithTxn(context.TODO(), me.txnManager, func(ctx context.Context) error {
performers, err := r.Performer().All() performers, err := me.repository.PerformerFinder.All(ctx)
if err != nil { if err != nil {
return err return err
} }
@@ -628,8 +628,8 @@ func (me *contentDirectoryService) getPerformerScenes(paths []string, host strin
func (me *contentDirectoryService) getMovies() []interface{} { func (me *contentDirectoryService) getMovies() []interface{} {
var objs []interface{} var objs []interface{}
if err := me.txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { if err := txn.WithTxn(context.TODO(), me.txnManager, func(ctx context.Context) error {
movies, err := r.Movie().All() movies, err := me.repository.MovieFinder.All(ctx)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -31,7 +31,6 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/stashapp/stash/pkg/models/mocks"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@@ -60,7 +59,6 @@ func TestRootParentObjectID(t *testing.T) {
func testHandleBrowse(argsXML string) (map[string]string, error) { func testHandleBrowse(argsXML string) (map[string]string, error) {
cds := contentDirectoryService{ cds := contentDirectoryService{
Server: &Server{}, Server: &Server{},
txnManager: mocks.NewTransactionManager(),
} }
r := &http.Request{} r := &http.Request{}

View File

@@ -48,8 +48,31 @@ import (
"github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/scene"
"github.com/stashapp/stash/pkg/txn"
) )
type SceneFinder interface {
scene.Queryer
scene.IDFinder
}
type StudioFinder interface {
All(ctx context.Context) ([]*models.Studio, error)
}
type TagFinder interface {
All(ctx context.Context) ([]*models.Tag, error)
}
type PerformerFinder interface {
All(ctx context.Context) ([]*models.Performer, error)
}
type MovieFinder interface {
All(ctx context.Context) ([]*models.Movie, error)
}
const ( const (
serverField = "Linux/3.4 DLNADOC/1.50 UPnP/1.0 DMS/1.0" serverField = "Linux/3.4 DLNADOC/1.50 UPnP/1.0 DMS/1.0"
rootDeviceType = "urn:schemas-upnp-org:device:MediaServer:1" rootDeviceType = "urn:schemas-upnp-org:device:MediaServer:1"
@@ -249,7 +272,8 @@ type Server struct {
// Time interval between SSPD announces // Time interval between SSPD announces
NotifyInterval time.Duration NotifyInterval time.Duration
txnManager models.TransactionManager txnManager txn.Manager
repository Repository
sceneServer sceneServer sceneServer sceneServer
ipWhitelistManager *ipWhitelistManager ipWhitelistManager *ipWhitelistManager
} }
@@ -415,12 +439,12 @@ func (me *Server) serveIcon(w http.ResponseWriter, r *http.Request) {
} }
var scene *models.Scene var scene *models.Scene
err := me.txnManager.WithReadTxn(r.Context(), func(r models.ReaderRepository) error { err := txn.WithTxn(r.Context(), me.txnManager, func(ctx context.Context) error {
idInt, err := strconv.Atoi(sceneId) idInt, err := strconv.Atoi(sceneId)
if err != nil { if err != nil {
return nil return nil
} }
scene, _ = r.Scene().Find(idInt) scene, _ = me.repository.SceneFinder.Find(ctx, idInt)
return nil return nil
}) })
if err != nil { if err != nil {
@@ -555,12 +579,12 @@ func (me *Server) initMux(mux *http.ServeMux) {
mux.HandleFunc(resPath, func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc(resPath, func(w http.ResponseWriter, r *http.Request) {
sceneId := r.URL.Query().Get("scene") sceneId := r.URL.Query().Get("scene")
var scene *models.Scene var scene *models.Scene
err := me.txnManager.WithReadTxn(r.Context(), func(r models.ReaderRepository) error { err := txn.WithTxn(r.Context(), me.txnManager, func(ctx context.Context) error {
sceneIdInt, err := strconv.Atoi(sceneId) sceneIdInt, err := strconv.Atoi(sceneId)
if err != nil { if err != nil {
return nil return nil
} }
scene, _ = r.Scene().Find(sceneIdInt) scene, _ = me.repository.SceneFinder.Find(ctx, sceneIdInt)
return nil return nil
}) })
if err != nil { if err != nil {
@@ -596,7 +620,6 @@ func (me *Server) initServices() {
me.services = map[string]UPnPService{ me.services = map[string]UPnPService{
"ContentDirectory": &contentDirectoryService{ "ContentDirectory": &contentDirectoryService{
Server: me, Server: me,
txnManager: me.txnManager,
}, },
"ConnectionManager": &connectionManagerService{ "ConnectionManager": &connectionManagerService{
Server: me, Server: me,

View File

@@ -1,6 +1,7 @@
package dlna package dlna
import ( import (
"context"
"fmt" "fmt"
"math" "math"
"strconv" "strconv"
@@ -18,7 +19,7 @@ func (p *scenePager) getPageID(page int) string {
return p.parentID + "/page/" + strconv.Itoa(page) return p.parentID + "/page/" + strconv.Itoa(page)
} }
func (p *scenePager) getPages(r models.ReaderRepository, total int) ([]interface{}, error) { func (p *scenePager) getPages(ctx context.Context, r scene.Queryer, total int) ([]interface{}, error) {
var objs []interface{} var objs []interface{}
// get the first scene of each page to set an appropriate title // get the first scene of each page to set an appropriate title
@@ -37,7 +38,7 @@ func (p *scenePager) getPages(r models.ReaderRepository, total int) ([]interface
if pages <= 10 || (page-1)%(pages/10) == 0 { if pages <= 10 || (page-1)%(pages/10) == 0 {
thisPage := ((page - 1) * pageSize) + 1 thisPage := ((page - 1) * pageSize) + 1
findFilter.Page = &thisPage findFilter.Page = &thisPage
scenes, err := scene.Query(r.Scene(), p.sceneFilter, findFilter) scenes, err := scene.Query(ctx, r, p.sceneFilter, findFilter)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -58,7 +59,7 @@ func (p *scenePager) getPages(r models.ReaderRepository, total int) ([]interface
return objs, nil return objs, nil
} }
func (p *scenePager) getPageVideos(r models.ReaderRepository, page int, host string) ([]interface{}, error) { func (p *scenePager) getPageVideos(ctx context.Context, r SceneFinder, page int, host string) ([]interface{}, error) {
var objs []interface{} var objs []interface{}
sort := "title" sort := "title"
@@ -68,7 +69,7 @@ func (p *scenePager) getPageVideos(r models.ReaderRepository, page int, host str
Sort: &sort, Sort: &sort,
} }
scenes, err := scene.Query(r.Scene(), p.sceneFilter, findFilter) scenes, err := scene.Query(ctx, r, p.sceneFilter, findFilter)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -10,8 +10,17 @@ import (
"github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
) )
type Repository struct {
SceneFinder SceneFinder
StudioFinder StudioFinder
TagFinder TagFinder
PerformerFinder PerformerFinder
MovieFinder MovieFinder
}
type Status struct { type Status struct {
Running bool `json:"running"` Running bool `json:"running"`
// If not currently running, time until it will be started. If running, time until it will be stopped // If not currently running, time until it will be started. If running, time until it will be stopped
@@ -48,7 +57,8 @@ type Config interface {
} }
type Service struct { type Service struct {
txnManager models.TransactionManager txnManager txn.Manager
repository Repository
config Config config Config
sceneServer sceneServer sceneServer sceneServer
ipWhitelistMgr *ipWhitelistManager ipWhitelistMgr *ipWhitelistManager
@@ -121,6 +131,7 @@ func (s *Service) init() error {
s.server = &Server{ s.server = &Server{
txnManager: s.txnManager, txnManager: s.txnManager,
sceneServer: s.sceneServer, sceneServer: s.sceneServer,
repository: s.repository,
ipWhitelistManager: s.ipWhitelistMgr, ipWhitelistManager: s.ipWhitelistMgr,
Interfaces: interfaces, Interfaces: interfaces,
HTTPConn: func() net.Listener { HTTPConn: func() net.Listener {
@@ -181,9 +192,10 @@ func (s *Service) init() error {
// } // }
// NewService initialises and returns a new DLNA service. // NewService initialises and returns a new DLNA service.
func NewService(txnManager models.TransactionManager, cfg Config, sceneServer sceneServer) *Service { func NewService(txnManager txn.Manager, repo Repository, cfg Config, sceneServer sceneServer) *Service {
ret := &Service{ ret := &Service{
txnManager: txnManager, txnManager: txnManager,
repository: repo,
sceneServer: sceneServer, sceneServer: sceneServer,
config: cfg, config: cfg,
ipWhitelistMgr: &ipWhitelistManager{ ipWhitelistMgr: &ipWhitelistManager{

View File

@@ -9,6 +9,7 @@ import (
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/scene" "github.com/stashapp/stash/pkg/scene"
"github.com/stashapp/stash/pkg/scraper" "github.com/stashapp/stash/pkg/scraper"
"github.com/stashapp/stash/pkg/txn"
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
) )
@@ -28,13 +29,18 @@ type ScraperSource struct {
} }
type SceneIdentifier struct { type SceneIdentifier struct {
SceneReaderUpdater SceneReaderUpdater
StudioCreator StudioCreator
PerformerCreator PerformerCreator
TagCreator TagCreator
DefaultOptions *MetadataOptions DefaultOptions *MetadataOptions
Sources []ScraperSource Sources []ScraperSource
ScreenshotSetter scene.ScreenshotSetter ScreenshotSetter scene.ScreenshotSetter
SceneUpdatePostHookExecutor SceneUpdatePostHookExecutor SceneUpdatePostHookExecutor SceneUpdatePostHookExecutor
} }
func (t *SceneIdentifier) Identify(ctx context.Context, txnManager models.TransactionManager, scene *models.Scene) error { func (t *SceneIdentifier) Identify(ctx context.Context, txnManager txn.Manager, scene *models.Scene) error {
result, err := t.scrapeScene(ctx, scene) result, err := t.scrapeScene(ctx, scene)
if err != nil { if err != nil {
return err return err
@@ -80,7 +86,7 @@ func (t *SceneIdentifier) scrapeScene(ctx context.Context, scene *models.Scene)
return nil, nil return nil, nil
} }
func (t *SceneIdentifier) getSceneUpdater(ctx context.Context, s *models.Scene, result *scrapeResult, repo models.Repository) (*scene.UpdateSet, error) { func (t *SceneIdentifier) getSceneUpdater(ctx context.Context, s *models.Scene, result *scrapeResult) (*scene.UpdateSet, error) {
ret := &scene.UpdateSet{ ret := &scene.UpdateSet{
ID: s.ID, ID: s.ID,
} }
@@ -106,7 +112,10 @@ func (t *SceneIdentifier) getSceneUpdater(ctx context.Context, s *models.Scene,
scraped := result.result scraped := result.result
rel := sceneRelationships{ rel := sceneRelationships{
repo: repo, sceneReader: t.SceneReaderUpdater,
studioCreator: t.StudioCreator,
performerCreator: t.PerformerCreator,
tagCreator: t.TagCreator,
scene: s, scene: s,
result: result, result: result,
fieldOptions: fieldOptions, fieldOptions: fieldOptions,
@@ -114,7 +123,7 @@ func (t *SceneIdentifier) getSceneUpdater(ctx context.Context, s *models.Scene,
ret.Partial = getScenePartial(s, scraped, fieldOptions, setOrganized) ret.Partial = getScenePartial(s, scraped, fieldOptions, setOrganized)
studioID, err := rel.studio() studioID, err := rel.studio(ctx)
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting studio: %w", err) return nil, fmt.Errorf("error getting studio: %w", err)
} }
@@ -134,17 +143,17 @@ func (t *SceneIdentifier) getSceneUpdater(ctx context.Context, s *models.Scene,
} }
} }
ret.PerformerIDs, err = rel.performers(ignoreMale) ret.PerformerIDs, err = rel.performers(ctx, ignoreMale)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ret.TagIDs, err = rel.tags() ret.TagIDs, err = rel.tags(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ret.StashIDs, err = rel.stashIDs() ret.StashIDs, err = rel.stashIDs(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -167,11 +176,11 @@ func (t *SceneIdentifier) getSceneUpdater(ctx context.Context, s *models.Scene,
return ret, nil return ret, nil
} }
func (t *SceneIdentifier) modifyScene(ctx context.Context, txnManager models.TransactionManager, s *models.Scene, result *scrapeResult) error { func (t *SceneIdentifier) modifyScene(ctx context.Context, txnManager txn.Manager, s *models.Scene, result *scrapeResult) error {
var updater *scene.UpdateSet var updater *scene.UpdateSet
if err := txnManager.WithTxn(ctx, func(repo models.Repository) error { if err := txn.WithTxn(ctx, txnManager, func(ctx context.Context) error {
var err error var err error
updater, err = t.getSceneUpdater(ctx, s, result, repo) updater, err = t.getSceneUpdater(ctx, s, result)
if err != nil { if err != nil {
return err return err
} }
@@ -182,7 +191,7 @@ func (t *SceneIdentifier) modifyScene(ctx context.Context, txnManager models.Tra
return nil return nil
} }
_, err = updater.Update(repo.Scene(), t.ScreenshotSetter) _, err = updater.Update(ctx, t.SceneReaderUpdater, t.ScreenshotSetter)
if err != nil { if err != nil {
return fmt.Errorf("error updating scene: %w", err) return fmt.Errorf("error updating scene: %w", err)
} }

View File

@@ -13,6 +13,8 @@ import (
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
) )
var testCtx = context.Background()
type mockSceneScraper struct { type mockSceneScraper struct {
errIDs []int errIDs []int
results map[int]*scraper.ScrapedScene results map[int]*scraper.ScrapedScene
@@ -70,11 +72,12 @@ func TestSceneIdentifier_Identify(t *testing.T) {
}, },
} }
repo := mocks.NewTransactionManager() mockSceneReaderWriter := &mocks.SceneReaderWriter{}
repo.Scene().(*mocks.SceneReaderWriter).On("Update", mock.MatchedBy(func(partial models.ScenePartial) bool {
mockSceneReaderWriter.On("Update", testCtx, mock.MatchedBy(func(partial models.ScenePartial) bool {
return partial.ID != errUpdateID return partial.ID != errUpdateID
})).Return(nil, nil) })).Return(nil, nil)
repo.Scene().(*mocks.SceneReaderWriter).On("Update", mock.MatchedBy(func(partial models.ScenePartial) bool { mockSceneReaderWriter.On("Update", testCtx, mock.MatchedBy(func(partial models.ScenePartial) bool {
return partial.ID == errUpdateID return partial.ID == errUpdateID
})).Return(nil, errors.New("update error")) })).Return(nil, errors.New("update error"))
@@ -116,6 +119,7 @@ func TestSceneIdentifier_Identify(t *testing.T) {
} }
identifier := SceneIdentifier{ identifier := SceneIdentifier{
SceneReaderUpdater: mockSceneReaderWriter,
DefaultOptions: defaultOptions, DefaultOptions: defaultOptions,
Sources: sources, Sources: sources,
SceneUpdatePostHookExecutor: mockHookExecutor{}, SceneUpdatePostHookExecutor: mockHookExecutor{},
@@ -126,7 +130,7 @@ func TestSceneIdentifier_Identify(t *testing.T) {
scene := &models.Scene{ scene := &models.Scene{
ID: tt.sceneID, ID: tt.sceneID,
} }
if err := identifier.Identify(context.TODO(), repo, scene); (err != nil) != tt.wantErr { if err := identifier.Identify(testCtx, &mocks.TxnManager{}, scene); (err != nil) != tt.wantErr {
t.Errorf("SceneIdentifier.Identify() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("SceneIdentifier.Identify() error = %v, wantErr %v", err, tt.wantErr)
} }
}) })
@@ -134,7 +138,9 @@ func TestSceneIdentifier_Identify(t *testing.T) {
} }
func TestSceneIdentifier_modifyScene(t *testing.T) { func TestSceneIdentifier_modifyScene(t *testing.T) {
repo := mocks.NewTransactionManager() repo := models.Repository{
TxnManager: &mocks.TxnManager{},
}
tr := &SceneIdentifier{} tr := &SceneIdentifier{}
type args struct { type args struct {
@@ -159,7 +165,7 @@ func TestSceneIdentifier_modifyScene(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if err := tr.modifyScene(context.TODO(), repo, tt.args.scene, tt.args.result); (err != nil) != tt.wantErr { if err := tr.modifyScene(testCtx, repo, tt.args.scene, tt.args.result); (err != nil) != tt.wantErr {
t.Errorf("SceneIdentifier.modifyScene() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("SceneIdentifier.modifyScene() error = %v, wantErr %v", err, tt.wantErr)
} }
}) })

View File

@@ -1,6 +1,7 @@
package identify package identify
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"strconv" "strconv"
@@ -10,7 +11,12 @@ import (
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
) )
func getPerformerID(endpoint string, r models.Repository, p *models.ScrapedPerformer, createMissing bool) (*int, error) { type PerformerCreator interface {
Create(ctx context.Context, newPerformer models.Performer) (*models.Performer, error)
UpdateStashIDs(ctx context.Context, performerID int, stashIDs []models.StashID) error
}
func getPerformerID(ctx context.Context, endpoint string, w PerformerCreator, p *models.ScrapedPerformer, createMissing bool) (*int, error) {
if p.StoredID != nil { if p.StoredID != nil {
// existing performer, just add it // existing performer, just add it
performerID, err := strconv.Atoi(*p.StoredID) performerID, err := strconv.Atoi(*p.StoredID)
@@ -20,20 +26,20 @@ func getPerformerID(endpoint string, r models.Repository, p *models.ScrapedPerfo
return &performerID, nil return &performerID, nil
} else if createMissing && p.Name != nil { // name is mandatory } else if createMissing && p.Name != nil { // name is mandatory
return createMissingPerformer(endpoint, r, p) return createMissingPerformer(ctx, endpoint, w, p)
} }
return nil, nil return nil, nil
} }
func createMissingPerformer(endpoint string, r models.Repository, p *models.ScrapedPerformer) (*int, error) { func createMissingPerformer(ctx context.Context, endpoint string, w PerformerCreator, p *models.ScrapedPerformer) (*int, error) {
created, err := r.Performer().Create(scrapedToPerformerInput(p)) created, err := w.Create(ctx, scrapedToPerformerInput(p))
if err != nil { if err != nil {
return nil, fmt.Errorf("error creating performer: %w", err) return nil, fmt.Errorf("error creating performer: %w", err)
} }
if endpoint != "" && p.RemoteSiteID != nil { if endpoint != "" && p.RemoteSiteID != nil {
if err := r.Performer().UpdateStashIDs(created.ID, []models.StashID{ if err := w.UpdateStashIDs(ctx, created.ID, []models.StashID{
{ {
Endpoint: endpoint, Endpoint: endpoint,
StashID: *p.RemoteSiteID, StashID: *p.RemoteSiteID,

View File

@@ -23,8 +23,8 @@ func Test_getPerformerID(t *testing.T) {
validStoredID := 1 validStoredID := 1
name := "name" name := "name"
repo := mocks.NewTransactionManager() mockPerformerReaderWriter := mocks.PerformerReaderWriter{}
repo.PerformerMock().On("Create", mock.Anything).Return(&models.Performer{ mockPerformerReaderWriter.On("Create", testCtx, mock.Anything).Return(&models.Performer{
ID: validStoredID, ID: validStoredID,
}, nil) }, nil)
@@ -110,7 +110,7 @@ func Test_getPerformerID(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := getPerformerID(tt.args.endpoint, repo, tt.args.p, tt.args.createMissing) got, err := getPerformerID(testCtx, tt.args.endpoint, &mockPerformerReaderWriter, tt.args.p, tt.args.createMissing)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("getPerformerID() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("getPerformerID() error = %v, wantErr %v", err, tt.wantErr)
return return
@@ -131,23 +131,23 @@ func Test_createMissingPerformer(t *testing.T) {
invalidName := "invalidName" invalidName := "invalidName"
performerID := 1 performerID := 1
repo := mocks.NewTransactionManager() mockPerformerReaderWriter := mocks.PerformerReaderWriter{}
repo.PerformerMock().On("Create", mock.MatchedBy(func(p models.Performer) bool { mockPerformerReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p models.Performer) bool {
return p.Name.String == validName return p.Name.String == validName
})).Return(&models.Performer{ })).Return(&models.Performer{
ID: performerID, ID: performerID,
}, nil) }, nil)
repo.PerformerMock().On("Create", mock.MatchedBy(func(p models.Performer) bool { mockPerformerReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p models.Performer) bool {
return p.Name.String == invalidName return p.Name.String == invalidName
})).Return(nil, errors.New("error creating performer")) })).Return(nil, errors.New("error creating performer"))
repo.PerformerMock().On("UpdateStashIDs", performerID, []models.StashID{ mockPerformerReaderWriter.On("UpdateStashIDs", testCtx, performerID, []models.StashID{
{ {
Endpoint: invalidEndpoint, Endpoint: invalidEndpoint,
StashID: remoteSiteID, StashID: remoteSiteID,
}, },
}).Return(errors.New("error updating stash ids")) }).Return(errors.New("error updating stash ids"))
repo.PerformerMock().On("UpdateStashIDs", performerID, []models.StashID{ mockPerformerReaderWriter.On("UpdateStashIDs", testCtx, performerID, []models.StashID{
{ {
Endpoint: validEndpoint, Endpoint: validEndpoint,
StashID: remoteSiteID, StashID: remoteSiteID,
@@ -213,7 +213,7 @@ func Test_createMissingPerformer(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := createMissingPerformer(tt.args.endpoint, repo, tt.args.p) got, err := createMissingPerformer(testCtx, tt.args.endpoint, &mockPerformerReaderWriter, tt.args.p)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("createMissingPerformer() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("createMissingPerformer() error = %v, wantErr %v", err, tt.wantErr)
return return

View File

@@ -9,19 +9,35 @@ import (
"time" "time"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/scene"
"github.com/stashapp/stash/pkg/sliceutil" "github.com/stashapp/stash/pkg/sliceutil"
"github.com/stashapp/stash/pkg/sliceutil/intslice" "github.com/stashapp/stash/pkg/sliceutil/intslice"
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
) )
type SceneReaderUpdater interface {
GetPerformerIDs(ctx context.Context, sceneID int) ([]int, error)
GetTagIDs(ctx context.Context, sceneID int) ([]int, error)
GetStashIDs(ctx context.Context, sceneID int) ([]*models.StashID, error)
GetCover(ctx context.Context, sceneID int) ([]byte, error)
scene.Updater
}
type TagCreator interface {
Create(ctx context.Context, newTag models.Tag) (*models.Tag, error)
}
type sceneRelationships struct { type sceneRelationships struct {
repo models.Repository sceneReader SceneReaderUpdater
studioCreator StudioCreator
performerCreator PerformerCreator
tagCreator TagCreator
scene *models.Scene scene *models.Scene
result *scrapeResult result *scrapeResult
fieldOptions map[string]*FieldOptions fieldOptions map[string]*FieldOptions
} }
func (g sceneRelationships) studio() (*int64, error) { func (g sceneRelationships) studio(ctx context.Context) (*int64, error) {
existingID := g.scene.StudioID existingID := g.scene.StudioID
fieldStrategy := g.fieldOptions["studio"] fieldStrategy := g.fieldOptions["studio"]
createMissing := fieldStrategy != nil && utils.IsTrue(fieldStrategy.CreateMissing) createMissing := fieldStrategy != nil && utils.IsTrue(fieldStrategy.CreateMissing)
@@ -45,13 +61,13 @@ func (g sceneRelationships) studio() (*int64, error) {
return &studioID, nil return &studioID, nil
} }
} else if createMissing { } else if createMissing {
return createMissingStudio(endpoint, g.repo, scraped) return createMissingStudio(ctx, endpoint, g.studioCreator, scraped)
} }
return nil, nil return nil, nil
} }
func (g sceneRelationships) performers(ignoreMale bool) ([]int, error) { func (g sceneRelationships) performers(ctx context.Context, ignoreMale bool) ([]int, error) {
fieldStrategy := g.fieldOptions["performers"] fieldStrategy := g.fieldOptions["performers"]
scraped := g.result.result.Performers scraped := g.result.result.Performers
@@ -66,11 +82,10 @@ func (g sceneRelationships) performers(ignoreMale bool) ([]int, error) {
strategy = fieldStrategy.Strategy strategy = fieldStrategy.Strategy
} }
repo := g.repo
endpoint := g.result.source.RemoteSite endpoint := g.result.source.RemoteSite
var performerIDs []int var performerIDs []int
originalPerformerIDs, err := repo.Scene().GetPerformerIDs(g.scene.ID) originalPerformerIDs, err := g.sceneReader.GetPerformerIDs(ctx, g.scene.ID)
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting scene performers: %w", err) return nil, fmt.Errorf("error getting scene performers: %w", err)
} }
@@ -85,7 +100,7 @@ func (g sceneRelationships) performers(ignoreMale bool) ([]int, error) {
continue continue
} }
performerID, err := getPerformerID(endpoint, repo, p, createMissing) performerID, err := getPerformerID(ctx, endpoint, g.performerCreator, p, createMissing)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -103,11 +118,10 @@ func (g sceneRelationships) performers(ignoreMale bool) ([]int, error) {
return performerIDs, nil return performerIDs, nil
} }
func (g sceneRelationships) tags() ([]int, error) { func (g sceneRelationships) tags(ctx context.Context) ([]int, error) {
fieldStrategy := g.fieldOptions["tags"] fieldStrategy := g.fieldOptions["tags"]
scraped := g.result.result.Tags scraped := g.result.result.Tags
target := g.scene target := g.scene
r := g.repo
// just check if ignored // just check if ignored
if len(scraped) == 0 || !shouldSetSingleValueField(fieldStrategy, false) { if len(scraped) == 0 || !shouldSetSingleValueField(fieldStrategy, false) {
@@ -121,7 +135,7 @@ func (g sceneRelationships) tags() ([]int, error) {
} }
var tagIDs []int var tagIDs []int
originalTagIDs, err := r.Scene().GetTagIDs(target.ID) originalTagIDs, err := g.sceneReader.GetTagIDs(ctx, target.ID)
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting scene tags: %w", err) return nil, fmt.Errorf("error getting scene tags: %w", err)
} }
@@ -142,7 +156,7 @@ func (g sceneRelationships) tags() ([]int, error) {
tagIDs = intslice.IntAppendUnique(tagIDs, int(tagID)) tagIDs = intslice.IntAppendUnique(tagIDs, int(tagID))
} else if createMissing { } else if createMissing {
now := time.Now() now := time.Now()
created, err := r.Tag().Create(models.Tag{ created, err := g.tagCreator.Create(ctx, models.Tag{
Name: t.Name, Name: t.Name,
CreatedAt: models.SQLiteTimestamp{Timestamp: now}, CreatedAt: models.SQLiteTimestamp{Timestamp: now},
UpdatedAt: models.SQLiteTimestamp{Timestamp: now}, UpdatedAt: models.SQLiteTimestamp{Timestamp: now},
@@ -163,11 +177,10 @@ func (g sceneRelationships) tags() ([]int, error) {
return tagIDs, nil return tagIDs, nil
} }
func (g sceneRelationships) stashIDs() ([]models.StashID, error) { func (g sceneRelationships) stashIDs(ctx context.Context) ([]models.StashID, error) {
remoteSiteID := g.result.result.RemoteSiteID remoteSiteID := g.result.result.RemoteSiteID
fieldStrategy := g.fieldOptions["stash_ids"] fieldStrategy := g.fieldOptions["stash_ids"]
target := g.scene target := g.scene
r := g.repo
endpoint := g.result.source.RemoteSite endpoint := g.result.source.RemoteSite
@@ -183,7 +196,7 @@ func (g sceneRelationships) stashIDs() ([]models.StashID, error) {
var originalStashIDs []models.StashID var originalStashIDs []models.StashID
var stashIDs []models.StashID var stashIDs []models.StashID
stashIDPtrs, err := r.Scene().GetStashIDs(target.ID) stashIDPtrs, err := g.sceneReader.GetStashIDs(ctx, target.ID)
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting scene tag: %w", err) return nil, fmt.Errorf("error getting scene tag: %w", err)
} }
@@ -227,14 +240,13 @@ func (g sceneRelationships) stashIDs() ([]models.StashID, error) {
func (g sceneRelationships) cover(ctx context.Context) ([]byte, error) { func (g sceneRelationships) cover(ctx context.Context) ([]byte, error) {
scraped := g.result.result.Image scraped := g.result.result.Image
r := g.repo
if scraped == nil { if scraped == nil {
return nil, nil return nil, nil
} }
// always overwrite if present // always overwrite if present
existingCover, err := r.Scene().GetCover(g.scene.ID) existingCover, err := g.sceneReader.GetCover(ctx, g.scene.ID)
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting scene cover: %w", err) return nil, fmt.Errorf("error getting scene cover: %w", err)
} }

View File

@@ -24,13 +24,13 @@ func Test_sceneRelationships_studio(t *testing.T) {
Strategy: FieldStrategyMerge, Strategy: FieldStrategyMerge,
} }
repo := mocks.NewTransactionManager() mockStudioReaderWriter := &mocks.StudioReaderWriter{}
repo.StudioMock().On("Create", mock.Anything).Return(&models.Studio{ mockStudioReaderWriter.On("Create", testCtx, mock.Anything).Return(&models.Studio{
ID: int(validStoredIDInt), ID: int(validStoredIDInt),
}, nil) }, nil)
tr := sceneRelationships{ tr := sceneRelationships{
repo: repo, studioCreator: mockStudioReaderWriter,
fieldOptions: make(map[string]*FieldOptions), fieldOptions: make(map[string]*FieldOptions),
} }
@@ -124,7 +124,7 @@ func Test_sceneRelationships_studio(t *testing.T) {
}, },
} }
got, err := tr.studio() got, err := tr.studio(testCtx)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("sceneRelationships.studio() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("sceneRelationships.studio() error = %v, wantErr %v", err, tt.wantErr)
return return
@@ -156,13 +156,13 @@ func Test_sceneRelationships_performers(t *testing.T) {
Strategy: FieldStrategyMerge, Strategy: FieldStrategyMerge,
} }
repo := mocks.NewTransactionManager() mockSceneReaderWriter := &mocks.SceneReaderWriter{}
repo.SceneMock().On("GetPerformerIDs", sceneID).Return(nil, nil) mockSceneReaderWriter.On("GetPerformerIDs", testCtx, sceneID).Return(nil, nil)
repo.SceneMock().On("GetPerformerIDs", sceneWithPerformerID).Return([]int{existingPerformerID}, nil) mockSceneReaderWriter.On("GetPerformerIDs", testCtx, sceneWithPerformerID).Return([]int{existingPerformerID}, nil)
repo.SceneMock().On("GetPerformerIDs", errSceneID).Return(nil, errors.New("error getting IDs")) mockSceneReaderWriter.On("GetPerformerIDs", testCtx, errSceneID).Return(nil, errors.New("error getting IDs"))
tr := sceneRelationships{ tr := sceneRelationships{
repo: repo, sceneReader: mockSceneReaderWriter,
fieldOptions: make(map[string]*FieldOptions), fieldOptions: make(map[string]*FieldOptions),
} }
@@ -316,7 +316,7 @@ func Test_sceneRelationships_performers(t *testing.T) {
}, },
} }
got, err := tr.performers(tt.ignoreMale) got, err := tr.performers(testCtx, tt.ignoreMale)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("sceneRelationships.performers() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("sceneRelationships.performers() error = %v, wantErr %v", err, tt.wantErr)
return return
@@ -347,22 +347,24 @@ func Test_sceneRelationships_tags(t *testing.T) {
Strategy: FieldStrategyMerge, Strategy: FieldStrategyMerge,
} }
repo := mocks.NewTransactionManager() mockSceneReaderWriter := &mocks.SceneReaderWriter{}
repo.SceneMock().On("GetTagIDs", sceneID).Return(nil, nil) mockTagReaderWriter := &mocks.TagReaderWriter{}
repo.SceneMock().On("GetTagIDs", sceneWithTagID).Return([]int{existingID}, nil) mockSceneReaderWriter.On("GetTagIDs", testCtx, sceneID).Return(nil, nil)
repo.SceneMock().On("GetTagIDs", errSceneID).Return(nil, errors.New("error getting IDs")) mockSceneReaderWriter.On("GetTagIDs", testCtx, sceneWithTagID).Return([]int{existingID}, nil)
mockSceneReaderWriter.On("GetTagIDs", testCtx, errSceneID).Return(nil, errors.New("error getting IDs"))
repo.TagMock().On("Create", mock.MatchedBy(func(p models.Tag) bool { mockTagReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p models.Tag) bool {
return p.Name == validName return p.Name == validName
})).Return(&models.Tag{ })).Return(&models.Tag{
ID: validStoredIDInt, ID: validStoredIDInt,
}, nil) }, nil)
repo.TagMock().On("Create", mock.MatchedBy(func(p models.Tag) bool { mockTagReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p models.Tag) bool {
return p.Name == invalidName return p.Name == invalidName
})).Return(nil, errors.New("error creating tag")) })).Return(nil, errors.New("error creating tag"))
tr := sceneRelationships{ tr := sceneRelationships{
repo: repo, sceneReader: mockSceneReaderWriter,
tagCreator: mockTagReaderWriter,
fieldOptions: make(map[string]*FieldOptions), fieldOptions: make(map[string]*FieldOptions),
} }
@@ -505,7 +507,7 @@ func Test_sceneRelationships_tags(t *testing.T) {
}, },
} }
got, err := tr.tags() got, err := tr.tags(testCtx)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("sceneRelationships.tags() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("sceneRelationships.tags() error = %v, wantErr %v", err, tt.wantErr)
return return
@@ -534,18 +536,18 @@ func Test_sceneRelationships_stashIDs(t *testing.T) {
Strategy: FieldStrategyMerge, Strategy: FieldStrategyMerge,
} }
repo := mocks.NewTransactionManager() mockSceneReaderWriter := &mocks.SceneReaderWriter{}
repo.SceneMock().On("GetStashIDs", sceneID).Return(nil, nil) mockSceneReaderWriter.On("GetStashIDs", testCtx, sceneID).Return(nil, nil)
repo.SceneMock().On("GetStashIDs", sceneWithStashID).Return([]*models.StashID{ mockSceneReaderWriter.On("GetStashIDs", testCtx, sceneWithStashID).Return([]*models.StashID{
{ {
StashID: remoteSiteID, StashID: remoteSiteID,
Endpoint: existingEndpoint, Endpoint: existingEndpoint,
}, },
}, nil) }, nil)
repo.SceneMock().On("GetStashIDs", errSceneID).Return(nil, errors.New("error getting IDs")) mockSceneReaderWriter.On("GetStashIDs", testCtx, errSceneID).Return(nil, errors.New("error getting IDs"))
tr := sceneRelationships{ tr := sceneRelationships{
repo: repo, sceneReader: mockSceneReaderWriter,
fieldOptions: make(map[string]*FieldOptions), fieldOptions: make(map[string]*FieldOptions),
} }
@@ -680,7 +682,7 @@ func Test_sceneRelationships_stashIDs(t *testing.T) {
}, },
} }
got, err := tr.stashIDs() got, err := tr.stashIDs(testCtx)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("sceneRelationships.stashIDs() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("sceneRelationships.stashIDs() error = %v, wantErr %v", err, tt.wantErr)
return return
@@ -707,12 +709,12 @@ func Test_sceneRelationships_cover(t *testing.T) {
newDataEncoded := base64Prefix + utils.GetBase64StringFromData(newData) newDataEncoded := base64Prefix + utils.GetBase64StringFromData(newData)
invalidData := newDataEncoded + "!!!" invalidData := newDataEncoded + "!!!"
repo := mocks.NewTransactionManager() mockSceneReaderWriter := &mocks.SceneReaderWriter{}
repo.SceneMock().On("GetCover", sceneID).Return(existingData, nil) mockSceneReaderWriter.On("GetCover", testCtx, sceneID).Return(existingData, nil)
repo.SceneMock().On("GetCover", errSceneID).Return(nil, errors.New("error getting cover")) mockSceneReaderWriter.On("GetCover", testCtx, errSceneID).Return(nil, errors.New("error getting cover"))
tr := sceneRelationships{ tr := sceneRelationships{
repo: repo, sceneReader: mockSceneReaderWriter,
fieldOptions: make(map[string]*FieldOptions), fieldOptions: make(map[string]*FieldOptions),
} }

View File

@@ -1,6 +1,7 @@
package identify package identify
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"time" "time"
@@ -9,14 +10,19 @@ import (
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
) )
func createMissingStudio(endpoint string, repo models.Repository, studio *models.ScrapedStudio) (*int64, error) { type StudioCreator interface {
created, err := repo.Studio().Create(scrapedToStudioInput(studio)) Create(ctx context.Context, newStudio models.Studio) (*models.Studio, error)
UpdateStashIDs(ctx context.Context, studioID int, stashIDs []models.StashID) error
}
func createMissingStudio(ctx context.Context, endpoint string, w StudioCreator, studio *models.ScrapedStudio) (*int64, error) {
created, err := w.Create(ctx, scrapedToStudioInput(studio))
if err != nil { if err != nil {
return nil, fmt.Errorf("error creating studio: %w", err) return nil, fmt.Errorf("error creating studio: %w", err)
} }
if endpoint != "" && studio.RemoteSiteID != nil { if endpoint != "" && studio.RemoteSiteID != nil {
if err := repo.Studio().UpdateStashIDs(created.ID, []models.StashID{ if err := w.UpdateStashIDs(ctx, created.ID, []models.StashID{
{ {
Endpoint: endpoint, Endpoint: endpoint,
StashID: *studio.RemoteSiteID, StashID: *studio.RemoteSiteID,

View File

@@ -20,23 +20,24 @@ func Test_createMissingStudio(t *testing.T) {
createdID := 1 createdID := 1
createdID64 := int64(createdID) createdID64 := int64(createdID)
repo := mocks.NewTransactionManager() repo := mocks.NewTxnRepository()
repo.StudioMock().On("Create", mock.MatchedBy(func(p models.Studio) bool { mockStudioReaderWriter := repo.Studio.(*mocks.StudioReaderWriter)
mockStudioReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p models.Studio) bool {
return p.Name.String == validName return p.Name.String == validName
})).Return(&models.Studio{ })).Return(&models.Studio{
ID: createdID, ID: createdID,
}, nil) }, nil)
repo.StudioMock().On("Create", mock.MatchedBy(func(p models.Studio) bool { mockStudioReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p models.Studio) bool {
return p.Name.String == invalidName return p.Name.String == invalidName
})).Return(nil, errors.New("error creating performer")) })).Return(nil, errors.New("error creating performer"))
repo.StudioMock().On("UpdateStashIDs", createdID, []models.StashID{ mockStudioReaderWriter.On("UpdateStashIDs", testCtx, createdID, []models.StashID{
{ {
Endpoint: invalidEndpoint, Endpoint: invalidEndpoint,
StashID: remoteSiteID, StashID: remoteSiteID,
}, },
}).Return(errors.New("error updating stash ids")) }).Return(errors.New("error updating stash ids"))
repo.StudioMock().On("UpdateStashIDs", createdID, []models.StashID{ mockStudioReaderWriter.On("UpdateStashIDs", testCtx, createdID, []models.StashID{
{ {
Endpoint: validEndpoint, Endpoint: validEndpoint,
StashID: remoteSiteID, StashID: remoteSiteID,
@@ -102,7 +103,7 @@ func Test_createMissingStudio(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := createMissingStudio(tt.args.endpoint, repo, tt.args.studio) got, err := createMissingStudio(testCtx, tt.args.endpoint, mockStudioReaderWriter, tt.args.studio)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("createMissingStudio() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("createMissingStudio() error = %v, wantErr %v", err, tt.wantErr)
return return

View File

@@ -7,16 +7,21 @@ import (
"github.com/stashapp/stash/internal/manager/config" "github.com/stashapp/stash/internal/manager/config"
"github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
) )
func setInitialMD5Config(ctx context.Context, txnManager models.TransactionManager) { type SceneCounter interface {
Count(ctx context.Context) (int, error)
}
func setInitialMD5Config(ctx context.Context, txnManager txn.Manager, counter SceneCounter) {
// if there are no scene files in the database, then default the // if there are no scene files in the database, then default the
// VideoFileNamingAlgorithm config setting to oshash and calculateMD5 to // VideoFileNamingAlgorithm config setting to oshash and calculateMD5 to
// false, otherwise set them to true for backwards compatibility purposes // false, otherwise set them to true for backwards compatibility purposes
var count int var count int
if err := txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { if err := txn.WithTxn(ctx, txnManager, func(ctx context.Context) error {
var err error var err error
count, err = r.Scene().Count() count, err = counter.Count(ctx)
return err return err
}); err != nil { }); err != nil {
logger.Errorf("Error while counting scenes: %s", err.Error()) logger.Errorf("Error while counting scenes: %s", err.Error())
@@ -36,6 +41,11 @@ func setInitialMD5Config(ctx context.Context, txnManager models.TransactionManag
} }
} }
type SceneMissingHashCounter interface {
CountMissingChecksum(ctx context.Context) (int, error)
CountMissingOSHash(ctx context.Context) (int, error)
}
// ValidateVideoFileNamingAlgorithm validates changing the // ValidateVideoFileNamingAlgorithm validates changing the
// VideoFileNamingAlgorithm configuration flag. // VideoFileNamingAlgorithm configuration flag.
// //
@@ -44,12 +54,10 @@ func setInitialMD5Config(ctx context.Context, txnManager models.TransactionManag
// //
// Likewise, if VideoFileNamingAlgorithm is set to oshash, then this function // Likewise, if VideoFileNamingAlgorithm is set to oshash, then this function
// will ensure that all oshash values are set on all scenes. // will ensure that all oshash values are set on all scenes.
func ValidateVideoFileNamingAlgorithm(txnManager models.TransactionManager, newValue models.HashAlgorithm) error { func ValidateVideoFileNamingAlgorithm(ctx context.Context, qb SceneMissingHashCounter, newValue models.HashAlgorithm) error {
// if algorithm is being set to MD5, then all checksums must be present // if algorithm is being set to MD5, then all checksums must be present
return txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error {
qb := r.Scene()
if newValue == models.HashAlgorithmMd5 { if newValue == models.HashAlgorithmMd5 {
missingMD5, err := qb.CountMissingChecksum() missingMD5, err := qb.CountMissingChecksum(ctx)
if err != nil { if err != nil {
return err return err
} }
@@ -58,7 +66,7 @@ func ValidateVideoFileNamingAlgorithm(txnManager models.TransactionManager, newV
return errors.New("some checksums are missing on scenes. Run Scan with calculateMD5 set to true") return errors.New("some checksums are missing on scenes. Run Scan with calculateMD5 set to true")
} }
} else if newValue == models.HashAlgorithmOshash { } else if newValue == models.HashAlgorithmOshash {
missingOSHash, err := qb.CountMissingOSHash() missingOSHash, err := qb.CountMissingOSHash(ctx)
if err != nil { if err != nil {
return err return err
} }
@@ -69,5 +77,4 @@ func ValidateVideoFileNamingAlgorithm(txnManager models.TransactionManager, newV
} }
return nil return nil
})
} }

View File

@@ -1,6 +1,7 @@
package manager package manager
import ( import (
"context"
"database/sql" "database/sql"
"errors" "errors"
"path/filepath" "path/filepath"
@@ -470,7 +471,15 @@ func (p *SceneFilenameParser) initWhiteSpaceRegex() {
} }
} }
func (p *SceneFilenameParser) Parse(repo models.ReaderRepository) ([]*SceneParserResult, int, error) { type SceneFilenameParserRepository struct {
Scene scene.Queryer
Performer PerformerNamesFinder
Studio studio.Queryer
Movie MovieNameFinder
Tag tag.Queryer
}
func (p *SceneFilenameParser) Parse(ctx context.Context, repo SceneFilenameParserRepository) ([]*SceneParserResult, int, error) {
// perform the query to find the scenes // perform the query to find the scenes
mapper, err := newParseMapper(p.Pattern, p.ParserInput.IgnoreWords) mapper, err := newParseMapper(p.Pattern, p.ParserInput.IgnoreWords)
@@ -492,17 +501,17 @@ func (p *SceneFilenameParser) Parse(repo models.ReaderRepository) ([]*SceneParse
p.Filter.Q = nil p.Filter.Q = nil
scenes, total, err := scene.QueryWithCount(repo.Scene(), sceneFilter, p.Filter) scenes, total, err := scene.QueryWithCount(ctx, repo.Scene, sceneFilter, p.Filter)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
ret := p.parseScenes(repo, scenes, mapper) ret := p.parseScenes(ctx, repo, scenes, mapper)
return ret, total, nil return ret, total, nil
} }
func (p *SceneFilenameParser) parseScenes(repo models.ReaderRepository, scenes []*models.Scene, mapper *parseMapper) []*SceneParserResult { func (p *SceneFilenameParser) parseScenes(ctx context.Context, repo SceneFilenameParserRepository, scenes []*models.Scene, mapper *parseMapper) []*SceneParserResult {
var ret []*SceneParserResult var ret []*SceneParserResult
for _, scene := range scenes { for _, scene := range scenes {
sceneHolder := mapper.parse(scene) sceneHolder := mapper.parse(scene)
@@ -511,7 +520,7 @@ func (p *SceneFilenameParser) parseScenes(repo models.ReaderRepository, scenes [
r := &SceneParserResult{ r := &SceneParserResult{
Scene: scene, Scene: scene,
} }
p.setParserResult(repo, *sceneHolder, r) p.setParserResult(ctx, repo, *sceneHolder, r)
ret = append(ret, r) ret = append(ret, r)
} }
@@ -530,7 +539,11 @@ func (p SceneFilenameParser) replaceWhitespaceCharacters(value string) string {
return value return value
} }
func (p *SceneFilenameParser) queryPerformer(qb models.PerformerReader, performerName string) *models.Performer { type PerformerNamesFinder interface {
FindByNames(ctx context.Context, names []string, nocase bool) ([]*models.Performer, error)
}
func (p *SceneFilenameParser) queryPerformer(ctx context.Context, qb PerformerNamesFinder, performerName string) *models.Performer {
// massage the performer name // massage the performer name
performerName = delimiterRE.ReplaceAllString(performerName, " ") performerName = delimiterRE.ReplaceAllString(performerName, " ")
@@ -540,7 +553,7 @@ func (p *SceneFilenameParser) queryPerformer(qb models.PerformerReader, performe
} }
// perform an exact match and grab the first // perform an exact match and grab the first
performers, _ := qb.FindByNames([]string{performerName}, true) performers, _ := qb.FindByNames(ctx, []string{performerName}, true)
var ret *models.Performer var ret *models.Performer
if len(performers) > 0 { if len(performers) > 0 {
@@ -553,7 +566,7 @@ func (p *SceneFilenameParser) queryPerformer(qb models.PerformerReader, performe
return ret return ret
} }
func (p *SceneFilenameParser) queryStudio(qb models.StudioReader, studioName string) *models.Studio { func (p *SceneFilenameParser) queryStudio(ctx context.Context, qb studio.Queryer, studioName string) *models.Studio {
// massage the performer name // massage the performer name
studioName = delimiterRE.ReplaceAllString(studioName, " ") studioName = delimiterRE.ReplaceAllString(studioName, " ")
@@ -562,11 +575,11 @@ func (p *SceneFilenameParser) queryStudio(qb models.StudioReader, studioName str
return ret return ret
} }
ret, _ := studio.ByName(qb, studioName) ret, _ := studio.ByName(ctx, qb, studioName)
// try to match on alias // try to match on alias
if ret == nil { if ret == nil {
ret, _ = studio.ByAlias(qb, studioName) ret, _ = studio.ByAlias(ctx, qb, studioName)
} }
// add result to cache // add result to cache
@@ -575,7 +588,11 @@ func (p *SceneFilenameParser) queryStudio(qb models.StudioReader, studioName str
return ret return ret
} }
func (p *SceneFilenameParser) queryMovie(qb models.MovieReader, movieName string) *models.Movie { type MovieNameFinder interface {
FindByName(ctx context.Context, name string, nocase bool) (*models.Movie, error)
}
func (p *SceneFilenameParser) queryMovie(ctx context.Context, qb MovieNameFinder, movieName string) *models.Movie {
// massage the movie name // massage the movie name
movieName = delimiterRE.ReplaceAllString(movieName, " ") movieName = delimiterRE.ReplaceAllString(movieName, " ")
@@ -584,7 +601,7 @@ func (p *SceneFilenameParser) queryMovie(qb models.MovieReader, movieName string
return ret return ret
} }
ret, _ := qb.FindByName(movieName, true) ret, _ := qb.FindByName(ctx, movieName, true)
// add result to cache // add result to cache
p.movieCache[movieName] = ret p.movieCache[movieName] = ret
@@ -592,7 +609,7 @@ func (p *SceneFilenameParser) queryMovie(qb models.MovieReader, movieName string
return ret return ret
} }
func (p *SceneFilenameParser) queryTag(qb models.TagReader, tagName string) *models.Tag { func (p *SceneFilenameParser) queryTag(ctx context.Context, qb tag.Queryer, tagName string) *models.Tag {
// massage the tag name // massage the tag name
tagName = delimiterRE.ReplaceAllString(tagName, " ") tagName = delimiterRE.ReplaceAllString(tagName, " ")
@@ -602,11 +619,11 @@ func (p *SceneFilenameParser) queryTag(qb models.TagReader, tagName string) *mod
} }
// match tag name exactly // match tag name exactly
ret, _ := tag.ByName(qb, tagName) ret, _ := tag.ByName(ctx, qb, tagName)
// try to match on alias // try to match on alias
if ret == nil { if ret == nil {
ret, _ = tag.ByAlias(qb, tagName) ret, _ = tag.ByAlias(ctx, qb, tagName)
} }
// add result to cache // add result to cache
@@ -615,12 +632,12 @@ func (p *SceneFilenameParser) queryTag(qb models.TagReader, tagName string) *mod
return ret return ret
} }
func (p *SceneFilenameParser) setPerformers(qb models.PerformerReader, h sceneHolder, result *SceneParserResult) { func (p *SceneFilenameParser) setPerformers(ctx context.Context, qb PerformerNamesFinder, h sceneHolder, result *SceneParserResult) {
// query for each performer // query for each performer
performersSet := make(map[int]bool) performersSet := make(map[int]bool)
for _, performerName := range h.performers { for _, performerName := range h.performers {
if performerName != "" { if performerName != "" {
performer := p.queryPerformer(qb, performerName) performer := p.queryPerformer(ctx, qb, performerName)
if performer != nil { if performer != nil {
if _, found := performersSet[performer.ID]; !found { if _, found := performersSet[performer.ID]; !found {
result.PerformerIds = append(result.PerformerIds, strconv.Itoa(performer.ID)) result.PerformerIds = append(result.PerformerIds, strconv.Itoa(performer.ID))
@@ -631,12 +648,12 @@ func (p *SceneFilenameParser) setPerformers(qb models.PerformerReader, h sceneHo
} }
} }
func (p *SceneFilenameParser) setTags(qb models.TagReader, h sceneHolder, result *SceneParserResult) { func (p *SceneFilenameParser) setTags(ctx context.Context, qb tag.Queryer, h sceneHolder, result *SceneParserResult) {
// query for each performer // query for each performer
tagsSet := make(map[int]bool) tagsSet := make(map[int]bool)
for _, tagName := range h.tags { for _, tagName := range h.tags {
if tagName != "" { if tagName != "" {
tag := p.queryTag(qb, tagName) tag := p.queryTag(ctx, qb, tagName)
if tag != nil { if tag != nil {
if _, found := tagsSet[tag.ID]; !found { if _, found := tagsSet[tag.ID]; !found {
result.TagIds = append(result.TagIds, strconv.Itoa(tag.ID)) result.TagIds = append(result.TagIds, strconv.Itoa(tag.ID))
@@ -647,10 +664,10 @@ func (p *SceneFilenameParser) setTags(qb models.TagReader, h sceneHolder, result
} }
} }
func (p *SceneFilenameParser) setStudio(qb models.StudioReader, h sceneHolder, result *SceneParserResult) { func (p *SceneFilenameParser) setStudio(ctx context.Context, qb studio.Queryer, h sceneHolder, result *SceneParserResult) {
// query for each performer // query for each performer
if h.studio != "" { if h.studio != "" {
studio := p.queryStudio(qb, h.studio) studio := p.queryStudio(ctx, qb, h.studio)
if studio != nil { if studio != nil {
studioID := strconv.Itoa(studio.ID) studioID := strconv.Itoa(studio.ID)
result.StudioID = &studioID result.StudioID = &studioID
@@ -658,12 +675,12 @@ func (p *SceneFilenameParser) setStudio(qb models.StudioReader, h sceneHolder, r
} }
} }
func (p *SceneFilenameParser) setMovies(qb models.MovieReader, h sceneHolder, result *SceneParserResult) { func (p *SceneFilenameParser) setMovies(ctx context.Context, qb MovieNameFinder, h sceneHolder, result *SceneParserResult) {
// query for each movie // query for each movie
moviesSet := make(map[int]bool) moviesSet := make(map[int]bool)
for _, movieName := range h.movies { for _, movieName := range h.movies {
if movieName != "" { if movieName != "" {
movie := p.queryMovie(qb, movieName) movie := p.queryMovie(ctx, qb, movieName)
if movie != nil { if movie != nil {
if _, found := moviesSet[movie.ID]; !found { if _, found := moviesSet[movie.ID]; !found {
result.Movies = append(result.Movies, &SceneMovieID{ result.Movies = append(result.Movies, &SceneMovieID{
@@ -676,7 +693,7 @@ func (p *SceneFilenameParser) setMovies(qb models.MovieReader, h sceneHolder, re
} }
} }
func (p *SceneFilenameParser) setParserResult(repo models.ReaderRepository, h sceneHolder, result *SceneParserResult) { func (p *SceneFilenameParser) setParserResult(ctx context.Context, repo SceneFilenameParserRepository, h sceneHolder, result *SceneParserResult) {
if h.result.Title.Valid { if h.result.Title.Valid {
title := h.result.Title.String title := h.result.Title.String
title = p.replaceWhitespaceCharacters(title) title = p.replaceWhitespaceCharacters(title)
@@ -698,15 +715,15 @@ func (p *SceneFilenameParser) setParserResult(repo models.ReaderRepository, h sc
} }
if len(h.performers) > 0 { if len(h.performers) > 0 {
p.setPerformers(repo.Performer(), h, result) p.setPerformers(ctx, repo.Performer, h, result)
} }
if len(h.tags) > 0 { if len(h.tags) > 0 {
p.setTags(repo.Tag(), h, result) p.setTags(ctx, repo.Tag, h, result)
} }
p.setStudio(repo.Studio(), h, result) p.setStudio(ctx, repo.Studio, h, result)
if len(h.movies) > 0 { if len(h.movies) > 0 {
p.setMovies(repo.Movie(), h, result) p.setMovies(ctx, repo.Movie, h, result)
} }
} }

View File

@@ -1,6 +1,7 @@
package manager package manager
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"strconv" "strconv"
@@ -52,22 +53,22 @@ func (e ImportDuplicateEnum) MarshalGQL(w io.Writer) {
} }
type importer interface { type importer interface {
PreImport() error PreImport(ctx context.Context) error
PostImport(id int) error PostImport(ctx context.Context, id int) error
Name() string Name() string
FindExistingID() (*int, error) FindExistingID(ctx context.Context) (*int, error)
Create() (*int, error) Create(ctx context.Context) (*int, error)
Update(id int) error Update(ctx context.Context, id int) error
} }
func performImport(i importer, duplicateBehaviour ImportDuplicateEnum) error { func performImport(ctx context.Context, i importer, duplicateBehaviour ImportDuplicateEnum) error {
if err := i.PreImport(); err != nil { if err := i.PreImport(ctx); err != nil {
return err return err
} }
// try to find an existing object with the same name // try to find an existing object with the same name
name := i.Name() name := i.Name()
existing, err := i.FindExistingID() existing, err := i.FindExistingID(ctx)
if err != nil { if err != nil {
return fmt.Errorf("error finding existing objects: %v", err) return fmt.Errorf("error finding existing objects: %v", err)
} }
@@ -84,12 +85,12 @@ func performImport(i importer, duplicateBehaviour ImportDuplicateEnum) error {
// must be overwriting // must be overwriting
id = *existing id = *existing
if err := i.Update(id); err != nil { if err := i.Update(ctx, id); err != nil {
return fmt.Errorf("error updating existing object: %v", err) return fmt.Errorf("error updating existing object: %v", err)
} }
} else { } else {
// creating // creating
createdID, err := i.Create() createdID, err := i.Create(ctx)
if err != nil { if err != nil {
return fmt.Errorf("error creating object: %v", err) return fmt.Errorf("error creating object: %v", err)
} }
@@ -97,7 +98,7 @@ func performImport(i importer, duplicateBehaviour ImportDuplicateEnum) error {
id = *createdID id = *createdID
} }
if err := i.PostImport(id); err != nil { if err := i.PostImport(ctx, id); err != nil {
return err return err
} }

View File

@@ -17,7 +17,6 @@ import (
"github.com/stashapp/stash/internal/dlna" "github.com/stashapp/stash/internal/dlna"
"github.com/stashapp/stash/internal/log" "github.com/stashapp/stash/internal/log"
"github.com/stashapp/stash/internal/manager/config" "github.com/stashapp/stash/internal/manager/config"
"github.com/stashapp/stash/pkg/database"
"github.com/stashapp/stash/pkg/ffmpeg" "github.com/stashapp/stash/pkg/ffmpeg"
"github.com/stashapp/stash/pkg/fsutil" "github.com/stashapp/stash/pkg/fsutil"
"github.com/stashapp/stash/pkg/job" "github.com/stashapp/stash/pkg/job"
@@ -115,7 +114,8 @@ type Manager struct {
DLNAService *dlna.Service DLNAService *dlna.Service
TxnManager models.TransactionManager Database *sqlite.Database
Repository models.Repository
scanSubs *subscriptionManager scanSubs *subscriptionManager
} }
@@ -150,6 +150,8 @@ func initialize() error {
l := initLog() l := initLog()
initProfiling(cfg.GetCPUProfilePath()) initProfiling(cfg.GetCPUProfilePath())
db := &sqlite.Database{}
instance = &Manager{ instance = &Manager{
Config: cfg, Config: cfg,
Logger: l, Logger: l,
@@ -157,7 +159,20 @@ func initialize() error {
DownloadStore: NewDownloadStore(), DownloadStore: NewDownloadStore(),
PluginCache: plugin.NewCache(cfg), PluginCache: plugin.NewCache(cfg),
TxnManager: sqlite.NewTransactionManager(), Database: db,
Repository: models.Repository{
TxnManager: db,
Gallery: sqlite.GalleryReaderWriter,
Image: sqlite.ImageReaderWriter,
Movie: sqlite.MovieReaderWriter,
Performer: sqlite.PerformerReaderWriter,
Scene: sqlite.SceneReaderWriter,
SceneMarker: sqlite.SceneMarkerReaderWriter,
ScrapedItem: sqlite.ScrapedItemReaderWriter,
Studio: sqlite.StudioReaderWriter,
Tag: sqlite.TagReaderWriter,
SavedFilter: sqlite.SavedFilterReaderWriter,
},
scanSubs: &subscriptionManager{}, scanSubs: &subscriptionManager{},
} }
@@ -165,9 +180,17 @@ func initialize() error {
instance.JobManager = initJobManager() instance.JobManager = initJobManager()
sceneServer := SceneServer{ sceneServer := SceneServer{
TXNManager: instance.TxnManager, TxnManager: instance.Repository,
SceneCoverGetter: instance.Repository.Scene,
} }
instance.DLNAService = dlna.NewService(instance.TxnManager, instance.Config, &sceneServer)
instance.DLNAService = dlna.NewService(instance.Repository, dlna.Repository{
SceneFinder: instance.Repository.Scene,
StudioFinder: instance.Repository.Studio,
TagFinder: instance.Repository.Tag,
PerformerFinder: instance.Repository.Performer,
MovieFinder: instance.Repository.Movie,
}, instance.Config, &sceneServer)
if !cfg.IsNewSystem() { if !cfg.IsNewSystem() {
logger.Infof("using config file: %s", cfg.GetConfigFile()) logger.Infof("using config file: %s", cfg.GetConfigFile())
@@ -177,9 +200,14 @@ func initialize() error {
} }
if err != nil { if err != nil {
return fmt.Errorf("error initializing configuration: %w", err) panic(fmt.Sprintf("error initializing configuration: %s", err.Error()))
} else if err := instance.PostInit(ctx); err != nil { } else if err := instance.PostInit(ctx); err != nil {
return err var migrationNeededErr *sqlite.MigrationNeededError
if errors.As(err, &migrationNeededErr) {
logger.Warn(err.Error())
} else {
panic(err)
}
} }
initSecurity(cfg) initSecurity(cfg)
@@ -352,7 +380,8 @@ func (s *Manager) PostInit(ctx context.Context) error {
}) })
} }
if err := database.Initialize(s.Config.GetDatabasePath()); err != nil { database := s.Database
if err := database.Open(s.Config.GetDatabasePath()); err != nil {
return err return err
} }
@@ -377,7 +406,14 @@ func writeStashIcon() {
// initScraperCache initializes a new scraper cache and returns it. // initScraperCache initializes a new scraper cache and returns it.
func (s *Manager) initScraperCache() *scraper.Cache { func (s *Manager) initScraperCache() *scraper.Cache {
ret, err := scraper.NewCache(config.GetInstance(), s.TxnManager) ret, err := scraper.NewCache(config.GetInstance(), s.Repository, scraper.Repository{
SceneFinder: s.Repository.Scene,
GalleryFinder: s.Repository.Gallery,
TagFinder: s.Repository.Tag,
PerformerFinder: s.Repository.Performer,
MovieFinder: s.Repository.Movie,
StudioFinder: s.Repository.Studio,
})
if err != nil { if err != nil {
logger.Errorf("Error reading scraper configs: %s", err.Error()) logger.Errorf("Error reading scraper configs: %s", err.Error())
@@ -476,8 +512,13 @@ func (s *Manager) Setup(ctx context.Context, input SetupInput) error {
// initialise the database // initialise the database
if err := s.PostInit(ctx); err != nil { if err := s.PostInit(ctx); err != nil {
var migrationNeededErr *sqlite.MigrationNeededError
if errors.As(err, &migrationNeededErr) {
logger.Warn(err.Error())
} else {
return fmt.Errorf("error initializing the database: %v", err) return fmt.Errorf("error initializing the database: %v", err)
} }
}
s.Config.FinalizeSetup() s.Config.FinalizeSetup()
@@ -501,6 +542,8 @@ type MigrateInput struct {
} }
func (s *Manager) Migrate(ctx context.Context, input MigrateInput) error { func (s *Manager) Migrate(ctx context.Context, input MigrateInput) error {
database := s.Database
// always backup so that we can roll back to the previous version if // always backup so that we can roll back to the previous version if
// migration fails // migration fails
backupPath := input.BackupPath backupPath := input.BackupPath
@@ -509,7 +552,7 @@ func (s *Manager) Migrate(ctx context.Context, input MigrateInput) error {
} }
// perform database backup // perform database backup
if err := database.Backup(database.DB, backupPath); err != nil { if err := database.Backup(backupPath); err != nil {
return fmt.Errorf("error backing up database: %s", err) return fmt.Errorf("error backing up database: %s", err)
} }
@@ -541,6 +584,7 @@ func (s *Manager) Migrate(ctx context.Context, input MigrateInput) error {
} }
func (s *Manager) GetSystemStatus() *SystemStatus { func (s *Manager) GetSystemStatus() *SystemStatus {
database := s.Database
status := SystemStatusEnumOk status := SystemStatusEnumOk
dbSchema := int(database.Version()) dbSchema := int(database.Version())
dbPath := database.DatabasePath() dbPath := database.DatabasePath()
@@ -569,7 +613,7 @@ func (s *Manager) Shutdown(code int) {
// TODO: Each part of the manager needs to gracefully stop at some point // TODO: Each part of the manager needs to gracefully stop at some point
// for now, we just close the database. // for now, we just close the database.
err := database.Close() err := s.Database.Close()
if err != nil { if err != nil {
logger.Errorf("Error closing database: %s", err) logger.Errorf("Error closing database: %s", err)
if code == 0 { if code == 0 {

View File

@@ -84,7 +84,7 @@ func (s *Manager) Scan(ctx context.Context, input ScanMetadataInput) (int, error
} }
scanJob := ScanJob{ scanJob := ScanJob{
txnManager: s.TxnManager, txnManager: s.Repository,
input: input, input: input,
subscriptions: s.scanSubs, subscriptions: s.scanSubs,
} }
@@ -101,7 +101,7 @@ func (s *Manager) Import(ctx context.Context) (int, error) {
j := job.MakeJobExec(func(ctx context.Context, progress *job.Progress) { j := job.MakeJobExec(func(ctx context.Context, progress *job.Progress) {
task := ImportTask{ task := ImportTask{
txnManager: s.TxnManager, txnManager: s.Repository,
BaseDir: metadataPath, BaseDir: metadataPath,
Reset: true, Reset: true,
DuplicateBehaviour: ImportDuplicateEnumFail, DuplicateBehaviour: ImportDuplicateEnumFail,
@@ -125,7 +125,7 @@ func (s *Manager) Export(ctx context.Context) (int, error) {
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
task := ExportTask{ task := ExportTask{
txnManager: s.TxnManager, txnManager: s.Repository,
full: true, full: true,
fileNamingAlgorithm: config.GetVideoFileNamingAlgorithm(), fileNamingAlgorithm: config.GetVideoFileNamingAlgorithm(),
} }
@@ -156,7 +156,7 @@ func (s *Manager) Generate(ctx context.Context, input GenerateMetadataInput) (in
} }
j := &GenerateJob{ j := &GenerateJob{
txnManager: s.TxnManager, txnManager: s.Repository,
input: input, input: input,
} }
@@ -185,9 +185,9 @@ func (s *Manager) generateScreenshot(ctx context.Context, sceneId string, at *fl
} }
var scene *models.Scene var scene *models.Scene
if err := s.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { if err := s.Repository.WithTxn(ctx, func(ctx context.Context) error {
var err error var err error
scene, err = r.Scene().Find(sceneIdInt) scene, err = s.Repository.Scene.Find(ctx, sceneIdInt)
return err return err
}); err != nil || scene == nil { }); err != nil || scene == nil {
logger.Errorf("failed to get scene for generate: %s", err.Error()) logger.Errorf("failed to get scene for generate: %s", err.Error())
@@ -195,7 +195,7 @@ func (s *Manager) generateScreenshot(ctx context.Context, sceneId string, at *fl
} }
task := GenerateScreenshotTask{ task := GenerateScreenshotTask{
txnManager: s.TxnManager, txnManager: s.Repository,
Scene: *scene, Scene: *scene,
ScreenshotAt: at, ScreenshotAt: at,
fileNamingAlgorithm: config.GetInstance().GetVideoFileNamingAlgorithm(), fileNamingAlgorithm: config.GetInstance().GetVideoFileNamingAlgorithm(),
@@ -222,7 +222,7 @@ type AutoTagMetadataInput struct {
func (s *Manager) AutoTag(ctx context.Context, input AutoTagMetadataInput) int { func (s *Manager) AutoTag(ctx context.Context, input AutoTagMetadataInput) int {
j := autoTagJob{ j := autoTagJob{
txnManager: s.TxnManager, txnManager: s.Repository,
input: input, input: input,
} }
@@ -237,7 +237,7 @@ type CleanMetadataInput struct {
func (s *Manager) Clean(ctx context.Context, input CleanMetadataInput) int { func (s *Manager) Clean(ctx context.Context, input CleanMetadataInput) int {
j := cleanJob{ j := cleanJob{
txnManager: s.TxnManager, txnManager: s.Repository,
input: input, input: input,
scanSubs: s.scanSubs, scanSubs: s.scanSubs,
} }
@@ -251,9 +251,9 @@ func (s *Manager) MigrateHash(ctx context.Context) int {
logger.Infof("Migrating generated files for %s naming hash", fileNamingAlgo.String()) logger.Infof("Migrating generated files for %s naming hash", fileNamingAlgo.String())
var scenes []*models.Scene var scenes []*models.Scene
if err := s.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { if err := s.Repository.WithTxn(ctx, func(ctx context.Context) error {
var err error var err error
scenes, err = r.Scene().All() scenes, err = s.Repository.Scene.All(ctx)
return err return err
}); err != nil { }); err != nil {
logger.Errorf("failed to fetch list of scenes for migration: %s", err.Error()) logger.Errorf("failed to fetch list of scenes for migration: %s", err.Error())
@@ -327,15 +327,14 @@ func (s *Manager) StashBoxBatchPerformerTag(ctx context.Context, input StashBoxB
// This is why we mark this section nolint. In principle, we should look to // This is why we mark this section nolint. In principle, we should look to
// rewrite the section at some point, to avoid the linter warning. // rewrite the section at some point, to avoid the linter warning.
if len(input.PerformerIds) > 0 { //nolint:gocritic if len(input.PerformerIds) > 0 { //nolint:gocritic
if err := s.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { if err := s.Repository.WithTxn(ctx, func(ctx context.Context) error {
performerQuery := r.Performer() performerQuery := s.Repository.Performer
for _, performerID := range input.PerformerIds { for _, performerID := range input.PerformerIds {
if id, err := strconv.Atoi(performerID); err == nil { if id, err := strconv.Atoi(performerID); err == nil {
performer, err := performerQuery.Find(id) performer, err := performerQuery.Find(ctx, id)
if err == nil { if err == nil {
tasks = append(tasks, StashBoxPerformerTagTask{ tasks = append(tasks, StashBoxPerformerTagTask{
txnManager: s.TxnManager,
performer: performer, performer: performer,
refresh: input.Refresh, refresh: input.Refresh,
box: box, box: box,
@@ -354,7 +353,6 @@ func (s *Manager) StashBoxBatchPerformerTag(ctx context.Context, input StashBoxB
for i := range input.PerformerNames { for i := range input.PerformerNames {
if len(input.PerformerNames[i]) > 0 { if len(input.PerformerNames[i]) > 0 {
tasks = append(tasks, StashBoxPerformerTagTask{ tasks = append(tasks, StashBoxPerformerTagTask{
txnManager: s.TxnManager,
name: &input.PerformerNames[i], name: &input.PerformerNames[i],
refresh: input.Refresh, refresh: input.Refresh,
box: box, box: box,
@@ -367,14 +365,14 @@ func (s *Manager) StashBoxBatchPerformerTag(ctx context.Context, input StashBoxB
// However, this doesn't really help with readability of the current section. Mark it // However, this doesn't really help with readability of the current section. Mark it
// as nolint for now. In the future we'd like to rewrite this code by factoring some of // as nolint for now. In the future we'd like to rewrite this code by factoring some of
// this into separate functions. // this into separate functions.
if err := s.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { if err := s.Repository.WithTxn(ctx, func(ctx context.Context) error {
performerQuery := r.Performer() performerQuery := s.Repository.Performer
var performers []*models.Performer var performers []*models.Performer
var err error var err error
if input.Refresh { if input.Refresh {
performers, err = performerQuery.FindByStashIDStatus(true, box.Endpoint) performers, err = performerQuery.FindByStashIDStatus(ctx, true, box.Endpoint)
} else { } else {
performers, err = performerQuery.FindByStashIDStatus(false, box.Endpoint) performers, err = performerQuery.FindByStashIDStatus(ctx, false, box.Endpoint)
} }
if err != nil { if err != nil {
return fmt.Errorf("error querying performers: %v", err) return fmt.Errorf("error querying performers: %v", err)
@@ -382,7 +380,6 @@ func (s *Manager) StashBoxBatchPerformerTag(ctx context.Context, input StashBoxB
for _, performer := range performers { for _, performer := range performers {
tasks = append(tasks, StashBoxPerformerTagTask{ tasks = append(tasks, StashBoxPerformerTagTask{
txnManager: s.TxnManager,
performer: performer, performer: performer,
refresh: input.Refresh, refresh: input.Refresh,
box: box, box: box,

View File

@@ -4,5 +4,5 @@ import "context"
// PostMigrate is executed after migrations have been executed. // PostMigrate is executed after migrations have been executed.
func (s *Manager) PostMigrate(ctx context.Context) { func (s *Manager) PostMigrate(ctx context.Context) {
setInitialMD5Config(ctx, s.TxnManager) setInitialMD5Config(ctx, s.Repository, s.Repository.Scene)
} }

View File

@@ -8,6 +8,7 @@ import (
"github.com/stashapp/stash/pkg/fsutil" "github.com/stashapp/stash/pkg/fsutil"
"github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
) )
@@ -49,8 +50,13 @@ func KillRunningStreams(scene *models.Scene, fileNamingAlgo models.HashAlgorithm
instance.ReadLockManager.Cancel(transcodePath) instance.ReadLockManager.Cancel(transcodePath)
} }
type SceneCoverGetter interface {
GetCover(ctx context.Context, sceneID int) ([]byte, error)
}
type SceneServer struct { type SceneServer struct {
TXNManager models.TransactionManager TxnManager txn.Manager
SceneCoverGetter SceneCoverGetter
} }
func (s *SceneServer) StreamSceneDirect(scene *models.Scene, w http.ResponseWriter, r *http.Request) { func (s *SceneServer) StreamSceneDirect(scene *models.Scene, w http.ResponseWriter, r *http.Request) {
@@ -75,8 +81,8 @@ func (s *SceneServer) ServeScreenshot(scene *models.Scene, w http.ResponseWriter
http.ServeFile(w, r, filepath) http.ServeFile(w, r, filepath)
} else { } else {
var cover []byte var cover []byte
err := s.TXNManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { err := txn.WithTxn(r.Context(), s.TxnManager, func(ctx context.Context) error {
cover, _ = repo.Scene().GetCover(scene.ID) cover, _ = s.SceneCoverGetter.GetCover(ctx, scene.ID)
return nil return nil
}) })
if err != nil { if err != nil {

View File

@@ -1,13 +1,15 @@
package manager package manager
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/studio"
) )
func ValidateModifyStudio(studio models.StudioPartial, qb models.StudioReader) error { func ValidateModifyStudio(ctx context.Context, studio models.StudioPartial, qb studio.Finder) error {
if studio.ParentID == nil || !studio.ParentID.Valid { if studio.ParentID == nil || !studio.ParentID.Valid {
return nil return nil
} }
@@ -22,7 +24,7 @@ func ValidateModifyStudio(studio models.StudioPartial, qb models.StudioReader) e
return errors.New("studio cannot be an ancestor of itself") return errors.New("studio cannot be an ancestor of itself")
} }
currentStudio, err := qb.Find(int(currentParentID.Int64)) currentStudio, err := qb.Find(ctx, int(currentParentID.Int64))
if err != nil { if err != nil {
return fmt.Errorf("error finding parent studio: %v", err) return fmt.Errorf("error finding parent studio: %v", err)
} }

View File

@@ -19,7 +19,7 @@ import (
) )
type autoTagJob struct { type autoTagJob struct {
txnManager models.TransactionManager txnManager models.Repository
input AutoTagMetadataInput input AutoTagMetadataInput
cache match.Cache cache match.Cache
@@ -73,27 +73,28 @@ func (j *autoTagJob) autoTagSpecific(ctx context.Context, progress *job.Progress
studioCount := len(studioIds) studioCount := len(studioIds)
tagCount := len(tagIds) tagCount := len(tagIds)
if err := j.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { if err := j.txnManager.WithTxn(ctx, func(ctx context.Context) error {
performerQuery := r.Performer() r := j.txnManager
studioQuery := r.Studio() performerQuery := r.Performer
tagQuery := r.Tag() studioQuery := r.Studio
tagQuery := r.Tag
const wildcard = "*" const wildcard = "*"
var err error var err error
if performerCount == 1 && performerIds[0] == wildcard { if performerCount == 1 && performerIds[0] == wildcard {
performerCount, err = performerQuery.Count() performerCount, err = performerQuery.Count(ctx)
if err != nil { if err != nil {
return fmt.Errorf("error getting performer count: %v", err) return fmt.Errorf("error getting performer count: %v", err)
} }
} }
if studioCount == 1 && studioIds[0] == wildcard { if studioCount == 1 && studioIds[0] == wildcard {
studioCount, err = studioQuery.Count() studioCount, err = studioQuery.Count(ctx)
if err != nil { if err != nil {
return fmt.Errorf("error getting studio count: %v", err) return fmt.Errorf("error getting studio count: %v", err)
} }
} }
if tagCount == 1 && tagIds[0] == wildcard { if tagCount == 1 && tagIds[0] == wildcard {
tagCount, err = tagQuery.Count() tagCount, err = tagQuery.Count(ctx)
if err != nil { if err != nil {
return fmt.Errorf("error getting tag count: %v", err) return fmt.Errorf("error getting tag count: %v", err)
} }
@@ -123,14 +124,14 @@ func (j *autoTagJob) autoTagPerformers(ctx context.Context, progress *job.Progre
for _, performerId := range performerIds { for _, performerId := range performerIds {
var performers []*models.Performer var performers []*models.Performer
if err := j.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { if err := j.txnManager.WithTxn(ctx, func(ctx context.Context) error {
performerQuery := r.Performer() performerQuery := j.txnManager.Performer
ignoreAutoTag := false ignoreAutoTag := false
perPage := -1 perPage := -1
if performerId == "*" { if performerId == "*" {
var err error var err error
performers, _, err = performerQuery.Query(&models.PerformerFilterType{ performers, _, err = performerQuery.Query(ctx, &models.PerformerFilterType{
IgnoreAutoTag: &ignoreAutoTag, IgnoreAutoTag: &ignoreAutoTag,
}, &models.FindFilterType{ }, &models.FindFilterType{
PerPage: &perPage, PerPage: &perPage,
@@ -144,7 +145,7 @@ func (j *autoTagJob) autoTagPerformers(ctx context.Context, progress *job.Progre
return fmt.Errorf("error parsing performer id %s: %s", performerId, err.Error()) return fmt.Errorf("error parsing performer id %s: %s", performerId, err.Error())
} }
performer, err := performerQuery.Find(performerIdInt) performer, err := performerQuery.Find(ctx, performerIdInt)
if err != nil { if err != nil {
return fmt.Errorf("error finding performer id %s: %s", performerId, err.Error()) return fmt.Errorf("error finding performer id %s: %s", performerId, err.Error())
} }
@@ -161,14 +162,15 @@ func (j *autoTagJob) autoTagPerformers(ctx context.Context, progress *job.Progre
return nil return nil
} }
if err := j.txnManager.WithTxn(ctx, func(r models.Repository) error { if err := j.txnManager.WithTxn(ctx, func(ctx context.Context) error {
if err := autotag.PerformerScenes(performer, paths, r.Scene(), &j.cache); err != nil { r := j.txnManager
if err := autotag.PerformerScenes(ctx, performer, paths, r.Scene, &j.cache); err != nil {
return err return err
} }
if err := autotag.PerformerImages(performer, paths, r.Image(), &j.cache); err != nil { if err := autotag.PerformerImages(ctx, performer, paths, r.Image, &j.cache); err != nil {
return err return err
} }
if err := autotag.PerformerGalleries(performer, paths, r.Gallery(), &j.cache); err != nil { if err := autotag.PerformerGalleries(ctx, performer, paths, r.Gallery, &j.cache); err != nil {
return err return err
} }
@@ -193,16 +195,18 @@ func (j *autoTagJob) autoTagStudios(ctx context.Context, progress *job.Progress,
return return
} }
r := j.txnManager
for _, studioId := range studioIds { for _, studioId := range studioIds {
var studios []*models.Studio var studios []*models.Studio
if err := j.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { if err := r.WithTxn(ctx, func(ctx context.Context) error {
studioQuery := r.Studio() studioQuery := r.Studio
ignoreAutoTag := false ignoreAutoTag := false
perPage := -1 perPage := -1
if studioId == "*" { if studioId == "*" {
var err error var err error
studios, _, err = studioQuery.Query(&models.StudioFilterType{ studios, _, err = studioQuery.Query(ctx, &models.StudioFilterType{
IgnoreAutoTag: &ignoreAutoTag, IgnoreAutoTag: &ignoreAutoTag,
}, &models.FindFilterType{ }, &models.FindFilterType{
PerPage: &perPage, PerPage: &perPage,
@@ -216,7 +220,7 @@ func (j *autoTagJob) autoTagStudios(ctx context.Context, progress *job.Progress,
return fmt.Errorf("error parsing studio id %s: %s", studioId, err.Error()) return fmt.Errorf("error parsing studio id %s: %s", studioId, err.Error())
} }
studio, err := studioQuery.Find(studioIdInt) studio, err := studioQuery.Find(ctx, studioIdInt)
if err != nil { if err != nil {
return fmt.Errorf("error finding studio id %s: %s", studioId, err.Error()) return fmt.Errorf("error finding studio id %s: %s", studioId, err.Error())
} }
@@ -234,19 +238,19 @@ func (j *autoTagJob) autoTagStudios(ctx context.Context, progress *job.Progress,
return nil return nil
} }
if err := j.txnManager.WithTxn(ctx, func(r models.Repository) error { if err := j.txnManager.WithTxn(ctx, func(ctx context.Context) error {
aliases, err := r.Studio().GetAliases(studio.ID) aliases, err := r.Studio.GetAliases(ctx, studio.ID)
if err != nil { if err != nil {
return err return err
} }
if err := autotag.StudioScenes(studio, paths, aliases, r.Scene(), &j.cache); err != nil { if err := autotag.StudioScenes(ctx, studio, paths, aliases, r.Scene, &j.cache); err != nil {
return err return err
} }
if err := autotag.StudioImages(studio, paths, aliases, r.Image(), &j.cache); err != nil { if err := autotag.StudioImages(ctx, studio, paths, aliases, r.Image, &j.cache); err != nil {
return err return err
} }
if err := autotag.StudioGalleries(studio, paths, aliases, r.Gallery(), &j.cache); err != nil { if err := autotag.StudioGalleries(ctx, studio, paths, aliases, r.Gallery, &j.cache); err != nil {
return err return err
} }
@@ -271,15 +275,17 @@ func (j *autoTagJob) autoTagTags(ctx context.Context, progress *job.Progress, pa
return return
} }
r := j.txnManager
for _, tagId := range tagIds { for _, tagId := range tagIds {
var tags []*models.Tag var tags []*models.Tag
if err := j.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { if err := j.txnManager.WithTxn(ctx, func(ctx context.Context) error {
tagQuery := r.Tag() tagQuery := r.Tag
ignoreAutoTag := false ignoreAutoTag := false
perPage := -1 perPage := -1
if tagId == "*" { if tagId == "*" {
var err error var err error
tags, _, err = tagQuery.Query(&models.TagFilterType{ tags, _, err = tagQuery.Query(ctx, &models.TagFilterType{
IgnoreAutoTag: &ignoreAutoTag, IgnoreAutoTag: &ignoreAutoTag,
}, &models.FindFilterType{ }, &models.FindFilterType{
PerPage: &perPage, PerPage: &perPage,
@@ -293,7 +299,7 @@ func (j *autoTagJob) autoTagTags(ctx context.Context, progress *job.Progress, pa
return fmt.Errorf("error parsing tag id %s: %s", tagId, err.Error()) return fmt.Errorf("error parsing tag id %s: %s", tagId, err.Error())
} }
tag, err := tagQuery.Find(tagIdInt) tag, err := tagQuery.Find(ctx, tagIdInt)
if err != nil { if err != nil {
return fmt.Errorf("error finding tag id %s: %s", tagId, err.Error()) return fmt.Errorf("error finding tag id %s: %s", tagId, err.Error())
} }
@@ -306,19 +312,19 @@ func (j *autoTagJob) autoTagTags(ctx context.Context, progress *job.Progress, pa
return nil return nil
} }
if err := j.txnManager.WithTxn(ctx, func(r models.Repository) error { if err := j.txnManager.WithTxn(ctx, func(ctx context.Context) error {
aliases, err := r.Tag().GetAliases(tag.ID) aliases, err := r.Tag.GetAliases(ctx, tag.ID)
if err != nil { if err != nil {
return err return err
} }
if err := autotag.TagScenes(tag, paths, aliases, r.Scene(), &j.cache); err != nil { if err := autotag.TagScenes(ctx, tag, paths, aliases, r.Scene, &j.cache); err != nil {
return err return err
} }
if err := autotag.TagImages(tag, paths, aliases, r.Image(), &j.cache); err != nil { if err := autotag.TagImages(ctx, tag, paths, aliases, r.Image, &j.cache); err != nil {
return err return err
} }
if err := autotag.TagGalleries(tag, paths, aliases, r.Gallery(), &j.cache); err != nil { if err := autotag.TagGalleries(ctx, tag, paths, aliases, r.Gallery, &j.cache); err != nil {
return err return err
} }
@@ -345,7 +351,7 @@ type autoTagFilesTask struct {
tags bool tags bool
progress *job.Progress progress *job.Progress
txnManager models.TransactionManager txnManager models.Repository
cache *match.Cache cache *match.Cache
} }
@@ -425,13 +431,13 @@ func (t *autoTagFilesTask) makeGalleryFilter() *models.GalleryFilterType {
return ret return ret
} }
func (t *autoTagFilesTask) getCount(r models.ReaderRepository) (int, error) { func (t *autoTagFilesTask) getCount(ctx context.Context, r models.Repository) (int, error) {
pp := 0 pp := 0
findFilter := &models.FindFilterType{ findFilter := &models.FindFilterType{
PerPage: &pp, PerPage: &pp,
} }
sceneResults, err := r.Scene().Query(models.SceneQueryOptions{ sceneResults, err := r.Scene.Query(ctx, models.SceneQueryOptions{
QueryOptions: models.QueryOptions{ QueryOptions: models.QueryOptions{
FindFilter: findFilter, FindFilter: findFilter,
Count: true, Count: true,
@@ -444,7 +450,7 @@ func (t *autoTagFilesTask) getCount(r models.ReaderRepository) (int, error) {
sceneCount := sceneResults.Count sceneCount := sceneResults.Count
imageResults, err := r.Image().Query(models.ImageQueryOptions{ imageResults, err := r.Image.Query(ctx, models.ImageQueryOptions{
QueryOptions: models.QueryOptions{ QueryOptions: models.QueryOptions{
FindFilter: findFilter, FindFilter: findFilter,
Count: true, Count: true,
@@ -457,7 +463,7 @@ func (t *autoTagFilesTask) getCount(r models.ReaderRepository) (int, error) {
imageCount := imageResults.Count imageCount := imageResults.Count
_, galleryCount, err := r.Gallery().Query(t.makeGalleryFilter(), findFilter) _, galleryCount, err := r.Gallery.Query(ctx, t.makeGalleryFilter(), findFilter)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@@ -465,7 +471,7 @@ func (t *autoTagFilesTask) getCount(r models.ReaderRepository) (int, error) {
return sceneCount + imageCount + galleryCount, nil return sceneCount + imageCount + galleryCount, nil
} }
func (t *autoTagFilesTask) processScenes(ctx context.Context, r models.ReaderRepository) error { func (t *autoTagFilesTask) processScenes(ctx context.Context, r models.Repository) error {
if job.IsCancelled(ctx) { if job.IsCancelled(ctx) {
return nil return nil
} }
@@ -477,7 +483,7 @@ func (t *autoTagFilesTask) processScenes(ctx context.Context, r models.ReaderRep
more := true more := true
for more { for more {
scenes, err := scene.Query(r.Scene(), sceneFilter, findFilter) scenes, err := scene.Query(ctx, r.Scene, sceneFilter, findFilter)
if err != nil { if err != nil {
return err return err
} }
@@ -518,7 +524,7 @@ func (t *autoTagFilesTask) processScenes(ctx context.Context, r models.ReaderRep
return nil return nil
} }
func (t *autoTagFilesTask) processImages(ctx context.Context, r models.ReaderRepository) error { func (t *autoTagFilesTask) processImages(ctx context.Context, r models.Repository) error {
if job.IsCancelled(ctx) { if job.IsCancelled(ctx) {
return nil return nil
} }
@@ -530,7 +536,7 @@ func (t *autoTagFilesTask) processImages(ctx context.Context, r models.ReaderRep
more := true more := true
for more { for more {
images, err := image.Query(r.Image(), imageFilter, findFilter) images, err := image.Query(ctx, r.Image, imageFilter, findFilter)
if err != nil { if err != nil {
return err return err
} }
@@ -571,7 +577,7 @@ func (t *autoTagFilesTask) processImages(ctx context.Context, r models.ReaderRep
return nil return nil
} }
func (t *autoTagFilesTask) processGalleries(ctx context.Context, r models.ReaderRepository) error { func (t *autoTagFilesTask) processGalleries(ctx context.Context, r models.Repository) error {
if job.IsCancelled(ctx) { if job.IsCancelled(ctx) {
return nil return nil
} }
@@ -583,7 +589,7 @@ func (t *autoTagFilesTask) processGalleries(ctx context.Context, r models.Reader
more := true more := true
for more { for more {
galleries, _, err := r.Gallery().Query(galleryFilter, findFilter) galleries, _, err := r.Gallery.Query(ctx, galleryFilter, findFilter)
if err != nil { if err != nil {
return err return err
} }
@@ -625,8 +631,9 @@ func (t *autoTagFilesTask) processGalleries(ctx context.Context, r models.Reader
} }
func (t *autoTagFilesTask) process(ctx context.Context) { func (t *autoTagFilesTask) process(ctx context.Context) {
if err := t.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { r := t.txnManager
total, err := t.getCount(r) if err := r.WithTxn(ctx, func(ctx context.Context) error {
total, err := t.getCount(ctx, t.txnManager)
if err != nil { if err != nil {
return err return err
} }
@@ -661,7 +668,7 @@ func (t *autoTagFilesTask) process(ctx context.Context) {
} }
type autoTagSceneTask struct { type autoTagSceneTask struct {
txnManager models.TransactionManager txnManager models.Repository
scene *models.Scene scene *models.Scene
performers bool performers bool
@@ -673,19 +680,20 @@ type autoTagSceneTask struct {
func (t *autoTagSceneTask) Start(ctx context.Context, wg *sync.WaitGroup) { func (t *autoTagSceneTask) Start(ctx context.Context, wg *sync.WaitGroup) {
defer wg.Done() defer wg.Done()
if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error { r := t.txnManager
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
if t.performers { if t.performers {
if err := autotag.ScenePerformers(t.scene, r.Scene(), r.Performer(), t.cache); err != nil { if err := autotag.ScenePerformers(ctx, t.scene, r.Scene, r.Performer, t.cache); err != nil {
return fmt.Errorf("error tagging scene performers for %s: %v", t.scene.Path, err) return fmt.Errorf("error tagging scene performers for %s: %v", t.scene.Path, err)
} }
} }
if t.studios { if t.studios {
if err := autotag.SceneStudios(t.scene, r.Scene(), r.Studio(), t.cache); err != nil { if err := autotag.SceneStudios(ctx, t.scene, r.Scene, r.Studio, t.cache); err != nil {
return fmt.Errorf("error tagging scene studio for %s: %v", t.scene.Path, err) return fmt.Errorf("error tagging scene studio for %s: %v", t.scene.Path, err)
} }
} }
if t.tags { if t.tags {
if err := autotag.SceneTags(t.scene, r.Scene(), r.Tag(), t.cache); err != nil { if err := autotag.SceneTags(ctx, t.scene, r.Scene, r.Tag, t.cache); err != nil {
return fmt.Errorf("error tagging scene tags for %s: %v", t.scene.Path, err) return fmt.Errorf("error tagging scene tags for %s: %v", t.scene.Path, err)
} }
} }
@@ -697,7 +705,7 @@ func (t *autoTagSceneTask) Start(ctx context.Context, wg *sync.WaitGroup) {
} }
type autoTagImageTask struct { type autoTagImageTask struct {
txnManager models.TransactionManager txnManager models.Repository
image *models.Image image *models.Image
performers bool performers bool
@@ -709,19 +717,20 @@ type autoTagImageTask struct {
func (t *autoTagImageTask) Start(ctx context.Context, wg *sync.WaitGroup) { func (t *autoTagImageTask) Start(ctx context.Context, wg *sync.WaitGroup) {
defer wg.Done() defer wg.Done()
if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error { r := t.txnManager
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
if t.performers { if t.performers {
if err := autotag.ImagePerformers(t.image, r.Image(), r.Performer(), t.cache); err != nil { if err := autotag.ImagePerformers(ctx, t.image, r.Image, r.Performer, t.cache); err != nil {
return fmt.Errorf("error tagging image performers for %s: %v", t.image.Path, err) return fmt.Errorf("error tagging image performers for %s: %v", t.image.Path, err)
} }
} }
if t.studios { if t.studios {
if err := autotag.ImageStudios(t.image, r.Image(), r.Studio(), t.cache); err != nil { if err := autotag.ImageStudios(ctx, t.image, r.Image, r.Studio, t.cache); err != nil {
return fmt.Errorf("error tagging image studio for %s: %v", t.image.Path, err) return fmt.Errorf("error tagging image studio for %s: %v", t.image.Path, err)
} }
} }
if t.tags { if t.tags {
if err := autotag.ImageTags(t.image, r.Image(), r.Tag(), t.cache); err != nil { if err := autotag.ImageTags(ctx, t.image, r.Image, r.Tag, t.cache); err != nil {
return fmt.Errorf("error tagging image tags for %s: %v", t.image.Path, err) return fmt.Errorf("error tagging image tags for %s: %v", t.image.Path, err)
} }
} }
@@ -733,7 +742,7 @@ func (t *autoTagImageTask) Start(ctx context.Context, wg *sync.WaitGroup) {
} }
type autoTagGalleryTask struct { type autoTagGalleryTask struct {
txnManager models.TransactionManager txnManager models.Repository
gallery *models.Gallery gallery *models.Gallery
performers bool performers bool
@@ -745,19 +754,20 @@ type autoTagGalleryTask struct {
func (t *autoTagGalleryTask) Start(ctx context.Context, wg *sync.WaitGroup) { func (t *autoTagGalleryTask) Start(ctx context.Context, wg *sync.WaitGroup) {
defer wg.Done() defer wg.Done()
if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error { r := t.txnManager
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
if t.performers { if t.performers {
if err := autotag.GalleryPerformers(t.gallery, r.Gallery(), r.Performer(), t.cache); err != nil { if err := autotag.GalleryPerformers(ctx, t.gallery, r.Gallery, r.Performer, t.cache); err != nil {
return fmt.Errorf("error tagging gallery performers for %s: %v", t.gallery.Path.String, err) return fmt.Errorf("error tagging gallery performers for %s: %v", t.gallery.Path.String, err)
} }
} }
if t.studios { if t.studios {
if err := autotag.GalleryStudios(t.gallery, r.Gallery(), r.Studio(), t.cache); err != nil { if err := autotag.GalleryStudios(ctx, t.gallery, r.Gallery, r.Studio, t.cache); err != nil {
return fmt.Errorf("error tagging gallery studio for %s: %v", t.gallery.Path.String, err) return fmt.Errorf("error tagging gallery studio for %s: %v", t.gallery.Path.String, err)
} }
} }
if t.tags { if t.tags {
if err := autotag.GalleryTags(t.gallery, r.Gallery(), r.Tag(), t.cache); err != nil { if err := autotag.GalleryTags(ctx, t.gallery, r.Gallery, r.Tag, t.cache); err != nil {
return fmt.Errorf("error tagging gallery tags for %s: %v", t.gallery.Path.String, err) return fmt.Errorf("error tagging gallery tags for %s: %v", t.gallery.Path.String, err)
} }
} }

View File

@@ -18,7 +18,7 @@ import (
) )
type cleanJob struct { type cleanJob struct {
txnManager models.TransactionManager txnManager models.Repository
input CleanMetadataInput input CleanMetadataInput
scanSubs *subscriptionManager scanSubs *subscriptionManager
} }
@@ -29,8 +29,10 @@ func (j *cleanJob) Execute(ctx context.Context, progress *job.Progress) {
logger.Infof("Running in Dry Mode") logger.Infof("Running in Dry Mode")
} }
if err := j.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { r := j.txnManager
total, err := j.getCount(r)
if err := j.txnManager.WithTxn(ctx, func(ctx context.Context) error {
total, err := j.getCount(ctx, r)
if err != nil { if err != nil {
return fmt.Errorf("error getting count: %w", err) return fmt.Errorf("error getting count: %w", err)
} }
@@ -41,13 +43,13 @@ func (j *cleanJob) Execute(ctx context.Context, progress *job.Progress) {
return nil return nil
} }
if err := j.processScenes(ctx, progress, r.Scene()); err != nil { if err := j.processScenes(ctx, progress, r.Scene); err != nil {
return fmt.Errorf("error cleaning scenes: %w", err) return fmt.Errorf("error cleaning scenes: %w", err)
} }
if err := j.processImages(ctx, progress, r.Image()); err != nil { if err := j.processImages(ctx, progress, r.Image); err != nil {
return fmt.Errorf("error cleaning images: %w", err) return fmt.Errorf("error cleaning images: %w", err)
} }
if err := j.processGalleries(ctx, progress, r.Gallery(), r.Image()); err != nil { if err := j.processGalleries(ctx, progress, r.Gallery, r.Image); err != nil {
return fmt.Errorf("error cleaning galleries: %w", err) return fmt.Errorf("error cleaning galleries: %w", err)
} }
@@ -66,9 +68,9 @@ func (j *cleanJob) Execute(ctx context.Context, progress *job.Progress) {
logger.Info("Finished Cleaning") logger.Info("Finished Cleaning")
} }
func (j *cleanJob) getCount(r models.ReaderRepository) (int, error) { func (j *cleanJob) getCount(ctx context.Context, r models.Repository) (int, error) {
sceneFilter := scene.PathsFilter(j.input.Paths) sceneFilter := scene.PathsFilter(j.input.Paths)
sceneResult, err := r.Scene().Query(models.SceneQueryOptions{ sceneResult, err := r.Scene.Query(ctx, models.SceneQueryOptions{
QueryOptions: models.QueryOptions{ QueryOptions: models.QueryOptions{
Count: true, Count: true,
}, },
@@ -78,12 +80,12 @@ func (j *cleanJob) getCount(r models.ReaderRepository) (int, error) {
return 0, err return 0, err
} }
imageCount, err := r.Image().QueryCount(image.PathsFilter(j.input.Paths), nil) imageCount, err := r.Image.QueryCount(ctx, image.PathsFilter(j.input.Paths), nil)
if err != nil { if err != nil {
return 0, err return 0, err
} }
galleryCount, err := r.Gallery().QueryCount(gallery.PathsFilter(j.input.Paths), nil) galleryCount, err := r.Gallery.QueryCount(ctx, gallery.PathsFilter(j.input.Paths), nil)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@@ -91,7 +93,7 @@ func (j *cleanJob) getCount(r models.ReaderRepository) (int, error) {
return sceneResult.Count + imageCount + galleryCount, nil return sceneResult.Count + imageCount + galleryCount, nil
} }
func (j *cleanJob) processScenes(ctx context.Context, progress *job.Progress, qb models.SceneReader) error { func (j *cleanJob) processScenes(ctx context.Context, progress *job.Progress, qb scene.Queryer) error {
batchSize := 1000 batchSize := 1000
findFilter := models.BatchFindFilter(batchSize) findFilter := models.BatchFindFilter(batchSize)
@@ -107,7 +109,7 @@ func (j *cleanJob) processScenes(ctx context.Context, progress *job.Progress, qb
return nil return nil
} }
scenes, err := scene.Query(qb, sceneFilter, findFilter) scenes, err := scene.Query(ctx, qb, sceneFilter, findFilter)
if err != nil { if err != nil {
return fmt.Errorf("error querying for scenes: %w", err) return fmt.Errorf("error querying for scenes: %w", err)
} }
@@ -154,7 +156,7 @@ func (j *cleanJob) processScenes(ctx context.Context, progress *job.Progress, qb
return nil return nil
} }
func (j *cleanJob) processGalleries(ctx context.Context, progress *job.Progress, qb models.GalleryReader, iqb models.ImageReader) error { func (j *cleanJob) processGalleries(ctx context.Context, progress *job.Progress, qb gallery.Queryer, iqb models.ImageReader) error {
batchSize := 1000 batchSize := 1000
findFilter := models.BatchFindFilter(batchSize) findFilter := models.BatchFindFilter(batchSize)
@@ -170,14 +172,14 @@ func (j *cleanJob) processGalleries(ctx context.Context, progress *job.Progress,
return nil return nil
} }
galleries, _, err := qb.Query(galleryFilter, findFilter) galleries, _, err := qb.Query(ctx, galleryFilter, findFilter)
if err != nil { if err != nil {
return fmt.Errorf("error querying for galleries: %w", err) return fmt.Errorf("error querying for galleries: %w", err)
} }
for _, gallery := range galleries { for _, gallery := range galleries {
progress.ExecuteTask(fmt.Sprintf("Assessing gallery %s for clean", gallery.GetTitle()), func() { progress.ExecuteTask(fmt.Sprintf("Assessing gallery %s for clean", gallery.GetTitle()), func() {
if j.shouldCleanGallery(gallery, iqb) { if j.shouldCleanGallery(ctx, gallery, iqb) {
toDelete = append(toDelete, gallery.ID) toDelete = append(toDelete, gallery.ID)
} else { } else {
// increment progress, no further processing // increment progress, no further processing
@@ -215,7 +217,7 @@ func (j *cleanJob) processGalleries(ctx context.Context, progress *job.Progress,
return nil return nil
} }
func (j *cleanJob) processImages(ctx context.Context, progress *job.Progress, qb models.ImageReader) error { func (j *cleanJob) processImages(ctx context.Context, progress *job.Progress, qb image.Queryer) error {
batchSize := 1000 batchSize := 1000
findFilter := models.BatchFindFilter(batchSize) findFilter := models.BatchFindFilter(batchSize)
@@ -234,7 +236,7 @@ func (j *cleanJob) processImages(ctx context.Context, progress *job.Progress, qb
return nil return nil
} }
images, err := image.Query(qb, imageFilter, findFilter) images, err := image.Query(ctx, qb, imageFilter, findFilter)
if err != nil { if err != nil {
return fmt.Errorf("error querying for images: %w", err) return fmt.Errorf("error querying for images: %w", err)
} }
@@ -318,7 +320,7 @@ func (j *cleanJob) shouldCleanScene(s *models.Scene) bool {
return false return false
} }
func (j *cleanJob) shouldCleanGallery(g *models.Gallery, qb models.ImageReader) bool { func (j *cleanJob) shouldCleanGallery(ctx context.Context, g *models.Gallery, qb models.ImageReader) bool {
// never clean manually created galleries // never clean manually created galleries
if !g.Path.Valid { if !g.Path.Valid {
return false return false
@@ -348,7 +350,7 @@ func (j *cleanJob) shouldCleanGallery(g *models.Gallery, qb models.ImageReader)
} }
} else { } else {
// folder-based - delete if it has no images // folder-based - delete if it has no images
count, err := qb.CountByGalleryID(g.ID) count, err := qb.CountByGalleryID(ctx, g.ID)
if err != nil { if err != nil {
logger.Warnf("Error trying to count gallery images for %q: %v", path, err) logger.Warnf("Error trying to count gallery images for %q: %v", path, err)
return false return false
@@ -401,16 +403,17 @@ func (j *cleanJob) deleteScene(ctx context.Context, fileNamingAlgorithm models.H
Paths: GetInstance().Paths, Paths: GetInstance().Paths,
} }
var s *models.Scene var s *models.Scene
if err := j.txnManager.WithTxn(ctx, func(repo models.Repository) error { if err := j.txnManager.WithTxn(ctx, func(ctx context.Context) error {
qb := repo.Scene() repo := j.txnManager
qb := repo.Scene
var err error var err error
s, err = qb.Find(sceneID) s, err = qb.Find(ctx, sceneID)
if err != nil { if err != nil {
return err return err
} }
return scene.Destroy(s, repo, fileDeleter, true, false) return scene.Destroy(ctx, s, repo.Scene, repo.SceneMarker, fileDeleter, true, false)
}); err != nil { }); err != nil {
fileDeleter.Rollback() fileDeleter.Rollback()
@@ -431,16 +434,16 @@ func (j *cleanJob) deleteScene(ctx context.Context, fileNamingAlgorithm models.H
func (j *cleanJob) deleteGallery(ctx context.Context, galleryID int) { func (j *cleanJob) deleteGallery(ctx context.Context, galleryID int) {
var g *models.Gallery var g *models.Gallery
if err := j.txnManager.WithTxn(ctx, func(repo models.Repository) error { if err := j.txnManager.WithTxn(ctx, func(ctx context.Context) error {
qb := repo.Gallery() qb := j.txnManager.Gallery
var err error var err error
g, err = qb.Find(galleryID) g, err = qb.Find(ctx, galleryID)
if err != nil { if err != nil {
return err return err
} }
return qb.Destroy(galleryID) return qb.Destroy(ctx, galleryID)
}); err != nil { }); err != nil {
logger.Errorf("Error deleting gallery from database: %s", err.Error()) logger.Errorf("Error deleting gallery from database: %s", err.Error())
return return
@@ -459,11 +462,11 @@ func (j *cleanJob) deleteImage(ctx context.Context, imageID int) {
} }
var i *models.Image var i *models.Image
if err := j.txnManager.WithTxn(ctx, func(repo models.Repository) error { if err := j.txnManager.WithTxn(ctx, func(ctx context.Context) error {
qb := repo.Image() qb := j.txnManager.Image
var err error var err error
i, err = qb.Find(imageID) i, err = qb.Find(ctx, imageID)
if err != nil { if err != nil {
return err return err
} }
@@ -472,7 +475,7 @@ func (j *cleanJob) deleteImage(ctx context.Context, imageID int) {
return fmt.Errorf("image not found: %d", imageID) return fmt.Errorf("image not found: %d", imageID)
} }
return image.Destroy(i, qb, fileDeleter, true, false) return image.Destroy(ctx, i, qb, fileDeleter, true, false)
}); err != nil { }); err != nil {
fileDeleter.Rollback() fileDeleter.Rollback()

View File

@@ -32,7 +32,7 @@ import (
) )
type ExportTask struct { type ExportTask struct {
txnManager models.TransactionManager txnManager models.Repository
full bool full bool
baseDir string baseDir string
@@ -100,7 +100,7 @@ func CreateExportTask(a models.HashAlgorithm, input ExportObjectsInput) *ExportT
} }
return &ExportTask{ return &ExportTask{
txnManager: GetInstance().TxnManager, txnManager: GetInstance().Repository,
fileNamingAlgorithm: a, fileNamingAlgorithm: a,
scenes: newExportSpec(input.Scenes), scenes: newExportSpec(input.Scenes),
images: newExportSpec(input.Images), images: newExportSpec(input.Images),
@@ -146,30 +146,32 @@ func (t *ExportTask) Start(ctx context.Context, wg *sync.WaitGroup) {
paths.EnsureJSONDirs(t.baseDir) paths.EnsureJSONDirs(t.baseDir)
txnErr := t.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { txnErr := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
r := t.txnManager
// include movie scenes and gallery images // include movie scenes and gallery images
if !t.full { if !t.full {
// only include movie scenes if includeDependencies is also set // only include movie scenes if includeDependencies is also set
if !t.scenes.all && t.includeDependencies { if !t.scenes.all && t.includeDependencies {
t.populateMovieScenes(r) t.populateMovieScenes(ctx, r)
} }
// always export gallery images // always export gallery images
if !t.images.all { if !t.images.all {
t.populateGalleryImages(r) t.populateGalleryImages(ctx, r)
} }
} }
t.ExportScenes(workerCount, r) t.ExportScenes(ctx, workerCount, r)
t.ExportImages(workerCount, r) t.ExportImages(ctx, workerCount, r)
t.ExportGalleries(workerCount, r) t.ExportGalleries(ctx, workerCount, r)
t.ExportMovies(workerCount, r) t.ExportMovies(ctx, workerCount, r)
t.ExportPerformers(workerCount, r) t.ExportPerformers(ctx, workerCount, r)
t.ExportStudios(workerCount, r) t.ExportStudios(ctx, workerCount, r)
t.ExportTags(workerCount, r) t.ExportTags(ctx, workerCount, r)
if t.full { if t.full {
t.ExportScrapedItems(r) t.ExportScrapedItems(ctx, r)
} }
return nil return nil
@@ -284,17 +286,17 @@ func (t *ExportTask) zipFile(fn, outDir string, z *zip.Writer) error {
return nil return nil
} }
func (t *ExportTask) populateMovieScenes(repo models.ReaderRepository) { func (t *ExportTask) populateMovieScenes(ctx context.Context, repo models.Repository) {
reader := repo.Movie() reader := repo.Movie
sceneReader := repo.Scene() sceneReader := repo.Scene
var movies []*models.Movie var movies []*models.Movie
var err error var err error
all := t.full || (t.movies != nil && t.movies.all) all := t.full || (t.movies != nil && t.movies.all)
if all { if all {
movies, err = reader.All() movies, err = reader.All(ctx)
} else if t.movies != nil && len(t.movies.IDs) > 0 { } else if t.movies != nil && len(t.movies.IDs) > 0 {
movies, err = reader.FindMany(t.movies.IDs) movies, err = reader.FindMany(ctx, t.movies.IDs)
} }
if err != nil { if err != nil {
@@ -302,7 +304,7 @@ func (t *ExportTask) populateMovieScenes(repo models.ReaderRepository) {
} }
for _, m := range movies { for _, m := range movies {
scenes, err := sceneReader.FindByMovieID(m.ID) scenes, err := sceneReader.FindByMovieID(ctx, m.ID)
if err != nil { if err != nil {
logger.Errorf("[movies] <%s> failed to fetch scenes for movie: %s", m.Checksum, err.Error()) logger.Errorf("[movies] <%s> failed to fetch scenes for movie: %s", m.Checksum, err.Error())
continue continue
@@ -314,17 +316,17 @@ func (t *ExportTask) populateMovieScenes(repo models.ReaderRepository) {
} }
} }
func (t *ExportTask) populateGalleryImages(repo models.ReaderRepository) { func (t *ExportTask) populateGalleryImages(ctx context.Context, repo models.Repository) {
reader := repo.Gallery() reader := repo.Gallery
imageReader := repo.Image() imageReader := repo.Image
var galleries []*models.Gallery var galleries []*models.Gallery
var err error var err error
all := t.full || (t.galleries != nil && t.galleries.all) all := t.full || (t.galleries != nil && t.galleries.all)
if all { if all {
galleries, err = reader.All() galleries, err = reader.All(ctx)
} else if t.galleries != nil && len(t.galleries.IDs) > 0 { } else if t.galleries != nil && len(t.galleries.IDs) > 0 {
galleries, err = reader.FindMany(t.galleries.IDs) galleries, err = reader.FindMany(ctx, t.galleries.IDs)
} }
if err != nil { if err != nil {
@@ -332,7 +334,7 @@ func (t *ExportTask) populateGalleryImages(repo models.ReaderRepository) {
} }
for _, g := range galleries { for _, g := range galleries {
images, err := imageReader.FindByGalleryID(g.ID) images, err := imageReader.FindByGalleryID(ctx, g.ID)
if err != nil { if err != nil {
logger.Errorf("[galleries] <%s> failed to fetch images for gallery: %s", g.Checksum, err.Error()) logger.Errorf("[galleries] <%s> failed to fetch images for gallery: %s", g.Checksum, err.Error())
continue continue
@@ -344,18 +346,18 @@ func (t *ExportTask) populateGalleryImages(repo models.ReaderRepository) {
} }
} }
func (t *ExportTask) ExportScenes(workers int, repo models.ReaderRepository) { func (t *ExportTask) ExportScenes(ctx context.Context, workers int, repo models.Repository) {
var scenesWg sync.WaitGroup var scenesWg sync.WaitGroup
sceneReader := repo.Scene() sceneReader := repo.Scene
var scenes []*models.Scene var scenes []*models.Scene
var err error var err error
all := t.full || (t.scenes != nil && t.scenes.all) all := t.full || (t.scenes != nil && t.scenes.all)
if all { if all {
scenes, err = sceneReader.All() scenes, err = sceneReader.All(ctx)
} else if t.scenes != nil && len(t.scenes.IDs) > 0 { } else if t.scenes != nil && len(t.scenes.IDs) > 0 {
scenes, err = sceneReader.FindMany(t.scenes.IDs) scenes, err = sceneReader.FindMany(ctx, t.scenes.IDs)
} }
if err != nil { if err != nil {
@@ -369,7 +371,7 @@ func (t *ExportTask) ExportScenes(workers int, repo models.ReaderRepository) {
for w := 0; w < workers; w++ { // create export Scene workers for w := 0; w < workers; w++ { // create export Scene workers
scenesWg.Add(1) scenesWg.Add(1)
go exportScene(&scenesWg, jobCh, repo, t) go exportScene(ctx, &scenesWg, jobCh, repo, t)
} }
for i, scene := range scenes { for i, scene := range scenes {
@@ -388,32 +390,32 @@ func (t *ExportTask) ExportScenes(workers int, repo models.ReaderRepository) {
logger.Infof("[scenes] export complete in %s. %d workers used.", time.Since(startTime), workers) logger.Infof("[scenes] export complete in %s. %d workers used.", time.Since(startTime), workers)
} }
func exportScene(wg *sync.WaitGroup, jobChan <-chan *models.Scene, repo models.ReaderRepository, t *ExportTask) { func exportScene(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Scene, repo models.Repository, t *ExportTask) {
defer wg.Done() defer wg.Done()
sceneReader := repo.Scene() sceneReader := repo.Scene
studioReader := repo.Studio() studioReader := repo.Studio
movieReader := repo.Movie() movieReader := repo.Movie
galleryReader := repo.Gallery() galleryReader := repo.Gallery
performerReader := repo.Performer() performerReader := repo.Performer
tagReader := repo.Tag() tagReader := repo.Tag
sceneMarkerReader := repo.SceneMarker() sceneMarkerReader := repo.SceneMarker
for s := range jobChan { for s := range jobChan {
sceneHash := s.GetHash(t.fileNamingAlgorithm) sceneHash := s.GetHash(t.fileNamingAlgorithm)
newSceneJSON, err := scene.ToBasicJSON(sceneReader, s) newSceneJSON, err := scene.ToBasicJSON(ctx, sceneReader, s)
if err != nil { if err != nil {
logger.Errorf("[scenes] <%s> error getting scene JSON: %s", sceneHash, err.Error()) logger.Errorf("[scenes] <%s> error getting scene JSON: %s", sceneHash, err.Error())
continue continue
} }
newSceneJSON.Studio, err = scene.GetStudioName(studioReader, s) newSceneJSON.Studio, err = scene.GetStudioName(ctx, studioReader, s)
if err != nil { if err != nil {
logger.Errorf("[scenes] <%s> error getting scene studio name: %s", sceneHash, err.Error()) logger.Errorf("[scenes] <%s> error getting scene studio name: %s", sceneHash, err.Error())
continue continue
} }
galleries, err := galleryReader.FindBySceneID(s.ID) galleries, err := galleryReader.FindBySceneID(ctx, s.ID)
if err != nil { if err != nil {
logger.Errorf("[scenes] <%s> error getting scene gallery checksums: %s", sceneHash, err.Error()) logger.Errorf("[scenes] <%s> error getting scene gallery checksums: %s", sceneHash, err.Error())
continue continue
@@ -421,7 +423,7 @@ func exportScene(wg *sync.WaitGroup, jobChan <-chan *models.Scene, repo models.R
newSceneJSON.Galleries = gallery.GetChecksums(galleries) newSceneJSON.Galleries = gallery.GetChecksums(galleries)
performers, err := performerReader.FindBySceneID(s.ID) performers, err := performerReader.FindBySceneID(ctx, s.ID)
if err != nil { if err != nil {
logger.Errorf("[scenes] <%s> error getting scene performer names: %s", sceneHash, err.Error()) logger.Errorf("[scenes] <%s> error getting scene performer names: %s", sceneHash, err.Error())
continue continue
@@ -429,19 +431,19 @@ func exportScene(wg *sync.WaitGroup, jobChan <-chan *models.Scene, repo models.R
newSceneJSON.Performers = performer.GetNames(performers) newSceneJSON.Performers = performer.GetNames(performers)
newSceneJSON.Tags, err = scene.GetTagNames(tagReader, s) newSceneJSON.Tags, err = scene.GetTagNames(ctx, tagReader, s)
if err != nil { if err != nil {
logger.Errorf("[scenes] <%s> error getting scene tag names: %s", sceneHash, err.Error()) logger.Errorf("[scenes] <%s> error getting scene tag names: %s", sceneHash, err.Error())
continue continue
} }
newSceneJSON.Markers, err = scene.GetSceneMarkersJSON(sceneMarkerReader, tagReader, s) newSceneJSON.Markers, err = scene.GetSceneMarkersJSON(ctx, sceneMarkerReader, tagReader, s)
if err != nil { if err != nil {
logger.Errorf("[scenes] <%s> error getting scene markers JSON: %s", sceneHash, err.Error()) logger.Errorf("[scenes] <%s> error getting scene markers JSON: %s", sceneHash, err.Error())
continue continue
} }
newSceneJSON.Movies, err = scene.GetSceneMoviesJSON(movieReader, sceneReader, s) newSceneJSON.Movies, err = scene.GetSceneMoviesJSON(ctx, movieReader, sceneReader, s)
if err != nil { if err != nil {
logger.Errorf("[scenes] <%s> error getting scene movies JSON: %s", sceneHash, err.Error()) logger.Errorf("[scenes] <%s> error getting scene movies JSON: %s", sceneHash, err.Error())
continue continue
@@ -454,14 +456,14 @@ func exportScene(wg *sync.WaitGroup, jobChan <-chan *models.Scene, repo models.R
t.galleries.IDs = intslice.IntAppendUniques(t.galleries.IDs, gallery.GetIDs(galleries)) t.galleries.IDs = intslice.IntAppendUniques(t.galleries.IDs, gallery.GetIDs(galleries))
tagIDs, err := scene.GetDependentTagIDs(tagReader, sceneMarkerReader, s) tagIDs, err := scene.GetDependentTagIDs(ctx, tagReader, sceneMarkerReader, s)
if err != nil { if err != nil {
logger.Errorf("[scenes] <%s> error getting scene tags: %s", sceneHash, err.Error()) logger.Errorf("[scenes] <%s> error getting scene tags: %s", sceneHash, err.Error())
continue continue
} }
t.tags.IDs = intslice.IntAppendUniques(t.tags.IDs, tagIDs) t.tags.IDs = intslice.IntAppendUniques(t.tags.IDs, tagIDs)
movieIDs, err := scene.GetDependentMovieIDs(sceneReader, s) movieIDs, err := scene.GetDependentMovieIDs(ctx, sceneReader, s)
if err != nil { if err != nil {
logger.Errorf("[scenes] <%s> error getting scene movies: %s", sceneHash, err.Error()) logger.Errorf("[scenes] <%s> error getting scene movies: %s", sceneHash, err.Error())
continue continue
@@ -482,18 +484,18 @@ func exportScene(wg *sync.WaitGroup, jobChan <-chan *models.Scene, repo models.R
} }
} }
func (t *ExportTask) ExportImages(workers int, repo models.ReaderRepository) { func (t *ExportTask) ExportImages(ctx context.Context, workers int, repo models.Repository) {
var imagesWg sync.WaitGroup var imagesWg sync.WaitGroup
imageReader := repo.Image() imageReader := repo.Image
var images []*models.Image var images []*models.Image
var err error var err error
all := t.full || (t.images != nil && t.images.all) all := t.full || (t.images != nil && t.images.all)
if all { if all {
images, err = imageReader.All() images, err = imageReader.All(ctx)
} else if t.images != nil && len(t.images.IDs) > 0 { } else if t.images != nil && len(t.images.IDs) > 0 {
images, err = imageReader.FindMany(t.images.IDs) images, err = imageReader.FindMany(ctx, t.images.IDs)
} }
if err != nil { if err != nil {
@@ -507,7 +509,7 @@ func (t *ExportTask) ExportImages(workers int, repo models.ReaderRepository) {
for w := 0; w < workers; w++ { // create export Image workers for w := 0; w < workers; w++ { // create export Image workers
imagesWg.Add(1) imagesWg.Add(1)
go exportImage(&imagesWg, jobCh, repo, t) go exportImage(ctx, &imagesWg, jobCh, repo, t)
} }
for i, image := range images { for i, image := range images {
@@ -526,12 +528,12 @@ func (t *ExportTask) ExportImages(workers int, repo models.ReaderRepository) {
logger.Infof("[images] export complete in %s. %d workers used.", time.Since(startTime), workers) logger.Infof("[images] export complete in %s. %d workers used.", time.Since(startTime), workers)
} }
func exportImage(wg *sync.WaitGroup, jobChan <-chan *models.Image, repo models.ReaderRepository, t *ExportTask) { func exportImage(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Image, repo models.Repository, t *ExportTask) {
defer wg.Done() defer wg.Done()
studioReader := repo.Studio() studioReader := repo.Studio
galleryReader := repo.Gallery() galleryReader := repo.Gallery
performerReader := repo.Performer() performerReader := repo.Performer
tagReader := repo.Tag() tagReader := repo.Tag
for s := range jobChan { for s := range jobChan {
imageHash := s.Checksum imageHash := s.Checksum
@@ -539,13 +541,13 @@ func exportImage(wg *sync.WaitGroup, jobChan <-chan *models.Image, repo models.R
newImageJSON := image.ToBasicJSON(s) newImageJSON := image.ToBasicJSON(s)
var err error var err error
newImageJSON.Studio, err = image.GetStudioName(studioReader, s) newImageJSON.Studio, err = image.GetStudioName(ctx, studioReader, s)
if err != nil { if err != nil {
logger.Errorf("[images] <%s> error getting image studio name: %s", imageHash, err.Error()) logger.Errorf("[images] <%s> error getting image studio name: %s", imageHash, err.Error())
continue continue
} }
imageGalleries, err := galleryReader.FindByImageID(s.ID) imageGalleries, err := galleryReader.FindByImageID(ctx, s.ID)
if err != nil { if err != nil {
logger.Errorf("[images] <%s> error getting image galleries: %s", imageHash, err.Error()) logger.Errorf("[images] <%s> error getting image galleries: %s", imageHash, err.Error())
continue continue
@@ -553,7 +555,7 @@ func exportImage(wg *sync.WaitGroup, jobChan <-chan *models.Image, repo models.R
newImageJSON.Galleries = t.getGalleryChecksums(imageGalleries) newImageJSON.Galleries = t.getGalleryChecksums(imageGalleries)
performers, err := performerReader.FindByImageID(s.ID) performers, err := performerReader.FindByImageID(ctx, s.ID)
if err != nil { if err != nil {
logger.Errorf("[images] <%s> error getting image performer names: %s", imageHash, err.Error()) logger.Errorf("[images] <%s> error getting image performer names: %s", imageHash, err.Error())
continue continue
@@ -561,7 +563,7 @@ func exportImage(wg *sync.WaitGroup, jobChan <-chan *models.Image, repo models.R
newImageJSON.Performers = performer.GetNames(performers) newImageJSON.Performers = performer.GetNames(performers)
tags, err := tagReader.FindByImageID(s.ID) tags, err := tagReader.FindByImageID(ctx, s.ID)
if err != nil { if err != nil {
logger.Errorf("[images] <%s> error getting image tag names: %s", imageHash, err.Error()) logger.Errorf("[images] <%s> error getting image tag names: %s", imageHash, err.Error())
continue continue
@@ -597,18 +599,18 @@ func (t *ExportTask) getGalleryChecksums(galleries []*models.Gallery) (ret []str
return return
} }
func (t *ExportTask) ExportGalleries(workers int, repo models.ReaderRepository) { func (t *ExportTask) ExportGalleries(ctx context.Context, workers int, repo models.Repository) {
var galleriesWg sync.WaitGroup var galleriesWg sync.WaitGroup
reader := repo.Gallery() reader := repo.Gallery
var galleries []*models.Gallery var galleries []*models.Gallery
var err error var err error
all := t.full || (t.galleries != nil && t.galleries.all) all := t.full || (t.galleries != nil && t.galleries.all)
if all { if all {
galleries, err = reader.All() galleries, err = reader.All(ctx)
} else if t.galleries != nil && len(t.galleries.IDs) > 0 { } else if t.galleries != nil && len(t.galleries.IDs) > 0 {
galleries, err = reader.FindMany(t.galleries.IDs) galleries, err = reader.FindMany(ctx, t.galleries.IDs)
} }
if err != nil { if err != nil {
@@ -622,7 +624,7 @@ func (t *ExportTask) ExportGalleries(workers int, repo models.ReaderRepository)
for w := 0; w < workers; w++ { // create export Scene workers for w := 0; w < workers; w++ { // create export Scene workers
galleriesWg.Add(1) galleriesWg.Add(1)
go exportGallery(&galleriesWg, jobCh, repo, t) go exportGallery(ctx, &galleriesWg, jobCh, repo, t)
} }
for i, gallery := range galleries { for i, gallery := range galleries {
@@ -646,11 +648,11 @@ func (t *ExportTask) ExportGalleries(workers int, repo models.ReaderRepository)
logger.Infof("[galleries] export complete in %s. %d workers used.", time.Since(startTime), workers) logger.Infof("[galleries] export complete in %s. %d workers used.", time.Since(startTime), workers)
} }
func exportGallery(wg *sync.WaitGroup, jobChan <-chan *models.Gallery, repo models.ReaderRepository, t *ExportTask) { func exportGallery(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Gallery, repo models.Repository, t *ExportTask) {
defer wg.Done() defer wg.Done()
studioReader := repo.Studio() studioReader := repo.Studio
performerReader := repo.Performer() performerReader := repo.Performer
tagReader := repo.Tag() tagReader := repo.Tag
for g := range jobChan { for g := range jobChan {
galleryHash := g.Checksum galleryHash := g.Checksum
@@ -661,13 +663,13 @@ func exportGallery(wg *sync.WaitGroup, jobChan <-chan *models.Gallery, repo mode
continue continue
} }
newGalleryJSON.Studio, err = gallery.GetStudioName(studioReader, g) newGalleryJSON.Studio, err = gallery.GetStudioName(ctx, studioReader, g)
if err != nil { if err != nil {
logger.Errorf("[galleries] <%s> error getting gallery studio name: %s", galleryHash, err.Error()) logger.Errorf("[galleries] <%s> error getting gallery studio name: %s", galleryHash, err.Error())
continue continue
} }
performers, err := performerReader.FindByGalleryID(g.ID) performers, err := performerReader.FindByGalleryID(ctx, g.ID)
if err != nil { if err != nil {
logger.Errorf("[galleries] <%s> error getting gallery performer names: %s", galleryHash, err.Error()) logger.Errorf("[galleries] <%s> error getting gallery performer names: %s", galleryHash, err.Error())
continue continue
@@ -675,7 +677,7 @@ func exportGallery(wg *sync.WaitGroup, jobChan <-chan *models.Gallery, repo mode
newGalleryJSON.Performers = performer.GetNames(performers) newGalleryJSON.Performers = performer.GetNames(performers)
tags, err := tagReader.FindByGalleryID(g.ID) tags, err := tagReader.FindByGalleryID(ctx, g.ID)
if err != nil { if err != nil {
logger.Errorf("[galleries] <%s> error getting gallery tag names: %s", galleryHash, err.Error()) logger.Errorf("[galleries] <%s> error getting gallery tag names: %s", galleryHash, err.Error())
continue continue
@@ -703,17 +705,17 @@ func exportGallery(wg *sync.WaitGroup, jobChan <-chan *models.Gallery, repo mode
} }
} }
func (t *ExportTask) ExportPerformers(workers int, repo models.ReaderRepository) { func (t *ExportTask) ExportPerformers(ctx context.Context, workers int, repo models.Repository) {
var performersWg sync.WaitGroup var performersWg sync.WaitGroup
reader := repo.Performer() reader := repo.Performer
var performers []*models.Performer var performers []*models.Performer
var err error var err error
all := t.full || (t.performers != nil && t.performers.all) all := t.full || (t.performers != nil && t.performers.all)
if all { if all {
performers, err = reader.All() performers, err = reader.All(ctx)
} else if t.performers != nil && len(t.performers.IDs) > 0 { } else if t.performers != nil && len(t.performers.IDs) > 0 {
performers, err = reader.FindMany(t.performers.IDs) performers, err = reader.FindMany(ctx, t.performers.IDs)
} }
if err != nil { if err != nil {
@@ -726,7 +728,7 @@ func (t *ExportTask) ExportPerformers(workers int, repo models.ReaderRepository)
for w := 0; w < workers; w++ { // create export Performer workers for w := 0; w < workers; w++ { // create export Performer workers
performersWg.Add(1) performersWg.Add(1)
go t.exportPerformer(&performersWg, jobCh, repo) go t.exportPerformer(ctx, &performersWg, jobCh, repo)
} }
for i, performer := range performers { for i, performer := range performers {
@@ -743,20 +745,20 @@ func (t *ExportTask) ExportPerformers(workers int, repo models.ReaderRepository)
logger.Infof("[performers] export complete in %s. %d workers used.", time.Since(startTime), workers) logger.Infof("[performers] export complete in %s. %d workers used.", time.Since(startTime), workers)
} }
func (t *ExportTask) exportPerformer(wg *sync.WaitGroup, jobChan <-chan *models.Performer, repo models.ReaderRepository) { func (t *ExportTask) exportPerformer(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Performer, repo models.Repository) {
defer wg.Done() defer wg.Done()
performerReader := repo.Performer() performerReader := repo.Performer
for p := range jobChan { for p := range jobChan {
newPerformerJSON, err := performer.ToJSON(performerReader, p) newPerformerJSON, err := performer.ToJSON(ctx, performerReader, p)
if err != nil { if err != nil {
logger.Errorf("[performers] <%s> error getting performer JSON: %s", p.Checksum, err.Error()) logger.Errorf("[performers] <%s> error getting performer JSON: %s", p.Checksum, err.Error())
continue continue
} }
tags, err := repo.Tag().FindByPerformerID(p.ID) tags, err := repo.Tag.FindByPerformerID(ctx, p.ID)
if err != nil { if err != nil {
logger.Errorf("[performers] <%s> error getting performer tags: %s", p.Checksum, err.Error()) logger.Errorf("[performers] <%s> error getting performer tags: %s", p.Checksum, err.Error())
continue continue
@@ -781,17 +783,17 @@ func (t *ExportTask) exportPerformer(wg *sync.WaitGroup, jobChan <-chan *models.
} }
} }
func (t *ExportTask) ExportStudios(workers int, repo models.ReaderRepository) { func (t *ExportTask) ExportStudios(ctx context.Context, workers int, repo models.Repository) {
var studiosWg sync.WaitGroup var studiosWg sync.WaitGroup
reader := repo.Studio() reader := repo.Studio
var studios []*models.Studio var studios []*models.Studio
var err error var err error
all := t.full || (t.studios != nil && t.studios.all) all := t.full || (t.studios != nil && t.studios.all)
if all { if all {
studios, err = reader.All() studios, err = reader.All(ctx)
} else if t.studios != nil && len(t.studios.IDs) > 0 { } else if t.studios != nil && len(t.studios.IDs) > 0 {
studios, err = reader.FindMany(t.studios.IDs) studios, err = reader.FindMany(ctx, t.studios.IDs)
} }
if err != nil { if err != nil {
@@ -805,7 +807,7 @@ func (t *ExportTask) ExportStudios(workers int, repo models.ReaderRepository) {
for w := 0; w < workers; w++ { // create export Studio workers for w := 0; w < workers; w++ { // create export Studio workers
studiosWg.Add(1) studiosWg.Add(1)
go t.exportStudio(&studiosWg, jobCh, repo) go t.exportStudio(ctx, &studiosWg, jobCh, repo)
} }
for i, studio := range studios { for i, studio := range studios {
@@ -822,13 +824,13 @@ func (t *ExportTask) ExportStudios(workers int, repo models.ReaderRepository) {
logger.Infof("[studios] export complete in %s. %d workers used.", time.Since(startTime), workers) logger.Infof("[studios] export complete in %s. %d workers used.", time.Since(startTime), workers)
} }
func (t *ExportTask) exportStudio(wg *sync.WaitGroup, jobChan <-chan *models.Studio, repo models.ReaderRepository) { func (t *ExportTask) exportStudio(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Studio, repo models.Repository) {
defer wg.Done() defer wg.Done()
studioReader := repo.Studio() studioReader := repo.Studio
for s := range jobChan { for s := range jobChan {
newStudioJSON, err := studio.ToJSON(studioReader, s) newStudioJSON, err := studio.ToJSON(ctx, studioReader, s)
if err != nil { if err != nil {
logger.Errorf("[studios] <%s> error getting studio JSON: %s", s.Checksum, err.Error()) logger.Errorf("[studios] <%s> error getting studio JSON: %s", s.Checksum, err.Error())
@@ -846,17 +848,17 @@ func (t *ExportTask) exportStudio(wg *sync.WaitGroup, jobChan <-chan *models.Stu
} }
} }
func (t *ExportTask) ExportTags(workers int, repo models.ReaderRepository) { func (t *ExportTask) ExportTags(ctx context.Context, workers int, repo models.Repository) {
var tagsWg sync.WaitGroup var tagsWg sync.WaitGroup
reader := repo.Tag() reader := repo.Tag
var tags []*models.Tag var tags []*models.Tag
var err error var err error
all := t.full || (t.tags != nil && t.tags.all) all := t.full || (t.tags != nil && t.tags.all)
if all { if all {
tags, err = reader.All() tags, err = reader.All(ctx)
} else if t.tags != nil && len(t.tags.IDs) > 0 { } else if t.tags != nil && len(t.tags.IDs) > 0 {
tags, err = reader.FindMany(t.tags.IDs) tags, err = reader.FindMany(ctx, t.tags.IDs)
} }
if err != nil { if err != nil {
@@ -870,7 +872,7 @@ func (t *ExportTask) ExportTags(workers int, repo models.ReaderRepository) {
for w := 0; w < workers; w++ { // create export Tag workers for w := 0; w < workers; w++ { // create export Tag workers
tagsWg.Add(1) tagsWg.Add(1)
go t.exportTag(&tagsWg, jobCh, repo) go t.exportTag(ctx, &tagsWg, jobCh, repo)
} }
for i, tag := range tags { for i, tag := range tags {
@@ -890,13 +892,13 @@ func (t *ExportTask) ExportTags(workers int, repo models.ReaderRepository) {
logger.Infof("[tags] export complete in %s. %d workers used.", time.Since(startTime), workers) logger.Infof("[tags] export complete in %s. %d workers used.", time.Since(startTime), workers)
} }
func (t *ExportTask) exportTag(wg *sync.WaitGroup, jobChan <-chan *models.Tag, repo models.ReaderRepository) { func (t *ExportTask) exportTag(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Tag, repo models.Repository) {
defer wg.Done() defer wg.Done()
tagReader := repo.Tag() tagReader := repo.Tag
for thisTag := range jobChan { for thisTag := range jobChan {
newTagJSON, err := tag.ToJSON(tagReader, thisTag) newTagJSON, err := tag.ToJSON(ctx, tagReader, thisTag)
if err != nil { if err != nil {
logger.Errorf("[tags] <%s> error getting tag JSON: %s", thisTag.Name, err.Error()) logger.Errorf("[tags] <%s> error getting tag JSON: %s", thisTag.Name, err.Error())
@@ -917,17 +919,17 @@ func (t *ExportTask) exportTag(wg *sync.WaitGroup, jobChan <-chan *models.Tag, r
} }
} }
func (t *ExportTask) ExportMovies(workers int, repo models.ReaderRepository) { func (t *ExportTask) ExportMovies(ctx context.Context, workers int, repo models.Repository) {
var moviesWg sync.WaitGroup var moviesWg sync.WaitGroup
reader := repo.Movie() reader := repo.Movie
var movies []*models.Movie var movies []*models.Movie
var err error var err error
all := t.full || (t.movies != nil && t.movies.all) all := t.full || (t.movies != nil && t.movies.all)
if all { if all {
movies, err = reader.All() movies, err = reader.All(ctx)
} else if t.movies != nil && len(t.movies.IDs) > 0 { } else if t.movies != nil && len(t.movies.IDs) > 0 {
movies, err = reader.FindMany(t.movies.IDs) movies, err = reader.FindMany(ctx, t.movies.IDs)
} }
if err != nil { if err != nil {
@@ -941,7 +943,7 @@ func (t *ExportTask) ExportMovies(workers int, repo models.ReaderRepository) {
for w := 0; w < workers; w++ { // create export Studio workers for w := 0; w < workers; w++ { // create export Studio workers
moviesWg.Add(1) moviesWg.Add(1)
go t.exportMovie(&moviesWg, jobCh, repo) go t.exportMovie(ctx, &moviesWg, jobCh, repo)
} }
for i, movie := range movies { for i, movie := range movies {
@@ -958,14 +960,14 @@ func (t *ExportTask) ExportMovies(workers int, repo models.ReaderRepository) {
logger.Infof("[movies] export complete in %s. %d workers used.", time.Since(startTime), workers) logger.Infof("[movies] export complete in %s. %d workers used.", time.Since(startTime), workers)
} }
func (t *ExportTask) exportMovie(wg *sync.WaitGroup, jobChan <-chan *models.Movie, repo models.ReaderRepository) { func (t *ExportTask) exportMovie(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Movie, repo models.Repository) {
defer wg.Done() defer wg.Done()
movieReader := repo.Movie() movieReader := repo.Movie
studioReader := repo.Studio() studioReader := repo.Studio
for m := range jobChan { for m := range jobChan {
newMovieJSON, err := movie.ToJSON(movieReader, studioReader, m) newMovieJSON, err := movie.ToJSON(ctx, movieReader, studioReader, m)
if err != nil { if err != nil {
logger.Errorf("[movies] <%s> error getting tag JSON: %s", m.Checksum, err.Error()) logger.Errorf("[movies] <%s> error getting tag JSON: %s", m.Checksum, err.Error())
@@ -991,10 +993,10 @@ func (t *ExportTask) exportMovie(wg *sync.WaitGroup, jobChan <-chan *models.Movi
} }
} }
func (t *ExportTask) ExportScrapedItems(repo models.ReaderRepository) { func (t *ExportTask) ExportScrapedItems(ctx context.Context, repo models.Repository) {
qb := repo.ScrapedItem() qb := repo.ScrapedItem
sqb := repo.Studio() sqb := repo.Studio
scrapedItems, err := qb.All() scrapedItems, err := qb.All(ctx)
if err != nil { if err != nil {
logger.Errorf("[scraped sites] failed to fetch all items: %s", err.Error()) logger.Errorf("[scraped sites] failed to fetch all items: %s", err.Error())
} }
@@ -1009,7 +1011,7 @@ func (t *ExportTask) ExportScrapedItems(repo models.ReaderRepository) {
var studioName string var studioName string
if scrapedItem.StudioID.Valid { if scrapedItem.StudioID.Valid {
studio, _ := sqb.Find(int(scrapedItem.StudioID.Int64)) studio, _ := sqb.Find(ctx, int(scrapedItem.StudioID.Int64))
if studio != nil { if studio != nil {
studioName = studio.Name.String studioName = studio.Name.String
} }

View File

@@ -54,7 +54,7 @@ type GeneratePreviewOptionsInput struct {
const generateQueueSize = 200000 const generateQueueSize = 200000
type GenerateJob struct { type GenerateJob struct {
txnManager models.TransactionManager txnManager models.Repository
input GenerateMetadataInput input GenerateMetadataInput
overwrite bool overwrite bool
@@ -110,20 +110,20 @@ func (j *GenerateJob) Execute(ctx context.Context, progress *job.Progress) {
Overwrite: j.overwrite, Overwrite: j.overwrite,
} }
if err := j.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { if err := j.txnManager.WithTxn(ctx, func(ctx context.Context) error {
qb := r.Scene() qb := j.txnManager.Scene
if len(j.input.SceneIDs) == 0 && len(j.input.MarkerIDs) == 0 { if len(j.input.SceneIDs) == 0 && len(j.input.MarkerIDs) == 0 {
totals = j.queueTasks(ctx, g, queue) totals = j.queueTasks(ctx, g, queue)
} else { } else {
if len(j.input.SceneIDs) > 0 { if len(j.input.SceneIDs) > 0 {
scenes, err = qb.FindMany(sceneIDs) scenes, err = qb.FindMany(ctx, sceneIDs)
for _, s := range scenes { for _, s := range scenes {
j.queueSceneJobs(ctx, g, s, queue, &totals) j.queueSceneJobs(ctx, g, s, queue, &totals)
} }
} }
if len(j.input.MarkerIDs) > 0 { if len(j.input.MarkerIDs) > 0 {
markers, err = r.SceneMarker().FindMany(markerIDs) markers, err = j.txnManager.SceneMarker.FindMany(ctx, markerIDs)
if err != nil { if err != nil {
return err return err
} }
@@ -192,13 +192,13 @@ func (j *GenerateJob) queueTasks(ctx context.Context, g *generate.Generator, que
findFilter := models.BatchFindFilter(batchSize) findFilter := models.BatchFindFilter(batchSize)
if err := j.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { if err := j.txnManager.WithTxn(ctx, func(ctx context.Context) error {
for more := true; more; { for more := true; more; {
if job.IsCancelled(ctx) { if job.IsCancelled(ctx) {
return context.Canceled return context.Canceled
} }
scenes, err := scene.Query(r.Scene(), nil, findFilter) scenes, err := scene.Query(ctx, j.txnManager.Scene, nil, findFilter)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -15,7 +15,7 @@ type GenerateInteractiveHeatmapSpeedTask struct {
Scene models.Scene Scene models.Scene
Overwrite bool Overwrite bool
fileNamingAlgorithm models.HashAlgorithm fileNamingAlgorithm models.HashAlgorithm
TxnManager models.TransactionManager TxnManager models.Repository
} }
func (t *GenerateInteractiveHeatmapSpeedTask) GetDescription() string { func (t *GenerateInteractiveHeatmapSpeedTask) GetDescription() string {
@@ -47,22 +47,22 @@ func (t *GenerateInteractiveHeatmapSpeedTask) Start(ctx context.Context) {
var s *models.Scene var s *models.Scene
if err := t.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error {
var err error var err error
s, err = r.Scene().FindByPath(t.Scene.Path) s, err = t.TxnManager.Scene.FindByPath(ctx, t.Scene.Path)
return err return err
}); err != nil { }); err != nil {
logger.Error(err.Error()) logger.Error(err.Error())
return return
} }
if err := t.TxnManager.WithTxn(ctx, func(r models.Repository) error { if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error {
qb := r.Scene() qb := t.TxnManager.Scene
scenePartial := models.ScenePartial{ scenePartial := models.ScenePartial{
ID: s.ID, ID: s.ID,
InteractiveSpeed: &median, InteractiveSpeed: &median,
} }
_, err := qb.Update(scenePartial) _, err := qb.Update(ctx, scenePartial)
return err return err
}); err != nil { }); err != nil {
logger.Error(err.Error()) logger.Error(err.Error())

View File

@@ -13,7 +13,7 @@ import (
) )
type GenerateMarkersTask struct { type GenerateMarkersTask struct {
TxnManager models.TransactionManager TxnManager models.Repository
Scene *models.Scene Scene *models.Scene
Marker *models.SceneMarker Marker *models.SceneMarker
Overwrite bool Overwrite bool
@@ -42,9 +42,9 @@ func (t *GenerateMarkersTask) Start(ctx context.Context) {
if t.Marker != nil { if t.Marker != nil {
var scene *models.Scene var scene *models.Scene
if err := t.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error {
var err error var err error
scene, err = r.Scene().Find(int(t.Marker.SceneID.Int64)) scene, err = t.TxnManager.Scene.Find(ctx, int(t.Marker.SceneID.Int64))
return err return err
}); err != nil { }); err != nil {
logger.Errorf("error finding scene for marker: %s", err.Error()) logger.Errorf("error finding scene for marker: %s", err.Error())
@@ -69,9 +69,9 @@ func (t *GenerateMarkersTask) Start(ctx context.Context) {
func (t *GenerateMarkersTask) generateSceneMarkers(ctx context.Context) { func (t *GenerateMarkersTask) generateSceneMarkers(ctx context.Context) {
var sceneMarkers []*models.SceneMarker var sceneMarkers []*models.SceneMarker
if err := t.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error {
var err error var err error
sceneMarkers, err = r.SceneMarker().FindBySceneID(t.Scene.ID) sceneMarkers, err = t.TxnManager.SceneMarker.FindBySceneID(ctx, t.Scene.ID)
return err return err
}); err != nil { }); err != nil {
logger.Errorf("error getting scene markers: %s", err.Error()) logger.Errorf("error getting scene markers: %s", err.Error())
@@ -134,9 +134,9 @@ func (t *GenerateMarkersTask) generateMarker(videoFile *ffmpeg.VideoFile, scene
func (t *GenerateMarkersTask) markersNeeded(ctx context.Context) int { func (t *GenerateMarkersTask) markersNeeded(ctx context.Context) int {
markers := 0 markers := 0
var sceneMarkers []*models.SceneMarker var sceneMarkers []*models.SceneMarker
if err := t.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error {
var err error var err error
sceneMarkers, err = r.SceneMarker().FindBySceneID(t.Scene.ID) sceneMarkers, err = t.TxnManager.SceneMarker.FindBySceneID(ctx, t.Scene.ID)
return err return err
}); err != nil { }); err != nil {
logger.Errorf("errror finding scene markers: %s", err.Error()) logger.Errorf("errror finding scene markers: %s", err.Error())

View File

@@ -14,7 +14,7 @@ type GeneratePhashTask struct {
Scene models.Scene Scene models.Scene
Overwrite bool Overwrite bool
fileNamingAlgorithm models.HashAlgorithm fileNamingAlgorithm models.HashAlgorithm
txnManager models.TransactionManager txnManager models.Repository
} }
func (t *GeneratePhashTask) GetDescription() string { func (t *GeneratePhashTask) GetDescription() string {
@@ -40,14 +40,14 @@ func (t *GeneratePhashTask) Start(ctx context.Context) {
return return
} }
if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error { if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
qb := r.Scene() qb := t.txnManager.Scene
hashValue := sql.NullInt64{Int64: int64(*hash), Valid: true} hashValue := sql.NullInt64{Int64: int64(*hash), Valid: true}
scenePartial := models.ScenePartial{ scenePartial := models.ScenePartial{
ID: t.Scene.ID, ID: t.Scene.ID,
Phash: &hashValue, Phash: &hashValue,
} }
_, err := qb.Update(scenePartial) _, err := qb.Update(ctx, scenePartial)
return err return err
}); err != nil { }); err != nil {
logger.Error(err.Error()) logger.Error(err.Error())

View File

@@ -17,7 +17,7 @@ type GenerateScreenshotTask struct {
Scene models.Scene Scene models.Scene
ScreenshotAt *float64 ScreenshotAt *float64
fileNamingAlgorithm models.HashAlgorithm fileNamingAlgorithm models.HashAlgorithm
txnManager models.TransactionManager txnManager models.Repository
} }
func (t *GenerateScreenshotTask) Start(ctx context.Context) { func (t *GenerateScreenshotTask) Start(ctx context.Context) {
@@ -74,8 +74,8 @@ func (t *GenerateScreenshotTask) Start(ctx context.Context) {
return return
} }
if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error { if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
qb := r.Scene() qb := t.txnManager.Scene
updatedTime := time.Now() updatedTime := time.Now()
updatedScene := models.ScenePartial{ updatedScene := models.ScenePartial{
ID: t.Scene.ID, ID: t.Scene.ID,
@@ -87,12 +87,12 @@ func (t *GenerateScreenshotTask) Start(ctx context.Context) {
} }
// update the scene cover table // update the scene cover table
if err := qb.UpdateCover(t.Scene.ID, coverImageData); err != nil { if err := qb.UpdateCover(ctx, t.Scene.ID, coverImageData); err != nil {
return fmt.Errorf("error setting screenshot: %v", err) return fmt.Errorf("error setting screenshot: %v", err)
} }
// update the scene with the update date // update the scene with the update date
_, err = qb.Update(updatedScene) _, err = qb.Update(ctx, updatedScene)
if err != nil { if err != nil {
return fmt.Errorf("error updating scene: %v", err) return fmt.Errorf("error updating scene: %v", err)
} }

View File

@@ -14,12 +14,12 @@ import (
"github.com/stashapp/stash/pkg/scraper" "github.com/stashapp/stash/pkg/scraper"
"github.com/stashapp/stash/pkg/scraper/stashbox" "github.com/stashapp/stash/pkg/scraper/stashbox"
"github.com/stashapp/stash/pkg/sliceutil/stringslice" "github.com/stashapp/stash/pkg/sliceutil/stringslice"
"github.com/stashapp/stash/pkg/txn"
) )
var ErrInput = errors.New("invalid request input") var ErrInput = errors.New("invalid request input")
type IdentifyJob struct { type IdentifyJob struct {
txnManager models.TransactionManager
postHookExecutor identify.SceneUpdatePostHookExecutor postHookExecutor identify.SceneUpdatePostHookExecutor
input identify.Options input identify.Options
@@ -29,7 +29,6 @@ type IdentifyJob struct {
func CreateIdentifyJob(input identify.Options) *IdentifyJob { func CreateIdentifyJob(input identify.Options) *IdentifyJob {
return &IdentifyJob{ return &IdentifyJob{
txnManager: instance.TxnManager,
postHookExecutor: instance.PluginCache, postHookExecutor: instance.PluginCache,
input: input, input: input,
stashBoxes: instance.Config.GetStashBoxes(), stashBoxes: instance.Config.GetStashBoxes(),
@@ -52,9 +51,9 @@ func (j *IdentifyJob) Execute(ctx context.Context, progress *job.Progress) {
// if scene ids provided, use those // if scene ids provided, use those
// otherwise, batch query for all scenes - ordering by path // otherwise, batch query for all scenes - ordering by path
if err := j.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { if err := txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error {
if len(j.input.SceneIDs) == 0 { if len(j.input.SceneIDs) == 0 {
return j.identifyAllScenes(ctx, r, sources) return j.identifyAllScenes(ctx, sources)
} }
sceneIDs, err := stringslice.StringSliceToIntSlice(j.input.SceneIDs) sceneIDs, err := stringslice.StringSliceToIntSlice(j.input.SceneIDs)
@@ -70,7 +69,7 @@ func (j *IdentifyJob) Execute(ctx context.Context, progress *job.Progress) {
// find the scene // find the scene
var err error var err error
scene, err := r.Scene().Find(id) scene, err := instance.Repository.Scene.Find(ctx, id)
if err != nil { if err != nil {
return fmt.Errorf("error finding scene with id %d: %w", id, err) return fmt.Errorf("error finding scene with id %d: %w", id, err)
} }
@@ -88,7 +87,7 @@ func (j *IdentifyJob) Execute(ctx context.Context, progress *job.Progress) {
} }
} }
func (j *IdentifyJob) identifyAllScenes(ctx context.Context, r models.ReaderRepository, sources []identify.ScraperSource) error { func (j *IdentifyJob) identifyAllScenes(ctx context.Context, sources []identify.ScraperSource) error {
// exclude organised // exclude organised
organised := false organised := false
sceneFilter := scene.FilterFromPaths(j.input.Paths) sceneFilter := scene.FilterFromPaths(j.input.Paths)
@@ -102,7 +101,7 @@ func (j *IdentifyJob) identifyAllScenes(ctx context.Context, r models.ReaderRepo
// get the count // get the count
pp := 0 pp := 0
findFilter.PerPage = &pp findFilter.PerPage = &pp
countResult, err := r.Scene().Query(models.SceneQueryOptions{ countResult, err := instance.Repository.Scene.Query(ctx, models.SceneQueryOptions{
QueryOptions: models.QueryOptions{ QueryOptions: models.QueryOptions{
FindFilter: findFilter, FindFilter: findFilter,
Count: true, Count: true,
@@ -115,7 +114,7 @@ func (j *IdentifyJob) identifyAllScenes(ctx context.Context, r models.ReaderRepo
j.progress.SetTotal(countResult.Count) j.progress.SetTotal(countResult.Count)
return scene.BatchProcess(ctx, r.Scene(), sceneFilter, findFilter, func(scene *models.Scene) error { return scene.BatchProcess(ctx, instance.Repository.Scene, sceneFilter, findFilter, func(scene *models.Scene) error {
if job.IsCancelled(ctx) { if job.IsCancelled(ctx) {
return nil return nil
} }
@@ -133,6 +132,11 @@ func (j *IdentifyJob) identifyScene(ctx context.Context, s *models.Scene, source
var taskError error var taskError error
j.progress.ExecuteTask("Identifying "+s.Path, func() { j.progress.ExecuteTask("Identifying "+s.Path, func() {
task := identify.SceneIdentifier{ task := identify.SceneIdentifier{
SceneReaderUpdater: instance.Repository.Scene,
StudioCreator: instance.Repository.Studio,
PerformerCreator: instance.Repository.Performer,
TagCreator: instance.Repository.Tag,
DefaultOptions: j.input.Options, DefaultOptions: j.input.Options,
Sources: sources, Sources: sources,
ScreenshotSetter: &scene.PathsScreenshotSetter{ ScreenshotSetter: &scene.PathsScreenshotSetter{
@@ -142,7 +146,7 @@ func (j *IdentifyJob) identifyScene(ctx context.Context, s *models.Scene, source
SceneUpdatePostHookExecutor: j.postHookExecutor, SceneUpdatePostHookExecutor: j.postHookExecutor,
} }
taskError = task.Identify(ctx, j.txnManager, s) taskError = task.Identify(ctx, instance.Repository, s)
}) })
if taskError != nil { if taskError != nil {
@@ -166,7 +170,12 @@ func (j *IdentifyJob) getSources() ([]identify.ScraperSource, error) {
src = identify.ScraperSource{ src = identify.ScraperSource{
Name: "stash-box: " + stashBox.Endpoint, Name: "stash-box: " + stashBox.Endpoint,
Scraper: stashboxSource{ Scraper: stashboxSource{
stashbox.NewClient(*stashBox, j.txnManager), stashbox.NewClient(*stashBox, instance.Repository, stashbox.Repository{
Scene: instance.Repository.Scene,
Performer: instance.Repository.Performer,
Tag: instance.Repository.Tag,
Studio: instance.Repository.Studio,
}),
stashBox.Endpoint, stashBox.Endpoint,
}, },
RemoteSite: stashBox.Endpoint, RemoteSite: stashBox.Endpoint,

View File

@@ -12,8 +12,6 @@ import (
"time" "time"
"github.com/99designs/gqlgen/graphql" "github.com/99designs/gqlgen/graphql"
"github.com/stashapp/stash/internal/manager/config"
"github.com/stashapp/stash/pkg/database"
"github.com/stashapp/stash/pkg/fsutil" "github.com/stashapp/stash/pkg/fsutil"
"github.com/stashapp/stash/pkg/gallery" "github.com/stashapp/stash/pkg/gallery"
"github.com/stashapp/stash/pkg/image" "github.com/stashapp/stash/pkg/image"
@@ -30,7 +28,7 @@ import (
) )
type ImportTask struct { type ImportTask struct {
txnManager models.TransactionManager txnManager models.Repository
json jsonUtils json jsonUtils
BaseDir string BaseDir string
@@ -73,7 +71,7 @@ func CreateImportTask(a models.HashAlgorithm, input ImportObjectsInput) (*Import
} }
return &ImportTask{ return &ImportTask{
txnManager: GetInstance().TxnManager, txnManager: GetInstance().Repository,
BaseDir: baseDir, BaseDir: baseDir,
TmpZip: tmpZip, TmpZip: tmpZip,
Reset: false, Reset: false,
@@ -126,7 +124,7 @@ func (t *ImportTask) Start(ctx context.Context) {
t.scraped = scraped t.scraped = scraped
if t.Reset { if t.Reset {
err := database.Reset(config.GetInstance().GetDatabasePath()) err := t.txnManager.Reset()
if err != nil { if err != nil {
logger.Errorf("Error resetting database: %s", err.Error()) logger.Errorf("Error resetting database: %s", err.Error())
@@ -211,15 +209,16 @@ func (t *ImportTask) ImportPerformers(ctx context.Context) {
logger.Progressf("[performers] %d of %d", index, len(t.mappings.Performers)) logger.Progressf("[performers] %d of %d", index, len(t.mappings.Performers))
if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error { if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
readerWriter := r.Performer() r := t.txnManager
readerWriter := r.Performer
importer := &performer.Importer{ importer := &performer.Importer{
ReaderWriter: readerWriter, ReaderWriter: readerWriter,
TagWriter: r.Tag(), TagWriter: r.Tag,
Input: *performerJSON, Input: *performerJSON,
} }
return performImport(importer, t.DuplicateBehaviour) return performImport(ctx, importer, t.DuplicateBehaviour)
}); err != nil { }); err != nil {
logger.Errorf("[performers] <%s> import failed: %s", mappingJSON.Checksum, err.Error()) logger.Errorf("[performers] <%s> import failed: %s", mappingJSON.Checksum, err.Error())
} }
@@ -243,8 +242,8 @@ func (t *ImportTask) ImportStudios(ctx context.Context) {
logger.Progressf("[studios] %d of %d", index, len(t.mappings.Studios)) logger.Progressf("[studios] %d of %d", index, len(t.mappings.Studios))
if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error { if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
return t.ImportStudio(studioJSON, pendingParent, r.Studio()) return t.ImportStudio(ctx, studioJSON, pendingParent, t.txnManager.Studio)
}); err != nil { }); err != nil {
if errors.Is(err, studio.ErrParentStudioNotExist) { if errors.Is(err, studio.ErrParentStudioNotExist) {
// add to the pending parent list so that it is created after the parent // add to the pending parent list so that it is created after the parent
@@ -265,8 +264,8 @@ func (t *ImportTask) ImportStudios(ctx context.Context) {
for _, s := range pendingParent { for _, s := range pendingParent {
for _, orphanStudioJSON := range s { for _, orphanStudioJSON := range s {
if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error { if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
return t.ImportStudio(orphanStudioJSON, nil, r.Studio()) return t.ImportStudio(ctx, orphanStudioJSON, nil, t.txnManager.Studio)
}); err != nil { }); err != nil {
logger.Errorf("[studios] <%s> failed to create: %s", orphanStudioJSON.Name, err.Error()) logger.Errorf("[studios] <%s> failed to create: %s", orphanStudioJSON.Name, err.Error())
continue continue
@@ -278,7 +277,7 @@ func (t *ImportTask) ImportStudios(ctx context.Context) {
logger.Info("[studios] import complete") logger.Info("[studios] import complete")
} }
func (t *ImportTask) ImportStudio(studioJSON *jsonschema.Studio, pendingParent map[string][]*jsonschema.Studio, readerWriter models.StudioReaderWriter) error { func (t *ImportTask) ImportStudio(ctx context.Context, studioJSON *jsonschema.Studio, pendingParent map[string][]*jsonschema.Studio, readerWriter studio.NameFinderCreatorUpdater) error {
importer := &studio.Importer{ importer := &studio.Importer{
ReaderWriter: readerWriter, ReaderWriter: readerWriter,
Input: *studioJSON, Input: *studioJSON,
@@ -290,7 +289,7 @@ func (t *ImportTask) ImportStudio(studioJSON *jsonschema.Studio, pendingParent m
importer.MissingRefBehaviour = models.ImportMissingRefEnumFail importer.MissingRefBehaviour = models.ImportMissingRefEnumFail
} }
if err := performImport(importer, t.DuplicateBehaviour); err != nil { if err := performImport(ctx, importer, t.DuplicateBehaviour); err != nil {
return err return err
} }
@@ -298,7 +297,7 @@ func (t *ImportTask) ImportStudio(studioJSON *jsonschema.Studio, pendingParent m
s := pendingParent[studioJSON.Name] s := pendingParent[studioJSON.Name]
for _, childStudioJSON := range s { for _, childStudioJSON := range s {
// map is nil since we're not checking parent studios at this point // map is nil since we're not checking parent studios at this point
if err := t.ImportStudio(childStudioJSON, nil, readerWriter); err != nil { if err := t.ImportStudio(ctx, childStudioJSON, nil, readerWriter); err != nil {
return fmt.Errorf("failed to create child studio <%s>: %s", childStudioJSON.Name, err.Error()) return fmt.Errorf("failed to create child studio <%s>: %s", childStudioJSON.Name, err.Error())
} }
} }
@@ -322,9 +321,10 @@ func (t *ImportTask) ImportMovies(ctx context.Context) {
logger.Progressf("[movies] %d of %d", index, len(t.mappings.Movies)) logger.Progressf("[movies] %d of %d", index, len(t.mappings.Movies))
if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error { if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
readerWriter := r.Movie() r := t.txnManager
studioReaderWriter := r.Studio() readerWriter := r.Movie
studioReaderWriter := r.Studio
movieImporter := &movie.Importer{ movieImporter := &movie.Importer{
ReaderWriter: readerWriter, ReaderWriter: readerWriter,
@@ -333,7 +333,7 @@ func (t *ImportTask) ImportMovies(ctx context.Context) {
MissingRefBehaviour: t.MissingRefBehaviour, MissingRefBehaviour: t.MissingRefBehaviour,
} }
return performImport(movieImporter, t.DuplicateBehaviour) return performImport(ctx, movieImporter, t.DuplicateBehaviour)
}); err != nil { }); err != nil {
logger.Errorf("[movies] <%s> import failed: %s", mappingJSON.Checksum, err.Error()) logger.Errorf("[movies] <%s> import failed: %s", mappingJSON.Checksum, err.Error())
continue continue
@@ -356,11 +356,12 @@ func (t *ImportTask) ImportGalleries(ctx context.Context) {
logger.Progressf("[galleries] %d of %d", index, len(t.mappings.Galleries)) logger.Progressf("[galleries] %d of %d", index, len(t.mappings.Galleries))
if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error { if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
readerWriter := r.Gallery() r := t.txnManager
tagWriter := r.Tag() readerWriter := r.Gallery
performerWriter := r.Performer() tagWriter := r.Tag
studioWriter := r.Studio() performerWriter := r.Performer
studioWriter := r.Studio
galleryImporter := &gallery.Importer{ galleryImporter := &gallery.Importer{
ReaderWriter: readerWriter, ReaderWriter: readerWriter,
@@ -371,7 +372,7 @@ func (t *ImportTask) ImportGalleries(ctx context.Context) {
MissingRefBehaviour: t.MissingRefBehaviour, MissingRefBehaviour: t.MissingRefBehaviour,
} }
return performImport(galleryImporter, t.DuplicateBehaviour) return performImport(ctx, galleryImporter, t.DuplicateBehaviour)
}); err != nil { }); err != nil {
logger.Errorf("[galleries] <%s> import failed to commit: %s", mappingJSON.Checksum, err.Error()) logger.Errorf("[galleries] <%s> import failed to commit: %s", mappingJSON.Checksum, err.Error())
continue continue
@@ -395,8 +396,8 @@ func (t *ImportTask) ImportTags(ctx context.Context) {
logger.Progressf("[tags] %d of %d", index, len(t.mappings.Tags)) logger.Progressf("[tags] %d of %d", index, len(t.mappings.Tags))
if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error { if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
return t.ImportTag(tagJSON, pendingParent, false, r.Tag()) return t.ImportTag(ctx, tagJSON, pendingParent, false, t.txnManager.Tag)
}); err != nil { }); err != nil {
var parentError tag.ParentTagNotExistError var parentError tag.ParentTagNotExistError
if errors.As(err, &parentError) { if errors.As(err, &parentError) {
@@ -411,8 +412,8 @@ func (t *ImportTask) ImportTags(ctx context.Context) {
for _, s := range pendingParent { for _, s := range pendingParent {
for _, orphanTagJSON := range s { for _, orphanTagJSON := range s {
if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error { if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
return t.ImportTag(orphanTagJSON, nil, true, r.Tag()) return t.ImportTag(ctx, orphanTagJSON, nil, true, t.txnManager.Tag)
}); err != nil { }); err != nil {
logger.Errorf("[tags] <%s> failed to create: %s", orphanTagJSON.Name, err.Error()) logger.Errorf("[tags] <%s> failed to create: %s", orphanTagJSON.Name, err.Error())
continue continue
@@ -423,7 +424,7 @@ func (t *ImportTask) ImportTags(ctx context.Context) {
logger.Info("[tags] import complete") logger.Info("[tags] import complete")
} }
func (t *ImportTask) ImportTag(tagJSON *jsonschema.Tag, pendingParent map[string][]*jsonschema.Tag, fail bool, readerWriter models.TagReaderWriter) error { func (t *ImportTask) ImportTag(ctx context.Context, tagJSON *jsonschema.Tag, pendingParent map[string][]*jsonschema.Tag, fail bool, readerWriter tag.NameFinderCreatorUpdater) error {
importer := &tag.Importer{ importer := &tag.Importer{
ReaderWriter: readerWriter, ReaderWriter: readerWriter,
Input: *tagJSON, Input: *tagJSON,
@@ -435,12 +436,12 @@ func (t *ImportTask) ImportTag(tagJSON *jsonschema.Tag, pendingParent map[string
importer.MissingRefBehaviour = models.ImportMissingRefEnumFail importer.MissingRefBehaviour = models.ImportMissingRefEnumFail
} }
if err := performImport(importer, t.DuplicateBehaviour); err != nil { if err := performImport(ctx, importer, t.DuplicateBehaviour); err != nil {
return err return err
} }
for _, childTagJSON := range pendingParent[tagJSON.Name] { for _, childTagJSON := range pendingParent[tagJSON.Name] {
if err := t.ImportTag(childTagJSON, pendingParent, fail, readerWriter); err != nil { if err := t.ImportTag(ctx, childTagJSON, pendingParent, fail, readerWriter); err != nil {
var parentError tag.ParentTagNotExistError var parentError tag.ParentTagNotExistError
if errors.As(err, &parentError) { if errors.As(err, &parentError) {
pendingParent[parentError.MissingParent()] = append(pendingParent[parentError.MissingParent()], tagJSON) pendingParent[parentError.MissingParent()] = append(pendingParent[parentError.MissingParent()], tagJSON)
@@ -457,10 +458,11 @@ func (t *ImportTask) ImportTag(tagJSON *jsonschema.Tag, pendingParent map[string
} }
func (t *ImportTask) ImportScrapedItems(ctx context.Context) { func (t *ImportTask) ImportScrapedItems(ctx context.Context) {
if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error { if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
logger.Info("[scraped sites] importing") logger.Info("[scraped sites] importing")
qb := r.ScrapedItem() r := t.txnManager
sqb := r.Studio() qb := r.ScrapedItem
sqb := r.Studio
currentTime := time.Now() currentTime := time.Now()
for i, mappingJSON := range t.scraped { for i, mappingJSON := range t.scraped {
@@ -484,7 +486,7 @@ func (t *ImportTask) ImportScrapedItems(ctx context.Context) {
UpdatedAt: models.SQLiteTimestamp{Timestamp: t.getTimeFromJSONTime(mappingJSON.UpdatedAt)}, UpdatedAt: models.SQLiteTimestamp{Timestamp: t.getTimeFromJSONTime(mappingJSON.UpdatedAt)},
} }
studio, err := sqb.FindByName(mappingJSON.Studio, false) studio, err := sqb.FindByName(ctx, mappingJSON.Studio, false)
if err != nil { if err != nil {
logger.Errorf("[scraped sites] failed to fetch studio: %s", err.Error()) logger.Errorf("[scraped sites] failed to fetch studio: %s", err.Error())
} }
@@ -492,7 +494,7 @@ func (t *ImportTask) ImportScrapedItems(ctx context.Context) {
newScrapedItem.StudioID = sql.NullInt64{Int64: int64(studio.ID), Valid: true} newScrapedItem.StudioID = sql.NullInt64{Int64: int64(studio.ID), Valid: true}
} }
_, err = qb.Create(newScrapedItem) _, err = qb.Create(ctx, newScrapedItem)
if err != nil { if err != nil {
logger.Errorf("[scraped sites] <%s> failed to create: %s", newScrapedItem.Title.String, err.Error()) logger.Errorf("[scraped sites] <%s> failed to create: %s", newScrapedItem.Title.String, err.Error())
} }
@@ -522,14 +524,15 @@ func (t *ImportTask) ImportScenes(ctx context.Context) {
sceneHash := mappingJSON.Checksum sceneHash := mappingJSON.Checksum
if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error { if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
readerWriter := r.Scene() r := t.txnManager
tagWriter := r.Tag() readerWriter := r.Scene
galleryWriter := r.Gallery() tagWriter := r.Tag
movieWriter := r.Movie() galleryWriter := r.Gallery
performerWriter := r.Performer() movieWriter := r.Movie
studioWriter := r.Studio() performerWriter := r.Performer
markerWriter := r.SceneMarker() studioWriter := r.Studio
markerWriter := r.SceneMarker
sceneImporter := &scene.Importer{ sceneImporter := &scene.Importer{
ReaderWriter: readerWriter, ReaderWriter: readerWriter,
@@ -546,7 +549,7 @@ func (t *ImportTask) ImportScenes(ctx context.Context) {
TagWriter: tagWriter, TagWriter: tagWriter,
} }
if err := performImport(sceneImporter, t.DuplicateBehaviour); err != nil { if err := performImport(ctx, sceneImporter, t.DuplicateBehaviour); err != nil {
return err return err
} }
@@ -560,7 +563,7 @@ func (t *ImportTask) ImportScenes(ctx context.Context) {
TagWriter: tagWriter, TagWriter: tagWriter,
} }
if err := performImport(markerImporter, t.DuplicateBehaviour); err != nil { if err := performImport(ctx, markerImporter, t.DuplicateBehaviour); err != nil {
return err return err
} }
} }
@@ -590,12 +593,13 @@ func (t *ImportTask) ImportImages(ctx context.Context) {
imageHash := mappingJSON.Checksum imageHash := mappingJSON.Checksum
if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error { if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
readerWriter := r.Image() r := t.txnManager
tagWriter := r.Tag() readerWriter := r.Image
galleryWriter := r.Gallery() tagWriter := r.Tag
performerWriter := r.Performer() galleryWriter := r.Gallery
studioWriter := r.Studio() performerWriter := r.Performer
studioWriter := r.Studio
imageImporter := &image.Importer{ imageImporter := &image.Importer{
ReaderWriter: readerWriter, ReaderWriter: readerWriter,
@@ -610,7 +614,7 @@ func (t *ImportTask) ImportImages(ctx context.Context) {
TagWriter: tagWriter, TagWriter: tagWriter,
} }
return performImport(imageImporter, t.DuplicateBehaviour) return performImport(ctx, imageImporter, t.DuplicateBehaviour)
}); err != nil { }); err != nil {
logger.Errorf("[images] <%s> import failed: %s", imageHash, err.Error()) logger.Errorf("[images] <%s> import failed: %s", imageHash, err.Error())
} }

View File

@@ -24,7 +24,7 @@ import (
const scanQueueSize = 200000 const scanQueueSize = 200000
type ScanJob struct { type ScanJob struct {
txnManager models.TransactionManager txnManager models.Repository
input ScanMetadataInput input ScanMetadataInput
subscriptions *subscriptionManager subscriptions *subscriptionManager
} }
@@ -220,20 +220,21 @@ func (j *ScanJob) doesPathExist(ctx context.Context, path string) bool {
gExt := config.GetGalleryExtensions() gExt := config.GetGalleryExtensions()
ret := false ret := false
txnErr := j.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { txnErr := j.txnManager.WithTxn(ctx, func(ctx context.Context) error {
r := j.txnManager
switch { switch {
case fsutil.MatchExtension(path, gExt): case fsutil.MatchExtension(path, gExt):
g, _ := r.Gallery().FindByPath(path) g, _ := r.Gallery.FindByPath(ctx, path)
if g != nil { if g != nil {
ret = true ret = true
} }
case fsutil.MatchExtension(path, vidExt): case fsutil.MatchExtension(path, vidExt):
s, _ := r.Scene().FindByPath(path) s, _ := r.Scene.FindByPath(ctx, path)
if s != nil { if s != nil {
ret = true ret = true
} }
case fsutil.MatchExtension(path, imgExt): case fsutil.MatchExtension(path, imgExt):
i, _ := r.Image().FindByPath(path) i, _ := r.Image.FindByPath(ctx, path)
if i != nil { if i != nil {
ret = true ret = true
} }
@@ -249,7 +250,7 @@ func (j *ScanJob) doesPathExist(ctx context.Context, path string) bool {
} }
type ScanTask struct { type ScanTask struct {
TxnManager models.TransactionManager TxnManager models.Repository
file file.SourceFile file file.SourceFile
UseFileMetadata bool UseFileMetadata bool
StripFileExtension bool StripFileExtension bool

View File

@@ -21,12 +21,12 @@ func (t *ScanTask) scanGallery(ctx context.Context) {
images := 0 images := 0
scanImages := false scanImages := false
if err := t.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error {
var err error var err error
g, err = r.Gallery().FindByPath(path) g, err = t.TxnManager.Gallery.FindByPath(ctx, path)
if g != nil && err == nil { if g != nil && err == nil {
images, err = r.Image().CountByGalleryID(g.ID) images, err = t.TxnManager.Image.CountByGalleryID(ctx, g.ID)
if err != nil { if err != nil {
return fmt.Errorf("error getting images for zip gallery %s: %s", path, err.Error()) return fmt.Errorf("error getting images for zip gallery %s: %s", path, err.Error())
} }
@@ -43,7 +43,7 @@ func (t *ScanTask) scanGallery(ctx context.Context) {
ImageExtensions: instance.Config.GetImageExtensions(), ImageExtensions: instance.Config.GetImageExtensions(),
StripFileExtension: t.StripFileExtension, StripFileExtension: t.StripFileExtension,
CaseSensitiveFs: t.CaseSensitiveFs, CaseSensitiveFs: t.CaseSensitiveFs,
TxnManager: t.TxnManager, CreatorUpdater: t.TxnManager.Gallery,
Paths: instance.Paths, Paths: instance.Paths,
PluginCache: instance.PluginCache, PluginCache: instance.PluginCache,
MutexManager: t.mutexManager, MutexManager: t.mutexManager,
@@ -79,10 +79,11 @@ func (t *ScanTask) scanGallery(ctx context.Context) {
// associates a gallery to a scene with the same basename // associates a gallery to a scene with the same basename
func (t *ScanTask) associateGallery(ctx context.Context, wg *sizedwaitgroup.SizedWaitGroup) { func (t *ScanTask) associateGallery(ctx context.Context, wg *sizedwaitgroup.SizedWaitGroup) {
path := t.file.Path() path := t.file.Path()
if err := t.TxnManager.WithTxn(ctx, func(r models.Repository) error { if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error {
qb := r.Gallery() r := t.TxnManager
sqb := r.Scene() qb := r.Gallery
g, err := qb.FindByPath(path) sqb := r.Scene
g, err := qb.FindByPath(ctx, path)
if err != nil { if err != nil {
return err return err
} }
@@ -106,10 +107,10 @@ func (t *ScanTask) associateGallery(ctx context.Context, wg *sizedwaitgroup.Size
} }
} }
for _, scenePath := range relatedFiles { for _, scenePath := range relatedFiles {
scene, _ := sqb.FindByPath(scenePath) scene, _ := sqb.FindByPath(ctx, scenePath)
// found related Scene // found related Scene
if scene != nil { if scene != nil {
sceneGalleries, _ := sqb.FindByGalleryID(g.ID) // check if gallery is already associated to the scene sceneGalleries, _ := sqb.FindByGalleryID(ctx, g.ID) // check if gallery is already associated to the scene
isAssoc := false isAssoc := false
for _, sg := range sceneGalleries { for _, sg := range sceneGalleries {
if scene.ID == sg.ID { if scene.ID == sg.ID {
@@ -119,7 +120,7 @@ func (t *ScanTask) associateGallery(ctx context.Context, wg *sizedwaitgroup.Size
} }
if !isAssoc { if !isAssoc {
logger.Infof("associate: Gallery %s is related to scene: %d", path, scene.ID) logger.Infof("associate: Gallery %s is related to scene: %d", path, scene.ID)
if err := sqb.UpdateGalleries(scene.ID, []int{g.ID}); err != nil { if err := sqb.UpdateGalleries(ctx, scene.ID, []int{g.ID}); err != nil {
return err return err
} }
} }
@@ -152,11 +153,11 @@ func (t *ScanTask) scanZipImages(ctx context.Context, zipGallery *models.Gallery
func (t *ScanTask) regenerateZipImages(ctx context.Context, zipGallery *models.Gallery) { func (t *ScanTask) regenerateZipImages(ctx context.Context, zipGallery *models.Gallery) {
var images []*models.Image var images []*models.Image
if err := t.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error {
iqb := r.Image() iqb := t.TxnManager.Image
var err error var err error
images, err = iqb.FindByGalleryID(zipGallery.ID) images, err = iqb.FindByGalleryID(ctx, zipGallery.ID)
return err return err
}); err != nil { }); err != nil {
logger.Warnf("failed to find gallery images: %s", err.Error()) logger.Warnf("failed to find gallery images: %s", err.Error())

View File

@@ -23,9 +23,9 @@ func (t *ScanTask) scanImage(ctx context.Context) {
var i *models.Image var i *models.Image
path := t.file.Path() path := t.file.Path()
if err := t.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error {
var err error var err error
i, err = r.Image().FindByPath(path) i, err = t.TxnManager.Image.FindByPath(ctx, path)
return err return err
}); err != nil { }); err != nil {
logger.Error(err.Error()) logger.Error(err.Error())
@@ -36,6 +36,8 @@ func (t *ScanTask) scanImage(ctx context.Context) {
Scanner: image.FileScanner(&file.FSHasher{}), Scanner: image.FileScanner(&file.FSHasher{}),
StripFileExtension: t.StripFileExtension, StripFileExtension: t.StripFileExtension,
TxnManager: t.TxnManager, TxnManager: t.TxnManager,
CreatorUpdater: t.TxnManager.Image,
CaseSensitiveFs: t.CaseSensitiveFs,
Paths: GetInstance().Paths, Paths: GetInstance().Paths,
PluginCache: instance.PluginCache, PluginCache: instance.PluginCache,
MutexManager: t.mutexManager, MutexManager: t.mutexManager,
@@ -58,8 +60,8 @@ func (t *ScanTask) scanImage(ctx context.Context) {
if i != nil { if i != nil {
if t.zipGallery != nil { if t.zipGallery != nil {
// associate with gallery // associate with gallery
if err := t.TxnManager.WithTxn(ctx, func(r models.Repository) error { if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error {
return gallery.AddImage(r.Gallery(), t.zipGallery.ID, i.ID) return gallery.AddImage(ctx, t.TxnManager.Gallery, t.zipGallery.ID, i.ID)
}); err != nil { }); err != nil {
logger.Error(err.Error()) logger.Error(err.Error())
return return
@@ -69,9 +71,9 @@ func (t *ScanTask) scanImage(ctx context.Context) {
logger.Infof("Associating image %s with folder gallery", i.Path) logger.Infof("Associating image %s with folder gallery", i.Path)
var galleryID int var galleryID int
var isNewGallery bool var isNewGallery bool
if err := t.TxnManager.WithTxn(ctx, func(r models.Repository) error { if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error {
var err error var err error
galleryID, isNewGallery, err = t.associateImageWithFolderGallery(i.ID, r.Gallery()) galleryID, isNewGallery, err = t.associateImageWithFolderGallery(ctx, i.ID, t.TxnManager.Gallery)
return err return err
}); err != nil { }); err != nil {
logger.Error(err.Error()) logger.Error(err.Error())
@@ -90,11 +92,17 @@ func (t *ScanTask) scanImage(ctx context.Context) {
} }
} }
func (t *ScanTask) associateImageWithFolderGallery(imageID int, qb models.GalleryReaderWriter) (galleryID int, isNew bool, err error) { type GalleryImageAssociator interface {
FindByPath(ctx context.Context, path string) (*models.Gallery, error)
Create(ctx context.Context, newGallery models.Gallery) (*models.Gallery, error)
gallery.ImageUpdater
}
func (t *ScanTask) associateImageWithFolderGallery(ctx context.Context, imageID int, qb GalleryImageAssociator) (galleryID int, isNew bool, err error) {
// find a gallery with the path specified // find a gallery with the path specified
path := filepath.Dir(t.file.Path()) path := filepath.Dir(t.file.Path())
var g *models.Gallery var g *models.Gallery
g, err = qb.FindByPath(path) g, err = qb.FindByPath(ctx, path)
if err != nil { if err != nil {
return return
} }
@@ -120,7 +128,7 @@ func (t *ScanTask) associateImageWithFolderGallery(imageID int, qb models.Galler
} }
logger.Infof("Creating gallery for folder %s", path) logger.Infof("Creating gallery for folder %s", path)
g, err = qb.Create(newGallery) g, err = qb.Create(ctx, newGallery)
if err != nil { if err != nil {
return 0, false, err return 0, false, err
} }
@@ -129,7 +137,7 @@ func (t *ScanTask) associateImageWithFolderGallery(imageID int, qb models.Galler
} }
// associate image with gallery // associate image with gallery
err = gallery.AddImage(qb, g.ID, imageID) err = gallery.AddImage(ctx, qb, g.ID, imageID)
galleryID = g.ID galleryID = g.ID
return return
} }

View File

@@ -34,9 +34,9 @@ func (t *ScanTask) scanScene(ctx context.Context) *models.Scene {
var retScene *models.Scene var retScene *models.Scene
var s *models.Scene var s *models.Scene
if err := t.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error {
var err error var err error
s, err = r.Scene().FindByPath(t.file.Path()) s, err = t.TxnManager.Scene.FindByPath(ctx, t.file.Path())
return err return err
}); err != nil { }); err != nil {
logger.Error(err.Error()) logger.Error(err.Error())
@@ -54,7 +54,9 @@ func (t *ScanTask) scanScene(ctx context.Context) *models.Scene {
StripFileExtension: t.StripFileExtension, StripFileExtension: t.StripFileExtension,
FileNamingAlgorithm: t.fileNamingAlgorithm, FileNamingAlgorithm: t.fileNamingAlgorithm,
TxnManager: t.TxnManager, TxnManager: t.TxnManager,
CreatorUpdater: t.TxnManager.Scene,
Paths: GetInstance().Paths, Paths: GetInstance().Paths,
CaseSensitiveFs: t.CaseSensitiveFs,
Screenshotter: &sceneScreenshotter{ Screenshotter: &sceneScreenshotter{
g: g, g: g,
}, },
@@ -88,12 +90,12 @@ func (t *ScanTask) associateCaptions(ctx context.Context) {
captionLang := scene.GetCaptionsLangFromPath(captionPath) captionLang := scene.GetCaptionsLangFromPath(captionPath)
relatedFiles := scene.GenerateCaptionCandidates(captionPath, vExt) relatedFiles := scene.GenerateCaptionCandidates(captionPath, vExt)
if err := t.TxnManager.WithTxn(ctx, func(r models.Repository) error { if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error {
var err error var err error
sqb := r.Scene() sqb := t.TxnManager.Scene
for _, scenePath := range relatedFiles { for _, scenePath := range relatedFiles {
s, er := sqb.FindByPath(scenePath) s, er := sqb.FindByPath(ctx, scenePath)
if er != nil { if er != nil {
logger.Errorf("Error searching for scene %s: %v", scenePath, er) logger.Errorf("Error searching for scene %s: %v", scenePath, er)
@@ -101,7 +103,7 @@ func (t *ScanTask) associateCaptions(ctx context.Context) {
} }
if s != nil { // found related Scene if s != nil { // found related Scene
logger.Debugf("Matched captions to scene %s", s.Path) logger.Debugf("Matched captions to scene %s", s.Path)
captions, er := sqb.GetCaptions(s.ID) captions, er := sqb.GetCaptions(ctx, s.ID)
if er == nil { if er == nil {
fileExt := filepath.Ext(captionPath) fileExt := filepath.Ext(captionPath)
ext := fileExt[1:] ext := fileExt[1:]
@@ -112,7 +114,7 @@ func (t *ScanTask) associateCaptions(ctx context.Context) {
CaptionType: ext, CaptionType: ext,
} }
captions = append(captions, newCaption) captions = append(captions, newCaption)
er = sqb.UpdateCaptions(s.ID, captions) er = sqb.UpdateCaptions(ctx, s.ID, captions)
if er == nil { if er == nil {
logger.Debugf("Updated captions for scene %s. Added %s", s.Path, captionLang) logger.Debugf("Updated captions for scene %s. Added %s", s.Path, captionLang)
} }

View File

@@ -10,11 +10,11 @@ import (
"github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/scraper/stashbox" "github.com/stashapp/stash/pkg/scraper/stashbox"
"github.com/stashapp/stash/pkg/txn"
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
) )
type StashBoxPerformerTagTask struct { type StashBoxPerformerTagTask struct {
txnManager models.TransactionManager
box *models.StashBox box *models.StashBox
name *string name *string
performer *models.Performer performer *models.Performer
@@ -41,12 +41,17 @@ func (t *StashBoxPerformerTagTask) stashBoxPerformerTag(ctx context.Context) {
var performer *models.ScrapedPerformer var performer *models.ScrapedPerformer
var err error var err error
client := stashbox.NewClient(*t.box, t.txnManager) client := stashbox.NewClient(*t.box, instance.Repository, stashbox.Repository{
Scene: instance.Repository.Scene,
Performer: instance.Repository.Performer,
Tag: instance.Repository.Tag,
Studio: instance.Repository.Studio,
})
if t.refresh { if t.refresh {
var performerID string var performerID string
txnErr := t.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error { txnErr := txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error {
stashids, _ := r.Performer().GetStashIDs(t.performer.ID) stashids, _ := instance.Repository.Performer.GetStashIDs(ctx, t.performer.ID)
for _, id := range stashids { for _, id := range stashids {
if id.Endpoint == t.box.Endpoint { if id.Endpoint == t.box.Endpoint {
performerID = id.StashID performerID = id.StashID
@@ -156,11 +161,12 @@ func (t *StashBoxPerformerTagTask) stashBoxPerformerTag(ctx context.Context) {
partial.URL = &value partial.URL = &value
} }
txnErr := t.txnManager.WithTxn(ctx, func(r models.Repository) error { txnErr := txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error {
_, err := r.Performer().Update(partial) r := instance.Repository
_, err := r.Performer.Update(ctx, partial)
if !t.refresh { if !t.refresh {
err = r.Performer().UpdateStashIDs(t.performer.ID, []models.StashID{ err = r.Performer.UpdateStashIDs(ctx, t.performer.ID, []models.StashID{
{ {
Endpoint: t.box.Endpoint, Endpoint: t.box.Endpoint,
StashID: *performer.RemoteSiteID, StashID: *performer.RemoteSiteID,
@@ -176,7 +182,7 @@ func (t *StashBoxPerformerTagTask) stashBoxPerformerTag(ctx context.Context) {
if err != nil { if err != nil {
return err return err
} }
err = r.Performer().UpdateImage(t.performer.ID, image) err = r.Performer.UpdateImage(ctx, t.performer.ID, image)
if err != nil { if err != nil {
return err return err
} }
@@ -218,13 +224,14 @@ func (t *StashBoxPerformerTagTask) stashBoxPerformerTag(ctx context.Context) {
URL: getNullString(performer.URL), URL: getNullString(performer.URL),
UpdatedAt: models.SQLiteTimestamp{Timestamp: currentTime}, UpdatedAt: models.SQLiteTimestamp{Timestamp: currentTime},
} }
err := t.txnManager.WithTxn(ctx, func(r models.Repository) error { err := txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error {
createdPerformer, err := r.Performer().Create(newPerformer) r := instance.Repository
createdPerformer, err := r.Performer.Create(ctx, newPerformer)
if err != nil { if err != nil {
return err return err
} }
err = r.Performer().UpdateStashIDs(createdPerformer.ID, []models.StashID{ err = r.Performer.UpdateStashIDs(ctx, createdPerformer.ID, []models.StashID{
{ {
Endpoint: t.box.Endpoint, Endpoint: t.box.Endpoint,
StashID: *performer.RemoteSiteID, StashID: *performer.RemoteSiteID,
@@ -239,7 +246,7 @@ func (t *StashBoxPerformerTagTask) stashBoxPerformerTag(ctx context.Context) {
if imageErr != nil { if imageErr != nil {
return imageErr return imageErr
} }
err = r.Performer().UpdateImage(createdPerformer.ID, image) err = r.Performer.UpdateImage(ctx, createdPerformer.ID, image)
} }
return err return err
}) })

View File

@@ -1,40 +0,0 @@
package database
import (
"context"
"github.com/jmoiron/sqlx"
"github.com/stashapp/stash/pkg/logger"
)
// WithTxn executes the provided function within a transaction. It rolls back
// the transaction if the function returns an error, otherwise the transaction
// is committed.
func WithTxn(fn func(tx *sqlx.Tx) error) error {
ctx := context.TODO()
tx := DB.MustBeginTx(ctx, nil)
var err error
defer func() {
if p := recover(); p != nil {
// a panic occurred, rollback and repanic
if err := tx.Rollback(); err != nil {
logger.Warnf("failure when performing transaction rollback: %v", err)
}
panic(p)
}
if err != nil {
// something went wrong, rollback
if err := tx.Rollback(); err != nil {
logger.Warnf("failure when performing transaction rollback: %v", err)
}
} else {
// all good, commit
err = tx.Commit()
}
}()
err = fn(tx)
return err
}

View File

@@ -1,9 +1,12 @@
package gallery package gallery
import ( import (
"context"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/models/json" "github.com/stashapp/stash/pkg/models/json"
"github.com/stashapp/stash/pkg/models/jsonschema" "github.com/stashapp/stash/pkg/models/jsonschema"
"github.com/stashapp/stash/pkg/studio"
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
) )
@@ -52,9 +55,9 @@ func ToBasicJSON(gallery *models.Gallery) (*jsonschema.Gallery, error) {
// GetStudioName returns the name of the provided gallery's studio. It returns an // GetStudioName returns the name of the provided gallery's studio. It returns an
// empty string if there is no studio assigned to the gallery. // empty string if there is no studio assigned to the gallery.
func GetStudioName(reader models.StudioReader, gallery *models.Gallery) (string, error) { func GetStudioName(ctx context.Context, reader studio.Finder, gallery *models.Gallery) (string, error) {
if gallery.StudioID.Valid { if gallery.StudioID.Valid {
studio, err := reader.Find(int(gallery.StudioID.Int64)) studio, err := reader.Find(ctx, int(gallery.StudioID.Int64))
if err != nil { if err != nil {
return "", err return "", err
} }

View File

@@ -154,15 +154,15 @@ func TestGetStudioName(t *testing.T) {
studioErr := errors.New("error getting image") studioErr := errors.New("error getting image")
mockStudioReader.On("Find", studioID).Return(&models.Studio{ mockStudioReader.On("Find", testCtx, studioID).Return(&models.Studio{
Name: models.NullString(studioName), Name: models.NullString(studioName),
}, nil).Once() }, nil).Once()
mockStudioReader.On("Find", missingStudioID).Return(nil, nil).Once() mockStudioReader.On("Find", testCtx, missingStudioID).Return(nil, nil).Once()
mockStudioReader.On("Find", errStudioID).Return(nil, studioErr).Once() mockStudioReader.On("Find", testCtx, errStudioID).Return(nil, studioErr).Once()
for i, s := range getStudioScenarios { for i, s := range getStudioScenarios {
gallery := s.input gallery := s.input
json, err := GetStudioName(mockStudioReader, &gallery) json, err := GetStudioName(testCtx, mockStudioReader, &gallery)
switch { switch {
case !s.err && err != nil: case !s.err && err != nil:

View File

@@ -1,20 +1,30 @@
package gallery package gallery
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"strings" "strings"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/models/jsonschema" "github.com/stashapp/stash/pkg/models/jsonschema"
"github.com/stashapp/stash/pkg/performer"
"github.com/stashapp/stash/pkg/sliceutil/stringslice" "github.com/stashapp/stash/pkg/sliceutil/stringslice"
"github.com/stashapp/stash/pkg/studio"
"github.com/stashapp/stash/pkg/tag"
) )
type FullCreatorUpdater interface {
FinderCreatorUpdater
UpdatePerformers(ctx context.Context, galleryID int, performerIDs []int) error
UpdateTags(ctx context.Context, galleryID int, tagIDs []int) error
}
type Importer struct { type Importer struct {
ReaderWriter models.GalleryReaderWriter ReaderWriter FullCreatorUpdater
StudioWriter models.StudioReaderWriter StudioWriter studio.NameFinderCreator
PerformerWriter models.PerformerReaderWriter PerformerWriter performer.NameFinderCreator
TagWriter models.TagReaderWriter TagWriter tag.NameFinderCreator
Input jsonschema.Gallery Input jsonschema.Gallery
MissingRefBehaviour models.ImportMissingRefEnum MissingRefBehaviour models.ImportMissingRefEnum
@@ -23,18 +33,18 @@ type Importer struct {
tags []*models.Tag tags []*models.Tag
} }
func (i *Importer) PreImport() error { func (i *Importer) PreImport(ctx context.Context) error {
i.gallery = i.galleryJSONToGallery(i.Input) i.gallery = i.galleryJSONToGallery(i.Input)
if err := i.populateStudio(); err != nil { if err := i.populateStudio(ctx); err != nil {
return err return err
} }
if err := i.populatePerformers(); err != nil { if err := i.populatePerformers(ctx); err != nil {
return err return err
} }
if err := i.populateTags(); err != nil { if err := i.populateTags(ctx); err != nil {
return err return err
} }
@@ -74,9 +84,9 @@ func (i *Importer) galleryJSONToGallery(galleryJSON jsonschema.Gallery) models.G
return newGallery return newGallery
} }
func (i *Importer) populateStudio() error { func (i *Importer) populateStudio(ctx context.Context) error {
if i.Input.Studio != "" { if i.Input.Studio != "" {
studio, err := i.StudioWriter.FindByName(i.Input.Studio, false) studio, err := i.StudioWriter.FindByName(ctx, i.Input.Studio, false)
if err != nil { if err != nil {
return fmt.Errorf("error finding studio by name: %v", err) return fmt.Errorf("error finding studio by name: %v", err)
} }
@@ -91,7 +101,7 @@ func (i *Importer) populateStudio() error {
} }
if i.MissingRefBehaviour == models.ImportMissingRefEnumCreate { if i.MissingRefBehaviour == models.ImportMissingRefEnumCreate {
studioID, err := i.createStudio(i.Input.Studio) studioID, err := i.createStudio(ctx, i.Input.Studio)
if err != nil { if err != nil {
return err return err
} }
@@ -108,10 +118,10 @@ func (i *Importer) populateStudio() error {
return nil return nil
} }
func (i *Importer) createStudio(name string) (int, error) { func (i *Importer) createStudio(ctx context.Context, name string) (int, error) {
newStudio := *models.NewStudio(name) newStudio := *models.NewStudio(name)
created, err := i.StudioWriter.Create(newStudio) created, err := i.StudioWriter.Create(ctx, newStudio)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@@ -119,10 +129,10 @@ func (i *Importer) createStudio(name string) (int, error) {
return created.ID, nil return created.ID, nil
} }
func (i *Importer) populatePerformers() error { func (i *Importer) populatePerformers(ctx context.Context) error {
if len(i.Input.Performers) > 0 { if len(i.Input.Performers) > 0 {
names := i.Input.Performers names := i.Input.Performers
performers, err := i.PerformerWriter.FindByNames(names, false) performers, err := i.PerformerWriter.FindByNames(ctx, names, false)
if err != nil { if err != nil {
return err return err
} }
@@ -145,7 +155,7 @@ func (i *Importer) populatePerformers() error {
} }
if i.MissingRefBehaviour == models.ImportMissingRefEnumCreate { if i.MissingRefBehaviour == models.ImportMissingRefEnumCreate {
createdPerformers, err := i.createPerformers(missingPerformers) createdPerformers, err := i.createPerformers(ctx, missingPerformers)
if err != nil { if err != nil {
return fmt.Errorf("error creating gallery performers: %v", err) return fmt.Errorf("error creating gallery performers: %v", err)
} }
@@ -162,12 +172,12 @@ func (i *Importer) populatePerformers() error {
return nil return nil
} }
func (i *Importer) createPerformers(names []string) ([]*models.Performer, error) { func (i *Importer) createPerformers(ctx context.Context, names []string) ([]*models.Performer, error) {
var ret []*models.Performer var ret []*models.Performer
for _, name := range names { for _, name := range names {
newPerformer := *models.NewPerformer(name) newPerformer := *models.NewPerformer(name)
created, err := i.PerformerWriter.Create(newPerformer) created, err := i.PerformerWriter.Create(ctx, newPerformer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -178,10 +188,10 @@ func (i *Importer) createPerformers(names []string) ([]*models.Performer, error)
return ret, nil return ret, nil
} }
func (i *Importer) populateTags() error { func (i *Importer) populateTags(ctx context.Context) error {
if len(i.Input.Tags) > 0 { if len(i.Input.Tags) > 0 {
names := i.Input.Tags names := i.Input.Tags
tags, err := i.TagWriter.FindByNames(names, false) tags, err := i.TagWriter.FindByNames(ctx, names, false)
if err != nil { if err != nil {
return err return err
} }
@@ -201,7 +211,7 @@ func (i *Importer) populateTags() error {
} }
if i.MissingRefBehaviour == models.ImportMissingRefEnumCreate { if i.MissingRefBehaviour == models.ImportMissingRefEnumCreate {
createdTags, err := i.createTags(missingTags) createdTags, err := i.createTags(ctx, missingTags)
if err != nil { if err != nil {
return fmt.Errorf("error creating gallery tags: %v", err) return fmt.Errorf("error creating gallery tags: %v", err)
} }
@@ -218,12 +228,12 @@ func (i *Importer) populateTags() error {
return nil return nil
} }
func (i *Importer) createTags(names []string) ([]*models.Tag, error) { func (i *Importer) createTags(ctx context.Context, names []string) ([]*models.Tag, error) {
var ret []*models.Tag var ret []*models.Tag
for _, name := range names { for _, name := range names {
newTag := *models.NewTag(name) newTag := *models.NewTag(name)
created, err := i.TagWriter.Create(newTag) created, err := i.TagWriter.Create(ctx, newTag)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -234,14 +244,14 @@ func (i *Importer) createTags(names []string) ([]*models.Tag, error) {
return ret, nil return ret, nil
} }
func (i *Importer) PostImport(id int) error { func (i *Importer) PostImport(ctx context.Context, id int) error {
if len(i.performers) > 0 { if len(i.performers) > 0 {
var performerIDs []int var performerIDs []int
for _, performer := range i.performers { for _, performer := range i.performers {
performerIDs = append(performerIDs, performer.ID) performerIDs = append(performerIDs, performer.ID)
} }
if err := i.ReaderWriter.UpdatePerformers(id, performerIDs); err != nil { if err := i.ReaderWriter.UpdatePerformers(ctx, id, performerIDs); err != nil {
return fmt.Errorf("failed to associate performers: %v", err) return fmt.Errorf("failed to associate performers: %v", err)
} }
} }
@@ -251,7 +261,7 @@ func (i *Importer) PostImport(id int) error {
for _, t := range i.tags { for _, t := range i.tags {
tagIDs = append(tagIDs, t.ID) tagIDs = append(tagIDs, t.ID)
} }
if err := i.ReaderWriter.UpdateTags(id, tagIDs); err != nil { if err := i.ReaderWriter.UpdateTags(ctx, id, tagIDs); err != nil {
return fmt.Errorf("failed to associate tags: %v", err) return fmt.Errorf("failed to associate tags: %v", err)
} }
} }
@@ -263,8 +273,8 @@ func (i *Importer) Name() string {
return i.Input.Path return i.Input.Path
} }
func (i *Importer) FindExistingID() (*int, error) { func (i *Importer) FindExistingID(ctx context.Context) (*int, error) {
existing, err := i.ReaderWriter.FindByChecksum(i.Input.Checksum) existing, err := i.ReaderWriter.FindByChecksum(ctx, i.Input.Checksum)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -277,8 +287,8 @@ func (i *Importer) FindExistingID() (*int, error) {
return nil, nil return nil, nil
} }
func (i *Importer) Create() (*int, error) { func (i *Importer) Create(ctx context.Context) (*int, error) {
created, err := i.ReaderWriter.Create(i.gallery) created, err := i.ReaderWriter.Create(ctx, i.gallery)
if err != nil { if err != nil {
return nil, fmt.Errorf("error creating gallery: %v", err) return nil, fmt.Errorf("error creating gallery: %v", err)
} }
@@ -287,10 +297,10 @@ func (i *Importer) Create() (*int, error) {
return &id, nil return &id, nil
} }
func (i *Importer) Update(id int) error { func (i *Importer) Update(ctx context.Context, id int) error {
gallery := i.gallery gallery := i.gallery
gallery.ID = id gallery.ID = id
_, err := i.ReaderWriter.Update(gallery) _, err := i.ReaderWriter.Update(ctx, gallery)
if err != nil { if err != nil {
return fmt.Errorf("error updating existing gallery: %v", err) return fmt.Errorf("error updating existing gallery: %v", err)
} }

View File

@@ -1,6 +1,7 @@
package gallery package gallery
import ( import (
"context"
"errors" "errors"
"testing" "testing"
"time" "time"
@@ -40,6 +41,8 @@ const (
errChecksum = "errChecksum" errChecksum = "errChecksum"
) )
var testCtx = context.Background()
var ( var (
createdAt = time.Date(2001, time.January, 2, 1, 2, 3, 4, time.Local) createdAt = time.Date(2001, time.January, 2, 1, 2, 3, 4, time.Local)
updatedAt = time.Date(2002, time.January, 2, 1, 2, 3, 4, time.Local) updatedAt = time.Date(2002, time.January, 2, 1, 2, 3, 4, time.Local)
@@ -75,7 +78,7 @@ func TestImporterPreImport(t *testing.T) {
}, },
} }
err := i.PreImport() err := i.PreImport(testCtx)
assert.Nil(t, err) assert.Nil(t, err)
expectedGallery := models.Gallery{ expectedGallery := models.Gallery{
@@ -112,17 +115,17 @@ func TestImporterPreImportWithStudio(t *testing.T) {
}, },
} }
studioReaderWriter.On("FindByName", existingStudioName, false).Return(&models.Studio{ studioReaderWriter.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{
ID: existingStudioID, ID: existingStudioID,
}, nil).Once() }, nil).Once()
studioReaderWriter.On("FindByName", existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once() studioReaderWriter.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once()
err := i.PreImport() err := i.PreImport(testCtx)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, int64(existingStudioID), i.gallery.StudioID.Int64) assert.Equal(t, int64(existingStudioID), i.gallery.StudioID.Int64)
i.Input.Studio = existingStudioErr i.Input.Studio = existingStudioErr
err = i.PreImport() err = i.PreImport(testCtx)
assert.NotNil(t, err) assert.NotNil(t, err)
studioReaderWriter.AssertExpectations(t) studioReaderWriter.AssertExpectations(t)
@@ -140,20 +143,20 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail, MissingRefBehaviour: models.ImportMissingRefEnumFail,
} }
studioReaderWriter.On("FindByName", missingStudioName, false).Return(nil, nil).Times(3) studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3)
studioReaderWriter.On("Create", mock.AnythingOfType("models.Studio")).Return(&models.Studio{ studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Studio")).Return(&models.Studio{
ID: existingStudioID, ID: existingStudioID,
}, nil) }, nil)
err := i.PreImport() err := i.PreImport(testCtx)
assert.NotNil(t, err) assert.NotNil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore
err = i.PreImport() err = i.PreImport(testCtx)
assert.Nil(t, err) assert.Nil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumCreate i.MissingRefBehaviour = models.ImportMissingRefEnumCreate
err = i.PreImport() err = i.PreImport(testCtx)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, int64(existingStudioID), i.gallery.StudioID.Int64) assert.Equal(t, int64(existingStudioID), i.gallery.StudioID.Int64)
@@ -172,10 +175,10 @@ func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate, MissingRefBehaviour: models.ImportMissingRefEnumCreate,
} }
studioReaderWriter.On("FindByName", missingStudioName, false).Return(nil, nil).Once() studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once()
studioReaderWriter.On("Create", mock.AnythingOfType("models.Studio")).Return(nil, errors.New("Create error")) studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Studio")).Return(nil, errors.New("Create error"))
err := i.PreImport() err := i.PreImport(testCtx)
assert.NotNil(t, err) assert.NotNil(t, err)
} }
@@ -193,20 +196,20 @@ func TestImporterPreImportWithPerformer(t *testing.T) {
}, },
} }
performerReaderWriter.On("FindByNames", []string{existingPerformerName}, false).Return([]*models.Performer{ performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerName}, false).Return([]*models.Performer{
{ {
ID: existingPerformerID, ID: existingPerformerID,
Name: models.NullString(existingPerformerName), Name: models.NullString(existingPerformerName),
}, },
}, nil).Once() }, nil).Once()
performerReaderWriter.On("FindByNames", []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once() performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once()
err := i.PreImport() err := i.PreImport(testCtx)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, existingPerformerID, i.performers[0].ID) assert.Equal(t, existingPerformerID, i.performers[0].ID)
i.Input.Performers = []string{existingPerformerErr} i.Input.Performers = []string{existingPerformerErr}
err = i.PreImport() err = i.PreImport(testCtx)
assert.NotNil(t, err) assert.NotNil(t, err)
performerReaderWriter.AssertExpectations(t) performerReaderWriter.AssertExpectations(t)
@@ -226,20 +229,20 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail, MissingRefBehaviour: models.ImportMissingRefEnumFail,
} }
performerReaderWriter.On("FindByNames", []string{missingPerformerName}, false).Return(nil, nil).Times(3) performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3)
performerReaderWriter.On("Create", mock.AnythingOfType("models.Performer")).Return(&models.Performer{ performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Performer")).Return(&models.Performer{
ID: existingPerformerID, ID: existingPerformerID,
}, nil) }, nil)
err := i.PreImport() err := i.PreImport(testCtx)
assert.NotNil(t, err) assert.NotNil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore
err = i.PreImport() err = i.PreImport(testCtx)
assert.Nil(t, err) assert.Nil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumCreate i.MissingRefBehaviour = models.ImportMissingRefEnumCreate
err = i.PreImport() err = i.PreImport(testCtx)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, existingPerformerID, i.performers[0].ID) assert.Equal(t, existingPerformerID, i.performers[0].ID)
@@ -260,10 +263,10 @@ func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate, MissingRefBehaviour: models.ImportMissingRefEnumCreate,
} }
performerReaderWriter.On("FindByNames", []string{missingPerformerName}, false).Return(nil, nil).Once() performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once()
performerReaderWriter.On("Create", mock.AnythingOfType("models.Performer")).Return(nil, errors.New("Create error")) performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Performer")).Return(nil, errors.New("Create error"))
err := i.PreImport() err := i.PreImport(testCtx)
assert.NotNil(t, err) assert.NotNil(t, err)
} }
@@ -281,20 +284,20 @@ func TestImporterPreImportWithTag(t *testing.T) {
}, },
} }
tagReaderWriter.On("FindByNames", []string{existingTagName}, false).Return([]*models.Tag{ tagReaderWriter.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{
{ {
ID: existingTagID, ID: existingTagID,
Name: existingTagName, Name: existingTagName,
}, },
}, nil).Once() }, nil).Once()
tagReaderWriter.On("FindByNames", []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once() tagReaderWriter.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once()
err := i.PreImport() err := i.PreImport(testCtx)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, existingTagID, i.tags[0].ID) assert.Equal(t, existingTagID, i.tags[0].ID)
i.Input.Tags = []string{existingTagErr} i.Input.Tags = []string{existingTagErr}
err = i.PreImport() err = i.PreImport(testCtx)
assert.NotNil(t, err) assert.NotNil(t, err)
tagReaderWriter.AssertExpectations(t) tagReaderWriter.AssertExpectations(t)
@@ -314,20 +317,20 @@ func TestImporterPreImportWithMissingTag(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail, MissingRefBehaviour: models.ImportMissingRefEnumFail,
} }
tagReaderWriter.On("FindByNames", []string{missingTagName}, false).Return(nil, nil).Times(3) tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3)
tagReaderWriter.On("Create", mock.AnythingOfType("models.Tag")).Return(&models.Tag{ tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Tag")).Return(&models.Tag{
ID: existingTagID, ID: existingTagID,
}, nil) }, nil)
err := i.PreImport() err := i.PreImport(testCtx)
assert.NotNil(t, err) assert.NotNil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore
err = i.PreImport() err = i.PreImport(testCtx)
assert.Nil(t, err) assert.Nil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumCreate i.MissingRefBehaviour = models.ImportMissingRefEnumCreate
err = i.PreImport() err = i.PreImport(testCtx)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, existingTagID, i.tags[0].ID) assert.Equal(t, existingTagID, i.tags[0].ID)
@@ -348,10 +351,10 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate, MissingRefBehaviour: models.ImportMissingRefEnumCreate,
} }
tagReaderWriter.On("FindByNames", []string{missingTagName}, false).Return(nil, nil).Once() tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once()
tagReaderWriter.On("Create", mock.AnythingOfType("models.Tag")).Return(nil, errors.New("Create error")) tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Tag")).Return(nil, errors.New("Create error"))
err := i.PreImport() err := i.PreImport(testCtx)
assert.NotNil(t, err) assert.NotNil(t, err)
} }
@@ -369,13 +372,13 @@ func TestImporterPostImportUpdatePerformers(t *testing.T) {
updateErr := errors.New("UpdatePerformers error") updateErr := errors.New("UpdatePerformers error")
galleryReaderWriter.On("UpdatePerformers", galleryID, []int{existingPerformerID}).Return(nil).Once() galleryReaderWriter.On("UpdatePerformers", testCtx, galleryID, []int{existingPerformerID}).Return(nil).Once()
galleryReaderWriter.On("UpdatePerformers", errPerformersID, mock.AnythingOfType("[]int")).Return(updateErr).Once() galleryReaderWriter.On("UpdatePerformers", testCtx, errPerformersID, mock.AnythingOfType("[]int")).Return(updateErr).Once()
err := i.PostImport(galleryID) err := i.PostImport(testCtx, galleryID)
assert.Nil(t, err) assert.Nil(t, err)
err = i.PostImport(errPerformersID) err = i.PostImport(testCtx, errPerformersID)
assert.NotNil(t, err) assert.NotNil(t, err)
galleryReaderWriter.AssertExpectations(t) galleryReaderWriter.AssertExpectations(t)
@@ -395,13 +398,13 @@ func TestImporterPostImportUpdateTags(t *testing.T) {
updateErr := errors.New("UpdateTags error") updateErr := errors.New("UpdateTags error")
galleryReaderWriter.On("UpdateTags", galleryID, []int{existingTagID}).Return(nil).Once() galleryReaderWriter.On("UpdateTags", testCtx, galleryID, []int{existingTagID}).Return(nil).Once()
galleryReaderWriter.On("UpdateTags", errTagsID, mock.AnythingOfType("[]int")).Return(updateErr).Once() galleryReaderWriter.On("UpdateTags", testCtx, errTagsID, mock.AnythingOfType("[]int")).Return(updateErr).Once()
err := i.PostImport(galleryID) err := i.PostImport(testCtx, galleryID)
assert.Nil(t, err) assert.Nil(t, err)
err = i.PostImport(errTagsID) err = i.PostImport(testCtx, errTagsID)
assert.NotNil(t, err) assert.NotNil(t, err)
galleryReaderWriter.AssertExpectations(t) galleryReaderWriter.AssertExpectations(t)
@@ -419,23 +422,23 @@ func TestImporterFindExistingID(t *testing.T) {
} }
expectedErr := errors.New("FindBy* error") expectedErr := errors.New("FindBy* error")
readerWriter.On("FindByChecksum", missingChecksum).Return(nil, nil).Once() readerWriter.On("FindByChecksum", testCtx, missingChecksum).Return(nil, nil).Once()
readerWriter.On("FindByChecksum", checksum).Return(&models.Gallery{ readerWriter.On("FindByChecksum", testCtx, checksum).Return(&models.Gallery{
ID: existingGalleryID, ID: existingGalleryID,
}, nil).Once() }, nil).Once()
readerWriter.On("FindByChecksum", errChecksum).Return(nil, expectedErr).Once() readerWriter.On("FindByChecksum", testCtx, errChecksum).Return(nil, expectedErr).Once()
id, err := i.FindExistingID() id, err := i.FindExistingID(testCtx)
assert.Nil(t, id) assert.Nil(t, id)
assert.Nil(t, err) assert.Nil(t, err)
i.Input.Checksum = checksum i.Input.Checksum = checksum
id, err = i.FindExistingID() id, err = i.FindExistingID(testCtx)
assert.Equal(t, existingGalleryID, *id) assert.Equal(t, existingGalleryID, *id)
assert.Nil(t, err) assert.Nil(t, err)
i.Input.Checksum = errChecksum i.Input.Checksum = errChecksum
id, err = i.FindExistingID() id, err = i.FindExistingID(testCtx)
assert.Nil(t, id) assert.Nil(t, id)
assert.NotNil(t, err) assert.NotNil(t, err)
@@ -459,17 +462,17 @@ func TestCreate(t *testing.T) {
} }
errCreate := errors.New("Create error") errCreate := errors.New("Create error")
readerWriter.On("Create", gallery).Return(&models.Gallery{ readerWriter.On("Create", testCtx, gallery).Return(&models.Gallery{
ID: galleryID, ID: galleryID,
}, nil).Once() }, nil).Once()
readerWriter.On("Create", galleryErr).Return(nil, errCreate).Once() readerWriter.On("Create", testCtx, galleryErr).Return(nil, errCreate).Once()
id, err := i.Create() id, err := i.Create(testCtx)
assert.Equal(t, galleryID, *id) assert.Equal(t, galleryID, *id)
assert.Nil(t, err) assert.Nil(t, err)
i.gallery = galleryErr i.gallery = galleryErr
id, err = i.Create() id, err = i.Create(testCtx)
assert.Nil(t, id) assert.Nil(t, id)
assert.NotNil(t, err) assert.NotNil(t, err)
@@ -490,9 +493,9 @@ func TestUpdate(t *testing.T) {
// id needs to be set for the mock input // id needs to be set for the mock input
gallery.ID = galleryID gallery.ID = galleryID
readerWriter.On("Update", gallery).Return(nil, nil).Once() readerWriter.On("Update", testCtx, gallery).Return(nil, nil).Once()
err := i.Update(galleryID) err := i.Update(testCtx, galleryID)
assert.Nil(t, err) assert.Nil(t, err)
readerWriter.AssertExpectations(t) readerWriter.AssertExpectations(t)

View File

@@ -1,12 +1,25 @@
package gallery package gallery
import ( import (
"context"
"strconv" "strconv"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
) )
func CountByPerformerID(r models.GalleryReader, id int) (int, error) { type Queryer interface {
Query(ctx context.Context, galleryFilter *models.GalleryFilterType, findFilter *models.FindFilterType) ([]*models.Gallery, int, error)
}
type CountQueryer interface {
QueryCount(ctx context.Context, galleryFilter *models.GalleryFilterType, findFilter *models.FindFilterType) (int, error)
}
type ChecksumsFinder interface {
FindByChecksums(ctx context.Context, checksums []string) ([]*models.Gallery, error)
}
func CountByPerformerID(ctx context.Context, r CountQueryer, id int) (int, error) {
filter := &models.GalleryFilterType{ filter := &models.GalleryFilterType{
Performers: &models.MultiCriterionInput{ Performers: &models.MultiCriterionInput{
Value: []string{strconv.Itoa(id)}, Value: []string{strconv.Itoa(id)},
@@ -14,10 +27,10 @@ func CountByPerformerID(r models.GalleryReader, id int) (int, error) {
}, },
} }
return r.QueryCount(filter, nil) return r.QueryCount(ctx, filter, nil)
} }
func CountByStudioID(r models.GalleryReader, id int) (int, error) { func CountByStudioID(ctx context.Context, r CountQueryer, id int) (int, error) {
filter := &models.GalleryFilterType{ filter := &models.GalleryFilterType{
Studios: &models.HierarchicalMultiCriterionInput{ Studios: &models.HierarchicalMultiCriterionInput{
Value: []string{strconv.Itoa(id)}, Value: []string{strconv.Itoa(id)},
@@ -25,10 +38,10 @@ func CountByStudioID(r models.GalleryReader, id int) (int, error) {
}, },
} }
return r.QueryCount(filter, nil) return r.QueryCount(ctx, filter, nil)
} }
func CountByTagID(r models.GalleryReader, id int) (int, error) { func CountByTagID(ctx context.Context, r CountQueryer, id int) (int, error) {
filter := &models.GalleryFilterType{ filter := &models.GalleryFilterType{
Tags: &models.HierarchicalMultiCriterionInput{ Tags: &models.HierarchicalMultiCriterionInput{
Value: []string{strconv.Itoa(id)}, Value: []string{strconv.Itoa(id)},
@@ -36,5 +49,5 @@ func CountByTagID(r models.GalleryReader, id int) (int, error) {
}, },
} }
return r.QueryCount(filter, nil) return r.QueryCount(ctx, filter, nil)
} }

View File

@@ -14,18 +14,26 @@ import (
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/models/paths" "github.com/stashapp/stash/pkg/models/paths"
"github.com/stashapp/stash/pkg/plugin" "github.com/stashapp/stash/pkg/plugin"
"github.com/stashapp/stash/pkg/txn"
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
) )
const mutexType = "gallery" const mutexType = "gallery"
type FinderCreatorUpdater interface {
FindByChecksum(ctx context.Context, checksum string) (*models.Gallery, error)
Create(ctx context.Context, newGallery models.Gallery) (*models.Gallery, error)
Update(ctx context.Context, updatedGallery models.Gallery) (*models.Gallery, error)
}
type Scanner struct { type Scanner struct {
file.Scanner file.Scanner
ImageExtensions []string ImageExtensions []string
StripFileExtension bool StripFileExtension bool
CaseSensitiveFs bool CaseSensitiveFs bool
TxnManager models.TransactionManager TxnManager txn.Manager
CreatorUpdater FinderCreatorUpdater
Paths *paths.Paths Paths *paths.Paths
PluginCache *plugin.Cache PluginCache *plugin.Cache
MutexManager *utils.MutexManager MutexManager *utils.MutexManager
@@ -75,19 +83,19 @@ func (scanner *Scanner) ScanExisting(ctx context.Context, existing file.FileBase
done := make(chan struct{}) done := make(chan struct{})
scanner.MutexManager.Claim(mutexType, scanned.New.Checksum, done) scanner.MutexManager.Claim(mutexType, scanned.New.Checksum, done)
if err := scanner.TxnManager.WithTxn(ctx, func(r models.Repository) error { if err := txn.WithTxn(ctx, scanner.TxnManager, func(ctx context.Context) error {
// free the mutex once transaction is complete // free the mutex once transaction is complete
defer close(done) defer close(done)
// ensure no clashes of hashes // ensure no clashes of hashes
if scanned.New.Checksum != "" && scanned.Old.Checksum != scanned.New.Checksum { if scanned.New.Checksum != "" && scanned.Old.Checksum != scanned.New.Checksum {
dupe, _ := r.Gallery().FindByChecksum(retGallery.Checksum) dupe, _ := scanner.CreatorUpdater.FindByChecksum(ctx, retGallery.Checksum)
if dupe != nil { if dupe != nil {
return fmt.Errorf("MD5 for file %s is the same as that of %s", path, dupe.Path.String) return fmt.Errorf("MD5 for file %s is the same as that of %s", path, dupe.Path.String)
} }
} }
retGallery, err = r.Gallery().Update(*retGallery) retGallery, err = scanner.CreatorUpdater.Update(ctx, *retGallery)
return err return err
}); err != nil { }); err != nil {
return nil, false, err return nil, false, err
@@ -116,10 +124,10 @@ func (scanner *Scanner) ScanNew(ctx context.Context, file file.SourceFile) (retG
scanner.MutexManager.Claim(mutexType, checksum, done) scanner.MutexManager.Claim(mutexType, checksum, done)
defer close(done) defer close(done)
if err := scanner.TxnManager.WithTxn(ctx, func(r models.Repository) error { if err := txn.WithTxn(ctx, scanner.TxnManager, func(ctx context.Context) error {
qb := r.Gallery() qb := scanner.CreatorUpdater
g, _ = qb.FindByChecksum(checksum) g, _ = qb.FindByChecksum(ctx, checksum)
if g != nil { if g != nil {
exists, _ := fsutil.FileExists(g.Path.String) exists, _ := fsutil.FileExists(g.Path.String)
if !scanner.CaseSensitiveFs { if !scanner.CaseSensitiveFs {
@@ -138,7 +146,7 @@ func (scanner *Scanner) ScanNew(ctx context.Context, file file.SourceFile) (retG
String: path, String: path,
Valid: true, Valid: true,
} }
g, err = qb.Update(*g) g, err = qb.Update(ctx, *g)
if err != nil { if err != nil {
return err return err
} }
@@ -167,7 +175,7 @@ func (scanner *Scanner) ScanNew(ctx context.Context, file file.SourceFile) (retG
} }
logger.Infof("%s doesn't exist. Creating new item...", path) logger.Infof("%s doesn't exist. Creating new item...", path)
g, err = qb.Create(*g) g, err = qb.Create(ctx, *g)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -1,29 +1,50 @@
package gallery package gallery
import ( import (
"context"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/sliceutil/intslice" "github.com/stashapp/stash/pkg/sliceutil/intslice"
) )
func UpdateFileModTime(qb models.GalleryWriter, id int, modTime models.NullSQLiteTimestamp) (*models.Gallery, error) { type PartialUpdater interface {
return qb.UpdatePartial(models.GalleryPartial{ UpdatePartial(ctx context.Context, updatedGallery models.GalleryPartial) (*models.Gallery, error)
}
type ImageUpdater interface {
GetImageIDs(ctx context.Context, galleryID int) ([]int, error)
UpdateImages(ctx context.Context, galleryID int, imageIDs []int) error
}
type PerformerUpdater interface {
GetPerformerIDs(ctx context.Context, galleryID int) ([]int, error)
UpdatePerformers(ctx context.Context, galleryID int, performerIDs []int) error
}
type TagUpdater interface {
GetTagIDs(ctx context.Context, galleryID int) ([]int, error)
UpdateTags(ctx context.Context, galleryID int, tagIDs []int) error
}
func UpdateFileModTime(ctx context.Context, qb PartialUpdater, id int, modTime models.NullSQLiteTimestamp) (*models.Gallery, error) {
return qb.UpdatePartial(ctx, models.GalleryPartial{
ID: id, ID: id,
FileModTime: &modTime, FileModTime: &modTime,
}) })
} }
func AddImage(qb models.GalleryReaderWriter, galleryID int, imageID int) error { func AddImage(ctx context.Context, qb ImageUpdater, galleryID int, imageID int) error {
imageIDs, err := qb.GetImageIDs(galleryID) imageIDs, err := qb.GetImageIDs(ctx, galleryID)
if err != nil { if err != nil {
return err return err
} }
imageIDs = intslice.IntAppendUnique(imageIDs, imageID) imageIDs = intslice.IntAppendUnique(imageIDs, imageID)
return qb.UpdateImages(galleryID, imageIDs) return qb.UpdateImages(ctx, galleryID, imageIDs)
} }
func AddPerformer(qb models.GalleryReaderWriter, id int, performerID int) (bool, error) { func AddPerformer(ctx context.Context, qb PerformerUpdater, id int, performerID int) (bool, error) {
performerIDs, err := qb.GetPerformerIDs(id) performerIDs, err := qb.GetPerformerIDs(ctx, id)
if err != nil { if err != nil {
return false, err return false, err
} }
@@ -32,7 +53,7 @@ func AddPerformer(qb models.GalleryReaderWriter, id int, performerID int) (bool,
performerIDs = intslice.IntAppendUnique(performerIDs, performerID) performerIDs = intslice.IntAppendUnique(performerIDs, performerID)
if len(performerIDs) != oldLen { if len(performerIDs) != oldLen {
if err := qb.UpdatePerformers(id, performerIDs); err != nil { if err := qb.UpdatePerformers(ctx, id, performerIDs); err != nil {
return false, err return false, err
} }
@@ -42,8 +63,8 @@ func AddPerformer(qb models.GalleryReaderWriter, id int, performerID int) (bool,
return false, nil return false, nil
} }
func AddTag(qb models.GalleryReaderWriter, id int, tagID int) (bool, error) { func AddTag(ctx context.Context, qb TagUpdater, id int, tagID int) (bool, error) {
tagIDs, err := qb.GetTagIDs(id) tagIDs, err := qb.GetTagIDs(ctx, id)
if err != nil { if err != nil {
return false, err return false, err
} }
@@ -52,7 +73,7 @@ func AddTag(qb models.GalleryReaderWriter, id int, tagID int) (bool, error) {
tagIDs = intslice.IntAppendUnique(tagIDs, tagID) tagIDs = intslice.IntAppendUnique(tagIDs, tagID)
if len(tagIDs) != oldLen { if len(tagIDs) != oldLen {
if err := qb.UpdateTags(id, tagIDs); err != nil { if err := qb.UpdateTags(ctx, id, tagIDs); err != nil {
return false, err return false, err
} }

View File

@@ -1,6 +1,8 @@
package image package image
import ( import (
"context"
"github.com/stashapp/stash/pkg/file" "github.com/stashapp/stash/pkg/file"
"github.com/stashapp/stash/pkg/fsutil" "github.com/stashapp/stash/pkg/fsutil"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
@@ -8,7 +10,7 @@ import (
) )
type Destroyer interface { type Destroyer interface {
Destroy(id int) error Destroy(ctx context.Context, id int) error
} }
// FileDeleter is an extension of file.Deleter that handles deletion of image files. // FileDeleter is an extension of file.Deleter that handles deletion of image files.
@@ -30,7 +32,7 @@ func (d *FileDeleter) MarkGeneratedFiles(image *models.Image) error {
} }
// Destroy destroys an image, optionally marking the file and generated files for deletion. // Destroy destroys an image, optionally marking the file and generated files for deletion.
func Destroy(i *models.Image, destroyer Destroyer, fileDeleter *FileDeleter, deleteGenerated, deleteFile bool) error { func Destroy(ctx context.Context, i *models.Image, destroyer Destroyer, fileDeleter *FileDeleter, deleteGenerated, deleteFile bool) error {
// don't try to delete if the image is in a zip file // don't try to delete if the image is in a zip file
if deleteFile && !file.IsZipPath(i.Path) { if deleteFile && !file.IsZipPath(i.Path) {
if err := fileDeleter.Files([]string{i.Path}); err != nil { if err := fileDeleter.Files([]string{i.Path}); err != nil {
@@ -44,5 +46,5 @@ func Destroy(i *models.Image, destroyer Destroyer, fileDeleter *FileDeleter, del
} }
} }
return destroyer.Destroy(i.ID) return destroyer.Destroy(ctx, i.ID)
} }

View File

@@ -1,9 +1,12 @@
package image package image
import ( import (
"context"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/models/json" "github.com/stashapp/stash/pkg/models/json"
"github.com/stashapp/stash/pkg/models/jsonschema" "github.com/stashapp/stash/pkg/models/jsonschema"
"github.com/stashapp/stash/pkg/studio"
) )
// ToBasicJSON converts a image object into its JSON object equivalent. It // ToBasicJSON converts a image object into its JSON object equivalent. It
@@ -56,9 +59,9 @@ func getImageFileJSON(image *models.Image) *jsonschema.ImageFile {
// GetStudioName returns the name of the provided image's studio. It returns an // GetStudioName returns the name of the provided image's studio. It returns an
// empty string if there is no studio assigned to the image. // empty string if there is no studio assigned to the image.
func GetStudioName(reader models.StudioReader, image *models.Image) (string, error) { func GetStudioName(ctx context.Context, reader studio.Finder, image *models.Image) (string, error) {
if image.StudioID.Valid { if image.StudioID.Valid {
studio, err := reader.Find(int(image.StudioID.Int64)) studio, err := reader.Find(ctx, int(image.StudioID.Int64))
if err != nil { if err != nil {
return "", err return "", err
} }

Some files were not shown because too many files have changed in this diff Show More