Model refactor, part 3 (#4152)

* Remove manager.Repository
* Refactor other repositories
* Fix tests and add database mock
* Add AssertExpectations method
* Refactor routes
* Move default movie image to internal/static and add convenience methods
* Refactor default performer image boxes
This commit is contained in:
DingDongSoLong4
2023-10-16 05:26:34 +02:00
committed by GitHub
parent 40bcb4baa5
commit 33f2ebf2a3
87 changed files with 1843 additions and 1651 deletions

View File

@@ -1,12 +1,13 @@
package api package api
import ( import (
"errors"
"fmt"
"io" "io"
"io/fs" "io/fs"
"os" "os"
"strings" "strings"
"github.com/stashapp/stash/internal/manager/config"
"github.com/stashapp/stash/internal/static" "github.com/stashapp/stash/internal/static"
"github.com/stashapp/stash/pkg/hash" "github.com/stashapp/stash/pkg/hash"
"github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/logger"
@@ -18,7 +19,7 @@ type imageBox struct {
files []string files []string
} }
var imageExtensions = []string{ var imageBoxExts = []string{
".jpg", ".jpg",
".jpeg", ".jpeg",
".png", ".png",
@@ -42,7 +43,7 @@ func newImageBox(box fs.FS) (*imageBox, error) {
} }
baseName := strings.ToLower(d.Name()) baseName := strings.ToLower(d.Name())
for _, ext := range imageExtensions { for _, ext := range imageBoxExts {
if strings.HasSuffix(baseName, ext) { if strings.HasSuffix(baseName, ext) {
ret.files = append(ret.files, path) ret.files = append(ret.files, path)
break break
@@ -55,65 +56,14 @@ func newImageBox(box fs.FS) (*imageBox, error) {
return ret, err return ret, err
} }
var performerBox *imageBox func (box *imageBox) GetRandomImageByName(name string) ([]byte, error) {
var performerBoxMale *imageBox files := box.files
var performerBoxCustom *imageBox if len(files) == 0 {
return nil, errors.New("box is empty")
func initialiseImages() {
var err error
performerBox, err = newImageBox(&static.Performer)
if err != nil {
logger.Warnf("error loading performer images: %v", err)
}
performerBoxMale, err = newImageBox(&static.PerformerMale)
if err != nil {
logger.Warnf("error loading male performer images: %v", err)
}
initialiseCustomImages()
}
func initialiseCustomImages() {
customPath := config.GetInstance().GetCustomPerformerImageLocation()
if customPath != "" {
logger.Debugf("Loading custom performer images from %s", customPath)
// We need to set performerBoxCustom at runtime, as this is a custom path, and store it in a pointer.
var err error
performerBoxCustom, err = newImageBox(os.DirFS(customPath))
if err != nil {
logger.Warnf("error loading custom performer from %s: %v", customPath, err)
}
} else {
performerBoxCustom = nil
}
}
func getRandomPerformerImageUsingName(name string, gender *models.GenderEnum, customPath string) ([]byte, error) {
var box *imageBox
// If we have a custom path, we should return a new box in the given path.
if performerBoxCustom != nil && len(performerBoxCustom.files) > 0 {
box = performerBoxCustom
} }
var g models.GenderEnum index := hash.IntFromString(name) % uint64(len(files))
if gender != nil { img, err := box.box.Open(files[index])
g = *gender
}
if box == nil {
switch g {
case models.GenderEnumFemale, models.GenderEnumTransgenderFemale:
box = performerBox
case models.GenderEnumMale, models.GenderEnumTransgenderMale:
box = performerBoxMale
default:
box = performerBox
}
}
imageFiles := box.files
index := hash.IntFromString(name) % uint64(len(imageFiles))
img, err := box.box.Open(imageFiles[index])
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -121,3 +71,64 @@ func getRandomPerformerImageUsingName(name string, gender *models.GenderEnum, cu
return io.ReadAll(img) return io.ReadAll(img)
} }
var performerBox *imageBox
var performerBoxMale *imageBox
var performerBoxCustom *imageBox
func init() {
var err error
performerBox, err = newImageBox(static.Sub(static.Performer))
if err != nil {
panic(fmt.Sprintf("loading performer images: %v", err))
}
performerBoxMale, err = newImageBox(static.Sub(static.PerformerMale))
if err != nil {
panic(fmt.Sprintf("loading male performer images: %v", err))
}
}
func initCustomPerformerImages(customPath string) {
if customPath != "" {
logger.Debugf("Loading custom performer images from %s", customPath)
var err error
performerBoxCustom, err = newImageBox(os.DirFS(customPath))
if err != nil {
logger.Warnf("error loading custom performer images from %s: %v", customPath, err)
}
} else {
performerBoxCustom = nil
}
}
func getDefaultPerformerImage(name string, gender *models.GenderEnum) []byte {
// try the custom box first if we have one
if performerBoxCustom != nil {
ret, err := performerBoxCustom.GetRandomImageByName(name)
if err == nil {
return ret
}
logger.Warnf("error loading custom default performer image: %v", err)
}
var g models.GenderEnum
if gender != nil {
g = *gender
}
var box *imageBox
switch g {
case models.GenderEnumFemale, models.GenderEnumTransgenderFemale:
box = performerBox
case models.GenderEnumMale, models.GenderEnumTransgenderMale:
box = performerBoxMale
default:
box = performerBox
}
ret, err := box.GetRandomImageByName(name)
if err != nil {
logger.Warnf("error loading default performer image: %v", err)
}
return ret
}

View File

@@ -17,9 +17,7 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/stashapp/stash/internal/manager"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
) )
type contextKey struct{ name string } type contextKey struct{ name string }
@@ -49,8 +47,7 @@ type Loaders struct {
} }
type Middleware struct { type Middleware struct {
DatabaseProvider txn.DatabaseProvider Repository models.Repository
Repository manager.Repository
} }
func (m Middleware) Middleware(next http.Handler) http.Handler { func (m Middleware) Middleware(next http.Handler) http.Handler {
@@ -131,13 +128,9 @@ func toErrorSlice(err error) []error {
return nil return nil
} }
func (m Middleware) withTxn(ctx context.Context, fn func(ctx context.Context) error) error {
return txn.WithDatabase(ctx, m.DatabaseProvider, fn)
}
func (m Middleware) fetchScenes(ctx context.Context) func(keys []int) ([]*models.Scene, []error) { func (m Middleware) fetchScenes(ctx context.Context) func(keys []int) ([]*models.Scene, []error) {
return func(keys []int) (ret []*models.Scene, errs []error) { return func(keys []int) (ret []*models.Scene, errs []error) {
err := m.withTxn(ctx, func(ctx context.Context) error { err := m.Repository.WithDB(ctx, func(ctx context.Context) error {
var err error var err error
ret, err = m.Repository.Scene.FindMany(ctx, keys) ret, err = m.Repository.Scene.FindMany(ctx, keys)
return err return err
@@ -148,7 +141,7 @@ func (m Middleware) fetchScenes(ctx context.Context) func(keys []int) ([]*models
func (m Middleware) fetchImages(ctx context.Context) func(keys []int) ([]*models.Image, []error) { func (m Middleware) fetchImages(ctx context.Context) func(keys []int) ([]*models.Image, []error) {
return func(keys []int) (ret []*models.Image, errs []error) { return func(keys []int) (ret []*models.Image, errs []error) {
err := m.withTxn(ctx, func(ctx context.Context) error { err := m.Repository.WithDB(ctx, func(ctx context.Context) error {
var err error var err error
ret, err = m.Repository.Image.FindMany(ctx, keys) ret, err = m.Repository.Image.FindMany(ctx, keys)
return err return err
@@ -160,7 +153,7 @@ func (m Middleware) fetchImages(ctx context.Context) func(keys []int) ([]*models
func (m Middleware) fetchGalleries(ctx context.Context) func(keys []int) ([]*models.Gallery, []error) { func (m Middleware) fetchGalleries(ctx context.Context) func(keys []int) ([]*models.Gallery, []error) {
return func(keys []int) (ret []*models.Gallery, errs []error) { return func(keys []int) (ret []*models.Gallery, errs []error) {
err := m.withTxn(ctx, func(ctx context.Context) error { err := m.Repository.WithDB(ctx, func(ctx context.Context) error {
var err error var err error
ret, err = m.Repository.Gallery.FindMany(ctx, keys) ret, err = m.Repository.Gallery.FindMany(ctx, keys)
return err return err
@@ -172,7 +165,7 @@ func (m Middleware) fetchGalleries(ctx context.Context) func(keys []int) ([]*mod
func (m Middleware) fetchPerformers(ctx context.Context) func(keys []int) ([]*models.Performer, []error) { func (m Middleware) fetchPerformers(ctx context.Context) func(keys []int) ([]*models.Performer, []error) {
return func(keys []int) (ret []*models.Performer, errs []error) { return func(keys []int) (ret []*models.Performer, errs []error) {
err := m.withTxn(ctx, func(ctx context.Context) error { err := m.Repository.WithDB(ctx, func(ctx context.Context) error {
var err error var err error
ret, err = m.Repository.Performer.FindMany(ctx, keys) ret, err = m.Repository.Performer.FindMany(ctx, keys)
return err return err
@@ -184,7 +177,7 @@ func (m Middleware) fetchPerformers(ctx context.Context) func(keys []int) ([]*mo
func (m Middleware) fetchStudios(ctx context.Context) func(keys []int) ([]*models.Studio, []error) { func (m Middleware) fetchStudios(ctx context.Context) func(keys []int) ([]*models.Studio, []error) {
return func(keys []int) (ret []*models.Studio, errs []error) { return func(keys []int) (ret []*models.Studio, errs []error) {
err := m.withTxn(ctx, func(ctx context.Context) error { err := m.Repository.WithDB(ctx, func(ctx context.Context) error {
var err error var err error
ret, err = m.Repository.Studio.FindMany(ctx, keys) ret, err = m.Repository.Studio.FindMany(ctx, keys)
return err return err
@@ -195,7 +188,7 @@ func (m Middleware) fetchStudios(ctx context.Context) func(keys []int) ([]*model
func (m Middleware) fetchTags(ctx context.Context) func(keys []int) ([]*models.Tag, []error) { func (m Middleware) fetchTags(ctx context.Context) func(keys []int) ([]*models.Tag, []error) {
return func(keys []int) (ret []*models.Tag, errs []error) { return func(keys []int) (ret []*models.Tag, errs []error) {
err := m.withTxn(ctx, func(ctx context.Context) error { err := m.Repository.WithDB(ctx, func(ctx context.Context) error {
var err error var err error
ret, err = m.Repository.Tag.FindMany(ctx, keys) ret, err = m.Repository.Tag.FindMany(ctx, keys)
return err return err
@@ -206,7 +199,7 @@ func (m Middleware) fetchTags(ctx context.Context) func(keys []int) ([]*models.T
func (m Middleware) fetchMovies(ctx context.Context) func(keys []int) ([]*models.Movie, []error) { func (m Middleware) fetchMovies(ctx context.Context) func(keys []int) ([]*models.Movie, []error) {
return func(keys []int) (ret []*models.Movie, errs []error) { return func(keys []int) (ret []*models.Movie, errs []error) {
err := m.withTxn(ctx, func(ctx context.Context) error { err := m.Repository.WithDB(ctx, func(ctx context.Context) error {
var err error var err error
ret, err = m.Repository.Movie.FindMany(ctx, keys) ret, err = m.Repository.Movie.FindMany(ctx, keys)
return err return err
@@ -217,7 +210,7 @@ func (m Middleware) fetchMovies(ctx context.Context) func(keys []int) ([]*models
func (m Middleware) fetchFiles(ctx context.Context) func(keys []models.FileID) ([]models.File, []error) { func (m Middleware) fetchFiles(ctx context.Context) func(keys []models.FileID) ([]models.File, []error) {
return func(keys []models.FileID) (ret []models.File, errs []error) { return func(keys []models.FileID) (ret []models.File, errs []error) {
err := m.withTxn(ctx, func(ctx context.Context) error { err := m.Repository.WithDB(ctx, func(ctx context.Context) error {
var err error var err error
ret, err = m.Repository.File.Find(ctx, keys...) ret, err = m.Repository.File.Find(ctx, keys...)
return err return err
@@ -228,7 +221,7 @@ func (m Middleware) fetchFiles(ctx context.Context) func(keys []models.FileID) (
func (m Middleware) fetchScenesFileIDs(ctx context.Context) func(keys []int) ([][]models.FileID, []error) { func (m Middleware) fetchScenesFileIDs(ctx context.Context) func(keys []int) ([][]models.FileID, []error) {
return func(keys []int) (ret [][]models.FileID, errs []error) { return func(keys []int) (ret [][]models.FileID, errs []error) {
err := m.withTxn(ctx, func(ctx context.Context) error { err := m.Repository.WithDB(ctx, func(ctx context.Context) error {
var err error var err error
ret, err = m.Repository.Scene.GetManyFileIDs(ctx, keys) ret, err = m.Repository.Scene.GetManyFileIDs(ctx, keys)
return err return err
@@ -239,7 +232,7 @@ func (m Middleware) fetchScenesFileIDs(ctx context.Context) func(keys []int) ([]
func (m Middleware) fetchImagesFileIDs(ctx context.Context) func(keys []int) ([][]models.FileID, []error) { func (m Middleware) fetchImagesFileIDs(ctx context.Context) func(keys []int) ([][]models.FileID, []error) {
return func(keys []int) (ret [][]models.FileID, errs []error) { return func(keys []int) (ret [][]models.FileID, errs []error) {
err := m.withTxn(ctx, func(ctx context.Context) error { err := m.Repository.WithDB(ctx, func(ctx context.Context) error {
var err error var err error
ret, err = m.Repository.Image.GetManyFileIDs(ctx, keys) ret, err = m.Repository.Image.GetManyFileIDs(ctx, keys)
return err return err
@@ -250,7 +243,7 @@ func (m Middleware) fetchImagesFileIDs(ctx context.Context) func(keys []int) ([]
func (m Middleware) fetchGalleriesFileIDs(ctx context.Context) func(keys []int) ([][]models.FileID, []error) { func (m Middleware) fetchGalleriesFileIDs(ctx context.Context) func(keys []int) ([][]models.FileID, []error) {
return func(keys []int) (ret [][]models.FileID, errs []error) { return func(keys []int) (ret [][]models.FileID, errs []error) {
err := m.withTxn(ctx, func(ctx context.Context) error { err := m.Repository.WithDB(ctx, func(ctx context.Context) error {
var err error var err error
ret, err = m.Repository.Gallery.GetManyFileIDs(ctx, keys) ret, err = m.Repository.Gallery.GetManyFileIDs(ctx, keys)
return err return err

View File

@@ -13,7 +13,7 @@ import (
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/plugin" "github.com/stashapp/stash/pkg/plugin"
"github.com/stashapp/stash/pkg/scraper" "github.com/stashapp/stash/pkg/scraper"
"github.com/stashapp/stash/pkg/txn" "github.com/stashapp/stash/pkg/scraper/stashbox"
) )
var ( var (
@@ -33,8 +33,7 @@ type hookExecutor interface {
} }
type Resolver struct { type Resolver struct {
txnManager txn.Manager repository models.Repository
repository manager.Repository
sceneService manager.SceneService sceneService manager.SceneService
imageService manager.ImageService imageService manager.ImageService
galleryService manager.GalleryService galleryService manager.GalleryService
@@ -102,11 +101,15 @@ type tagResolver struct{ *Resolver }
type savedFilterResolver struct{ *Resolver } type savedFilterResolver struct{ *Resolver }
func (r *Resolver) withTxn(ctx context.Context, fn func(ctx context.Context) error) error { func (r *Resolver) withTxn(ctx context.Context, fn func(ctx context.Context) error) error {
return txn.WithTxn(ctx, r.txnManager, fn) return r.repository.WithTxn(ctx, fn)
} }
func (r *Resolver) withReadTxn(ctx context.Context, fn func(ctx context.Context) error) error { func (r *Resolver) withReadTxn(ctx context.Context, fn func(ctx context.Context) error) error {
return txn.WithReadTxn(ctx, r.txnManager, fn) return r.repository.WithReadTxn(ctx, fn)
}
func (r *Resolver) stashboxRepository() stashbox.Repository {
return stashbox.NewRepository(r.repository)
} }
func (r *queryResolver) MarkerWall(ctx context.Context, q *string) (ret []*models.SceneMarker, err error) { func (r *queryResolver) MarkerWall(ctx context.Context, q *string) (ret []*models.SceneMarker, err error) {

View File

@@ -316,7 +316,7 @@ func (r *mutationResolver) ConfigureGeneral(ctx context.Context, input ConfigGen
if input.CustomPerformerImageLocation != nil { if input.CustomPerformerImageLocation != nil {
c.Set(config.CustomPerformerImageLocation, *input.CustomPerformerImageLocation) c.Set(config.CustomPerformerImageLocation, *input.CustomPerformerImageLocation)
initialiseCustomImages() initCustomPerformerImages(*input.CustomPerformerImageLocation)
} }
if input.ScraperUserAgent != nil { if input.ScraperUserAgent != nil {

View File

@@ -17,7 +17,7 @@ func (r *mutationResolver) MoveFiles(ctx context.Context, input MoveFilesInput)
fileStore := r.repository.File fileStore := r.repository.File
folderStore := r.repository.Folder folderStore := r.repository.Folder
mover := file.NewMover(fileStore, folderStore) mover := file.NewMover(fileStore, folderStore)
mover.RegisterHooks(ctx, r.txnManager) mover.RegisterHooks(ctx)
var ( var (
folder *models.Folder folder *models.Folder

View File

@@ -11,30 +11,30 @@ import (
) )
func (r *mutationResolver) MigrateSceneScreenshots(ctx context.Context, input MigrateSceneScreenshotsInput) (string, error) { func (r *mutationResolver) MigrateSceneScreenshots(ctx context.Context, input MigrateSceneScreenshotsInput) (string, error) {
db := manager.GetInstance().Database mgr := manager.GetInstance()
t := &task.MigrateSceneScreenshotsJob{ t := &task.MigrateSceneScreenshotsJob{
ScreenshotsPath: manager.GetInstance().Paths.Generated.Screenshots, ScreenshotsPath: manager.GetInstance().Paths.Generated.Screenshots,
Input: scene.MigrateSceneScreenshotsInput{ Input: scene.MigrateSceneScreenshotsInput{
DeleteFiles: utils.IsTrue(input.DeleteFiles), DeleteFiles: utils.IsTrue(input.DeleteFiles),
OverwriteExisting: utils.IsTrue(input.OverwriteExisting), OverwriteExisting: utils.IsTrue(input.OverwriteExisting),
}, },
SceneRepo: db.Scene, SceneRepo: mgr.Repository.Scene,
TxnManager: db, TxnManager: mgr.Repository.TxnManager,
} }
jobID := manager.GetInstance().JobManager.Add(ctx, "Migrating scene screenshots to blobs...", t) jobID := mgr.JobManager.Add(ctx, "Migrating scene screenshots to blobs...", t)
return strconv.Itoa(jobID), nil return strconv.Itoa(jobID), nil
} }
func (r *mutationResolver) MigrateBlobs(ctx context.Context, input MigrateBlobsInput) (string, error) { func (r *mutationResolver) MigrateBlobs(ctx context.Context, input MigrateBlobsInput) (string, error) {
db := manager.GetInstance().Database mgr := manager.GetInstance()
t := &task.MigrateBlobsJob{ t := &task.MigrateBlobsJob{
TxnManager: db, TxnManager: mgr.Database,
BlobStore: db.Blobs, BlobStore: mgr.Database.Blobs,
Vacuumer: db, Vacuumer: mgr.Database,
DeleteOld: utils.IsTrue(input.DeleteOld), DeleteOld: utils.IsTrue(input.DeleteOld),
} }
jobID := manager.GetInstance().JobManager.Add(ctx, "Migrating blobs...", t) jobID := mgr.JobManager.Add(ctx, "Migrating blobs...", t)
return strconv.Itoa(jobID), nil return strconv.Itoa(jobID), nil
} }

View File

@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"strconv" "strconv"
"github.com/stashapp/stash/internal/static"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/plugin" "github.com/stashapp/stash/pkg/plugin"
"github.com/stashapp/stash/pkg/sliceutil/stringslice" "github.com/stashapp/stash/pkg/sliceutil/stringslice"
@@ -50,12 +51,6 @@ func (r *mutationResolver) MovieCreate(ctx context.Context, input MovieCreateInp
return nil, fmt.Errorf("converting studio id: %w", err) return nil, fmt.Errorf("converting studio id: %w", err)
} }
// HACK: if back image is being set, set the front image to the default.
// This is because we can't have a null front image with a non-null back image.
if input.FrontImage == nil && input.BackImage != nil {
input.FrontImage = &models.DefaultMovieImage
}
// Process the base 64 encoded image string // Process the base 64 encoded image string
var frontimageData []byte var frontimageData []byte
if input.FrontImage != nil { if input.FrontImage != nil {
@@ -74,6 +69,12 @@ func (r *mutationResolver) MovieCreate(ctx context.Context, input MovieCreateInp
} }
} }
// HACK: if back image is being set, set the front image to the default.
// This is because we can't have a null front image with a non-null back image.
if len(frontimageData) == 0 && len(backimageData) != 0 {
frontimageData = static.ReadAll(static.DefaultMovieImage)
}
// Start the transaction and save the movie // Start the transaction and save the movie
if err := r.withTxn(ctx, func(ctx context.Context) error { if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Movie qb := r.repository.Movie

View File

@@ -11,15 +11,6 @@ import (
"github.com/stashapp/stash/pkg/scraper/stashbox" "github.com/stashapp/stash/pkg/scraper/stashbox"
) )
func (r *Resolver) stashboxRepository() stashbox.Repository {
return stashbox.Repository{
Scene: r.repository.Scene,
Performer: r.repository.Performer,
Tag: r.repository.Tag,
Studio: r.repository.Studio,
}
}
func (r *mutationResolver) SubmitStashBoxFingerprints(ctx context.Context, input StashBoxFingerprintSubmissionInput) (bool, error) { func (r *mutationResolver) SubmitStashBoxFingerprints(ctx context.Context, input StashBoxFingerprintSubmissionInput) (bool, error) {
boxes := config.GetInstance().GetStashBoxes() boxes := config.GetInstance().GetStashBoxes()
@@ -27,7 +18,7 @@ func (r *mutationResolver) SubmitStashBoxFingerprints(ctx context.Context, input
return false, fmt.Errorf("invalid stash_box_index %d", input.StashBoxIndex) return false, fmt.Errorf("invalid stash_box_index %d", input.StashBoxIndex)
} }
client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.txnManager, r.stashboxRepository()) client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.stashboxRepository())
return client.SubmitStashBoxFingerprints(ctx, input.SceneIds, boxes[input.StashBoxIndex].Endpoint) return client.SubmitStashBoxFingerprints(ctx, input.SceneIds, boxes[input.StashBoxIndex].Endpoint)
} }
@@ -49,7 +40,7 @@ func (r *mutationResolver) SubmitStashBoxSceneDraft(ctx context.Context, input S
return nil, fmt.Errorf("invalid stash_box_index %d", input.StashBoxIndex) return nil, fmt.Errorf("invalid stash_box_index %d", input.StashBoxIndex)
} }
client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.txnManager, r.stashboxRepository()) client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.stashboxRepository())
id, err := strconv.Atoi(input.ID) id, err := strconv.Atoi(input.ID)
if err != nil { if err != nil {
@@ -91,7 +82,7 @@ func (r *mutationResolver) SubmitStashBoxPerformerDraft(ctx context.Context, inp
return nil, fmt.Errorf("invalid stash_box_index %d", input.StashBoxIndex) return nil, fmt.Errorf("invalid stash_box_index %d", input.StashBoxIndex)
} }
client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.txnManager, r.stashboxRepository()) client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.stashboxRepository())
id, err := strconv.Atoi(input.ID) id, err := strconv.Atoi(input.ID)
if err != nil { if err != nil {

View File

@@ -5,7 +5,6 @@ import (
"errors" "errors"
"testing" "testing"
"github.com/stashapp/stash/internal/manager"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/models/mocks" "github.com/stashapp/stash/pkg/models/mocks"
"github.com/stashapp/stash/pkg/plugin" "github.com/stashapp/stash/pkg/plugin"
@@ -15,14 +14,9 @@ import (
) )
// TODO - move this into a common area // TODO - move this into a common area
func newResolver() *Resolver { func newResolver(db *mocks.Database) *Resolver {
txnMgr := &mocks.TxnManager{}
return &Resolver{ return &Resolver{
txnManager: txnMgr, repository: db.Repository(),
repository: manager.Repository{
TxnManager: txnMgr,
Tag: &mocks.TagReaderWriter{},
},
hookExecutor: &mockHookExecutor{}, hookExecutor: &mockHookExecutor{},
} }
} }
@@ -45,9 +39,8 @@ func (*mockHookExecutor) ExecutePostHooks(ctx context.Context, id int, hookType
} }
func TestTagCreate(t *testing.T) { func TestTagCreate(t *testing.T) {
r := newResolver() db := mocks.NewDatabase()
r := newResolver(db)
tagRW := r.repository.Tag.(*mocks.TagReaderWriter)
pp := 1 pp := 1
findFilter := &models.FindFilterType{ findFilter := &models.FindFilterType{
@@ -72,17 +65,17 @@ func TestTagCreate(t *testing.T) {
} }
} }
tagRW.On("Query", mock.Anything, tagFilterForName(existingTagName), findFilter).Return([]*models.Tag{ db.Tag.On("Query", mock.Anything, tagFilterForName(existingTagName), findFilter).Return([]*models.Tag{
{ {
ID: existingTagID, ID: existingTagID,
Name: existingTagName, Name: existingTagName,
}, },
}, 1, nil).Once() }, 1, nil).Once()
tagRW.On("Query", mock.Anything, tagFilterForName(errTagName), findFilter).Return(nil, 0, nil).Once() db.Tag.On("Query", mock.Anything, tagFilterForName(errTagName), findFilter).Return(nil, 0, nil).Once()
tagRW.On("Query", mock.Anything, tagFilterForAlias(errTagName), findFilter).Return(nil, 0, nil).Once() db.Tag.On("Query", mock.Anything, tagFilterForAlias(errTagName), findFilter).Return(nil, 0, nil).Once()
expectedErr := errors.New("TagCreate error") expectedErr := errors.New("TagCreate error")
tagRW.On("Create", mock.Anything, mock.AnythingOfType("*models.Tag")).Return(expectedErr) db.Tag.On("Create", mock.Anything, mock.AnythingOfType("*models.Tag")).Return(expectedErr)
// fails here because testCtx is empty // fails here because testCtx is empty
// TODO: Fix this // TODO: Fix this
@@ -101,22 +94,22 @@ func TestTagCreate(t *testing.T) {
}) })
assert.Equal(t, expectedErr, err) assert.Equal(t, expectedErr, err)
tagRW.AssertExpectations(t) db.AssertExpectations(t)
r = newResolver() db = mocks.NewDatabase()
tagRW = r.repository.Tag.(*mocks.TagReaderWriter) r = newResolver(db)
tagRW.On("Query", mock.Anything, tagFilterForName(tagName), findFilter).Return(nil, 0, nil).Once() db.Tag.On("Query", mock.Anything, tagFilterForName(tagName), findFilter).Return(nil, 0, nil).Once()
tagRW.On("Query", mock.Anything, tagFilterForAlias(tagName), findFilter).Return(nil, 0, nil).Once() db.Tag.On("Query", mock.Anything, tagFilterForAlias(tagName), findFilter).Return(nil, 0, nil).Once()
newTag := &models.Tag{ newTag := &models.Tag{
ID: newTagID, ID: newTagID,
Name: tagName, Name: tagName,
} }
tagRW.On("Create", mock.Anything, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) { db.Tag.On("Create", mock.Anything, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) {
arg := args.Get(1).(*models.Tag) arg := args.Get(1).(*models.Tag)
arg.ID = newTagID arg.ID = newTagID
}).Return(nil) }).Return(nil)
tagRW.On("Find", mock.Anything, newTagID).Return(newTag, nil) db.Tag.On("Find", mock.Anything, newTagID).Return(newTag, nil)
tag, err := r.Mutation().TagCreate(testCtx, TagCreateInput{ tag, err := r.Mutation().TagCreate(testCtx, TagCreateInput{
Name: tagName, Name: tagName,
@@ -124,4 +117,5 @@ func TestTagCreate(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.NotNil(t, tag) assert.NotNil(t, tag)
db.AssertExpectations(t)
} }

View File

@@ -243,7 +243,9 @@ func makeConfigUIResult() map[string]interface{} {
} }
func (r *queryResolver) ValidateStashBoxCredentials(ctx context.Context, input config.StashBoxInput) (*StashBoxValidationResult, error) { func (r *queryResolver) ValidateStashBoxCredentials(ctx context.Context, input config.StashBoxInput) (*StashBoxValidationResult, error) {
client := stashbox.NewClient(models.StashBox{Endpoint: input.Endpoint, APIKey: input.APIKey}, r.txnManager, r.stashboxRepository()) box := models.StashBox{Endpoint: input.Endpoint, APIKey: input.APIKey}
client := stashbox.NewClient(box, r.stashboxRepository())
user, err := client.GetUser(ctx) user, err := client.GetUser(ctx)
valid := user != nil && user.Me != nil valid := user != nil && user.Me != nil

View File

@@ -191,16 +191,11 @@ func (r *queryResolver) FindScenesByPathRegex(ctx context.Context, filter *model
} }
func (r *queryResolver) ParseSceneFilenames(ctx context.Context, filter *models.FindFilterType, config models.SceneParserInput) (ret *SceneParserResultType, err error) { func (r *queryResolver) ParseSceneFilenames(ctx context.Context, filter *models.FindFilterType, config models.SceneParserInput) (ret *SceneParserResultType, err error) {
parser := scene.NewFilenameParser(filter, config) repo := scene.NewFilenameParserRepository(r.repository)
parser := scene.NewFilenameParser(filter, config, repo)
if err := r.withReadTxn(ctx, func(ctx context.Context) error { if err := r.withReadTxn(ctx, func(ctx context.Context) error {
result, count, err := parser.Parse(ctx, scene.FilenameParserRepository{ result, count, err := parser.Parse(ctx)
Scene: r.repository.Scene,
Performer: r.repository.Performer,
Studio: r.repository.Studio,
Movie: r.repository.Movie,
Tag: r.repository.Tag,
})
if err != nil { if err != nil {
return err return err

View File

@@ -238,7 +238,7 @@ func (r *queryResolver) getStashBoxClient(index int) (*stashbox.Client, error) {
return nil, fmt.Errorf("%w: invalid stash_box_index %d", ErrInput, index) return nil, fmt.Errorf("%w: invalid stash_box_index %d", ErrInput, index)
} }
return stashbox.NewClient(*boxes[index], r.txnManager, r.stashboxRepository()), nil return stashbox.NewClient(*boxes[index], r.stashboxRepository()), nil
} }
func (r *queryResolver) ScrapeSingleScene(ctx context.Context, source scraper.Source, input ScrapeSingleSceneInput) ([]*scraper.ScrapedScene, error) { func (r *queryResolver) ScrapeSingleScene(ctx context.Context, source scraper.Source, input ScrapeSingleSceneInput) ([]*scraper.ScrapedScene, error) {

15
internal/api/routes.go Normal file
View File

@@ -0,0 +1,15 @@
package api
import (
"net/http"
"github.com/stashapp/stash/pkg/txn"
)
type routes struct {
txnManager txn.Manager
}
func (rs routes) withReadTxn(r *http.Request, fn txn.TxnFunc) error {
return txn.WithReadTxn(r.Context(), rs.txnManager, fn)
}

View File

@@ -12,6 +12,10 @@ type customRoutes struct {
servedFolders config.URLMap servedFolders config.URLMap
} }
func getCustomRoutes(servedFolders config.URLMap) chi.Router {
return customRoutes{servedFolders: servedFolders}.Routes()
}
func (rs customRoutes) Routes() chi.Router { func (rs customRoutes) Routes() chi.Router {
r := chi.NewRouter() r := chi.NewRouter()

View File

@@ -10,6 +10,10 @@ import (
type downloadsRoutes struct{} type downloadsRoutes struct{}
func getDownloadsRoutes() chi.Router {
return downloadsRoutes{}.Routes()
}
func (rs downloadsRoutes) Routes() chi.Router { func (rs downloadsRoutes) Routes() chi.Router {
r := chi.NewRouter() r := chi.NewRouter()

View File

@@ -3,7 +3,6 @@ package api
import ( import (
"context" "context"
"errors" "errors"
"io"
"io/fs" "io/fs"
"net/http" "net/http"
"os/exec" "os/exec"
@@ -17,7 +16,6 @@ import (
"github.com/stashapp/stash/pkg/image" "github.com/stashapp/stash/pkg/image"
"github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
) )
@@ -27,11 +25,19 @@ type ImageFinder interface {
} }
type imageRoutes struct { type imageRoutes struct {
txnManager txn.Manager routes
imageFinder ImageFinder imageFinder ImageFinder
fileGetter models.FileGetter fileGetter models.FileGetter
} }
func getImageRoutes(repo models.Repository) chi.Router {
return imageRoutes{
routes: routes{txnManager: repo.TxnManager},
imageFinder: repo.Image,
fileGetter: repo.File,
}.Routes()
}
func (rs imageRoutes) Routes() chi.Router { func (rs imageRoutes) Routes() chi.Router {
r := chi.NewRouter() r := chi.NewRouter()
@@ -46,8 +52,6 @@ func (rs imageRoutes) Routes() chi.Router {
return r return r
} }
// region Handlers
func (rs imageRoutes) Thumbnail(w http.ResponseWriter, r *http.Request) { func (rs imageRoutes) Thumbnail(w http.ResponseWriter, r *http.Request) {
img := r.Context().Value(imageKey).(*models.Image) img := r.Context().Value(imageKey).(*models.Image)
filepath := manager.GetInstance().Paths.Generated.GetThumbnailPath(img.Checksum, models.DefaultGthumbWidth) filepath := manager.GetInstance().Paths.Generated.GetThumbnailPath(img.Checksum, models.DefaultGthumbWidth)
@@ -119,8 +123,6 @@ func (rs imageRoutes) Image(w http.ResponseWriter, r *http.Request) {
} }
func (rs imageRoutes) serveImage(w http.ResponseWriter, r *http.Request, i *models.Image, useDefault bool) { func (rs imageRoutes) serveImage(w http.ResponseWriter, r *http.Request, i *models.Image, useDefault bool) {
const defaultImageImage = "image/image.svg"
if i.Files.Primary() != nil { if i.Files.Primary() != nil {
err := i.Files.Primary().Base().Serve(&file.OsFS{}, w, r) err := i.Files.Primary().Base().Serve(&file.OsFS{}, w, r)
if err == nil { if err == nil {
@@ -141,22 +143,18 @@ func (rs imageRoutes) serveImage(w http.ResponseWriter, r *http.Request, i *mode
return return
} }
// fall back to static image // fallback to default image
f, _ := static.Image.Open(defaultImageImage) image := static.ReadAll(static.DefaultImageImage)
defer f.Close()
image, _ := io.ReadAll(f)
utils.ServeImage(w, r, image) utils.ServeImage(w, r, image)
} }
// endregion
func (rs imageRoutes) ImageCtx(next http.Handler) http.Handler { func (rs imageRoutes) ImageCtx(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
imageIdentifierQueryParam := chi.URLParam(r, "imageId") imageIdentifierQueryParam := chi.URLParam(r, "imageId")
imageID, _ := strconv.Atoi(imageIdentifierQueryParam) imageID, _ := strconv.Atoi(imageIdentifierQueryParam)
var image *models.Image var image *models.Image
_ = txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error { _ = rs.withReadTxn(r, func(ctx context.Context) error {
qb := rs.imageFinder qb := rs.imageFinder
if imageID == 0 { if imageID == 0 {
images, _ := qb.FindByChecksum(ctx, imageIdentifierQueryParam) images, _ := qb.FindByChecksum(ctx, imageIdentifierQueryParam)

View File

@@ -7,9 +7,10 @@ import (
"strconv" "strconv"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/stashapp/stash/internal/static"
"github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
) )
@@ -20,10 +21,17 @@ type MovieFinder interface {
} }
type movieRoutes struct { type movieRoutes struct {
txnManager txn.Manager routes
movieFinder MovieFinder movieFinder MovieFinder
} }
func getMovieRoutes(repo models.Repository) chi.Router {
return movieRoutes{
routes: routes{txnManager: repo.TxnManager},
movieFinder: repo.Movie,
}.Routes()
}
func (rs movieRoutes) Routes() chi.Router { func (rs movieRoutes) Routes() chi.Router {
r := chi.NewRouter() r := chi.NewRouter()
@@ -41,7 +49,7 @@ func (rs movieRoutes) FrontImage(w http.ResponseWriter, r *http.Request) {
defaultParam := r.URL.Query().Get("default") defaultParam := r.URL.Query().Get("default")
var image []byte var image []byte
if defaultParam != "true" { if defaultParam != "true" {
readTxnErr := txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error { readTxnErr := rs.withReadTxn(r, func(ctx context.Context) error {
var err error var err error
image, err = rs.movieFinder.GetFrontImage(ctx, movie.ID) image, err = rs.movieFinder.GetFrontImage(ctx, movie.ID)
return err return err
@@ -54,8 +62,9 @@ func (rs movieRoutes) FrontImage(w http.ResponseWriter, r *http.Request) {
} }
} }
// fallback to default image
if len(image) == 0 { if len(image) == 0 {
image, _ = utils.ProcessBase64Image(models.DefaultMovieImage) image = static.ReadAll(static.DefaultMovieImage)
} }
utils.ServeImage(w, r, image) utils.ServeImage(w, r, image)
@@ -66,7 +75,7 @@ func (rs movieRoutes) BackImage(w http.ResponseWriter, r *http.Request) {
defaultParam := r.URL.Query().Get("default") defaultParam := r.URL.Query().Get("default")
var image []byte var image []byte
if defaultParam != "true" { if defaultParam != "true" {
readTxnErr := txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error { readTxnErr := rs.withReadTxn(r, func(ctx context.Context) error {
var err error var err error
image, err = rs.movieFinder.GetBackImage(ctx, movie.ID) image, err = rs.movieFinder.GetBackImage(ctx, movie.ID)
return err return err
@@ -79,8 +88,9 @@ func (rs movieRoutes) BackImage(w http.ResponseWriter, r *http.Request) {
} }
} }
// fallback to default image
if len(image) == 0 { if len(image) == 0 {
image, _ = utils.ProcessBase64Image(models.DefaultMovieImage) image = static.ReadAll(static.DefaultMovieImage)
} }
utils.ServeImage(w, r, image) utils.ServeImage(w, r, image)
@@ -95,7 +105,7 @@ func (rs movieRoutes) MovieCtx(next http.Handler) http.Handler {
} }
var movie *models.Movie var movie *models.Movie
_ = txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error { _ = rs.withReadTxn(r, func(ctx context.Context) error {
movie, _ = rs.movieFinder.Find(ctx, movieID) movie, _ = rs.movieFinder.Find(ctx, movieID)
return nil return nil
}) })

View File

@@ -7,10 +7,8 @@ import (
"strconv" "strconv"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/stashapp/stash/internal/manager/config"
"github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
) )
@@ -20,10 +18,17 @@ type PerformerFinder interface {
} }
type performerRoutes struct { type performerRoutes struct {
txnManager txn.Manager routes
performerFinder PerformerFinder performerFinder PerformerFinder
} }
func getPerformerRoutes(repo models.Repository) chi.Router {
return performerRoutes{
routes: routes{txnManager: repo.TxnManager},
performerFinder: repo.Performer,
}.Routes()
}
func (rs performerRoutes) Routes() chi.Router { func (rs performerRoutes) Routes() chi.Router {
r := chi.NewRouter() r := chi.NewRouter()
@@ -41,7 +46,7 @@ func (rs performerRoutes) Image(w http.ResponseWriter, r *http.Request) {
var image []byte var image []byte
if defaultParam != "true" { if defaultParam != "true" {
readTxnErr := txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error { readTxnErr := rs.withReadTxn(r, func(ctx context.Context) error {
var err error var err error
image, err = rs.performerFinder.GetImage(ctx, performer.ID) image, err = rs.performerFinder.GetImage(ctx, performer.ID)
return err return err
@@ -55,7 +60,7 @@ func (rs performerRoutes) Image(w http.ResponseWriter, r *http.Request) {
} }
if len(image) == 0 { if len(image) == 0 {
image, _ = getRandomPerformerImageUsingName(performer.Name, performer.Gender, config.GetInstance().GetCustomPerformerImageLocation()) image = getDefaultPerformerImage(performer.Name, performer.Gender)
} }
utils.ServeImage(w, r, image) utils.ServeImage(w, r, image)
@@ -70,7 +75,7 @@ func (rs performerRoutes) PerformerCtx(next http.Handler) http.Handler {
} }
var performer *models.Performer var performer *models.Performer
_ = txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error { _ = rs.withReadTxn(r, func(ctx context.Context) error {
var err error var err error
performer, err = rs.performerFinder.Find(ctx, performerID) performer, err = rs.performerFinder.Find(ctx, performerID)
return err return err

View File

@@ -16,7 +16,6 @@ import (
"github.com/stashapp/stash/pkg/fsutil" "github.com/stashapp/stash/pkg/fsutil"
"github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
) )
@@ -43,7 +42,7 @@ type CaptionFinder interface {
} }
type sceneRoutes struct { type sceneRoutes struct {
txnManager txn.Manager routes
sceneFinder SceneFinder sceneFinder SceneFinder
fileGetter models.FileGetter fileGetter models.FileGetter
captionFinder CaptionFinder captionFinder CaptionFinder
@@ -51,6 +50,17 @@ type sceneRoutes struct {
tagFinder SceneMarkerTagFinder tagFinder SceneMarkerTagFinder
} }
func getSceneRoutes(repo models.Repository) chi.Router {
return sceneRoutes{
routes: routes{txnManager: repo.TxnManager},
sceneFinder: repo.Scene,
fileGetter: repo.File,
captionFinder: repo.File,
sceneMarkerFinder: repo.SceneMarker,
tagFinder: repo.Tag,
}.Routes()
}
func (rs sceneRoutes) Routes() chi.Router { func (rs sceneRoutes) Routes() chi.Router {
r := chi.NewRouter() r := chi.NewRouter()
@@ -89,8 +99,6 @@ func (rs sceneRoutes) Routes() chi.Router {
return r return r
} }
// region Handlers
func (rs sceneRoutes) StreamDirect(w http.ResponseWriter, r *http.Request) { func (rs sceneRoutes) StreamDirect(w http.ResponseWriter, r *http.Request) {
scene := r.Context().Value(sceneKey).(*models.Scene) scene := r.Context().Value(sceneKey).(*models.Scene)
ss := manager.SceneServer{ ss := manager.SceneServer{
@@ -270,13 +278,13 @@ func (rs sceneRoutes) Webp(w http.ResponseWriter, r *http.Request) {
utils.ServeStaticFile(w, r, filepath) utils.ServeStaticFile(w, r, filepath)
} }
func (rs sceneRoutes) getChapterVttTitle(ctx context.Context, marker *models.SceneMarker) (*string, error) { func (rs sceneRoutes) getChapterVttTitle(r *http.Request, marker *models.SceneMarker) (*string, error) {
if marker.Title != "" { if marker.Title != "" {
return &marker.Title, nil return &marker.Title, nil
} }
var title string var title string
if err := txn.WithReadTxn(ctx, rs.txnManager, func(ctx context.Context) error { if err := rs.withReadTxn(r, func(ctx context.Context) error {
qb := rs.tagFinder qb := rs.tagFinder
primaryTag, err := qb.Find(ctx, marker.PrimaryTagID) primaryTag, err := qb.Find(ctx, marker.PrimaryTagID)
if err != nil { if err != nil {
@@ -305,7 +313,7 @@ func (rs sceneRoutes) getChapterVttTitle(ctx context.Context, marker *models.Sce
func (rs sceneRoutes) VttChapter(w http.ResponseWriter, r *http.Request) { func (rs sceneRoutes) VttChapter(w http.ResponseWriter, r *http.Request) {
scene := r.Context().Value(sceneKey).(*models.Scene) scene := r.Context().Value(sceneKey).(*models.Scene)
var sceneMarkers []*models.SceneMarker var sceneMarkers []*models.SceneMarker
readTxnErr := txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error { readTxnErr := rs.withReadTxn(r, func(ctx context.Context) error {
var err error var err error
sceneMarkers, err = rs.sceneMarkerFinder.FindBySceneID(ctx, scene.ID) sceneMarkers, err = rs.sceneMarkerFinder.FindBySceneID(ctx, scene.ID)
return err return err
@@ -325,7 +333,7 @@ func (rs sceneRoutes) VttChapter(w http.ResponseWriter, r *http.Request) {
time := utils.GetVTTTime(marker.Seconds) time := utils.GetVTTTime(marker.Seconds)
vttLines = append(vttLines, time+" --> "+time) vttLines = append(vttLines, time+" --> "+time)
vttTitle, err := rs.getChapterVttTitle(r.Context(), marker) vttTitle, err := rs.getChapterVttTitle(r, marker)
if errors.Is(err, context.Canceled) { if errors.Is(err, context.Canceled) {
return return
} }
@@ -404,7 +412,7 @@ func (rs sceneRoutes) Caption(w http.ResponseWriter, r *http.Request, lang strin
s := r.Context().Value(sceneKey).(*models.Scene) s := r.Context().Value(sceneKey).(*models.Scene)
var captions []*models.VideoCaption var captions []*models.VideoCaption
readTxnErr := txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error { readTxnErr := rs.withReadTxn(r, func(ctx context.Context) error {
var err error var err error
primaryFile := s.Files.Primary() primaryFile := s.Files.Primary()
if primaryFile == nil { if primaryFile == nil {
@@ -466,7 +474,7 @@ func (rs sceneRoutes) SceneMarkerStream(w http.ResponseWriter, r *http.Request)
sceneHash := scene.GetHash(config.GetInstance().GetVideoFileNamingAlgorithm()) sceneHash := scene.GetHash(config.GetInstance().GetVideoFileNamingAlgorithm())
sceneMarkerID, _ := strconv.Atoi(chi.URLParam(r, "sceneMarkerId")) sceneMarkerID, _ := strconv.Atoi(chi.URLParam(r, "sceneMarkerId"))
var sceneMarker *models.SceneMarker var sceneMarker *models.SceneMarker
readTxnErr := txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error { readTxnErr := rs.withReadTxn(r, func(ctx context.Context) error {
var err error var err error
sceneMarker, err = rs.sceneMarkerFinder.Find(ctx, sceneMarkerID) sceneMarker, err = rs.sceneMarkerFinder.Find(ctx, sceneMarkerID)
return err return err
@@ -494,7 +502,7 @@ func (rs sceneRoutes) SceneMarkerPreview(w http.ResponseWriter, r *http.Request)
sceneHash := scene.GetHash(config.GetInstance().GetVideoFileNamingAlgorithm()) sceneHash := scene.GetHash(config.GetInstance().GetVideoFileNamingAlgorithm())
sceneMarkerID, _ := strconv.Atoi(chi.URLParam(r, "sceneMarkerId")) sceneMarkerID, _ := strconv.Atoi(chi.URLParam(r, "sceneMarkerId"))
var sceneMarker *models.SceneMarker var sceneMarker *models.SceneMarker
readTxnErr := txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error { readTxnErr := rs.withReadTxn(r, func(ctx context.Context) error {
var err error var err error
sceneMarker, err = rs.sceneMarkerFinder.Find(ctx, sceneMarkerID) sceneMarker, err = rs.sceneMarkerFinder.Find(ctx, sceneMarkerID)
return err return err
@@ -530,7 +538,7 @@ func (rs sceneRoutes) SceneMarkerScreenshot(w http.ResponseWriter, r *http.Reque
sceneHash := scene.GetHash(config.GetInstance().GetVideoFileNamingAlgorithm()) sceneHash := scene.GetHash(config.GetInstance().GetVideoFileNamingAlgorithm())
sceneMarkerID, _ := strconv.Atoi(chi.URLParam(r, "sceneMarkerId")) sceneMarkerID, _ := strconv.Atoi(chi.URLParam(r, "sceneMarkerId"))
var sceneMarker *models.SceneMarker var sceneMarker *models.SceneMarker
readTxnErr := txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error { readTxnErr := rs.withReadTxn(r, func(ctx context.Context) error {
var err error var err error
sceneMarker, err = rs.sceneMarkerFinder.Find(ctx, sceneMarkerID) sceneMarker, err = rs.sceneMarkerFinder.Find(ctx, sceneMarkerID)
return err return err
@@ -561,8 +569,6 @@ func (rs sceneRoutes) SceneMarkerScreenshot(w http.ResponseWriter, r *http.Reque
} }
} }
// endregion
func (rs sceneRoutes) SceneCtx(next http.Handler) http.Handler { func (rs sceneRoutes) SceneCtx(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sceneID, err := strconv.Atoi(chi.URLParam(r, "sceneId")) sceneID, err := strconv.Atoi(chi.URLParam(r, "sceneId"))
@@ -572,7 +578,7 @@ func (rs sceneRoutes) SceneCtx(next http.Handler) http.Handler {
} }
var scene *models.Scene var scene *models.Scene
_ = txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error { _ = rs.withReadTxn(r, func(ctx context.Context) error {
qb := rs.sceneFinder qb := rs.sceneFinder
scene, _ = qb.Find(ctx, sceneID) scene, _ = qb.Find(ctx, sceneID)

View File

@@ -3,7 +3,6 @@ package api
import ( import (
"context" "context"
"errors" "errors"
"io"
"net/http" "net/http"
"strconv" "strconv"
@@ -11,7 +10,6 @@ import (
"github.com/stashapp/stash/internal/static" "github.com/stashapp/stash/internal/static"
"github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
) )
@@ -21,10 +19,17 @@ type StudioFinder interface {
} }
type studioRoutes struct { type studioRoutes struct {
txnManager txn.Manager routes
studioFinder StudioFinder studioFinder StudioFinder
} }
func getStudioRoutes(repo models.Repository) chi.Router {
return studioRoutes{
routes: routes{txnManager: repo.TxnManager},
studioFinder: repo.Studio,
}.Routes()
}
func (rs studioRoutes) Routes() chi.Router { func (rs studioRoutes) Routes() chi.Router {
r := chi.NewRouter() r := chi.NewRouter()
@@ -42,7 +47,7 @@ func (rs studioRoutes) Image(w http.ResponseWriter, r *http.Request) {
var image []byte var image []byte
if defaultParam != "true" { if defaultParam != "true" {
readTxnErr := txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error { readTxnErr := rs.withReadTxn(r, func(ctx context.Context) error {
var err error var err error
image, err = rs.studioFinder.GetImage(ctx, studio.ID) image, err = rs.studioFinder.GetImage(ctx, studio.ID)
return err return err
@@ -55,15 +60,9 @@ func (rs studioRoutes) Image(w http.ResponseWriter, r *http.Request) {
} }
} }
// fallback to default image
if len(image) == 0 { if len(image) == 0 {
const defaultStudioImage = "studio/studio.svg" image = static.ReadAll(static.DefaultStudioImage)
// fall back to static image
f, _ := static.Studio.Open(defaultStudioImage)
defer f.Close()
stat, _ := f.Stat()
http.ServeContent(w, r, "studio.svg", stat.ModTime(), f.(io.ReadSeeker))
return
} }
utils.ServeImage(w, r, image) utils.ServeImage(w, r, image)
@@ -78,7 +77,7 @@ func (rs studioRoutes) StudioCtx(next http.Handler) http.Handler {
} }
var studio *models.Studio var studio *models.Studio
_ = txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error { _ = rs.withReadTxn(r, func(ctx context.Context) error {
var err error var err error
studio, err = rs.studioFinder.Find(ctx, studioID) studio, err = rs.studioFinder.Find(ctx, studioID)
return err return err

View File

@@ -3,7 +3,6 @@ package api
import ( import (
"context" "context"
"errors" "errors"
"io"
"net/http" "net/http"
"strconv" "strconv"
@@ -11,7 +10,6 @@ import (
"github.com/stashapp/stash/internal/static" "github.com/stashapp/stash/internal/static"
"github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
) )
@@ -21,8 +19,15 @@ type TagFinder interface {
} }
type tagRoutes struct { type tagRoutes struct {
txnManager txn.Manager routes
tagFinder TagFinder tagFinder TagFinder
}
func getTagRoutes(repo models.Repository) chi.Router {
return tagRoutes{
routes: routes{txnManager: repo.TxnManager},
tagFinder: repo.Tag,
}.Routes()
} }
func (rs tagRoutes) Routes() chi.Router { func (rs tagRoutes) Routes() chi.Router {
@@ -42,7 +47,7 @@ func (rs tagRoutes) Image(w http.ResponseWriter, r *http.Request) {
var image []byte var image []byte
if defaultParam != "true" { if defaultParam != "true" {
readTxnErr := txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error { readTxnErr := rs.withReadTxn(r, func(ctx context.Context) error {
var err error var err error
image, err = rs.tagFinder.GetImage(ctx, tag.ID) image, err = rs.tagFinder.GetImage(ctx, tag.ID)
return err return err
@@ -55,15 +60,9 @@ func (rs tagRoutes) Image(w http.ResponseWriter, r *http.Request) {
} }
} }
// fallback to default image
if len(image) == 0 { if len(image) == 0 {
const defaultTagImage = "tag/tag.svg" image = static.ReadAll(static.DefaultTagImage)
// fall back to static image
f, _ := static.Tag.Open(defaultTagImage)
defer f.Close()
stat, _ := f.Stat()
http.ServeContent(w, r, "tag.svg", stat.ModTime(), f.(io.ReadSeeker))
return
} }
utils.ServeImage(w, r, image) utils.ServeImage(w, r, image)
@@ -78,7 +77,7 @@ func (rs tagRoutes) TagCtx(next http.Handler) http.Handler {
} }
var tag *models.Tag var tag *models.Tag
_ = txn.WithReadTxn(r.Context(), rs.txnManager, func(ctx context.Context) error { _ = rs.withReadTxn(r, func(ctx context.Context) error {
var err error var err error
tag, err = rs.tagFinder.Find(ctx, tagID) tag, err = rs.tagFinder.Find(ctx, tagID)
return err return err

View File

@@ -50,7 +50,9 @@ var uiBox = ui.UIBox
var loginUIBox = ui.LoginUIBox var loginUIBox = ui.LoginUIBox
func Start() error { func Start() error {
initialiseImages() c := config.GetInstance()
initCustomPerformerImages(c.GetCustomPerformerImageLocation())
r := chi.NewRouter() r := chi.NewRouter()
@@ -62,7 +64,6 @@ func Start() error {
r.Use(middleware.Recoverer) r.Use(middleware.Recoverer)
c := config.GetInstance()
if c.GetLogAccess() { if c.GetLogAccess() {
httpLogger := httplog.NewLogger("Stash", httplog.Options{ httpLogger := httplog.NewLogger("Stash", httplog.Options{
Concise: true, Concise: true,
@@ -82,11 +83,10 @@ func Start() error {
return errors.New(message) return errors.New(message)
} }
txnManager := manager.GetInstance().Repository repo := manager.GetInstance().Repository
dataloaders := loaders.Middleware{ dataloaders := loaders.Middleware{
DatabaseProvider: txnManager, Repository: repo,
Repository: txnManager,
} }
r.Use(dataloaders.Middleware) r.Use(dataloaders.Middleware)
@@ -96,8 +96,7 @@ func Start() error {
imageService := manager.GetInstance().ImageService imageService := manager.GetInstance().ImageService
galleryService := manager.GetInstance().GalleryService galleryService := manager.GetInstance().GalleryService
resolver := &Resolver{ resolver := &Resolver{
txnManager: txnManager, repository: repo,
repository: txnManager,
sceneService: sceneService, sceneService: sceneService,
imageService: imageService, imageService: imageService,
galleryService: galleryService, galleryService: galleryService,
@@ -144,36 +143,13 @@ func Start() error {
gqlPlayground.Handler("GraphQL playground", endpoint)(w, r) gqlPlayground.Handler("GraphQL playground", endpoint)(w, r)
}) })
r.Mount("/performer", performerRoutes{ r.Mount("/performer", getPerformerRoutes(repo))
txnManager: txnManager, r.Mount("/scene", getSceneRoutes(repo))
performerFinder: txnManager.Performer, r.Mount("/image", getImageRoutes(repo))
}.Routes()) r.Mount("/studio", getStudioRoutes(repo))
r.Mount("/scene", sceneRoutes{ r.Mount("/movie", getMovieRoutes(repo))
txnManager: txnManager, r.Mount("/tag", getTagRoutes(repo))
sceneFinder: txnManager.Scene, r.Mount("/downloads", getDownloadsRoutes())
fileGetter: txnManager.File,
captionFinder: txnManager.File,
sceneMarkerFinder: txnManager.SceneMarker,
tagFinder: txnManager.Tag,
}.Routes())
r.Mount("/image", imageRoutes{
txnManager: txnManager,
imageFinder: txnManager.Image,
fileGetter: txnManager.File,
}.Routes())
r.Mount("/studio", studioRoutes{
txnManager: txnManager,
studioFinder: txnManager.Studio,
}.Routes())
r.Mount("/movie", movieRoutes{
txnManager: txnManager,
movieFinder: txnManager.Movie,
}.Routes())
r.Mount("/tag", tagRoutes{
txnManager: txnManager,
tagFinder: txnManager.Tag,
}.Routes())
r.Mount("/downloads", downloadsRoutes{}.Routes())
r.HandleFunc("/css", cssHandler(c, pluginCache)) r.HandleFunc("/css", cssHandler(c, pluginCache))
r.HandleFunc("/javascript", javascriptHandler(c, pluginCache)) r.HandleFunc("/javascript", javascriptHandler(c, pluginCache))
@@ -193,9 +169,7 @@ func Start() error {
// Serve static folders // Serve static folders
customServedFolders := c.GetCustomServedFolders() customServedFolders := c.GetCustomServedFolders()
if customServedFolders != nil { if customServedFolders != nil {
r.Mount("/custom", customRoutes{ r.Mount("/custom", getCustomRoutes(customServedFolders))
servedFolders: customServedFolders,
}.Routes())
} }
customUILocation := c.GetCustomUILocation() customUILocation := c.GetCustomUILocation()

View File

@@ -52,11 +52,10 @@ func TestGalleryPerformers(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
for _, test := range testTables { for _, test := range testTables {
mockPerformerReader := &mocks.PerformerReaderWriter{} db := mocks.NewDatabase()
mockGalleryReader := &mocks.GalleryReaderWriter{}
mockPerformerReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) db.Performer.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockPerformerReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once() db.Performer.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once()
if test.Matches { if test.Matches {
matchPartial := mock.MatchedBy(func(got models.GalleryPartial) bool { matchPartial := mock.MatchedBy(func(got models.GalleryPartial) bool {
@@ -69,7 +68,7 @@ func TestGalleryPerformers(t *testing.T) {
return galleryPartialsEqual(got, expected) return galleryPartialsEqual(got, expected)
}) })
mockGalleryReader.On("UpdatePartial", testCtx, galleryID, matchPartial).Return(nil, nil).Once() db.Gallery.On("UpdatePartial", testCtx, galleryID, matchPartial).Return(nil, nil).Once()
} }
gallery := models.Gallery{ gallery := models.Gallery{
@@ -77,11 +76,10 @@ func TestGalleryPerformers(t *testing.T) {
Path: test.Path, Path: test.Path,
PerformerIDs: models.NewRelatedIDs([]int{}), PerformerIDs: models.NewRelatedIDs([]int{}),
} }
err := GalleryPerformers(testCtx, &gallery, mockGalleryReader, mockPerformerReader, nil) err := GalleryPerformers(testCtx, &gallery, db.Gallery, db.Performer, nil)
assert.Nil(err) assert.Nil(err)
mockPerformerReader.AssertExpectations(t) db.AssertExpectations(t)
mockGalleryReader.AssertExpectations(t)
} }
} }
@@ -107,7 +105,7 @@ func TestGalleryStudios(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
doTest := func(mockStudioReader *mocks.StudioReaderWriter, mockGalleryReader *mocks.GalleryReaderWriter, test pathTestTable) { doTest := func(db *mocks.Database, test pathTestTable) {
if test.Matches { if test.Matches {
matchPartial := mock.MatchedBy(func(got models.GalleryPartial) bool { matchPartial := mock.MatchedBy(func(got models.GalleryPartial) bool {
expected := models.GalleryPartial{ expected := models.GalleryPartial{
@@ -116,29 +114,27 @@ func TestGalleryStudios(t *testing.T) {
return galleryPartialsEqual(got, expected) return galleryPartialsEqual(got, expected)
}) })
mockGalleryReader.On("UpdatePartial", testCtx, galleryID, matchPartial).Return(nil, nil).Once() db.Gallery.On("UpdatePartial", testCtx, galleryID, matchPartial).Return(nil, nil).Once()
} }
gallery := models.Gallery{ gallery := models.Gallery{
ID: galleryID, ID: galleryID,
Path: test.Path, Path: test.Path,
} }
err := GalleryStudios(testCtx, &gallery, mockGalleryReader, mockStudioReader, nil) err := GalleryStudios(testCtx, &gallery, db.Gallery, db.Studio, nil)
assert.Nil(err) assert.Nil(err)
mockStudioReader.AssertExpectations(t) db.AssertExpectations(t)
mockGalleryReader.AssertExpectations(t)
} }
for _, test := range testTables { for _, test := range testTables {
mockStudioReader := &mocks.StudioReaderWriter{} db := mocks.NewDatabase()
mockGalleryReader := &mocks.GalleryReaderWriter{}
mockStudioReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) db.Studio.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockStudioReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once() db.Studio.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
mockStudioReader.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe() db.Studio.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe()
doTest(mockStudioReader, mockGalleryReader, test) doTest(db, test)
} }
// test against aliases // test against aliases
@@ -146,17 +142,16 @@ func TestGalleryStudios(t *testing.T) {
studio.Name = unmatchedName studio.Name = unmatchedName
for _, test := range testTables { for _, test := range testTables {
mockStudioReader := &mocks.StudioReaderWriter{} db := mocks.NewDatabase()
mockGalleryReader := &mocks.GalleryReaderWriter{}
mockStudioReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) db.Studio.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockStudioReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once() db.Studio.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
mockStudioReader.On("GetAliases", testCtx, studioID).Return([]string{ db.Studio.On("GetAliases", testCtx, studioID).Return([]string{
studioName, studioName,
}, nil).Once() }, nil).Once()
mockStudioReader.On("GetAliases", testCtx, reversedStudioID).Return([]string{}, nil).Once() db.Studio.On("GetAliases", testCtx, reversedStudioID).Return([]string{}, nil).Once()
doTest(mockStudioReader, mockGalleryReader, test) doTest(db, test)
} }
} }
@@ -182,7 +177,7 @@ func TestGalleryTags(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
doTest := func(mockTagReader *mocks.TagReaderWriter, mockGalleryReader *mocks.GalleryReaderWriter, test pathTestTable) { doTest := func(db *mocks.Database, test pathTestTable) {
if test.Matches { if test.Matches {
matchPartial := mock.MatchedBy(func(got models.GalleryPartial) bool { matchPartial := mock.MatchedBy(func(got models.GalleryPartial) bool {
expected := models.GalleryPartial{ expected := models.GalleryPartial{
@@ -194,7 +189,7 @@ func TestGalleryTags(t *testing.T) {
return galleryPartialsEqual(got, expected) return galleryPartialsEqual(got, expected)
}) })
mockGalleryReader.On("UpdatePartial", testCtx, galleryID, matchPartial).Return(nil, nil).Once() db.Gallery.On("UpdatePartial", testCtx, galleryID, matchPartial).Return(nil, nil).Once()
} }
gallery := models.Gallery{ gallery := models.Gallery{
@@ -202,38 +197,35 @@ func TestGalleryTags(t *testing.T) {
Path: test.Path, Path: test.Path,
TagIDs: models.NewRelatedIDs([]int{}), TagIDs: models.NewRelatedIDs([]int{}),
} }
err := GalleryTags(testCtx, &gallery, mockGalleryReader, mockTagReader, nil) err := GalleryTags(testCtx, &gallery, db.Gallery, db.Tag, nil)
assert.Nil(err) assert.Nil(err)
mockTagReader.AssertExpectations(t) db.AssertExpectations(t)
mockGalleryReader.AssertExpectations(t)
} }
for _, test := range testTables { for _, test := range testTables {
mockTagReader := &mocks.TagReaderWriter{} db := mocks.NewDatabase()
mockGalleryReader := &mocks.GalleryReaderWriter{}
mockTagReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) db.Tag.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockTagReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once() db.Tag.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
mockTagReader.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe() db.Tag.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe()
doTest(mockTagReader, mockGalleryReader, test) doTest(db, test)
} }
const unmatchedName = "unmatched" const unmatchedName = "unmatched"
tag.Name = unmatchedName tag.Name = unmatchedName
for _, test := range testTables { for _, test := range testTables {
mockTagReader := &mocks.TagReaderWriter{} db := mocks.NewDatabase()
mockGalleryReader := &mocks.GalleryReaderWriter{}
mockTagReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) db.Tag.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockTagReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once() db.Tag.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
mockTagReader.On("GetAliases", testCtx, tagID).Return([]string{ db.Tag.On("GetAliases", testCtx, tagID).Return([]string{
tagName, tagName,
}, nil).Once() }, nil).Once()
mockTagReader.On("GetAliases", testCtx, reversedTagID).Return([]string{}, nil).Once() db.Tag.On("GetAliases", testCtx, reversedTagID).Return([]string{}, nil).Once()
doTest(mockTagReader, mockGalleryReader, test) doTest(db, test)
} }
} }

View File

@@ -49,11 +49,10 @@ func TestImagePerformers(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
for _, test := range testTables { for _, test := range testTables {
mockPerformerReader := &mocks.PerformerReaderWriter{} db := mocks.NewDatabase()
mockImageReader := &mocks.ImageReaderWriter{}
mockPerformerReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) db.Performer.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockPerformerReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once() db.Performer.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once()
if test.Matches { if test.Matches {
matchPartial := mock.MatchedBy(func(got models.ImagePartial) bool { matchPartial := mock.MatchedBy(func(got models.ImagePartial) bool {
@@ -66,7 +65,7 @@ func TestImagePerformers(t *testing.T) {
return imagePartialsEqual(got, expected) return imagePartialsEqual(got, expected)
}) })
mockImageReader.On("UpdatePartial", testCtx, imageID, matchPartial).Return(nil, nil).Once() db.Image.On("UpdatePartial", testCtx, imageID, matchPartial).Return(nil, nil).Once()
} }
image := models.Image{ image := models.Image{
@@ -74,11 +73,10 @@ func TestImagePerformers(t *testing.T) {
Path: test.Path, Path: test.Path,
PerformerIDs: models.NewRelatedIDs([]int{}), PerformerIDs: models.NewRelatedIDs([]int{}),
} }
err := ImagePerformers(testCtx, &image, mockImageReader, mockPerformerReader, nil) err := ImagePerformers(testCtx, &image, db.Image, db.Performer, nil)
assert.Nil(err) assert.Nil(err)
mockPerformerReader.AssertExpectations(t) db.AssertExpectations(t)
mockImageReader.AssertExpectations(t)
} }
} }
@@ -104,7 +102,7 @@ func TestImageStudios(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
doTest := func(mockStudioReader *mocks.StudioReaderWriter, mockImageReader *mocks.ImageReaderWriter, test pathTestTable) { doTest := func(db *mocks.Database, test pathTestTable) {
if test.Matches { if test.Matches {
matchPartial := mock.MatchedBy(func(got models.ImagePartial) bool { matchPartial := mock.MatchedBy(func(got models.ImagePartial) bool {
expected := models.ImagePartial{ expected := models.ImagePartial{
@@ -113,29 +111,27 @@ func TestImageStudios(t *testing.T) {
return imagePartialsEqual(got, expected) return imagePartialsEqual(got, expected)
}) })
mockImageReader.On("UpdatePartial", testCtx, imageID, matchPartial).Return(nil, nil).Once() db.Image.On("UpdatePartial", testCtx, imageID, matchPartial).Return(nil, nil).Once()
} }
image := models.Image{ image := models.Image{
ID: imageID, ID: imageID,
Path: test.Path, Path: test.Path,
} }
err := ImageStudios(testCtx, &image, mockImageReader, mockStudioReader, nil) err := ImageStudios(testCtx, &image, db.Image, db.Studio, nil)
assert.Nil(err) assert.Nil(err)
mockStudioReader.AssertExpectations(t) db.AssertExpectations(t)
mockImageReader.AssertExpectations(t)
} }
for _, test := range testTables { for _, test := range testTables {
mockStudioReader := &mocks.StudioReaderWriter{} db := mocks.NewDatabase()
mockImageReader := &mocks.ImageReaderWriter{}
mockStudioReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) db.Studio.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockStudioReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once() db.Studio.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
mockStudioReader.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe() db.Studio.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe()
doTest(mockStudioReader, mockImageReader, test) doTest(db, test)
} }
// test against aliases // test against aliases
@@ -143,17 +139,16 @@ func TestImageStudios(t *testing.T) {
studio.Name = unmatchedName studio.Name = unmatchedName
for _, test := range testTables { for _, test := range testTables {
mockStudioReader := &mocks.StudioReaderWriter{} db := mocks.NewDatabase()
mockImageReader := &mocks.ImageReaderWriter{}
mockStudioReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) db.Studio.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockStudioReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once() db.Studio.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
mockStudioReader.On("GetAliases", testCtx, studioID).Return([]string{ db.Studio.On("GetAliases", testCtx, studioID).Return([]string{
studioName, studioName,
}, nil).Once() }, nil).Once()
mockStudioReader.On("GetAliases", testCtx, reversedStudioID).Return([]string{}, nil).Once() db.Studio.On("GetAliases", testCtx, reversedStudioID).Return([]string{}, nil).Once()
doTest(mockStudioReader, mockImageReader, test) doTest(db, test)
} }
} }
@@ -179,7 +174,7 @@ func TestImageTags(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
doTest := func(mockTagReader *mocks.TagReaderWriter, mockImageReader *mocks.ImageReaderWriter, test pathTestTable) { doTest := func(db *mocks.Database, test pathTestTable) {
if test.Matches { if test.Matches {
matchPartial := mock.MatchedBy(func(got models.ImagePartial) bool { matchPartial := mock.MatchedBy(func(got models.ImagePartial) bool {
expected := models.ImagePartial{ expected := models.ImagePartial{
@@ -191,7 +186,7 @@ func TestImageTags(t *testing.T) {
return imagePartialsEqual(got, expected) return imagePartialsEqual(got, expected)
}) })
mockImageReader.On("UpdatePartial", testCtx, imageID, matchPartial).Return(nil, nil).Once() db.Image.On("UpdatePartial", testCtx, imageID, matchPartial).Return(nil, nil).Once()
} }
image := models.Image{ image := models.Image{
@@ -199,22 +194,20 @@ func TestImageTags(t *testing.T) {
Path: test.Path, Path: test.Path,
TagIDs: models.NewRelatedIDs([]int{}), TagIDs: models.NewRelatedIDs([]int{}),
} }
err := ImageTags(testCtx, &image, mockImageReader, mockTagReader, nil) err := ImageTags(testCtx, &image, db.Image, db.Tag, nil)
assert.Nil(err) assert.Nil(err)
mockTagReader.AssertExpectations(t) db.AssertExpectations(t)
mockImageReader.AssertExpectations(t)
} }
for _, test := range testTables { for _, test := range testTables {
mockTagReader := &mocks.TagReaderWriter{} db := mocks.NewDatabase()
mockImageReader := &mocks.ImageReaderWriter{}
mockTagReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) db.Tag.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockTagReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once() db.Tag.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
mockTagReader.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe() db.Tag.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe()
doTest(mockTagReader, mockImageReader, test) doTest(db, test)
} }
// test against aliases // test against aliases
@@ -222,16 +215,15 @@ func TestImageTags(t *testing.T) {
tag.Name = unmatchedName tag.Name = unmatchedName
for _, test := range testTables { for _, test := range testTables {
mockTagReader := &mocks.TagReaderWriter{} db := mocks.NewDatabase()
mockImageReader := &mocks.ImageReaderWriter{}
mockTagReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) db.Tag.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockTagReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once() db.Tag.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
mockTagReader.On("GetAliases", testCtx, tagID).Return([]string{ db.Tag.On("GetAliases", testCtx, tagID).Return([]string{
tagName, tagName,
}, nil).Once() }, nil).Once()
mockTagReader.On("GetAliases", testCtx, reversedTagID).Return([]string{}, nil).Once() db.Tag.On("GetAliases", testCtx, reversedTagID).Return([]string{}, nil).Once()
doTest(mockTagReader, mockImageReader, test) doTest(db, test)
} }
} }

View File

@@ -62,7 +62,7 @@ func runTests(m *testing.M) int {
panic(fmt.Sprintf("Could not initialize database: %s", err.Error())) panic(fmt.Sprintf("Could not initialize database: %s", err.Error()))
} }
r = db.TxnRepository() r = db.Repository()
// defer close and delete the database // defer close and delete the database
defer testTeardown(databaseFile) defer testTeardown(databaseFile)
@@ -474,11 +474,11 @@ func createGallery(ctx context.Context, w models.GalleryWriter, o *models.Galler
} }
func withTxn(f func(ctx context.Context) error) error { func withTxn(f func(ctx context.Context) error) error {
return txn.WithTxn(context.TODO(), db, f) return txn.WithTxn(testCtx, db, f)
} }
func withDB(f func(ctx context.Context) error) error { func withDB(f func(ctx context.Context) error) error {
return txn.WithDatabase(context.TODO(), db, f) return txn.WithDatabase(testCtx, db, f)
} }
func populateDB() error { func populateDB() error {

View File

@@ -45,7 +45,7 @@ func TestPerformerScenes(t *testing.T) {
} }
func testPerformerScenes(t *testing.T, performerName, expectedRegex string) { func testPerformerScenes(t *testing.T, performerName, expectedRegex string) {
mockSceneReader := &mocks.SceneReaderWriter{} db := mocks.NewDatabase()
const performerID = 2 const performerID = 2
@@ -84,7 +84,7 @@ func testPerformerScenes(t *testing.T, performerName, expectedRegex string) {
Direction: &direction, Direction: &direction,
} }
mockSceneReader.On("Query", mock.Anything, scene.QueryOptions(expectedSceneFilter, expectedFindFilter, false)). db.Scene.On("Query", mock.Anything, scene.QueryOptions(expectedSceneFilter, expectedFindFilter, false)).
Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once() Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once()
for i := range matchingPaths { for i := range matchingPaths {
@@ -100,19 +100,19 @@ func testPerformerScenes(t *testing.T, performerName, expectedRegex string) {
return scenePartialsEqual(got, expected) return scenePartialsEqual(got, expected)
}) })
mockSceneReader.On("UpdatePartial", mock.Anything, sceneID, matchPartial).Return(nil, nil).Once() db.Scene.On("UpdatePartial", mock.Anything, sceneID, matchPartial).Return(nil, nil).Once()
} }
tagger := Tagger{ tagger := Tagger{
TxnManager: &mocks.TxnManager{}, TxnManager: db,
} }
err := tagger.PerformerScenes(testCtx, &performer, nil, mockSceneReader) err := tagger.PerformerScenes(testCtx, &performer, nil, db.Scene)
assert := assert.New(t) assert := assert.New(t)
assert.Nil(err) assert.Nil(err)
mockSceneReader.AssertExpectations(t) db.AssertExpectations(t)
} }
func TestPerformerImages(t *testing.T) { func TestPerformerImages(t *testing.T) {
@@ -140,7 +140,7 @@ func TestPerformerImages(t *testing.T) {
} }
func testPerformerImages(t *testing.T, performerName, expectedRegex string) { func testPerformerImages(t *testing.T, performerName, expectedRegex string) {
mockImageReader := &mocks.ImageReaderWriter{} db := mocks.NewDatabase()
const performerID = 2 const performerID = 2
@@ -179,7 +179,7 @@ func testPerformerImages(t *testing.T, performerName, expectedRegex string) {
Direction: &direction, Direction: &direction,
} }
mockImageReader.On("Query", mock.Anything, image.QueryOptions(expectedImageFilter, expectedFindFilter, false)). db.Image.On("Query", mock.Anything, image.QueryOptions(expectedImageFilter, expectedFindFilter, false)).
Return(mocks.ImageQueryResult(images, len(images)), nil).Once() Return(mocks.ImageQueryResult(images, len(images)), nil).Once()
for i := range matchingPaths { for i := range matchingPaths {
@@ -195,19 +195,19 @@ func testPerformerImages(t *testing.T, performerName, expectedRegex string) {
return imagePartialsEqual(got, expected) return imagePartialsEqual(got, expected)
}) })
mockImageReader.On("UpdatePartial", mock.Anything, imageID, matchPartial).Return(nil, nil).Once() db.Image.On("UpdatePartial", mock.Anything, imageID, matchPartial).Return(nil, nil).Once()
} }
tagger := Tagger{ tagger := Tagger{
TxnManager: &mocks.TxnManager{}, TxnManager: db,
} }
err := tagger.PerformerImages(testCtx, &performer, nil, mockImageReader) err := tagger.PerformerImages(testCtx, &performer, nil, db.Image)
assert := assert.New(t) assert := assert.New(t)
assert.Nil(err) assert.Nil(err)
mockImageReader.AssertExpectations(t) db.AssertExpectations(t)
} }
func TestPerformerGalleries(t *testing.T) { func TestPerformerGalleries(t *testing.T) {
@@ -235,7 +235,7 @@ func TestPerformerGalleries(t *testing.T) {
} }
func testPerformerGalleries(t *testing.T, performerName, expectedRegex string) { func testPerformerGalleries(t *testing.T, performerName, expectedRegex string) {
mockGalleryReader := &mocks.GalleryReaderWriter{} db := mocks.NewDatabase()
const performerID = 2 const performerID = 2
@@ -275,7 +275,7 @@ func testPerformerGalleries(t *testing.T, performerName, expectedRegex string) {
Direction: &direction, Direction: &direction,
} }
mockGalleryReader.On("Query", mock.Anything, expectedGalleryFilter, expectedFindFilter).Return(galleries, len(galleries), nil).Once() db.Gallery.On("Query", mock.Anything, expectedGalleryFilter, expectedFindFilter).Return(galleries, len(galleries), nil).Once()
for i := range matchingPaths { for i := range matchingPaths {
galleryID := i + 1 galleryID := i + 1
@@ -290,17 +290,17 @@ func testPerformerGalleries(t *testing.T, performerName, expectedRegex string) {
return galleryPartialsEqual(got, expected) return galleryPartialsEqual(got, expected)
}) })
mockGalleryReader.On("UpdatePartial", mock.Anything, galleryID, matchPartial).Return(nil, nil).Once() db.Gallery.On("UpdatePartial", mock.Anything, galleryID, matchPartial).Return(nil, nil).Once()
} }
tagger := Tagger{ tagger := Tagger{
TxnManager: &mocks.TxnManager{}, TxnManager: db,
} }
err := tagger.PerformerGalleries(testCtx, &performer, nil, mockGalleryReader) err := tagger.PerformerGalleries(testCtx, &performer, nil, db.Gallery)
assert := assert.New(t) assert := assert.New(t)
assert.Nil(err) assert.Nil(err)
mockGalleryReader.AssertExpectations(t) db.AssertExpectations(t)
} }

View File

@@ -182,11 +182,10 @@ func TestScenePerformers(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
for _, test := range testTables { for _, test := range testTables {
mockPerformerReader := &mocks.PerformerReaderWriter{} db := mocks.NewDatabase()
mockSceneReader := &mocks.SceneReaderWriter{}
mockPerformerReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) db.Performer.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockPerformerReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once() db.Performer.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once()
scene := models.Scene{ scene := models.Scene{
ID: sceneID, ID: sceneID,
@@ -205,14 +204,13 @@ func TestScenePerformers(t *testing.T) {
return scenePartialsEqual(got, expected) return scenePartialsEqual(got, expected)
}) })
mockSceneReader.On("UpdatePartial", testCtx, sceneID, matchPartial).Return(nil, nil).Once() db.Scene.On("UpdatePartial", testCtx, sceneID, matchPartial).Return(nil, nil).Once()
} }
err := ScenePerformers(testCtx, &scene, mockSceneReader, mockPerformerReader, nil) err := ScenePerformers(testCtx, &scene, db.Scene, db.Performer, nil)
assert.Nil(err) assert.Nil(err)
mockPerformerReader.AssertExpectations(t) db.AssertExpectations(t)
mockSceneReader.AssertExpectations(t)
} }
} }
@@ -240,7 +238,7 @@ func TestSceneStudios(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
doTest := func(mockStudioReader *mocks.StudioReaderWriter, mockSceneReader *mocks.SceneReaderWriter, test pathTestTable) { doTest := func(db *mocks.Database, test pathTestTable) {
if test.Matches { if test.Matches {
matchPartial := mock.MatchedBy(func(got models.ScenePartial) bool { matchPartial := mock.MatchedBy(func(got models.ScenePartial) bool {
expected := models.ScenePartial{ expected := models.ScenePartial{
@@ -249,29 +247,27 @@ func TestSceneStudios(t *testing.T) {
return scenePartialsEqual(got, expected) return scenePartialsEqual(got, expected)
}) })
mockSceneReader.On("UpdatePartial", testCtx, sceneID, matchPartial).Return(nil, nil).Once() db.Scene.On("UpdatePartial", testCtx, sceneID, matchPartial).Return(nil, nil).Once()
} }
scene := models.Scene{ scene := models.Scene{
ID: sceneID, ID: sceneID,
Path: test.Path, Path: test.Path,
} }
err := SceneStudios(testCtx, &scene, mockSceneReader, mockStudioReader, nil) err := SceneStudios(testCtx, &scene, db.Scene, db.Studio, nil)
assert.Nil(err) assert.Nil(err)
mockStudioReader.AssertExpectations(t) db.AssertExpectations(t)
mockSceneReader.AssertExpectations(t)
} }
for _, test := range testTables { for _, test := range testTables {
mockStudioReader := &mocks.StudioReaderWriter{} db := mocks.NewDatabase()
mockSceneReader := &mocks.SceneReaderWriter{}
mockStudioReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) db.Studio.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockStudioReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once() db.Studio.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
mockStudioReader.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe() db.Studio.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe()
doTest(mockStudioReader, mockSceneReader, test) doTest(db, test)
} }
const unmatchedName = "unmatched" const unmatchedName = "unmatched"
@@ -279,17 +275,16 @@ func TestSceneStudios(t *testing.T) {
// test against aliases // test against aliases
for _, test := range testTables { for _, test := range testTables {
mockStudioReader := &mocks.StudioReaderWriter{} db := mocks.NewDatabase()
mockSceneReader := &mocks.SceneReaderWriter{}
mockStudioReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) db.Studio.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockStudioReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once() db.Studio.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
mockStudioReader.On("GetAliases", testCtx, studioID).Return([]string{ db.Studio.On("GetAliases", testCtx, studioID).Return([]string{
studioName, studioName,
}, nil).Once() }, nil).Once()
mockStudioReader.On("GetAliases", testCtx, reversedStudioID).Return([]string{}, nil).Once() db.Studio.On("GetAliases", testCtx, reversedStudioID).Return([]string{}, nil).Once()
doTest(mockStudioReader, mockSceneReader, test) doTest(db, test)
} }
} }
@@ -315,7 +310,7 @@ func TestSceneTags(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
doTest := func(mockTagReader *mocks.TagReaderWriter, mockSceneReader *mocks.SceneReaderWriter, test pathTestTable) { doTest := func(db *mocks.Database, test pathTestTable) {
if test.Matches { if test.Matches {
matchPartial := mock.MatchedBy(func(got models.ScenePartial) bool { matchPartial := mock.MatchedBy(func(got models.ScenePartial) bool {
expected := models.ScenePartial{ expected := models.ScenePartial{
@@ -327,7 +322,7 @@ func TestSceneTags(t *testing.T) {
return scenePartialsEqual(got, expected) return scenePartialsEqual(got, expected)
}) })
mockSceneReader.On("UpdatePartial", testCtx, sceneID, matchPartial).Return(nil, nil).Once() db.Scene.On("UpdatePartial", testCtx, sceneID, matchPartial).Return(nil, nil).Once()
} }
scene := models.Scene{ scene := models.Scene{
@@ -335,22 +330,20 @@ func TestSceneTags(t *testing.T) {
Path: test.Path, Path: test.Path,
TagIDs: models.NewRelatedIDs([]int{}), TagIDs: models.NewRelatedIDs([]int{}),
} }
err := SceneTags(testCtx, &scene, mockSceneReader, mockTagReader, nil) err := SceneTags(testCtx, &scene, db.Scene, db.Tag, nil)
assert.Nil(err) assert.Nil(err)
mockTagReader.AssertExpectations(t) db.AssertExpectations(t)
mockSceneReader.AssertExpectations(t)
} }
for _, test := range testTables { for _, test := range testTables {
mockTagReader := &mocks.TagReaderWriter{} db := mocks.NewDatabase()
mockSceneReader := &mocks.SceneReaderWriter{}
mockTagReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) db.Tag.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockTagReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once() db.Tag.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
mockTagReader.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe() db.Tag.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe()
doTest(mockTagReader, mockSceneReader, test) doTest(db, test)
} }
const unmatchedName = "unmatched" const unmatchedName = "unmatched"
@@ -358,16 +351,15 @@ func TestSceneTags(t *testing.T) {
// test against aliases // test against aliases
for _, test := range testTables { for _, test := range testTables {
mockTagReader := &mocks.TagReaderWriter{} db := mocks.NewDatabase()
mockSceneReader := &mocks.SceneReaderWriter{}
mockTagReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil) db.Tag.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
mockTagReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once() db.Tag.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
mockTagReader.On("GetAliases", testCtx, tagID).Return([]string{ db.Tag.On("GetAliases", testCtx, tagID).Return([]string{
tagName, tagName,
}, nil).Once() }, nil).Once()
mockTagReader.On("GetAliases", testCtx, reversedTagID).Return([]string{}, nil).Once() db.Tag.On("GetAliases", testCtx, reversedTagID).Return([]string{}, nil).Once()
doTest(mockTagReader, mockSceneReader, test) doTest(db, test)
} }
} }

View File

@@ -83,7 +83,7 @@ func testStudioScenes(t *testing.T, tc testStudioCase) {
aliasName := tc.aliasName aliasName := tc.aliasName
aliasRegex := tc.aliasRegex aliasRegex := tc.aliasRegex
mockSceneReader := &mocks.SceneReaderWriter{} db := mocks.NewDatabase()
var studioID = 2 var studioID = 2
@@ -130,7 +130,7 @@ func testStudioScenes(t *testing.T, tc testStudioCase) {
} }
// if alias provided, then don't find by name // if alias provided, then don't find by name
onNameQuery := mockSceneReader.On("Query", testCtx, scene.QueryOptions(expectedSceneFilter, expectedFindFilter, false)) onNameQuery := db.Scene.On("Query", testCtx, scene.QueryOptions(expectedSceneFilter, expectedFindFilter, false))
if aliasName == "" { if aliasName == "" {
onNameQuery.Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once() onNameQuery.Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once()
@@ -145,7 +145,7 @@ func testStudioScenes(t *testing.T, tc testStudioCase) {
}, },
} }
mockSceneReader.On("Query", mock.Anything, scene.QueryOptions(expectedAliasFilter, expectedFindFilter, false)). db.Scene.On("Query", mock.Anything, scene.QueryOptions(expectedAliasFilter, expectedFindFilter, false)).
Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once() Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once()
} }
@@ -159,19 +159,19 @@ func testStudioScenes(t *testing.T, tc testStudioCase) {
return scenePartialsEqual(got, expected) return scenePartialsEqual(got, expected)
}) })
mockSceneReader.On("UpdatePartial", mock.Anything, sceneID, matchPartial).Return(nil, nil).Once() db.Scene.On("UpdatePartial", mock.Anything, sceneID, matchPartial).Return(nil, nil).Once()
} }
tagger := Tagger{ tagger := Tagger{
TxnManager: &mocks.TxnManager{}, TxnManager: db,
} }
err := tagger.StudioScenes(testCtx, &studio, nil, aliases, mockSceneReader) err := tagger.StudioScenes(testCtx, &studio, nil, aliases, db.Scene)
assert := assert.New(t) assert := assert.New(t)
assert.Nil(err) assert.Nil(err)
mockSceneReader.AssertExpectations(t) db.AssertExpectations(t)
} }
func TestStudioImages(t *testing.T) { func TestStudioImages(t *testing.T) {
@@ -188,7 +188,7 @@ func testStudioImages(t *testing.T, tc testStudioCase) {
aliasName := tc.aliasName aliasName := tc.aliasName
aliasRegex := tc.aliasRegex aliasRegex := tc.aliasRegex
mockImageReader := &mocks.ImageReaderWriter{} db := mocks.NewDatabase()
var studioID = 2 var studioID = 2
@@ -234,7 +234,7 @@ func testStudioImages(t *testing.T, tc testStudioCase) {
} }
// if alias provided, then don't find by name // if alias provided, then don't find by name
onNameQuery := mockImageReader.On("Query", mock.Anything, image.QueryOptions(expectedImageFilter, expectedFindFilter, false)) onNameQuery := db.Image.On("Query", mock.Anything, image.QueryOptions(expectedImageFilter, expectedFindFilter, false))
if aliasName == "" { if aliasName == "" {
onNameQuery.Return(mocks.ImageQueryResult(images, len(images)), nil).Once() onNameQuery.Return(mocks.ImageQueryResult(images, len(images)), nil).Once()
} else { } else {
@@ -248,7 +248,7 @@ func testStudioImages(t *testing.T, tc testStudioCase) {
}, },
} }
mockImageReader.On("Query", mock.Anything, image.QueryOptions(expectedAliasFilter, expectedFindFilter, false)). db.Image.On("Query", mock.Anything, image.QueryOptions(expectedAliasFilter, expectedFindFilter, false)).
Return(mocks.ImageQueryResult(images, len(images)), nil).Once() Return(mocks.ImageQueryResult(images, len(images)), nil).Once()
} }
@@ -262,19 +262,19 @@ func testStudioImages(t *testing.T, tc testStudioCase) {
return imagePartialsEqual(got, expected) return imagePartialsEqual(got, expected)
}) })
mockImageReader.On("UpdatePartial", mock.Anything, imageID, matchPartial).Return(nil, nil).Once() db.Image.On("UpdatePartial", mock.Anything, imageID, matchPartial).Return(nil, nil).Once()
} }
tagger := Tagger{ tagger := Tagger{
TxnManager: &mocks.TxnManager{}, TxnManager: db,
} }
err := tagger.StudioImages(testCtx, &studio, nil, aliases, mockImageReader) err := tagger.StudioImages(testCtx, &studio, nil, aliases, db.Image)
assert := assert.New(t) assert := assert.New(t)
assert.Nil(err) assert.Nil(err)
mockImageReader.AssertExpectations(t) db.AssertExpectations(t)
} }
func TestStudioGalleries(t *testing.T) { func TestStudioGalleries(t *testing.T) {
@@ -290,7 +290,8 @@ func testStudioGalleries(t *testing.T, tc testStudioCase) {
expectedRegex := tc.expectedRegex expectedRegex := tc.expectedRegex
aliasName := tc.aliasName aliasName := tc.aliasName
aliasRegex := tc.aliasRegex aliasRegex := tc.aliasRegex
mockGalleryReader := &mocks.GalleryReaderWriter{}
db := mocks.NewDatabase()
var studioID = 2 var studioID = 2
@@ -337,7 +338,7 @@ func testStudioGalleries(t *testing.T, tc testStudioCase) {
} }
// if alias provided, then don't find by name // if alias provided, then don't find by name
onNameQuery := mockGalleryReader.On("Query", mock.Anything, expectedGalleryFilter, expectedFindFilter) onNameQuery := db.Gallery.On("Query", mock.Anything, expectedGalleryFilter, expectedFindFilter)
if aliasName == "" { if aliasName == "" {
onNameQuery.Return(galleries, len(galleries), nil).Once() onNameQuery.Return(galleries, len(galleries), nil).Once()
} else { } else {
@@ -351,7 +352,7 @@ func testStudioGalleries(t *testing.T, tc testStudioCase) {
}, },
} }
mockGalleryReader.On("Query", mock.Anything, expectedAliasFilter, expectedFindFilter).Return(galleries, len(galleries), nil).Once() db.Gallery.On("Query", mock.Anything, expectedAliasFilter, expectedFindFilter).Return(galleries, len(galleries), nil).Once()
} }
for i := range matchingPaths { for i := range matchingPaths {
@@ -364,17 +365,17 @@ func testStudioGalleries(t *testing.T, tc testStudioCase) {
return galleryPartialsEqual(got, expected) return galleryPartialsEqual(got, expected)
}) })
mockGalleryReader.On("UpdatePartial", mock.Anything, galleryID, matchPartial).Return(nil, nil).Once() db.Gallery.On("UpdatePartial", mock.Anything, galleryID, matchPartial).Return(nil, nil).Once()
} }
tagger := Tagger{ tagger := Tagger{
TxnManager: &mocks.TxnManager{}, TxnManager: db,
} }
err := tagger.StudioGalleries(testCtx, &studio, nil, aliases, mockGalleryReader) err := tagger.StudioGalleries(testCtx, &studio, nil, aliases, db.Gallery)
assert := assert.New(t) assert := assert.New(t)
assert.Nil(err) assert.Nil(err)
mockGalleryReader.AssertExpectations(t) db.AssertExpectations(t)
} }

View File

@@ -83,7 +83,7 @@ func testTagScenes(t *testing.T, tc testTagCase) {
aliasName := tc.aliasName aliasName := tc.aliasName
aliasRegex := tc.aliasRegex aliasRegex := tc.aliasRegex
mockSceneReader := &mocks.SceneReaderWriter{} db := mocks.NewDatabase()
const tagID = 2 const tagID = 2
@@ -131,7 +131,7 @@ func testTagScenes(t *testing.T, tc testTagCase) {
} }
// if alias provided, then don't find by name // if alias provided, then don't find by name
onNameQuery := mockSceneReader.On("Query", testCtx, scene.QueryOptions(expectedSceneFilter, expectedFindFilter, false)) onNameQuery := db.Scene.On("Query", testCtx, scene.QueryOptions(expectedSceneFilter, expectedFindFilter, false))
if aliasName == "" { if aliasName == "" {
onNameQuery.Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once() onNameQuery.Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once()
} else { } else {
@@ -145,7 +145,7 @@ func testTagScenes(t *testing.T, tc testTagCase) {
}, },
} }
mockSceneReader.On("Query", mock.Anything, scene.QueryOptions(expectedAliasFilter, expectedFindFilter, false)). db.Scene.On("Query", mock.Anything, scene.QueryOptions(expectedAliasFilter, expectedFindFilter, false)).
Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once() Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once()
} }
@@ -162,19 +162,19 @@ func testTagScenes(t *testing.T, tc testTagCase) {
return scenePartialsEqual(got, expected) return scenePartialsEqual(got, expected)
}) })
mockSceneReader.On("UpdatePartial", mock.Anything, sceneID, matchPartial).Return(nil, nil).Once() db.Scene.On("UpdatePartial", mock.Anything, sceneID, matchPartial).Return(nil, nil).Once()
} }
tagger := Tagger{ tagger := Tagger{
TxnManager: &mocks.TxnManager{}, TxnManager: db,
} }
err := tagger.TagScenes(testCtx, &tag, nil, aliases, mockSceneReader) err := tagger.TagScenes(testCtx, &tag, nil, aliases, db.Scene)
assert := assert.New(t) assert := assert.New(t)
assert.Nil(err) assert.Nil(err)
mockSceneReader.AssertExpectations(t) db.AssertExpectations(t)
} }
func TestTagImages(t *testing.T) { func TestTagImages(t *testing.T) {
@@ -191,7 +191,7 @@ func testTagImages(t *testing.T, tc testTagCase) {
aliasName := tc.aliasName aliasName := tc.aliasName
aliasRegex := tc.aliasRegex aliasRegex := tc.aliasRegex
mockImageReader := &mocks.ImageReaderWriter{} db := mocks.NewDatabase()
const tagID = 2 const tagID = 2
@@ -238,7 +238,7 @@ func testTagImages(t *testing.T, tc testTagCase) {
} }
// if alias provided, then don't find by name // if alias provided, then don't find by name
onNameQuery := mockImageReader.On("Query", testCtx, image.QueryOptions(expectedImageFilter, expectedFindFilter, false)) onNameQuery := db.Image.On("Query", testCtx, image.QueryOptions(expectedImageFilter, expectedFindFilter, false))
if aliasName == "" { if aliasName == "" {
onNameQuery.Return(mocks.ImageQueryResult(images, len(images)), nil).Once() onNameQuery.Return(mocks.ImageQueryResult(images, len(images)), nil).Once()
} else { } else {
@@ -252,7 +252,7 @@ func testTagImages(t *testing.T, tc testTagCase) {
}, },
} }
mockImageReader.On("Query", mock.Anything, image.QueryOptions(expectedAliasFilter, expectedFindFilter, false)). db.Image.On("Query", mock.Anything, image.QueryOptions(expectedAliasFilter, expectedFindFilter, false)).
Return(mocks.ImageQueryResult(images, len(images)), nil).Once() Return(mocks.ImageQueryResult(images, len(images)), nil).Once()
} }
@@ -269,19 +269,19 @@ func testTagImages(t *testing.T, tc testTagCase) {
return imagePartialsEqual(got, expected) return imagePartialsEqual(got, expected)
}) })
mockImageReader.On("UpdatePartial", mock.Anything, imageID, matchPartial).Return(nil, nil).Once() db.Image.On("UpdatePartial", mock.Anything, imageID, matchPartial).Return(nil, nil).Once()
} }
tagger := Tagger{ tagger := Tagger{
TxnManager: &mocks.TxnManager{}, TxnManager: db,
} }
err := tagger.TagImages(testCtx, &tag, nil, aliases, mockImageReader) err := tagger.TagImages(testCtx, &tag, nil, aliases, db.Image)
assert := assert.New(t) assert := assert.New(t)
assert.Nil(err) assert.Nil(err)
mockImageReader.AssertExpectations(t) db.AssertExpectations(t)
} }
func TestTagGalleries(t *testing.T) { func TestTagGalleries(t *testing.T) {
@@ -298,7 +298,7 @@ func testTagGalleries(t *testing.T, tc testTagCase) {
aliasName := tc.aliasName aliasName := tc.aliasName
aliasRegex := tc.aliasRegex aliasRegex := tc.aliasRegex
mockGalleryReader := &mocks.GalleryReaderWriter{} db := mocks.NewDatabase()
const tagID = 2 const tagID = 2
@@ -346,7 +346,7 @@ func testTagGalleries(t *testing.T, tc testTagCase) {
} }
// if alias provided, then don't find by name // if alias provided, then don't find by name
onNameQuery := mockGalleryReader.On("Query", testCtx, expectedGalleryFilter, expectedFindFilter) onNameQuery := db.Gallery.On("Query", testCtx, expectedGalleryFilter, expectedFindFilter)
if aliasName == "" { if aliasName == "" {
onNameQuery.Return(galleries, len(galleries), nil).Once() onNameQuery.Return(galleries, len(galleries), nil).Once()
} else { } else {
@@ -360,7 +360,7 @@ func testTagGalleries(t *testing.T, tc testTagCase) {
}, },
} }
mockGalleryReader.On("Query", mock.Anything, expectedAliasFilter, expectedFindFilter).Return(galleries, len(galleries), nil).Once() db.Gallery.On("Query", mock.Anything, expectedAliasFilter, expectedFindFilter).Return(galleries, len(galleries), nil).Once()
} }
for i := range matchingPaths { for i := range matchingPaths {
@@ -376,18 +376,18 @@ func testTagGalleries(t *testing.T, tc testTagCase) {
return galleryPartialsEqual(got, expected) return galleryPartialsEqual(got, expected)
}) })
mockGalleryReader.On("UpdatePartial", mock.Anything, galleryID, matchPartial).Return(nil, nil).Once() db.Gallery.On("UpdatePartial", mock.Anything, galleryID, matchPartial).Return(nil, nil).Once()
} }
tagger := Tagger{ tagger := Tagger{
TxnManager: &mocks.TxnManager{}, TxnManager: db,
} }
err := tagger.TagGalleries(testCtx, &tag, nil, aliases, mockGalleryReader) err := tagger.TagGalleries(testCtx, &tag, nil, aliases, db.Gallery)
assert := assert.New(t) assert := assert.New(t)
assert.Nil(err) assert.Nil(err)
mockGalleryReader.AssertExpectations(t) db.AssertExpectations(t)
} }

View File

@@ -41,7 +41,6 @@ import (
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/scene" "github.com/stashapp/stash/pkg/scene"
"github.com/stashapp/stash/pkg/sliceutil/stringslice" "github.com/stashapp/stash/pkg/sliceutil/stringslice"
"github.com/stashapp/stash/pkg/txn"
) )
var pageSize = 100 var pageSize = 100
@@ -360,10 +359,11 @@ func (me *contentDirectoryService) handleBrowseMetadata(obj object, host string)
} else { } else {
var scene *models.Scene var scene *models.Scene
if err := txn.WithReadTxn(context.TODO(), me.txnManager, func(ctx context.Context) error { r := me.repository
scene, err = me.repository.SceneFinder.Find(ctx, sceneID) if err := r.WithReadTxn(context.TODO(), func(ctx context.Context) error {
scene, err = r.SceneFinder.Find(ctx, sceneID)
if scene != nil { if scene != nil {
err = scene.LoadPrimaryFile(ctx, me.repository.FileGetter) err = scene.LoadPrimaryFile(ctx, r.FileGetter)
} }
if err != nil { if err != nil {
@@ -452,7 +452,8 @@ func getSortDirection(sceneFilter *models.SceneFilterType, sort string) models.S
func (me *contentDirectoryService) getVideos(sceneFilter *models.SceneFilterType, parentID string, host string) []interface{} { func (me *contentDirectoryService) getVideos(sceneFilter *models.SceneFilterType, parentID string, host string) []interface{} {
var objs []interface{} var objs []interface{}
if err := txn.WithReadTxn(context.TODO(), me.txnManager, func(ctx context.Context) error { r := me.repository
if err := r.WithReadTxn(context.TODO(), func(ctx context.Context) error {
sort := me.VideoSortOrder sort := me.VideoSortOrder
direction := getSortDirection(sceneFilter, sort) direction := getSortDirection(sceneFilter, sort)
findFilter := &models.FindFilterType{ findFilter := &models.FindFilterType{
@@ -461,7 +462,7 @@ func (me *contentDirectoryService) getVideos(sceneFilter *models.SceneFilterType
Direction: &direction, Direction: &direction,
} }
scenes, total, err := scene.QueryWithCount(ctx, me.repository.SceneFinder, sceneFilter, findFilter) scenes, total, err := scene.QueryWithCount(ctx, r.SceneFinder, sceneFilter, findFilter)
if err != nil { if err != nil {
return err return err
} }
@@ -472,13 +473,13 @@ func (me *contentDirectoryService) getVideos(sceneFilter *models.SceneFilterType
parentID: parentID, parentID: parentID,
} }
objs, err = pager.getPages(ctx, me.repository.SceneFinder, total) objs, err = pager.getPages(ctx, r.SceneFinder, total)
if err != nil { if err != nil {
return err return err
} }
} else { } else {
for _, s := range scenes { for _, s := range scenes {
if err := s.LoadPrimaryFile(ctx, me.repository.FileGetter); err != nil { if err := s.LoadPrimaryFile(ctx, r.FileGetter); err != nil {
return err return err
} }
@@ -497,7 +498,8 @@ func (me *contentDirectoryService) getVideos(sceneFilter *models.SceneFilterType
func (me *contentDirectoryService) getPageVideos(sceneFilter *models.SceneFilterType, parentID string, page int, host string) []interface{} { func (me *contentDirectoryService) getPageVideos(sceneFilter *models.SceneFilterType, parentID string, page int, host string) []interface{} {
var objs []interface{} var objs []interface{}
if err := txn.WithReadTxn(context.TODO(), me.txnManager, func(ctx context.Context) error { r := me.repository
if err := r.WithReadTxn(context.TODO(), func(ctx context.Context) error {
pager := scenePager{ pager := scenePager{
sceneFilter: sceneFilter, sceneFilter: sceneFilter,
parentID: parentID, parentID: parentID,
@@ -506,7 +508,7 @@ func (me *contentDirectoryService) getPageVideos(sceneFilter *models.SceneFilter
sort := me.VideoSortOrder sort := me.VideoSortOrder
direction := getSortDirection(sceneFilter, sort) direction := getSortDirection(sceneFilter, sort)
var err error var err error
objs, err = pager.getPageVideos(ctx, me.repository.SceneFinder, me.repository.FileGetter, page, host, sort, direction) objs, err = pager.getPageVideos(ctx, r.SceneFinder, r.FileGetter, page, host, sort, direction)
if err != nil { if err != nil {
return err return err
} }
@@ -540,8 +542,9 @@ func (me *contentDirectoryService) getAllScenes(host string) []interface{} {
func (me *contentDirectoryService) getStudios() []interface{} { func (me *contentDirectoryService) getStudios() []interface{} {
var objs []interface{} var objs []interface{}
if err := txn.WithReadTxn(context.TODO(), me.txnManager, func(ctx context.Context) error { r := me.repository
studios, err := me.repository.StudioFinder.All(ctx) if err := r.WithReadTxn(context.TODO(), func(ctx context.Context) error {
studios, err := r.StudioFinder.All(ctx)
if err != nil { if err != nil {
return err return err
} }
@@ -579,8 +582,9 @@ func (me *contentDirectoryService) getStudioScenes(paths []string, host string)
func (me *contentDirectoryService) getTags() []interface{} { func (me *contentDirectoryService) getTags() []interface{} {
var objs []interface{} var objs []interface{}
if err := txn.WithReadTxn(context.TODO(), me.txnManager, func(ctx context.Context) error { r := me.repository
tags, err := me.repository.TagFinder.All(ctx) if err := r.WithReadTxn(context.TODO(), func(ctx context.Context) error {
tags, err := r.TagFinder.All(ctx)
if err != nil { if err != nil {
return err return err
} }
@@ -618,8 +622,9 @@ func (me *contentDirectoryService) getTagScenes(paths []string, host string) []i
func (me *contentDirectoryService) getPerformers() []interface{} { func (me *contentDirectoryService) getPerformers() []interface{} {
var objs []interface{} var objs []interface{}
if err := txn.WithReadTxn(context.TODO(), me.txnManager, func(ctx context.Context) error { r := me.repository
performers, err := me.repository.PerformerFinder.All(ctx) if err := r.WithReadTxn(context.TODO(), func(ctx context.Context) error {
performers, err := r.PerformerFinder.All(ctx)
if err != nil { if err != nil {
return err return err
} }
@@ -657,8 +662,9 @@ func (me *contentDirectoryService) getPerformerScenes(paths []string, host strin
func (me *contentDirectoryService) getMovies() []interface{} { func (me *contentDirectoryService) getMovies() []interface{} {
var objs []interface{} var objs []interface{}
if err := txn.WithReadTxn(context.TODO(), me.txnManager, func(ctx context.Context) error { r := me.repository
movies, err := me.repository.MovieFinder.All(ctx) if err := r.WithReadTxn(context.TODO(), func(ctx context.Context) error {
movies, err := r.MovieFinder.All(ctx)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -48,7 +48,6 @@ import (
"github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
) )
type SceneFinder interface { type SceneFinder interface {
@@ -271,7 +270,6 @@ type Server struct {
// Time interval between SSPD announces // Time interval between SSPD announces
NotifyInterval time.Duration NotifyInterval time.Duration
txnManager txn.Manager
repository Repository repository Repository
sceneServer sceneServer sceneServer sceneServer
ipWhitelistManager *ipWhitelistManager ipWhitelistManager *ipWhitelistManager
@@ -439,12 +437,13 @@ func (me *Server) serveIcon(w http.ResponseWriter, r *http.Request) {
} }
var scene *models.Scene var scene *models.Scene
err := txn.WithReadTxn(r.Context(), me.txnManager, func(ctx context.Context) error { repo := me.repository
err := repo.WithReadTxn(r.Context(), func(ctx context.Context) error {
idInt, err := strconv.Atoi(sceneId) idInt, err := strconv.Atoi(sceneId)
if err != nil { if err != nil {
return nil return nil
} }
scene, _ = me.repository.SceneFinder.Find(ctx, idInt) scene, _ = repo.SceneFinder.Find(ctx, idInt)
return nil return nil
}) })
if err != nil { if err != nil {
@@ -579,12 +578,13 @@ func (me *Server) initMux(mux *http.ServeMux) {
mux.HandleFunc(resPath, func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc(resPath, func(w http.ResponseWriter, r *http.Request) {
sceneId := r.URL.Query().Get("scene") sceneId := r.URL.Query().Get("scene")
var scene *models.Scene var scene *models.Scene
err := txn.WithReadTxn(r.Context(), me.txnManager, func(ctx context.Context) error { repo := me.repository
err := repo.WithReadTxn(r.Context(), func(ctx context.Context) error {
sceneIdInt, err := strconv.Atoi(sceneId) sceneIdInt, err := strconv.Atoi(sceneId)
if err != nil { if err != nil {
return nil return nil
} }
scene, _ = me.repository.SceneFinder.Find(ctx, sceneIdInt) scene, _ = repo.SceneFinder.Find(ctx, sceneIdInt)
return nil return nil
}) })
if err != nil { if err != nil {

View File

@@ -1,6 +1,7 @@
package dlna package dlna
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
@@ -14,6 +15,8 @@ import (
) )
type Repository struct { type Repository struct {
TxnManager models.TxnManager
SceneFinder SceneFinder SceneFinder SceneFinder
FileGetter models.FileGetter FileGetter models.FileGetter
StudioFinder StudioFinder StudioFinder StudioFinder
@@ -22,6 +25,22 @@ type Repository struct {
MovieFinder MovieFinder MovieFinder MovieFinder
} }
func NewRepository(repo models.Repository) Repository {
return Repository{
TxnManager: repo.TxnManager,
FileGetter: repo.File,
SceneFinder: repo.Scene,
StudioFinder: repo.Studio,
TagFinder: repo.Tag,
PerformerFinder: repo.Performer,
MovieFinder: repo.Movie,
}
}
func (r *Repository) WithReadTxn(ctx context.Context, fn txn.TxnFunc) error {
return txn.WithReadTxn(ctx, r.TxnManager, fn)
}
type Status struct { type Status struct {
Running bool `json:"running"` Running bool `json:"running"`
// If not currently running, time until it will be started. If running, time until it will be stopped // If not currently running, time until it will be started. If running, time until it will be stopped
@@ -60,7 +79,6 @@ type Config interface {
} }
type Service struct { type Service struct {
txnManager txn.Manager
repository Repository repository Repository
config Config config Config
sceneServer sceneServer sceneServer sceneServer
@@ -133,9 +151,8 @@ func (s *Service) init() error {
} }
s.server = &Server{ s.server = &Server{
txnManager: s.txnManager,
sceneServer: s.sceneServer,
repository: s.repository, repository: s.repository,
sceneServer: s.sceneServer,
ipWhitelistManager: s.ipWhitelistMgr, ipWhitelistManager: s.ipWhitelistMgr,
Interfaces: interfaces, Interfaces: interfaces,
HTTPConn: func() net.Listener { HTTPConn: func() net.Listener {
@@ -197,9 +214,8 @@ func (s *Service) init() error {
// } // }
// NewService initialises and returns a new DLNA service. // NewService initialises and returns a new DLNA service.
func NewService(txnManager txn.Manager, repo Repository, cfg Config, sceneServer sceneServer) *Service { func NewService(repo Repository, cfg Config, sceneServer sceneServer) *Service {
ret := &Service{ ret := &Service{
txnManager: txnManager,
repository: repo, repository: repo,
sceneServer: sceneServer, sceneServer: sceneServer,
config: cfg, config: cfg,

View File

@@ -43,6 +43,7 @@ type ScraperSource struct {
} }
type SceneIdentifier struct { type SceneIdentifier struct {
TxnManager txn.Manager
SceneReaderUpdater SceneReaderUpdater SceneReaderUpdater SceneReaderUpdater
StudioReaderWriter models.StudioReaderWriter StudioReaderWriter models.StudioReaderWriter
PerformerCreator PerformerCreator PerformerCreator PerformerCreator
@@ -53,8 +54,8 @@ type SceneIdentifier struct {
SceneUpdatePostHookExecutor SceneUpdatePostHookExecutor SceneUpdatePostHookExecutor SceneUpdatePostHookExecutor
} }
func (t *SceneIdentifier) Identify(ctx context.Context, txnManager txn.Manager, scene *models.Scene) error { func (t *SceneIdentifier) Identify(ctx context.Context, scene *models.Scene) error {
result, err := t.scrapeScene(ctx, txnManager, scene) result, err := t.scrapeScene(ctx, scene)
var multipleMatchErr *MultipleMatchesFoundError var multipleMatchErr *MultipleMatchesFoundError
if err != nil { if err != nil {
if !errors.As(err, &multipleMatchErr) { if !errors.As(err, &multipleMatchErr) {
@@ -70,7 +71,7 @@ func (t *SceneIdentifier) Identify(ctx context.Context, txnManager txn.Manager,
options := t.getOptions(multipleMatchErr.Source) options := t.getOptions(multipleMatchErr.Source)
if options.SkipMultipleMatchTag != nil && len(*options.SkipMultipleMatchTag) > 0 { if options.SkipMultipleMatchTag != nil && len(*options.SkipMultipleMatchTag) > 0 {
// Tag it with the multiple results tag // Tag it with the multiple results tag
err := t.addTagToScene(ctx, txnManager, scene, *options.SkipMultipleMatchTag) err := t.addTagToScene(ctx, scene, *options.SkipMultipleMatchTag)
if err != nil { if err != nil {
return err return err
} }
@@ -83,7 +84,7 @@ func (t *SceneIdentifier) Identify(ctx context.Context, txnManager txn.Manager,
} }
// results were found, modify the scene // results were found, modify the scene
if err := t.modifyScene(ctx, txnManager, scene, result); err != nil { if err := t.modifyScene(ctx, scene, result); err != nil {
return fmt.Errorf("error modifying scene: %v", err) return fmt.Errorf("error modifying scene: %v", err)
} }
@@ -95,7 +96,7 @@ type scrapeResult struct {
source ScraperSource source ScraperSource
} }
func (t *SceneIdentifier) scrapeScene(ctx context.Context, txnManager txn.Manager, scene *models.Scene) (*scrapeResult, error) { func (t *SceneIdentifier) scrapeScene(ctx context.Context, scene *models.Scene) (*scrapeResult, error) {
// iterate through the input sources // iterate through the input sources
for _, source := range t.Sources { for _, source := range t.Sources {
// scrape using the source // scrape using the source
@@ -261,9 +262,9 @@ func (t *SceneIdentifier) getSceneUpdater(ctx context.Context, s *models.Scene,
return ret, nil return ret, nil
} }
func (t *SceneIdentifier) modifyScene(ctx context.Context, txnManager txn.Manager, s *models.Scene, result *scrapeResult) error { func (t *SceneIdentifier) modifyScene(ctx context.Context, s *models.Scene, result *scrapeResult) error {
var updater *scene.UpdateSet var updater *scene.UpdateSet
if err := txn.WithTxn(ctx, txnManager, func(ctx context.Context) error { if err := txn.WithTxn(ctx, t.TxnManager, func(ctx context.Context) error {
// load scene relationships // load scene relationships
if err := s.LoadURLs(ctx, t.SceneReaderUpdater); err != nil { if err := s.LoadURLs(ctx, t.SceneReaderUpdater); err != nil {
return err return err
@@ -316,8 +317,8 @@ func (t *SceneIdentifier) modifyScene(ctx context.Context, txnManager txn.Manage
return nil return nil
} }
func (t *SceneIdentifier) addTagToScene(ctx context.Context, txnManager txn.Manager, s *models.Scene, tagToAdd string) error { func (t *SceneIdentifier) addTagToScene(ctx context.Context, s *models.Scene, tagToAdd string) error {
if err := txn.WithTxn(ctx, txnManager, func(ctx context.Context) error { if err := txn.WithTxn(ctx, t.TxnManager, func(ctx context.Context) error {
tagID, err := strconv.Atoi(tagToAdd) tagID, err := strconv.Atoi(tagToAdd)
if err != nil { if err != nil {
return fmt.Errorf("error converting tag ID %s: %w", tagToAdd, err) return fmt.Errorf("error converting tag ID %s: %w", tagToAdd, err)

View File

@@ -108,17 +108,17 @@ func TestSceneIdentifier_Identify(t *testing.T) {
}, },
} }
mockSceneReaderWriter := &mocks.SceneReaderWriter{} db := mocks.NewDatabase()
mockSceneReaderWriter.On("GetURLs", mock.Anything, mock.Anything).Return(nil, nil)
mockSceneReaderWriter.On("UpdatePartial", mock.Anything, mock.MatchedBy(func(id int) bool { db.Scene.On("GetURLs", mock.Anything, mock.Anything).Return(nil, nil)
db.Scene.On("UpdatePartial", mock.Anything, mock.MatchedBy(func(id int) bool {
return id == errUpdateID return id == errUpdateID
}), mock.Anything).Return(nil, errors.New("update error")) }), mock.Anything).Return(nil, errors.New("update error"))
mockSceneReaderWriter.On("UpdatePartial", mock.Anything, mock.MatchedBy(func(id int) bool { db.Scene.On("UpdatePartial", mock.Anything, mock.MatchedBy(func(id int) bool {
return id != errUpdateID return id != errUpdateID
}), mock.Anything).Return(nil, nil) }), mock.Anything).Return(nil, nil)
mockTagFinderCreator := &mocks.TagReaderWriter{} db.Tag.On("Find", mock.Anything, skipMultipleTagID).Return(&models.Tag{
mockTagFinderCreator.On("Find", mock.Anything, skipMultipleTagID).Return(&models.Tag{
ID: skipMultipleTagID, ID: skipMultipleTagID,
Name: skipMultipleTagIDStr, Name: skipMultipleTagIDStr,
}, nil) }, nil)
@@ -185,8 +185,11 @@ func TestSceneIdentifier_Identify(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
identifier := SceneIdentifier{ identifier := SceneIdentifier{
SceneReaderUpdater: mockSceneReaderWriter, TxnManager: db,
TagFinderCreator: mockTagFinderCreator, SceneReaderUpdater: db.Scene,
StudioReaderWriter: db.Studio,
PerformerCreator: db.Performer,
TagFinderCreator: db.Tag,
DefaultOptions: defaultOptions, DefaultOptions: defaultOptions,
Sources: sources, Sources: sources,
SceneUpdatePostHookExecutor: mockHookExecutor{}, SceneUpdatePostHookExecutor: mockHookExecutor{},
@@ -202,7 +205,7 @@ func TestSceneIdentifier_Identify(t *testing.T) {
TagIDs: models.NewRelatedIDs([]int{}), TagIDs: models.NewRelatedIDs([]int{}),
StashIDs: models.NewRelatedStashIDs([]models.StashID{}), StashIDs: models.NewRelatedStashIDs([]models.StashID{}),
} }
if err := identifier.Identify(testCtx, &mocks.TxnManager{}, scene); (err != nil) != tt.wantErr { if err := identifier.Identify(testCtx, scene); (err != nil) != tt.wantErr {
t.Errorf("SceneIdentifier.Identify() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("SceneIdentifier.Identify() error = %v, wantErr %v", err, tt.wantErr)
} }
}) })
@@ -210,9 +213,8 @@ func TestSceneIdentifier_Identify(t *testing.T) {
} }
func TestSceneIdentifier_modifyScene(t *testing.T) { func TestSceneIdentifier_modifyScene(t *testing.T) {
repo := models.Repository{ db := mocks.NewDatabase()
TxnManager: &mocks.TxnManager{},
}
boolFalse := false boolFalse := false
defaultOptions := &MetadataOptions{ defaultOptions := &MetadataOptions{
SetOrganized: &boolFalse, SetOrganized: &boolFalse,
@@ -221,7 +223,12 @@ func TestSceneIdentifier_modifyScene(t *testing.T) {
SkipSingleNamePerformers: &boolFalse, SkipSingleNamePerformers: &boolFalse,
} }
tr := &SceneIdentifier{ tr := &SceneIdentifier{
DefaultOptions: defaultOptions, TxnManager: db,
SceneReaderUpdater: db.Scene,
StudioReaderWriter: db.Studio,
PerformerCreator: db.Performer,
TagFinderCreator: db.Tag,
DefaultOptions: defaultOptions,
} }
type args struct { type args struct {
@@ -254,7 +261,7 @@ func TestSceneIdentifier_modifyScene(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if err := tr.modifyScene(testCtx, repo, tt.args.scene, tt.args.result); (err != nil) != tt.wantErr { if err := tr.modifyScene(testCtx, tt.args.scene, tt.args.result); (err != nil) != tt.wantErr {
t.Errorf("SceneIdentifier.modifyScene() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("SceneIdentifier.modifyScene() error = %v, wantErr %v", err, tt.wantErr)
} }
}) })

View File

@@ -22,8 +22,9 @@ func Test_getPerformerID(t *testing.T) {
remoteSiteID := "2" remoteSiteID := "2"
name := "name" name := "name"
mockPerformerReaderWriter := mocks.PerformerReaderWriter{} db := mocks.NewDatabase()
mockPerformerReaderWriter.On("Create", testCtx, mock.Anything).Run(func(args mock.Arguments) {
db.Performer.On("Create", testCtx, mock.Anything).Run(func(args mock.Arguments) {
p := args.Get(1).(*models.Performer) p := args.Get(1).(*models.Performer)
p.ID = validStoredID p.ID = validStoredID
}).Return(nil) }).Return(nil)
@@ -131,7 +132,7 @@ func Test_getPerformerID(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := getPerformerID(testCtx, tt.args.endpoint, &mockPerformerReaderWriter, tt.args.p, tt.args.createMissing, tt.args.skipSingleName) got, err := getPerformerID(testCtx, tt.args.endpoint, db.Performer, tt.args.p, tt.args.createMissing, tt.args.skipSingleName)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("getPerformerID() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("getPerformerID() error = %v, wantErr %v", err, tt.wantErr)
return return
@@ -151,15 +152,16 @@ func Test_createMissingPerformer(t *testing.T) {
invalidName := "invalidName" invalidName := "invalidName"
performerID := 1 performerID := 1
mockPerformerReaderWriter := mocks.PerformerReaderWriter{} db := mocks.NewDatabase()
mockPerformerReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p *models.Performer) bool {
db.Performer.On("Create", testCtx, mock.MatchedBy(func(p *models.Performer) bool {
return p.Name == validName return p.Name == validName
})).Run(func(args mock.Arguments) { })).Run(func(args mock.Arguments) {
p := args.Get(1).(*models.Performer) p := args.Get(1).(*models.Performer)
p.ID = performerID p.ID = performerID
}).Return(nil) }).Return(nil)
mockPerformerReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p *models.Performer) bool { db.Performer.On("Create", testCtx, mock.MatchedBy(func(p *models.Performer) bool {
return p.Name == invalidName return p.Name == invalidName
})).Return(errors.New("error creating performer")) })).Return(errors.New("error creating performer"))
@@ -212,7 +214,7 @@ func Test_createMissingPerformer(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := createMissingPerformer(testCtx, tt.args.endpoint, &mockPerformerReaderWriter, tt.args.p) got, err := createMissingPerformer(testCtx, tt.args.endpoint, db.Performer, tt.args.p)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("createMissingPerformer() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("createMissingPerformer() error = %v, wantErr %v", err, tt.wantErr)
return return

View File

@@ -24,14 +24,15 @@ func Test_sceneRelationships_studio(t *testing.T) {
Strategy: FieldStrategyMerge, Strategy: FieldStrategyMerge,
} }
mockStudioReaderWriter := &mocks.StudioReaderWriter{} db := mocks.NewDatabase()
mockStudioReaderWriter.On("Create", testCtx, mock.Anything).Run(func(args mock.Arguments) {
db.Studio.On("Create", testCtx, mock.Anything).Run(func(args mock.Arguments) {
s := args.Get(1).(*models.Studio) s := args.Get(1).(*models.Studio)
s.ID = validStoredIDInt s.ID = validStoredIDInt
}).Return(nil) }).Return(nil)
tr := sceneRelationships{ tr := sceneRelationships{
studioReaderWriter: mockStudioReaderWriter, studioReaderWriter: db.Studio,
fieldOptions: make(map[string]*FieldOptions), fieldOptions: make(map[string]*FieldOptions),
} }
@@ -174,8 +175,10 @@ func Test_sceneRelationships_performers(t *testing.T) {
}), }),
} }
db := mocks.NewDatabase()
tr := sceneRelationships{ tr := sceneRelationships{
sceneReader: &mocks.SceneReaderWriter{}, sceneReader: db.Scene,
fieldOptions: make(map[string]*FieldOptions), fieldOptions: make(map[string]*FieldOptions),
} }
@@ -363,22 +366,21 @@ func Test_sceneRelationships_tags(t *testing.T) {
StashIDs: models.NewRelatedStashIDs([]models.StashID{}), StashIDs: models.NewRelatedStashIDs([]models.StashID{}),
} }
mockSceneReaderWriter := &mocks.SceneReaderWriter{} db := mocks.NewDatabase()
mockTagReaderWriter := &mocks.TagReaderWriter{}
mockTagReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p *models.Tag) bool { db.Tag.On("Create", testCtx, mock.MatchedBy(func(p *models.Tag) bool {
return p.Name == validName return p.Name == validName
})).Run(func(args mock.Arguments) { })).Run(func(args mock.Arguments) {
t := args.Get(1).(*models.Tag) t := args.Get(1).(*models.Tag)
t.ID = validStoredIDInt t.ID = validStoredIDInt
}).Return(nil) }).Return(nil)
mockTagReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p *models.Tag) bool { db.Tag.On("Create", testCtx, mock.MatchedBy(func(p *models.Tag) bool {
return p.Name == invalidName return p.Name == invalidName
})).Return(errors.New("error creating tag")) })).Return(errors.New("error creating tag"))
tr := sceneRelationships{ tr := sceneRelationships{
sceneReader: mockSceneReaderWriter, sceneReader: db.Scene,
tagCreator: mockTagReaderWriter, tagCreator: db.Tag,
fieldOptions: make(map[string]*FieldOptions), fieldOptions: make(map[string]*FieldOptions),
} }
@@ -552,10 +554,10 @@ func Test_sceneRelationships_stashIDs(t *testing.T) {
}), }),
} }
mockSceneReaderWriter := &mocks.SceneReaderWriter{} db := mocks.NewDatabase()
tr := sceneRelationships{ tr := sceneRelationships{
sceneReader: mockSceneReaderWriter, sceneReader: db.Scene,
fieldOptions: make(map[string]*FieldOptions), fieldOptions: make(map[string]*FieldOptions),
} }
@@ -706,12 +708,13 @@ func Test_sceneRelationships_cover(t *testing.T) {
newDataEncoded := base64Prefix + utils.GetBase64StringFromData(newData) newDataEncoded := base64Prefix + utils.GetBase64StringFromData(newData)
invalidData := newDataEncoded + "!!!" invalidData := newDataEncoded + "!!!"
mockSceneReaderWriter := &mocks.SceneReaderWriter{} db := mocks.NewDatabase()
mockSceneReaderWriter.On("GetCover", testCtx, sceneID).Return(existingData, nil)
mockSceneReaderWriter.On("GetCover", testCtx, errSceneID).Return(nil, errors.New("error getting cover")) db.Scene.On("GetCover", testCtx, sceneID).Return(existingData, nil)
db.Scene.On("GetCover", testCtx, errSceneID).Return(nil, errors.New("error getting cover"))
tr := sceneRelationships{ tr := sceneRelationships{
sceneReader: mockSceneReaderWriter, sceneReader: db.Scene,
fieldOptions: make(map[string]*FieldOptions), fieldOptions: make(map[string]*FieldOptions),
} }

View File

@@ -19,18 +19,19 @@ func Test_createMissingStudio(t *testing.T) {
invalidName := "invalidName" invalidName := "invalidName"
createdID := 1 createdID := 1
mockStudioReaderWriter := &mocks.StudioReaderWriter{} db := mocks.NewDatabase()
mockStudioReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p *models.Studio) bool {
db.Studio.On("Create", testCtx, mock.MatchedBy(func(p *models.Studio) bool {
return p.Name == validName return p.Name == validName
})).Run(func(args mock.Arguments) { })).Run(func(args mock.Arguments) {
s := args.Get(1).(*models.Studio) s := args.Get(1).(*models.Studio)
s.ID = createdID s.ID = createdID
}).Return(nil) }).Return(nil)
mockStudioReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p *models.Studio) bool { db.Studio.On("Create", testCtx, mock.MatchedBy(func(p *models.Studio) bool {
return p.Name == invalidName return p.Name == invalidName
})).Return(errors.New("error creating studio")) })).Return(errors.New("error creating studio"))
mockStudioReaderWriter.On("UpdatePartial", testCtx, models.StudioPartial{ db.Studio.On("UpdatePartial", testCtx, models.StudioPartial{
ID: createdID, ID: createdID,
StashIDs: &models.UpdateStashIDs{ StashIDs: &models.UpdateStashIDs{
StashIDs: []models.StashID{ StashIDs: []models.StashID{
@@ -42,7 +43,7 @@ func Test_createMissingStudio(t *testing.T) {
Mode: models.RelationshipUpdateModeSet, Mode: models.RelationshipUpdateModeSet,
}, },
}).Return(nil, errors.New("error updating stash ids")) }).Return(nil, errors.New("error updating stash ids"))
mockStudioReaderWriter.On("UpdatePartial", testCtx, models.StudioPartial{ db.Studio.On("UpdatePartial", testCtx, models.StudioPartial{
ID: createdID, ID: createdID,
StashIDs: &models.UpdateStashIDs{ StashIDs: &models.UpdateStashIDs{
StashIDs: []models.StashID{ StashIDs: []models.StashID{
@@ -106,7 +107,7 @@ func Test_createMissingStudio(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := createMissingStudio(testCtx, tt.args.endpoint, mockStudioReaderWriter, tt.args.studio) got, err := createMissingStudio(testCtx, tt.args.endpoint, db.Studio, tt.args.studio)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("createMissingStudio() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("createMissingStudio() error = %v, wantErr %v", err, tt.wantErr)
return return

View File

@@ -131,7 +131,7 @@ type Manager struct {
DLNAService *dlna.Service DLNAService *dlna.Service
Database *sqlite.Database Database *sqlite.Database
Repository Repository Repository models.Repository
SceneService SceneService SceneService SceneService
ImageService ImageService ImageService ImageService
@@ -174,6 +174,7 @@ func initialize() error {
initProfiling(cfg.GetCPUProfilePath()) initProfiling(cfg.GetCPUProfilePath())
db := sqlite.NewDatabase() db := sqlite.NewDatabase()
repo := db.Repository()
// start with empty paths // start with empty paths
emptyPaths := paths.Paths{} emptyPaths := paths.Paths{}
@@ -186,49 +187,43 @@ func initialize() error {
PluginCache: plugin.NewCache(cfg), PluginCache: plugin.NewCache(cfg),
Database: db, Database: db,
Repository: sqliteRepository(db), Repository: repo,
Paths: &emptyPaths, Paths: &emptyPaths,
scanSubs: &subscriptionManager{}, scanSubs: &subscriptionManager{},
} }
instance.SceneService = &scene.Service{ instance.SceneService = &scene.Service{
File: db.File, File: repo.File,
Repository: db.Scene, Repository: repo.Scene,
MarkerRepository: db.SceneMarker, MarkerRepository: repo.SceneMarker,
PluginCache: instance.PluginCache, PluginCache: instance.PluginCache,
Paths: instance.Paths, Paths: instance.Paths,
Config: cfg, Config: cfg,
} }
instance.ImageService = &image.Service{ instance.ImageService = &image.Service{
File: db.File, File: repo.File,
Repository: db.Image, Repository: repo.Image,
} }
instance.GalleryService = &gallery.Service{ instance.GalleryService = &gallery.Service{
Repository: db.Gallery, Repository: repo.Gallery,
ImageFinder: db.Image, ImageFinder: repo.Image,
ImageService: instance.ImageService, ImageService: instance.ImageService,
File: db.File, File: repo.File,
Folder: db.Folder, Folder: repo.Folder,
} }
instance.JobManager = initJobManager() instance.JobManager = initJobManager()
sceneServer := SceneServer{ sceneServer := SceneServer{
TxnManager: instance.Repository, TxnManager: repo.TxnManager,
SceneCoverGetter: instance.Repository.Scene, SceneCoverGetter: repo.Scene,
} }
instance.DLNAService = dlna.NewService(instance.Repository, dlna.Repository{ dlnaRepository := dlna.NewRepository(repo)
SceneFinder: instance.Repository.Scene, instance.DLNAService = dlna.NewService(dlnaRepository, cfg, &sceneServer)
FileGetter: instance.Repository.File,
StudioFinder: instance.Repository.Studio,
TagFinder: instance.Repository.Tag,
PerformerFinder: instance.Repository.Performer,
MovieFinder: instance.Repository.Movie,
}, instance.Config, &sceneServer)
if !cfg.IsNewSystem() { if !cfg.IsNewSystem() {
logger.Infof("using config file: %s", cfg.GetConfigFile()) logger.Infof("using config file: %s", cfg.GetConfigFile())
@@ -268,8 +263,8 @@ func initialize() error {
logger.Warnf("could not initialize FFMPEG subsystem: %v", err) logger.Warnf("could not initialize FFMPEG subsystem: %v", err)
} }
instance.Scanner = makeScanner(db, instance.PluginCache) instance.Scanner = makeScanner(repo, instance.PluginCache)
instance.Cleaner = makeCleaner(db, instance.PluginCache) instance.Cleaner = makeCleaner(repo, instance.PluginCache)
// if DLNA is enabled, start it now // if DLNA is enabled, start it now
if instance.Config.GetDLNADefaultEnabled() { if instance.Config.GetDLNADefaultEnabled() {
@@ -293,14 +288,9 @@ func galleryFileFilter(ctx context.Context, f models.File) bool {
return isZip(f.Base().Basename) return isZip(f.Base().Basename)
} }
func makeScanner(db *sqlite.Database, pluginCache *plugin.Cache) *file.Scanner { func makeScanner(repo models.Repository, pluginCache *plugin.Cache) *file.Scanner {
return &file.Scanner{ return &file.Scanner{
Repository: file.Repository{ Repository: file.NewRepository(repo),
Manager: db,
DatabaseProvider: db,
FileStore: db.File,
FolderStore: db.Folder,
},
FileDecorators: []file.Decorator{ FileDecorators: []file.Decorator{
&file.FilteredDecorator{ &file.FilteredDecorator{
Decorator: &video.Decorator{ Decorator: &video.Decorator{
@@ -320,15 +310,10 @@ func makeScanner(db *sqlite.Database, pluginCache *plugin.Cache) *file.Scanner {
} }
} }
func makeCleaner(db *sqlite.Database, pluginCache *plugin.Cache) *file.Cleaner { func makeCleaner(repo models.Repository, pluginCache *plugin.Cache) *file.Cleaner {
return &file.Cleaner{ return &file.Cleaner{
FS: &file.OsFS{}, FS: &file.OsFS{},
Repository: file.Repository{ Repository: file.NewRepository(repo),
Manager: db,
DatabaseProvider: db,
FileStore: db.File,
FolderStore: db.Folder,
},
Handlers: []file.CleanHandler{ Handlers: []file.CleanHandler{
&cleanHandler{}, &cleanHandler{},
}, },
@@ -523,14 +508,8 @@ func writeStashIcon() {
// initScraperCache initializes a new scraper cache and returns it. // initScraperCache initializes a new scraper cache and returns it.
func (s *Manager) initScraperCache() *scraper.Cache { func (s *Manager) initScraperCache() *scraper.Cache {
ret, err := scraper.NewCache(config.GetInstance(), s.Repository, scraper.Repository{ scraperRepository := scraper.NewRepository(s.Repository)
SceneFinder: s.Repository.Scene, ret, err := scraper.NewCache(s.Config, scraperRepository)
GalleryFinder: s.Repository.Gallery,
TagFinder: s.Repository.Tag,
PerformerFinder: s.Repository.Performer,
MovieFinder: s.Repository.Movie,
StudioFinder: s.Repository.Studio,
})
if err != nil { if err != nil {
logger.Errorf("Error reading scraper configs: %s", err.Error()) logger.Errorf("Error reading scraper configs: %s", err.Error())
@@ -697,7 +676,7 @@ func (s *Manager) Setup(ctx context.Context, input SetupInput) error {
return fmt.Errorf("error initializing FFMPEG subsystem: %v", err) return fmt.Errorf("error initializing FFMPEG subsystem: %v", err)
} }
instance.Scanner = makeScanner(instance.Database, instance.PluginCache) instance.Scanner = makeScanner(instance.Repository, instance.PluginCache)
return nil return nil
} }

View File

@@ -112,7 +112,8 @@ func (s *Manager) Import(ctx context.Context) (int, error) {
j := job.MakeJobExec(func(ctx context.Context, progress *job.Progress) { j := job.MakeJobExec(func(ctx context.Context, progress *job.Progress) {
task := ImportTask{ task := ImportTask{
txnManager: s.Repository, repository: s.Repository,
resetter: s.Database,
BaseDir: metadataPath, BaseDir: metadataPath,
Reset: true, Reset: true,
DuplicateBehaviour: ImportDuplicateEnumFail, DuplicateBehaviour: ImportDuplicateEnumFail,
@@ -136,7 +137,7 @@ func (s *Manager) Export(ctx context.Context) (int, error) {
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
task := ExportTask{ task := ExportTask{
txnManager: s.Repository, repository: s.Repository,
full: true, full: true,
fileNamingAlgorithm: config.GetVideoFileNamingAlgorithm(), fileNamingAlgorithm: config.GetVideoFileNamingAlgorithm(),
} }
@@ -167,7 +168,7 @@ func (s *Manager) Generate(ctx context.Context, input GenerateMetadataInput) (in
} }
j := &GenerateJob{ j := &GenerateJob{
txnManager: s.Repository, repository: s.Repository,
input: input, input: input,
} }
@@ -212,7 +213,7 @@ func (s *Manager) generateScreenshot(ctx context.Context, sceneId string, at *fl
} }
task := GenerateCoverTask{ task := GenerateCoverTask{
txnManager: s.Repository, repository: s.Repository,
Scene: *scene, Scene: *scene,
ScreenshotAt: at, ScreenshotAt: at,
Overwrite: true, Overwrite: true,
@@ -239,7 +240,7 @@ type AutoTagMetadataInput struct {
func (s *Manager) AutoTag(ctx context.Context, input AutoTagMetadataInput) int { func (s *Manager) AutoTag(ctx context.Context, input AutoTagMetadataInput) int {
j := autoTagJob{ j := autoTagJob{
txnManager: s.Repository, repository: s.Repository,
input: input, input: input,
} }
@@ -255,7 +256,7 @@ type CleanMetadataInput struct {
func (s *Manager) Clean(ctx context.Context, input CleanMetadataInput) int { func (s *Manager) Clean(ctx context.Context, input CleanMetadataInput) int {
j := cleanJob{ j := cleanJob{
cleaner: s.Cleaner, cleaner: s.Cleaner,
txnManager: s.Repository, repository: s.Repository,
sceneService: s.SceneService, sceneService: s.SceneService,
imageService: s.ImageService, imageService: s.ImageService,
input: input, input: input,

View File

@@ -6,59 +6,8 @@ import (
"github.com/stashapp/stash/pkg/image" "github.com/stashapp/stash/pkg/image"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/scene" "github.com/stashapp/stash/pkg/scene"
"github.com/stashapp/stash/pkg/sqlite"
"github.com/stashapp/stash/pkg/txn"
) )
type Repository struct {
models.TxnManager
File models.FileReaderWriter
Folder models.FolderReaderWriter
Gallery models.GalleryReaderWriter
GalleryChapter models.GalleryChapterReaderWriter
Image models.ImageReaderWriter
Movie models.MovieReaderWriter
Performer models.PerformerReaderWriter
Scene models.SceneReaderWriter
SceneMarker models.SceneMarkerReaderWriter
Studio models.StudioReaderWriter
Tag models.TagReaderWriter
SavedFilter models.SavedFilterReaderWriter
}
func (r *Repository) WithTxn(ctx context.Context, fn txn.TxnFunc) error {
return txn.WithTxn(ctx, r, fn)
}
func (r *Repository) WithReadTxn(ctx context.Context, fn txn.TxnFunc) error {
return txn.WithReadTxn(ctx, r, fn)
}
func (r *Repository) WithDB(ctx context.Context, fn txn.TxnFunc) error {
return txn.WithDatabase(ctx, r, fn)
}
func sqliteRepository(d *sqlite.Database) Repository {
txnRepo := d.TxnRepository()
return Repository{
TxnManager: txnRepo,
File: d.File,
Folder: d.Folder,
Gallery: d.Gallery,
GalleryChapter: txnRepo.GalleryChapter,
Image: d.Image,
Movie: txnRepo.Movie,
Performer: txnRepo.Performer,
Scene: d.Scene,
SceneMarker: txnRepo.SceneMarker,
Studio: txnRepo.Studio,
Tag: txnRepo.Tag,
SavedFilter: txnRepo.SavedFilter,
}
}
type SceneService interface { type SceneService interface {
Create(ctx context.Context, input *models.Scene, fileIDs []models.FileID, coverImage []byte) (*models.Scene, error) Create(ctx context.Context, input *models.Scene, fileIDs []models.FileID, coverImage []byte) (*models.Scene, error)
AssignFile(ctx context.Context, sceneID int, fileID models.FileID) error AssignFile(ctx context.Context, sceneID int, fileID models.FileID) error

View File

@@ -3,7 +3,6 @@ package manager
import ( import (
"context" "context"
"errors" "errors"
"io"
"net/http" "net/http"
"github.com/stashapp/stash/internal/manager/config" "github.com/stashapp/stash/internal/manager/config"
@@ -58,8 +57,6 @@ func (s *SceneServer) StreamSceneDirect(scene *models.Scene, w http.ResponseWrit
} }
func (s *SceneServer) ServeScreenshot(scene *models.Scene, w http.ResponseWriter, r *http.Request) { func (s *SceneServer) ServeScreenshot(scene *models.Scene, w http.ResponseWriter, r *http.Request) {
const defaultSceneImage = "scene/scene.svg"
var cover []byte var cover []byte
readTxnErr := txn.WithReadTxn(r.Context(), s.TxnManager, func(ctx context.Context) error { readTxnErr := txn.WithReadTxn(r.Context(), s.TxnManager, func(ctx context.Context) error {
cover, _ = s.SceneCoverGetter.GetCover(ctx, scene.ID) cover, _ = s.SceneCoverGetter.GetCover(ctx, scene.ID)
@@ -92,10 +89,7 @@ func (s *SceneServer) ServeScreenshot(scene *models.Scene, w http.ResponseWriter
} }
// fallback to default cover if none found // fallback to default cover if none found
// should always be there cover = static.ReadAll(static.DefaultSceneImage)
f, _ := static.Scene.Open(defaultSceneImage)
defer f.Close()
cover, _ = io.ReadAll(f)
} }
utils.ServeImage(w, r, cover) utils.ServeImage(w, r, cover)

View File

@@ -19,7 +19,7 @@ import (
) )
type autoTagJob struct { type autoTagJob struct {
txnManager Repository repository models.Repository
input AutoTagMetadataInput input AutoTagMetadataInput
cache match.Cache cache match.Cache
@@ -56,7 +56,7 @@ func (j *autoTagJob) autoTagFiles(ctx context.Context, progress *job.Progress, p
studios: studios, studios: studios,
tags: tags, tags: tags,
progress: progress, progress: progress,
txnManager: j.txnManager, repository: j.repository,
cache: &j.cache, cache: &j.cache,
} }
@@ -73,8 +73,8 @@ func (j *autoTagJob) autoTagSpecific(ctx context.Context, progress *job.Progress
studioCount := len(studioIds) studioCount := len(studioIds)
tagCount := len(tagIds) tagCount := len(tagIds)
if err := j.txnManager.WithReadTxn(ctx, func(ctx context.Context) error { r := j.repository
r := j.txnManager if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
performerQuery := r.Performer performerQuery := r.Performer
studioQuery := r.Studio studioQuery := r.Studio
tagQuery := r.Tag tagQuery := r.Tag
@@ -123,16 +123,17 @@ func (j *autoTagJob) autoTagPerformers(ctx context.Context, progress *job.Progre
return return
} }
r := j.repository
tagger := autotag.Tagger{ tagger := autotag.Tagger{
TxnManager: j.txnManager, TxnManager: r.TxnManager,
Cache: &j.cache, Cache: &j.cache,
} }
for _, performerId := range performerIds { for _, performerId := range performerIds {
var performers []*models.Performer var performers []*models.Performer
if err := j.txnManager.WithDB(ctx, func(ctx context.Context) error { if err := r.WithDB(ctx, func(ctx context.Context) error {
performerQuery := j.txnManager.Performer performerQuery := r.Performer
ignoreAutoTag := false ignoreAutoTag := false
perPage := -1 perPage := -1
@@ -161,7 +162,7 @@ func (j *autoTagJob) autoTagPerformers(ctx context.Context, progress *job.Progre
return fmt.Errorf("performer with id %s not found", performerId) return fmt.Errorf("performer with id %s not found", performerId)
} }
if err := performer.LoadAliases(ctx, j.txnManager.Performer); err != nil { if err := performer.LoadAliases(ctx, r.Performer); err != nil {
return fmt.Errorf("loading aliases for performer %d: %w", performer.ID, err) return fmt.Errorf("loading aliases for performer %d: %w", performer.ID, err)
} }
performers = append(performers, performer) performers = append(performers, performer)
@@ -173,7 +174,6 @@ func (j *autoTagJob) autoTagPerformers(ctx context.Context, progress *job.Progre
} }
err := func() error { err := func() error {
r := j.txnManager
if err := tagger.PerformerScenes(ctx, performer, paths, r.Scene); err != nil { if err := tagger.PerformerScenes(ctx, performer, paths, r.Scene); err != nil {
return fmt.Errorf("processing scenes: %w", err) return fmt.Errorf("processing scenes: %w", err)
} }
@@ -215,9 +215,9 @@ func (j *autoTagJob) autoTagStudios(ctx context.Context, progress *job.Progress,
return return
} }
r := j.txnManager r := j.repository
tagger := autotag.Tagger{ tagger := autotag.Tagger{
TxnManager: j.txnManager, TxnManager: r.TxnManager,
Cache: &j.cache, Cache: &j.cache,
} }
@@ -308,15 +308,15 @@ func (j *autoTagJob) autoTagTags(ctx context.Context, progress *job.Progress, pa
return return
} }
r := j.txnManager r := j.repository
tagger := autotag.Tagger{ tagger := autotag.Tagger{
TxnManager: j.txnManager, TxnManager: r.TxnManager,
Cache: &j.cache, Cache: &j.cache,
} }
for _, tagId := range tagIds { for _, tagId := range tagIds {
var tags []*models.Tag var tags []*models.Tag
if err := j.txnManager.WithDB(ctx, func(ctx context.Context) error { if err := r.WithDB(ctx, func(ctx context.Context) error {
tagQuery := r.Tag tagQuery := r.Tag
ignoreAutoTag := false ignoreAutoTag := false
perPage := -1 perPage := -1
@@ -402,7 +402,7 @@ type autoTagFilesTask struct {
tags bool tags bool
progress *job.Progress progress *job.Progress
txnManager Repository repository models.Repository
cache *match.Cache cache *match.Cache
} }
@@ -482,7 +482,9 @@ func (t *autoTagFilesTask) makeGalleryFilter() *models.GalleryFilterType {
return ret return ret
} }
func (t *autoTagFilesTask) getCount(ctx context.Context, r Repository) (int, error) { func (t *autoTagFilesTask) getCount(ctx context.Context) (int, error) {
r := t.repository
pp := 0 pp := 0
findFilter := &models.FindFilterType{ findFilter := &models.FindFilterType{
PerPage: &pp, PerPage: &pp,
@@ -522,7 +524,7 @@ func (t *autoTagFilesTask) getCount(ctx context.Context, r Repository) (int, err
return sceneCount + imageCount + galleryCount, nil return sceneCount + imageCount + galleryCount, nil
} }
func (t *autoTagFilesTask) processScenes(ctx context.Context, r Repository) { func (t *autoTagFilesTask) processScenes(ctx context.Context) {
if job.IsCancelled(ctx) { if job.IsCancelled(ctx) {
return return
} }
@@ -534,10 +536,12 @@ func (t *autoTagFilesTask) processScenes(ctx context.Context, r Repository) {
findFilter := models.BatchFindFilter(batchSize) findFilter := models.BatchFindFilter(batchSize)
sceneFilter := t.makeSceneFilter() sceneFilter := t.makeSceneFilter()
r := t.repository
more := true more := true
for more { for more {
var scenes []*models.Scene var scenes []*models.Scene
if err := t.txnManager.WithReadTxn(ctx, func(ctx context.Context) error { if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
var err error var err error
scenes, err = scene.Query(ctx, r.Scene, sceneFilter, findFilter) scenes, err = scene.Query(ctx, r.Scene, sceneFilter, findFilter)
return err return err
@@ -555,7 +559,7 @@ func (t *autoTagFilesTask) processScenes(ctx context.Context, r Repository) {
} }
tt := autoTagSceneTask{ tt := autoTagSceneTask{
txnManager: t.txnManager, repository: r,
scene: ss, scene: ss,
performers: t.performers, performers: t.performers,
studios: t.studios, studios: t.studios,
@@ -583,7 +587,7 @@ func (t *autoTagFilesTask) processScenes(ctx context.Context, r Repository) {
} }
} }
func (t *autoTagFilesTask) processImages(ctx context.Context, r Repository) { func (t *autoTagFilesTask) processImages(ctx context.Context) {
if job.IsCancelled(ctx) { if job.IsCancelled(ctx) {
return return
} }
@@ -595,10 +599,12 @@ func (t *autoTagFilesTask) processImages(ctx context.Context, r Repository) {
findFilter := models.BatchFindFilter(batchSize) findFilter := models.BatchFindFilter(batchSize)
imageFilter := t.makeImageFilter() imageFilter := t.makeImageFilter()
r := t.repository
more := true more := true
for more { for more {
var images []*models.Image var images []*models.Image
if err := t.txnManager.WithReadTxn(ctx, func(ctx context.Context) error { if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
var err error var err error
images, err = image.Query(ctx, r.Image, imageFilter, findFilter) images, err = image.Query(ctx, r.Image, imageFilter, findFilter)
return err return err
@@ -616,7 +622,7 @@ func (t *autoTagFilesTask) processImages(ctx context.Context, r Repository) {
} }
tt := autoTagImageTask{ tt := autoTagImageTask{
txnManager: t.txnManager, repository: t.repository,
image: ss, image: ss,
performers: t.performers, performers: t.performers,
studios: t.studios, studios: t.studios,
@@ -644,7 +650,7 @@ func (t *autoTagFilesTask) processImages(ctx context.Context, r Repository) {
} }
} }
func (t *autoTagFilesTask) processGalleries(ctx context.Context, r Repository) { func (t *autoTagFilesTask) processGalleries(ctx context.Context) {
if job.IsCancelled(ctx) { if job.IsCancelled(ctx) {
return return
} }
@@ -656,10 +662,12 @@ func (t *autoTagFilesTask) processGalleries(ctx context.Context, r Repository) {
findFilter := models.BatchFindFilter(batchSize) findFilter := models.BatchFindFilter(batchSize)
galleryFilter := t.makeGalleryFilter() galleryFilter := t.makeGalleryFilter()
r := t.repository
more := true more := true
for more { for more {
var galleries []*models.Gallery var galleries []*models.Gallery
if err := t.txnManager.WithReadTxn(ctx, func(ctx context.Context) error { if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
var err error var err error
galleries, _, err = r.Gallery.Query(ctx, galleryFilter, findFilter) galleries, _, err = r.Gallery.Query(ctx, galleryFilter, findFilter)
return err return err
@@ -677,7 +685,7 @@ func (t *autoTagFilesTask) processGalleries(ctx context.Context, r Repository) {
} }
tt := autoTagGalleryTask{ tt := autoTagGalleryTask{
txnManager: t.txnManager, repository: t.repository,
gallery: ss, gallery: ss,
performers: t.performers, performers: t.performers,
studios: t.studios, studios: t.studios,
@@ -706,9 +714,8 @@ func (t *autoTagFilesTask) processGalleries(ctx context.Context, r Repository) {
} }
func (t *autoTagFilesTask) process(ctx context.Context) { func (t *autoTagFilesTask) process(ctx context.Context) {
r := t.txnManager if err := t.repository.WithReadTxn(ctx, func(ctx context.Context) error {
if err := r.WithReadTxn(ctx, func(ctx context.Context) error { total, err := t.getCount(ctx)
total, err := t.getCount(ctx, t.txnManager)
if err != nil { if err != nil {
return err return err
} }
@@ -724,13 +731,13 @@ func (t *autoTagFilesTask) process(ctx context.Context) {
return return
} }
t.processScenes(ctx, r) t.processScenes(ctx)
t.processImages(ctx, r) t.processImages(ctx)
t.processGalleries(ctx, r) t.processGalleries(ctx)
} }
type autoTagSceneTask struct { type autoTagSceneTask struct {
txnManager Repository repository models.Repository
scene *models.Scene scene *models.Scene
performers bool performers bool
@@ -742,8 +749,8 @@ type autoTagSceneTask struct {
func (t *autoTagSceneTask) Start(ctx context.Context, wg *sync.WaitGroup) { func (t *autoTagSceneTask) Start(ctx context.Context, wg *sync.WaitGroup) {
defer wg.Done() defer wg.Done()
r := t.txnManager r := t.repository
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error { if err := r.WithTxn(ctx, func(ctx context.Context) error {
if t.scene.Path == "" { if t.scene.Path == "" {
// nothing to do // nothing to do
return nil return nil
@@ -774,7 +781,7 @@ func (t *autoTagSceneTask) Start(ctx context.Context, wg *sync.WaitGroup) {
} }
type autoTagImageTask struct { type autoTagImageTask struct {
txnManager Repository repository models.Repository
image *models.Image image *models.Image
performers bool performers bool
@@ -786,8 +793,8 @@ type autoTagImageTask struct {
func (t *autoTagImageTask) Start(ctx context.Context, wg *sync.WaitGroup) { func (t *autoTagImageTask) Start(ctx context.Context, wg *sync.WaitGroup) {
defer wg.Done() defer wg.Done()
r := t.txnManager r := t.repository
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error { if err := r.WithTxn(ctx, func(ctx context.Context) error {
if t.performers { if t.performers {
if err := autotag.ImagePerformers(ctx, t.image, r.Image, r.Performer, t.cache); err != nil { if err := autotag.ImagePerformers(ctx, t.image, r.Image, r.Performer, t.cache); err != nil {
return fmt.Errorf("tagging image performers for %s: %v", t.image.DisplayName(), err) return fmt.Errorf("tagging image performers for %s: %v", t.image.DisplayName(), err)
@@ -813,7 +820,7 @@ func (t *autoTagImageTask) Start(ctx context.Context, wg *sync.WaitGroup) {
} }
type autoTagGalleryTask struct { type autoTagGalleryTask struct {
txnManager Repository repository models.Repository
gallery *models.Gallery gallery *models.Gallery
performers bool performers bool
@@ -825,8 +832,8 @@ type autoTagGalleryTask struct {
func (t *autoTagGalleryTask) Start(ctx context.Context, wg *sync.WaitGroup) { func (t *autoTagGalleryTask) Start(ctx context.Context, wg *sync.WaitGroup) {
defer wg.Done() defer wg.Done()
r := t.txnManager r := t.repository
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error { if err := r.WithTxn(ctx, func(ctx context.Context) error {
if t.performers { if t.performers {
if err := autotag.GalleryPerformers(ctx, t.gallery, r.Gallery, r.Performer, t.cache); err != nil { if err := autotag.GalleryPerformers(ctx, t.gallery, r.Gallery, r.Performer, t.cache); err != nil {
return fmt.Errorf("tagging gallery performers for %s: %v", t.gallery.DisplayName(), err) return fmt.Errorf("tagging gallery performers for %s: %v", t.gallery.DisplayName(), err)

View File

@@ -16,7 +16,6 @@ import (
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/plugin" "github.com/stashapp/stash/pkg/plugin"
"github.com/stashapp/stash/pkg/scene" "github.com/stashapp/stash/pkg/scene"
"github.com/stashapp/stash/pkg/txn"
) )
type cleaner interface { type cleaner interface {
@@ -25,7 +24,7 @@ type cleaner interface {
type cleanJob struct { type cleanJob struct {
cleaner cleaner cleaner cleaner
txnManager Repository repository models.Repository
input CleanMetadataInput input CleanMetadataInput
sceneService SceneService sceneService SceneService
imageService ImageService imageService ImageService
@@ -61,10 +60,11 @@ func (j *cleanJob) cleanEmptyGalleries(ctx context.Context) {
const batchSize = 1000 const batchSize = 1000
var toClean []int var toClean []int
findFilter := models.BatchFindFilter(batchSize) findFilter := models.BatchFindFilter(batchSize)
if err := txn.WithTxn(ctx, j.txnManager, func(ctx context.Context) error { r := j.repository
if err := r.WithTxn(ctx, func(ctx context.Context) error {
found := true found := true
for found { for found {
emptyGalleries, _, err := j.txnManager.Gallery.Query(ctx, &models.GalleryFilterType{ emptyGalleries, _, err := r.Gallery.Query(ctx, &models.GalleryFilterType{
ImageCount: &models.IntCriterionInput{ ImageCount: &models.IntCriterionInput{
Value: 0, Value: 0,
Modifier: models.CriterionModifierEquals, Modifier: models.CriterionModifierEquals,
@@ -108,9 +108,10 @@ func (j *cleanJob) cleanEmptyGalleries(ctx context.Context) {
func (j *cleanJob) deleteGallery(ctx context.Context, id int) { func (j *cleanJob) deleteGallery(ctx context.Context, id int) {
pluginCache := GetInstance().PluginCache pluginCache := GetInstance().PluginCache
qb := j.txnManager.Gallery
if err := txn.WithTxn(ctx, j.txnManager, func(ctx context.Context) error { r := j.repository
if err := r.WithTxn(ctx, func(ctx context.Context) error {
qb := r.Gallery
g, err := qb.Find(ctx, id) g, err := qb.Find(ctx, id)
if err != nil { if err != nil {
return err return err
@@ -120,7 +121,7 @@ func (j *cleanJob) deleteGallery(ctx context.Context, id int) {
return fmt.Errorf("gallery with id %d not found", id) return fmt.Errorf("gallery with id %d not found", id)
} }
if err := g.LoadPrimaryFile(ctx, j.txnManager.File); err != nil { if err := g.LoadPrimaryFile(ctx, r.File); err != nil {
return err return err
} }
@@ -253,9 +254,7 @@ func (f *cleanFilter) shouldCleanImage(path string, stash *config.StashConfig) b
return false return false
} }
type cleanHandler struct { type cleanHandler struct{}
PluginCache *plugin.Cache
}
func (h *cleanHandler) HandleFile(ctx context.Context, fileDeleter *file.Deleter, fileID models.FileID) error { func (h *cleanHandler) HandleFile(ctx context.Context, fileDeleter *file.Deleter, fileID models.FileID) error {
if err := h.handleRelatedScenes(ctx, fileDeleter, fileID); err != nil { if err := h.handleRelatedScenes(ctx, fileDeleter, fileID); err != nil {
@@ -277,7 +276,7 @@ func (h *cleanHandler) HandleFolder(ctx context.Context, fileDeleter *file.Delet
func (h *cleanHandler) handleRelatedScenes(ctx context.Context, fileDeleter *file.Deleter, fileID models.FileID) error { func (h *cleanHandler) handleRelatedScenes(ctx context.Context, fileDeleter *file.Deleter, fileID models.FileID) error {
mgr := GetInstance() mgr := GetInstance()
sceneQB := mgr.Database.Scene sceneQB := mgr.Repository.Scene
scenes, err := sceneQB.FindByFileID(ctx, fileID) scenes, err := sceneQB.FindByFileID(ctx, fileID)
if err != nil { if err != nil {
return err return err
@@ -303,12 +302,9 @@ func (h *cleanHandler) handleRelatedScenes(ctx context.Context, fileDeleter *fil
return err return err
} }
checksum := scene.Checksum
oshash := scene.OSHash
mgr.PluginCache.RegisterPostHooks(ctx, scene.ID, plugin.SceneDestroyPost, plugin.SceneDestroyInput{ mgr.PluginCache.RegisterPostHooks(ctx, scene.ID, plugin.SceneDestroyPost, plugin.SceneDestroyInput{
Checksum: checksum, Checksum: scene.Checksum,
OSHash: oshash, OSHash: scene.OSHash,
Path: scene.Path, Path: scene.Path,
}, nil) }, nil)
} else { } else {
@@ -335,7 +331,7 @@ func (h *cleanHandler) handleRelatedScenes(ctx context.Context, fileDeleter *fil
func (h *cleanHandler) handleRelatedGalleries(ctx context.Context, fileID models.FileID) error { func (h *cleanHandler) handleRelatedGalleries(ctx context.Context, fileID models.FileID) error {
mgr := GetInstance() mgr := GetInstance()
qb := mgr.Database.Gallery qb := mgr.Repository.Gallery
galleries, err := qb.FindByFileID(ctx, fileID) galleries, err := qb.FindByFileID(ctx, fileID)
if err != nil { if err != nil {
return err return err
@@ -381,7 +377,7 @@ func (h *cleanHandler) handleRelatedGalleries(ctx context.Context, fileID models
func (h *cleanHandler) deleteRelatedFolderGalleries(ctx context.Context, folderID models.FolderID) error { func (h *cleanHandler) deleteRelatedFolderGalleries(ctx context.Context, folderID models.FolderID) error {
mgr := GetInstance() mgr := GetInstance()
qb := mgr.Database.Gallery qb := mgr.Repository.Gallery
galleries, err := qb.FindByFolderID(ctx, folderID) galleries, err := qb.FindByFolderID(ctx, folderID)
if err != nil { if err != nil {
return err return err
@@ -405,7 +401,7 @@ func (h *cleanHandler) deleteRelatedFolderGalleries(ctx context.Context, folderI
func (h *cleanHandler) handleRelatedImages(ctx context.Context, fileDeleter *file.Deleter, fileID models.FileID) error { func (h *cleanHandler) handleRelatedImages(ctx context.Context, fileDeleter *file.Deleter, fileID models.FileID) error {
mgr := GetInstance() mgr := GetInstance()
imageQB := mgr.Database.Image imageQB := mgr.Repository.Image
images, err := imageQB.FindByFileID(ctx, fileID) images, err := imageQB.FindByFileID(ctx, fileID)
if err != nil { if err != nil {
return err return err
@@ -413,7 +409,7 @@ func (h *cleanHandler) handleRelatedImages(ctx context.Context, fileDeleter *fil
imageFileDeleter := &image.FileDeleter{ imageFileDeleter := &image.FileDeleter{
Deleter: fileDeleter, Deleter: fileDeleter,
Paths: GetInstance().Paths, Paths: mgr.Paths,
} }
for _, i := range images { for _, i := range images {

View File

@@ -31,7 +31,7 @@ import (
) )
type ExportTask struct { type ExportTask struct {
txnManager Repository repository models.Repository
full bool full bool
baseDir string baseDir string
@@ -98,7 +98,7 @@ func CreateExportTask(a models.HashAlgorithm, input ExportObjectsInput) *ExportT
} }
return &ExportTask{ return &ExportTask{
txnManager: GetInstance().Repository, repository: GetInstance().Repository,
fileNamingAlgorithm: a, fileNamingAlgorithm: a,
scenes: newExportSpec(input.Scenes), scenes: newExportSpec(input.Scenes),
images: newExportSpec(input.Images), images: newExportSpec(input.Images),
@@ -148,29 +148,27 @@ func (t *ExportTask) Start(ctx context.Context, wg *sync.WaitGroup) {
paths.EmptyJSONDirs(t.baseDir) paths.EmptyJSONDirs(t.baseDir)
paths.EnsureJSONDirs(t.baseDir) paths.EnsureJSONDirs(t.baseDir)
txnErr := t.txnManager.WithTxn(ctx, func(ctx context.Context) error { txnErr := t.repository.WithTxn(ctx, func(ctx context.Context) error {
r := t.txnManager
// include movie scenes and gallery images // include movie scenes and gallery images
if !t.full { if !t.full {
// only include movie scenes if includeDependencies is also set // only include movie scenes if includeDependencies is also set
if !t.scenes.all && t.includeDependencies { if !t.scenes.all && t.includeDependencies {
t.populateMovieScenes(ctx, r) t.populateMovieScenes(ctx)
} }
// always export gallery images // always export gallery images
if !t.images.all { if !t.images.all {
t.populateGalleryImages(ctx, r) t.populateGalleryImages(ctx)
} }
} }
t.ExportScenes(ctx, workerCount, r) t.ExportScenes(ctx, workerCount)
t.ExportImages(ctx, workerCount, r) t.ExportImages(ctx, workerCount)
t.ExportGalleries(ctx, workerCount, r) t.ExportGalleries(ctx, workerCount)
t.ExportMovies(ctx, workerCount, r) t.ExportMovies(ctx, workerCount)
t.ExportPerformers(ctx, workerCount, r) t.ExportPerformers(ctx, workerCount)
t.ExportStudios(ctx, workerCount, r) t.ExportStudios(ctx, workerCount)
t.ExportTags(ctx, workerCount, r) t.ExportTags(ctx, workerCount)
return nil return nil
}) })
@@ -277,9 +275,10 @@ func (t *ExportTask) zipFile(fn, outDir string, z *zip.Writer) error {
return nil return nil
} }
func (t *ExportTask) populateMovieScenes(ctx context.Context, repo Repository) { func (t *ExportTask) populateMovieScenes(ctx context.Context) {
reader := repo.Movie r := t.repository
sceneReader := repo.Scene reader := r.Movie
sceneReader := r.Scene
var movies []*models.Movie var movies []*models.Movie
var err error var err error
@@ -307,9 +306,10 @@ func (t *ExportTask) populateMovieScenes(ctx context.Context, repo Repository) {
} }
} }
func (t *ExportTask) populateGalleryImages(ctx context.Context, repo Repository) { func (t *ExportTask) populateGalleryImages(ctx context.Context) {
reader := repo.Gallery r := t.repository
imageReader := repo.Image reader := r.Gallery
imageReader := r.Image
var galleries []*models.Gallery var galleries []*models.Gallery
var err error var err error
@@ -342,10 +342,10 @@ func (t *ExportTask) populateGalleryImages(ctx context.Context, repo Repository)
} }
} }
func (t *ExportTask) ExportScenes(ctx context.Context, workers int, repo Repository) { func (t *ExportTask) ExportScenes(ctx context.Context, workers int) {
var scenesWg sync.WaitGroup var scenesWg sync.WaitGroup
sceneReader := repo.Scene sceneReader := t.repository.Scene
var scenes []*models.Scene var scenes []*models.Scene
var err error var err error
@@ -367,7 +367,7 @@ func (t *ExportTask) ExportScenes(ctx context.Context, workers int, repo Reposit
for w := 0; w < workers; w++ { // create export Scene workers for w := 0; w < workers; w++ { // create export Scene workers
scenesWg.Add(1) scenesWg.Add(1)
go exportScene(ctx, &scenesWg, jobCh, repo, t) go t.exportScene(ctx, &scenesWg, jobCh)
} }
for i, scene := range scenes { for i, scene := range scenes {
@@ -385,7 +385,7 @@ func (t *ExportTask) ExportScenes(ctx context.Context, workers int, repo Reposit
logger.Infof("[scenes] export complete in %s. %d workers used.", time.Since(startTime), workers) logger.Infof("[scenes] export complete in %s. %d workers used.", time.Since(startTime), workers)
} }
func exportFile(f models.File, t *ExportTask) { func (t *ExportTask) exportFile(f models.File) {
newFileJSON := fileToJSON(f) newFileJSON := fileToJSON(f)
fn := newFileJSON.Filename() fn := newFileJSON.Filename()
@@ -449,7 +449,7 @@ func fileToJSON(f models.File) jsonschema.DirEntry {
return &base return &base
} }
func exportFolder(f models.Folder, t *ExportTask) { func (t *ExportTask) exportFolder(f models.Folder) {
newFileJSON := folderToJSON(f) newFileJSON := folderToJSON(f)
fn := newFileJSON.Filename() fn := newFileJSON.Filename()
@@ -475,15 +475,17 @@ func folderToJSON(f models.Folder) jsonschema.DirEntry {
return &base return &base
} }
func exportScene(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Scene, repo Repository, t *ExportTask) { func (t *ExportTask) exportScene(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Scene) {
defer wg.Done() defer wg.Done()
sceneReader := repo.Scene
studioReader := repo.Studio r := t.repository
movieReader := repo.Movie sceneReader := r.Scene
galleryReader := repo.Gallery studioReader := r.Studio
performerReader := repo.Performer movieReader := r.Movie
tagReader := repo.Tag galleryReader := r.Gallery
sceneMarkerReader := repo.SceneMarker performerReader := r.Performer
tagReader := r.Tag
sceneMarkerReader := r.SceneMarker
for s := range jobChan { for s := range jobChan {
sceneHash := s.GetHash(t.fileNamingAlgorithm) sceneHash := s.GetHash(t.fileNamingAlgorithm)
@@ -500,7 +502,7 @@ func exportScene(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models
// export files // export files
for _, f := range s.Files.List() { for _, f := range s.Files.List() {
exportFile(f, t) t.exportFile(f)
} }
newSceneJSON.Studio, err = scene.GetStudioName(ctx, studioReader, s) newSceneJSON.Studio, err = scene.GetStudioName(ctx, studioReader, s)
@@ -589,10 +591,11 @@ func exportScene(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models
} }
} }
func (t *ExportTask) ExportImages(ctx context.Context, workers int, repo Repository) { func (t *ExportTask) ExportImages(ctx context.Context, workers int) {
var imagesWg sync.WaitGroup var imagesWg sync.WaitGroup
imageReader := repo.Image r := t.repository
imageReader := r.Image
var images []*models.Image var images []*models.Image
var err error var err error
@@ -614,7 +617,7 @@ func (t *ExportTask) ExportImages(ctx context.Context, workers int, repo Reposit
for w := 0; w < workers; w++ { // create export Image workers for w := 0; w < workers; w++ { // create export Image workers
imagesWg.Add(1) imagesWg.Add(1)
go exportImage(ctx, &imagesWg, jobCh, repo, t) go t.exportImage(ctx, &imagesWg, jobCh)
} }
for i, image := range images { for i, image := range images {
@@ -632,22 +635,24 @@ func (t *ExportTask) ExportImages(ctx context.Context, workers int, repo Reposit
logger.Infof("[images] export complete in %s. %d workers used.", time.Since(startTime), workers) logger.Infof("[images] export complete in %s. %d workers used.", time.Since(startTime), workers)
} }
func exportImage(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Image, repo Repository, t *ExportTask) { func (t *ExportTask) exportImage(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Image) {
defer wg.Done() defer wg.Done()
studioReader := repo.Studio
galleryReader := repo.Gallery r := t.repository
performerReader := repo.Performer studioReader := r.Studio
tagReader := repo.Tag galleryReader := r.Gallery
performerReader := r.Performer
tagReader := r.Tag
for s := range jobChan { for s := range jobChan {
imageHash := s.Checksum imageHash := s.Checksum
if err := s.LoadFiles(ctx, repo.Image); err != nil { if err := s.LoadFiles(ctx, r.Image); err != nil {
logger.Errorf("[images] <%s> error getting image files: %s", imageHash, err.Error()) logger.Errorf("[images] <%s> error getting image files: %s", imageHash, err.Error())
continue continue
} }
if err := s.LoadURLs(ctx, repo.Image); err != nil { if err := s.LoadURLs(ctx, r.Image); err != nil {
logger.Errorf("[images] <%s> error getting image urls: %s", imageHash, err.Error()) logger.Errorf("[images] <%s> error getting image urls: %s", imageHash, err.Error())
continue continue
} }
@@ -656,7 +661,7 @@ func exportImage(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models
// export files // export files
for _, f := range s.Files.List() { for _, f := range s.Files.List() {
exportFile(f, t) t.exportFile(f)
} }
var err error var err error
@@ -715,10 +720,10 @@ func exportImage(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models
} }
} }
func (t *ExportTask) ExportGalleries(ctx context.Context, workers int, repo Repository) { func (t *ExportTask) ExportGalleries(ctx context.Context, workers int) {
var galleriesWg sync.WaitGroup var galleriesWg sync.WaitGroup
reader := repo.Gallery reader := t.repository.Gallery
var galleries []*models.Gallery var galleries []*models.Gallery
var err error var err error
@@ -740,7 +745,7 @@ func (t *ExportTask) ExportGalleries(ctx context.Context, workers int, repo Repo
for w := 0; w < workers; w++ { // create export Scene workers for w := 0; w < workers; w++ { // create export Scene workers
galleriesWg.Add(1) galleriesWg.Add(1)
go exportGallery(ctx, &galleriesWg, jobCh, repo, t) go t.exportGallery(ctx, &galleriesWg, jobCh)
} }
for i, gallery := range galleries { for i, gallery := range galleries {
@@ -759,15 +764,17 @@ func (t *ExportTask) ExportGalleries(ctx context.Context, workers int, repo Repo
logger.Infof("[galleries] export complete in %s. %d workers used.", time.Since(startTime), workers) logger.Infof("[galleries] export complete in %s. %d workers used.", time.Since(startTime), workers)
} }
func exportGallery(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Gallery, repo Repository, t *ExportTask) { func (t *ExportTask) exportGallery(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Gallery) {
defer wg.Done() defer wg.Done()
studioReader := repo.Studio
performerReader := repo.Performer r := t.repository
tagReader := repo.Tag studioReader := r.Studio
galleryChapterReader := repo.GalleryChapter performerReader := r.Performer
tagReader := r.Tag
galleryChapterReader := r.GalleryChapter
for g := range jobChan { for g := range jobChan {
if err := g.LoadFiles(ctx, repo.Gallery); err != nil { if err := g.LoadFiles(ctx, r.Gallery); err != nil {
logger.Errorf("[galleries] <%s> failed to fetch files for gallery: %s", g.DisplayName(), err.Error()) logger.Errorf("[galleries] <%s> failed to fetch files for gallery: %s", g.DisplayName(), err.Error())
continue continue
} }
@@ -782,12 +789,12 @@ func exportGallery(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *mode
// export files // export files
for _, f := range g.Files.List() { for _, f := range g.Files.List() {
exportFile(f, t) t.exportFile(f)
} }
// export folder if necessary // export folder if necessary
if g.FolderID != nil { if g.FolderID != nil {
folder, err := repo.Folder.Find(ctx, *g.FolderID) folder, err := r.Folder.Find(ctx, *g.FolderID)
if err != nil { if err != nil {
logger.Errorf("[galleries] <%s> error getting gallery folder: %v", galleryHash, err) logger.Errorf("[galleries] <%s> error getting gallery folder: %v", galleryHash, err)
continue continue
@@ -798,7 +805,7 @@ func exportGallery(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *mode
continue continue
} }
exportFolder(*folder, t) t.exportFolder(*folder)
} }
newGalleryJSON.Studio, err = gallery.GetStudioName(ctx, studioReader, g) newGalleryJSON.Studio, err = gallery.GetStudioName(ctx, studioReader, g)
@@ -857,10 +864,10 @@ func exportGallery(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *mode
} }
} }
func (t *ExportTask) ExportPerformers(ctx context.Context, workers int, repo Repository) { func (t *ExportTask) ExportPerformers(ctx context.Context, workers int) {
var performersWg sync.WaitGroup var performersWg sync.WaitGroup
reader := repo.Performer reader := t.repository.Performer
var performers []*models.Performer var performers []*models.Performer
var err error var err error
all := t.full || (t.performers != nil && t.performers.all) all := t.full || (t.performers != nil && t.performers.all)
@@ -880,7 +887,7 @@ func (t *ExportTask) ExportPerformers(ctx context.Context, workers int, repo Rep
for w := 0; w < workers; w++ { // create export Performer workers for w := 0; w < workers; w++ { // create export Performer workers
performersWg.Add(1) performersWg.Add(1)
go t.exportPerformer(ctx, &performersWg, jobCh, repo) go t.exportPerformer(ctx, &performersWg, jobCh)
} }
for i, performer := range performers { for i, performer := range performers {
@@ -896,10 +903,11 @@ func (t *ExportTask) ExportPerformers(ctx context.Context, workers int, repo Rep
logger.Infof("[performers] export complete in %s. %d workers used.", time.Since(startTime), workers) logger.Infof("[performers] export complete in %s. %d workers used.", time.Since(startTime), workers)
} }
func (t *ExportTask) exportPerformer(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Performer, repo Repository) { func (t *ExportTask) exportPerformer(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Performer) {
defer wg.Done() defer wg.Done()
performerReader := repo.Performer r := t.repository
performerReader := r.Performer
for p := range jobChan { for p := range jobChan {
newPerformerJSON, err := performer.ToJSON(ctx, performerReader, p) newPerformerJSON, err := performer.ToJSON(ctx, performerReader, p)
@@ -909,7 +917,7 @@ func (t *ExportTask) exportPerformer(ctx context.Context, wg *sync.WaitGroup, jo
continue continue
} }
tags, err := repo.Tag.FindByPerformerID(ctx, p.ID) tags, err := r.Tag.FindByPerformerID(ctx, p.ID)
if err != nil { if err != nil {
logger.Errorf("[performers] <%s> error getting performer tags: %s", p.Name, err.Error()) logger.Errorf("[performers] <%s> error getting performer tags: %s", p.Name, err.Error())
continue continue
@@ -929,10 +937,10 @@ func (t *ExportTask) exportPerformer(ctx context.Context, wg *sync.WaitGroup, jo
} }
} }
func (t *ExportTask) ExportStudios(ctx context.Context, workers int, repo Repository) { func (t *ExportTask) ExportStudios(ctx context.Context, workers int) {
var studiosWg sync.WaitGroup var studiosWg sync.WaitGroup
reader := repo.Studio reader := t.repository.Studio
var studios []*models.Studio var studios []*models.Studio
var err error var err error
all := t.full || (t.studios != nil && t.studios.all) all := t.full || (t.studios != nil && t.studios.all)
@@ -953,7 +961,7 @@ func (t *ExportTask) ExportStudios(ctx context.Context, workers int, repo Reposi
for w := 0; w < workers; w++ { // create export Studio workers for w := 0; w < workers; w++ { // create export Studio workers
studiosWg.Add(1) studiosWg.Add(1)
go t.exportStudio(ctx, &studiosWg, jobCh, repo) go t.exportStudio(ctx, &studiosWg, jobCh)
} }
for i, studio := range studios { for i, studio := range studios {
@@ -969,10 +977,10 @@ func (t *ExportTask) ExportStudios(ctx context.Context, workers int, repo Reposi
logger.Infof("[studios] export complete in %s. %d workers used.", time.Since(startTime), workers) logger.Infof("[studios] export complete in %s. %d workers used.", time.Since(startTime), workers)
} }
func (t *ExportTask) exportStudio(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Studio, repo Repository) { func (t *ExportTask) exportStudio(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Studio) {
defer wg.Done() defer wg.Done()
studioReader := repo.Studio studioReader := t.repository.Studio
for s := range jobChan { for s := range jobChan {
newStudioJSON, err := studio.ToJSON(ctx, studioReader, s) newStudioJSON, err := studio.ToJSON(ctx, studioReader, s)
@@ -990,10 +998,10 @@ func (t *ExportTask) exportStudio(ctx context.Context, wg *sync.WaitGroup, jobCh
} }
} }
func (t *ExportTask) ExportTags(ctx context.Context, workers int, repo Repository) { func (t *ExportTask) ExportTags(ctx context.Context, workers int) {
var tagsWg sync.WaitGroup var tagsWg sync.WaitGroup
reader := repo.Tag reader := t.repository.Tag
var tags []*models.Tag var tags []*models.Tag
var err error var err error
all := t.full || (t.tags != nil && t.tags.all) all := t.full || (t.tags != nil && t.tags.all)
@@ -1014,7 +1022,7 @@ func (t *ExportTask) ExportTags(ctx context.Context, workers int, repo Repositor
for w := 0; w < workers; w++ { // create export Tag workers for w := 0; w < workers; w++ { // create export Tag workers
tagsWg.Add(1) tagsWg.Add(1)
go t.exportTag(ctx, &tagsWg, jobCh, repo) go t.exportTag(ctx, &tagsWg, jobCh)
} }
for i, tag := range tags { for i, tag := range tags {
@@ -1030,10 +1038,10 @@ func (t *ExportTask) ExportTags(ctx context.Context, workers int, repo Repositor
logger.Infof("[tags] export complete in %s. %d workers used.", time.Since(startTime), workers) logger.Infof("[tags] export complete in %s. %d workers used.", time.Since(startTime), workers)
} }
func (t *ExportTask) exportTag(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Tag, repo Repository) { func (t *ExportTask) exportTag(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Tag) {
defer wg.Done() defer wg.Done()
tagReader := repo.Tag tagReader := t.repository.Tag
for thisTag := range jobChan { for thisTag := range jobChan {
newTagJSON, err := tag.ToJSON(ctx, tagReader, thisTag) newTagJSON, err := tag.ToJSON(ctx, tagReader, thisTag)
@@ -1051,10 +1059,10 @@ func (t *ExportTask) exportTag(ctx context.Context, wg *sync.WaitGroup, jobChan
} }
} }
func (t *ExportTask) ExportMovies(ctx context.Context, workers int, repo Repository) { func (t *ExportTask) ExportMovies(ctx context.Context, workers int) {
var moviesWg sync.WaitGroup var moviesWg sync.WaitGroup
reader := repo.Movie reader := t.repository.Movie
var movies []*models.Movie var movies []*models.Movie
var err error var err error
all := t.full || (t.movies != nil && t.movies.all) all := t.full || (t.movies != nil && t.movies.all)
@@ -1075,7 +1083,7 @@ func (t *ExportTask) ExportMovies(ctx context.Context, workers int, repo Reposit
for w := 0; w < workers; w++ { // create export Studio workers for w := 0; w < workers; w++ { // create export Studio workers
moviesWg.Add(1) moviesWg.Add(1)
go t.exportMovie(ctx, &moviesWg, jobCh, repo) go t.exportMovie(ctx, &moviesWg, jobCh)
} }
for i, movie := range movies { for i, movie := range movies {
@@ -1091,11 +1099,12 @@ func (t *ExportTask) ExportMovies(ctx context.Context, workers int, repo Reposit
logger.Infof("[movies] export complete in %s. %d workers used.", time.Since(startTime), workers) logger.Infof("[movies] export complete in %s. %d workers used.", time.Since(startTime), workers)
} }
func (t *ExportTask) exportMovie(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Movie, repo Repository) { func (t *ExportTask) exportMovie(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Movie) {
defer wg.Done() defer wg.Done()
movieReader := repo.Movie r := t.repository
studioReader := repo.Studio movieReader := r.Movie
studioReader := r.Studio
for m := range jobChan { for m := range jobChan {
newMovieJSON, err := movie.ToJSON(ctx, movieReader, studioReader, m) newMovieJSON, err := movie.ToJSON(ctx, movieReader, studioReader, m)

View File

@@ -55,7 +55,7 @@ type GeneratePreviewOptionsInput struct {
const generateQueueSize = 200000 const generateQueueSize = 200000
type GenerateJob struct { type GenerateJob struct {
txnManager Repository repository models.Repository
input GenerateMetadataInput input GenerateMetadataInput
overwrite bool overwrite bool
@@ -112,8 +112,9 @@ func (j *GenerateJob) Execute(ctx context.Context, progress *job.Progress) {
Overwrite: j.overwrite, Overwrite: j.overwrite,
} }
if err := j.txnManager.WithReadTxn(ctx, func(ctx context.Context) error { r := j.repository
qb := j.txnManager.Scene if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
qb := r.Scene
if len(j.input.SceneIDs) == 0 && len(j.input.MarkerIDs) == 0 { if len(j.input.SceneIDs) == 0 && len(j.input.MarkerIDs) == 0 {
totals = j.queueTasks(ctx, g, queue) totals = j.queueTasks(ctx, g, queue)
} else { } else {
@@ -129,7 +130,7 @@ func (j *GenerateJob) Execute(ctx context.Context, progress *job.Progress) {
} }
if len(j.input.MarkerIDs) > 0 { if len(j.input.MarkerIDs) > 0 {
markers, err = j.txnManager.SceneMarker.FindMany(ctx, markerIDs) markers, err = r.SceneMarker.FindMany(ctx, markerIDs)
if err != nil { if err != nil {
return err return err
} }
@@ -229,12 +230,14 @@ func (j *GenerateJob) queueTasks(ctx context.Context, g *generate.Generator, que
findFilter := models.BatchFindFilter(batchSize) findFilter := models.BatchFindFilter(batchSize)
r := j.repository
for more := true; more; { for more := true; more; {
if job.IsCancelled(ctx) { if job.IsCancelled(ctx) {
return totals return totals
} }
scenes, err := scene.Query(ctx, j.txnManager.Scene, nil, findFilter) scenes, err := scene.Query(ctx, r.Scene, nil, findFilter)
if err != nil { if err != nil {
logger.Errorf("Error encountered queuing files to scan: %s", err.Error()) logger.Errorf("Error encountered queuing files to scan: %s", err.Error())
return totals return totals
@@ -245,7 +248,7 @@ func (j *GenerateJob) queueTasks(ctx context.Context, g *generate.Generator, que
return totals return totals
} }
if err := ss.LoadFiles(ctx, j.txnManager.Scene); err != nil { if err := ss.LoadFiles(ctx, r.Scene); err != nil {
logger.Errorf("Error encountered queuing files to scan: %s", err.Error()) logger.Errorf("Error encountered queuing files to scan: %s", err.Error())
return totals return totals
} }
@@ -266,7 +269,7 @@ func (j *GenerateJob) queueTasks(ctx context.Context, g *generate.Generator, que
return totals return totals
} }
images, err := image.Query(ctx, j.txnManager.Image, nil, findFilter) images, err := image.Query(ctx, r.Image, nil, findFilter)
if err != nil { if err != nil {
logger.Errorf("Error encountered queuing files to scan: %s", err.Error()) logger.Errorf("Error encountered queuing files to scan: %s", err.Error())
return totals return totals
@@ -277,7 +280,7 @@ func (j *GenerateJob) queueTasks(ctx context.Context, g *generate.Generator, que
return totals return totals
} }
if err := ss.LoadFiles(ctx, j.txnManager.Image); err != nil { if err := ss.LoadFiles(ctx, r.Image); err != nil {
logger.Errorf("Error encountered queuing files to scan: %s", err.Error()) logger.Errorf("Error encountered queuing files to scan: %s", err.Error())
return totals return totals
} }
@@ -331,9 +334,11 @@ func getGeneratePreviewOptions(optionsInput GeneratePreviewOptionsInput) generat
} }
func (j *GenerateJob) queueSceneJobs(ctx context.Context, g *generate.Generator, scene *models.Scene, queue chan<- Task, totals *totalsGenerate) { func (j *GenerateJob) queueSceneJobs(ctx context.Context, g *generate.Generator, scene *models.Scene, queue chan<- Task, totals *totalsGenerate) {
r := j.repository
if j.input.Covers { if j.input.Covers {
task := &GenerateCoverTask{ task := &GenerateCoverTask{
txnManager: j.txnManager, repository: r,
Scene: *scene, Scene: *scene,
Overwrite: j.overwrite, Overwrite: j.overwrite,
} }
@@ -390,7 +395,7 @@ func (j *GenerateJob) queueSceneJobs(ctx context.Context, g *generate.Generator,
if j.input.Markers { if j.input.Markers {
task := &GenerateMarkersTask{ task := &GenerateMarkersTask{
TxnManager: j.txnManager, repository: r,
Scene: scene, Scene: scene,
Overwrite: j.overwrite, Overwrite: j.overwrite,
fileNamingAlgorithm: j.fileNamingAlgo, fileNamingAlgorithm: j.fileNamingAlgo,
@@ -429,10 +434,9 @@ func (j *GenerateJob) queueSceneJobs(ctx context.Context, g *generate.Generator,
// generate for all files in scene // generate for all files in scene
for _, f := range scene.Files.List() { for _, f := range scene.Files.List() {
task := &GeneratePhashTask{ task := &GeneratePhashTask{
repository: r,
File: f, File: f,
fileNamingAlgorithm: j.fileNamingAlgo, fileNamingAlgorithm: j.fileNamingAlgo,
txnManager: j.txnManager,
fileUpdater: j.txnManager.File,
Overwrite: j.overwrite, Overwrite: j.overwrite,
} }
@@ -446,10 +450,10 @@ func (j *GenerateJob) queueSceneJobs(ctx context.Context, g *generate.Generator,
if j.input.InteractiveHeatmapsSpeeds { if j.input.InteractiveHeatmapsSpeeds {
task := &GenerateInteractiveHeatmapSpeedTask{ task := &GenerateInteractiveHeatmapSpeedTask{
repository: r,
Scene: *scene, Scene: *scene,
Overwrite: j.overwrite, Overwrite: j.overwrite,
fileNamingAlgorithm: j.fileNamingAlgo, fileNamingAlgorithm: j.fileNamingAlgo,
TxnManager: j.txnManager,
} }
if task.required() { if task.required() {
@@ -462,7 +466,7 @@ func (j *GenerateJob) queueSceneJobs(ctx context.Context, g *generate.Generator,
func (j *GenerateJob) queueMarkerJob(g *generate.Generator, marker *models.SceneMarker, queue chan<- Task, totals *totalsGenerate) { func (j *GenerateJob) queueMarkerJob(g *generate.Generator, marker *models.SceneMarker, queue chan<- Task, totals *totalsGenerate) {
task := &GenerateMarkersTask{ task := &GenerateMarkersTask{
TxnManager: j.txnManager, repository: j.repository,
Marker: marker, Marker: marker,
Overwrite: j.overwrite, Overwrite: j.overwrite,
fileNamingAlgorithm: j.fileNamingAlgo, fileNamingAlgorithm: j.fileNamingAlgo,

View File

@@ -11,10 +11,10 @@ import (
) )
type GenerateInteractiveHeatmapSpeedTask struct { type GenerateInteractiveHeatmapSpeedTask struct {
repository models.Repository
Scene models.Scene Scene models.Scene
Overwrite bool Overwrite bool
fileNamingAlgorithm models.HashAlgorithm fileNamingAlgorithm models.HashAlgorithm
TxnManager Repository
} }
func (t *GenerateInteractiveHeatmapSpeedTask) GetDescription() string { func (t *GenerateInteractiveHeatmapSpeedTask) GetDescription() string {
@@ -42,10 +42,11 @@ func (t *GenerateInteractiveHeatmapSpeedTask) Start(ctx context.Context) {
median := generator.InteractiveSpeed median := generator.InteractiveSpeed
if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error { r := t.repository
if err := r.WithTxn(ctx, func(ctx context.Context) error {
primaryFile := t.Scene.Files.Primary() primaryFile := t.Scene.Files.Primary()
primaryFile.InteractiveSpeed = &median primaryFile.InteractiveSpeed = &median
qb := t.TxnManager.File qb := r.File
return qb.Update(ctx, primaryFile) return qb.Update(ctx, primaryFile)
}); err != nil && ctx.Err() == nil { }); err != nil && ctx.Err() == nil {
logger.Error(err.Error()) logger.Error(err.Error())

View File

@@ -12,7 +12,7 @@ import (
) )
type GenerateMarkersTask struct { type GenerateMarkersTask struct {
TxnManager Repository repository models.Repository
Scene *models.Scene Scene *models.Scene
Marker *models.SceneMarker Marker *models.SceneMarker
Overwrite bool Overwrite bool
@@ -41,9 +41,10 @@ func (t *GenerateMarkersTask) Start(ctx context.Context) {
if t.Marker != nil { if t.Marker != nil {
var scene *models.Scene var scene *models.Scene
if err := t.TxnManager.WithReadTxn(ctx, func(ctx context.Context) error { r := t.repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
var err error var err error
scene, err = t.TxnManager.Scene.Find(ctx, t.Marker.SceneID) scene, err = r.Scene.Find(ctx, t.Marker.SceneID)
if err != nil { if err != nil {
return err return err
} }
@@ -51,7 +52,7 @@ func (t *GenerateMarkersTask) Start(ctx context.Context) {
return fmt.Errorf("scene with id %d not found", t.Marker.SceneID) return fmt.Errorf("scene with id %d not found", t.Marker.SceneID)
} }
return scene.LoadPrimaryFile(ctx, t.TxnManager.File) return scene.LoadPrimaryFile(ctx, r.File)
}); err != nil { }); err != nil {
logger.Errorf("error finding scene for marker generation: %v", err) logger.Errorf("error finding scene for marker generation: %v", err)
return return
@@ -70,9 +71,10 @@ func (t *GenerateMarkersTask) Start(ctx context.Context) {
func (t *GenerateMarkersTask) generateSceneMarkers(ctx context.Context) { func (t *GenerateMarkersTask) generateSceneMarkers(ctx context.Context) {
var sceneMarkers []*models.SceneMarker var sceneMarkers []*models.SceneMarker
if err := t.TxnManager.WithReadTxn(ctx, func(ctx context.Context) error { r := t.repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
var err error var err error
sceneMarkers, err = t.TxnManager.SceneMarker.FindBySceneID(ctx, t.Scene.ID) sceneMarkers, err = r.SceneMarker.FindBySceneID(ctx, t.Scene.ID)
return err return err
}); err != nil { }); err != nil {
logger.Errorf("error getting scene markers: %s", err.Error()) logger.Errorf("error getting scene markers: %s", err.Error())
@@ -129,7 +131,7 @@ func (t *GenerateMarkersTask) generateMarker(videoFile *models.VideoFile, scene
func (t *GenerateMarkersTask) markersNeeded(ctx context.Context) int { func (t *GenerateMarkersTask) markersNeeded(ctx context.Context) int {
markers := 0 markers := 0
sceneMarkers, err := t.TxnManager.SceneMarker.FindBySceneID(ctx, t.Scene.ID) sceneMarkers, err := t.repository.SceneMarker.FindBySceneID(ctx, t.Scene.ID)
if err != nil { if err != nil {
logger.Errorf("error finding scene markers: %s", err.Error()) logger.Errorf("error finding scene markers: %s", err.Error())
return 0 return 0

View File

@@ -7,15 +7,13 @@ import (
"github.com/stashapp/stash/pkg/hash/videophash" "github.com/stashapp/stash/pkg/hash/videophash"
"github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
) )
type GeneratePhashTask struct { type GeneratePhashTask struct {
repository models.Repository
File *models.VideoFile File *models.VideoFile
Overwrite bool Overwrite bool
fileNamingAlgorithm models.HashAlgorithm fileNamingAlgorithm models.HashAlgorithm
txnManager txn.Manager
fileUpdater models.FileUpdater
} }
func (t *GeneratePhashTask) GetDescription() string { func (t *GeneratePhashTask) GetDescription() string {
@@ -34,15 +32,15 @@ func (t *GeneratePhashTask) Start(ctx context.Context) {
return return
} }
if err := txn.WithTxn(ctx, t.txnManager, func(ctx context.Context) error { r := t.repository
qb := t.fileUpdater if err := r.WithTxn(ctx, func(ctx context.Context) error {
hashValue := int64(*hash) hashValue := int64(*hash)
t.File.Fingerprints = t.File.Fingerprints.AppendUnique(models.Fingerprint{ t.File.Fingerprints = t.File.Fingerprints.AppendUnique(models.Fingerprint{
Type: models.FingerprintTypePhash, Type: models.FingerprintTypePhash,
Fingerprint: hashValue, Fingerprint: hashValue,
}) })
return qb.Update(ctx, t.File) return r.File.Update(ctx, t.File)
}); err != nil && ctx.Err() == nil { }); err != nil && ctx.Err() == nil {
logger.Errorf("Error setting phash: %v", err) logger.Errorf("Error setting phash: %v", err)
} }

View File

@@ -10,9 +10,9 @@ import (
) )
type GenerateCoverTask struct { type GenerateCoverTask struct {
repository models.Repository
Scene models.Scene Scene models.Scene
ScreenshotAt *float64 ScreenshotAt *float64
txnManager Repository
Overwrite bool Overwrite bool
} }
@@ -23,11 +23,13 @@ func (t *GenerateCoverTask) GetDescription() string {
func (t *GenerateCoverTask) Start(ctx context.Context) { func (t *GenerateCoverTask) Start(ctx context.Context) {
scenePath := t.Scene.Path scenePath := t.Scene.Path
r := t.repository
var required bool var required bool
if err := t.txnManager.WithReadTxn(ctx, func(ctx context.Context) error { if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
required = t.required(ctx) required = t.required(ctx)
return t.Scene.LoadPrimaryFile(ctx, t.txnManager.File) return t.Scene.LoadPrimaryFile(ctx, r.File)
}); err != nil { }); err != nil {
logger.Error(err) logger.Error(err)
} }
@@ -70,8 +72,8 @@ func (t *GenerateCoverTask) Start(ctx context.Context) {
return return
} }
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error { if err := r.WithTxn(ctx, func(ctx context.Context) error {
qb := t.txnManager.Scene qb := r.Scene
scenePartial := models.NewScenePartial() scenePartial := models.NewScenePartial()
// update the scene cover table // update the scene cover table
@@ -103,7 +105,7 @@ func (t *GenerateCoverTask) required(ctx context.Context) bool {
} }
// if the scene has a cover, then we don't need to generate it // if the scene has a cover, then we don't need to generate it
hasCover, err := t.txnManager.Scene.HasCover(ctx, t.Scene.ID) hasCover, err := t.repository.Scene.HasCover(ctx, t.Scene.ID)
if err != nil { if err != nil {
logger.Errorf("Error getting cover: %v", err) logger.Errorf("Error getting cover: %v", err)
return false return false

View File

@@ -14,7 +14,6 @@ import (
"github.com/stashapp/stash/pkg/scraper" "github.com/stashapp/stash/pkg/scraper"
"github.com/stashapp/stash/pkg/scraper/stashbox" "github.com/stashapp/stash/pkg/scraper/stashbox"
"github.com/stashapp/stash/pkg/sliceutil/stringslice" "github.com/stashapp/stash/pkg/sliceutil/stringslice"
"github.com/stashapp/stash/pkg/txn"
) )
var ErrInput = errors.New("invalid request input") var ErrInput = errors.New("invalid request input")
@@ -52,7 +51,8 @@ func (j *IdentifyJob) Execute(ctx context.Context, progress *job.Progress) {
// if scene ids provided, use those // if scene ids provided, use those
// otherwise, batch query for all scenes - ordering by path // otherwise, batch query for all scenes - ordering by path
// don't use a transaction to query scenes // don't use a transaction to query scenes
if err := txn.WithDatabase(ctx, instance.Repository, func(ctx context.Context) error { r := instance.Repository
if err := r.WithDB(ctx, func(ctx context.Context) error {
if len(j.input.SceneIDs) == 0 { if len(j.input.SceneIDs) == 0 {
return j.identifyAllScenes(ctx, sources) return j.identifyAllScenes(ctx, sources)
} }
@@ -70,7 +70,7 @@ func (j *IdentifyJob) Execute(ctx context.Context, progress *job.Progress) {
// find the scene // find the scene
var err error var err error
scene, err := instance.Repository.Scene.Find(ctx, id) scene, err := r.Scene.Find(ctx, id)
if err != nil { if err != nil {
return fmt.Errorf("finding scene id %d: %w", id, err) return fmt.Errorf("finding scene id %d: %w", id, err)
} }
@@ -89,6 +89,8 @@ func (j *IdentifyJob) Execute(ctx context.Context, progress *job.Progress) {
} }
func (j *IdentifyJob) identifyAllScenes(ctx context.Context, sources []identify.ScraperSource) error { func (j *IdentifyJob) identifyAllScenes(ctx context.Context, sources []identify.ScraperSource) error {
r := instance.Repository
// exclude organised // exclude organised
organised := false organised := false
sceneFilter := scene.FilterFromPaths(j.input.Paths) sceneFilter := scene.FilterFromPaths(j.input.Paths)
@@ -102,7 +104,7 @@ func (j *IdentifyJob) identifyAllScenes(ctx context.Context, sources []identify.
// get the count // get the count
pp := 0 pp := 0
findFilter.PerPage = &pp findFilter.PerPage = &pp
countResult, err := instance.Repository.Scene.Query(ctx, models.SceneQueryOptions{ countResult, err := r.Scene.Query(ctx, models.SceneQueryOptions{
QueryOptions: models.QueryOptions{ QueryOptions: models.QueryOptions{
FindFilter: findFilter, FindFilter: findFilter,
Count: true, Count: true,
@@ -115,7 +117,7 @@ func (j *IdentifyJob) identifyAllScenes(ctx context.Context, sources []identify.
j.progress.SetTotal(countResult.Count) j.progress.SetTotal(countResult.Count)
return scene.BatchProcess(ctx, instance.Repository.Scene, sceneFilter, findFilter, func(scene *models.Scene) error { return scene.BatchProcess(ctx, r.Scene, sceneFilter, findFilter, func(scene *models.Scene) error {
if job.IsCancelled(ctx) { if job.IsCancelled(ctx) {
return nil return nil
} }
@@ -132,18 +134,20 @@ func (j *IdentifyJob) identifyScene(ctx context.Context, s *models.Scene, source
var taskError error var taskError error
j.progress.ExecuteTask("Identifying "+s.Path, func() { j.progress.ExecuteTask("Identifying "+s.Path, func() {
r := instance.Repository
task := identify.SceneIdentifier{ task := identify.SceneIdentifier{
SceneReaderUpdater: instance.Repository.Scene, TxnManager: r.TxnManager,
StudioReaderWriter: instance.Repository.Studio, SceneReaderUpdater: r.Scene,
PerformerCreator: instance.Repository.Performer, StudioReaderWriter: r.Studio,
TagFinderCreator: instance.Repository.Tag, PerformerCreator: r.Performer,
TagFinderCreator: r.Tag,
DefaultOptions: j.input.Options, DefaultOptions: j.input.Options,
Sources: sources, Sources: sources,
SceneUpdatePostHookExecutor: j.postHookExecutor, SceneUpdatePostHookExecutor: j.postHookExecutor,
} }
taskError = task.Identify(ctx, instance.Repository, s) taskError = task.Identify(ctx, s)
}) })
if taskError != nil { if taskError != nil {
@@ -164,15 +168,11 @@ func (j *IdentifyJob) getSources() ([]identify.ScraperSource, error) {
var src identify.ScraperSource var src identify.ScraperSource
if stashBox != nil { if stashBox != nil {
stashboxRepository := stashbox.NewRepository(instance.Repository)
src = identify.ScraperSource{ src = identify.ScraperSource{
Name: "stash-box: " + stashBox.Endpoint, Name: "stash-box: " + stashBox.Endpoint,
Scraper: stashboxSource{ Scraper: stashboxSource{
stashbox.NewClient(*stashBox, instance.Repository, stashbox.Repository{ stashbox.NewClient(*stashBox, stashboxRepository),
Scene: instance.Repository.Scene,
Performer: instance.Repository.Performer,
Tag: instance.Repository.Tag,
Studio: instance.Repository.Studio,
}),
stashBox.Endpoint, stashBox.Endpoint,
}, },
RemoteSite: stashBox.Endpoint, RemoteSite: stashBox.Endpoint,

View File

@@ -25,8 +25,13 @@ import (
"github.com/stashapp/stash/pkg/tag" "github.com/stashapp/stash/pkg/tag"
) )
type Resetter interface {
Reset() error
}
type ImportTask struct { type ImportTask struct {
txnManager Repository repository models.Repository
resetter Resetter
json jsonUtils json jsonUtils
BaseDir string BaseDir string
@@ -66,8 +71,10 @@ func CreateImportTask(a models.HashAlgorithm, input ImportObjectsInput) (*Import
} }
} }
mgr := GetInstance()
return &ImportTask{ return &ImportTask{
txnManager: GetInstance().Repository, repository: mgr.Repository,
resetter: mgr.Database,
BaseDir: baseDir, BaseDir: baseDir,
TmpZip: tmpZip, TmpZip: tmpZip,
Reset: false, Reset: false,
@@ -109,7 +116,7 @@ func (t *ImportTask) Start(ctx context.Context) {
} }
if t.Reset { if t.Reset {
err := t.txnManager.Reset() err := t.resetter.Reset()
if err != nil { if err != nil {
logger.Errorf("Error resetting database: %s", err.Error()) logger.Errorf("Error resetting database: %s", err.Error())
@@ -194,6 +201,8 @@ func (t *ImportTask) ImportPerformers(ctx context.Context) {
return return
} }
r := t.repository
for i, fi := range files { for i, fi := range files {
index := i + 1 index := i + 1
performerJSON, err := jsonschema.LoadPerformerFile(filepath.Join(path, fi.Name())) performerJSON, err := jsonschema.LoadPerformerFile(filepath.Join(path, fi.Name()))
@@ -204,11 +213,9 @@ func (t *ImportTask) ImportPerformers(ctx context.Context) {
logger.Progressf("[performers] %d of %d", index, len(files)) logger.Progressf("[performers] %d of %d", index, len(files))
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error { if err := r.WithTxn(ctx, func(ctx context.Context) error {
r := t.txnManager
readerWriter := r.Performer
importer := &performer.Importer{ importer := &performer.Importer{
ReaderWriter: readerWriter, ReaderWriter: r.Performer,
TagWriter: r.Tag, TagWriter: r.Tag,
Input: *performerJSON, Input: *performerJSON,
} }
@@ -237,6 +244,8 @@ func (t *ImportTask) ImportStudios(ctx context.Context) {
return return
} }
r := t.repository
for i, fi := range files { for i, fi := range files {
index := i + 1 index := i + 1
studioJSON, err := jsonschema.LoadStudioFile(filepath.Join(path, fi.Name())) studioJSON, err := jsonschema.LoadStudioFile(filepath.Join(path, fi.Name()))
@@ -247,8 +256,8 @@ func (t *ImportTask) ImportStudios(ctx context.Context) {
logger.Progressf("[studios] %d of %d", index, len(files)) logger.Progressf("[studios] %d of %d", index, len(files))
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error { if err := r.WithTxn(ctx, func(ctx context.Context) error {
return t.ImportStudio(ctx, studioJSON, pendingParent, t.txnManager.Studio) return t.importStudio(ctx, studioJSON, pendingParent)
}); err != nil { }); err != nil {
if errors.Is(err, studio.ErrParentStudioNotExist) { if errors.Is(err, studio.ErrParentStudioNotExist) {
// add to the pending parent list so that it is created after the parent // add to the pending parent list so that it is created after the parent
@@ -269,8 +278,8 @@ func (t *ImportTask) ImportStudios(ctx context.Context) {
for _, s := range pendingParent { for _, s := range pendingParent {
for _, orphanStudioJSON := range s { for _, orphanStudioJSON := range s {
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error { if err := r.WithTxn(ctx, func(ctx context.Context) error {
return t.ImportStudio(ctx, orphanStudioJSON, nil, t.txnManager.Studio) return t.importStudio(ctx, orphanStudioJSON, nil)
}); err != nil { }); err != nil {
logger.Errorf("[studios] <%s> failed to create: %s", orphanStudioJSON.Name, err.Error()) logger.Errorf("[studios] <%s> failed to create: %s", orphanStudioJSON.Name, err.Error())
continue continue
@@ -282,9 +291,9 @@ func (t *ImportTask) ImportStudios(ctx context.Context) {
logger.Info("[studios] import complete") logger.Info("[studios] import complete")
} }
func (t *ImportTask) ImportStudio(ctx context.Context, studioJSON *jsonschema.Studio, pendingParent map[string][]*jsonschema.Studio, readerWriter studio.ImporterReaderWriter) error { func (t *ImportTask) importStudio(ctx context.Context, studioJSON *jsonschema.Studio, pendingParent map[string][]*jsonschema.Studio) error {
importer := &studio.Importer{ importer := &studio.Importer{
ReaderWriter: readerWriter, ReaderWriter: t.repository.Studio,
Input: *studioJSON, Input: *studioJSON,
MissingRefBehaviour: t.MissingRefBehaviour, MissingRefBehaviour: t.MissingRefBehaviour,
} }
@@ -302,7 +311,7 @@ func (t *ImportTask) ImportStudio(ctx context.Context, studioJSON *jsonschema.St
s := pendingParent[studioJSON.Name] s := pendingParent[studioJSON.Name]
for _, childStudioJSON := range s { for _, childStudioJSON := range s {
// map is nil since we're not checking parent studios at this point // map is nil since we're not checking parent studios at this point
if err := t.ImportStudio(ctx, childStudioJSON, nil, readerWriter); err != nil { if err := t.importStudio(ctx, childStudioJSON, nil); err != nil {
return fmt.Errorf("failed to create child studio <%s>: %s", childStudioJSON.Name, err.Error()) return fmt.Errorf("failed to create child studio <%s>: %s", childStudioJSON.Name, err.Error())
} }
} }
@@ -326,6 +335,8 @@ func (t *ImportTask) ImportMovies(ctx context.Context) {
return return
} }
r := t.repository
for i, fi := range files { for i, fi := range files {
index := i + 1 index := i + 1
movieJSON, err := jsonschema.LoadMovieFile(filepath.Join(path, fi.Name())) movieJSON, err := jsonschema.LoadMovieFile(filepath.Join(path, fi.Name()))
@@ -336,14 +347,10 @@ func (t *ImportTask) ImportMovies(ctx context.Context) {
logger.Progressf("[movies] %d of %d", index, len(files)) logger.Progressf("[movies] %d of %d", index, len(files))
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error { if err := r.WithTxn(ctx, func(ctx context.Context) error {
r := t.txnManager
readerWriter := r.Movie
studioReaderWriter := r.Studio
movieImporter := &movie.Importer{ movieImporter := &movie.Importer{
ReaderWriter: readerWriter, ReaderWriter: r.Movie,
StudioWriter: studioReaderWriter, StudioWriter: r.Studio,
Input: *movieJSON, Input: *movieJSON,
MissingRefBehaviour: t.MissingRefBehaviour, MissingRefBehaviour: t.MissingRefBehaviour,
} }
@@ -371,6 +378,8 @@ func (t *ImportTask) ImportFiles(ctx context.Context) {
return return
} }
r := t.repository
pendingParent := make(map[string][]jsonschema.DirEntry) pendingParent := make(map[string][]jsonschema.DirEntry)
for i, fi := range files { for i, fi := range files {
@@ -383,8 +392,8 @@ func (t *ImportTask) ImportFiles(ctx context.Context) {
logger.Progressf("[files] %d of %d", index, len(files)) logger.Progressf("[files] %d of %d", index, len(files))
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error { if err := r.WithTxn(ctx, func(ctx context.Context) error {
return t.ImportFile(ctx, fileJSON, pendingParent) return t.importFile(ctx, fileJSON, pendingParent)
}); err != nil { }); err != nil {
if errors.Is(err, file.ErrZipFileNotExist) { if errors.Is(err, file.ErrZipFileNotExist) {
// add to the pending parent list so that it is created after the parent // add to the pending parent list so that it is created after the parent
@@ -405,8 +414,8 @@ func (t *ImportTask) ImportFiles(ctx context.Context) {
for _, s := range pendingParent { for _, s := range pendingParent {
for _, orphanFileJSON := range s { for _, orphanFileJSON := range s {
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error { if err := r.WithTxn(ctx, func(ctx context.Context) error {
return t.ImportFile(ctx, orphanFileJSON, nil) return t.importFile(ctx, orphanFileJSON, nil)
}); err != nil { }); err != nil {
logger.Errorf("[files] <%s> failed to create: %s", orphanFileJSON.DirEntry().Path, err.Error()) logger.Errorf("[files] <%s> failed to create: %s", orphanFileJSON.DirEntry().Path, err.Error())
continue continue
@@ -418,12 +427,11 @@ func (t *ImportTask) ImportFiles(ctx context.Context) {
logger.Info("[files] import complete") logger.Info("[files] import complete")
} }
func (t *ImportTask) ImportFile(ctx context.Context, fileJSON jsonschema.DirEntry, pendingParent map[string][]jsonschema.DirEntry) error { func (t *ImportTask) importFile(ctx context.Context, fileJSON jsonschema.DirEntry, pendingParent map[string][]jsonschema.DirEntry) error {
r := t.txnManager r := t.repository
readerWriter := r.File
fileImporter := &file.Importer{ fileImporter := &file.Importer{
ReaderWriter: readerWriter, ReaderWriter: r.File,
FolderStore: r.Folder, FolderStore: r.Folder,
Input: fileJSON, Input: fileJSON,
} }
@@ -437,7 +445,7 @@ func (t *ImportTask) ImportFile(ctx context.Context, fileJSON jsonschema.DirEntr
s := pendingParent[fileJSON.DirEntry().Path] s := pendingParent[fileJSON.DirEntry().Path]
for _, childFileJSON := range s { for _, childFileJSON := range s {
// map is nil since we're not checking parent studios at this point // map is nil since we're not checking parent studios at this point
if err := t.ImportFile(ctx, childFileJSON, nil); err != nil { if err := t.importFile(ctx, childFileJSON, nil); err != nil {
return fmt.Errorf("failed to create child file <%s>: %s", childFileJSON.DirEntry().Path, err.Error()) return fmt.Errorf("failed to create child file <%s>: %s", childFileJSON.DirEntry().Path, err.Error())
} }
} }
@@ -461,6 +469,8 @@ func (t *ImportTask) ImportGalleries(ctx context.Context) {
return return
} }
r := t.repository
for i, fi := range files { for i, fi := range files {
index := i + 1 index := i + 1
galleryJSON, err := jsonschema.LoadGalleryFile(filepath.Join(path, fi.Name())) galleryJSON, err := jsonschema.LoadGalleryFile(filepath.Join(path, fi.Name()))
@@ -471,21 +481,14 @@ func (t *ImportTask) ImportGalleries(ctx context.Context) {
logger.Progressf("[galleries] %d of %d", index, len(files)) logger.Progressf("[galleries] %d of %d", index, len(files))
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error { if err := r.WithTxn(ctx, func(ctx context.Context) error {
r := t.txnManager
readerWriter := r.Gallery
tagWriter := r.Tag
performerWriter := r.Performer
studioWriter := r.Studio
chapterWriter := r.GalleryChapter
galleryImporter := &gallery.Importer{ galleryImporter := &gallery.Importer{
ReaderWriter: readerWriter, ReaderWriter: r.Gallery,
FolderFinder: r.Folder, FolderFinder: r.Folder,
FileFinder: r.File, FileFinder: r.File,
PerformerWriter: performerWriter, PerformerWriter: r.Performer,
StudioWriter: studioWriter, StudioWriter: r.Studio,
TagWriter: tagWriter, TagWriter: r.Tag,
Input: *galleryJSON, Input: *galleryJSON,
MissingRefBehaviour: t.MissingRefBehaviour, MissingRefBehaviour: t.MissingRefBehaviour,
} }
@@ -500,7 +503,7 @@ func (t *ImportTask) ImportGalleries(ctx context.Context) {
GalleryID: galleryImporter.ID, GalleryID: galleryImporter.ID,
Input: m, Input: m,
MissingRefBehaviour: t.MissingRefBehaviour, MissingRefBehaviour: t.MissingRefBehaviour,
ReaderWriter: chapterWriter, ReaderWriter: r.GalleryChapter,
} }
if err := performImport(ctx, chapterImporter, t.DuplicateBehaviour); err != nil { if err := performImport(ctx, chapterImporter, t.DuplicateBehaviour); err != nil {
@@ -532,6 +535,8 @@ func (t *ImportTask) ImportTags(ctx context.Context) {
return return
} }
r := t.repository
for i, fi := range files { for i, fi := range files {
index := i + 1 index := i + 1
tagJSON, err := jsonschema.LoadTagFile(filepath.Join(path, fi.Name())) tagJSON, err := jsonschema.LoadTagFile(filepath.Join(path, fi.Name()))
@@ -542,8 +547,8 @@ func (t *ImportTask) ImportTags(ctx context.Context) {
logger.Progressf("[tags] %d of %d", index, len(files)) logger.Progressf("[tags] %d of %d", index, len(files))
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error { if err := r.WithTxn(ctx, func(ctx context.Context) error {
return t.ImportTag(ctx, tagJSON, pendingParent, false, t.txnManager.Tag) return t.importTag(ctx, tagJSON, pendingParent, false)
}); err != nil { }); err != nil {
var parentError tag.ParentTagNotExistError var parentError tag.ParentTagNotExistError
if errors.As(err, &parentError) { if errors.As(err, &parentError) {
@@ -558,8 +563,8 @@ func (t *ImportTask) ImportTags(ctx context.Context) {
for _, s := range pendingParent { for _, s := range pendingParent {
for _, orphanTagJSON := range s { for _, orphanTagJSON := range s {
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error { if err := r.WithTxn(ctx, func(ctx context.Context) error {
return t.ImportTag(ctx, orphanTagJSON, nil, true, t.txnManager.Tag) return t.importTag(ctx, orphanTagJSON, nil, true)
}); err != nil { }); err != nil {
logger.Errorf("[tags] <%s> failed to create: %s", orphanTagJSON.Name, err.Error()) logger.Errorf("[tags] <%s> failed to create: %s", orphanTagJSON.Name, err.Error())
continue continue
@@ -570,9 +575,9 @@ func (t *ImportTask) ImportTags(ctx context.Context) {
logger.Info("[tags] import complete") logger.Info("[tags] import complete")
} }
func (t *ImportTask) ImportTag(ctx context.Context, tagJSON *jsonschema.Tag, pendingParent map[string][]*jsonschema.Tag, fail bool, readerWriter tag.ImporterReaderWriter) error { func (t *ImportTask) importTag(ctx context.Context, tagJSON *jsonschema.Tag, pendingParent map[string][]*jsonschema.Tag, fail bool) error {
importer := &tag.Importer{ importer := &tag.Importer{
ReaderWriter: readerWriter, ReaderWriter: t.repository.Tag,
Input: *tagJSON, Input: *tagJSON,
MissingRefBehaviour: t.MissingRefBehaviour, MissingRefBehaviour: t.MissingRefBehaviour,
} }
@@ -587,7 +592,7 @@ func (t *ImportTask) ImportTag(ctx context.Context, tagJSON *jsonschema.Tag, pen
} }
for _, childTagJSON := range pendingParent[tagJSON.Name] { for _, childTagJSON := range pendingParent[tagJSON.Name] {
if err := t.ImportTag(ctx, childTagJSON, pendingParent, fail, readerWriter); err != nil { if err := t.importTag(ctx, childTagJSON, pendingParent, fail); err != nil {
var parentError tag.ParentTagNotExistError var parentError tag.ParentTagNotExistError
if errors.As(err, &parentError) { if errors.As(err, &parentError) {
pendingParent[parentError.MissingParent()] = append(pendingParent[parentError.MissingParent()], childTagJSON) pendingParent[parentError.MissingParent()] = append(pendingParent[parentError.MissingParent()], childTagJSON)
@@ -616,6 +621,8 @@ func (t *ImportTask) ImportScenes(ctx context.Context) {
return return
} }
r := t.repository
for i, fi := range files { for i, fi := range files {
index := i + 1 index := i + 1
@@ -627,29 +634,20 @@ func (t *ImportTask) ImportScenes(ctx context.Context) {
continue continue
} }
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error { if err := r.WithTxn(ctx, func(ctx context.Context) error {
r := t.txnManager
readerWriter := r.Scene
tagWriter := r.Tag
galleryWriter := r.Gallery
movieWriter := r.Movie
performerWriter := r.Performer
studioWriter := r.Studio
markerWriter := r.SceneMarker
sceneImporter := &scene.Importer{ sceneImporter := &scene.Importer{
ReaderWriter: readerWriter, ReaderWriter: r.Scene,
Input: *sceneJSON, Input: *sceneJSON,
FileFinder: r.File, FileFinder: r.File,
FileNamingAlgorithm: t.fileNamingAlgorithm, FileNamingAlgorithm: t.fileNamingAlgorithm,
MissingRefBehaviour: t.MissingRefBehaviour, MissingRefBehaviour: t.MissingRefBehaviour,
GalleryFinder: galleryWriter, GalleryFinder: r.Gallery,
MovieWriter: movieWriter, MovieWriter: r.Movie,
PerformerWriter: performerWriter, PerformerWriter: r.Performer,
StudioWriter: studioWriter, StudioWriter: r.Studio,
TagWriter: tagWriter, TagWriter: r.Tag,
} }
if err := performImport(ctx, sceneImporter, t.DuplicateBehaviour); err != nil { if err := performImport(ctx, sceneImporter, t.DuplicateBehaviour); err != nil {
@@ -662,8 +660,8 @@ func (t *ImportTask) ImportScenes(ctx context.Context) {
SceneID: sceneImporter.ID, SceneID: sceneImporter.ID,
Input: m, Input: m,
MissingRefBehaviour: t.MissingRefBehaviour, MissingRefBehaviour: t.MissingRefBehaviour,
ReaderWriter: markerWriter, ReaderWriter: r.SceneMarker,
TagWriter: tagWriter, TagWriter: r.Tag,
} }
if err := performImport(ctx, markerImporter, t.DuplicateBehaviour); err != nil { if err := performImport(ctx, markerImporter, t.DuplicateBehaviour); err != nil {
@@ -693,6 +691,8 @@ func (t *ImportTask) ImportImages(ctx context.Context) {
return return
} }
r := t.repository
for i, fi := range files { for i, fi := range files {
index := i + 1 index := i + 1
@@ -704,25 +704,18 @@ func (t *ImportTask) ImportImages(ctx context.Context) {
continue continue
} }
if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error { if err := r.WithTxn(ctx, func(ctx context.Context) error {
r := t.txnManager
readerWriter := r.Image
tagWriter := r.Tag
galleryWriter := r.Gallery
performerWriter := r.Performer
studioWriter := r.Studio
imageImporter := &image.Importer{ imageImporter := &image.Importer{
ReaderWriter: readerWriter, ReaderWriter: r.Image,
FileFinder: r.File, FileFinder: r.File,
Input: *imageJSON, Input: *imageJSON,
MissingRefBehaviour: t.MissingRefBehaviour, MissingRefBehaviour: t.MissingRefBehaviour,
GalleryFinder: galleryWriter, GalleryFinder: r.Gallery,
PerformerWriter: performerWriter, PerformerWriter: r.Performer,
StudioWriter: studioWriter, StudioWriter: r.Studio,
TagWriter: tagWriter, TagWriter: r.Tag,
} }
return performImport(ctx, imageImporter, t.DuplicateBehaviour) return performImport(ctx, imageImporter, t.DuplicateBehaviour)

View File

@@ -19,6 +19,7 @@ import (
"github.com/stashapp/stash/pkg/job" "github.com/stashapp/stash/pkg/job"
"github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/models/paths"
"github.com/stashapp/stash/pkg/scene" "github.com/stashapp/stash/pkg/scene"
"github.com/stashapp/stash/pkg/scene/generate" "github.com/stashapp/stash/pkg/scene/generate"
"github.com/stashapp/stash/pkg/txn" "github.com/stashapp/stash/pkg/txn"
@@ -48,10 +49,14 @@ func (j *ScanJob) Execute(ctx context.Context, progress *job.Progress) {
paths[i] = p.Path paths[i] = p.Path
} }
mgr := GetInstance()
c := mgr.Config
repo := mgr.Repository
start := time.Now() start := time.Now()
const taskQueueSize = 200000 const taskQueueSize = 200000
taskQueue := job.NewTaskQueue(ctx, progress, taskQueueSize, instance.Config.GetParallelTasksWithAutoDetection()) taskQueue := job.NewTaskQueue(ctx, progress, taskQueueSize, c.GetParallelTasksWithAutoDetection())
var minModTime time.Time var minModTime time.Time
if j.input.Filter != nil && j.input.Filter.MinModTime != nil { if j.input.Filter != nil && j.input.Filter.MinModTime != nil {
@@ -59,13 +64,11 @@ func (j *ScanJob) Execute(ctx context.Context, progress *job.Progress) {
} }
j.scanner.Scan(ctx, getScanHandlers(j.input, taskQueue, progress), file.ScanOptions{ j.scanner.Scan(ctx, getScanHandlers(j.input, taskQueue, progress), file.ScanOptions{
Paths: paths, Paths: paths,
ScanFilters: []file.PathFilter{newScanFilter(instance.Config, minModTime)}, ScanFilters: []file.PathFilter{newScanFilter(c, repo, minModTime)},
ZipFileExtensions: instance.Config.GetGalleryExtensions(), ZipFileExtensions: c.GetGalleryExtensions(),
ParallelTasks: instance.Config.GetParallelTasksWithAutoDetection(), ParallelTasks: c.GetParallelTasksWithAutoDetection(),
HandlerRequiredFilters: []file.Filter{ HandlerRequiredFilters: []file.Filter{newHandlerRequiredFilter(c, repo)},
newHandlerRequiredFilter(instance.Config),
},
}, progress) }, progress)
taskQueue.Close() taskQueue.Close()
@@ -123,17 +126,16 @@ type handlerRequiredFilter struct {
videoFileNamingAlgorithm models.HashAlgorithm videoFileNamingAlgorithm models.HashAlgorithm
} }
func newHandlerRequiredFilter(c *config.Instance) *handlerRequiredFilter { func newHandlerRequiredFilter(c *config.Instance, repo models.Repository) *handlerRequiredFilter {
db := instance.Database
processes := c.GetParallelTasksWithAutoDetection() processes := c.GetParallelTasksWithAutoDetection()
return &handlerRequiredFilter{ return &handlerRequiredFilter{
extensionConfig: newExtensionConfig(c), extensionConfig: newExtensionConfig(c),
txnManager: db, txnManager: repo.TxnManager,
SceneFinder: db.Scene, SceneFinder: repo.Scene,
ImageFinder: db.Image, ImageFinder: repo.Image,
GalleryFinder: db.Gallery, GalleryFinder: repo.Gallery,
CaptionUpdater: db.File, CaptionUpdater: repo.File,
FolderCache: lru.New(processes * 2), FolderCache: lru.New(processes * 2),
videoFileNamingAlgorithm: c.GetVideoFileNamingAlgorithm(), videoFileNamingAlgorithm: c.GetVideoFileNamingAlgorithm(),
} }
@@ -226,6 +228,10 @@ func (f *handlerRequiredFilter) Accept(ctx context.Context, ff models.File) bool
type scanFilter struct { type scanFilter struct {
extensionConfig extensionConfig
txnManager txn.Manager
FileFinder models.FileFinder
CaptionUpdater video.CaptionUpdater
stashPaths config.StashConfigs stashPaths config.StashConfigs
generatedPath string generatedPath string
videoExcludeRegex []*regexp.Regexp videoExcludeRegex []*regexp.Regexp
@@ -233,9 +239,12 @@ type scanFilter struct {
minModTime time.Time minModTime time.Time
} }
func newScanFilter(c *config.Instance, minModTime time.Time) *scanFilter { func newScanFilter(c *config.Instance, repo models.Repository, minModTime time.Time) *scanFilter {
return &scanFilter{ return &scanFilter{
extensionConfig: newExtensionConfig(c), extensionConfig: newExtensionConfig(c),
txnManager: repo.TxnManager,
FileFinder: repo.File,
CaptionUpdater: repo.File,
stashPaths: c.GetStashPaths(), stashPaths: c.GetStashPaths(),
generatedPath: c.GetGeneratedPath(), generatedPath: c.GetGeneratedPath(),
videoExcludeRegex: generateRegexps(c.GetExcludes()), videoExcludeRegex: generateRegexps(c.GetExcludes()),
@@ -263,7 +272,7 @@ func (f *scanFilter) Accept(ctx context.Context, path string, info fs.FileInfo)
if fsutil.MatchExtension(path, video.CaptionExts) { if fsutil.MatchExtension(path, video.CaptionExts) {
// we don't include caption files in the file scan, but we do need // we don't include caption files in the file scan, but we do need
// to handle them // to handle them
video.AssociateCaptions(ctx, path, instance.Repository, instance.Database.File, instance.Database.File) video.AssociateCaptions(ctx, path, f.txnManager, f.FileFinder, f.CaptionUpdater)
return false return false
} }
@@ -308,30 +317,37 @@ func (f *scanFilter) Accept(ctx context.Context, path string, info fs.FileInfo)
type scanConfig struct { type scanConfig struct {
isGenerateThumbnails bool isGenerateThumbnails bool
isGenerateClipPreviews bool isGenerateClipPreviews bool
createGalleriesFromFolders bool
} }
func (c *scanConfig) GetCreateGalleriesFromFolders() bool { func (c *scanConfig) GetCreateGalleriesFromFolders() bool {
return instance.Config.GetCreateGalleriesFromFolders() return c.createGalleriesFromFolders
} }
func getScanHandlers(options ScanMetadataInput, taskQueue *job.TaskQueue, progress *job.Progress) []file.Handler { func getScanHandlers(options ScanMetadataInput, taskQueue *job.TaskQueue, progress *job.Progress) []file.Handler {
db := instance.Database mgr := GetInstance()
pluginCache := instance.PluginCache c := mgr.Config
r := mgr.Repository
pluginCache := mgr.PluginCache
return []file.Handler{ return []file.Handler{
&file.FilteredHandler{ &file.FilteredHandler{
Filter: file.FilterFunc(imageFileFilter), Filter: file.FilterFunc(imageFileFilter),
Handler: &image.ScanHandler{ Handler: &image.ScanHandler{
CreatorUpdater: db.Image, CreatorUpdater: r.Image,
GalleryFinder: db.Gallery, GalleryFinder: r.Gallery,
ScanGenerator: &imageGenerators{ ScanGenerator: &imageGenerators{
input: options, input: options,
taskQueue: taskQueue, taskQueue: taskQueue,
progress: progress, progress: progress,
paths: mgr.Paths,
sequentialScanning: c.GetSequentialScanning(),
}, },
ScanConfig: &scanConfig{ ScanConfig: &scanConfig{
isGenerateThumbnails: options.ScanGenerateThumbnails, isGenerateThumbnails: options.ScanGenerateThumbnails,
isGenerateClipPreviews: options.ScanGenerateClipPreviews, isGenerateClipPreviews: options.ScanGenerateClipPreviews,
createGalleriesFromFolders: c.GetCreateGalleriesFromFolders(),
}, },
PluginCache: pluginCache, PluginCache: pluginCache,
Paths: instance.Paths, Paths: instance.Paths,
@@ -340,25 +356,28 @@ func getScanHandlers(options ScanMetadataInput, taskQueue *job.TaskQueue, progre
&file.FilteredHandler{ &file.FilteredHandler{
Filter: file.FilterFunc(galleryFileFilter), Filter: file.FilterFunc(galleryFileFilter),
Handler: &gallery.ScanHandler{ Handler: &gallery.ScanHandler{
CreatorUpdater: db.Gallery, CreatorUpdater: r.Gallery,
SceneFinderUpdater: db.Scene, SceneFinderUpdater: r.Scene,
ImageFinderUpdater: db.Image, ImageFinderUpdater: r.Image,
PluginCache: pluginCache, PluginCache: pluginCache,
}, },
}, },
&file.FilteredHandler{ &file.FilteredHandler{
Filter: file.FilterFunc(videoFileFilter), Filter: file.FilterFunc(videoFileFilter),
Handler: &scene.ScanHandler{ Handler: &scene.ScanHandler{
CreatorUpdater: db.Scene, CreatorUpdater: r.Scene,
CaptionUpdater: r.File,
PluginCache: pluginCache, PluginCache: pluginCache,
CaptionUpdater: db.File,
ScanGenerator: &sceneGenerators{ ScanGenerator: &sceneGenerators{
input: options, input: options,
taskQueue: taskQueue, taskQueue: taskQueue,
progress: progress, progress: progress,
paths: mgr.Paths,
fileNamingAlgorithm: c.GetVideoFileNamingAlgorithm(),
sequentialScanning: c.GetSequentialScanning(),
}, },
FileNamingAlgorithm: instance.Config.GetVideoFileNamingAlgorithm(), FileNamingAlgorithm: c.GetVideoFileNamingAlgorithm(),
Paths: instance.Paths, Paths: mgr.Paths,
}, },
}, },
} }
@@ -368,6 +387,9 @@ type imageGenerators struct {
input ScanMetadataInput input ScanMetadataInput
taskQueue *job.TaskQueue taskQueue *job.TaskQueue
progress *job.Progress progress *job.Progress
paths *paths.Paths
sequentialScanning bool
} }
func (g *imageGenerators) Generate(ctx context.Context, i *models.Image, f models.File) error { func (g *imageGenerators) Generate(ctx context.Context, i *models.Image, f models.File) error {
@@ -376,8 +398,6 @@ func (g *imageGenerators) Generate(ctx context.Context, i *models.Image, f model
progress := g.progress progress := g.progress
t := g.input t := g.input
path := f.Base().Path path := f.Base().Path
config := instance.Config
sequentialScanning := config.GetSequentialScanning()
if t.ScanGenerateThumbnails { if t.ScanGenerateThumbnails {
// this should be quick, so always generate sequentially // this should be quick, so always generate sequentially
@@ -405,7 +425,7 @@ func (g *imageGenerators) Generate(ctx context.Context, i *models.Image, f model
progress.Increment() progress.Increment()
} }
if sequentialScanning { if g.sequentialScanning {
previewsFn(ctx) previewsFn(ctx)
} else { } else {
g.taskQueue.Add(fmt.Sprintf("Generating preview for %s", path), previewsFn) g.taskQueue.Add(fmt.Sprintf("Generating preview for %s", path), previewsFn)
@@ -416,7 +436,7 @@ func (g *imageGenerators) Generate(ctx context.Context, i *models.Image, f model
} }
func (g *imageGenerators) generateThumbnail(ctx context.Context, i *models.Image, f models.File) error { func (g *imageGenerators) generateThumbnail(ctx context.Context, i *models.Image, f models.File) error {
thumbPath := GetInstance().Paths.Generated.GetThumbnailPath(i.Checksum, models.DefaultGthumbWidth) thumbPath := g.paths.Generated.GetThumbnailPath(i.Checksum, models.DefaultGthumbWidth)
exists, _ := fsutil.FileExists(thumbPath) exists, _ := fsutil.FileExists(thumbPath)
if exists { if exists {
return nil return nil
@@ -435,13 +455,16 @@ func (g *imageGenerators) generateThumbnail(ctx context.Context, i *models.Image
logger.Debugf("Generating thumbnail for %s", path) logger.Debugf("Generating thumbnail for %s", path)
mgr := GetInstance()
c := mgr.Config
clipPreviewOptions := image.ClipPreviewOptions{ clipPreviewOptions := image.ClipPreviewOptions{
InputArgs: instance.Config.GetTranscodeInputArgs(), InputArgs: c.GetTranscodeInputArgs(),
OutputArgs: instance.Config.GetTranscodeOutputArgs(), OutputArgs: c.GetTranscodeOutputArgs(),
Preset: instance.Config.GetPreviewPreset().String(), Preset: c.GetPreviewPreset().String(),
} }
encoder := image.NewThumbnailEncoder(instance.FFMPEG, instance.FFProbe, clipPreviewOptions) encoder := image.NewThumbnailEncoder(mgr.FFMPEG, mgr.FFProbe, clipPreviewOptions)
data, err := encoder.GetThumbnail(f, models.DefaultGthumbWidth) data, err := encoder.GetThumbnail(f, models.DefaultGthumbWidth)
if err != nil { if err != nil {
@@ -464,6 +487,10 @@ type sceneGenerators struct {
input ScanMetadataInput input ScanMetadataInput
taskQueue *job.TaskQueue taskQueue *job.TaskQueue
progress *job.Progress progress *job.Progress
paths *paths.Paths
fileNamingAlgorithm models.HashAlgorithm
sequentialScanning bool
} }
func (g *sceneGenerators) Generate(ctx context.Context, s *models.Scene, f *models.VideoFile) error { func (g *sceneGenerators) Generate(ctx context.Context, s *models.Scene, f *models.VideoFile) error {
@@ -472,9 +499,8 @@ func (g *sceneGenerators) Generate(ctx context.Context, s *models.Scene, f *mode
progress := g.progress progress := g.progress
t := g.input t := g.input
path := f.Path path := f.Path
config := instance.Config
fileNamingAlgorithm := config.GetVideoFileNamingAlgorithm() mgr := GetInstance()
sequentialScanning := config.GetSequentialScanning()
if t.ScanGenerateSprites { if t.ScanGenerateSprites {
progress.AddTotal(1) progress.AddTotal(1)
@@ -482,13 +508,13 @@ func (g *sceneGenerators) Generate(ctx context.Context, s *models.Scene, f *mode
taskSprite := GenerateSpriteTask{ taskSprite := GenerateSpriteTask{
Scene: *s, Scene: *s,
Overwrite: overwrite, Overwrite: overwrite,
fileNamingAlgorithm: fileNamingAlgorithm, fileNamingAlgorithm: g.fileNamingAlgorithm,
} }
taskSprite.Start(ctx) taskSprite.Start(ctx)
progress.Increment() progress.Increment()
} }
if sequentialScanning { if g.sequentialScanning {
spriteFn(ctx) spriteFn(ctx)
} else { } else {
g.taskQueue.Add(fmt.Sprintf("Generating sprites for %s", path), spriteFn) g.taskQueue.Add(fmt.Sprintf("Generating sprites for %s", path), spriteFn)
@@ -499,17 +525,16 @@ func (g *sceneGenerators) Generate(ctx context.Context, s *models.Scene, f *mode
progress.AddTotal(1) progress.AddTotal(1)
phashFn := func(ctx context.Context) { phashFn := func(ctx context.Context) {
taskPhash := GeneratePhashTask{ taskPhash := GeneratePhashTask{
repository: mgr.Repository,
File: f, File: f,
fileNamingAlgorithm: fileNamingAlgorithm,
txnManager: instance.Database,
fileUpdater: instance.Database.File,
Overwrite: overwrite, Overwrite: overwrite,
fileNamingAlgorithm: g.fileNamingAlgorithm,
} }
taskPhash.Start(ctx) taskPhash.Start(ctx)
progress.Increment() progress.Increment()
} }
if sequentialScanning { if g.sequentialScanning {
phashFn(ctx) phashFn(ctx)
} else { } else {
g.taskQueue.Add(fmt.Sprintf("Generating phash for %s", path), phashFn) g.taskQueue.Add(fmt.Sprintf("Generating phash for %s", path), phashFn)
@@ -521,12 +546,12 @@ func (g *sceneGenerators) Generate(ctx context.Context, s *models.Scene, f *mode
previewsFn := func(ctx context.Context) { previewsFn := func(ctx context.Context) {
options := getGeneratePreviewOptions(GeneratePreviewOptionsInput{}) options := getGeneratePreviewOptions(GeneratePreviewOptionsInput{})
g := &generate.Generator{ generator := &generate.Generator{
Encoder: instance.FFMPEG, Encoder: mgr.FFMPEG,
FFMpegConfig: instance.Config, FFMpegConfig: mgr.Config,
LockManager: instance.ReadLockManager, LockManager: mgr.ReadLockManager,
MarkerPaths: instance.Paths.SceneMarkers, MarkerPaths: g.paths.SceneMarkers,
ScenePaths: instance.Paths.Scene, ScenePaths: g.paths.Scene,
Overwrite: overwrite, Overwrite: overwrite,
} }
@@ -535,14 +560,14 @@ func (g *sceneGenerators) Generate(ctx context.Context, s *models.Scene, f *mode
ImagePreview: t.ScanGenerateImagePreviews, ImagePreview: t.ScanGenerateImagePreviews,
Options: options, Options: options,
Overwrite: overwrite, Overwrite: overwrite,
fileNamingAlgorithm: fileNamingAlgorithm, fileNamingAlgorithm: g.fileNamingAlgorithm,
generator: g, generator: generator,
} }
taskPreview.Start(ctx) taskPreview.Start(ctx)
progress.Increment() progress.Increment()
} }
if sequentialScanning { if g.sequentialScanning {
previewsFn(ctx) previewsFn(ctx)
} else { } else {
g.taskQueue.Add(fmt.Sprintf("Generating preview for %s", path), previewsFn) g.taskQueue.Add(fmt.Sprintf("Generating preview for %s", path), previewsFn)
@@ -553,8 +578,8 @@ func (g *sceneGenerators) Generate(ctx context.Context, s *models.Scene, f *mode
progress.AddTotal(1) progress.AddTotal(1)
g.taskQueue.Add(fmt.Sprintf("Generating cover for %s", path), func(ctx context.Context) { g.taskQueue.Add(fmt.Sprintf("Generating cover for %s", path), func(ctx context.Context) {
taskCover := GenerateCoverTask{ taskCover := GenerateCoverTask{
repository: mgr.Repository,
Scene: *s, Scene: *s,
txnManager: instance.Repository,
Overwrite: overwrite, Overwrite: overwrite,
} }
taskCover.Start(ctx) taskCover.Start(ctx)

View File

@@ -9,7 +9,6 @@ import (
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/scraper/stashbox" "github.com/stashapp/stash/pkg/scraper/stashbox"
"github.com/stashapp/stash/pkg/studio" "github.com/stashapp/stash/pkg/studio"
"github.com/stashapp/stash/pkg/txn"
) )
type StashBoxTagTaskType int type StashBoxTagTaskType int
@@ -92,18 +91,18 @@ func (t *StashBoxBatchTagTask) findStashBoxPerformer(ctx context.Context) (*mode
var performer *models.ScrapedPerformer var performer *models.ScrapedPerformer
var err error var err error
client := stashbox.NewClient(*t.box, instance.Repository, stashbox.Repository{ r := instance.Repository
Scene: instance.Repository.Scene,
Performer: instance.Repository.Performer, stashboxRepository := stashbox.NewRepository(r)
Tag: instance.Repository.Tag, client := stashbox.NewClient(*t.box, stashboxRepository)
Studio: instance.Repository.Studio,
})
if t.refresh { if t.refresh {
var remoteID string var remoteID string
if err := txn.WithReadTxn(ctx, instance.Repository, func(ctx context.Context) error { if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
qb := r.Performer
if !t.performer.StashIDs.Loaded() { if !t.performer.StashIDs.Loaded() {
err = t.performer.LoadStashIDs(ctx, instance.Repository.Performer) err = t.performer.LoadStashIDs(ctx, qb)
if err != nil { if err != nil {
return err return err
} }
@@ -145,8 +144,9 @@ func (t *StashBoxBatchTagTask) processMatchedPerformer(ctx context.Context, p *m
} }
// Start the transaction and update the performer // Start the transaction and update the performer
err = txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error { r := instance.Repository
qb := instance.Repository.Performer err = r.WithTxn(ctx, func(ctx context.Context) error {
qb := r.Performer
existingStashIDs, err := qb.GetStashIDs(ctx, storedID) existingStashIDs, err := qb.GetStashIDs(ctx, storedID)
if err != nil { if err != nil {
@@ -181,8 +181,10 @@ func (t *StashBoxBatchTagTask) processMatchedPerformer(ctx context.Context, p *m
return return
} }
err = txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error { r := instance.Repository
qb := instance.Repository.Performer err = r.WithTxn(ctx, func(ctx context.Context) error {
qb := r.Performer
if err := qb.Create(ctx, newPerformer); err != nil { if err := qb.Create(ctx, newPerformer); err != nil {
return err return err
} }
@@ -233,18 +235,16 @@ func (t *StashBoxBatchTagTask) findStashBoxStudio(ctx context.Context) (*models.
var studio *models.ScrapedStudio var studio *models.ScrapedStudio
var err error var err error
client := stashbox.NewClient(*t.box, instance.Repository, stashbox.Repository{ r := instance.Repository
Scene: instance.Repository.Scene,
Performer: instance.Repository.Performer, stashboxRepository := stashbox.NewRepository(r)
Tag: instance.Repository.Tag, client := stashbox.NewClient(*t.box, stashboxRepository)
Studio: instance.Repository.Studio,
})
if t.refresh { if t.refresh {
var remoteID string var remoteID string
if err := txn.WithReadTxn(ctx, instance.Repository, func(ctx context.Context) error { if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
if !t.studio.StashIDs.Loaded() { if !t.studio.StashIDs.Loaded() {
err = t.studio.LoadStashIDs(ctx, instance.Repository.Studio) err = t.studio.LoadStashIDs(ctx, r.Studio)
if err != nil { if err != nil {
return err return err
} }
@@ -293,8 +293,9 @@ func (t *StashBoxBatchTagTask) processMatchedStudio(ctx context.Context, s *mode
} }
// Start the transaction and update the studio // Start the transaction and update the studio
err = txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error { r := instance.Repository
qb := instance.Repository.Studio err = r.WithTxn(ctx, func(ctx context.Context) error {
qb := r.Studio
existingStashIDs, err := qb.GetStashIDs(ctx, storedID) existingStashIDs, err := qb.GetStashIDs(ctx, storedID)
if err != nil { if err != nil {
@@ -341,8 +342,10 @@ func (t *StashBoxBatchTagTask) processMatchedStudio(ctx context.Context, s *mode
} }
// Start the transaction and save the studio // Start the transaction and save the studio
err = txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error { r := instance.Repository
qb := instance.Repository.Studio err = r.WithTxn(ctx, func(ctx context.Context) error {
qb := r.Studio
if err := qb.Create(ctx, newStudio); err != nil { if err := qb.Create(ctx, newStudio); err != nil {
return err return err
} }
@@ -375,8 +378,10 @@ func (t *StashBoxBatchTagTask) processParentStudio(ctx context.Context, parent *
} }
// Start the transaction and save the studio // Start the transaction and save the studio
err = txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error { r := instance.Repository
qb := instance.Repository.Studio err = r.WithTxn(ctx, func(ctx context.Context) error {
qb := r.Studio
if err := qb.Create(ctx, newParentStudio); err != nil { if err := qb.Create(ctx, newParentStudio); err != nil {
return err return err
} }
@@ -408,8 +413,9 @@ func (t *StashBoxBatchTagTask) processParentStudio(ctx context.Context, parent *
} }
// Start the transaction and update the studio // Start the transaction and update the studio
err = txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error { r := instance.Repository
qb := instance.Repository.Studio err = r.WithTxn(ctx, func(ctx context.Context) error {
qb := r.Studio
existingStashIDs, err := qb.GetStashIDs(ctx, storedID) existingStashIDs, err := qb.GetStashIDs(ctx, storedID)
if err != nil { if err != nil {

View File

@@ -1,21 +1,62 @@
package static package static
import "embed" import (
"embed"
"fmt"
"io"
"io/fs"
)
//go:embed performer //go:embed performer performer_male scene image tag studio movie
var Performer embed.FS var data embed.FS
//go:embed performer_male const (
var PerformerMale embed.FS Performer = "performer"
PerformerMale = "performer_male"
//go:embed scene Scene = "scene"
var Scene embed.FS DefaultSceneImage = "scene/scene.svg"
//go:embed image Image = "image"
var Image embed.FS DefaultImageImage = "image/image.svg"
//go:embed tag Tag = "tag"
var Tag embed.FS DefaultTagImage = "tag/tag.svg"
//go:embed studio Studio = "studio"
var Studio embed.FS DefaultStudioImage = "studio/studio.svg"
Movie = "movie"
DefaultMovieImage = "movie/movie.png"
)
// Sub returns an FS rooted at path, using fs.Sub.
// It will panic if an error occurs.
func Sub(path string) fs.FS {
ret, err := fs.Sub(data, path)
if err != nil {
panic(fmt.Sprintf("creating static SubFS: %v", err))
}
return ret
}
// Open opens the file at path for reading.
// It will panic if an error occurs.
func Open(path string) fs.File {
f, err := data.Open(path)
if err != nil {
panic(fmt.Sprintf("opening static file: %v", err))
}
return f
}
// ReadAll returns the contents of the file at path.
// It will panic if an error occurs.
func ReadAll(path string) []byte {
f := Open(path)
ret, err := io.ReadAll(f)
if err != nil {
panic(fmt.Sprintf("reading static file: %v", err))
}
return ret
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 405 B

View File

@@ -11,7 +11,6 @@ import (
"github.com/stashapp/stash/pkg/job" "github.com/stashapp/stash/pkg/job"
"github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
) )
// Cleaner scans through stored file and folder instances and removes those that are no longer present on disk. // Cleaner scans through stored file and folder instances and removes those that are no longer present on disk.
@@ -112,14 +111,15 @@ func (j *cleanJob) execute(ctx context.Context) error {
folderCount int folderCount int
) )
if err := txn.WithReadTxn(ctx, j.Repository, func(ctx context.Context) error { r := j.Repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
var err error var err error
fileCount, err = j.Repository.FileStore.CountAllInPaths(ctx, j.options.Paths) fileCount, err = r.File.CountAllInPaths(ctx, j.options.Paths)
if err != nil { if err != nil {
return err return err
} }
folderCount, err = j.Repository.FolderStore.CountAllInPaths(ctx, j.options.Paths) folderCount, err = r.Folder.CountAllInPaths(ctx, j.options.Paths)
if err != nil { if err != nil {
return err return err
} }
@@ -172,13 +172,14 @@ func (j *cleanJob) assessFiles(ctx context.Context, toDelete *deleteSet) error {
progress := j.progress progress := j.progress
more := true more := true
if err := txn.WithReadTxn(ctx, j.Repository, func(ctx context.Context) error { r := j.Repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
for more { for more {
if job.IsCancelled(ctx) { if job.IsCancelled(ctx) {
return nil return nil
} }
files, err := j.Repository.FileStore.FindAllInPaths(ctx, j.options.Paths, batchSize, offset) files, err := r.File.FindAllInPaths(ctx, j.options.Paths, batchSize, offset)
if err != nil { if err != nil {
return fmt.Errorf("error querying for files: %w", err) return fmt.Errorf("error querying for files: %w", err)
} }
@@ -223,8 +224,9 @@ func (j *cleanJob) assessFiles(ctx context.Context, toDelete *deleteSet) error {
// flagFolderForDelete adds folders to the toDelete set, with the leaf folders added first // flagFolderForDelete adds folders to the toDelete set, with the leaf folders added first
func (j *cleanJob) flagFileForDelete(ctx context.Context, toDelete *deleteSet, f models.File) error { func (j *cleanJob) flagFileForDelete(ctx context.Context, toDelete *deleteSet, f models.File) error {
r := j.Repository
// add contained files first // add contained files first
containedFiles, err := j.Repository.FileStore.FindByZipFileID(ctx, f.Base().ID) containedFiles, err := r.File.FindByZipFileID(ctx, f.Base().ID)
if err != nil { if err != nil {
return fmt.Errorf("error finding contained files for %q: %w", f.Base().Path, err) return fmt.Errorf("error finding contained files for %q: %w", f.Base().Path, err)
} }
@@ -235,7 +237,7 @@ func (j *cleanJob) flagFileForDelete(ctx context.Context, toDelete *deleteSet, f
} }
// add contained folders as well // add contained folders as well
containedFolders, err := j.Repository.FolderStore.FindByZipFileID(ctx, f.Base().ID) containedFolders, err := r.Folder.FindByZipFileID(ctx, f.Base().ID)
if err != nil { if err != nil {
return fmt.Errorf("error finding contained folders for %q: %w", f.Base().Path, err) return fmt.Errorf("error finding contained folders for %q: %w", f.Base().Path, err)
} }
@@ -256,13 +258,14 @@ func (j *cleanJob) assessFolders(ctx context.Context, toDelete *deleteSet) error
progress := j.progress progress := j.progress
more := true more := true
if err := txn.WithReadTxn(ctx, j.Repository, func(ctx context.Context) error { r := j.Repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
for more { for more {
if job.IsCancelled(ctx) { if job.IsCancelled(ctx) {
return nil return nil
} }
folders, err := j.Repository.FolderStore.FindAllInPaths(ctx, j.options.Paths, batchSize, offset) folders, err := r.Folder.FindAllInPaths(ctx, j.options.Paths, batchSize, offset)
if err != nil { if err != nil {
return fmt.Errorf("error querying for folders: %w", err) return fmt.Errorf("error querying for folders: %w", err)
} }
@@ -380,14 +383,15 @@ func (j *cleanJob) shouldCleanFolder(ctx context.Context, f *models.Folder) bool
func (j *cleanJob) deleteFile(ctx context.Context, fileID models.FileID, fn string) { func (j *cleanJob) deleteFile(ctx context.Context, fileID models.FileID, fn string) {
// delete associated objects // delete associated objects
fileDeleter := NewDeleter() fileDeleter := NewDeleter()
if err := txn.WithTxn(ctx, j.Repository, func(ctx context.Context) error { r := j.Repository
if err := r.WithTxn(ctx, func(ctx context.Context) error {
fileDeleter.RegisterHooks(ctx) fileDeleter.RegisterHooks(ctx)
if err := j.fireHandlers(ctx, fileDeleter, fileID); err != nil { if err := j.fireHandlers(ctx, fileDeleter, fileID); err != nil {
return err return err
} }
return j.Repository.FileStore.Destroy(ctx, fileID) return r.File.Destroy(ctx, fileID)
}); err != nil { }); err != nil {
logger.Errorf("Error deleting file %q from database: %s", fn, err.Error()) logger.Errorf("Error deleting file %q from database: %s", fn, err.Error())
return return
@@ -397,14 +401,15 @@ func (j *cleanJob) deleteFile(ctx context.Context, fileID models.FileID, fn stri
func (j *cleanJob) deleteFolder(ctx context.Context, folderID models.FolderID, fn string) { func (j *cleanJob) deleteFolder(ctx context.Context, folderID models.FolderID, fn string) {
// delete associated objects // delete associated objects
fileDeleter := NewDeleter() fileDeleter := NewDeleter()
if err := txn.WithTxn(ctx, j.Repository, func(ctx context.Context) error { r := j.Repository
if err := r.WithTxn(ctx, func(ctx context.Context) error {
fileDeleter.RegisterHooks(ctx) fileDeleter.RegisterHooks(ctx)
if err := j.fireFolderHandlers(ctx, fileDeleter, folderID); err != nil { if err := j.fireFolderHandlers(ctx, fileDeleter, folderID); err != nil {
return err return err
} }
return j.Repository.FolderStore.Destroy(ctx, folderID) return r.Folder.Destroy(ctx, folderID)
}); err != nil { }); err != nil {
logger.Errorf("Error deleting folder %q from database: %s", fn, err.Error()) logger.Errorf("Error deleting folder %q from database: %s", fn, err.Error())
return return

View File

@@ -1,15 +1,36 @@
package file package file
import ( import (
"context"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn" "github.com/stashapp/stash/pkg/txn"
) )
// Repository provides access to storage methods for files and folders. // Repository provides access to storage methods for files and folders.
type Repository struct { type Repository struct {
txn.Manager TxnManager models.TxnManager
txn.DatabaseProvider
FileStore models.FileReaderWriter File models.FileReaderWriter
FolderStore models.FolderReaderWriter Folder models.FolderReaderWriter
}
func NewRepository(repo models.Repository) Repository {
return Repository{
TxnManager: repo.TxnManager,
File: repo.File,
Folder: repo.Folder,
}
}
func (r *Repository) WithTxn(ctx context.Context, fn txn.TxnFunc) error {
return txn.WithTxn(ctx, r.TxnManager, fn)
}
func (r *Repository) WithReadTxn(ctx context.Context, fn txn.TxnFunc) error {
return txn.WithReadTxn(ctx, r.TxnManager, fn)
}
func (r *Repository) WithDB(ctx context.Context, fn txn.TxnFunc) error {
return txn.WithDatabase(ctx, r.TxnManager, fn)
} }

View File

@@ -86,6 +86,8 @@ func (s *scanJob) detectFolderMove(ctx context.Context, file scanFile) (*models.
} }
// rejects is a set of folder ids which were found to still exist // rejects is a set of folder ids which were found to still exist
r := s.Repository
if err := symWalk(file.fs, file.Path, func(path string, d fs.DirEntry, err error) error { if err := symWalk(file.fs, file.Path, func(path string, d fs.DirEntry, err error) error {
if err != nil { if err != nil {
// don't let errors prevent scanning // don't let errors prevent scanning
@@ -118,7 +120,7 @@ func (s *scanJob) detectFolderMove(ctx context.Context, file scanFile) (*models.
} }
// check if the file exists in the database based on basename, size and mod time // check if the file exists in the database based on basename, size and mod time
existing, err := s.Repository.FileStore.FindByFileInfo(ctx, info, size) existing, err := r.File.FindByFileInfo(ctx, info, size)
if err != nil { if err != nil {
return fmt.Errorf("checking for existing file %q: %w", path, err) return fmt.Errorf("checking for existing file %q: %w", path, err)
} }
@@ -140,7 +142,7 @@ func (s *scanJob) detectFolderMove(ctx context.Context, file scanFile) (*models.
if c == nil { if c == nil {
// need to check if the folder exists in the filesystem // need to check if the folder exists in the filesystem
pf, err := s.Repository.FolderStore.Find(ctx, e.Base().ParentFolderID) pf, err := r.Folder.Find(ctx, e.Base().ParentFolderID)
if err != nil { if err != nil {
return fmt.Errorf("getting parent folder %d: %w", e.Base().ParentFolderID, err) return fmt.Errorf("getting parent folder %d: %w", e.Base().ParentFolderID, err)
} }
@@ -164,7 +166,7 @@ func (s *scanJob) detectFolderMove(ctx context.Context, file scanFile) (*models.
// parent folder is missing, possible candidate // parent folder is missing, possible candidate
// count the total number of files in the existing folder // count the total number of files in the existing folder
count, err := s.Repository.FileStore.CountByFolderID(ctx, parentFolderID) count, err := r.File.CountByFolderID(ctx, parentFolderID)
if err != nil { if err != nil {
return fmt.Errorf("counting files in folder %d: %w", parentFolderID, err) return fmt.Errorf("counting files in folder %d: %w", parentFolderID, err)
} }

View File

@@ -181,7 +181,7 @@ func (m *Mover) moveFile(oldPath, newPath string) error {
return nil return nil
} }
func (m *Mover) RegisterHooks(ctx context.Context, mgr txn.Manager) { func (m *Mover) RegisterHooks(ctx context.Context) {
txn.AddPostCommitHook(ctx, func(ctx context.Context) { txn.AddPostCommitHook(ctx, func(ctx context.Context) {
m.commit() m.commit()
}) })

View File

@@ -144,7 +144,7 @@ func (s *Scanner) Scan(ctx context.Context, handlers []Handler, options ScanOpti
ProgressReports: progressReporter, ProgressReports: progressReporter,
options: options, options: options,
txnRetryer: txn.Retryer{ txnRetryer: txn.Retryer{
Manager: s.Repository, Manager: s.Repository.TxnManager,
Retries: maxRetries, Retries: maxRetries,
}, },
} }
@@ -163,7 +163,7 @@ func (s *scanJob) withTxn(ctx context.Context, fn func(ctx context.Context) erro
} }
func (s *scanJob) withDB(ctx context.Context, fn func(ctx context.Context) error) error { func (s *scanJob) withDB(ctx context.Context, fn func(ctx context.Context) error) error {
return txn.WithDatabase(ctx, s.Repository, fn) return s.Repository.WithDB(ctx, fn)
} }
func (s *scanJob) execute(ctx context.Context) { func (s *scanJob) execute(ctx context.Context) {
@@ -439,7 +439,7 @@ func (s *scanJob) getFolderID(ctx context.Context, path string) (*models.FolderI
return &v, nil return &v, nil
} }
ret, err := s.Repository.FolderStore.FindByPath(ctx, path) ret, err := s.Repository.Folder.FindByPath(ctx, path)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -469,7 +469,7 @@ func (s *scanJob) getZipFileID(ctx context.Context, zipFile *scanFile) (*models.
return &v, nil return &v, nil
} }
ret, err := s.Repository.FileStore.FindByPath(ctx, path) ret, err := s.Repository.File.FindByPath(ctx, path)
if err != nil { if err != nil {
return nil, fmt.Errorf("getting zip file ID for %q: %w", path, err) return nil, fmt.Errorf("getting zip file ID for %q: %w", path, err)
} }
@@ -489,7 +489,7 @@ func (s *scanJob) handleFolder(ctx context.Context, file scanFile) error {
defer s.incrementProgress(file) defer s.incrementProgress(file)
// determine if folder already exists in data store (by path) // determine if folder already exists in data store (by path)
f, err := s.Repository.FolderStore.FindByPath(ctx, path) f, err := s.Repository.Folder.FindByPath(ctx, path)
if err != nil { if err != nil {
return fmt.Errorf("checking for existing folder %q: %w", path, err) return fmt.Errorf("checking for existing folder %q: %w", path, err)
} }
@@ -553,7 +553,7 @@ func (s *scanJob) onNewFolder(ctx context.Context, file scanFile) (*models.Folde
logger.Infof("%s doesn't exist. Creating new folder entry...", file.Path) logger.Infof("%s doesn't exist. Creating new folder entry...", file.Path)
}) })
if err := s.Repository.FolderStore.Create(ctx, toCreate); err != nil { if err := s.Repository.Folder.Create(ctx, toCreate); err != nil {
return nil, fmt.Errorf("creating folder %q: %w", file.Path, err) return nil, fmt.Errorf("creating folder %q: %w", file.Path, err)
} }
@@ -589,12 +589,12 @@ func (s *scanJob) handleFolderRename(ctx context.Context, file scanFile) (*model
renamedFrom.ParentFolderID = parentFolderID renamedFrom.ParentFolderID = parentFolderID
if err := s.Repository.FolderStore.Update(ctx, renamedFrom); err != nil { if err := s.Repository.Folder.Update(ctx, renamedFrom); err != nil {
return nil, fmt.Errorf("updating folder for rename %q: %w", renamedFrom.Path, err) return nil, fmt.Errorf("updating folder for rename %q: %w", renamedFrom.Path, err)
} }
// #4146 - correct sub-folders to have the correct path // #4146 - correct sub-folders to have the correct path
if err := correctSubFolderHierarchy(ctx, s.Repository.FolderStore, renamedFrom); err != nil { if err := correctSubFolderHierarchy(ctx, s.Repository.Folder, renamedFrom); err != nil {
return nil, fmt.Errorf("correcting sub folder hierarchy for %q: %w", renamedFrom.Path, err) return nil, fmt.Errorf("correcting sub folder hierarchy for %q: %w", renamedFrom.Path, err)
} }
@@ -626,7 +626,7 @@ func (s *scanJob) onExistingFolder(ctx context.Context, f scanFile, existing *mo
if update { if update {
var err error var err error
if err = s.Repository.FolderStore.Update(ctx, existing); err != nil { if err = s.Repository.Folder.Update(ctx, existing); err != nil {
return nil, fmt.Errorf("updating folder %q: %w", f.Path, err) return nil, fmt.Errorf("updating folder %q: %w", f.Path, err)
} }
} }
@@ -647,7 +647,7 @@ func (s *scanJob) handleFile(ctx context.Context, f scanFile) error {
if err := s.withDB(ctx, func(ctx context.Context) error { if err := s.withDB(ctx, func(ctx context.Context) error {
// determine if file already exists in data store // determine if file already exists in data store
var err error var err error
ff, err = s.Repository.FileStore.FindByPath(ctx, f.Path) ff, err = s.Repository.File.FindByPath(ctx, f.Path)
if err != nil { if err != nil {
return fmt.Errorf("checking for existing file %q: %w", f.Path, err) return fmt.Errorf("checking for existing file %q: %w", f.Path, err)
} }
@@ -745,7 +745,7 @@ func (s *scanJob) onNewFile(ctx context.Context, f scanFile) (models.File, error
// if not renamed, queue file for creation // if not renamed, queue file for creation
if err := s.withTxn(ctx, func(ctx context.Context) error { if err := s.withTxn(ctx, func(ctx context.Context) error {
if err := s.Repository.FileStore.Create(ctx, file); err != nil { if err := s.Repository.File.Create(ctx, file); err != nil {
return fmt.Errorf("creating file %q: %w", path, err) return fmt.Errorf("creating file %q: %w", path, err)
} }
@@ -838,7 +838,7 @@ func (s *scanJob) handleRename(ctx context.Context, f models.File, fp []models.F
var others []models.File var others []models.File
for _, tfp := range fp { for _, tfp := range fp {
thisOthers, err := s.Repository.FileStore.FindByFingerprint(ctx, tfp) thisOthers, err := s.Repository.File.FindByFingerprint(ctx, tfp)
if err != nil { if err != nil {
return nil, fmt.Errorf("getting files by fingerprint %v: %w", tfp, err) return nil, fmt.Errorf("getting files by fingerprint %v: %w", tfp, err)
} }
@@ -896,12 +896,12 @@ func (s *scanJob) handleRename(ctx context.Context, f models.File, fp []models.F
fBase.Fingerprints = otherBase.Fingerprints fBase.Fingerprints = otherBase.Fingerprints
if err := s.withTxn(ctx, func(ctx context.Context) error { if err := s.withTxn(ctx, func(ctx context.Context) error {
if err := s.Repository.FileStore.Update(ctx, f); err != nil { if err := s.Repository.File.Update(ctx, f); err != nil {
return fmt.Errorf("updating file for rename %q: %w", fBase.Path, err) return fmt.Errorf("updating file for rename %q: %w", fBase.Path, err)
} }
if s.isZipFile(fBase.Basename) { if s.isZipFile(fBase.Basename) {
if err := TransferZipFolderHierarchy(ctx, s.Repository.FolderStore, fBase.ID, otherBase.Path, fBase.Path); err != nil { if err := TransferZipFolderHierarchy(ctx, s.Repository.Folder, fBase.ID, otherBase.Path, fBase.Path); err != nil {
return fmt.Errorf("moving folder hierarchy for renamed zip file %q: %w", fBase.Path, err) return fmt.Errorf("moving folder hierarchy for renamed zip file %q: %w", fBase.Path, err)
} }
} }
@@ -963,7 +963,7 @@ func (s *scanJob) setMissingMetadata(ctx context.Context, f scanFile, existing m
// queue file for update // queue file for update
if err := s.withTxn(ctx, func(ctx context.Context) error { if err := s.withTxn(ctx, func(ctx context.Context) error {
if err := s.Repository.FileStore.Update(ctx, existing); err != nil { if err := s.Repository.File.Update(ctx, existing); err != nil {
return fmt.Errorf("updating file %q: %w", path, err) return fmt.Errorf("updating file %q: %w", path, err)
} }
@@ -986,7 +986,7 @@ func (s *scanJob) setMissingFingerprints(ctx context.Context, f scanFile, existi
existing.SetFingerprints(fp) existing.SetFingerprints(fp)
if err := s.withTxn(ctx, func(ctx context.Context) error { if err := s.withTxn(ctx, func(ctx context.Context) error {
if err := s.Repository.FileStore.Update(ctx, existing); err != nil { if err := s.Repository.File.Update(ctx, existing); err != nil {
return fmt.Errorf("updating file %q: %w", f.Path, err) return fmt.Errorf("updating file %q: %w", f.Path, err)
} }
@@ -1035,7 +1035,7 @@ func (s *scanJob) onExistingFile(ctx context.Context, f scanFile, existing model
// queue file for update // queue file for update
if err := s.withTxn(ctx, func(ctx context.Context) error { if err := s.withTxn(ctx, func(ctx context.Context) error {
if err := s.Repository.FileStore.Update(ctx, existing); err != nil { if err := s.Repository.File.Update(ctx, existing); err != nil {
return fmt.Errorf("updating file %q: %w", path, err) return fmt.Errorf("updating file %q: %w", path, err)
} }

View File

@@ -157,19 +157,19 @@ var getStudioScenarios = []stringTestScenario{
} }
func TestGetStudioName(t *testing.T) { func TestGetStudioName(t *testing.T) {
mockStudioReader := &mocks.StudioReaderWriter{} db := mocks.NewDatabase()
studioErr := errors.New("error getting image") studioErr := errors.New("error getting image")
mockStudioReader.On("Find", testCtx, studioID).Return(&models.Studio{ db.Studio.On("Find", testCtx, studioID).Return(&models.Studio{
Name: studioName, Name: studioName,
}, nil).Once() }, nil).Once()
mockStudioReader.On("Find", testCtx, missingStudioID).Return(nil, nil).Once() db.Studio.On("Find", testCtx, missingStudioID).Return(nil, nil).Once()
mockStudioReader.On("Find", testCtx, errStudioID).Return(nil, studioErr).Once() db.Studio.On("Find", testCtx, errStudioID).Return(nil, studioErr).Once()
for i, s := range getStudioScenarios { for i, s := range getStudioScenarios {
gallery := s.input gallery := s.input
json, err := GetStudioName(testCtx, mockStudioReader, &gallery) json, err := GetStudioName(testCtx, db.Studio, &gallery)
switch { switch {
case !s.err && err != nil: case !s.err && err != nil:
@@ -181,7 +181,7 @@ func TestGetStudioName(t *testing.T) {
} }
} }
mockStudioReader.AssertExpectations(t) db.AssertExpectations(t)
} }
const ( const (
@@ -258,17 +258,17 @@ var validChapters = []*models.GalleryChapter{
} }
func TestGetGalleryChaptersJSON(t *testing.T) { func TestGetGalleryChaptersJSON(t *testing.T) {
mockChapterReader := &mocks.GalleryChapterReaderWriter{} db := mocks.NewDatabase()
chaptersErr := errors.New("error getting gallery chapters") chaptersErr := errors.New("error getting gallery chapters")
mockChapterReader.On("FindByGalleryID", testCtx, galleryID).Return(validChapters, nil).Once() db.GalleryChapter.On("FindByGalleryID", testCtx, galleryID).Return(validChapters, nil).Once()
mockChapterReader.On("FindByGalleryID", testCtx, noChaptersID).Return(nil, nil).Once() db.GalleryChapter.On("FindByGalleryID", testCtx, noChaptersID).Return(nil, nil).Once()
mockChapterReader.On("FindByGalleryID", testCtx, errChaptersID).Return(nil, chaptersErr).Once() db.GalleryChapter.On("FindByGalleryID", testCtx, errChaptersID).Return(nil, chaptersErr).Once()
for i, s := range getGalleryChaptersJSONScenarios { for i, s := range getGalleryChaptersJSONScenarios {
gallery := s.input gallery := s.input
json, err := GetGalleryChaptersJSON(testCtx, mockChapterReader, &gallery) json, err := GetGalleryChaptersJSON(testCtx, db.GalleryChapter, &gallery)
switch { switch {
case !s.err && err != nil: case !s.err && err != nil:
@@ -280,4 +280,5 @@ func TestGetGalleryChaptersJSON(t *testing.T) {
} }
} }
db.AssertExpectations(t)
} }

View File

@@ -78,19 +78,19 @@ func TestImporterPreImport(t *testing.T) {
} }
func TestImporterPreImportWithStudio(t *testing.T) { func TestImporterPreImportWithStudio(t *testing.T) {
studioReaderWriter := &mocks.StudioReaderWriter{} db := mocks.NewDatabase()
i := Importer{ i := Importer{
StudioWriter: studioReaderWriter, StudioWriter: db.Studio,
Input: jsonschema.Gallery{ Input: jsonschema.Gallery{
Studio: existingStudioName, Studio: existingStudioName,
}, },
} }
studioReaderWriter.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{ db.Studio.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{
ID: existingStudioID, ID: existingStudioID,
}, nil).Once() }, nil).Once()
studioReaderWriter.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once() db.Studio.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once()
err := i.PreImport(testCtx) err := i.PreImport(testCtx)
assert.Nil(t, err) assert.Nil(t, err)
@@ -100,22 +100,22 @@ func TestImporterPreImportWithStudio(t *testing.T) {
err = i.PreImport(testCtx) err = i.PreImport(testCtx)
assert.NotNil(t, err) assert.NotNil(t, err)
studioReaderWriter.AssertExpectations(t) db.AssertExpectations(t)
} }
func TestImporterPreImportWithMissingStudio(t *testing.T) { func TestImporterPreImportWithMissingStudio(t *testing.T) {
studioReaderWriter := &mocks.StudioReaderWriter{} db := mocks.NewDatabase()
i := Importer{ i := Importer{
StudioWriter: studioReaderWriter, StudioWriter: db.Studio,
Input: jsonschema.Gallery{ Input: jsonschema.Gallery{
Studio: missingStudioName, Studio: missingStudioName,
}, },
MissingRefBehaviour: models.ImportMissingRefEnumFail, MissingRefBehaviour: models.ImportMissingRefEnumFail,
} }
studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3) db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3)
studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) { db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
s := args.Get(1).(*models.Studio) s := args.Get(1).(*models.Studio)
s.ID = existingStudioID s.ID = existingStudioID
}).Return(nil) }).Return(nil)
@@ -132,32 +132,34 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, existingStudioID, *i.gallery.StudioID) assert.Equal(t, existingStudioID, *i.gallery.StudioID)
studioReaderWriter.AssertExpectations(t) db.AssertExpectations(t)
} }
func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) { func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) {
studioReaderWriter := &mocks.StudioReaderWriter{} db := mocks.NewDatabase()
i := Importer{ i := Importer{
StudioWriter: studioReaderWriter, StudioWriter: db.Studio,
Input: jsonschema.Gallery{ Input: jsonschema.Gallery{
Studio: missingStudioName, Studio: missingStudioName,
}, },
MissingRefBehaviour: models.ImportMissingRefEnumCreate, MissingRefBehaviour: models.ImportMissingRefEnumCreate,
} }
studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once() db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once()
studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error")) db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
err := i.PreImport(testCtx) err := i.PreImport(testCtx)
assert.NotNil(t, err) assert.NotNil(t, err)
db.AssertExpectations(t)
} }
func TestImporterPreImportWithPerformer(t *testing.T) { func TestImporterPreImportWithPerformer(t *testing.T) {
performerReaderWriter := &mocks.PerformerReaderWriter{} db := mocks.NewDatabase()
i := Importer{ i := Importer{
PerformerWriter: performerReaderWriter, PerformerWriter: db.Performer,
MissingRefBehaviour: models.ImportMissingRefEnumFail, MissingRefBehaviour: models.ImportMissingRefEnumFail,
Input: jsonschema.Gallery{ Input: jsonschema.Gallery{
Performers: []string{ Performers: []string{
@@ -166,13 +168,13 @@ func TestImporterPreImportWithPerformer(t *testing.T) {
}, },
} }
performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerName}, false).Return([]*models.Performer{ db.Performer.On("FindByNames", testCtx, []string{existingPerformerName}, false).Return([]*models.Performer{
{ {
ID: existingPerformerID, ID: existingPerformerID,
Name: existingPerformerName, Name: existingPerformerName,
}, },
}, nil).Once() }, nil).Once()
performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once() db.Performer.On("FindByNames", testCtx, []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once()
err := i.PreImport(testCtx) err := i.PreImport(testCtx)
assert.Nil(t, err) assert.Nil(t, err)
@@ -182,14 +184,14 @@ func TestImporterPreImportWithPerformer(t *testing.T) {
err = i.PreImport(testCtx) err = i.PreImport(testCtx)
assert.NotNil(t, err) assert.NotNil(t, err)
performerReaderWriter.AssertExpectations(t) db.AssertExpectations(t)
} }
func TestImporterPreImportWithMissingPerformer(t *testing.T) { func TestImporterPreImportWithMissingPerformer(t *testing.T) {
performerReaderWriter := &mocks.PerformerReaderWriter{} db := mocks.NewDatabase()
i := Importer{ i := Importer{
PerformerWriter: performerReaderWriter, PerformerWriter: db.Performer,
Input: jsonschema.Gallery{ Input: jsonschema.Gallery{
Performers: []string{ Performers: []string{
missingPerformerName, missingPerformerName,
@@ -198,8 +200,8 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail, MissingRefBehaviour: models.ImportMissingRefEnumFail,
} }
performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3) db.Performer.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3)
performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Run(func(args mock.Arguments) { db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Run(func(args mock.Arguments) {
performer := args.Get(1).(*models.Performer) performer := args.Get(1).(*models.Performer)
performer.ID = existingPerformerID performer.ID = existingPerformerID
}).Return(nil) }).Return(nil)
@@ -216,14 +218,14 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, []int{existingPerformerID}, i.gallery.PerformerIDs.List()) assert.Equal(t, []int{existingPerformerID}, i.gallery.PerformerIDs.List())
performerReaderWriter.AssertExpectations(t) db.AssertExpectations(t)
} }
func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) { func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) {
performerReaderWriter := &mocks.PerformerReaderWriter{} db := mocks.NewDatabase()
i := Importer{ i := Importer{
PerformerWriter: performerReaderWriter, PerformerWriter: db.Performer,
Input: jsonschema.Gallery{ Input: jsonschema.Gallery{
Performers: []string{ Performers: []string{
missingPerformerName, missingPerformerName,
@@ -232,18 +234,20 @@ func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate, MissingRefBehaviour: models.ImportMissingRefEnumCreate,
} }
performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once() db.Performer.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once()
performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Return(errors.New("Create error")) db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Return(errors.New("Create error"))
err := i.PreImport(testCtx) err := i.PreImport(testCtx)
assert.NotNil(t, err) assert.NotNil(t, err)
db.AssertExpectations(t)
} }
func TestImporterPreImportWithTag(t *testing.T) { func TestImporterPreImportWithTag(t *testing.T) {
tagReaderWriter := &mocks.TagReaderWriter{} db := mocks.NewDatabase()
i := Importer{ i := Importer{
TagWriter: tagReaderWriter, TagWriter: db.Tag,
MissingRefBehaviour: models.ImportMissingRefEnumFail, MissingRefBehaviour: models.ImportMissingRefEnumFail,
Input: jsonschema.Gallery{ Input: jsonschema.Gallery{
Tags: []string{ Tags: []string{
@@ -252,13 +256,13 @@ func TestImporterPreImportWithTag(t *testing.T) {
}, },
} }
tagReaderWriter.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{ db.Tag.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{
{ {
ID: existingTagID, ID: existingTagID,
Name: existingTagName, Name: existingTagName,
}, },
}, nil).Once() }, nil).Once()
tagReaderWriter.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once() db.Tag.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once()
err := i.PreImport(testCtx) err := i.PreImport(testCtx)
assert.Nil(t, err) assert.Nil(t, err)
@@ -268,14 +272,14 @@ func TestImporterPreImportWithTag(t *testing.T) {
err = i.PreImport(testCtx) err = i.PreImport(testCtx)
assert.NotNil(t, err) assert.NotNil(t, err)
tagReaderWriter.AssertExpectations(t) db.AssertExpectations(t)
} }
func TestImporterPreImportWithMissingTag(t *testing.T) { func TestImporterPreImportWithMissingTag(t *testing.T) {
tagReaderWriter := &mocks.TagReaderWriter{} db := mocks.NewDatabase()
i := Importer{ i := Importer{
TagWriter: tagReaderWriter, TagWriter: db.Tag,
Input: jsonschema.Gallery{ Input: jsonschema.Gallery{
Tags: []string{ Tags: []string{
missingTagName, missingTagName,
@@ -284,8 +288,8 @@ func TestImporterPreImportWithMissingTag(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail, MissingRefBehaviour: models.ImportMissingRefEnumFail,
} }
tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3) db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3)
tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) { db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) {
t := args.Get(1).(*models.Tag) t := args.Get(1).(*models.Tag)
t.ID = existingTagID t.ID = existingTagID
}).Return(nil) }).Return(nil)
@@ -302,14 +306,14 @@ func TestImporterPreImportWithMissingTag(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, []int{existingTagID}, i.gallery.TagIDs.List()) assert.Equal(t, []int{existingTagID}, i.gallery.TagIDs.List())
tagReaderWriter.AssertExpectations(t) db.AssertExpectations(t)
} }
func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) { func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) {
tagReaderWriter := &mocks.TagReaderWriter{} db := mocks.NewDatabase()
i := Importer{ i := Importer{
TagWriter: tagReaderWriter, TagWriter: db.Tag,
Input: jsonschema.Gallery{ Input: jsonschema.Gallery{
Tags: []string{ Tags: []string{
missingTagName, missingTagName,
@@ -318,9 +322,11 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate, MissingRefBehaviour: models.ImportMissingRefEnumCreate,
} }
tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once() db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once()
tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error")) db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error"))
err := i.PreImport(testCtx) err := i.PreImport(testCtx)
assert.NotNil(t, err) assert.NotNil(t, err)
db.AssertExpectations(t)
} }

View File

@@ -130,19 +130,19 @@ var getStudioScenarios = []stringTestScenario{
} }
func TestGetStudioName(t *testing.T) { func TestGetStudioName(t *testing.T) {
mockStudioReader := &mocks.StudioReaderWriter{} db := mocks.NewDatabase()
studioErr := errors.New("error getting image") studioErr := errors.New("error getting image")
mockStudioReader.On("Find", testCtx, studioID).Return(&models.Studio{ db.Studio.On("Find", testCtx, studioID).Return(&models.Studio{
Name: studioName, Name: studioName,
}, nil).Once() }, nil).Once()
mockStudioReader.On("Find", testCtx, missingStudioID).Return(nil, nil).Once() db.Studio.On("Find", testCtx, missingStudioID).Return(nil, nil).Once()
mockStudioReader.On("Find", testCtx, errStudioID).Return(nil, studioErr).Once() db.Studio.On("Find", testCtx, errStudioID).Return(nil, studioErr).Once()
for i, s := range getStudioScenarios { for i, s := range getStudioScenarios {
image := s.input image := s.input
json, err := GetStudioName(testCtx, mockStudioReader, &image) json, err := GetStudioName(testCtx, db.Studio, &image)
switch { switch {
case !s.err && err != nil: case !s.err && err != nil:
@@ -154,5 +154,5 @@ func TestGetStudioName(t *testing.T) {
} }
} }
mockStudioReader.AssertExpectations(t) db.AssertExpectations(t)
} }

View File

@@ -40,19 +40,19 @@ func TestImporterPreImport(t *testing.T) {
} }
func TestImporterPreImportWithStudio(t *testing.T) { func TestImporterPreImportWithStudio(t *testing.T) {
studioReaderWriter := &mocks.StudioReaderWriter{} db := mocks.NewDatabase()
i := Importer{ i := Importer{
StudioWriter: studioReaderWriter, StudioWriter: db.Studio,
Input: jsonschema.Image{ Input: jsonschema.Image{
Studio: existingStudioName, Studio: existingStudioName,
}, },
} }
studioReaderWriter.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{ db.Studio.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{
ID: existingStudioID, ID: existingStudioID,
}, nil).Once() }, nil).Once()
studioReaderWriter.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once() db.Studio.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once()
err := i.PreImport(testCtx) err := i.PreImport(testCtx)
assert.Nil(t, err) assert.Nil(t, err)
@@ -62,22 +62,22 @@ func TestImporterPreImportWithStudio(t *testing.T) {
err = i.PreImport(testCtx) err = i.PreImport(testCtx)
assert.NotNil(t, err) assert.NotNil(t, err)
studioReaderWriter.AssertExpectations(t) db.AssertExpectations(t)
} }
func TestImporterPreImportWithMissingStudio(t *testing.T) { func TestImporterPreImportWithMissingStudio(t *testing.T) {
studioReaderWriter := &mocks.StudioReaderWriter{} db := mocks.NewDatabase()
i := Importer{ i := Importer{
StudioWriter: studioReaderWriter, StudioWriter: db.Studio,
Input: jsonschema.Image{ Input: jsonschema.Image{
Studio: missingStudioName, Studio: missingStudioName,
}, },
MissingRefBehaviour: models.ImportMissingRefEnumFail, MissingRefBehaviour: models.ImportMissingRefEnumFail,
} }
studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3) db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3)
studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) { db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
s := args.Get(1).(*models.Studio) s := args.Get(1).(*models.Studio)
s.ID = existingStudioID s.ID = existingStudioID
}).Return(nil) }).Return(nil)
@@ -94,32 +94,34 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, existingStudioID, *i.image.StudioID) assert.Equal(t, existingStudioID, *i.image.StudioID)
studioReaderWriter.AssertExpectations(t) db.AssertExpectations(t)
} }
func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) { func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) {
studioReaderWriter := &mocks.StudioReaderWriter{} db := mocks.NewDatabase()
i := Importer{ i := Importer{
StudioWriter: studioReaderWriter, StudioWriter: db.Studio,
Input: jsonschema.Image{ Input: jsonschema.Image{
Studio: missingStudioName, Studio: missingStudioName,
}, },
MissingRefBehaviour: models.ImportMissingRefEnumCreate, MissingRefBehaviour: models.ImportMissingRefEnumCreate,
} }
studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once() db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once()
studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error")) db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
err := i.PreImport(testCtx) err := i.PreImport(testCtx)
assert.NotNil(t, err) assert.NotNil(t, err)
db.AssertExpectations(t)
} }
func TestImporterPreImportWithPerformer(t *testing.T) { func TestImporterPreImportWithPerformer(t *testing.T) {
performerReaderWriter := &mocks.PerformerReaderWriter{} db := mocks.NewDatabase()
i := Importer{ i := Importer{
PerformerWriter: performerReaderWriter, PerformerWriter: db.Performer,
MissingRefBehaviour: models.ImportMissingRefEnumFail, MissingRefBehaviour: models.ImportMissingRefEnumFail,
Input: jsonschema.Image{ Input: jsonschema.Image{
Performers: []string{ Performers: []string{
@@ -128,13 +130,13 @@ func TestImporterPreImportWithPerformer(t *testing.T) {
}, },
} }
performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerName}, false).Return([]*models.Performer{ db.Performer.On("FindByNames", testCtx, []string{existingPerformerName}, false).Return([]*models.Performer{
{ {
ID: existingPerformerID, ID: existingPerformerID,
Name: existingPerformerName, Name: existingPerformerName,
}, },
}, nil).Once() }, nil).Once()
performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once() db.Performer.On("FindByNames", testCtx, []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once()
err := i.PreImport(testCtx) err := i.PreImport(testCtx)
assert.Nil(t, err) assert.Nil(t, err)
@@ -144,14 +146,14 @@ func TestImporterPreImportWithPerformer(t *testing.T) {
err = i.PreImport(testCtx) err = i.PreImport(testCtx)
assert.NotNil(t, err) assert.NotNil(t, err)
performerReaderWriter.AssertExpectations(t) db.AssertExpectations(t)
} }
func TestImporterPreImportWithMissingPerformer(t *testing.T) { func TestImporterPreImportWithMissingPerformer(t *testing.T) {
performerReaderWriter := &mocks.PerformerReaderWriter{} db := mocks.NewDatabase()
i := Importer{ i := Importer{
PerformerWriter: performerReaderWriter, PerformerWriter: db.Performer,
Input: jsonschema.Image{ Input: jsonschema.Image{
Performers: []string{ Performers: []string{
missingPerformerName, missingPerformerName,
@@ -160,8 +162,8 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail, MissingRefBehaviour: models.ImportMissingRefEnumFail,
} }
performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3) db.Performer.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3)
performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Run(func(args mock.Arguments) { db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Run(func(args mock.Arguments) {
performer := args.Get(1).(*models.Performer) performer := args.Get(1).(*models.Performer)
performer.ID = existingPerformerID performer.ID = existingPerformerID
}).Return(nil) }).Return(nil)
@@ -178,14 +180,14 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, []int{existingPerformerID}, i.image.PerformerIDs.List()) assert.Equal(t, []int{existingPerformerID}, i.image.PerformerIDs.List())
performerReaderWriter.AssertExpectations(t) db.AssertExpectations(t)
} }
func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) { func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) {
performerReaderWriter := &mocks.PerformerReaderWriter{} db := mocks.NewDatabase()
i := Importer{ i := Importer{
PerformerWriter: performerReaderWriter, PerformerWriter: db.Performer,
Input: jsonschema.Image{ Input: jsonschema.Image{
Performers: []string{ Performers: []string{
missingPerformerName, missingPerformerName,
@@ -194,18 +196,20 @@ func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate, MissingRefBehaviour: models.ImportMissingRefEnumCreate,
} }
performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once() db.Performer.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once()
performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Return(errors.New("Create error")) db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Return(errors.New("Create error"))
err := i.PreImport(testCtx) err := i.PreImport(testCtx)
assert.NotNil(t, err) assert.NotNil(t, err)
db.AssertExpectations(t)
} }
func TestImporterPreImportWithTag(t *testing.T) { func TestImporterPreImportWithTag(t *testing.T) {
tagReaderWriter := &mocks.TagReaderWriter{} db := mocks.NewDatabase()
i := Importer{ i := Importer{
TagWriter: tagReaderWriter, TagWriter: db.Tag,
MissingRefBehaviour: models.ImportMissingRefEnumFail, MissingRefBehaviour: models.ImportMissingRefEnumFail,
Input: jsonschema.Image{ Input: jsonschema.Image{
Tags: []string{ Tags: []string{
@@ -214,13 +218,13 @@ func TestImporterPreImportWithTag(t *testing.T) {
}, },
} }
tagReaderWriter.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{ db.Tag.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{
{ {
ID: existingTagID, ID: existingTagID,
Name: existingTagName, Name: existingTagName,
}, },
}, nil).Once() }, nil).Once()
tagReaderWriter.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once() db.Tag.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once()
err := i.PreImport(testCtx) err := i.PreImport(testCtx)
assert.Nil(t, err) assert.Nil(t, err)
@@ -230,14 +234,14 @@ func TestImporterPreImportWithTag(t *testing.T) {
err = i.PreImport(testCtx) err = i.PreImport(testCtx)
assert.NotNil(t, err) assert.NotNil(t, err)
tagReaderWriter.AssertExpectations(t) db.AssertExpectations(t)
} }
func TestImporterPreImportWithMissingTag(t *testing.T) { func TestImporterPreImportWithMissingTag(t *testing.T) {
tagReaderWriter := &mocks.TagReaderWriter{} db := mocks.NewDatabase()
i := Importer{ i := Importer{
TagWriter: tagReaderWriter, TagWriter: db.Tag,
Input: jsonschema.Image{ Input: jsonschema.Image{
Tags: []string{ Tags: []string{
missingTagName, missingTagName,
@@ -246,8 +250,8 @@ func TestImporterPreImportWithMissingTag(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail, MissingRefBehaviour: models.ImportMissingRefEnumFail,
} }
tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3) db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3)
tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) { db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) {
t := args.Get(1).(*models.Tag) t := args.Get(1).(*models.Tag)
t.ID = existingTagID t.ID = existingTagID
}).Return(nil) }).Return(nil)
@@ -264,14 +268,14 @@ func TestImporterPreImportWithMissingTag(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, []int{existingTagID}, i.image.TagIDs.List()) assert.Equal(t, []int{existingTagID}, i.image.TagIDs.List())
tagReaderWriter.AssertExpectations(t) db.AssertExpectations(t)
} }
func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) { func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) {
tagReaderWriter := &mocks.TagReaderWriter{} db := mocks.NewDatabase()
i := Importer{ i := Importer{
TagWriter: tagReaderWriter, TagWriter: db.Tag,
Input: jsonschema.Image{ Input: jsonschema.Image{
Tags: []string{ Tags: []string{
missingTagName, missingTagName,
@@ -280,9 +284,11 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate, MissingRefBehaviour: models.ImportMissingRefEnumCreate,
} }
tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once() db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once()
tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error")) db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error"))
err := i.PreImport(testCtx) err := i.PreImport(testCtx)
assert.NotNil(t, err) assert.NotNil(t, err)
db.AssertExpectations(t)
} }

View File

@@ -0,0 +1,107 @@
package mocks
import (
"context"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
"github.com/stretchr/testify/mock"
)
type Database struct {
File *FileReaderWriter
Folder *FolderReaderWriter
Gallery *GalleryReaderWriter
GalleryChapter *GalleryChapterReaderWriter
Image *ImageReaderWriter
Movie *MovieReaderWriter
Performer *PerformerReaderWriter
Scene *SceneReaderWriter
SceneMarker *SceneMarkerReaderWriter
Studio *StudioReaderWriter
Tag *TagReaderWriter
SavedFilter *SavedFilterReaderWriter
}
func (*Database) Begin(ctx context.Context, exclusive bool) (context.Context, error) {
return ctx, nil
}
func (*Database) WithDatabase(ctx context.Context) (context.Context, error) {
return ctx, nil
}
func (*Database) Commit(ctx context.Context) error {
return nil
}
func (*Database) Rollback(ctx context.Context) error {
return nil
}
func (*Database) Complete(ctx context.Context) {
}
func (*Database) AddPostCommitHook(ctx context.Context, hook txn.TxnFunc) {
}
func (*Database) AddPostRollbackHook(ctx context.Context, hook txn.TxnFunc) {
}
func (*Database) IsLocked(err error) bool {
return false
}
func (*Database) Reset() error {
return nil
}
func NewDatabase() *Database {
return &Database{
File: &FileReaderWriter{},
Folder: &FolderReaderWriter{},
Gallery: &GalleryReaderWriter{},
GalleryChapter: &GalleryChapterReaderWriter{},
Image: &ImageReaderWriter{},
Movie: &MovieReaderWriter{},
Performer: &PerformerReaderWriter{},
Scene: &SceneReaderWriter{},
SceneMarker: &SceneMarkerReaderWriter{},
Studio: &StudioReaderWriter{},
Tag: &TagReaderWriter{},
SavedFilter: &SavedFilterReaderWriter{},
}
}
func (db *Database) AssertExpectations(t mock.TestingT) {
db.File.AssertExpectations(t)
db.Folder.AssertExpectations(t)
db.Gallery.AssertExpectations(t)
db.GalleryChapter.AssertExpectations(t)
db.Image.AssertExpectations(t)
db.Movie.AssertExpectations(t)
db.Performer.AssertExpectations(t)
db.Scene.AssertExpectations(t)
db.SceneMarker.AssertExpectations(t)
db.Studio.AssertExpectations(t)
db.Tag.AssertExpectations(t)
db.SavedFilter.AssertExpectations(t)
}
func (db *Database) Repository() models.Repository {
return models.Repository{
TxnManager: db,
File: db.File,
Folder: db.Folder,
Gallery: db.Gallery,
GalleryChapter: db.GalleryChapter,
Image: db.Image,
Movie: db.Movie,
Performer: db.Performer,
Scene: db.Scene,
SceneMarker: db.SceneMarker,
Studio: db.Studio,
Tag: db.Tag,
SavedFilter: db.SavedFilter,
}
}

View File

@@ -1,59 +0,0 @@
package mocks
import (
context "context"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
)
type TxnManager struct{}
func (*TxnManager) Begin(ctx context.Context, exclusive bool) (context.Context, error) {
return ctx, nil
}
func (*TxnManager) WithDatabase(ctx context.Context) (context.Context, error) {
return ctx, nil
}
func (*TxnManager) Commit(ctx context.Context) error {
return nil
}
func (*TxnManager) Rollback(ctx context.Context) error {
return nil
}
func (*TxnManager) Complete(ctx context.Context) {
}
func (*TxnManager) AddPostCommitHook(ctx context.Context, hook txn.TxnFunc) {
}
func (*TxnManager) AddPostRollbackHook(ctx context.Context, hook txn.TxnFunc) {
}
func (*TxnManager) IsLocked(err error) bool {
return false
}
func (*TxnManager) Reset() error {
return nil
}
func NewTxnRepository() models.Repository {
return models.Repository{
TxnManager: &TxnManager{},
Gallery: &GalleryReaderWriter{},
GalleryChapter: &GalleryChapterReaderWriter{},
Image: &ImageReaderWriter{},
Movie: &MovieReaderWriter{},
Performer: &PerformerReaderWriter{},
Scene: &SceneReaderWriter{},
SceneMarker: &SceneMarkerReaderWriter{},
Studio: &StudioReaderWriter{},
Tag: &TagReaderWriter{},
SavedFilter: &SavedFilterReaderWriter{},
}
}

View File

@@ -49,5 +49,3 @@ func NewMoviePartial() MoviePartial {
UpdatedAt: NewOptionalTime(currentTime), UpdatedAt: NewOptionalTime(currentTime),
} }
} }
var DefaultMovieImage = ""

View File

@@ -1,17 +1,18 @@
package models package models
import ( import (
"context"
"github.com/stashapp/stash/pkg/txn" "github.com/stashapp/stash/pkg/txn"
) )
type TxnManager interface { type TxnManager interface {
txn.Manager txn.Manager
txn.DatabaseProvider txn.DatabaseProvider
Reset() error
} }
type Repository struct { type Repository struct {
TxnManager TxnManager TxnManager
File FileReaderWriter File FileReaderWriter
Folder FolderReaderWriter Folder FolderReaderWriter
@@ -26,3 +27,15 @@ type Repository struct {
Tag TagReaderWriter Tag TagReaderWriter
SavedFilter SavedFilterReaderWriter SavedFilter SavedFilterReaderWriter
} }
func (r *Repository) WithTxn(ctx context.Context, fn txn.TxnFunc) error {
return txn.WithTxn(ctx, r.TxnManager, fn)
}
func (r *Repository) WithReadTxn(ctx context.Context, fn txn.TxnFunc) error {
return txn.WithReadTxn(ctx, r.TxnManager, fn)
}
func (r *Repository) WithDB(ctx context.Context, fn txn.TxnFunc) error {
return txn.WithDatabase(ctx, r.TxnManager, fn)
}

View File

@@ -168,34 +168,32 @@ func initTestTable() {
func TestToJSON(t *testing.T) { func TestToJSON(t *testing.T) {
initTestTable() initTestTable()
mockMovieReader := &mocks.MovieReaderWriter{} db := mocks.NewDatabase()
imageErr := errors.New("error getting image") imageErr := errors.New("error getting image")
mockMovieReader.On("GetFrontImage", testCtx, movieID).Return(frontImageBytes, nil).Once() db.Movie.On("GetFrontImage", testCtx, movieID).Return(frontImageBytes, nil).Once()
mockMovieReader.On("GetFrontImage", testCtx, missingStudioMovieID).Return(frontImageBytes, nil).Once() db.Movie.On("GetFrontImage", testCtx, missingStudioMovieID).Return(frontImageBytes, nil).Once()
mockMovieReader.On("GetFrontImage", testCtx, emptyID).Return(nil, nil).Once().Maybe() db.Movie.On("GetFrontImage", testCtx, emptyID).Return(nil, nil).Once().Maybe()
mockMovieReader.On("GetFrontImage", testCtx, errFrontImageID).Return(nil, imageErr).Once() db.Movie.On("GetFrontImage", testCtx, errFrontImageID).Return(nil, imageErr).Once()
mockMovieReader.On("GetFrontImage", testCtx, errBackImageID).Return(frontImageBytes, nil).Once() db.Movie.On("GetFrontImage", testCtx, errBackImageID).Return(frontImageBytes, nil).Once()
mockMovieReader.On("GetBackImage", testCtx, movieID).Return(backImageBytes, nil).Once() db.Movie.On("GetBackImage", testCtx, movieID).Return(backImageBytes, nil).Once()
mockMovieReader.On("GetBackImage", testCtx, missingStudioMovieID).Return(backImageBytes, nil).Once() db.Movie.On("GetBackImage", testCtx, missingStudioMovieID).Return(backImageBytes, nil).Once()
mockMovieReader.On("GetBackImage", testCtx, emptyID).Return(nil, nil).Once() db.Movie.On("GetBackImage", testCtx, emptyID).Return(nil, nil).Once()
mockMovieReader.On("GetBackImage", testCtx, errBackImageID).Return(nil, imageErr).Once() db.Movie.On("GetBackImage", testCtx, errBackImageID).Return(nil, imageErr).Once()
mockMovieReader.On("GetBackImage", testCtx, errFrontImageID).Return(backImageBytes, nil).Maybe() db.Movie.On("GetBackImage", testCtx, errFrontImageID).Return(backImageBytes, nil).Maybe()
mockMovieReader.On("GetBackImage", testCtx, errStudioMovieID).Return(backImageBytes, nil).Maybe() db.Movie.On("GetBackImage", testCtx, errStudioMovieID).Return(backImageBytes, nil).Maybe()
mockStudioReader := &mocks.StudioReaderWriter{}
studioErr := errors.New("error getting studio") studioErr := errors.New("error getting studio")
mockStudioReader.On("Find", testCtx, studioID).Return(&movieStudio, nil) db.Studio.On("Find", testCtx, studioID).Return(&movieStudio, nil)
mockStudioReader.On("Find", testCtx, missingStudioID).Return(nil, nil) db.Studio.On("Find", testCtx, missingStudioID).Return(nil, nil)
mockStudioReader.On("Find", testCtx, errStudioID).Return(nil, studioErr) db.Studio.On("Find", testCtx, errStudioID).Return(nil, studioErr)
for i, s := range scenarios { for i, s := range scenarios {
movie := s.movie movie := s.movie
json, err := ToJSON(testCtx, mockMovieReader, mockStudioReader, &movie) json, err := ToJSON(testCtx, db.Movie, db.Studio, &movie)
switch { switch {
case !s.err && err != nil: case !s.err && err != nil:
@@ -207,6 +205,5 @@ func TestToJSON(t *testing.T) {
} }
} }
mockMovieReader.AssertExpectations(t) db.AssertExpectations(t)
mockStudioReader.AssertExpectations(t)
} }

View File

@@ -69,10 +69,11 @@ func TestImporterPreImport(t *testing.T) {
} }
func TestImporterPreImportWithStudio(t *testing.T) { func TestImporterPreImportWithStudio(t *testing.T) {
studioReaderWriter := &mocks.StudioReaderWriter{} db := mocks.NewDatabase()
i := Importer{ i := Importer{
StudioWriter: studioReaderWriter, ReaderWriter: db.Movie,
StudioWriter: db.Studio,
Input: jsonschema.Movie{ Input: jsonschema.Movie{
Name: movieName, Name: movieName,
FrontImage: frontImage, FrontImage: frontImage,
@@ -82,10 +83,10 @@ func TestImporterPreImportWithStudio(t *testing.T) {
}, },
} }
studioReaderWriter.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{ db.Studio.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{
ID: existingStudioID, ID: existingStudioID,
}, nil).Once() }, nil).Once()
studioReaderWriter.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once() db.Studio.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once()
err := i.PreImport(testCtx) err := i.PreImport(testCtx)
assert.Nil(t, err) assert.Nil(t, err)
@@ -95,14 +96,15 @@ func TestImporterPreImportWithStudio(t *testing.T) {
err = i.PreImport(testCtx) err = i.PreImport(testCtx)
assert.NotNil(t, err) assert.NotNil(t, err)
studioReaderWriter.AssertExpectations(t) db.AssertExpectations(t)
} }
func TestImporterPreImportWithMissingStudio(t *testing.T) { func TestImporterPreImportWithMissingStudio(t *testing.T) {
studioReaderWriter := &mocks.StudioReaderWriter{} db := mocks.NewDatabase()
i := Importer{ i := Importer{
StudioWriter: studioReaderWriter, ReaderWriter: db.Movie,
StudioWriter: db.Studio,
Input: jsonschema.Movie{ Input: jsonschema.Movie{
Name: movieName, Name: movieName,
FrontImage: frontImage, FrontImage: frontImage,
@@ -111,8 +113,8 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail, MissingRefBehaviour: models.ImportMissingRefEnumFail,
} }
studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3) db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3)
studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) { db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
s := args.Get(1).(*models.Studio) s := args.Get(1).(*models.Studio)
s.ID = existingStudioID s.ID = existingStudioID
}).Return(nil) }).Return(nil)
@@ -129,14 +131,15 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, existingStudioID, *i.movie.StudioID) assert.Equal(t, existingStudioID, *i.movie.StudioID)
studioReaderWriter.AssertExpectations(t) db.AssertExpectations(t)
} }
func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) { func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) {
studioReaderWriter := &mocks.StudioReaderWriter{} db := mocks.NewDatabase()
i := Importer{ i := Importer{
StudioWriter: studioReaderWriter, ReaderWriter: db.Movie,
StudioWriter: db.Studio,
Input: jsonschema.Movie{ Input: jsonschema.Movie{
Name: movieName, Name: movieName,
FrontImage: frontImage, FrontImage: frontImage,
@@ -145,27 +148,30 @@ func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate, MissingRefBehaviour: models.ImportMissingRefEnumCreate,
} }
studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once() db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once()
studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error")) db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
err := i.PreImport(testCtx) err := i.PreImport(testCtx)
assert.NotNil(t, err) assert.NotNil(t, err)
db.AssertExpectations(t)
} }
func TestImporterPostImport(t *testing.T) { func TestImporterPostImport(t *testing.T) {
readerWriter := &mocks.MovieReaderWriter{} db := mocks.NewDatabase()
i := Importer{ i := Importer{
ReaderWriter: readerWriter, ReaderWriter: db.Movie,
StudioWriter: db.Studio,
frontImageData: frontImageBytes, frontImageData: frontImageBytes,
backImageData: backImageBytes, backImageData: backImageBytes,
} }
updateMovieImageErr := errors.New("UpdateImages error") updateMovieImageErr := errors.New("UpdateImages error")
readerWriter.On("UpdateFrontImage", testCtx, movieID, frontImageBytes).Return(nil).Once() db.Movie.On("UpdateFrontImage", testCtx, movieID, frontImageBytes).Return(nil).Once()
readerWriter.On("UpdateBackImage", testCtx, movieID, backImageBytes).Return(nil).Once() db.Movie.On("UpdateBackImage", testCtx, movieID, backImageBytes).Return(nil).Once()
readerWriter.On("UpdateFrontImage", testCtx, errImageID, frontImageBytes).Return(updateMovieImageErr).Once() db.Movie.On("UpdateFrontImage", testCtx, errImageID, frontImageBytes).Return(updateMovieImageErr).Once()
err := i.PostImport(testCtx, movieID) err := i.PostImport(testCtx, movieID)
assert.Nil(t, err) assert.Nil(t, err)
@@ -173,25 +179,26 @@ func TestImporterPostImport(t *testing.T) {
err = i.PostImport(testCtx, errImageID) err = i.PostImport(testCtx, errImageID)
assert.NotNil(t, err) assert.NotNil(t, err)
readerWriter.AssertExpectations(t) db.AssertExpectations(t)
} }
func TestImporterFindExistingID(t *testing.T) { func TestImporterFindExistingID(t *testing.T) {
readerWriter := &mocks.MovieReaderWriter{} db := mocks.NewDatabase()
i := Importer{ i := Importer{
ReaderWriter: readerWriter, ReaderWriter: db.Movie,
StudioWriter: db.Studio,
Input: jsonschema.Movie{ Input: jsonschema.Movie{
Name: movieName, Name: movieName,
}, },
} }
errFindByName := errors.New("FindByName error") errFindByName := errors.New("FindByName error")
readerWriter.On("FindByName", testCtx, movieName, false).Return(nil, nil).Once() db.Movie.On("FindByName", testCtx, movieName, false).Return(nil, nil).Once()
readerWriter.On("FindByName", testCtx, existingMovieName, false).Return(&models.Movie{ db.Movie.On("FindByName", testCtx, existingMovieName, false).Return(&models.Movie{
ID: existingMovieID, ID: existingMovieID,
}, nil).Once() }, nil).Once()
readerWriter.On("FindByName", testCtx, movieNameErr, false).Return(nil, errFindByName).Once() db.Movie.On("FindByName", testCtx, movieNameErr, false).Return(nil, errFindByName).Once()
id, err := i.FindExistingID(testCtx) id, err := i.FindExistingID(testCtx)
assert.Nil(t, id) assert.Nil(t, id)
@@ -207,11 +214,11 @@ func TestImporterFindExistingID(t *testing.T) {
assert.Nil(t, id) assert.Nil(t, id)
assert.NotNil(t, err) assert.NotNil(t, err)
readerWriter.AssertExpectations(t) db.AssertExpectations(t)
} }
func TestCreate(t *testing.T) { func TestCreate(t *testing.T) {
readerWriter := &mocks.MovieReaderWriter{} db := mocks.NewDatabase()
movie := models.Movie{ movie := models.Movie{
Name: movieName, Name: movieName,
@@ -222,16 +229,17 @@ func TestCreate(t *testing.T) {
} }
i := Importer{ i := Importer{
ReaderWriter: readerWriter, ReaderWriter: db.Movie,
StudioWriter: db.Studio,
movie: movie, movie: movie,
} }
errCreate := errors.New("Create error") errCreate := errors.New("Create error")
readerWriter.On("Create", testCtx, &movie).Run(func(args mock.Arguments) { db.Movie.On("Create", testCtx, &movie).Run(func(args mock.Arguments) {
m := args.Get(1).(*models.Movie) m := args.Get(1).(*models.Movie)
m.ID = movieID m.ID = movieID
}).Return(nil).Once() }).Return(nil).Once()
readerWriter.On("Create", testCtx, &movieErr).Return(errCreate).Once() db.Movie.On("Create", testCtx, &movieErr).Return(errCreate).Once()
id, err := i.Create(testCtx) id, err := i.Create(testCtx)
assert.Equal(t, movieID, *id) assert.Equal(t, movieID, *id)
@@ -242,11 +250,11 @@ func TestCreate(t *testing.T) {
assert.Nil(t, id) assert.Nil(t, id)
assert.NotNil(t, err) assert.NotNil(t, err)
readerWriter.AssertExpectations(t) db.AssertExpectations(t)
} }
func TestUpdate(t *testing.T) { func TestUpdate(t *testing.T) {
readerWriter := &mocks.MovieReaderWriter{} db := mocks.NewDatabase()
movie := models.Movie{ movie := models.Movie{
Name: movieName, Name: movieName,
@@ -257,7 +265,8 @@ func TestUpdate(t *testing.T) {
} }
i := Importer{ i := Importer{
ReaderWriter: readerWriter, ReaderWriter: db.Movie,
StudioWriter: db.Studio,
movie: movie, movie: movie,
} }
@@ -265,7 +274,7 @@ func TestUpdate(t *testing.T) {
// id needs to be set for the mock input // id needs to be set for the mock input
movie.ID = movieID movie.ID = movieID
readerWriter.On("Update", testCtx, &movie).Return(nil).Once() db.Movie.On("Update", testCtx, &movie).Return(nil).Once()
err := i.Update(testCtx, movieID) err := i.Update(testCtx, movieID)
assert.Nil(t, err) assert.Nil(t, err)
@@ -274,10 +283,10 @@ func TestUpdate(t *testing.T) {
// need to set id separately // need to set id separately
movieErr.ID = errImageID movieErr.ID = errImageID
readerWriter.On("Update", testCtx, &movieErr).Return(errUpdate).Once() db.Movie.On("Update", testCtx, &movieErr).Return(errUpdate).Once()
err = i.Update(testCtx, errImageID) err = i.Update(testCtx, errImageID)
assert.NotNil(t, err) assert.NotNil(t, err)
readerWriter.AssertExpectations(t) db.AssertExpectations(t)
} }

View File

@@ -203,17 +203,17 @@ func initTestTable() {
func TestToJSON(t *testing.T) { func TestToJSON(t *testing.T) {
initTestTable() initTestTable()
mockPerformerReader := &mocks.PerformerReaderWriter{} db := mocks.NewDatabase()
imageErr := errors.New("error getting image") imageErr := errors.New("error getting image")
mockPerformerReader.On("GetImage", testCtx, performerID).Return(imageBytes, nil).Once() db.Performer.On("GetImage", testCtx, performerID).Return(imageBytes, nil).Once()
mockPerformerReader.On("GetImage", testCtx, noImageID).Return(nil, nil).Once() db.Performer.On("GetImage", testCtx, noImageID).Return(nil, nil).Once()
mockPerformerReader.On("GetImage", testCtx, errImageID).Return(nil, imageErr).Once() db.Performer.On("GetImage", testCtx, errImageID).Return(nil, imageErr).Once()
for i, s := range scenarios { for i, s := range scenarios {
tag := s.input tag := s.input
json, err := ToJSON(testCtx, mockPerformerReader, &tag) json, err := ToJSON(testCtx, db.Performer, &tag)
switch { switch {
case !s.err && err != nil: case !s.err && err != nil:
@@ -225,5 +225,5 @@ func TestToJSON(t *testing.T) {
} }
} }
mockPerformerReader.AssertExpectations(t) db.AssertExpectations(t)
} }

View File

@@ -63,10 +63,11 @@ func TestImporterPreImport(t *testing.T) {
} }
func TestImporterPreImportWithTag(t *testing.T) { func TestImporterPreImportWithTag(t *testing.T) {
tagReaderWriter := &mocks.TagReaderWriter{} db := mocks.NewDatabase()
i := Importer{ i := Importer{
TagWriter: tagReaderWriter, ReaderWriter: db.Performer,
TagWriter: db.Tag,
MissingRefBehaviour: models.ImportMissingRefEnumFail, MissingRefBehaviour: models.ImportMissingRefEnumFail,
Input: jsonschema.Performer{ Input: jsonschema.Performer{
Tags: []string{ Tags: []string{
@@ -75,13 +76,13 @@ func TestImporterPreImportWithTag(t *testing.T) {
}, },
} }
tagReaderWriter.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{ db.Tag.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{
{ {
ID: existingTagID, ID: existingTagID,
Name: existingTagName, Name: existingTagName,
}, },
}, nil).Once() }, nil).Once()
tagReaderWriter.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once() db.Tag.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once()
err := i.PreImport(testCtx) err := i.PreImport(testCtx)
assert.Nil(t, err) assert.Nil(t, err)
@@ -91,14 +92,15 @@ func TestImporterPreImportWithTag(t *testing.T) {
err = i.PreImport(testCtx) err = i.PreImport(testCtx)
assert.NotNil(t, err) assert.NotNil(t, err)
tagReaderWriter.AssertExpectations(t) db.AssertExpectations(t)
} }
func TestImporterPreImportWithMissingTag(t *testing.T) { func TestImporterPreImportWithMissingTag(t *testing.T) {
tagReaderWriter := &mocks.TagReaderWriter{} db := mocks.NewDatabase()
i := Importer{ i := Importer{
TagWriter: tagReaderWriter, ReaderWriter: db.Performer,
TagWriter: db.Tag,
Input: jsonschema.Performer{ Input: jsonschema.Performer{
Tags: []string{ Tags: []string{
missingTagName, missingTagName,
@@ -107,8 +109,8 @@ func TestImporterPreImportWithMissingTag(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail, MissingRefBehaviour: models.ImportMissingRefEnumFail,
} }
tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3) db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3)
tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) { db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) {
t := args.Get(1).(*models.Tag) t := args.Get(1).(*models.Tag)
t.ID = existingTagID t.ID = existingTagID
}).Return(nil) }).Return(nil)
@@ -125,14 +127,15 @@ func TestImporterPreImportWithMissingTag(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, existingTagID, i.performer.TagIDs.List()[0]) assert.Equal(t, existingTagID, i.performer.TagIDs.List()[0])
tagReaderWriter.AssertExpectations(t) db.AssertExpectations(t)
} }
func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) { func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) {
tagReaderWriter := &mocks.TagReaderWriter{} db := mocks.NewDatabase()
i := Importer{ i := Importer{
TagWriter: tagReaderWriter, ReaderWriter: db.Performer,
TagWriter: db.Tag,
Input: jsonschema.Performer{ Input: jsonschema.Performer{
Tags: []string{ Tags: []string{
missingTagName, missingTagName,
@@ -141,25 +144,28 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate, MissingRefBehaviour: models.ImportMissingRefEnumCreate,
} }
tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once() db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once()
tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error")) db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error"))
err := i.PreImport(testCtx) err := i.PreImport(testCtx)
assert.NotNil(t, err) assert.NotNil(t, err)
db.AssertExpectations(t)
} }
func TestImporterPostImport(t *testing.T) { func TestImporterPostImport(t *testing.T) {
readerWriter := &mocks.PerformerReaderWriter{} db := mocks.NewDatabase()
i := Importer{ i := Importer{
ReaderWriter: readerWriter, ReaderWriter: db.Performer,
TagWriter: db.Tag,
imageData: imageBytes, imageData: imageBytes,
} }
updatePerformerImageErr := errors.New("UpdateImage error") updatePerformerImageErr := errors.New("UpdateImage error")
readerWriter.On("UpdateImage", testCtx, performerID, imageBytes).Return(nil).Once() db.Performer.On("UpdateImage", testCtx, performerID, imageBytes).Return(nil).Once()
readerWriter.On("UpdateImage", testCtx, errImageID, imageBytes).Return(updatePerformerImageErr).Once() db.Performer.On("UpdateImage", testCtx, errImageID, imageBytes).Return(updatePerformerImageErr).Once()
err := i.PostImport(testCtx, performerID) err := i.PostImport(testCtx, performerID)
assert.Nil(t, err) assert.Nil(t, err)
@@ -167,14 +173,15 @@ func TestImporterPostImport(t *testing.T) {
err = i.PostImport(testCtx, errImageID) err = i.PostImport(testCtx, errImageID)
assert.NotNil(t, err) assert.NotNil(t, err)
readerWriter.AssertExpectations(t) db.AssertExpectations(t)
} }
func TestImporterFindExistingID(t *testing.T) { func TestImporterFindExistingID(t *testing.T) {
readerWriter := &mocks.PerformerReaderWriter{} db := mocks.NewDatabase()
i := Importer{ i := Importer{
ReaderWriter: readerWriter, ReaderWriter: db.Performer,
TagWriter: db.Tag,
Input: jsonschema.Performer{ Input: jsonschema.Performer{
Name: performerName, Name: performerName,
}, },
@@ -195,13 +202,13 @@ func TestImporterFindExistingID(t *testing.T) {
} }
errFindByNames := errors.New("FindByNames error") errFindByNames := errors.New("FindByNames error")
readerWriter.On("Query", testCtx, performerFilter(performerName), findFilter).Return(nil, 0, nil).Once() db.Performer.On("Query", testCtx, performerFilter(performerName), findFilter).Return(nil, 0, nil).Once()
readerWriter.On("Query", testCtx, performerFilter(existingPerformerName), findFilter).Return([]*models.Performer{ db.Performer.On("Query", testCtx, performerFilter(existingPerformerName), findFilter).Return([]*models.Performer{
{ {
ID: existingPerformerID, ID: existingPerformerID,
}, },
}, 1, nil).Once() }, 1, nil).Once()
readerWriter.On("Query", testCtx, performerFilter(performerNameErr), findFilter).Return(nil, 0, errFindByNames).Once() db.Performer.On("Query", testCtx, performerFilter(performerNameErr), findFilter).Return(nil, 0, errFindByNames).Once()
id, err := i.FindExistingID(testCtx) id, err := i.FindExistingID(testCtx)
assert.Nil(t, id) assert.Nil(t, id)
@@ -217,11 +224,11 @@ func TestImporterFindExistingID(t *testing.T) {
assert.Nil(t, id) assert.Nil(t, id)
assert.NotNil(t, err) assert.NotNil(t, err)
readerWriter.AssertExpectations(t) db.AssertExpectations(t)
} }
func TestCreate(t *testing.T) { func TestCreate(t *testing.T) {
readerWriter := &mocks.PerformerReaderWriter{} db := mocks.NewDatabase()
performer := models.Performer{ performer := models.Performer{
Name: performerName, Name: performerName,
@@ -232,16 +239,17 @@ func TestCreate(t *testing.T) {
} }
i := Importer{ i := Importer{
ReaderWriter: readerWriter, ReaderWriter: db.Performer,
TagWriter: db.Tag,
performer: performer, performer: performer,
} }
errCreate := errors.New("Create error") errCreate := errors.New("Create error")
readerWriter.On("Create", testCtx, &performer).Run(func(args mock.Arguments) { db.Performer.On("Create", testCtx, &performer).Run(func(args mock.Arguments) {
arg := args.Get(1).(*models.Performer) arg := args.Get(1).(*models.Performer)
arg.ID = performerID arg.ID = performerID
}).Return(nil).Once() }).Return(nil).Once()
readerWriter.On("Create", testCtx, &performerErr).Return(errCreate).Once() db.Performer.On("Create", testCtx, &performerErr).Return(errCreate).Once()
id, err := i.Create(testCtx) id, err := i.Create(testCtx)
assert.Equal(t, performerID, *id) assert.Equal(t, performerID, *id)
@@ -252,11 +260,11 @@ func TestCreate(t *testing.T) {
assert.Nil(t, id) assert.Nil(t, id)
assert.NotNil(t, err) assert.NotNil(t, err)
readerWriter.AssertExpectations(t) db.AssertExpectations(t)
} }
func TestUpdate(t *testing.T) { func TestUpdate(t *testing.T) {
readerWriter := &mocks.PerformerReaderWriter{} db := mocks.NewDatabase()
performer := models.Performer{ performer := models.Performer{
Name: performerName, Name: performerName,
@@ -267,7 +275,8 @@ func TestUpdate(t *testing.T) {
} }
i := Importer{ i := Importer{
ReaderWriter: readerWriter, ReaderWriter: db.Performer,
TagWriter: db.Tag,
performer: performer, performer: performer,
} }
@@ -275,7 +284,7 @@ func TestUpdate(t *testing.T) {
// id needs to be set for the mock input // id needs to be set for the mock input
performer.ID = performerID performer.ID = performerID
readerWriter.On("Update", testCtx, &performer).Return(nil).Once() db.Performer.On("Update", testCtx, &performer).Return(nil).Once()
err := i.Update(testCtx, performerID) err := i.Update(testCtx, performerID)
assert.Nil(t, err) assert.Nil(t, err)
@@ -284,10 +293,10 @@ func TestUpdate(t *testing.T) {
// need to set id separately // need to set id separately
performerErr.ID = errImageID performerErr.ID = errImageID
readerWriter.On("Update", testCtx, &performerErr).Return(errUpdate).Once() db.Performer.On("Update", testCtx, &performerErr).Return(errUpdate).Once()
err = i.Update(testCtx, errImageID) err = i.Update(testCtx, errImageID)
assert.NotNil(t, err) assert.NotNil(t, err)
readerWriter.AssertExpectations(t) db.AssertExpectations(t)
} }

View File

@@ -186,17 +186,17 @@ var scenarios = []basicTestScenario{
} }
func TestToJSON(t *testing.T) { func TestToJSON(t *testing.T) {
mockSceneReader := &mocks.SceneReaderWriter{} db := mocks.NewDatabase()
imageErr := errors.New("error getting image") imageErr := errors.New("error getting image")
mockSceneReader.On("GetCover", testCtx, sceneID).Return(imageBytes, nil).Once() db.Scene.On("GetCover", testCtx, sceneID).Return(imageBytes, nil).Once()
mockSceneReader.On("GetCover", testCtx, noImageID).Return(nil, nil).Once() db.Scene.On("GetCover", testCtx, noImageID).Return(nil, nil).Once()
mockSceneReader.On("GetCover", testCtx, errImageID).Return(nil, imageErr).Once() db.Scene.On("GetCover", testCtx, errImageID).Return(nil, imageErr).Once()
for i, s := range scenarios { for i, s := range scenarios {
scene := s.input scene := s.input
json, err := ToBasicJSON(testCtx, mockSceneReader, &scene) json, err := ToBasicJSON(testCtx, db.Scene, &scene)
switch { switch {
case !s.err && err != nil: case !s.err && err != nil:
@@ -208,7 +208,7 @@ func TestToJSON(t *testing.T) {
} }
} }
mockSceneReader.AssertExpectations(t) db.AssertExpectations(t)
} }
func createStudioScene(studioID int) models.Scene { func createStudioScene(studioID int) models.Scene {
@@ -242,19 +242,19 @@ var getStudioScenarios = []stringTestScenario{
} }
func TestGetStudioName(t *testing.T) { func TestGetStudioName(t *testing.T) {
mockStudioReader := &mocks.StudioReaderWriter{} db := mocks.NewDatabase()
studioErr := errors.New("error getting image") studioErr := errors.New("error getting image")
mockStudioReader.On("Find", testCtx, studioID).Return(&models.Studio{ db.Studio.On("Find", testCtx, studioID).Return(&models.Studio{
Name: studioName, Name: studioName,
}, nil).Once() }, nil).Once()
mockStudioReader.On("Find", testCtx, missingStudioID).Return(nil, nil).Once() db.Studio.On("Find", testCtx, missingStudioID).Return(nil, nil).Once()
mockStudioReader.On("Find", testCtx, errStudioID).Return(nil, studioErr).Once() db.Studio.On("Find", testCtx, errStudioID).Return(nil, studioErr).Once()
for i, s := range getStudioScenarios { for i, s := range getStudioScenarios {
scene := s.input scene := s.input
json, err := GetStudioName(testCtx, mockStudioReader, &scene) json, err := GetStudioName(testCtx, db.Studio, &scene)
switch { switch {
case !s.err && err != nil: case !s.err && err != nil:
@@ -266,7 +266,7 @@ func TestGetStudioName(t *testing.T) {
} }
} }
mockStudioReader.AssertExpectations(t) db.AssertExpectations(t)
} }
type stringSliceTestScenario struct { type stringSliceTestScenario struct {
@@ -305,17 +305,17 @@ func getTags(names []string) []*models.Tag {
} }
func TestGetTagNames(t *testing.T) { func TestGetTagNames(t *testing.T) {
mockTagReader := &mocks.TagReaderWriter{} db := mocks.NewDatabase()
tagErr := errors.New("error getting tag") tagErr := errors.New("error getting tag")
mockTagReader.On("FindBySceneID", testCtx, sceneID).Return(getTags(names), nil).Once() db.Tag.On("FindBySceneID", testCtx, sceneID).Return(getTags(names), nil).Once()
mockTagReader.On("FindBySceneID", testCtx, noTagsID).Return(nil, nil).Once() db.Tag.On("FindBySceneID", testCtx, noTagsID).Return(nil, nil).Once()
mockTagReader.On("FindBySceneID", testCtx, errTagsID).Return(nil, tagErr).Once() db.Tag.On("FindBySceneID", testCtx, errTagsID).Return(nil, tagErr).Once()
for i, s := range getTagNamesScenarios { for i, s := range getTagNamesScenarios {
scene := s.input scene := s.input
json, err := GetTagNames(testCtx, mockTagReader, &scene) json, err := GetTagNames(testCtx, db.Tag, &scene)
switch { switch {
case !s.err && err != nil: case !s.err && err != nil:
@@ -327,7 +327,7 @@ func TestGetTagNames(t *testing.T) {
} }
} }
mockTagReader.AssertExpectations(t) db.AssertExpectations(t)
} }
type sceneMoviesTestScenario struct { type sceneMoviesTestScenario struct {
@@ -391,20 +391,21 @@ var getSceneMoviesJSONScenarios = []sceneMoviesTestScenario{
} }
func TestGetSceneMoviesJSON(t *testing.T) { func TestGetSceneMoviesJSON(t *testing.T) {
mockMovieReader := &mocks.MovieReaderWriter{} db := mocks.NewDatabase()
movieErr := errors.New("error getting movie") movieErr := errors.New("error getting movie")
mockMovieReader.On("Find", testCtx, validMovie1).Return(&models.Movie{ db.Movie.On("Find", testCtx, validMovie1).Return(&models.Movie{
Name: movie1Name, Name: movie1Name,
}, nil).Once() }, nil).Once()
mockMovieReader.On("Find", testCtx, validMovie2).Return(&models.Movie{ db.Movie.On("Find", testCtx, validMovie2).Return(&models.Movie{
Name: movie2Name, Name: movie2Name,
}, nil).Once() }, nil).Once()
mockMovieReader.On("Find", testCtx, invalidMovie).Return(nil, movieErr).Once() db.Movie.On("Find", testCtx, invalidMovie).Return(nil, movieErr).Once()
for i, s := range getSceneMoviesJSONScenarios { for i, s := range getSceneMoviesJSONScenarios {
scene := s.input scene := s.input
json, err := GetSceneMoviesJSON(testCtx, mockMovieReader, &scene) json, err := GetSceneMoviesJSON(testCtx, db.Movie, &scene)
switch { switch {
case !s.err && err != nil: case !s.err && err != nil:
@@ -416,7 +417,7 @@ func TestGetSceneMoviesJSON(t *testing.T) {
} }
} }
mockMovieReader.AssertExpectations(t) db.AssertExpectations(t)
} }
const ( const (
@@ -542,27 +543,26 @@ var invalidMarkers2 = []*models.SceneMarker{
} }
func TestGetSceneMarkersJSON(t *testing.T) { func TestGetSceneMarkersJSON(t *testing.T) {
mockTagReader := &mocks.TagReaderWriter{} db := mocks.NewDatabase()
mockMarkerReader := &mocks.SceneMarkerReaderWriter{}
markersErr := errors.New("error getting scene markers") markersErr := errors.New("error getting scene markers")
tagErr := errors.New("error getting tags") tagErr := errors.New("error getting tags")
mockMarkerReader.On("FindBySceneID", testCtx, sceneID).Return(validMarkers, nil).Once() db.SceneMarker.On("FindBySceneID", testCtx, sceneID).Return(validMarkers, nil).Once()
mockMarkerReader.On("FindBySceneID", testCtx, noMarkersID).Return(nil, nil).Once() db.SceneMarker.On("FindBySceneID", testCtx, noMarkersID).Return(nil, nil).Once()
mockMarkerReader.On("FindBySceneID", testCtx, errMarkersID).Return(nil, markersErr).Once() db.SceneMarker.On("FindBySceneID", testCtx, errMarkersID).Return(nil, markersErr).Once()
mockMarkerReader.On("FindBySceneID", testCtx, errFindPrimaryTagID).Return(invalidMarkers1, nil).Once() db.SceneMarker.On("FindBySceneID", testCtx, errFindPrimaryTagID).Return(invalidMarkers1, nil).Once()
mockMarkerReader.On("FindBySceneID", testCtx, errFindByMarkerID).Return(invalidMarkers2, nil).Once() db.SceneMarker.On("FindBySceneID", testCtx, errFindByMarkerID).Return(invalidMarkers2, nil).Once()
mockTagReader.On("Find", testCtx, validTagID1).Return(&models.Tag{ db.Tag.On("Find", testCtx, validTagID1).Return(&models.Tag{
Name: validTagName1, Name: validTagName1,
}, nil) }, nil)
mockTagReader.On("Find", testCtx, validTagID2).Return(&models.Tag{ db.Tag.On("Find", testCtx, validTagID2).Return(&models.Tag{
Name: validTagName2, Name: validTagName2,
}, nil) }, nil)
mockTagReader.On("Find", testCtx, invalidTagID).Return(nil, tagErr) db.Tag.On("Find", testCtx, invalidTagID).Return(nil, tagErr)
mockTagReader.On("FindBySceneMarkerID", testCtx, validMarkerID1).Return([]*models.Tag{ db.Tag.On("FindBySceneMarkerID", testCtx, validMarkerID1).Return([]*models.Tag{
{ {
Name: validTagName1, Name: validTagName1,
}, },
@@ -570,16 +570,16 @@ func TestGetSceneMarkersJSON(t *testing.T) {
Name: validTagName2, Name: validTagName2,
}, },
}, nil) }, nil)
mockTagReader.On("FindBySceneMarkerID", testCtx, validMarkerID2).Return([]*models.Tag{ db.Tag.On("FindBySceneMarkerID", testCtx, validMarkerID2).Return([]*models.Tag{
{ {
Name: validTagName2, Name: validTagName2,
}, },
}, nil) }, nil)
mockTagReader.On("FindBySceneMarkerID", testCtx, invalidMarkerID2).Return(nil, tagErr).Once() db.Tag.On("FindBySceneMarkerID", testCtx, invalidMarkerID2).Return(nil, tagErr).Once()
for i, s := range getSceneMarkersJSONScenarios { for i, s := range getSceneMarkersJSONScenarios {
scene := s.input scene := s.input
json, err := GetSceneMarkersJSON(testCtx, mockMarkerReader, mockTagReader, &scene) json, err := GetSceneMarkersJSON(testCtx, db.SceneMarker, db.Tag, &scene)
switch { switch {
case !s.err && err != nil: case !s.err && err != nil:
@@ -591,5 +591,5 @@ func TestGetSceneMarkersJSON(t *testing.T) {
} }
} }
mockTagReader.AssertExpectations(t) db.AssertExpectations(t)
} }

View File

@@ -410,17 +410,19 @@ type FilenameParser struct {
ParserInput models.SceneParserInput ParserInput models.SceneParserInput
Filter *models.FindFilterType Filter *models.FindFilterType
whitespaceRE *regexp.Regexp whitespaceRE *regexp.Regexp
repository FilenameParserRepository
performerCache map[string]*models.Performer performerCache map[string]*models.Performer
studioCache map[string]*models.Studio studioCache map[string]*models.Studio
movieCache map[string]*models.Movie movieCache map[string]*models.Movie
tagCache map[string]*models.Tag tagCache map[string]*models.Tag
} }
func NewFilenameParser(filter *models.FindFilterType, config models.SceneParserInput) *FilenameParser { func NewFilenameParser(filter *models.FindFilterType, config models.SceneParserInput, repo FilenameParserRepository) *FilenameParser {
p := &FilenameParser{ p := &FilenameParser{
Pattern: *filter.Q, Pattern: *filter.Q,
ParserInput: config, ParserInput: config,
Filter: filter, Filter: filter,
repository: repo,
} }
p.performerCache = make(map[string]*models.Performer) p.performerCache = make(map[string]*models.Performer)
@@ -457,7 +459,17 @@ type FilenameParserRepository struct {
Tag models.TagQueryer Tag models.TagQueryer
} }
func (p *FilenameParser) Parse(ctx context.Context, repo FilenameParserRepository) ([]*models.SceneParserResult, int, error) { func NewFilenameParserRepository(repo models.Repository) FilenameParserRepository {
return FilenameParserRepository{
Scene: repo.Scene,
Performer: repo.Performer,
Studio: repo.Studio,
Movie: repo.Movie,
Tag: repo.Tag,
}
}
func (p *FilenameParser) Parse(ctx context.Context) ([]*models.SceneParserResult, int, error) {
// perform the query to find the scenes // perform the query to find the scenes
mapper, err := newParseMapper(p.Pattern, p.ParserInput.IgnoreWords) mapper, err := newParseMapper(p.Pattern, p.ParserInput.IgnoreWords)
@@ -479,17 +491,17 @@ func (p *FilenameParser) Parse(ctx context.Context, repo FilenameParserRepositor
p.Filter.Q = nil p.Filter.Q = nil
scenes, total, err := QueryWithCount(ctx, repo.Scene, sceneFilter, p.Filter) scenes, total, err := QueryWithCount(ctx, p.repository.Scene, sceneFilter, p.Filter)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
ret := p.parseScenes(ctx, repo, scenes, mapper) ret := p.parseScenes(ctx, scenes, mapper)
return ret, total, nil return ret, total, nil
} }
func (p *FilenameParser) parseScenes(ctx context.Context, repo FilenameParserRepository, scenes []*models.Scene, mapper *parseMapper) []*models.SceneParserResult { func (p *FilenameParser) parseScenes(ctx context.Context, scenes []*models.Scene, mapper *parseMapper) []*models.SceneParserResult {
var ret []*models.SceneParserResult var ret []*models.SceneParserResult
for _, scene := range scenes { for _, scene := range scenes {
sceneHolder := mapper.parse(scene) sceneHolder := mapper.parse(scene)
@@ -498,7 +510,7 @@ func (p *FilenameParser) parseScenes(ctx context.Context, repo FilenameParserRep
r := &models.SceneParserResult{ r := &models.SceneParserResult{
Scene: scene, Scene: scene,
} }
p.setParserResult(ctx, repo, *sceneHolder, r) p.setParserResult(ctx, *sceneHolder, r)
ret = append(ret, r) ret = append(ret, r)
} }
@@ -671,7 +683,7 @@ func (p *FilenameParser) setMovies(ctx context.Context, qb MovieNameFinder, h sc
} }
} }
func (p *FilenameParser) setParserResult(ctx context.Context, repo FilenameParserRepository, h sceneHolder, result *models.SceneParserResult) { func (p *FilenameParser) setParserResult(ctx context.Context, h sceneHolder, result *models.SceneParserResult) {
if h.result.Title != "" { if h.result.Title != "" {
title := h.result.Title title := h.result.Title
title = p.replaceWhitespaceCharacters(title) title = p.replaceWhitespaceCharacters(title)
@@ -692,15 +704,17 @@ func (p *FilenameParser) setParserResult(ctx context.Context, repo FilenameParse
result.Rating = h.result.Rating result.Rating = h.result.Rating
} }
r := p.repository
if len(h.performers) > 0 { if len(h.performers) > 0 {
p.setPerformers(ctx, repo.Performer, h, result) p.setPerformers(ctx, r.Performer, h, result)
} }
if len(h.tags) > 0 { if len(h.tags) > 0 {
p.setTags(ctx, repo.Tag, h, result) p.setTags(ctx, r.Tag, h, result)
} }
p.setStudio(ctx, repo.Studio, h, result) p.setStudio(ctx, r.Studio, h, result)
if len(h.movies) > 0 { if len(h.movies) > 0 {
p.setMovies(ctx, repo.Movie, h, result) p.setMovies(ctx, r.Movie, h, result)
} }
} }

View File

@@ -56,20 +56,19 @@ func TestImporterPreImport(t *testing.T) {
} }
func TestImporterPreImportWithStudio(t *testing.T) { func TestImporterPreImportWithStudio(t *testing.T) {
studioReaderWriter := &mocks.StudioReaderWriter{} db := mocks.NewDatabase()
testCtx := context.Background()
i := Importer{ i := Importer{
StudioWriter: studioReaderWriter, StudioWriter: db.Studio,
Input: jsonschema.Scene{ Input: jsonschema.Scene{
Studio: existingStudioName, Studio: existingStudioName,
}, },
} }
studioReaderWriter.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{ db.Studio.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{
ID: existingStudioID, ID: existingStudioID,
}, nil).Once() }, nil).Once()
studioReaderWriter.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once() db.Studio.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once()
err := i.PreImport(testCtx) err := i.PreImport(testCtx)
assert.Nil(t, err) assert.Nil(t, err)
@@ -79,22 +78,22 @@ func TestImporterPreImportWithStudio(t *testing.T) {
err = i.PreImport(testCtx) err = i.PreImport(testCtx)
assert.NotNil(t, err) assert.NotNil(t, err)
studioReaderWriter.AssertExpectations(t) db.AssertExpectations(t)
} }
func TestImporterPreImportWithMissingStudio(t *testing.T) { func TestImporterPreImportWithMissingStudio(t *testing.T) {
studioReaderWriter := &mocks.StudioReaderWriter{} db := mocks.NewDatabase()
i := Importer{ i := Importer{
StudioWriter: studioReaderWriter, StudioWriter: db.Studio,
Input: jsonschema.Scene{ Input: jsonschema.Scene{
Studio: missingStudioName, Studio: missingStudioName,
}, },
MissingRefBehaviour: models.ImportMissingRefEnumFail, MissingRefBehaviour: models.ImportMissingRefEnumFail,
} }
studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3) db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3)
studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) { db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
s := args.Get(1).(*models.Studio) s := args.Get(1).(*models.Studio)
s.ID = existingStudioID s.ID = existingStudioID
}).Return(nil) }).Return(nil)
@@ -111,32 +110,34 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, existingStudioID, *i.scene.StudioID) assert.Equal(t, existingStudioID, *i.scene.StudioID)
studioReaderWriter.AssertExpectations(t) db.AssertExpectations(t)
} }
func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) { func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) {
studioReaderWriter := &mocks.StudioReaderWriter{} db := mocks.NewDatabase()
i := Importer{ i := Importer{
StudioWriter: studioReaderWriter, StudioWriter: db.Studio,
Input: jsonschema.Scene{ Input: jsonschema.Scene{
Studio: missingStudioName, Studio: missingStudioName,
}, },
MissingRefBehaviour: models.ImportMissingRefEnumCreate, MissingRefBehaviour: models.ImportMissingRefEnumCreate,
} }
studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once() db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once()
studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error")) db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
err := i.PreImport(testCtx) err := i.PreImport(testCtx)
assert.NotNil(t, err) assert.NotNil(t, err)
db.AssertExpectations(t)
} }
func TestImporterPreImportWithPerformer(t *testing.T) { func TestImporterPreImportWithPerformer(t *testing.T) {
performerReaderWriter := &mocks.PerformerReaderWriter{} db := mocks.NewDatabase()
i := Importer{ i := Importer{
PerformerWriter: performerReaderWriter, PerformerWriter: db.Performer,
MissingRefBehaviour: models.ImportMissingRefEnumFail, MissingRefBehaviour: models.ImportMissingRefEnumFail,
Input: jsonschema.Scene{ Input: jsonschema.Scene{
Performers: []string{ Performers: []string{
@@ -145,13 +146,13 @@ func TestImporterPreImportWithPerformer(t *testing.T) {
}, },
} }
performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerName}, false).Return([]*models.Performer{ db.Performer.On("FindByNames", testCtx, []string{existingPerformerName}, false).Return([]*models.Performer{
{ {
ID: existingPerformerID, ID: existingPerformerID,
Name: existingPerformerName, Name: existingPerformerName,
}, },
}, nil).Once() }, nil).Once()
performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once() db.Performer.On("FindByNames", testCtx, []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once()
err := i.PreImport(testCtx) err := i.PreImport(testCtx)
assert.Nil(t, err) assert.Nil(t, err)
@@ -161,14 +162,14 @@ func TestImporterPreImportWithPerformer(t *testing.T) {
err = i.PreImport(testCtx) err = i.PreImport(testCtx)
assert.NotNil(t, err) assert.NotNil(t, err)
performerReaderWriter.AssertExpectations(t) db.AssertExpectations(t)
} }
func TestImporterPreImportWithMissingPerformer(t *testing.T) { func TestImporterPreImportWithMissingPerformer(t *testing.T) {
performerReaderWriter := &mocks.PerformerReaderWriter{} db := mocks.NewDatabase()
i := Importer{ i := Importer{
PerformerWriter: performerReaderWriter, PerformerWriter: db.Performer,
Input: jsonschema.Scene{ Input: jsonschema.Scene{
Performers: []string{ Performers: []string{
missingPerformerName, missingPerformerName,
@@ -177,8 +178,8 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail, MissingRefBehaviour: models.ImportMissingRefEnumFail,
} }
performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3) db.Performer.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3)
performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Run(func(args mock.Arguments) { db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Run(func(args mock.Arguments) {
p := args.Get(1).(*models.Performer) p := args.Get(1).(*models.Performer)
p.ID = existingPerformerID p.ID = existingPerformerID
}).Return(nil) }).Return(nil)
@@ -195,14 +196,14 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, []int{existingPerformerID}, i.scene.PerformerIDs.List()) assert.Equal(t, []int{existingPerformerID}, i.scene.PerformerIDs.List())
performerReaderWriter.AssertExpectations(t) db.AssertExpectations(t)
} }
func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) { func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) {
performerReaderWriter := &mocks.PerformerReaderWriter{} db := mocks.NewDatabase()
i := Importer{ i := Importer{
PerformerWriter: performerReaderWriter, PerformerWriter: db.Performer,
Input: jsonschema.Scene{ Input: jsonschema.Scene{
Performers: []string{ Performers: []string{
missingPerformerName, missingPerformerName,
@@ -211,19 +212,20 @@ func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate, MissingRefBehaviour: models.ImportMissingRefEnumCreate,
} }
performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once() db.Performer.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once()
performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Return(errors.New("Create error")) db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Return(errors.New("Create error"))
err := i.PreImport(testCtx) err := i.PreImport(testCtx)
assert.NotNil(t, err) assert.NotNil(t, err)
db.AssertExpectations(t)
} }
func TestImporterPreImportWithMovie(t *testing.T) { func TestImporterPreImportWithMovie(t *testing.T) {
movieReaderWriter := &mocks.MovieReaderWriter{} db := mocks.NewDatabase()
testCtx := context.Background()
i := Importer{ i := Importer{
MovieWriter: movieReaderWriter, MovieWriter: db.Movie,
MissingRefBehaviour: models.ImportMissingRefEnumFail, MissingRefBehaviour: models.ImportMissingRefEnumFail,
Input: jsonschema.Scene{ Input: jsonschema.Scene{
Movies: []jsonschema.SceneMovie{ Movies: []jsonschema.SceneMovie{
@@ -235,11 +237,11 @@ func TestImporterPreImportWithMovie(t *testing.T) {
}, },
} }
movieReaderWriter.On("FindByName", testCtx, existingMovieName, false).Return(&models.Movie{ db.Movie.On("FindByName", testCtx, existingMovieName, false).Return(&models.Movie{
ID: existingMovieID, ID: existingMovieID,
Name: existingMovieName, Name: existingMovieName,
}, nil).Once() }, nil).Once()
movieReaderWriter.On("FindByName", testCtx, existingMovieErr, false).Return(nil, errors.New("FindByName error")).Once() db.Movie.On("FindByName", testCtx, existingMovieErr, false).Return(nil, errors.New("FindByName error")).Once()
err := i.PreImport(testCtx) err := i.PreImport(testCtx)
assert.Nil(t, err) assert.Nil(t, err)
@@ -249,15 +251,14 @@ func TestImporterPreImportWithMovie(t *testing.T) {
err = i.PreImport(testCtx) err = i.PreImport(testCtx)
assert.NotNil(t, err) assert.NotNil(t, err)
movieReaderWriter.AssertExpectations(t) db.AssertExpectations(t)
} }
func TestImporterPreImportWithMissingMovie(t *testing.T) { func TestImporterPreImportWithMissingMovie(t *testing.T) {
movieReaderWriter := &mocks.MovieReaderWriter{} db := mocks.NewDatabase()
testCtx := context.Background()
i := Importer{ i := Importer{
MovieWriter: movieReaderWriter, MovieWriter: db.Movie,
Input: jsonschema.Scene{ Input: jsonschema.Scene{
Movies: []jsonschema.SceneMovie{ Movies: []jsonschema.SceneMovie{
{ {
@@ -268,8 +269,8 @@ func TestImporterPreImportWithMissingMovie(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail, MissingRefBehaviour: models.ImportMissingRefEnumFail,
} }
movieReaderWriter.On("FindByName", testCtx, missingMovieName, false).Return(nil, nil).Times(3) db.Movie.On("FindByName", testCtx, missingMovieName, false).Return(nil, nil).Times(3)
movieReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Movie")).Run(func(args mock.Arguments) { db.Movie.On("Create", testCtx, mock.AnythingOfType("*models.Movie")).Run(func(args mock.Arguments) {
m := args.Get(1).(*models.Movie) m := args.Get(1).(*models.Movie)
m.ID = existingMovieID m.ID = existingMovieID
}).Return(nil) }).Return(nil)
@@ -286,14 +287,14 @@ func TestImporterPreImportWithMissingMovie(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, existingMovieID, i.scene.Movies.List()[0].MovieID) assert.Equal(t, existingMovieID, i.scene.Movies.List()[0].MovieID)
movieReaderWriter.AssertExpectations(t) db.AssertExpectations(t)
} }
func TestImporterPreImportWithMissingMovieCreateErr(t *testing.T) { func TestImporterPreImportWithMissingMovieCreateErr(t *testing.T) {
movieReaderWriter := &mocks.MovieReaderWriter{} db := mocks.NewDatabase()
i := Importer{ i := Importer{
MovieWriter: movieReaderWriter, MovieWriter: db.Movie,
Input: jsonschema.Scene{ Input: jsonschema.Scene{
Movies: []jsonschema.SceneMovie{ Movies: []jsonschema.SceneMovie{
{ {
@@ -304,18 +305,20 @@ func TestImporterPreImportWithMissingMovieCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate, MissingRefBehaviour: models.ImportMissingRefEnumCreate,
} }
movieReaderWriter.On("FindByName", testCtx, missingMovieName, false).Return(nil, nil).Once() db.Movie.On("FindByName", testCtx, missingMovieName, false).Return(nil, nil).Once()
movieReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Movie")).Return(errors.New("Create error")) db.Movie.On("Create", testCtx, mock.AnythingOfType("*models.Movie")).Return(errors.New("Create error"))
err := i.PreImport(testCtx) err := i.PreImport(testCtx)
assert.NotNil(t, err) assert.NotNil(t, err)
db.AssertExpectations(t)
} }
func TestImporterPreImportWithTag(t *testing.T) { func TestImporterPreImportWithTag(t *testing.T) {
tagReaderWriter := &mocks.TagReaderWriter{} db := mocks.NewDatabase()
i := Importer{ i := Importer{
TagWriter: tagReaderWriter, TagWriter: db.Tag,
MissingRefBehaviour: models.ImportMissingRefEnumFail, MissingRefBehaviour: models.ImportMissingRefEnumFail,
Input: jsonschema.Scene{ Input: jsonschema.Scene{
Tags: []string{ Tags: []string{
@@ -324,13 +327,13 @@ func TestImporterPreImportWithTag(t *testing.T) {
}, },
} }
tagReaderWriter.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{ db.Tag.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{
{ {
ID: existingTagID, ID: existingTagID,
Name: existingTagName, Name: existingTagName,
}, },
}, nil).Once() }, nil).Once()
tagReaderWriter.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once() db.Tag.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once()
err := i.PreImport(testCtx) err := i.PreImport(testCtx)
assert.Nil(t, err) assert.Nil(t, err)
@@ -340,14 +343,14 @@ func TestImporterPreImportWithTag(t *testing.T) {
err = i.PreImport(testCtx) err = i.PreImport(testCtx)
assert.NotNil(t, err) assert.NotNil(t, err)
tagReaderWriter.AssertExpectations(t) db.AssertExpectations(t)
} }
func TestImporterPreImportWithMissingTag(t *testing.T) { func TestImporterPreImportWithMissingTag(t *testing.T) {
tagReaderWriter := &mocks.TagReaderWriter{} db := mocks.NewDatabase()
i := Importer{ i := Importer{
TagWriter: tagReaderWriter, TagWriter: db.Tag,
Input: jsonschema.Scene{ Input: jsonschema.Scene{
Tags: []string{ Tags: []string{
missingTagName, missingTagName,
@@ -356,8 +359,8 @@ func TestImporterPreImportWithMissingTag(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail, MissingRefBehaviour: models.ImportMissingRefEnumFail,
} }
tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3) db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3)
tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) { db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) {
t := args.Get(1).(*models.Tag) t := args.Get(1).(*models.Tag)
t.ID = existingTagID t.ID = existingTagID
}).Return(nil) }).Return(nil)
@@ -374,14 +377,14 @@ func TestImporterPreImportWithMissingTag(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, []int{existingTagID}, i.scene.TagIDs.List()) assert.Equal(t, []int{existingTagID}, i.scene.TagIDs.List())
tagReaderWriter.AssertExpectations(t) db.AssertExpectations(t)
} }
func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) { func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) {
tagReaderWriter := &mocks.TagReaderWriter{} db := mocks.NewDatabase()
i := Importer{ i := Importer{
TagWriter: tagReaderWriter, TagWriter: db.Tag,
Input: jsonschema.Scene{ Input: jsonschema.Scene{
Tags: []string{ Tags: []string{
missingTagName, missingTagName,
@@ -390,9 +393,11 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate, MissingRefBehaviour: models.ImportMissingRefEnumCreate,
} }
tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once() db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once()
tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error")) db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error"))
err := i.PreImport(testCtx) err := i.PreImport(testCtx)
assert.NotNil(t, err) assert.NotNil(t, err)
db.AssertExpectations(t)
} }

View File

@@ -1,7 +1,6 @@
package scene package scene
import ( import (
"context"
"errors" "errors"
"strconv" "strconv"
"testing" "testing"
@@ -105,8 +104,6 @@ func TestUpdater_Update(t *testing.T) {
tagID tagID
) )
ctx := context.Background()
performerIDs := []int{performerID} performerIDs := []int{performerID}
tagIDs := []int{tagID} tagIDs := []int{tagID}
stashID := "stashID" stashID := "stashID"
@@ -119,14 +116,15 @@ func TestUpdater_Update(t *testing.T) {
updateErr := errors.New("error updating") updateErr := errors.New("error updating")
qb := mocks.SceneReaderWriter{} db := mocks.NewDatabase()
qb.On("UpdatePartial", ctx, mock.MatchedBy(func(id int) bool {
db.Scene.On("UpdatePartial", testCtx, mock.MatchedBy(func(id int) bool {
return id != badUpdateID return id != badUpdateID
}), mock.Anything).Return(validScene, nil) }), mock.Anything).Return(validScene, nil)
qb.On("UpdatePartial", ctx, badUpdateID, mock.Anything).Return(nil, updateErr) db.Scene.On("UpdatePartial", testCtx, badUpdateID, mock.Anything).Return(nil, updateErr)
qb.On("UpdateCover", ctx, sceneID, cover).Return(nil).Once() db.Scene.On("UpdateCover", testCtx, sceneID, cover).Return(nil).Once()
qb.On("UpdateCover", ctx, badCoverID, cover).Return(updateErr).Once() db.Scene.On("UpdateCover", testCtx, badCoverID, cover).Return(updateErr).Once()
tests := []struct { tests := []struct {
name string name string
@@ -204,7 +202,7 @@ func TestUpdater_Update(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := tt.u.Update(ctx, &qb) got, err := tt.u.Update(testCtx, db.Scene)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("Updater.Update() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Updater.Update() error = %v, wantErr %v", err, tt.wantErr)
return return
@@ -215,7 +213,7 @@ func TestUpdater_Update(t *testing.T) {
}) })
} }
qb.AssertExpectations(t) db.AssertExpectations(t)
} }
func TestUpdateSet_UpdateInput(t *testing.T) { func TestUpdateSet_UpdateInput(t *testing.T) {

View File

@@ -18,7 +18,6 @@ const (
) )
type autotagScraper struct { type autotagScraper struct {
// repository models.Repository
txnManager txn.Manager txnManager txn.Manager
performerReader models.PerformerAutoTagQueryer performerReader models.PerformerAutoTagQueryer
studioReader models.StudioAutoTagQueryer studioReader models.StudioAutoTagQueryer
@@ -208,9 +207,9 @@ func (s autotagScraper) spec() Scraper {
} }
} }
func getAutoTagScraper(txnManager txn.Manager, repo Repository, globalConfig GlobalConfig) scraper { func getAutoTagScraper(repo Repository, globalConfig GlobalConfig) scraper {
base := autotagScraper{ base := autotagScraper{
txnManager: txnManager, txnManager: repo.TxnManager,
performerReader: repo.PerformerFinder, performerReader: repo.PerformerFinder,
studioReader: repo.StudioFinder, studioReader: repo.StudioFinder,
tagReader: repo.TagFinder, tagReader: repo.TagFinder,

View File

@@ -77,6 +77,8 @@ type GalleryFinder interface {
} }
type Repository struct { type Repository struct {
TxnManager models.TxnManager
SceneFinder SceneFinder SceneFinder SceneFinder
GalleryFinder GalleryFinder GalleryFinder GalleryFinder
TagFinder TagFinder TagFinder TagFinder
@@ -85,12 +87,27 @@ type Repository struct {
StudioFinder StudioFinder StudioFinder StudioFinder
} }
func NewRepository(repo models.Repository) Repository {
return Repository{
TxnManager: repo.TxnManager,
SceneFinder: repo.Scene,
GalleryFinder: repo.Gallery,
TagFinder: repo.Tag,
PerformerFinder: repo.Performer,
MovieFinder: repo.Movie,
StudioFinder: repo.Studio,
}
}
func (r *Repository) WithReadTxn(ctx context.Context, fn txn.TxnFunc) error {
return txn.WithReadTxn(ctx, r.TxnManager, fn)
}
// Cache stores the database of scrapers // Cache stores the database of scrapers
type Cache struct { type Cache struct {
client *http.Client client *http.Client
scrapers map[string]scraper // Scraper ID -> Scraper scrapers map[string]scraper // Scraper ID -> Scraper
globalConfig GlobalConfig globalConfig GlobalConfig
txnManager txn.Manager
repository Repository repository Repository
} }
@@ -122,14 +139,13 @@ func newClient(gc GlobalConfig) *http.Client {
// //
// Scraper configurations are loaded from yml files in the provided scrapers // Scraper configurations are loaded from yml files in the provided scrapers
// directory and any subdirectories. // directory and any subdirectories.
func NewCache(globalConfig GlobalConfig, txnManager txn.Manager, repo Repository) (*Cache, error) { func NewCache(globalConfig GlobalConfig, repo Repository) (*Cache, error) {
// HTTP Client setup // HTTP Client setup
client := newClient(globalConfig) client := newClient(globalConfig)
ret := &Cache{ ret := &Cache{
client: client, client: client,
globalConfig: globalConfig, globalConfig: globalConfig,
txnManager: txnManager,
repository: repo, repository: repo,
} }
@@ -148,7 +164,7 @@ func (c *Cache) loadScrapers() (map[string]scraper, error) {
// Add built-in scrapers // Add built-in scrapers
freeOnes := getFreeonesScraper(c.globalConfig) freeOnes := getFreeonesScraper(c.globalConfig)
autoTag := getAutoTagScraper(c.txnManager, c.repository, c.globalConfig) autoTag := getAutoTagScraper(c.repository, c.globalConfig)
scrapers[freeOnes.spec().ID] = freeOnes scrapers[freeOnes.spec().ID] = freeOnes
scrapers[autoTag.spec().ID] = autoTag scrapers[autoTag.spec().ID] = autoTag
@@ -369,9 +385,12 @@ func (c Cache) ScrapeID(ctx context.Context, scraperID string, id int, ty Scrape
func (c Cache) getScene(ctx context.Context, sceneID int) (*models.Scene, error) { func (c Cache) getScene(ctx context.Context, sceneID int) (*models.Scene, error) {
var ret *models.Scene var ret *models.Scene
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error { r := c.repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
qb := r.SceneFinder
var err error var err error
ret, err = c.repository.SceneFinder.Find(ctx, sceneID) ret, err = qb.Find(ctx, sceneID)
if err != nil { if err != nil {
return err return err
} }
@@ -380,7 +399,7 @@ func (c Cache) getScene(ctx context.Context, sceneID int) (*models.Scene, error)
return fmt.Errorf("scene with id %d not found", sceneID) return fmt.Errorf("scene with id %d not found", sceneID)
} }
return ret.LoadURLs(ctx, c.repository.SceneFinder) return ret.LoadURLs(ctx, qb)
}); err != nil { }); err != nil {
return nil, err return nil, err
} }
@@ -389,9 +408,12 @@ func (c Cache) getScene(ctx context.Context, sceneID int) (*models.Scene, error)
func (c Cache) getGallery(ctx context.Context, galleryID int) (*models.Gallery, error) { func (c Cache) getGallery(ctx context.Context, galleryID int) (*models.Gallery, error) {
var ret *models.Gallery var ret *models.Gallery
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error { r := c.repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
qb := r.GalleryFinder
var err error var err error
ret, err = c.repository.GalleryFinder.Find(ctx, galleryID) ret, err = qb.Find(ctx, galleryID)
if err != nil { if err != nil {
return err return err
} }
@@ -400,12 +422,12 @@ func (c Cache) getGallery(ctx context.Context, galleryID int) (*models.Gallery,
return fmt.Errorf("gallery with id %d not found", galleryID) return fmt.Errorf("gallery with id %d not found", galleryID)
} }
err = ret.LoadFiles(ctx, c.repository.GalleryFinder) err = ret.LoadFiles(ctx, qb)
if err != nil { if err != nil {
return err return err
} }
return ret.LoadURLs(ctx, c.repository.GalleryFinder) return ret.LoadURLs(ctx, qb)
}); err != nil { }); err != nil {
return nil, err return nil, err
} }

View File

@@ -6,7 +6,6 @@ import (
"github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/match" "github.com/stashapp/stash/pkg/match"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
) )
// postScrape handles post-processing of scraped content. If the content // postScrape handles post-processing of scraped content. If the content
@@ -46,8 +45,9 @@ func (c Cache) postScrape(ctx context.Context, content ScrapedContent) (ScrapedC
} }
func (c Cache) postScrapePerformer(ctx context.Context, p models.ScrapedPerformer) (ScrapedContent, error) { func (c Cache) postScrapePerformer(ctx context.Context, p models.ScrapedPerformer) (ScrapedContent, error) {
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error { r := c.repository
tqb := c.repository.TagFinder if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
tqb := r.TagFinder
tags, err := postProcessTags(ctx, tqb, p.Tags) tags, err := postProcessTags(ctx, tqb, p.Tags)
if err != nil { if err != nil {
@@ -72,8 +72,9 @@ func (c Cache) postScrapePerformer(ctx context.Context, p models.ScrapedPerforme
func (c Cache) postScrapeMovie(ctx context.Context, m models.ScrapedMovie) (ScrapedContent, error) { func (c Cache) postScrapeMovie(ctx context.Context, m models.ScrapedMovie) (ScrapedContent, error) {
if m.Studio != nil { if m.Studio != nil {
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error { r := c.repository
return match.ScrapedStudio(ctx, c.repository.StudioFinder, m.Studio, nil) if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
return match.ScrapedStudio(ctx, r.StudioFinder, m.Studio, nil)
}); err != nil { }); err != nil {
return nil, err return nil, err
} }
@@ -113,11 +114,12 @@ func (c Cache) postScrapeScene(ctx context.Context, scene ScrapedScene) (Scraped
scene.URLs = []string{*scene.URL} scene.URLs = []string{*scene.URL}
} }
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error { r := c.repository
pqb := c.repository.PerformerFinder if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
mqb := c.repository.MovieFinder pqb := r.PerformerFinder
tqb := c.repository.TagFinder mqb := r.MovieFinder
sqb := c.repository.StudioFinder tqb := r.TagFinder
sqb := r.StudioFinder
for _, p := range scene.Performers { for _, p := range scene.Performers {
if p == nil { if p == nil {
@@ -175,10 +177,11 @@ func (c Cache) postScrapeGallery(ctx context.Context, g ScrapedGallery) (Scraped
g.URLs = []string{*g.URL} g.URLs = []string{*g.URL}
} }
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error { r := c.repository
pqb := c.repository.PerformerFinder if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
tqb := c.repository.TagFinder pqb := r.PerformerFinder
sqb := c.repository.StudioFinder tqb := r.TagFinder
sqb := r.StudioFinder
for _, p := range g.Performers { for _, p := range g.Performers {
err := match.ScrapedPerformer(ctx, pqb, p, nil) err := match.ScrapedPerformer(ctx, pqb, p, nil)

View File

@@ -56,22 +56,37 @@ type TagFinder interface {
} }
type Repository struct { type Repository struct {
TxnManager models.TxnManager
Scene SceneReader Scene SceneReader
Performer PerformerReader Performer PerformerReader
Tag TagFinder Tag TagFinder
Studio StudioReader Studio StudioReader
} }
func NewRepository(repo models.Repository) Repository {
return Repository{
TxnManager: repo.TxnManager,
Scene: repo.Scene,
Performer: repo.Performer,
Tag: repo.Tag,
Studio: repo.Studio,
}
}
func (r *Repository) WithReadTxn(ctx context.Context, fn txn.TxnFunc) error {
return txn.WithReadTxn(ctx, r.TxnManager, fn)
}
// Client represents the client interface to a stash-box server instance. // Client represents the client interface to a stash-box server instance.
type Client struct { type Client struct {
client *graphql.Client client *graphql.Client
txnManager txn.Manager
repository Repository repository Repository
box models.StashBox box models.StashBox
} }
// NewClient returns a new instance of a stash-box client. // NewClient returns a new instance of a stash-box client.
func NewClient(box models.StashBox, txnManager txn.Manager, repo Repository) *Client { func NewClient(box models.StashBox, repo Repository) *Client {
authHeader := func(req *http.Request) { authHeader := func(req *http.Request) {
req.Header.Set("ApiKey", box.APIKey) req.Header.Set("ApiKey", box.APIKey)
} }
@@ -82,7 +97,6 @@ func NewClient(box models.StashBox, txnManager txn.Manager, repo Repository) *Cl
return &Client{ return &Client{
client: client, client: client,
txnManager: txnManager,
repository: repo, repository: repo,
box: box, box: box,
} }
@@ -129,8 +143,9 @@ func (c Client) FindStashBoxSceneByFingerprints(ctx context.Context, sceneID int
func (c Client) FindStashBoxScenesByFingerprints(ctx context.Context, ids []int) ([][]*scraper.ScrapedScene, error) { func (c Client) FindStashBoxScenesByFingerprints(ctx context.Context, ids []int) ([][]*scraper.ScrapedScene, error) {
var fingerprints [][]*graphql.FingerprintQueryInput var fingerprints [][]*graphql.FingerprintQueryInput
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error { r := c.repository
qb := c.repository.Scene if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
qb := r.Scene
for _, sceneID := range ids { for _, sceneID := range ids {
scene, err := qb.Find(ctx, sceneID) scene, err := qb.Find(ctx, sceneID)
@@ -142,7 +157,7 @@ func (c Client) FindStashBoxScenesByFingerprints(ctx context.Context, ids []int)
return fmt.Errorf("scene with id %d not found", sceneID) return fmt.Errorf("scene with id %d not found", sceneID)
} }
if err := scene.LoadFiles(ctx, c.repository.Scene); err != nil { if err := scene.LoadFiles(ctx, r.Scene); err != nil {
return err return err
} }
@@ -243,8 +258,9 @@ func (c Client) SubmitStashBoxFingerprints(ctx context.Context, sceneIDs []strin
var fingerprints []graphql.FingerprintSubmission var fingerprints []graphql.FingerprintSubmission
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error { r := c.repository
qb := c.repository.Scene if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
qb := r.Scene
for _, sceneID := range ids { for _, sceneID := range ids {
scene, err := qb.Find(ctx, sceneID) scene, err := qb.Find(ctx, sceneID)
@@ -382,9 +398,9 @@ func (c Client) FindStashBoxPerformersByNames(ctx context.Context, performerIDs
} }
var performers []*models.Performer var performers []*models.Performer
r := c.repository
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error { if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
qb := c.repository.Performer qb := r.Performer
for _, performerID := range ids { for _, performerID := range ids {
performer, err := qb.Find(ctx, performerID) performer, err := qb.Find(ctx, performerID)
@@ -417,8 +433,9 @@ func (c Client) FindStashBoxPerformersByPerformerNames(ctx context.Context, perf
var performers []*models.Performer var performers []*models.Performer
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error { r := c.repository
qb := c.repository.Performer if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
qb := r.Performer
for _, performerID := range ids { for _, performerID := range ids {
performer, err := qb.Find(ctx, performerID) performer, err := qb.Find(ctx, performerID)
@@ -739,14 +756,15 @@ func (c Client) sceneFragmentToScrapedScene(ctx context.Context, s *graphql.Scen
ss.URL = &s.Urls[0].URL ss.URL = &s.Urls[0].URL
} }
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error { r := c.repository
pqb := c.repository.Performer if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
tqb := c.repository.Tag pqb := r.Performer
tqb := r.Tag
if s.Studio != nil { if s.Studio != nil {
ss.Studio = studioFragmentToScrapedStudio(*s.Studio) ss.Studio = studioFragmentToScrapedStudio(*s.Studio)
err := match.ScrapedStudio(ctx, c.repository.Studio, ss.Studio, &c.box.Endpoint) err := match.ScrapedStudio(ctx, r.Studio, ss.Studio, &c.box.Endpoint)
if err != nil { if err != nil {
return err return err
} }
@@ -761,7 +779,7 @@ func (c Client) sceneFragmentToScrapedScene(ctx context.Context, s *graphql.Scen
if parentStudio.FindStudio != nil { if parentStudio.FindStudio != nil {
ss.Studio.Parent = studioFragmentToScrapedStudio(*parentStudio.FindStudio) ss.Studio.Parent = studioFragmentToScrapedStudio(*parentStudio.FindStudio)
err = match.ScrapedStudio(ctx, c.repository.Studio, ss.Studio.Parent, &c.box.Endpoint) err = match.ScrapedStudio(ctx, r.Studio, ss.Studio.Parent, &c.box.Endpoint)
if err != nil { if err != nil {
return err return err
} }
@@ -809,8 +827,9 @@ func (c Client) FindStashBoxPerformerByID(ctx context.Context, id string) (*mode
ret := performerFragmentToScrapedPerformer(*performer.FindPerformer) ret := performerFragmentToScrapedPerformer(*performer.FindPerformer)
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error { r := c.repository
err := match.ScrapedPerformer(ctx, c.repository.Performer, ret, &c.box.Endpoint) if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
err := match.ScrapedPerformer(ctx, r.Performer, ret, &c.box.Endpoint)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -836,8 +855,9 @@ func (c Client) FindStashBoxPerformerByName(ctx context.Context, name string) (*
return nil, nil return nil, nil
} }
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error { r := c.repository
err := match.ScrapedPerformer(ctx, c.repository.Performer, ret, &c.box.Endpoint) if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
err := match.ScrapedPerformer(ctx, r.Performer, ret, &c.box.Endpoint)
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -864,10 +884,11 @@ func (c Client) FindStashBoxStudio(ctx context.Context, query string) (*models.S
var ret *models.ScrapedStudio var ret *models.ScrapedStudio
if studio.FindStudio != nil { if studio.FindStudio != nil {
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error { r := c.repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
ret = studioFragmentToScrapedStudio(*studio.FindStudio) ret = studioFragmentToScrapedStudio(*studio.FindStudio)
err = match.ScrapedStudio(ctx, c.repository.Studio, ret, &c.box.Endpoint) err = match.ScrapedStudio(ctx, r.Studio, ret, &c.box.Endpoint)
if err != nil { if err != nil {
return err return err
} }
@@ -881,7 +902,7 @@ func (c Client) FindStashBoxStudio(ctx context.Context, query string) (*models.S
if parentStudio.FindStudio != nil { if parentStudio.FindStudio != nil {
ret.Parent = studioFragmentToScrapedStudio(*parentStudio.FindStudio) ret.Parent = studioFragmentToScrapedStudio(*parentStudio.FindStudio)
err = match.ScrapedStudio(ctx, c.repository.Studio, ret.Parent, &c.box.Endpoint) err = match.ScrapedStudio(ctx, r.Studio, ret.Parent, &c.box.Endpoint)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -1176,7 +1176,7 @@ func makeImage(i int) *models.Image {
} }
func createImages(ctx context.Context, n int) error { func createImages(ctx context.Context, n int) error {
qb := db.TxnRepository().Image qb := db.Image
fqb := db.File fqb := db.File
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
@@ -1273,7 +1273,7 @@ func makeGallery(i int, includeScenes bool) *models.Gallery {
} }
func createGalleries(ctx context.Context, n int) error { func createGalleries(ctx context.Context, n int) error {
gqb := db.TxnRepository().Gallery gqb := db.Gallery
fqb := db.File fqb := db.File
for i := 0; i < n; i++ { for i := 0; i < n; i++ {

View File

@@ -123,7 +123,7 @@ func (db *Database) IsLocked(err error) bool {
return false return false
} }
func (db *Database) TxnRepository() models.Repository { func (db *Database) Repository() models.Repository {
return models.Repository{ return models.Repository{
TxnManager: db, TxnManager: db,
File: db.File, File: db.File,

View File

@@ -1,7 +1,6 @@
package studio package studio
import ( import (
"context"
"errors" "errors"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
@@ -162,27 +161,26 @@ func initTestTable() {
func TestToJSON(t *testing.T) { func TestToJSON(t *testing.T) {
initTestTable() initTestTable()
ctx := context.Background()
mockStudioReader := &mocks.StudioReaderWriter{} db := mocks.NewDatabase()
imageErr := errors.New("error getting image") imageErr := errors.New("error getting image")
mockStudioReader.On("GetImage", ctx, studioID).Return(imageBytes, nil).Once() db.Studio.On("GetImage", testCtx, studioID).Return(imageBytes, nil).Once()
mockStudioReader.On("GetImage", ctx, noImageID).Return(nil, nil).Once() db.Studio.On("GetImage", testCtx, noImageID).Return(nil, nil).Once()
mockStudioReader.On("GetImage", ctx, errImageID).Return(nil, imageErr).Once() db.Studio.On("GetImage", testCtx, errImageID).Return(nil, imageErr).Once()
mockStudioReader.On("GetImage", ctx, missingParentStudioID).Return(imageBytes, nil).Maybe() db.Studio.On("GetImage", testCtx, missingParentStudioID).Return(imageBytes, nil).Maybe()
mockStudioReader.On("GetImage", ctx, errStudioID).Return(imageBytes, nil).Maybe() db.Studio.On("GetImage", testCtx, errStudioID).Return(imageBytes, nil).Maybe()
parentStudioErr := errors.New("error getting parent studio") parentStudioErr := errors.New("error getting parent studio")
mockStudioReader.On("Find", ctx, parentStudioID).Return(&parentStudio, nil) db.Studio.On("Find", testCtx, parentStudioID).Return(&parentStudio, nil)
mockStudioReader.On("Find", ctx, missingStudioID).Return(nil, nil) db.Studio.On("Find", testCtx, missingStudioID).Return(nil, nil)
mockStudioReader.On("Find", ctx, errParentStudioID).Return(nil, parentStudioErr) db.Studio.On("Find", testCtx, errParentStudioID).Return(nil, parentStudioErr)
for i, s := range scenarios { for i, s := range scenarios {
studio := s.input studio := s.input
json, err := ToJSON(ctx, mockStudioReader, &studio) json, err := ToJSON(testCtx, db.Studio, &studio)
switch { switch {
case !s.err && err != nil: case !s.err && err != nil:
@@ -194,5 +192,5 @@ func TestToJSON(t *testing.T) {
} }
} }
mockStudioReader.AssertExpectations(t) db.AssertExpectations(t)
} }

View File

@@ -25,6 +25,8 @@ const (
missingParentStudioName = "existingParentStudioName" missingParentStudioName = "existingParentStudioName"
) )
var testCtx = context.Background()
func TestImporterName(t *testing.T) { func TestImporterName(t *testing.T) {
i := Importer{ i := Importer{
Input: jsonschema.Studio{ Input: jsonschema.Studio{
@@ -43,22 +45,21 @@ func TestImporterPreImport(t *testing.T) {
IgnoreAutoTag: autoTagIgnored, IgnoreAutoTag: autoTagIgnored,
}, },
} }
ctx := context.Background()
err := i.PreImport(ctx) err := i.PreImport(testCtx)
assert.NotNil(t, err) assert.NotNil(t, err)
i.Input.Image = image i.Input.Image = image
err = i.PreImport(ctx) err = i.PreImport(testCtx)
assert.Nil(t, err) assert.Nil(t, err)
i.Input = *createFullJSONStudio(studioName, image, []string{"alias"}) i.Input = *createFullJSONStudio(studioName, image, []string{"alias"})
i.Input.ParentStudio = "" i.Input.ParentStudio = ""
err = i.PreImport(ctx) err = i.PreImport(testCtx)
assert.Nil(t, err) assert.Nil(t, err)
expectedStudio := createFullStudio(0, 0) expectedStudio := createFullStudio(0, 0)
@@ -67,11 +68,10 @@ func TestImporterPreImport(t *testing.T) {
} }
func TestImporterPreImportWithParent(t *testing.T) { func TestImporterPreImportWithParent(t *testing.T) {
readerWriter := &mocks.StudioReaderWriter{} db := mocks.NewDatabase()
ctx := context.Background()
i := Importer{ i := Importer{
ReaderWriter: readerWriter, ReaderWriter: db.Studio,
Input: jsonschema.Studio{ Input: jsonschema.Studio{
Name: studioName, Name: studioName,
Image: image, Image: image,
@@ -79,28 +79,27 @@ func TestImporterPreImportWithParent(t *testing.T) {
}, },
} }
readerWriter.On("FindByName", ctx, existingParentStudioName, false).Return(&models.Studio{ db.Studio.On("FindByName", testCtx, existingParentStudioName, false).Return(&models.Studio{
ID: existingStudioID, ID: existingStudioID,
}, nil).Once() }, nil).Once()
readerWriter.On("FindByName", ctx, existingParentStudioErr, false).Return(nil, errors.New("FindByName error")).Once() db.Studio.On("FindByName", testCtx, existingParentStudioErr, false).Return(nil, errors.New("FindByName error")).Once()
err := i.PreImport(ctx) err := i.PreImport(testCtx)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, existingStudioID, *i.studio.ParentID) assert.Equal(t, existingStudioID, *i.studio.ParentID)
i.Input.ParentStudio = existingParentStudioErr i.Input.ParentStudio = existingParentStudioErr
err = i.PreImport(ctx) err = i.PreImport(testCtx)
assert.NotNil(t, err) assert.NotNil(t, err)
readerWriter.AssertExpectations(t) db.AssertExpectations(t)
} }
func TestImporterPreImportWithMissingParent(t *testing.T) { func TestImporterPreImportWithMissingParent(t *testing.T) {
readerWriter := &mocks.StudioReaderWriter{} db := mocks.NewDatabase()
ctx := context.Background()
i := Importer{ i := Importer{
ReaderWriter: readerWriter, ReaderWriter: db.Studio,
Input: jsonschema.Studio{ Input: jsonschema.Studio{
Name: studioName, Name: studioName,
Image: image, Image: image,
@@ -109,33 +108,32 @@ func TestImporterPreImportWithMissingParent(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail, MissingRefBehaviour: models.ImportMissingRefEnumFail,
} }
readerWriter.On("FindByName", ctx, missingParentStudioName, false).Return(nil, nil).Times(3) db.Studio.On("FindByName", testCtx, missingParentStudioName, false).Return(nil, nil).Times(3)
readerWriter.On("Create", ctx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) { db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
s := args.Get(1).(*models.Studio) s := args.Get(1).(*models.Studio)
s.ID = existingStudioID s.ID = existingStudioID
}).Return(nil) }).Return(nil)
err := i.PreImport(ctx) err := i.PreImport(testCtx)
assert.NotNil(t, err) assert.NotNil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore
err = i.PreImport(ctx) err = i.PreImport(testCtx)
assert.Nil(t, err) assert.Nil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumCreate i.MissingRefBehaviour = models.ImportMissingRefEnumCreate
err = i.PreImport(ctx) err = i.PreImport(testCtx)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, existingStudioID, *i.studio.ParentID) assert.Equal(t, existingStudioID, *i.studio.ParentID)
readerWriter.AssertExpectations(t) db.AssertExpectations(t)
} }
func TestImporterPreImportWithMissingParentCreateErr(t *testing.T) { func TestImporterPreImportWithMissingParentCreateErr(t *testing.T) {
readerWriter := &mocks.StudioReaderWriter{} db := mocks.NewDatabase()
ctx := context.Background()
i := Importer{ i := Importer{
ReaderWriter: readerWriter, ReaderWriter: db.Studio,
Input: jsonschema.Studio{ Input: jsonschema.Studio{
Name: studioName, Name: studioName,
Image: image, Image: image,
@@ -144,19 +142,20 @@ func TestImporterPreImportWithMissingParentCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate, MissingRefBehaviour: models.ImportMissingRefEnumCreate,
} }
readerWriter.On("FindByName", ctx, missingParentStudioName, false).Return(nil, nil).Once() db.Studio.On("FindByName", testCtx, missingParentStudioName, false).Return(nil, nil).Once()
readerWriter.On("Create", ctx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error")) db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
err := i.PreImport(ctx) err := i.PreImport(testCtx)
assert.NotNil(t, err) assert.NotNil(t, err)
db.AssertExpectations(t)
} }
func TestImporterPostImport(t *testing.T) { func TestImporterPostImport(t *testing.T) {
readerWriter := &mocks.StudioReaderWriter{} db := mocks.NewDatabase()
ctx := context.Background()
i := Importer{ i := Importer{
ReaderWriter: readerWriter, ReaderWriter: db.Studio,
Input: jsonschema.Studio{ Input: jsonschema.Studio{
Aliases: []string{"alias"}, Aliases: []string{"alias"},
}, },
@@ -165,56 +164,54 @@ func TestImporterPostImport(t *testing.T) {
updateStudioImageErr := errors.New("UpdateImage error") updateStudioImageErr := errors.New("UpdateImage error")
readerWriter.On("UpdateImage", ctx, studioID, imageBytes).Return(nil).Once() db.Studio.On("UpdateImage", testCtx, studioID, imageBytes).Return(nil).Once()
readerWriter.On("UpdateImage", ctx, errImageID, imageBytes).Return(updateStudioImageErr).Once() db.Studio.On("UpdateImage", testCtx, errImageID, imageBytes).Return(updateStudioImageErr).Once()
err := i.PostImport(ctx, studioID) err := i.PostImport(testCtx, studioID)
assert.Nil(t, err) assert.Nil(t, err)
err = i.PostImport(ctx, errImageID) err = i.PostImport(testCtx, errImageID)
assert.NotNil(t, err) assert.NotNil(t, err)
readerWriter.AssertExpectations(t) db.AssertExpectations(t)
} }
func TestImporterFindExistingID(t *testing.T) { func TestImporterFindExistingID(t *testing.T) {
readerWriter := &mocks.StudioReaderWriter{} db := mocks.NewDatabase()
ctx := context.Background()
i := Importer{ i := Importer{
ReaderWriter: readerWriter, ReaderWriter: db.Studio,
Input: jsonschema.Studio{ Input: jsonschema.Studio{
Name: studioName, Name: studioName,
}, },
} }
errFindByName := errors.New("FindByName error") errFindByName := errors.New("FindByName error")
readerWriter.On("FindByName", ctx, studioName, false).Return(nil, nil).Once() db.Studio.On("FindByName", testCtx, studioName, false).Return(nil, nil).Once()
readerWriter.On("FindByName", ctx, existingStudioName, false).Return(&models.Studio{ db.Studio.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{
ID: existingStudioID, ID: existingStudioID,
}, nil).Once() }, nil).Once()
readerWriter.On("FindByName", ctx, studioNameErr, false).Return(nil, errFindByName).Once() db.Studio.On("FindByName", testCtx, studioNameErr, false).Return(nil, errFindByName).Once()
id, err := i.FindExistingID(ctx) id, err := i.FindExistingID(testCtx)
assert.Nil(t, id) assert.Nil(t, id)
assert.Nil(t, err) assert.Nil(t, err)
i.Input.Name = existingStudioName i.Input.Name = existingStudioName
id, err = i.FindExistingID(ctx) id, err = i.FindExistingID(testCtx)
assert.Equal(t, existingStudioID, *id) assert.Equal(t, existingStudioID, *id)
assert.Nil(t, err) assert.Nil(t, err)
i.Input.Name = studioNameErr i.Input.Name = studioNameErr
id, err = i.FindExistingID(ctx) id, err = i.FindExistingID(testCtx)
assert.Nil(t, id) assert.Nil(t, id)
assert.NotNil(t, err) assert.NotNil(t, err)
readerWriter.AssertExpectations(t) db.AssertExpectations(t)
} }
func TestCreate(t *testing.T) { func TestCreate(t *testing.T) {
readerWriter := &mocks.StudioReaderWriter{} db := mocks.NewDatabase()
ctx := context.Background()
studio := models.Studio{ studio := models.Studio{
Name: studioName, Name: studioName,
@@ -225,32 +222,31 @@ func TestCreate(t *testing.T) {
} }
i := Importer{ i := Importer{
ReaderWriter: readerWriter, ReaderWriter: db.Studio,
studio: studio, studio: studio,
} }
errCreate := errors.New("Create error") errCreate := errors.New("Create error")
readerWriter.On("Create", ctx, &studio).Run(func(args mock.Arguments) { db.Studio.On("Create", testCtx, &studio).Run(func(args mock.Arguments) {
s := args.Get(1).(*models.Studio) s := args.Get(1).(*models.Studio)
s.ID = studioID s.ID = studioID
}).Return(nil).Once() }).Return(nil).Once()
readerWriter.On("Create", ctx, &studioErr).Return(errCreate).Once() db.Studio.On("Create", testCtx, &studioErr).Return(errCreate).Once()
id, err := i.Create(ctx) id, err := i.Create(testCtx)
assert.Equal(t, studioID, *id) assert.Equal(t, studioID, *id)
assert.Nil(t, err) assert.Nil(t, err)
i.studio = studioErr i.studio = studioErr
id, err = i.Create(ctx) id, err = i.Create(testCtx)
assert.Nil(t, id) assert.Nil(t, id)
assert.NotNil(t, err) assert.NotNil(t, err)
readerWriter.AssertExpectations(t) db.AssertExpectations(t)
} }
func TestUpdate(t *testing.T) { func TestUpdate(t *testing.T) {
readerWriter := &mocks.StudioReaderWriter{} db := mocks.NewDatabase()
ctx := context.Background()
studio := models.Studio{ studio := models.Studio{
Name: studioName, Name: studioName,
@@ -261,7 +257,7 @@ func TestUpdate(t *testing.T) {
} }
i := Importer{ i := Importer{
ReaderWriter: readerWriter, ReaderWriter: db.Studio,
studio: studio, studio: studio,
} }
@@ -269,19 +265,19 @@ func TestUpdate(t *testing.T) {
// id needs to be set for the mock input // id needs to be set for the mock input
studio.ID = studioID studio.ID = studioID
readerWriter.On("Update", ctx, &studio).Return(nil).Once() db.Studio.On("Update", testCtx, &studio).Return(nil).Once()
err := i.Update(ctx, studioID) err := i.Update(testCtx, studioID)
assert.Nil(t, err) assert.Nil(t, err)
i.studio = studioErr i.studio = studioErr
// need to set id separately // need to set id separately
studioErr.ID = errImageID studioErr.ID = errImageID
readerWriter.On("Update", ctx, &studioErr).Return(errUpdate).Once() db.Studio.On("Update", testCtx, &studioErr).Return(errUpdate).Once()
err = i.Update(ctx, errImageID) err = i.Update(testCtx, errImageID)
assert.NotNil(t, err) assert.NotNil(t, err)
readerWriter.AssertExpectations(t) db.AssertExpectations(t)
} }

View File

@@ -1,7 +1,6 @@
package tag package tag
import ( import (
"context"
"errors" "errors"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
@@ -109,35 +108,34 @@ func initTestTable() {
func TestToJSON(t *testing.T) { func TestToJSON(t *testing.T) {
initTestTable() initTestTable()
ctx := context.Background() db := mocks.NewDatabase()
mockTagReader := &mocks.TagReaderWriter{}
imageErr := errors.New("error getting image") imageErr := errors.New("error getting image")
aliasErr := errors.New("error getting aliases") aliasErr := errors.New("error getting aliases")
parentsErr := errors.New("error getting parents") parentsErr := errors.New("error getting parents")
mockTagReader.On("GetAliases", ctx, tagID).Return([]string{"alias"}, nil).Once() db.Tag.On("GetAliases", testCtx, tagID).Return([]string{"alias"}, nil).Once()
mockTagReader.On("GetAliases", ctx, noImageID).Return(nil, nil).Once() db.Tag.On("GetAliases", testCtx, noImageID).Return(nil, nil).Once()
mockTagReader.On("GetAliases", ctx, errImageID).Return(nil, nil).Once() db.Tag.On("GetAliases", testCtx, errImageID).Return(nil, nil).Once()
mockTagReader.On("GetAliases", ctx, errAliasID).Return(nil, aliasErr).Once() db.Tag.On("GetAliases", testCtx, errAliasID).Return(nil, aliasErr).Once()
mockTagReader.On("GetAliases", ctx, withParentsID).Return(nil, nil).Once() db.Tag.On("GetAliases", testCtx, withParentsID).Return(nil, nil).Once()
mockTagReader.On("GetAliases", ctx, errParentsID).Return(nil, nil).Once() db.Tag.On("GetAliases", testCtx, errParentsID).Return(nil, nil).Once()
mockTagReader.On("GetImage", ctx, tagID).Return(imageBytes, nil).Once() db.Tag.On("GetImage", testCtx, tagID).Return(imageBytes, nil).Once()
mockTagReader.On("GetImage", ctx, noImageID).Return(nil, nil).Once() db.Tag.On("GetImage", testCtx, noImageID).Return(nil, nil).Once()
mockTagReader.On("GetImage", ctx, errImageID).Return(nil, imageErr).Once() db.Tag.On("GetImage", testCtx, errImageID).Return(nil, imageErr).Once()
mockTagReader.On("GetImage", ctx, withParentsID).Return(imageBytes, nil).Once() db.Tag.On("GetImage", testCtx, withParentsID).Return(imageBytes, nil).Once()
mockTagReader.On("GetImage", ctx, errParentsID).Return(nil, nil).Once() db.Tag.On("GetImage", testCtx, errParentsID).Return(nil, nil).Once()
mockTagReader.On("FindByChildTagID", ctx, tagID).Return(nil, nil).Once() db.Tag.On("FindByChildTagID", testCtx, tagID).Return(nil, nil).Once()
mockTagReader.On("FindByChildTagID", ctx, noImageID).Return(nil, nil).Once() db.Tag.On("FindByChildTagID", testCtx, noImageID).Return(nil, nil).Once()
mockTagReader.On("FindByChildTagID", ctx, withParentsID).Return([]*models.Tag{{Name: "parent"}}, nil).Once() db.Tag.On("FindByChildTagID", testCtx, withParentsID).Return([]*models.Tag{{Name: "parent"}}, nil).Once()
mockTagReader.On("FindByChildTagID", ctx, errParentsID).Return(nil, parentsErr).Once() db.Tag.On("FindByChildTagID", testCtx, errParentsID).Return(nil, parentsErr).Once()
mockTagReader.On("FindByChildTagID", ctx, errImageID).Return(nil, nil).Once() db.Tag.On("FindByChildTagID", testCtx, errImageID).Return(nil, nil).Once()
for i, s := range scenarios { for i, s := range scenarios {
tag := s.tag tag := s.tag
json, err := ToJSON(ctx, mockTagReader, &tag) json, err := ToJSON(testCtx, db.Tag, &tag)
switch { switch {
case !s.err && err != nil: case !s.err && err != nil:
@@ -149,5 +147,5 @@ func TestToJSON(t *testing.T) {
} }
} }
mockTagReader.AssertExpectations(t) db.AssertExpectations(t)
} }

View File

@@ -58,10 +58,10 @@ func TestImporterPreImport(t *testing.T) {
} }
func TestImporterPostImport(t *testing.T) { func TestImporterPostImport(t *testing.T) {
readerWriter := &mocks.TagReaderWriter{} db := mocks.NewDatabase()
i := Importer{ i := Importer{
ReaderWriter: readerWriter, ReaderWriter: db.Tag,
Input: jsonschema.Tag{ Input: jsonschema.Tag{
Aliases: []string{"alias"}, Aliases: []string{"alias"},
}, },
@@ -72,23 +72,23 @@ func TestImporterPostImport(t *testing.T) {
updateTagAliasErr := errors.New("UpdateAlias error") updateTagAliasErr := errors.New("UpdateAlias error")
updateTagParentsErr := errors.New("UpdateParentTags error") updateTagParentsErr := errors.New("UpdateParentTags error")
readerWriter.On("UpdateAliases", testCtx, tagID, i.Input.Aliases).Return(nil).Once() db.Tag.On("UpdateAliases", testCtx, tagID, i.Input.Aliases).Return(nil).Once()
readerWriter.On("UpdateAliases", testCtx, errAliasID, i.Input.Aliases).Return(updateTagAliasErr).Once() db.Tag.On("UpdateAliases", testCtx, errAliasID, i.Input.Aliases).Return(updateTagAliasErr).Once()
readerWriter.On("UpdateAliases", testCtx, withParentsID, i.Input.Aliases).Return(nil).Once() db.Tag.On("UpdateAliases", testCtx, withParentsID, i.Input.Aliases).Return(nil).Once()
readerWriter.On("UpdateAliases", testCtx, errParentsID, i.Input.Aliases).Return(nil).Once() db.Tag.On("UpdateAliases", testCtx, errParentsID, i.Input.Aliases).Return(nil).Once()
readerWriter.On("UpdateImage", testCtx, tagID, imageBytes).Return(nil).Once() db.Tag.On("UpdateImage", testCtx, tagID, imageBytes).Return(nil).Once()
readerWriter.On("UpdateImage", testCtx, errAliasID, imageBytes).Return(nil).Once() db.Tag.On("UpdateImage", testCtx, errAliasID, imageBytes).Return(nil).Once()
readerWriter.On("UpdateImage", testCtx, errImageID, imageBytes).Return(updateTagImageErr).Once() db.Tag.On("UpdateImage", testCtx, errImageID, imageBytes).Return(updateTagImageErr).Once()
readerWriter.On("UpdateImage", testCtx, withParentsID, imageBytes).Return(nil).Once() db.Tag.On("UpdateImage", testCtx, withParentsID, imageBytes).Return(nil).Once()
readerWriter.On("UpdateImage", testCtx, errParentsID, imageBytes).Return(nil).Once() db.Tag.On("UpdateImage", testCtx, errParentsID, imageBytes).Return(nil).Once()
var parentTags []int var parentTags []int
readerWriter.On("UpdateParentTags", testCtx, tagID, parentTags).Return(nil).Once() db.Tag.On("UpdateParentTags", testCtx, tagID, parentTags).Return(nil).Once()
readerWriter.On("UpdateParentTags", testCtx, withParentsID, []int{100}).Return(nil).Once() db.Tag.On("UpdateParentTags", testCtx, withParentsID, []int{100}).Return(nil).Once()
readerWriter.On("UpdateParentTags", testCtx, errParentsID, []int{100}).Return(updateTagParentsErr).Once() db.Tag.On("UpdateParentTags", testCtx, errParentsID, []int{100}).Return(updateTagParentsErr).Once()
readerWriter.On("FindByName", testCtx, "Parent", false).Return(&models.Tag{ID: 100}, nil) db.Tag.On("FindByName", testCtx, "Parent", false).Return(&models.Tag{ID: 100}, nil)
err := i.PostImport(testCtx, tagID) err := i.PostImport(testCtx, tagID)
assert.Nil(t, err) assert.Nil(t, err)
@@ -106,14 +106,14 @@ func TestImporterPostImport(t *testing.T) {
err = i.PostImport(testCtx, errParentsID) err = i.PostImport(testCtx, errParentsID)
assert.NotNil(t, err) assert.NotNil(t, err)
readerWriter.AssertExpectations(t) db.AssertExpectations(t)
} }
func TestImporterPostImportParentMissing(t *testing.T) { func TestImporterPostImportParentMissing(t *testing.T) {
readerWriter := &mocks.TagReaderWriter{} db := mocks.NewDatabase()
i := Importer{ i := Importer{
ReaderWriter: readerWriter, ReaderWriter: db.Tag,
Input: jsonschema.Tag{}, Input: jsonschema.Tag{},
imageData: imageBytes, imageData: imageBytes,
} }
@@ -133,33 +133,33 @@ func TestImporterPostImportParentMissing(t *testing.T) {
var emptyParents []int var emptyParents []int
readerWriter.On("UpdateImage", testCtx, mock.Anything, mock.Anything).Return(nil) db.Tag.On("UpdateImage", testCtx, mock.Anything, mock.Anything).Return(nil)
readerWriter.On("UpdateAliases", testCtx, mock.Anything, mock.Anything).Return(nil) db.Tag.On("UpdateAliases", testCtx, mock.Anything, mock.Anything).Return(nil)
readerWriter.On("FindByName", testCtx, "Create", false).Return(nil, nil).Once() db.Tag.On("FindByName", testCtx, "Create", false).Return(nil, nil).Once()
readerWriter.On("FindByName", testCtx, "CreateError", false).Return(nil, nil).Once() db.Tag.On("FindByName", testCtx, "CreateError", false).Return(nil, nil).Once()
readerWriter.On("FindByName", testCtx, "CreateFindError", false).Return(nil, findError).Once() db.Tag.On("FindByName", testCtx, "CreateFindError", false).Return(nil, findError).Once()
readerWriter.On("FindByName", testCtx, "CreateFound", false).Return(&models.Tag{ID: 101}, nil).Once() db.Tag.On("FindByName", testCtx, "CreateFound", false).Return(&models.Tag{ID: 101}, nil).Once()
readerWriter.On("FindByName", testCtx, "Fail", false).Return(nil, nil).Once() db.Tag.On("FindByName", testCtx, "Fail", false).Return(nil, nil).Once()
readerWriter.On("FindByName", testCtx, "FailFindError", false).Return(nil, findError) db.Tag.On("FindByName", testCtx, "FailFindError", false).Return(nil, findError)
readerWriter.On("FindByName", testCtx, "FailFound", false).Return(&models.Tag{ID: 102}, nil).Once() db.Tag.On("FindByName", testCtx, "FailFound", false).Return(&models.Tag{ID: 102}, nil).Once()
readerWriter.On("FindByName", testCtx, "Ignore", false).Return(nil, nil).Once() db.Tag.On("FindByName", testCtx, "Ignore", false).Return(nil, nil).Once()
readerWriter.On("FindByName", testCtx, "IgnoreFindError", false).Return(nil, findError) db.Tag.On("FindByName", testCtx, "IgnoreFindError", false).Return(nil, findError)
readerWriter.On("FindByName", testCtx, "IgnoreFound", false).Return(&models.Tag{ID: 103}, nil).Once() db.Tag.On("FindByName", testCtx, "IgnoreFound", false).Return(&models.Tag{ID: 103}, nil).Once()
readerWriter.On("UpdateParentTags", testCtx, createID, []int{100}).Return(nil).Once() db.Tag.On("UpdateParentTags", testCtx, createID, []int{100}).Return(nil).Once()
readerWriter.On("UpdateParentTags", testCtx, createFoundID, []int{101}).Return(nil).Once() db.Tag.On("UpdateParentTags", testCtx, createFoundID, []int{101}).Return(nil).Once()
readerWriter.On("UpdateParentTags", testCtx, failFoundID, []int{102}).Return(nil).Once() db.Tag.On("UpdateParentTags", testCtx, failFoundID, []int{102}).Return(nil).Once()
readerWriter.On("UpdateParentTags", testCtx, ignoreID, emptyParents).Return(nil).Once() db.Tag.On("UpdateParentTags", testCtx, ignoreID, emptyParents).Return(nil).Once()
readerWriter.On("UpdateParentTags", testCtx, ignoreFoundID, []int{103}).Return(nil).Once() db.Tag.On("UpdateParentTags", testCtx, ignoreFoundID, []int{103}).Return(nil).Once()
readerWriter.On("Create", testCtx, mock.MatchedBy(func(t *models.Tag) bool { db.Tag.On("Create", testCtx, mock.MatchedBy(func(t *models.Tag) bool {
return t.Name == "Create" return t.Name == "Create"
})).Run(func(args mock.Arguments) { })).Run(func(args mock.Arguments) {
t := args.Get(1).(*models.Tag) t := args.Get(1).(*models.Tag)
t.ID = 100 t.ID = 100
}).Return(nil).Once() }).Return(nil).Once()
readerWriter.On("Create", testCtx, mock.MatchedBy(func(t *models.Tag) bool { db.Tag.On("Create", testCtx, mock.MatchedBy(func(t *models.Tag) bool {
return t.Name == "CreateError" return t.Name == "CreateError"
})).Return(errors.New("failed creating parent")).Once() })).Return(errors.New("failed creating parent")).Once()
@@ -206,25 +206,25 @@ func TestImporterPostImportParentMissing(t *testing.T) {
err = i.PostImport(testCtx, ignoreFoundID) err = i.PostImport(testCtx, ignoreFoundID)
assert.Nil(t, err) assert.Nil(t, err)
readerWriter.AssertExpectations(t) db.AssertExpectations(t)
} }
func TestImporterFindExistingID(t *testing.T) { func TestImporterFindExistingID(t *testing.T) {
readerWriter := &mocks.TagReaderWriter{} db := mocks.NewDatabase()
i := Importer{ i := Importer{
ReaderWriter: readerWriter, ReaderWriter: db.Tag,
Input: jsonschema.Tag{ Input: jsonschema.Tag{
Name: tagName, Name: tagName,
}, },
} }
errFindByName := errors.New("FindByName error") errFindByName := errors.New("FindByName error")
readerWriter.On("FindByName", testCtx, tagName, false).Return(nil, nil).Once() db.Tag.On("FindByName", testCtx, tagName, false).Return(nil, nil).Once()
readerWriter.On("FindByName", testCtx, existingTagName, false).Return(&models.Tag{ db.Tag.On("FindByName", testCtx, existingTagName, false).Return(&models.Tag{
ID: existingTagID, ID: existingTagID,
}, nil).Once() }, nil).Once()
readerWriter.On("FindByName", testCtx, tagNameErr, false).Return(nil, errFindByName).Once() db.Tag.On("FindByName", testCtx, tagNameErr, false).Return(nil, errFindByName).Once()
id, err := i.FindExistingID(testCtx) id, err := i.FindExistingID(testCtx)
assert.Nil(t, id) assert.Nil(t, id)
@@ -240,11 +240,11 @@ func TestImporterFindExistingID(t *testing.T) {
assert.Nil(t, id) assert.Nil(t, id)
assert.NotNil(t, err) assert.NotNil(t, err)
readerWriter.AssertExpectations(t) db.AssertExpectations(t)
} }
func TestCreate(t *testing.T) { func TestCreate(t *testing.T) {
readerWriter := &mocks.TagReaderWriter{} db := mocks.NewDatabase()
tag := models.Tag{ tag := models.Tag{
Name: tagName, Name: tagName,
@@ -255,16 +255,16 @@ func TestCreate(t *testing.T) {
} }
i := Importer{ i := Importer{
ReaderWriter: readerWriter, ReaderWriter: db.Tag,
tag: tag, tag: tag,
} }
errCreate := errors.New("Create error") errCreate := errors.New("Create error")
readerWriter.On("Create", testCtx, &tag).Run(func(args mock.Arguments) { db.Tag.On("Create", testCtx, &tag).Run(func(args mock.Arguments) {
t := args.Get(1).(*models.Tag) t := args.Get(1).(*models.Tag)
t.ID = tagID t.ID = tagID
}).Return(nil).Once() }).Return(nil).Once()
readerWriter.On("Create", testCtx, &tagErr).Return(errCreate).Once() db.Tag.On("Create", testCtx, &tagErr).Return(errCreate).Once()
id, err := i.Create(testCtx) id, err := i.Create(testCtx)
assert.Equal(t, tagID, *id) assert.Equal(t, tagID, *id)
@@ -275,11 +275,11 @@ func TestCreate(t *testing.T) {
assert.Nil(t, id) assert.Nil(t, id)
assert.NotNil(t, err) assert.NotNil(t, err)
readerWriter.AssertExpectations(t) db.AssertExpectations(t)
} }
func TestUpdate(t *testing.T) { func TestUpdate(t *testing.T) {
readerWriter := &mocks.TagReaderWriter{} db := mocks.NewDatabase()
tag := models.Tag{ tag := models.Tag{
Name: tagName, Name: tagName,
@@ -290,7 +290,7 @@ func TestUpdate(t *testing.T) {
} }
i := Importer{ i := Importer{
ReaderWriter: readerWriter, ReaderWriter: db.Tag,
tag: tag, tag: tag,
} }
@@ -298,7 +298,7 @@ func TestUpdate(t *testing.T) {
// id needs to be set for the mock input // id needs to be set for the mock input
tag.ID = tagID tag.ID = tagID
readerWriter.On("Update", testCtx, &tag).Return(nil).Once() db.Tag.On("Update", testCtx, &tag).Return(nil).Once()
err := i.Update(testCtx, tagID) err := i.Update(testCtx, tagID)
assert.Nil(t, err) assert.Nil(t, err)
@@ -307,10 +307,10 @@ func TestUpdate(t *testing.T) {
// need to set id separately // need to set id separately
tagErr.ID = errImageID tagErr.ID = errImageID
readerWriter.On("Update", testCtx, &tagErr).Return(errUpdate).Once() db.Tag.On("Update", testCtx, &tagErr).Return(errUpdate).Once()
err = i.Update(testCtx, errImageID) err = i.Update(testCtx, errImageID)
assert.NotNil(t, err) assert.NotNil(t, err)
readerWriter.AssertExpectations(t) db.AssertExpectations(t)
} }

View File

@@ -219,8 +219,7 @@ func TestEnsureHierarchy(t *testing.T) {
} }
func testEnsureHierarchy(t *testing.T, tc testUniqueHierarchyCase, queryParents, queryChildren bool) { func testEnsureHierarchy(t *testing.T, tc testUniqueHierarchyCase, queryParents, queryChildren bool) {
mockTagReader := &mocks.TagReaderWriter{} db := mocks.NewDatabase()
ctx := context.Background()
var parentIDs, childIDs []int var parentIDs, childIDs []int
find := make(map[int]*models.Tag) find := make(map[int]*models.Tag)
@@ -247,15 +246,15 @@ func testEnsureHierarchy(t *testing.T, tc testUniqueHierarchyCase, queryParents,
if queryParents { if queryParents {
parentIDs = nil parentIDs = nil
mockTagReader.On("FindByChildTagID", ctx, tc.id).Return(tc.parents, nil).Once() db.Tag.On("FindByChildTagID", testCtx, tc.id).Return(tc.parents, nil).Once()
} }
if queryChildren { if queryChildren {
childIDs = nil childIDs = nil
mockTagReader.On("FindByParentTagID", ctx, tc.id).Return(tc.children, nil).Once() db.Tag.On("FindByParentTagID", testCtx, tc.id).Return(tc.children, nil).Once()
} }
mockTagReader.On("FindAllAncestors", ctx, mock.AnythingOfType("int"), []int(nil)).Return(func(ctx context.Context, tagID int, excludeIDs []int) []*models.TagPath { db.Tag.On("FindAllAncestors", testCtx, mock.AnythingOfType("int"), []int(nil)).Return(func(ctx context.Context, tagID int, excludeIDs []int) []*models.TagPath {
return tc.onFindAllAncestors return tc.onFindAllAncestors
}, func(ctx context.Context, tagID int, excludeIDs []int) error { }, func(ctx context.Context, tagID int, excludeIDs []int) error {
if tc.onFindAllAncestors != nil { if tc.onFindAllAncestors != nil {
@@ -264,7 +263,7 @@ func testEnsureHierarchy(t *testing.T, tc testUniqueHierarchyCase, queryParents,
return fmt.Errorf("undefined ancestors for: %d", tagID) return fmt.Errorf("undefined ancestors for: %d", tagID)
}).Maybe() }).Maybe()
mockTagReader.On("FindAllDescendants", ctx, mock.AnythingOfType("int"), []int(nil)).Return(func(ctx context.Context, tagID int, excludeIDs []int) []*models.TagPath { db.Tag.On("FindAllDescendants", testCtx, mock.AnythingOfType("int"), []int(nil)).Return(func(ctx context.Context, tagID int, excludeIDs []int) []*models.TagPath {
return tc.onFindAllDescendants return tc.onFindAllDescendants
}, func(ctx context.Context, tagID int, excludeIDs []int) error { }, func(ctx context.Context, tagID int, excludeIDs []int) error {
if tc.onFindAllDescendants != nil { if tc.onFindAllDescendants != nil {
@@ -273,7 +272,7 @@ func testEnsureHierarchy(t *testing.T, tc testUniqueHierarchyCase, queryParents,
return fmt.Errorf("undefined descendants for: %d", tagID) return fmt.Errorf("undefined descendants for: %d", tagID)
}).Maybe() }).Maybe()
res := ValidateHierarchy(ctx, testUniqueHierarchyTags[tc.id], parentIDs, childIDs, mockTagReader) res := ValidateHierarchy(testCtx, testUniqueHierarchyTags[tc.id], parentIDs, childIDs, db.Tag)
assert := assert.New(t) assert := assert.New(t)
@@ -285,5 +284,5 @@ func testEnsureHierarchy(t *testing.T, tc testUniqueHierarchyCase, queryParents,
assert.Nil(res) assert.Nil(res)
} }
mockTagReader.AssertExpectations(t) db.AssertExpectations(t)
} }