Model refactor, part 3 (#4152)

* Remove manager.Repository
* Refactor other repositories
* Fix tests and add database mock
* Add AssertExpectations method
* Refactor routes
* Move default movie image to internal/static and add convenience methods
* Refactor default performer image boxes
This commit is contained in:
DingDongSoLong4
2023-10-16 05:26:34 +02:00
committed by GitHub
parent 40bcb4baa5
commit 33f2ebf2a3
87 changed files with 1843 additions and 1651 deletions

View File

@@ -11,7 +11,6 @@ import (
"github.com/stashapp/stash/pkg/job"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
)
// Cleaner scans through stored file and folder instances and removes those that are no longer present on disk.
@@ -112,14 +111,15 @@ func (j *cleanJob) execute(ctx context.Context) error {
folderCount int
)
if err := txn.WithReadTxn(ctx, j.Repository, func(ctx context.Context) error {
r := j.Repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
var err error
fileCount, err = j.Repository.FileStore.CountAllInPaths(ctx, j.options.Paths)
fileCount, err = r.File.CountAllInPaths(ctx, j.options.Paths)
if err != nil {
return err
}
folderCount, err = j.Repository.FolderStore.CountAllInPaths(ctx, j.options.Paths)
folderCount, err = r.Folder.CountAllInPaths(ctx, j.options.Paths)
if err != nil {
return err
}
@@ -172,13 +172,14 @@ func (j *cleanJob) assessFiles(ctx context.Context, toDelete *deleteSet) error {
progress := j.progress
more := true
if err := txn.WithReadTxn(ctx, j.Repository, func(ctx context.Context) error {
r := j.Repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
for more {
if job.IsCancelled(ctx) {
return nil
}
files, err := j.Repository.FileStore.FindAllInPaths(ctx, j.options.Paths, batchSize, offset)
files, err := r.File.FindAllInPaths(ctx, j.options.Paths, batchSize, offset)
if err != nil {
return fmt.Errorf("error querying for files: %w", err)
}
@@ -223,8 +224,9 @@ func (j *cleanJob) assessFiles(ctx context.Context, toDelete *deleteSet) error {
// flagFolderForDelete adds folders to the toDelete set, with the leaf folders added first
func (j *cleanJob) flagFileForDelete(ctx context.Context, toDelete *deleteSet, f models.File) error {
r := j.Repository
// add contained files first
containedFiles, err := j.Repository.FileStore.FindByZipFileID(ctx, f.Base().ID)
containedFiles, err := r.File.FindByZipFileID(ctx, f.Base().ID)
if err != nil {
return fmt.Errorf("error finding contained files for %q: %w", f.Base().Path, err)
}
@@ -235,7 +237,7 @@ func (j *cleanJob) flagFileForDelete(ctx context.Context, toDelete *deleteSet, f
}
// add contained folders as well
containedFolders, err := j.Repository.FolderStore.FindByZipFileID(ctx, f.Base().ID)
containedFolders, err := r.Folder.FindByZipFileID(ctx, f.Base().ID)
if err != nil {
return fmt.Errorf("error finding contained folders for %q: %w", f.Base().Path, err)
}
@@ -256,13 +258,14 @@ func (j *cleanJob) assessFolders(ctx context.Context, toDelete *deleteSet) error
progress := j.progress
more := true
if err := txn.WithReadTxn(ctx, j.Repository, func(ctx context.Context) error {
r := j.Repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
for more {
if job.IsCancelled(ctx) {
return nil
}
folders, err := j.Repository.FolderStore.FindAllInPaths(ctx, j.options.Paths, batchSize, offset)
folders, err := r.Folder.FindAllInPaths(ctx, j.options.Paths, batchSize, offset)
if err != nil {
return fmt.Errorf("error querying for folders: %w", err)
}
@@ -380,14 +383,15 @@ func (j *cleanJob) shouldCleanFolder(ctx context.Context, f *models.Folder) bool
func (j *cleanJob) deleteFile(ctx context.Context, fileID models.FileID, fn string) {
// delete associated objects
fileDeleter := NewDeleter()
if err := txn.WithTxn(ctx, j.Repository, func(ctx context.Context) error {
r := j.Repository
if err := r.WithTxn(ctx, func(ctx context.Context) error {
fileDeleter.RegisterHooks(ctx)
if err := j.fireHandlers(ctx, fileDeleter, fileID); err != nil {
return err
}
return j.Repository.FileStore.Destroy(ctx, fileID)
return r.File.Destroy(ctx, fileID)
}); err != nil {
logger.Errorf("Error deleting file %q from database: %s", fn, err.Error())
return
@@ -397,14 +401,15 @@ func (j *cleanJob) deleteFile(ctx context.Context, fileID models.FileID, fn stri
func (j *cleanJob) deleteFolder(ctx context.Context, folderID models.FolderID, fn string) {
// delete associated objects
fileDeleter := NewDeleter()
if err := txn.WithTxn(ctx, j.Repository, func(ctx context.Context) error {
r := j.Repository
if err := r.WithTxn(ctx, func(ctx context.Context) error {
fileDeleter.RegisterHooks(ctx)
if err := j.fireFolderHandlers(ctx, fileDeleter, folderID); err != nil {
return err
}
return j.Repository.FolderStore.Destroy(ctx, folderID)
return r.Folder.Destroy(ctx, folderID)
}); err != nil {
logger.Errorf("Error deleting folder %q from database: %s", fn, err.Error())
return

View File

@@ -1,15 +1,36 @@
package file
import (
"context"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
)
// Repository provides access to storage methods for files and folders.
type Repository struct {
txn.Manager
txn.DatabaseProvider
TxnManager models.TxnManager
FileStore models.FileReaderWriter
FolderStore models.FolderReaderWriter
File models.FileReaderWriter
Folder models.FolderReaderWriter
}
func NewRepository(repo models.Repository) Repository {
return Repository{
TxnManager: repo.TxnManager,
File: repo.File,
Folder: repo.Folder,
}
}
func (r *Repository) WithTxn(ctx context.Context, fn txn.TxnFunc) error {
return txn.WithTxn(ctx, r.TxnManager, fn)
}
func (r *Repository) WithReadTxn(ctx context.Context, fn txn.TxnFunc) error {
return txn.WithReadTxn(ctx, r.TxnManager, fn)
}
func (r *Repository) WithDB(ctx context.Context, fn txn.TxnFunc) error {
return txn.WithDatabase(ctx, r.TxnManager, fn)
}

View File

@@ -86,6 +86,8 @@ func (s *scanJob) detectFolderMove(ctx context.Context, file scanFile) (*models.
}
// rejects is a set of folder ids which were found to still exist
r := s.Repository
if err := symWalk(file.fs, file.Path, func(path string, d fs.DirEntry, err error) error {
if err != nil {
// don't let errors prevent scanning
@@ -118,7 +120,7 @@ func (s *scanJob) detectFolderMove(ctx context.Context, file scanFile) (*models.
}
// check if the file exists in the database based on basename, size and mod time
existing, err := s.Repository.FileStore.FindByFileInfo(ctx, info, size)
existing, err := r.File.FindByFileInfo(ctx, info, size)
if err != nil {
return fmt.Errorf("checking for existing file %q: %w", path, err)
}
@@ -140,7 +142,7 @@ func (s *scanJob) detectFolderMove(ctx context.Context, file scanFile) (*models.
if c == nil {
// need to check if the folder exists in the filesystem
pf, err := s.Repository.FolderStore.Find(ctx, e.Base().ParentFolderID)
pf, err := r.Folder.Find(ctx, e.Base().ParentFolderID)
if err != nil {
return fmt.Errorf("getting parent folder %d: %w", e.Base().ParentFolderID, err)
}
@@ -164,7 +166,7 @@ func (s *scanJob) detectFolderMove(ctx context.Context, file scanFile) (*models.
// parent folder is missing, possible candidate
// count the total number of files in the existing folder
count, err := s.Repository.FileStore.CountByFolderID(ctx, parentFolderID)
count, err := r.File.CountByFolderID(ctx, parentFolderID)
if err != nil {
return fmt.Errorf("counting files in folder %d: %w", parentFolderID, err)
}

View File

@@ -181,7 +181,7 @@ func (m *Mover) moveFile(oldPath, newPath string) error {
return nil
}
func (m *Mover) RegisterHooks(ctx context.Context, mgr txn.Manager) {
func (m *Mover) RegisterHooks(ctx context.Context) {
txn.AddPostCommitHook(ctx, func(ctx context.Context) {
m.commit()
})

View File

@@ -144,7 +144,7 @@ func (s *Scanner) Scan(ctx context.Context, handlers []Handler, options ScanOpti
ProgressReports: progressReporter,
options: options,
txnRetryer: txn.Retryer{
Manager: s.Repository,
Manager: s.Repository.TxnManager,
Retries: maxRetries,
},
}
@@ -163,7 +163,7 @@ func (s *scanJob) withTxn(ctx context.Context, fn func(ctx context.Context) erro
}
func (s *scanJob) withDB(ctx context.Context, fn func(ctx context.Context) error) error {
return txn.WithDatabase(ctx, s.Repository, fn)
return s.Repository.WithDB(ctx, fn)
}
func (s *scanJob) execute(ctx context.Context) {
@@ -439,7 +439,7 @@ func (s *scanJob) getFolderID(ctx context.Context, path string) (*models.FolderI
return &v, nil
}
ret, err := s.Repository.FolderStore.FindByPath(ctx, path)
ret, err := s.Repository.Folder.FindByPath(ctx, path)
if err != nil {
return nil, err
}
@@ -469,7 +469,7 @@ func (s *scanJob) getZipFileID(ctx context.Context, zipFile *scanFile) (*models.
return &v, nil
}
ret, err := s.Repository.FileStore.FindByPath(ctx, path)
ret, err := s.Repository.File.FindByPath(ctx, path)
if err != nil {
return nil, fmt.Errorf("getting zip file ID for %q: %w", path, err)
}
@@ -489,7 +489,7 @@ func (s *scanJob) handleFolder(ctx context.Context, file scanFile) error {
defer s.incrementProgress(file)
// determine if folder already exists in data store (by path)
f, err := s.Repository.FolderStore.FindByPath(ctx, path)
f, err := s.Repository.Folder.FindByPath(ctx, path)
if err != nil {
return fmt.Errorf("checking for existing folder %q: %w", path, err)
}
@@ -553,7 +553,7 @@ func (s *scanJob) onNewFolder(ctx context.Context, file scanFile) (*models.Folde
logger.Infof("%s doesn't exist. Creating new folder entry...", file.Path)
})
if err := s.Repository.FolderStore.Create(ctx, toCreate); err != nil {
if err := s.Repository.Folder.Create(ctx, toCreate); err != nil {
return nil, fmt.Errorf("creating folder %q: %w", file.Path, err)
}
@@ -589,12 +589,12 @@ func (s *scanJob) handleFolderRename(ctx context.Context, file scanFile) (*model
renamedFrom.ParentFolderID = parentFolderID
if err := s.Repository.FolderStore.Update(ctx, renamedFrom); err != nil {
if err := s.Repository.Folder.Update(ctx, renamedFrom); err != nil {
return nil, fmt.Errorf("updating folder for rename %q: %w", renamedFrom.Path, err)
}
// #4146 - correct sub-folders to have the correct path
if err := correctSubFolderHierarchy(ctx, s.Repository.FolderStore, renamedFrom); err != nil {
if err := correctSubFolderHierarchy(ctx, s.Repository.Folder, renamedFrom); err != nil {
return nil, fmt.Errorf("correcting sub folder hierarchy for %q: %w", renamedFrom.Path, err)
}
@@ -626,7 +626,7 @@ func (s *scanJob) onExistingFolder(ctx context.Context, f scanFile, existing *mo
if update {
var err error
if err = s.Repository.FolderStore.Update(ctx, existing); err != nil {
if err = s.Repository.Folder.Update(ctx, existing); err != nil {
return nil, fmt.Errorf("updating folder %q: %w", f.Path, err)
}
}
@@ -647,7 +647,7 @@ func (s *scanJob) handleFile(ctx context.Context, f scanFile) error {
if err := s.withDB(ctx, func(ctx context.Context) error {
// determine if file already exists in data store
var err error
ff, err = s.Repository.FileStore.FindByPath(ctx, f.Path)
ff, err = s.Repository.File.FindByPath(ctx, f.Path)
if err != nil {
return fmt.Errorf("checking for existing file %q: %w", f.Path, err)
}
@@ -745,7 +745,7 @@ func (s *scanJob) onNewFile(ctx context.Context, f scanFile) (models.File, error
// if not renamed, queue file for creation
if err := s.withTxn(ctx, func(ctx context.Context) error {
if err := s.Repository.FileStore.Create(ctx, file); err != nil {
if err := s.Repository.File.Create(ctx, file); err != nil {
return fmt.Errorf("creating file %q: %w", path, err)
}
@@ -838,7 +838,7 @@ func (s *scanJob) handleRename(ctx context.Context, f models.File, fp []models.F
var others []models.File
for _, tfp := range fp {
thisOthers, err := s.Repository.FileStore.FindByFingerprint(ctx, tfp)
thisOthers, err := s.Repository.File.FindByFingerprint(ctx, tfp)
if err != nil {
return nil, fmt.Errorf("getting files by fingerprint %v: %w", tfp, err)
}
@@ -896,12 +896,12 @@ func (s *scanJob) handleRename(ctx context.Context, f models.File, fp []models.F
fBase.Fingerprints = otherBase.Fingerprints
if err := s.withTxn(ctx, func(ctx context.Context) error {
if err := s.Repository.FileStore.Update(ctx, f); err != nil {
if err := s.Repository.File.Update(ctx, f); err != nil {
return fmt.Errorf("updating file for rename %q: %w", fBase.Path, err)
}
if s.isZipFile(fBase.Basename) {
if err := TransferZipFolderHierarchy(ctx, s.Repository.FolderStore, fBase.ID, otherBase.Path, fBase.Path); err != nil {
if err := TransferZipFolderHierarchy(ctx, s.Repository.Folder, fBase.ID, otherBase.Path, fBase.Path); err != nil {
return fmt.Errorf("moving folder hierarchy for renamed zip file %q: %w", fBase.Path, err)
}
}
@@ -963,7 +963,7 @@ func (s *scanJob) setMissingMetadata(ctx context.Context, f scanFile, existing m
// queue file for update
if err := s.withTxn(ctx, func(ctx context.Context) error {
if err := s.Repository.FileStore.Update(ctx, existing); err != nil {
if err := s.Repository.File.Update(ctx, existing); err != nil {
return fmt.Errorf("updating file %q: %w", path, err)
}
@@ -986,7 +986,7 @@ func (s *scanJob) setMissingFingerprints(ctx context.Context, f scanFile, existi
existing.SetFingerprints(fp)
if err := s.withTxn(ctx, func(ctx context.Context) error {
if err := s.Repository.FileStore.Update(ctx, existing); err != nil {
if err := s.Repository.File.Update(ctx, existing); err != nil {
return fmt.Errorf("updating file %q: %w", f.Path, err)
}
@@ -1035,7 +1035,7 @@ func (s *scanJob) onExistingFile(ctx context.Context, f scanFile, existing model
// queue file for update
if err := s.withTxn(ctx, func(ctx context.Context) error {
if err := s.Repository.FileStore.Update(ctx, existing); err != nil {
if err := s.Repository.File.Update(ctx, existing); err != nil {
return fmt.Errorf("updating file %q: %w", path, err)
}

View File

@@ -157,19 +157,19 @@ var getStudioScenarios = []stringTestScenario{
}
func TestGetStudioName(t *testing.T) {
mockStudioReader := &mocks.StudioReaderWriter{}
db := mocks.NewDatabase()
studioErr := errors.New("error getting image")
mockStudioReader.On("Find", testCtx, studioID).Return(&models.Studio{
db.Studio.On("Find", testCtx, studioID).Return(&models.Studio{
Name: studioName,
}, nil).Once()
mockStudioReader.On("Find", testCtx, missingStudioID).Return(nil, nil).Once()
mockStudioReader.On("Find", testCtx, errStudioID).Return(nil, studioErr).Once()
db.Studio.On("Find", testCtx, missingStudioID).Return(nil, nil).Once()
db.Studio.On("Find", testCtx, errStudioID).Return(nil, studioErr).Once()
for i, s := range getStudioScenarios {
gallery := s.input
json, err := GetStudioName(testCtx, mockStudioReader, &gallery)
json, err := GetStudioName(testCtx, db.Studio, &gallery)
switch {
case !s.err && err != nil:
@@ -181,7 +181,7 @@ func TestGetStudioName(t *testing.T) {
}
}
mockStudioReader.AssertExpectations(t)
db.AssertExpectations(t)
}
const (
@@ -258,17 +258,17 @@ var validChapters = []*models.GalleryChapter{
}
func TestGetGalleryChaptersJSON(t *testing.T) {
mockChapterReader := &mocks.GalleryChapterReaderWriter{}
db := mocks.NewDatabase()
chaptersErr := errors.New("error getting gallery chapters")
mockChapterReader.On("FindByGalleryID", testCtx, galleryID).Return(validChapters, nil).Once()
mockChapterReader.On("FindByGalleryID", testCtx, noChaptersID).Return(nil, nil).Once()
mockChapterReader.On("FindByGalleryID", testCtx, errChaptersID).Return(nil, chaptersErr).Once()
db.GalleryChapter.On("FindByGalleryID", testCtx, galleryID).Return(validChapters, nil).Once()
db.GalleryChapter.On("FindByGalleryID", testCtx, noChaptersID).Return(nil, nil).Once()
db.GalleryChapter.On("FindByGalleryID", testCtx, errChaptersID).Return(nil, chaptersErr).Once()
for i, s := range getGalleryChaptersJSONScenarios {
gallery := s.input
json, err := GetGalleryChaptersJSON(testCtx, mockChapterReader, &gallery)
json, err := GetGalleryChaptersJSON(testCtx, db.GalleryChapter, &gallery)
switch {
case !s.err && err != nil:
@@ -280,4 +280,5 @@ func TestGetGalleryChaptersJSON(t *testing.T) {
}
}
db.AssertExpectations(t)
}

View File

@@ -78,19 +78,19 @@ func TestImporterPreImport(t *testing.T) {
}
func TestImporterPreImportWithStudio(t *testing.T) {
studioReaderWriter := &mocks.StudioReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
StudioWriter: studioReaderWriter,
StudioWriter: db.Studio,
Input: jsonschema.Gallery{
Studio: existingStudioName,
},
}
studioReaderWriter.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{
db.Studio.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{
ID: existingStudioID,
}, nil).Once()
studioReaderWriter.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once()
db.Studio.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once()
err := i.PreImport(testCtx)
assert.Nil(t, err)
@@ -100,22 +100,22 @@ func TestImporterPreImportWithStudio(t *testing.T) {
err = i.PreImport(testCtx)
assert.NotNil(t, err)
studioReaderWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingStudio(t *testing.T) {
studioReaderWriter := &mocks.StudioReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
StudioWriter: studioReaderWriter,
StudioWriter: db.Studio,
Input: jsonschema.Gallery{
Studio: missingStudioName,
},
MissingRefBehaviour: models.ImportMissingRefEnumFail,
}
studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3)
studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3)
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
s := args.Get(1).(*models.Studio)
s.ID = existingStudioID
}).Return(nil)
@@ -132,32 +132,34 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, existingStudioID, *i.gallery.StudioID)
studioReaderWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) {
studioReaderWriter := &mocks.StudioReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
StudioWriter: studioReaderWriter,
StudioWriter: db.Studio,
Input: jsonschema.Gallery{
Studio: missingStudioName,
},
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
}
studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once()
studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once()
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
err := i.PreImport(testCtx)
assert.NotNil(t, err)
db.AssertExpectations(t)
}
func TestImporterPreImportWithPerformer(t *testing.T) {
performerReaderWriter := &mocks.PerformerReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
PerformerWriter: performerReaderWriter,
PerformerWriter: db.Performer,
MissingRefBehaviour: models.ImportMissingRefEnumFail,
Input: jsonschema.Gallery{
Performers: []string{
@@ -166,13 +168,13 @@ func TestImporterPreImportWithPerformer(t *testing.T) {
},
}
performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerName}, false).Return([]*models.Performer{
db.Performer.On("FindByNames", testCtx, []string{existingPerformerName}, false).Return([]*models.Performer{
{
ID: existingPerformerID,
Name: existingPerformerName,
},
}, nil).Once()
performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once()
db.Performer.On("FindByNames", testCtx, []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once()
err := i.PreImport(testCtx)
assert.Nil(t, err)
@@ -182,14 +184,14 @@ func TestImporterPreImportWithPerformer(t *testing.T) {
err = i.PreImport(testCtx)
assert.NotNil(t, err)
performerReaderWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingPerformer(t *testing.T) {
performerReaderWriter := &mocks.PerformerReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
PerformerWriter: performerReaderWriter,
PerformerWriter: db.Performer,
Input: jsonschema.Gallery{
Performers: []string{
missingPerformerName,
@@ -198,8 +200,8 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail,
}
performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3)
performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Run(func(args mock.Arguments) {
db.Performer.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3)
db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Run(func(args mock.Arguments) {
performer := args.Get(1).(*models.Performer)
performer.ID = existingPerformerID
}).Return(nil)
@@ -216,14 +218,14 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, []int{existingPerformerID}, i.gallery.PerformerIDs.List())
performerReaderWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) {
performerReaderWriter := &mocks.PerformerReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
PerformerWriter: performerReaderWriter,
PerformerWriter: db.Performer,
Input: jsonschema.Gallery{
Performers: []string{
missingPerformerName,
@@ -232,18 +234,20 @@ func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
}
performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once()
performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Return(errors.New("Create error"))
db.Performer.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once()
db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Return(errors.New("Create error"))
err := i.PreImport(testCtx)
assert.NotNil(t, err)
db.AssertExpectations(t)
}
func TestImporterPreImportWithTag(t *testing.T) {
tagReaderWriter := &mocks.TagReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
TagWriter: tagReaderWriter,
TagWriter: db.Tag,
MissingRefBehaviour: models.ImportMissingRefEnumFail,
Input: jsonschema.Gallery{
Tags: []string{
@@ -252,13 +256,13 @@ func TestImporterPreImportWithTag(t *testing.T) {
},
}
tagReaderWriter.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{
db.Tag.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{
{
ID: existingTagID,
Name: existingTagName,
},
}, nil).Once()
tagReaderWriter.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once()
db.Tag.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once()
err := i.PreImport(testCtx)
assert.Nil(t, err)
@@ -268,14 +272,14 @@ func TestImporterPreImportWithTag(t *testing.T) {
err = i.PreImport(testCtx)
assert.NotNil(t, err)
tagReaderWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingTag(t *testing.T) {
tagReaderWriter := &mocks.TagReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
TagWriter: tagReaderWriter,
TagWriter: db.Tag,
Input: jsonschema.Gallery{
Tags: []string{
missingTagName,
@@ -284,8 +288,8 @@ func TestImporterPreImportWithMissingTag(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail,
}
tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3)
tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) {
db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3)
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) {
t := args.Get(1).(*models.Tag)
t.ID = existingTagID
}).Return(nil)
@@ -302,14 +306,14 @@ func TestImporterPreImportWithMissingTag(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, []int{existingTagID}, i.gallery.TagIDs.List())
tagReaderWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) {
tagReaderWriter := &mocks.TagReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
TagWriter: tagReaderWriter,
TagWriter: db.Tag,
Input: jsonschema.Gallery{
Tags: []string{
missingTagName,
@@ -318,9 +322,11 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
}
tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once()
tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error"))
db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once()
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error"))
err := i.PreImport(testCtx)
assert.NotNil(t, err)
db.AssertExpectations(t)
}

View File

@@ -130,19 +130,19 @@ var getStudioScenarios = []stringTestScenario{
}
func TestGetStudioName(t *testing.T) {
mockStudioReader := &mocks.StudioReaderWriter{}
db := mocks.NewDatabase()
studioErr := errors.New("error getting image")
mockStudioReader.On("Find", testCtx, studioID).Return(&models.Studio{
db.Studio.On("Find", testCtx, studioID).Return(&models.Studio{
Name: studioName,
}, nil).Once()
mockStudioReader.On("Find", testCtx, missingStudioID).Return(nil, nil).Once()
mockStudioReader.On("Find", testCtx, errStudioID).Return(nil, studioErr).Once()
db.Studio.On("Find", testCtx, missingStudioID).Return(nil, nil).Once()
db.Studio.On("Find", testCtx, errStudioID).Return(nil, studioErr).Once()
for i, s := range getStudioScenarios {
image := s.input
json, err := GetStudioName(testCtx, mockStudioReader, &image)
json, err := GetStudioName(testCtx, db.Studio, &image)
switch {
case !s.err && err != nil:
@@ -154,5 +154,5 @@ func TestGetStudioName(t *testing.T) {
}
}
mockStudioReader.AssertExpectations(t)
db.AssertExpectations(t)
}

View File

@@ -40,19 +40,19 @@ func TestImporterPreImport(t *testing.T) {
}
func TestImporterPreImportWithStudio(t *testing.T) {
studioReaderWriter := &mocks.StudioReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
StudioWriter: studioReaderWriter,
StudioWriter: db.Studio,
Input: jsonschema.Image{
Studio: existingStudioName,
},
}
studioReaderWriter.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{
db.Studio.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{
ID: existingStudioID,
}, nil).Once()
studioReaderWriter.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once()
db.Studio.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once()
err := i.PreImport(testCtx)
assert.Nil(t, err)
@@ -62,22 +62,22 @@ func TestImporterPreImportWithStudio(t *testing.T) {
err = i.PreImport(testCtx)
assert.NotNil(t, err)
studioReaderWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingStudio(t *testing.T) {
studioReaderWriter := &mocks.StudioReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
StudioWriter: studioReaderWriter,
StudioWriter: db.Studio,
Input: jsonschema.Image{
Studio: missingStudioName,
},
MissingRefBehaviour: models.ImportMissingRefEnumFail,
}
studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3)
studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3)
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
s := args.Get(1).(*models.Studio)
s.ID = existingStudioID
}).Return(nil)
@@ -94,32 +94,34 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, existingStudioID, *i.image.StudioID)
studioReaderWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) {
studioReaderWriter := &mocks.StudioReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
StudioWriter: studioReaderWriter,
StudioWriter: db.Studio,
Input: jsonschema.Image{
Studio: missingStudioName,
},
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
}
studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once()
studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once()
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
err := i.PreImport(testCtx)
assert.NotNil(t, err)
db.AssertExpectations(t)
}
func TestImporterPreImportWithPerformer(t *testing.T) {
performerReaderWriter := &mocks.PerformerReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
PerformerWriter: performerReaderWriter,
PerformerWriter: db.Performer,
MissingRefBehaviour: models.ImportMissingRefEnumFail,
Input: jsonschema.Image{
Performers: []string{
@@ -128,13 +130,13 @@ func TestImporterPreImportWithPerformer(t *testing.T) {
},
}
performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerName}, false).Return([]*models.Performer{
db.Performer.On("FindByNames", testCtx, []string{existingPerformerName}, false).Return([]*models.Performer{
{
ID: existingPerformerID,
Name: existingPerformerName,
},
}, nil).Once()
performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once()
db.Performer.On("FindByNames", testCtx, []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once()
err := i.PreImport(testCtx)
assert.Nil(t, err)
@@ -144,14 +146,14 @@ func TestImporterPreImportWithPerformer(t *testing.T) {
err = i.PreImport(testCtx)
assert.NotNil(t, err)
performerReaderWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingPerformer(t *testing.T) {
performerReaderWriter := &mocks.PerformerReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
PerformerWriter: performerReaderWriter,
PerformerWriter: db.Performer,
Input: jsonschema.Image{
Performers: []string{
missingPerformerName,
@@ -160,8 +162,8 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail,
}
performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3)
performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Run(func(args mock.Arguments) {
db.Performer.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3)
db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Run(func(args mock.Arguments) {
performer := args.Get(1).(*models.Performer)
performer.ID = existingPerformerID
}).Return(nil)
@@ -178,14 +180,14 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, []int{existingPerformerID}, i.image.PerformerIDs.List())
performerReaderWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) {
performerReaderWriter := &mocks.PerformerReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
PerformerWriter: performerReaderWriter,
PerformerWriter: db.Performer,
Input: jsonschema.Image{
Performers: []string{
missingPerformerName,
@@ -194,18 +196,20 @@ func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
}
performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once()
performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Return(errors.New("Create error"))
db.Performer.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once()
db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Return(errors.New("Create error"))
err := i.PreImport(testCtx)
assert.NotNil(t, err)
db.AssertExpectations(t)
}
func TestImporterPreImportWithTag(t *testing.T) {
tagReaderWriter := &mocks.TagReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
TagWriter: tagReaderWriter,
TagWriter: db.Tag,
MissingRefBehaviour: models.ImportMissingRefEnumFail,
Input: jsonschema.Image{
Tags: []string{
@@ -214,13 +218,13 @@ func TestImporterPreImportWithTag(t *testing.T) {
},
}
tagReaderWriter.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{
db.Tag.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{
{
ID: existingTagID,
Name: existingTagName,
},
}, nil).Once()
tagReaderWriter.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once()
db.Tag.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once()
err := i.PreImport(testCtx)
assert.Nil(t, err)
@@ -230,14 +234,14 @@ func TestImporterPreImportWithTag(t *testing.T) {
err = i.PreImport(testCtx)
assert.NotNil(t, err)
tagReaderWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingTag(t *testing.T) {
tagReaderWriter := &mocks.TagReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
TagWriter: tagReaderWriter,
TagWriter: db.Tag,
Input: jsonschema.Image{
Tags: []string{
missingTagName,
@@ -246,8 +250,8 @@ func TestImporterPreImportWithMissingTag(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail,
}
tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3)
tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) {
db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3)
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) {
t := args.Get(1).(*models.Tag)
t.ID = existingTagID
}).Return(nil)
@@ -264,14 +268,14 @@ func TestImporterPreImportWithMissingTag(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, []int{existingTagID}, i.image.TagIDs.List())
tagReaderWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) {
tagReaderWriter := &mocks.TagReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
TagWriter: tagReaderWriter,
TagWriter: db.Tag,
Input: jsonschema.Image{
Tags: []string{
missingTagName,
@@ -280,9 +284,11 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
}
tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once()
tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error"))
db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once()
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error"))
err := i.PreImport(testCtx)
assert.NotNil(t, err)
db.AssertExpectations(t)
}

View File

@@ -0,0 +1,107 @@
package mocks
import (
"context"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
"github.com/stretchr/testify/mock"
)
type Database struct {
File *FileReaderWriter
Folder *FolderReaderWriter
Gallery *GalleryReaderWriter
GalleryChapter *GalleryChapterReaderWriter
Image *ImageReaderWriter
Movie *MovieReaderWriter
Performer *PerformerReaderWriter
Scene *SceneReaderWriter
SceneMarker *SceneMarkerReaderWriter
Studio *StudioReaderWriter
Tag *TagReaderWriter
SavedFilter *SavedFilterReaderWriter
}
func (*Database) Begin(ctx context.Context, exclusive bool) (context.Context, error) {
return ctx, nil
}
func (*Database) WithDatabase(ctx context.Context) (context.Context, error) {
return ctx, nil
}
func (*Database) Commit(ctx context.Context) error {
return nil
}
func (*Database) Rollback(ctx context.Context) error {
return nil
}
func (*Database) Complete(ctx context.Context) {
}
func (*Database) AddPostCommitHook(ctx context.Context, hook txn.TxnFunc) {
}
func (*Database) AddPostRollbackHook(ctx context.Context, hook txn.TxnFunc) {
}
func (*Database) IsLocked(err error) bool {
return false
}
func (*Database) Reset() error {
return nil
}
func NewDatabase() *Database {
return &Database{
File: &FileReaderWriter{},
Folder: &FolderReaderWriter{},
Gallery: &GalleryReaderWriter{},
GalleryChapter: &GalleryChapterReaderWriter{},
Image: &ImageReaderWriter{},
Movie: &MovieReaderWriter{},
Performer: &PerformerReaderWriter{},
Scene: &SceneReaderWriter{},
SceneMarker: &SceneMarkerReaderWriter{},
Studio: &StudioReaderWriter{},
Tag: &TagReaderWriter{},
SavedFilter: &SavedFilterReaderWriter{},
}
}
func (db *Database) AssertExpectations(t mock.TestingT) {
db.File.AssertExpectations(t)
db.Folder.AssertExpectations(t)
db.Gallery.AssertExpectations(t)
db.GalleryChapter.AssertExpectations(t)
db.Image.AssertExpectations(t)
db.Movie.AssertExpectations(t)
db.Performer.AssertExpectations(t)
db.Scene.AssertExpectations(t)
db.SceneMarker.AssertExpectations(t)
db.Studio.AssertExpectations(t)
db.Tag.AssertExpectations(t)
db.SavedFilter.AssertExpectations(t)
}
func (db *Database) Repository() models.Repository {
return models.Repository{
TxnManager: db,
File: db.File,
Folder: db.Folder,
Gallery: db.Gallery,
GalleryChapter: db.GalleryChapter,
Image: db.Image,
Movie: db.Movie,
Performer: db.Performer,
Scene: db.Scene,
SceneMarker: db.SceneMarker,
Studio: db.Studio,
Tag: db.Tag,
SavedFilter: db.SavedFilter,
}
}

View File

@@ -1,59 +0,0 @@
package mocks
import (
context "context"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
)
type TxnManager struct{}
func (*TxnManager) Begin(ctx context.Context, exclusive bool) (context.Context, error) {
return ctx, nil
}
func (*TxnManager) WithDatabase(ctx context.Context) (context.Context, error) {
return ctx, nil
}
func (*TxnManager) Commit(ctx context.Context) error {
return nil
}
func (*TxnManager) Rollback(ctx context.Context) error {
return nil
}
func (*TxnManager) Complete(ctx context.Context) {
}
func (*TxnManager) AddPostCommitHook(ctx context.Context, hook txn.TxnFunc) {
}
func (*TxnManager) AddPostRollbackHook(ctx context.Context, hook txn.TxnFunc) {
}
func (*TxnManager) IsLocked(err error) bool {
return false
}
func (*TxnManager) Reset() error {
return nil
}
func NewTxnRepository() models.Repository {
return models.Repository{
TxnManager: &TxnManager{},
Gallery: &GalleryReaderWriter{},
GalleryChapter: &GalleryChapterReaderWriter{},
Image: &ImageReaderWriter{},
Movie: &MovieReaderWriter{},
Performer: &PerformerReaderWriter{},
Scene: &SceneReaderWriter{},
SceneMarker: &SceneMarkerReaderWriter{},
Studio: &StudioReaderWriter{},
Tag: &TagReaderWriter{},
SavedFilter: &SavedFilterReaderWriter{},
}
}

View File

@@ -49,5 +49,3 @@ func NewMoviePartial() MoviePartial {
UpdatedAt: NewOptionalTime(currentTime),
}
}
var DefaultMovieImage = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAYAAABw4pVUAAAABmJLR0QA/wD/AP+gvaeTAAAACXBIWXMAAA3XAAAN1wFCKJt4AAAAB3RJTUUH4wgVBQsJl1CMZAAAASJJREFUeNrt3N0JwyAYhlEj3cj9R3Cm5rbkqtAP+qrnGaCYHPwJpLlaa++mmLpbAERAgAgIEAEBIiBABERAgAgIEAEBIiBABERAgAgIEAHZuVflj40x4i94zhk9vqsVvEq6AsQqMP1EjORx20OACAgQRRx7T+zzcFBxcjNDfoB4ntQqTm5Awo7MlqywZxcgYQ+RlqywJ3ozJAQCSBiEJSsQA0gYBpDAgAARECACAkRAgAgIEAERECACAmSjUv6eAOSB8m8YIGGzBUjYbAESBgMkbBkDEjZbgITBAClcxiqQvEoatreYIWEBASIgJ4Gkf11ntXH3nS9uxfGWfJ5J9hAgAgJEQAQEiIAAERAgAgJEQAQEiIAAERAgAgJEQAQEiL7qBuc6RKLHxr0CAAAAAElFTkSuQmCC"

View File

@@ -1,17 +1,18 @@
package models
import (
"context"
"github.com/stashapp/stash/pkg/txn"
)
type TxnManager interface {
txn.Manager
txn.DatabaseProvider
Reset() error
}
type Repository struct {
TxnManager
TxnManager TxnManager
File FileReaderWriter
Folder FolderReaderWriter
@@ -26,3 +27,15 @@ type Repository struct {
Tag TagReaderWriter
SavedFilter SavedFilterReaderWriter
}
func (r *Repository) WithTxn(ctx context.Context, fn txn.TxnFunc) error {
return txn.WithTxn(ctx, r.TxnManager, fn)
}
func (r *Repository) WithReadTxn(ctx context.Context, fn txn.TxnFunc) error {
return txn.WithReadTxn(ctx, r.TxnManager, fn)
}
func (r *Repository) WithDB(ctx context.Context, fn txn.TxnFunc) error {
return txn.WithDatabase(ctx, r.TxnManager, fn)
}

View File

@@ -168,34 +168,32 @@ func initTestTable() {
func TestToJSON(t *testing.T) {
initTestTable()
mockMovieReader := &mocks.MovieReaderWriter{}
db := mocks.NewDatabase()
imageErr := errors.New("error getting image")
mockMovieReader.On("GetFrontImage", testCtx, movieID).Return(frontImageBytes, nil).Once()
mockMovieReader.On("GetFrontImage", testCtx, missingStudioMovieID).Return(frontImageBytes, nil).Once()
mockMovieReader.On("GetFrontImage", testCtx, emptyID).Return(nil, nil).Once().Maybe()
mockMovieReader.On("GetFrontImage", testCtx, errFrontImageID).Return(nil, imageErr).Once()
mockMovieReader.On("GetFrontImage", testCtx, errBackImageID).Return(frontImageBytes, nil).Once()
db.Movie.On("GetFrontImage", testCtx, movieID).Return(frontImageBytes, nil).Once()
db.Movie.On("GetFrontImage", testCtx, missingStudioMovieID).Return(frontImageBytes, nil).Once()
db.Movie.On("GetFrontImage", testCtx, emptyID).Return(nil, nil).Once().Maybe()
db.Movie.On("GetFrontImage", testCtx, errFrontImageID).Return(nil, imageErr).Once()
db.Movie.On("GetFrontImage", testCtx, errBackImageID).Return(frontImageBytes, nil).Once()
mockMovieReader.On("GetBackImage", testCtx, movieID).Return(backImageBytes, nil).Once()
mockMovieReader.On("GetBackImage", testCtx, missingStudioMovieID).Return(backImageBytes, nil).Once()
mockMovieReader.On("GetBackImage", testCtx, emptyID).Return(nil, nil).Once()
mockMovieReader.On("GetBackImage", testCtx, errBackImageID).Return(nil, imageErr).Once()
mockMovieReader.On("GetBackImage", testCtx, errFrontImageID).Return(backImageBytes, nil).Maybe()
mockMovieReader.On("GetBackImage", testCtx, errStudioMovieID).Return(backImageBytes, nil).Maybe()
mockStudioReader := &mocks.StudioReaderWriter{}
db.Movie.On("GetBackImage", testCtx, movieID).Return(backImageBytes, nil).Once()
db.Movie.On("GetBackImage", testCtx, missingStudioMovieID).Return(backImageBytes, nil).Once()
db.Movie.On("GetBackImage", testCtx, emptyID).Return(nil, nil).Once()
db.Movie.On("GetBackImage", testCtx, errBackImageID).Return(nil, imageErr).Once()
db.Movie.On("GetBackImage", testCtx, errFrontImageID).Return(backImageBytes, nil).Maybe()
db.Movie.On("GetBackImage", testCtx, errStudioMovieID).Return(backImageBytes, nil).Maybe()
studioErr := errors.New("error getting studio")
mockStudioReader.On("Find", testCtx, studioID).Return(&movieStudio, nil)
mockStudioReader.On("Find", testCtx, missingStudioID).Return(nil, nil)
mockStudioReader.On("Find", testCtx, errStudioID).Return(nil, studioErr)
db.Studio.On("Find", testCtx, studioID).Return(&movieStudio, nil)
db.Studio.On("Find", testCtx, missingStudioID).Return(nil, nil)
db.Studio.On("Find", testCtx, errStudioID).Return(nil, studioErr)
for i, s := range scenarios {
movie := s.movie
json, err := ToJSON(testCtx, mockMovieReader, mockStudioReader, &movie)
json, err := ToJSON(testCtx, db.Movie, db.Studio, &movie)
switch {
case !s.err && err != nil:
@@ -207,6 +205,5 @@ func TestToJSON(t *testing.T) {
}
}
mockMovieReader.AssertExpectations(t)
mockStudioReader.AssertExpectations(t)
db.AssertExpectations(t)
}

View File

@@ -69,10 +69,11 @@ func TestImporterPreImport(t *testing.T) {
}
func TestImporterPreImportWithStudio(t *testing.T) {
studioReaderWriter := &mocks.StudioReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
StudioWriter: studioReaderWriter,
ReaderWriter: db.Movie,
StudioWriter: db.Studio,
Input: jsonschema.Movie{
Name: movieName,
FrontImage: frontImage,
@@ -82,10 +83,10 @@ func TestImporterPreImportWithStudio(t *testing.T) {
},
}
studioReaderWriter.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{
db.Studio.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{
ID: existingStudioID,
}, nil).Once()
studioReaderWriter.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once()
db.Studio.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once()
err := i.PreImport(testCtx)
assert.Nil(t, err)
@@ -95,14 +96,15 @@ func TestImporterPreImportWithStudio(t *testing.T) {
err = i.PreImport(testCtx)
assert.NotNil(t, err)
studioReaderWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingStudio(t *testing.T) {
studioReaderWriter := &mocks.StudioReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
StudioWriter: studioReaderWriter,
ReaderWriter: db.Movie,
StudioWriter: db.Studio,
Input: jsonschema.Movie{
Name: movieName,
FrontImage: frontImage,
@@ -111,8 +113,8 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail,
}
studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3)
studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3)
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
s := args.Get(1).(*models.Studio)
s.ID = existingStudioID
}).Return(nil)
@@ -129,14 +131,15 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, existingStudioID, *i.movie.StudioID)
studioReaderWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) {
studioReaderWriter := &mocks.StudioReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
StudioWriter: studioReaderWriter,
ReaderWriter: db.Movie,
StudioWriter: db.Studio,
Input: jsonschema.Movie{
Name: movieName,
FrontImage: frontImage,
@@ -145,27 +148,30 @@ func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
}
studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once()
studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once()
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
err := i.PreImport(testCtx)
assert.NotNil(t, err)
db.AssertExpectations(t)
}
func TestImporterPostImport(t *testing.T) {
readerWriter := &mocks.MovieReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
ReaderWriter: readerWriter,
ReaderWriter: db.Movie,
StudioWriter: db.Studio,
frontImageData: frontImageBytes,
backImageData: backImageBytes,
}
updateMovieImageErr := errors.New("UpdateImages error")
readerWriter.On("UpdateFrontImage", testCtx, movieID, frontImageBytes).Return(nil).Once()
readerWriter.On("UpdateBackImage", testCtx, movieID, backImageBytes).Return(nil).Once()
readerWriter.On("UpdateFrontImage", testCtx, errImageID, frontImageBytes).Return(updateMovieImageErr).Once()
db.Movie.On("UpdateFrontImage", testCtx, movieID, frontImageBytes).Return(nil).Once()
db.Movie.On("UpdateBackImage", testCtx, movieID, backImageBytes).Return(nil).Once()
db.Movie.On("UpdateFrontImage", testCtx, errImageID, frontImageBytes).Return(updateMovieImageErr).Once()
err := i.PostImport(testCtx, movieID)
assert.Nil(t, err)
@@ -173,25 +179,26 @@ func TestImporterPostImport(t *testing.T) {
err = i.PostImport(testCtx, errImageID)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterFindExistingID(t *testing.T) {
readerWriter := &mocks.MovieReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
ReaderWriter: readerWriter,
ReaderWriter: db.Movie,
StudioWriter: db.Studio,
Input: jsonschema.Movie{
Name: movieName,
},
}
errFindByName := errors.New("FindByName error")
readerWriter.On("FindByName", testCtx, movieName, false).Return(nil, nil).Once()
readerWriter.On("FindByName", testCtx, existingMovieName, false).Return(&models.Movie{
db.Movie.On("FindByName", testCtx, movieName, false).Return(nil, nil).Once()
db.Movie.On("FindByName", testCtx, existingMovieName, false).Return(&models.Movie{
ID: existingMovieID,
}, nil).Once()
readerWriter.On("FindByName", testCtx, movieNameErr, false).Return(nil, errFindByName).Once()
db.Movie.On("FindByName", testCtx, movieNameErr, false).Return(nil, errFindByName).Once()
id, err := i.FindExistingID(testCtx)
assert.Nil(t, id)
@@ -207,11 +214,11 @@ func TestImporterFindExistingID(t *testing.T) {
assert.Nil(t, id)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestCreate(t *testing.T) {
readerWriter := &mocks.MovieReaderWriter{}
db := mocks.NewDatabase()
movie := models.Movie{
Name: movieName,
@@ -222,16 +229,17 @@ func TestCreate(t *testing.T) {
}
i := Importer{
ReaderWriter: readerWriter,
ReaderWriter: db.Movie,
StudioWriter: db.Studio,
movie: movie,
}
errCreate := errors.New("Create error")
readerWriter.On("Create", testCtx, &movie).Run(func(args mock.Arguments) {
db.Movie.On("Create", testCtx, &movie).Run(func(args mock.Arguments) {
m := args.Get(1).(*models.Movie)
m.ID = movieID
}).Return(nil).Once()
readerWriter.On("Create", testCtx, &movieErr).Return(errCreate).Once()
db.Movie.On("Create", testCtx, &movieErr).Return(errCreate).Once()
id, err := i.Create(testCtx)
assert.Equal(t, movieID, *id)
@@ -242,11 +250,11 @@ func TestCreate(t *testing.T) {
assert.Nil(t, id)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestUpdate(t *testing.T) {
readerWriter := &mocks.MovieReaderWriter{}
db := mocks.NewDatabase()
movie := models.Movie{
Name: movieName,
@@ -257,7 +265,8 @@ func TestUpdate(t *testing.T) {
}
i := Importer{
ReaderWriter: readerWriter,
ReaderWriter: db.Movie,
StudioWriter: db.Studio,
movie: movie,
}
@@ -265,7 +274,7 @@ func TestUpdate(t *testing.T) {
// id needs to be set for the mock input
movie.ID = movieID
readerWriter.On("Update", testCtx, &movie).Return(nil).Once()
db.Movie.On("Update", testCtx, &movie).Return(nil).Once()
err := i.Update(testCtx, movieID)
assert.Nil(t, err)
@@ -274,10 +283,10 @@ func TestUpdate(t *testing.T) {
// need to set id separately
movieErr.ID = errImageID
readerWriter.On("Update", testCtx, &movieErr).Return(errUpdate).Once()
db.Movie.On("Update", testCtx, &movieErr).Return(errUpdate).Once()
err = i.Update(testCtx, errImageID)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
db.AssertExpectations(t)
}

View File

@@ -203,17 +203,17 @@ func initTestTable() {
func TestToJSON(t *testing.T) {
initTestTable()
mockPerformerReader := &mocks.PerformerReaderWriter{}
db := mocks.NewDatabase()
imageErr := errors.New("error getting image")
mockPerformerReader.On("GetImage", testCtx, performerID).Return(imageBytes, nil).Once()
mockPerformerReader.On("GetImage", testCtx, noImageID).Return(nil, nil).Once()
mockPerformerReader.On("GetImage", testCtx, errImageID).Return(nil, imageErr).Once()
db.Performer.On("GetImage", testCtx, performerID).Return(imageBytes, nil).Once()
db.Performer.On("GetImage", testCtx, noImageID).Return(nil, nil).Once()
db.Performer.On("GetImage", testCtx, errImageID).Return(nil, imageErr).Once()
for i, s := range scenarios {
tag := s.input
json, err := ToJSON(testCtx, mockPerformerReader, &tag)
json, err := ToJSON(testCtx, db.Performer, &tag)
switch {
case !s.err && err != nil:
@@ -225,5 +225,5 @@ func TestToJSON(t *testing.T) {
}
}
mockPerformerReader.AssertExpectations(t)
db.AssertExpectations(t)
}

View File

@@ -63,10 +63,11 @@ func TestImporterPreImport(t *testing.T) {
}
func TestImporterPreImportWithTag(t *testing.T) {
tagReaderWriter := &mocks.TagReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
TagWriter: tagReaderWriter,
ReaderWriter: db.Performer,
TagWriter: db.Tag,
MissingRefBehaviour: models.ImportMissingRefEnumFail,
Input: jsonschema.Performer{
Tags: []string{
@@ -75,13 +76,13 @@ func TestImporterPreImportWithTag(t *testing.T) {
},
}
tagReaderWriter.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{
db.Tag.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{
{
ID: existingTagID,
Name: existingTagName,
},
}, nil).Once()
tagReaderWriter.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once()
db.Tag.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once()
err := i.PreImport(testCtx)
assert.Nil(t, err)
@@ -91,14 +92,15 @@ func TestImporterPreImportWithTag(t *testing.T) {
err = i.PreImport(testCtx)
assert.NotNil(t, err)
tagReaderWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingTag(t *testing.T) {
tagReaderWriter := &mocks.TagReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
TagWriter: tagReaderWriter,
ReaderWriter: db.Performer,
TagWriter: db.Tag,
Input: jsonschema.Performer{
Tags: []string{
missingTagName,
@@ -107,8 +109,8 @@ func TestImporterPreImportWithMissingTag(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail,
}
tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3)
tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) {
db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3)
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) {
t := args.Get(1).(*models.Tag)
t.ID = existingTagID
}).Return(nil)
@@ -125,14 +127,15 @@ func TestImporterPreImportWithMissingTag(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, existingTagID, i.performer.TagIDs.List()[0])
tagReaderWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) {
tagReaderWriter := &mocks.TagReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
TagWriter: tagReaderWriter,
ReaderWriter: db.Performer,
TagWriter: db.Tag,
Input: jsonschema.Performer{
Tags: []string{
missingTagName,
@@ -141,25 +144,28 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
}
tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once()
tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error"))
db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once()
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error"))
err := i.PreImport(testCtx)
assert.NotNil(t, err)
db.AssertExpectations(t)
}
func TestImporterPostImport(t *testing.T) {
readerWriter := &mocks.PerformerReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
ReaderWriter: readerWriter,
ReaderWriter: db.Performer,
TagWriter: db.Tag,
imageData: imageBytes,
}
updatePerformerImageErr := errors.New("UpdateImage error")
readerWriter.On("UpdateImage", testCtx, performerID, imageBytes).Return(nil).Once()
readerWriter.On("UpdateImage", testCtx, errImageID, imageBytes).Return(updatePerformerImageErr).Once()
db.Performer.On("UpdateImage", testCtx, performerID, imageBytes).Return(nil).Once()
db.Performer.On("UpdateImage", testCtx, errImageID, imageBytes).Return(updatePerformerImageErr).Once()
err := i.PostImport(testCtx, performerID)
assert.Nil(t, err)
@@ -167,14 +173,15 @@ func TestImporterPostImport(t *testing.T) {
err = i.PostImport(testCtx, errImageID)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterFindExistingID(t *testing.T) {
readerWriter := &mocks.PerformerReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
ReaderWriter: readerWriter,
ReaderWriter: db.Performer,
TagWriter: db.Tag,
Input: jsonschema.Performer{
Name: performerName,
},
@@ -195,13 +202,13 @@ func TestImporterFindExistingID(t *testing.T) {
}
errFindByNames := errors.New("FindByNames error")
readerWriter.On("Query", testCtx, performerFilter(performerName), findFilter).Return(nil, 0, nil).Once()
readerWriter.On("Query", testCtx, performerFilter(existingPerformerName), findFilter).Return([]*models.Performer{
db.Performer.On("Query", testCtx, performerFilter(performerName), findFilter).Return(nil, 0, nil).Once()
db.Performer.On("Query", testCtx, performerFilter(existingPerformerName), findFilter).Return([]*models.Performer{
{
ID: existingPerformerID,
},
}, 1, nil).Once()
readerWriter.On("Query", testCtx, performerFilter(performerNameErr), findFilter).Return(nil, 0, errFindByNames).Once()
db.Performer.On("Query", testCtx, performerFilter(performerNameErr), findFilter).Return(nil, 0, errFindByNames).Once()
id, err := i.FindExistingID(testCtx)
assert.Nil(t, id)
@@ -217,11 +224,11 @@ func TestImporterFindExistingID(t *testing.T) {
assert.Nil(t, id)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestCreate(t *testing.T) {
readerWriter := &mocks.PerformerReaderWriter{}
db := mocks.NewDatabase()
performer := models.Performer{
Name: performerName,
@@ -232,16 +239,17 @@ func TestCreate(t *testing.T) {
}
i := Importer{
ReaderWriter: readerWriter,
ReaderWriter: db.Performer,
TagWriter: db.Tag,
performer: performer,
}
errCreate := errors.New("Create error")
readerWriter.On("Create", testCtx, &performer).Run(func(args mock.Arguments) {
db.Performer.On("Create", testCtx, &performer).Run(func(args mock.Arguments) {
arg := args.Get(1).(*models.Performer)
arg.ID = performerID
}).Return(nil).Once()
readerWriter.On("Create", testCtx, &performerErr).Return(errCreate).Once()
db.Performer.On("Create", testCtx, &performerErr).Return(errCreate).Once()
id, err := i.Create(testCtx)
assert.Equal(t, performerID, *id)
@@ -252,11 +260,11 @@ func TestCreate(t *testing.T) {
assert.Nil(t, id)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestUpdate(t *testing.T) {
readerWriter := &mocks.PerformerReaderWriter{}
db := mocks.NewDatabase()
performer := models.Performer{
Name: performerName,
@@ -267,7 +275,8 @@ func TestUpdate(t *testing.T) {
}
i := Importer{
ReaderWriter: readerWriter,
ReaderWriter: db.Performer,
TagWriter: db.Tag,
performer: performer,
}
@@ -275,7 +284,7 @@ func TestUpdate(t *testing.T) {
// id needs to be set for the mock input
performer.ID = performerID
readerWriter.On("Update", testCtx, &performer).Return(nil).Once()
db.Performer.On("Update", testCtx, &performer).Return(nil).Once()
err := i.Update(testCtx, performerID)
assert.Nil(t, err)
@@ -284,10 +293,10 @@ func TestUpdate(t *testing.T) {
// need to set id separately
performerErr.ID = errImageID
readerWriter.On("Update", testCtx, &performerErr).Return(errUpdate).Once()
db.Performer.On("Update", testCtx, &performerErr).Return(errUpdate).Once()
err = i.Update(testCtx, errImageID)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
db.AssertExpectations(t)
}

View File

@@ -186,17 +186,17 @@ var scenarios = []basicTestScenario{
}
func TestToJSON(t *testing.T) {
mockSceneReader := &mocks.SceneReaderWriter{}
db := mocks.NewDatabase()
imageErr := errors.New("error getting image")
mockSceneReader.On("GetCover", testCtx, sceneID).Return(imageBytes, nil).Once()
mockSceneReader.On("GetCover", testCtx, noImageID).Return(nil, nil).Once()
mockSceneReader.On("GetCover", testCtx, errImageID).Return(nil, imageErr).Once()
db.Scene.On("GetCover", testCtx, sceneID).Return(imageBytes, nil).Once()
db.Scene.On("GetCover", testCtx, noImageID).Return(nil, nil).Once()
db.Scene.On("GetCover", testCtx, errImageID).Return(nil, imageErr).Once()
for i, s := range scenarios {
scene := s.input
json, err := ToBasicJSON(testCtx, mockSceneReader, &scene)
json, err := ToBasicJSON(testCtx, db.Scene, &scene)
switch {
case !s.err && err != nil:
@@ -208,7 +208,7 @@ func TestToJSON(t *testing.T) {
}
}
mockSceneReader.AssertExpectations(t)
db.AssertExpectations(t)
}
func createStudioScene(studioID int) models.Scene {
@@ -242,19 +242,19 @@ var getStudioScenarios = []stringTestScenario{
}
func TestGetStudioName(t *testing.T) {
mockStudioReader := &mocks.StudioReaderWriter{}
db := mocks.NewDatabase()
studioErr := errors.New("error getting image")
mockStudioReader.On("Find", testCtx, studioID).Return(&models.Studio{
db.Studio.On("Find", testCtx, studioID).Return(&models.Studio{
Name: studioName,
}, nil).Once()
mockStudioReader.On("Find", testCtx, missingStudioID).Return(nil, nil).Once()
mockStudioReader.On("Find", testCtx, errStudioID).Return(nil, studioErr).Once()
db.Studio.On("Find", testCtx, missingStudioID).Return(nil, nil).Once()
db.Studio.On("Find", testCtx, errStudioID).Return(nil, studioErr).Once()
for i, s := range getStudioScenarios {
scene := s.input
json, err := GetStudioName(testCtx, mockStudioReader, &scene)
json, err := GetStudioName(testCtx, db.Studio, &scene)
switch {
case !s.err && err != nil:
@@ -266,7 +266,7 @@ func TestGetStudioName(t *testing.T) {
}
}
mockStudioReader.AssertExpectations(t)
db.AssertExpectations(t)
}
type stringSliceTestScenario struct {
@@ -305,17 +305,17 @@ func getTags(names []string) []*models.Tag {
}
func TestGetTagNames(t *testing.T) {
mockTagReader := &mocks.TagReaderWriter{}
db := mocks.NewDatabase()
tagErr := errors.New("error getting tag")
mockTagReader.On("FindBySceneID", testCtx, sceneID).Return(getTags(names), nil).Once()
mockTagReader.On("FindBySceneID", testCtx, noTagsID).Return(nil, nil).Once()
mockTagReader.On("FindBySceneID", testCtx, errTagsID).Return(nil, tagErr).Once()
db.Tag.On("FindBySceneID", testCtx, sceneID).Return(getTags(names), nil).Once()
db.Tag.On("FindBySceneID", testCtx, noTagsID).Return(nil, nil).Once()
db.Tag.On("FindBySceneID", testCtx, errTagsID).Return(nil, tagErr).Once()
for i, s := range getTagNamesScenarios {
scene := s.input
json, err := GetTagNames(testCtx, mockTagReader, &scene)
json, err := GetTagNames(testCtx, db.Tag, &scene)
switch {
case !s.err && err != nil:
@@ -327,7 +327,7 @@ func TestGetTagNames(t *testing.T) {
}
}
mockTagReader.AssertExpectations(t)
db.AssertExpectations(t)
}
type sceneMoviesTestScenario struct {
@@ -391,20 +391,21 @@ var getSceneMoviesJSONScenarios = []sceneMoviesTestScenario{
}
func TestGetSceneMoviesJSON(t *testing.T) {
mockMovieReader := &mocks.MovieReaderWriter{}
db := mocks.NewDatabase()
movieErr := errors.New("error getting movie")
mockMovieReader.On("Find", testCtx, validMovie1).Return(&models.Movie{
db.Movie.On("Find", testCtx, validMovie1).Return(&models.Movie{
Name: movie1Name,
}, nil).Once()
mockMovieReader.On("Find", testCtx, validMovie2).Return(&models.Movie{
db.Movie.On("Find", testCtx, validMovie2).Return(&models.Movie{
Name: movie2Name,
}, nil).Once()
mockMovieReader.On("Find", testCtx, invalidMovie).Return(nil, movieErr).Once()
db.Movie.On("Find", testCtx, invalidMovie).Return(nil, movieErr).Once()
for i, s := range getSceneMoviesJSONScenarios {
scene := s.input
json, err := GetSceneMoviesJSON(testCtx, mockMovieReader, &scene)
json, err := GetSceneMoviesJSON(testCtx, db.Movie, &scene)
switch {
case !s.err && err != nil:
@@ -416,7 +417,7 @@ func TestGetSceneMoviesJSON(t *testing.T) {
}
}
mockMovieReader.AssertExpectations(t)
db.AssertExpectations(t)
}
const (
@@ -542,27 +543,26 @@ var invalidMarkers2 = []*models.SceneMarker{
}
func TestGetSceneMarkersJSON(t *testing.T) {
mockTagReader := &mocks.TagReaderWriter{}
mockMarkerReader := &mocks.SceneMarkerReaderWriter{}
db := mocks.NewDatabase()
markersErr := errors.New("error getting scene markers")
tagErr := errors.New("error getting tags")
mockMarkerReader.On("FindBySceneID", testCtx, sceneID).Return(validMarkers, nil).Once()
mockMarkerReader.On("FindBySceneID", testCtx, noMarkersID).Return(nil, nil).Once()
mockMarkerReader.On("FindBySceneID", testCtx, errMarkersID).Return(nil, markersErr).Once()
mockMarkerReader.On("FindBySceneID", testCtx, errFindPrimaryTagID).Return(invalidMarkers1, nil).Once()
mockMarkerReader.On("FindBySceneID", testCtx, errFindByMarkerID).Return(invalidMarkers2, nil).Once()
db.SceneMarker.On("FindBySceneID", testCtx, sceneID).Return(validMarkers, nil).Once()
db.SceneMarker.On("FindBySceneID", testCtx, noMarkersID).Return(nil, nil).Once()
db.SceneMarker.On("FindBySceneID", testCtx, errMarkersID).Return(nil, markersErr).Once()
db.SceneMarker.On("FindBySceneID", testCtx, errFindPrimaryTagID).Return(invalidMarkers1, nil).Once()
db.SceneMarker.On("FindBySceneID", testCtx, errFindByMarkerID).Return(invalidMarkers2, nil).Once()
mockTagReader.On("Find", testCtx, validTagID1).Return(&models.Tag{
db.Tag.On("Find", testCtx, validTagID1).Return(&models.Tag{
Name: validTagName1,
}, nil)
mockTagReader.On("Find", testCtx, validTagID2).Return(&models.Tag{
db.Tag.On("Find", testCtx, validTagID2).Return(&models.Tag{
Name: validTagName2,
}, nil)
mockTagReader.On("Find", testCtx, invalidTagID).Return(nil, tagErr)
db.Tag.On("Find", testCtx, invalidTagID).Return(nil, tagErr)
mockTagReader.On("FindBySceneMarkerID", testCtx, validMarkerID1).Return([]*models.Tag{
db.Tag.On("FindBySceneMarkerID", testCtx, validMarkerID1).Return([]*models.Tag{
{
Name: validTagName1,
},
@@ -570,16 +570,16 @@ func TestGetSceneMarkersJSON(t *testing.T) {
Name: validTagName2,
},
}, nil)
mockTagReader.On("FindBySceneMarkerID", testCtx, validMarkerID2).Return([]*models.Tag{
db.Tag.On("FindBySceneMarkerID", testCtx, validMarkerID2).Return([]*models.Tag{
{
Name: validTagName2,
},
}, nil)
mockTagReader.On("FindBySceneMarkerID", testCtx, invalidMarkerID2).Return(nil, tagErr).Once()
db.Tag.On("FindBySceneMarkerID", testCtx, invalidMarkerID2).Return(nil, tagErr).Once()
for i, s := range getSceneMarkersJSONScenarios {
scene := s.input
json, err := GetSceneMarkersJSON(testCtx, mockMarkerReader, mockTagReader, &scene)
json, err := GetSceneMarkersJSON(testCtx, db.SceneMarker, db.Tag, &scene)
switch {
case !s.err && err != nil:
@@ -591,5 +591,5 @@ func TestGetSceneMarkersJSON(t *testing.T) {
}
}
mockTagReader.AssertExpectations(t)
db.AssertExpectations(t)
}

View File

@@ -410,17 +410,19 @@ type FilenameParser struct {
ParserInput models.SceneParserInput
Filter *models.FindFilterType
whitespaceRE *regexp.Regexp
repository FilenameParserRepository
performerCache map[string]*models.Performer
studioCache map[string]*models.Studio
movieCache map[string]*models.Movie
tagCache map[string]*models.Tag
}
func NewFilenameParser(filter *models.FindFilterType, config models.SceneParserInput) *FilenameParser {
func NewFilenameParser(filter *models.FindFilterType, config models.SceneParserInput, repo FilenameParserRepository) *FilenameParser {
p := &FilenameParser{
Pattern: *filter.Q,
ParserInput: config,
Filter: filter,
repository: repo,
}
p.performerCache = make(map[string]*models.Performer)
@@ -457,7 +459,17 @@ type FilenameParserRepository struct {
Tag models.TagQueryer
}
func (p *FilenameParser) Parse(ctx context.Context, repo FilenameParserRepository) ([]*models.SceneParserResult, int, error) {
func NewFilenameParserRepository(repo models.Repository) FilenameParserRepository {
return FilenameParserRepository{
Scene: repo.Scene,
Performer: repo.Performer,
Studio: repo.Studio,
Movie: repo.Movie,
Tag: repo.Tag,
}
}
func (p *FilenameParser) Parse(ctx context.Context) ([]*models.SceneParserResult, int, error) {
// perform the query to find the scenes
mapper, err := newParseMapper(p.Pattern, p.ParserInput.IgnoreWords)
@@ -479,17 +491,17 @@ func (p *FilenameParser) Parse(ctx context.Context, repo FilenameParserRepositor
p.Filter.Q = nil
scenes, total, err := QueryWithCount(ctx, repo.Scene, sceneFilter, p.Filter)
scenes, total, err := QueryWithCount(ctx, p.repository.Scene, sceneFilter, p.Filter)
if err != nil {
return nil, 0, err
}
ret := p.parseScenes(ctx, repo, scenes, mapper)
ret := p.parseScenes(ctx, scenes, mapper)
return ret, total, nil
}
func (p *FilenameParser) parseScenes(ctx context.Context, repo FilenameParserRepository, scenes []*models.Scene, mapper *parseMapper) []*models.SceneParserResult {
func (p *FilenameParser) parseScenes(ctx context.Context, scenes []*models.Scene, mapper *parseMapper) []*models.SceneParserResult {
var ret []*models.SceneParserResult
for _, scene := range scenes {
sceneHolder := mapper.parse(scene)
@@ -498,7 +510,7 @@ func (p *FilenameParser) parseScenes(ctx context.Context, repo FilenameParserRep
r := &models.SceneParserResult{
Scene: scene,
}
p.setParserResult(ctx, repo, *sceneHolder, r)
p.setParserResult(ctx, *sceneHolder, r)
ret = append(ret, r)
}
@@ -671,7 +683,7 @@ func (p *FilenameParser) setMovies(ctx context.Context, qb MovieNameFinder, h sc
}
}
func (p *FilenameParser) setParserResult(ctx context.Context, repo FilenameParserRepository, h sceneHolder, result *models.SceneParserResult) {
func (p *FilenameParser) setParserResult(ctx context.Context, h sceneHolder, result *models.SceneParserResult) {
if h.result.Title != "" {
title := h.result.Title
title = p.replaceWhitespaceCharacters(title)
@@ -692,15 +704,17 @@ func (p *FilenameParser) setParserResult(ctx context.Context, repo FilenameParse
result.Rating = h.result.Rating
}
r := p.repository
if len(h.performers) > 0 {
p.setPerformers(ctx, repo.Performer, h, result)
p.setPerformers(ctx, r.Performer, h, result)
}
if len(h.tags) > 0 {
p.setTags(ctx, repo.Tag, h, result)
p.setTags(ctx, r.Tag, h, result)
}
p.setStudio(ctx, repo.Studio, h, result)
p.setStudio(ctx, r.Studio, h, result)
if len(h.movies) > 0 {
p.setMovies(ctx, repo.Movie, h, result)
p.setMovies(ctx, r.Movie, h, result)
}
}

View File

@@ -56,20 +56,19 @@ func TestImporterPreImport(t *testing.T) {
}
func TestImporterPreImportWithStudio(t *testing.T) {
studioReaderWriter := &mocks.StudioReaderWriter{}
testCtx := context.Background()
db := mocks.NewDatabase()
i := Importer{
StudioWriter: studioReaderWriter,
StudioWriter: db.Studio,
Input: jsonschema.Scene{
Studio: existingStudioName,
},
}
studioReaderWriter.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{
db.Studio.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{
ID: existingStudioID,
}, nil).Once()
studioReaderWriter.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once()
db.Studio.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once()
err := i.PreImport(testCtx)
assert.Nil(t, err)
@@ -79,22 +78,22 @@ func TestImporterPreImportWithStudio(t *testing.T) {
err = i.PreImport(testCtx)
assert.NotNil(t, err)
studioReaderWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingStudio(t *testing.T) {
studioReaderWriter := &mocks.StudioReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
StudioWriter: studioReaderWriter,
StudioWriter: db.Studio,
Input: jsonschema.Scene{
Studio: missingStudioName,
},
MissingRefBehaviour: models.ImportMissingRefEnumFail,
}
studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3)
studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3)
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
s := args.Get(1).(*models.Studio)
s.ID = existingStudioID
}).Return(nil)
@@ -111,32 +110,34 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, existingStudioID, *i.scene.StudioID)
studioReaderWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) {
studioReaderWriter := &mocks.StudioReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
StudioWriter: studioReaderWriter,
StudioWriter: db.Studio,
Input: jsonschema.Scene{
Studio: missingStudioName,
},
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
}
studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once()
studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
db.Studio.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once()
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
err := i.PreImport(testCtx)
assert.NotNil(t, err)
db.AssertExpectations(t)
}
func TestImporterPreImportWithPerformer(t *testing.T) {
performerReaderWriter := &mocks.PerformerReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
PerformerWriter: performerReaderWriter,
PerformerWriter: db.Performer,
MissingRefBehaviour: models.ImportMissingRefEnumFail,
Input: jsonschema.Scene{
Performers: []string{
@@ -145,13 +146,13 @@ func TestImporterPreImportWithPerformer(t *testing.T) {
},
}
performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerName}, false).Return([]*models.Performer{
db.Performer.On("FindByNames", testCtx, []string{existingPerformerName}, false).Return([]*models.Performer{
{
ID: existingPerformerID,
Name: existingPerformerName,
},
}, nil).Once()
performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once()
db.Performer.On("FindByNames", testCtx, []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once()
err := i.PreImport(testCtx)
assert.Nil(t, err)
@@ -161,14 +162,14 @@ func TestImporterPreImportWithPerformer(t *testing.T) {
err = i.PreImport(testCtx)
assert.NotNil(t, err)
performerReaderWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingPerformer(t *testing.T) {
performerReaderWriter := &mocks.PerformerReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
PerformerWriter: performerReaderWriter,
PerformerWriter: db.Performer,
Input: jsonschema.Scene{
Performers: []string{
missingPerformerName,
@@ -177,8 +178,8 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail,
}
performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3)
performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Run(func(args mock.Arguments) {
db.Performer.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3)
db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Run(func(args mock.Arguments) {
p := args.Get(1).(*models.Performer)
p.ID = existingPerformerID
}).Return(nil)
@@ -195,14 +196,14 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, []int{existingPerformerID}, i.scene.PerformerIDs.List())
performerReaderWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) {
performerReaderWriter := &mocks.PerformerReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
PerformerWriter: performerReaderWriter,
PerformerWriter: db.Performer,
Input: jsonschema.Scene{
Performers: []string{
missingPerformerName,
@@ -211,19 +212,20 @@ func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
}
performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once()
performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Return(errors.New("Create error"))
db.Performer.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once()
db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Return(errors.New("Create error"))
err := i.PreImport(testCtx)
assert.NotNil(t, err)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMovie(t *testing.T) {
movieReaderWriter := &mocks.MovieReaderWriter{}
testCtx := context.Background()
db := mocks.NewDatabase()
i := Importer{
MovieWriter: movieReaderWriter,
MovieWriter: db.Movie,
MissingRefBehaviour: models.ImportMissingRefEnumFail,
Input: jsonschema.Scene{
Movies: []jsonschema.SceneMovie{
@@ -235,11 +237,11 @@ func TestImporterPreImportWithMovie(t *testing.T) {
},
}
movieReaderWriter.On("FindByName", testCtx, existingMovieName, false).Return(&models.Movie{
db.Movie.On("FindByName", testCtx, existingMovieName, false).Return(&models.Movie{
ID: existingMovieID,
Name: existingMovieName,
}, nil).Once()
movieReaderWriter.On("FindByName", testCtx, existingMovieErr, false).Return(nil, errors.New("FindByName error")).Once()
db.Movie.On("FindByName", testCtx, existingMovieErr, false).Return(nil, errors.New("FindByName error")).Once()
err := i.PreImport(testCtx)
assert.Nil(t, err)
@@ -249,15 +251,14 @@ func TestImporterPreImportWithMovie(t *testing.T) {
err = i.PreImport(testCtx)
assert.NotNil(t, err)
movieReaderWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingMovie(t *testing.T) {
movieReaderWriter := &mocks.MovieReaderWriter{}
testCtx := context.Background()
db := mocks.NewDatabase()
i := Importer{
MovieWriter: movieReaderWriter,
MovieWriter: db.Movie,
Input: jsonschema.Scene{
Movies: []jsonschema.SceneMovie{
{
@@ -268,8 +269,8 @@ func TestImporterPreImportWithMissingMovie(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail,
}
movieReaderWriter.On("FindByName", testCtx, missingMovieName, false).Return(nil, nil).Times(3)
movieReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Movie")).Run(func(args mock.Arguments) {
db.Movie.On("FindByName", testCtx, missingMovieName, false).Return(nil, nil).Times(3)
db.Movie.On("Create", testCtx, mock.AnythingOfType("*models.Movie")).Run(func(args mock.Arguments) {
m := args.Get(1).(*models.Movie)
m.ID = existingMovieID
}).Return(nil)
@@ -286,14 +287,14 @@ func TestImporterPreImportWithMissingMovie(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, existingMovieID, i.scene.Movies.List()[0].MovieID)
movieReaderWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingMovieCreateErr(t *testing.T) {
movieReaderWriter := &mocks.MovieReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
MovieWriter: movieReaderWriter,
MovieWriter: db.Movie,
Input: jsonschema.Scene{
Movies: []jsonschema.SceneMovie{
{
@@ -304,18 +305,20 @@ func TestImporterPreImportWithMissingMovieCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
}
movieReaderWriter.On("FindByName", testCtx, missingMovieName, false).Return(nil, nil).Once()
movieReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Movie")).Return(errors.New("Create error"))
db.Movie.On("FindByName", testCtx, missingMovieName, false).Return(nil, nil).Once()
db.Movie.On("Create", testCtx, mock.AnythingOfType("*models.Movie")).Return(errors.New("Create error"))
err := i.PreImport(testCtx)
assert.NotNil(t, err)
db.AssertExpectations(t)
}
func TestImporterPreImportWithTag(t *testing.T) {
tagReaderWriter := &mocks.TagReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
TagWriter: tagReaderWriter,
TagWriter: db.Tag,
MissingRefBehaviour: models.ImportMissingRefEnumFail,
Input: jsonschema.Scene{
Tags: []string{
@@ -324,13 +327,13 @@ func TestImporterPreImportWithTag(t *testing.T) {
},
}
tagReaderWriter.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{
db.Tag.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{
{
ID: existingTagID,
Name: existingTagName,
},
}, nil).Once()
tagReaderWriter.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once()
db.Tag.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once()
err := i.PreImport(testCtx)
assert.Nil(t, err)
@@ -340,14 +343,14 @@ func TestImporterPreImportWithTag(t *testing.T) {
err = i.PreImport(testCtx)
assert.NotNil(t, err)
tagReaderWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingTag(t *testing.T) {
tagReaderWriter := &mocks.TagReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
TagWriter: tagReaderWriter,
TagWriter: db.Tag,
Input: jsonschema.Scene{
Tags: []string{
missingTagName,
@@ -356,8 +359,8 @@ func TestImporterPreImportWithMissingTag(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail,
}
tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3)
tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) {
db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3)
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) {
t := args.Get(1).(*models.Tag)
t.ID = existingTagID
}).Return(nil)
@@ -374,14 +377,14 @@ func TestImporterPreImportWithMissingTag(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, []int{existingTagID}, i.scene.TagIDs.List())
tagReaderWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) {
tagReaderWriter := &mocks.TagReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
TagWriter: tagReaderWriter,
TagWriter: db.Tag,
Input: jsonschema.Scene{
Tags: []string{
missingTagName,
@@ -390,9 +393,11 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
}
tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once()
tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error"))
db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once()
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error"))
err := i.PreImport(testCtx)
assert.NotNil(t, err)
db.AssertExpectations(t)
}

View File

@@ -1,7 +1,6 @@
package scene
import (
"context"
"errors"
"strconv"
"testing"
@@ -105,8 +104,6 @@ func TestUpdater_Update(t *testing.T) {
tagID
)
ctx := context.Background()
performerIDs := []int{performerID}
tagIDs := []int{tagID}
stashID := "stashID"
@@ -119,14 +116,15 @@ func TestUpdater_Update(t *testing.T) {
updateErr := errors.New("error updating")
qb := mocks.SceneReaderWriter{}
qb.On("UpdatePartial", ctx, mock.MatchedBy(func(id int) bool {
db := mocks.NewDatabase()
db.Scene.On("UpdatePartial", testCtx, mock.MatchedBy(func(id int) bool {
return id != badUpdateID
}), mock.Anything).Return(validScene, nil)
qb.On("UpdatePartial", ctx, badUpdateID, mock.Anything).Return(nil, updateErr)
db.Scene.On("UpdatePartial", testCtx, badUpdateID, mock.Anything).Return(nil, updateErr)
qb.On("UpdateCover", ctx, sceneID, cover).Return(nil).Once()
qb.On("UpdateCover", ctx, badCoverID, cover).Return(updateErr).Once()
db.Scene.On("UpdateCover", testCtx, sceneID, cover).Return(nil).Once()
db.Scene.On("UpdateCover", testCtx, badCoverID, cover).Return(updateErr).Once()
tests := []struct {
name string
@@ -204,7 +202,7 @@ func TestUpdater_Update(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.u.Update(ctx, &qb)
got, err := tt.u.Update(testCtx, db.Scene)
if (err != nil) != tt.wantErr {
t.Errorf("Updater.Update() error = %v, wantErr %v", err, tt.wantErr)
return
@@ -215,7 +213,7 @@ func TestUpdater_Update(t *testing.T) {
})
}
qb.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestUpdateSet_UpdateInput(t *testing.T) {

View File

@@ -18,7 +18,6 @@ const (
)
type autotagScraper struct {
// repository models.Repository
txnManager txn.Manager
performerReader models.PerformerAutoTagQueryer
studioReader models.StudioAutoTagQueryer
@@ -208,9 +207,9 @@ func (s autotagScraper) spec() Scraper {
}
}
func getAutoTagScraper(txnManager txn.Manager, repo Repository, globalConfig GlobalConfig) scraper {
func getAutoTagScraper(repo Repository, globalConfig GlobalConfig) scraper {
base := autotagScraper{
txnManager: txnManager,
txnManager: repo.TxnManager,
performerReader: repo.PerformerFinder,
studioReader: repo.StudioFinder,
tagReader: repo.TagFinder,

View File

@@ -77,6 +77,8 @@ type GalleryFinder interface {
}
type Repository struct {
TxnManager models.TxnManager
SceneFinder SceneFinder
GalleryFinder GalleryFinder
TagFinder TagFinder
@@ -85,12 +87,27 @@ type Repository struct {
StudioFinder StudioFinder
}
func NewRepository(repo models.Repository) Repository {
return Repository{
TxnManager: repo.TxnManager,
SceneFinder: repo.Scene,
GalleryFinder: repo.Gallery,
TagFinder: repo.Tag,
PerformerFinder: repo.Performer,
MovieFinder: repo.Movie,
StudioFinder: repo.Studio,
}
}
func (r *Repository) WithReadTxn(ctx context.Context, fn txn.TxnFunc) error {
return txn.WithReadTxn(ctx, r.TxnManager, fn)
}
// Cache stores the database of scrapers
type Cache struct {
client *http.Client
scrapers map[string]scraper // Scraper ID -> Scraper
globalConfig GlobalConfig
txnManager txn.Manager
repository Repository
}
@@ -122,14 +139,13 @@ func newClient(gc GlobalConfig) *http.Client {
//
// Scraper configurations are loaded from yml files in the provided scrapers
// directory and any subdirectories.
func NewCache(globalConfig GlobalConfig, txnManager txn.Manager, repo Repository) (*Cache, error) {
func NewCache(globalConfig GlobalConfig, repo Repository) (*Cache, error) {
// HTTP Client setup
client := newClient(globalConfig)
ret := &Cache{
client: client,
globalConfig: globalConfig,
txnManager: txnManager,
repository: repo,
}
@@ -148,7 +164,7 @@ func (c *Cache) loadScrapers() (map[string]scraper, error) {
// Add built-in scrapers
freeOnes := getFreeonesScraper(c.globalConfig)
autoTag := getAutoTagScraper(c.txnManager, c.repository, c.globalConfig)
autoTag := getAutoTagScraper(c.repository, c.globalConfig)
scrapers[freeOnes.spec().ID] = freeOnes
scrapers[autoTag.spec().ID] = autoTag
@@ -369,9 +385,12 @@ func (c Cache) ScrapeID(ctx context.Context, scraperID string, id int, ty Scrape
func (c Cache) getScene(ctx context.Context, sceneID int) (*models.Scene, error) {
var ret *models.Scene
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error {
r := c.repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
qb := r.SceneFinder
var err error
ret, err = c.repository.SceneFinder.Find(ctx, sceneID)
ret, err = qb.Find(ctx, sceneID)
if err != nil {
return err
}
@@ -380,7 +399,7 @@ func (c Cache) getScene(ctx context.Context, sceneID int) (*models.Scene, error)
return fmt.Errorf("scene with id %d not found", sceneID)
}
return ret.LoadURLs(ctx, c.repository.SceneFinder)
return ret.LoadURLs(ctx, qb)
}); err != nil {
return nil, err
}
@@ -389,9 +408,12 @@ func (c Cache) getScene(ctx context.Context, sceneID int) (*models.Scene, error)
func (c Cache) getGallery(ctx context.Context, galleryID int) (*models.Gallery, error) {
var ret *models.Gallery
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error {
r := c.repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
qb := r.GalleryFinder
var err error
ret, err = c.repository.GalleryFinder.Find(ctx, galleryID)
ret, err = qb.Find(ctx, galleryID)
if err != nil {
return err
}
@@ -400,12 +422,12 @@ func (c Cache) getGallery(ctx context.Context, galleryID int) (*models.Gallery,
return fmt.Errorf("gallery with id %d not found", galleryID)
}
err = ret.LoadFiles(ctx, c.repository.GalleryFinder)
err = ret.LoadFiles(ctx, qb)
if err != nil {
return err
}
return ret.LoadURLs(ctx, c.repository.GalleryFinder)
return ret.LoadURLs(ctx, qb)
}); err != nil {
return nil, err
}

View File

@@ -6,7 +6,6 @@ import (
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/match"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/txn"
)
// postScrape handles post-processing of scraped content. If the content
@@ -46,8 +45,9 @@ func (c Cache) postScrape(ctx context.Context, content ScrapedContent) (ScrapedC
}
func (c Cache) postScrapePerformer(ctx context.Context, p models.ScrapedPerformer) (ScrapedContent, error) {
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error {
tqb := c.repository.TagFinder
r := c.repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
tqb := r.TagFinder
tags, err := postProcessTags(ctx, tqb, p.Tags)
if err != nil {
@@ -72,8 +72,9 @@ func (c Cache) postScrapePerformer(ctx context.Context, p models.ScrapedPerforme
func (c Cache) postScrapeMovie(ctx context.Context, m models.ScrapedMovie) (ScrapedContent, error) {
if m.Studio != nil {
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error {
return match.ScrapedStudio(ctx, c.repository.StudioFinder, m.Studio, nil)
r := c.repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
return match.ScrapedStudio(ctx, r.StudioFinder, m.Studio, nil)
}); err != nil {
return nil, err
}
@@ -113,11 +114,12 @@ func (c Cache) postScrapeScene(ctx context.Context, scene ScrapedScene) (Scraped
scene.URLs = []string{*scene.URL}
}
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error {
pqb := c.repository.PerformerFinder
mqb := c.repository.MovieFinder
tqb := c.repository.TagFinder
sqb := c.repository.StudioFinder
r := c.repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
pqb := r.PerformerFinder
mqb := r.MovieFinder
tqb := r.TagFinder
sqb := r.StudioFinder
for _, p := range scene.Performers {
if p == nil {
@@ -175,10 +177,11 @@ func (c Cache) postScrapeGallery(ctx context.Context, g ScrapedGallery) (Scraped
g.URLs = []string{*g.URL}
}
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error {
pqb := c.repository.PerformerFinder
tqb := c.repository.TagFinder
sqb := c.repository.StudioFinder
r := c.repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
pqb := r.PerformerFinder
tqb := r.TagFinder
sqb := r.StudioFinder
for _, p := range g.Performers {
err := match.ScrapedPerformer(ctx, pqb, p, nil)

View File

@@ -56,22 +56,37 @@ type TagFinder interface {
}
type Repository struct {
TxnManager models.TxnManager
Scene SceneReader
Performer PerformerReader
Tag TagFinder
Studio StudioReader
}
func NewRepository(repo models.Repository) Repository {
return Repository{
TxnManager: repo.TxnManager,
Scene: repo.Scene,
Performer: repo.Performer,
Tag: repo.Tag,
Studio: repo.Studio,
}
}
func (r *Repository) WithReadTxn(ctx context.Context, fn txn.TxnFunc) error {
return txn.WithReadTxn(ctx, r.TxnManager, fn)
}
// Client represents the client interface to a stash-box server instance.
type Client struct {
client *graphql.Client
txnManager txn.Manager
repository Repository
box models.StashBox
}
// NewClient returns a new instance of a stash-box client.
func NewClient(box models.StashBox, txnManager txn.Manager, repo Repository) *Client {
func NewClient(box models.StashBox, repo Repository) *Client {
authHeader := func(req *http.Request) {
req.Header.Set("ApiKey", box.APIKey)
}
@@ -82,7 +97,6 @@ func NewClient(box models.StashBox, txnManager txn.Manager, repo Repository) *Cl
return &Client{
client: client,
txnManager: txnManager,
repository: repo,
box: box,
}
@@ -129,8 +143,9 @@ func (c Client) FindStashBoxSceneByFingerprints(ctx context.Context, sceneID int
func (c Client) FindStashBoxScenesByFingerprints(ctx context.Context, ids []int) ([][]*scraper.ScrapedScene, error) {
var fingerprints [][]*graphql.FingerprintQueryInput
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error {
qb := c.repository.Scene
r := c.repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
qb := r.Scene
for _, sceneID := range ids {
scene, err := qb.Find(ctx, sceneID)
@@ -142,7 +157,7 @@ func (c Client) FindStashBoxScenesByFingerprints(ctx context.Context, ids []int)
return fmt.Errorf("scene with id %d not found", sceneID)
}
if err := scene.LoadFiles(ctx, c.repository.Scene); err != nil {
if err := scene.LoadFiles(ctx, r.Scene); err != nil {
return err
}
@@ -243,8 +258,9 @@ func (c Client) SubmitStashBoxFingerprints(ctx context.Context, sceneIDs []strin
var fingerprints []graphql.FingerprintSubmission
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error {
qb := c.repository.Scene
r := c.repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
qb := r.Scene
for _, sceneID := range ids {
scene, err := qb.Find(ctx, sceneID)
@@ -382,9 +398,9 @@ func (c Client) FindStashBoxPerformersByNames(ctx context.Context, performerIDs
}
var performers []*models.Performer
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error {
qb := c.repository.Performer
r := c.repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
qb := r.Performer
for _, performerID := range ids {
performer, err := qb.Find(ctx, performerID)
@@ -417,8 +433,9 @@ func (c Client) FindStashBoxPerformersByPerformerNames(ctx context.Context, perf
var performers []*models.Performer
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error {
qb := c.repository.Performer
r := c.repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
qb := r.Performer
for _, performerID := range ids {
performer, err := qb.Find(ctx, performerID)
@@ -739,14 +756,15 @@ func (c Client) sceneFragmentToScrapedScene(ctx context.Context, s *graphql.Scen
ss.URL = &s.Urls[0].URL
}
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error {
pqb := c.repository.Performer
tqb := c.repository.Tag
r := c.repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
pqb := r.Performer
tqb := r.Tag
if s.Studio != nil {
ss.Studio = studioFragmentToScrapedStudio(*s.Studio)
err := match.ScrapedStudio(ctx, c.repository.Studio, ss.Studio, &c.box.Endpoint)
err := match.ScrapedStudio(ctx, r.Studio, ss.Studio, &c.box.Endpoint)
if err != nil {
return err
}
@@ -761,7 +779,7 @@ func (c Client) sceneFragmentToScrapedScene(ctx context.Context, s *graphql.Scen
if parentStudio.FindStudio != nil {
ss.Studio.Parent = studioFragmentToScrapedStudio(*parentStudio.FindStudio)
err = match.ScrapedStudio(ctx, c.repository.Studio, ss.Studio.Parent, &c.box.Endpoint)
err = match.ScrapedStudio(ctx, r.Studio, ss.Studio.Parent, &c.box.Endpoint)
if err != nil {
return err
}
@@ -809,8 +827,9 @@ func (c Client) FindStashBoxPerformerByID(ctx context.Context, id string) (*mode
ret := performerFragmentToScrapedPerformer(*performer.FindPerformer)
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error {
err := match.ScrapedPerformer(ctx, c.repository.Performer, ret, &c.box.Endpoint)
r := c.repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
err := match.ScrapedPerformer(ctx, r.Performer, ret, &c.box.Endpoint)
return err
}); err != nil {
return nil, err
@@ -836,8 +855,9 @@ func (c Client) FindStashBoxPerformerByName(ctx context.Context, name string) (*
return nil, nil
}
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error {
err := match.ScrapedPerformer(ctx, c.repository.Performer, ret, &c.box.Endpoint)
r := c.repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
err := match.ScrapedPerformer(ctx, r.Performer, ret, &c.box.Endpoint)
return err
}); err != nil {
return nil, err
@@ -864,10 +884,11 @@ func (c Client) FindStashBoxStudio(ctx context.Context, query string) (*models.S
var ret *models.ScrapedStudio
if studio.FindStudio != nil {
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error {
r := c.repository
if err := r.WithReadTxn(ctx, func(ctx context.Context) error {
ret = studioFragmentToScrapedStudio(*studio.FindStudio)
err = match.ScrapedStudio(ctx, c.repository.Studio, ret, &c.box.Endpoint)
err = match.ScrapedStudio(ctx, r.Studio, ret, &c.box.Endpoint)
if err != nil {
return err
}
@@ -881,7 +902,7 @@ func (c Client) FindStashBoxStudio(ctx context.Context, query string) (*models.S
if parentStudio.FindStudio != nil {
ret.Parent = studioFragmentToScrapedStudio(*parentStudio.FindStudio)
err = match.ScrapedStudio(ctx, c.repository.Studio, ret.Parent, &c.box.Endpoint)
err = match.ScrapedStudio(ctx, r.Studio, ret.Parent, &c.box.Endpoint)
if err != nil {
return err
}

View File

@@ -1176,7 +1176,7 @@ func makeImage(i int) *models.Image {
}
func createImages(ctx context.Context, n int) error {
qb := db.TxnRepository().Image
qb := db.Image
fqb := db.File
for i := 0; i < n; i++ {
@@ -1273,7 +1273,7 @@ func makeGallery(i int, includeScenes bool) *models.Gallery {
}
func createGalleries(ctx context.Context, n int) error {
gqb := db.TxnRepository().Gallery
gqb := db.Gallery
fqb := db.File
for i := 0; i < n; i++ {

View File

@@ -123,7 +123,7 @@ func (db *Database) IsLocked(err error) bool {
return false
}
func (db *Database) TxnRepository() models.Repository {
func (db *Database) Repository() models.Repository {
return models.Repository{
TxnManager: db,
File: db.File,

View File

@@ -1,7 +1,6 @@
package studio
import (
"context"
"errors"
"github.com/stashapp/stash/pkg/models"
@@ -162,27 +161,26 @@ func initTestTable() {
func TestToJSON(t *testing.T) {
initTestTable()
ctx := context.Background()
mockStudioReader := &mocks.StudioReaderWriter{}
db := mocks.NewDatabase()
imageErr := errors.New("error getting image")
mockStudioReader.On("GetImage", ctx, studioID).Return(imageBytes, nil).Once()
mockStudioReader.On("GetImage", ctx, noImageID).Return(nil, nil).Once()
mockStudioReader.On("GetImage", ctx, errImageID).Return(nil, imageErr).Once()
mockStudioReader.On("GetImage", ctx, missingParentStudioID).Return(imageBytes, nil).Maybe()
mockStudioReader.On("GetImage", ctx, errStudioID).Return(imageBytes, nil).Maybe()
db.Studio.On("GetImage", testCtx, studioID).Return(imageBytes, nil).Once()
db.Studio.On("GetImage", testCtx, noImageID).Return(nil, nil).Once()
db.Studio.On("GetImage", testCtx, errImageID).Return(nil, imageErr).Once()
db.Studio.On("GetImage", testCtx, missingParentStudioID).Return(imageBytes, nil).Maybe()
db.Studio.On("GetImage", testCtx, errStudioID).Return(imageBytes, nil).Maybe()
parentStudioErr := errors.New("error getting parent studio")
mockStudioReader.On("Find", ctx, parentStudioID).Return(&parentStudio, nil)
mockStudioReader.On("Find", ctx, missingStudioID).Return(nil, nil)
mockStudioReader.On("Find", ctx, errParentStudioID).Return(nil, parentStudioErr)
db.Studio.On("Find", testCtx, parentStudioID).Return(&parentStudio, nil)
db.Studio.On("Find", testCtx, missingStudioID).Return(nil, nil)
db.Studio.On("Find", testCtx, errParentStudioID).Return(nil, parentStudioErr)
for i, s := range scenarios {
studio := s.input
json, err := ToJSON(ctx, mockStudioReader, &studio)
json, err := ToJSON(testCtx, db.Studio, &studio)
switch {
case !s.err && err != nil:
@@ -194,5 +192,5 @@ func TestToJSON(t *testing.T) {
}
}
mockStudioReader.AssertExpectations(t)
db.AssertExpectations(t)
}

View File

@@ -25,6 +25,8 @@ const (
missingParentStudioName = "existingParentStudioName"
)
var testCtx = context.Background()
func TestImporterName(t *testing.T) {
i := Importer{
Input: jsonschema.Studio{
@@ -43,22 +45,21 @@ func TestImporterPreImport(t *testing.T) {
IgnoreAutoTag: autoTagIgnored,
},
}
ctx := context.Background()
err := i.PreImport(ctx)
err := i.PreImport(testCtx)
assert.NotNil(t, err)
i.Input.Image = image
err = i.PreImport(ctx)
err = i.PreImport(testCtx)
assert.Nil(t, err)
i.Input = *createFullJSONStudio(studioName, image, []string{"alias"})
i.Input.ParentStudio = ""
err = i.PreImport(ctx)
err = i.PreImport(testCtx)
assert.Nil(t, err)
expectedStudio := createFullStudio(0, 0)
@@ -67,11 +68,10 @@ func TestImporterPreImport(t *testing.T) {
}
func TestImporterPreImportWithParent(t *testing.T) {
readerWriter := &mocks.StudioReaderWriter{}
ctx := context.Background()
db := mocks.NewDatabase()
i := Importer{
ReaderWriter: readerWriter,
ReaderWriter: db.Studio,
Input: jsonschema.Studio{
Name: studioName,
Image: image,
@@ -79,28 +79,27 @@ func TestImporterPreImportWithParent(t *testing.T) {
},
}
readerWriter.On("FindByName", ctx, existingParentStudioName, false).Return(&models.Studio{
db.Studio.On("FindByName", testCtx, existingParentStudioName, false).Return(&models.Studio{
ID: existingStudioID,
}, nil).Once()
readerWriter.On("FindByName", ctx, existingParentStudioErr, false).Return(nil, errors.New("FindByName error")).Once()
db.Studio.On("FindByName", testCtx, existingParentStudioErr, false).Return(nil, errors.New("FindByName error")).Once()
err := i.PreImport(ctx)
err := i.PreImport(testCtx)
assert.Nil(t, err)
assert.Equal(t, existingStudioID, *i.studio.ParentID)
i.Input.ParentStudio = existingParentStudioErr
err = i.PreImport(ctx)
err = i.PreImport(testCtx)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingParent(t *testing.T) {
readerWriter := &mocks.StudioReaderWriter{}
ctx := context.Background()
db := mocks.NewDatabase()
i := Importer{
ReaderWriter: readerWriter,
ReaderWriter: db.Studio,
Input: jsonschema.Studio{
Name: studioName,
Image: image,
@@ -109,33 +108,32 @@ func TestImporterPreImportWithMissingParent(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail,
}
readerWriter.On("FindByName", ctx, missingParentStudioName, false).Return(nil, nil).Times(3)
readerWriter.On("Create", ctx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
db.Studio.On("FindByName", testCtx, missingParentStudioName, false).Return(nil, nil).Times(3)
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Run(func(args mock.Arguments) {
s := args.Get(1).(*models.Studio)
s.ID = existingStudioID
}).Return(nil)
err := i.PreImport(ctx)
err := i.PreImport(testCtx)
assert.NotNil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore
err = i.PreImport(ctx)
err = i.PreImport(testCtx)
assert.Nil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumCreate
err = i.PreImport(ctx)
err = i.PreImport(testCtx)
assert.Nil(t, err)
assert.Equal(t, existingStudioID, *i.studio.ParentID)
readerWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPreImportWithMissingParentCreateErr(t *testing.T) {
readerWriter := &mocks.StudioReaderWriter{}
ctx := context.Background()
db := mocks.NewDatabase()
i := Importer{
ReaderWriter: readerWriter,
ReaderWriter: db.Studio,
Input: jsonschema.Studio{
Name: studioName,
Image: image,
@@ -144,19 +142,20 @@ func TestImporterPreImportWithMissingParentCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
}
readerWriter.On("FindByName", ctx, missingParentStudioName, false).Return(nil, nil).Once()
readerWriter.On("Create", ctx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
db.Studio.On("FindByName", testCtx, missingParentStudioName, false).Return(nil, nil).Once()
db.Studio.On("Create", testCtx, mock.AnythingOfType("*models.Studio")).Return(errors.New("Create error"))
err := i.PreImport(ctx)
err := i.PreImport(testCtx)
assert.NotNil(t, err)
db.AssertExpectations(t)
}
func TestImporterPostImport(t *testing.T) {
readerWriter := &mocks.StudioReaderWriter{}
ctx := context.Background()
db := mocks.NewDatabase()
i := Importer{
ReaderWriter: readerWriter,
ReaderWriter: db.Studio,
Input: jsonschema.Studio{
Aliases: []string{"alias"},
},
@@ -165,56 +164,54 @@ func TestImporterPostImport(t *testing.T) {
updateStudioImageErr := errors.New("UpdateImage error")
readerWriter.On("UpdateImage", ctx, studioID, imageBytes).Return(nil).Once()
readerWriter.On("UpdateImage", ctx, errImageID, imageBytes).Return(updateStudioImageErr).Once()
db.Studio.On("UpdateImage", testCtx, studioID, imageBytes).Return(nil).Once()
db.Studio.On("UpdateImage", testCtx, errImageID, imageBytes).Return(updateStudioImageErr).Once()
err := i.PostImport(ctx, studioID)
err := i.PostImport(testCtx, studioID)
assert.Nil(t, err)
err = i.PostImport(ctx, errImageID)
err = i.PostImport(testCtx, errImageID)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterFindExistingID(t *testing.T) {
readerWriter := &mocks.StudioReaderWriter{}
ctx := context.Background()
db := mocks.NewDatabase()
i := Importer{
ReaderWriter: readerWriter,
ReaderWriter: db.Studio,
Input: jsonschema.Studio{
Name: studioName,
},
}
errFindByName := errors.New("FindByName error")
readerWriter.On("FindByName", ctx, studioName, false).Return(nil, nil).Once()
readerWriter.On("FindByName", ctx, existingStudioName, false).Return(&models.Studio{
db.Studio.On("FindByName", testCtx, studioName, false).Return(nil, nil).Once()
db.Studio.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{
ID: existingStudioID,
}, nil).Once()
readerWriter.On("FindByName", ctx, studioNameErr, false).Return(nil, errFindByName).Once()
db.Studio.On("FindByName", testCtx, studioNameErr, false).Return(nil, errFindByName).Once()
id, err := i.FindExistingID(ctx)
id, err := i.FindExistingID(testCtx)
assert.Nil(t, id)
assert.Nil(t, err)
i.Input.Name = existingStudioName
id, err = i.FindExistingID(ctx)
id, err = i.FindExistingID(testCtx)
assert.Equal(t, existingStudioID, *id)
assert.Nil(t, err)
i.Input.Name = studioNameErr
id, err = i.FindExistingID(ctx)
id, err = i.FindExistingID(testCtx)
assert.Nil(t, id)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestCreate(t *testing.T) {
readerWriter := &mocks.StudioReaderWriter{}
ctx := context.Background()
db := mocks.NewDatabase()
studio := models.Studio{
Name: studioName,
@@ -225,32 +222,31 @@ func TestCreate(t *testing.T) {
}
i := Importer{
ReaderWriter: readerWriter,
ReaderWriter: db.Studio,
studio: studio,
}
errCreate := errors.New("Create error")
readerWriter.On("Create", ctx, &studio).Run(func(args mock.Arguments) {
db.Studio.On("Create", testCtx, &studio).Run(func(args mock.Arguments) {
s := args.Get(1).(*models.Studio)
s.ID = studioID
}).Return(nil).Once()
readerWriter.On("Create", ctx, &studioErr).Return(errCreate).Once()
db.Studio.On("Create", testCtx, &studioErr).Return(errCreate).Once()
id, err := i.Create(ctx)
id, err := i.Create(testCtx)
assert.Equal(t, studioID, *id)
assert.Nil(t, err)
i.studio = studioErr
id, err = i.Create(ctx)
id, err = i.Create(testCtx)
assert.Nil(t, id)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestUpdate(t *testing.T) {
readerWriter := &mocks.StudioReaderWriter{}
ctx := context.Background()
db := mocks.NewDatabase()
studio := models.Studio{
Name: studioName,
@@ -261,7 +257,7 @@ func TestUpdate(t *testing.T) {
}
i := Importer{
ReaderWriter: readerWriter,
ReaderWriter: db.Studio,
studio: studio,
}
@@ -269,19 +265,19 @@ func TestUpdate(t *testing.T) {
// id needs to be set for the mock input
studio.ID = studioID
readerWriter.On("Update", ctx, &studio).Return(nil).Once()
db.Studio.On("Update", testCtx, &studio).Return(nil).Once()
err := i.Update(ctx, studioID)
err := i.Update(testCtx, studioID)
assert.Nil(t, err)
i.studio = studioErr
// need to set id separately
studioErr.ID = errImageID
readerWriter.On("Update", ctx, &studioErr).Return(errUpdate).Once()
db.Studio.On("Update", testCtx, &studioErr).Return(errUpdate).Once()
err = i.Update(ctx, errImageID)
err = i.Update(testCtx, errImageID)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
db.AssertExpectations(t)
}

View File

@@ -1,7 +1,6 @@
package tag
import (
"context"
"errors"
"github.com/stashapp/stash/pkg/models"
@@ -109,35 +108,34 @@ func initTestTable() {
func TestToJSON(t *testing.T) {
initTestTable()
ctx := context.Background()
mockTagReader := &mocks.TagReaderWriter{}
db := mocks.NewDatabase()
imageErr := errors.New("error getting image")
aliasErr := errors.New("error getting aliases")
parentsErr := errors.New("error getting parents")
mockTagReader.On("GetAliases", ctx, tagID).Return([]string{"alias"}, nil).Once()
mockTagReader.On("GetAliases", ctx, noImageID).Return(nil, nil).Once()
mockTagReader.On("GetAliases", ctx, errImageID).Return(nil, nil).Once()
mockTagReader.On("GetAliases", ctx, errAliasID).Return(nil, aliasErr).Once()
mockTagReader.On("GetAliases", ctx, withParentsID).Return(nil, nil).Once()
mockTagReader.On("GetAliases", ctx, errParentsID).Return(nil, nil).Once()
db.Tag.On("GetAliases", testCtx, tagID).Return([]string{"alias"}, nil).Once()
db.Tag.On("GetAliases", testCtx, noImageID).Return(nil, nil).Once()
db.Tag.On("GetAliases", testCtx, errImageID).Return(nil, nil).Once()
db.Tag.On("GetAliases", testCtx, errAliasID).Return(nil, aliasErr).Once()
db.Tag.On("GetAliases", testCtx, withParentsID).Return(nil, nil).Once()
db.Tag.On("GetAliases", testCtx, errParentsID).Return(nil, nil).Once()
mockTagReader.On("GetImage", ctx, tagID).Return(imageBytes, nil).Once()
mockTagReader.On("GetImage", ctx, noImageID).Return(nil, nil).Once()
mockTagReader.On("GetImage", ctx, errImageID).Return(nil, imageErr).Once()
mockTagReader.On("GetImage", ctx, withParentsID).Return(imageBytes, nil).Once()
mockTagReader.On("GetImage", ctx, errParentsID).Return(nil, nil).Once()
db.Tag.On("GetImage", testCtx, tagID).Return(imageBytes, nil).Once()
db.Tag.On("GetImage", testCtx, noImageID).Return(nil, nil).Once()
db.Tag.On("GetImage", testCtx, errImageID).Return(nil, imageErr).Once()
db.Tag.On("GetImage", testCtx, withParentsID).Return(imageBytes, nil).Once()
db.Tag.On("GetImage", testCtx, errParentsID).Return(nil, nil).Once()
mockTagReader.On("FindByChildTagID", ctx, tagID).Return(nil, nil).Once()
mockTagReader.On("FindByChildTagID", ctx, noImageID).Return(nil, nil).Once()
mockTagReader.On("FindByChildTagID", ctx, withParentsID).Return([]*models.Tag{{Name: "parent"}}, nil).Once()
mockTagReader.On("FindByChildTagID", ctx, errParentsID).Return(nil, parentsErr).Once()
mockTagReader.On("FindByChildTagID", ctx, errImageID).Return(nil, nil).Once()
db.Tag.On("FindByChildTagID", testCtx, tagID).Return(nil, nil).Once()
db.Tag.On("FindByChildTagID", testCtx, noImageID).Return(nil, nil).Once()
db.Tag.On("FindByChildTagID", testCtx, withParentsID).Return([]*models.Tag{{Name: "parent"}}, nil).Once()
db.Tag.On("FindByChildTagID", testCtx, errParentsID).Return(nil, parentsErr).Once()
db.Tag.On("FindByChildTagID", testCtx, errImageID).Return(nil, nil).Once()
for i, s := range scenarios {
tag := s.tag
json, err := ToJSON(ctx, mockTagReader, &tag)
json, err := ToJSON(testCtx, db.Tag, &tag)
switch {
case !s.err && err != nil:
@@ -149,5 +147,5 @@ func TestToJSON(t *testing.T) {
}
}
mockTagReader.AssertExpectations(t)
db.AssertExpectations(t)
}

View File

@@ -58,10 +58,10 @@ func TestImporterPreImport(t *testing.T) {
}
func TestImporterPostImport(t *testing.T) {
readerWriter := &mocks.TagReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
ReaderWriter: readerWriter,
ReaderWriter: db.Tag,
Input: jsonschema.Tag{
Aliases: []string{"alias"},
},
@@ -72,23 +72,23 @@ func TestImporterPostImport(t *testing.T) {
updateTagAliasErr := errors.New("UpdateAlias error")
updateTagParentsErr := errors.New("UpdateParentTags error")
readerWriter.On("UpdateAliases", testCtx, tagID, i.Input.Aliases).Return(nil).Once()
readerWriter.On("UpdateAliases", testCtx, errAliasID, i.Input.Aliases).Return(updateTagAliasErr).Once()
readerWriter.On("UpdateAliases", testCtx, withParentsID, i.Input.Aliases).Return(nil).Once()
readerWriter.On("UpdateAliases", testCtx, errParentsID, i.Input.Aliases).Return(nil).Once()
db.Tag.On("UpdateAliases", testCtx, tagID, i.Input.Aliases).Return(nil).Once()
db.Tag.On("UpdateAliases", testCtx, errAliasID, i.Input.Aliases).Return(updateTagAliasErr).Once()
db.Tag.On("UpdateAliases", testCtx, withParentsID, i.Input.Aliases).Return(nil).Once()
db.Tag.On("UpdateAliases", testCtx, errParentsID, i.Input.Aliases).Return(nil).Once()
readerWriter.On("UpdateImage", testCtx, tagID, imageBytes).Return(nil).Once()
readerWriter.On("UpdateImage", testCtx, errAliasID, imageBytes).Return(nil).Once()
readerWriter.On("UpdateImage", testCtx, errImageID, imageBytes).Return(updateTagImageErr).Once()
readerWriter.On("UpdateImage", testCtx, withParentsID, imageBytes).Return(nil).Once()
readerWriter.On("UpdateImage", testCtx, errParentsID, imageBytes).Return(nil).Once()
db.Tag.On("UpdateImage", testCtx, tagID, imageBytes).Return(nil).Once()
db.Tag.On("UpdateImage", testCtx, errAliasID, imageBytes).Return(nil).Once()
db.Tag.On("UpdateImage", testCtx, errImageID, imageBytes).Return(updateTagImageErr).Once()
db.Tag.On("UpdateImage", testCtx, withParentsID, imageBytes).Return(nil).Once()
db.Tag.On("UpdateImage", testCtx, errParentsID, imageBytes).Return(nil).Once()
var parentTags []int
readerWriter.On("UpdateParentTags", testCtx, tagID, parentTags).Return(nil).Once()
readerWriter.On("UpdateParentTags", testCtx, withParentsID, []int{100}).Return(nil).Once()
readerWriter.On("UpdateParentTags", testCtx, errParentsID, []int{100}).Return(updateTagParentsErr).Once()
db.Tag.On("UpdateParentTags", testCtx, tagID, parentTags).Return(nil).Once()
db.Tag.On("UpdateParentTags", testCtx, withParentsID, []int{100}).Return(nil).Once()
db.Tag.On("UpdateParentTags", testCtx, errParentsID, []int{100}).Return(updateTagParentsErr).Once()
readerWriter.On("FindByName", testCtx, "Parent", false).Return(&models.Tag{ID: 100}, nil)
db.Tag.On("FindByName", testCtx, "Parent", false).Return(&models.Tag{ID: 100}, nil)
err := i.PostImport(testCtx, tagID)
assert.Nil(t, err)
@@ -106,14 +106,14 @@ func TestImporterPostImport(t *testing.T) {
err = i.PostImport(testCtx, errParentsID)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterPostImportParentMissing(t *testing.T) {
readerWriter := &mocks.TagReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
ReaderWriter: readerWriter,
ReaderWriter: db.Tag,
Input: jsonschema.Tag{},
imageData: imageBytes,
}
@@ -133,33 +133,33 @@ func TestImporterPostImportParentMissing(t *testing.T) {
var emptyParents []int
readerWriter.On("UpdateImage", testCtx, mock.Anything, mock.Anything).Return(nil)
readerWriter.On("UpdateAliases", testCtx, mock.Anything, mock.Anything).Return(nil)
db.Tag.On("UpdateImage", testCtx, mock.Anything, mock.Anything).Return(nil)
db.Tag.On("UpdateAliases", testCtx, mock.Anything, mock.Anything).Return(nil)
readerWriter.On("FindByName", testCtx, "Create", false).Return(nil, nil).Once()
readerWriter.On("FindByName", testCtx, "CreateError", false).Return(nil, nil).Once()
readerWriter.On("FindByName", testCtx, "CreateFindError", false).Return(nil, findError).Once()
readerWriter.On("FindByName", testCtx, "CreateFound", false).Return(&models.Tag{ID: 101}, nil).Once()
readerWriter.On("FindByName", testCtx, "Fail", false).Return(nil, nil).Once()
readerWriter.On("FindByName", testCtx, "FailFindError", false).Return(nil, findError)
readerWriter.On("FindByName", testCtx, "FailFound", false).Return(&models.Tag{ID: 102}, nil).Once()
readerWriter.On("FindByName", testCtx, "Ignore", false).Return(nil, nil).Once()
readerWriter.On("FindByName", testCtx, "IgnoreFindError", false).Return(nil, findError)
readerWriter.On("FindByName", testCtx, "IgnoreFound", false).Return(&models.Tag{ID: 103}, nil).Once()
db.Tag.On("FindByName", testCtx, "Create", false).Return(nil, nil).Once()
db.Tag.On("FindByName", testCtx, "CreateError", false).Return(nil, nil).Once()
db.Tag.On("FindByName", testCtx, "CreateFindError", false).Return(nil, findError).Once()
db.Tag.On("FindByName", testCtx, "CreateFound", false).Return(&models.Tag{ID: 101}, nil).Once()
db.Tag.On("FindByName", testCtx, "Fail", false).Return(nil, nil).Once()
db.Tag.On("FindByName", testCtx, "FailFindError", false).Return(nil, findError)
db.Tag.On("FindByName", testCtx, "FailFound", false).Return(&models.Tag{ID: 102}, nil).Once()
db.Tag.On("FindByName", testCtx, "Ignore", false).Return(nil, nil).Once()
db.Tag.On("FindByName", testCtx, "IgnoreFindError", false).Return(nil, findError)
db.Tag.On("FindByName", testCtx, "IgnoreFound", false).Return(&models.Tag{ID: 103}, nil).Once()
readerWriter.On("UpdateParentTags", testCtx, createID, []int{100}).Return(nil).Once()
readerWriter.On("UpdateParentTags", testCtx, createFoundID, []int{101}).Return(nil).Once()
readerWriter.On("UpdateParentTags", testCtx, failFoundID, []int{102}).Return(nil).Once()
readerWriter.On("UpdateParentTags", testCtx, ignoreID, emptyParents).Return(nil).Once()
readerWriter.On("UpdateParentTags", testCtx, ignoreFoundID, []int{103}).Return(nil).Once()
db.Tag.On("UpdateParentTags", testCtx, createID, []int{100}).Return(nil).Once()
db.Tag.On("UpdateParentTags", testCtx, createFoundID, []int{101}).Return(nil).Once()
db.Tag.On("UpdateParentTags", testCtx, failFoundID, []int{102}).Return(nil).Once()
db.Tag.On("UpdateParentTags", testCtx, ignoreID, emptyParents).Return(nil).Once()
db.Tag.On("UpdateParentTags", testCtx, ignoreFoundID, []int{103}).Return(nil).Once()
readerWriter.On("Create", testCtx, mock.MatchedBy(func(t *models.Tag) bool {
db.Tag.On("Create", testCtx, mock.MatchedBy(func(t *models.Tag) bool {
return t.Name == "Create"
})).Run(func(args mock.Arguments) {
t := args.Get(1).(*models.Tag)
t.ID = 100
}).Return(nil).Once()
readerWriter.On("Create", testCtx, mock.MatchedBy(func(t *models.Tag) bool {
db.Tag.On("Create", testCtx, mock.MatchedBy(func(t *models.Tag) bool {
return t.Name == "CreateError"
})).Return(errors.New("failed creating parent")).Once()
@@ -206,25 +206,25 @@ func TestImporterPostImportParentMissing(t *testing.T) {
err = i.PostImport(testCtx, ignoreFoundID)
assert.Nil(t, err)
readerWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestImporterFindExistingID(t *testing.T) {
readerWriter := &mocks.TagReaderWriter{}
db := mocks.NewDatabase()
i := Importer{
ReaderWriter: readerWriter,
ReaderWriter: db.Tag,
Input: jsonschema.Tag{
Name: tagName,
},
}
errFindByName := errors.New("FindByName error")
readerWriter.On("FindByName", testCtx, tagName, false).Return(nil, nil).Once()
readerWriter.On("FindByName", testCtx, existingTagName, false).Return(&models.Tag{
db.Tag.On("FindByName", testCtx, tagName, false).Return(nil, nil).Once()
db.Tag.On("FindByName", testCtx, existingTagName, false).Return(&models.Tag{
ID: existingTagID,
}, nil).Once()
readerWriter.On("FindByName", testCtx, tagNameErr, false).Return(nil, errFindByName).Once()
db.Tag.On("FindByName", testCtx, tagNameErr, false).Return(nil, errFindByName).Once()
id, err := i.FindExistingID(testCtx)
assert.Nil(t, id)
@@ -240,11 +240,11 @@ func TestImporterFindExistingID(t *testing.T) {
assert.Nil(t, id)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestCreate(t *testing.T) {
readerWriter := &mocks.TagReaderWriter{}
db := mocks.NewDatabase()
tag := models.Tag{
Name: tagName,
@@ -255,16 +255,16 @@ func TestCreate(t *testing.T) {
}
i := Importer{
ReaderWriter: readerWriter,
ReaderWriter: db.Tag,
tag: tag,
}
errCreate := errors.New("Create error")
readerWriter.On("Create", testCtx, &tag).Run(func(args mock.Arguments) {
db.Tag.On("Create", testCtx, &tag).Run(func(args mock.Arguments) {
t := args.Get(1).(*models.Tag)
t.ID = tagID
}).Return(nil).Once()
readerWriter.On("Create", testCtx, &tagErr).Return(errCreate).Once()
db.Tag.On("Create", testCtx, &tagErr).Return(errCreate).Once()
id, err := i.Create(testCtx)
assert.Equal(t, tagID, *id)
@@ -275,11 +275,11 @@ func TestCreate(t *testing.T) {
assert.Nil(t, id)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
db.AssertExpectations(t)
}
func TestUpdate(t *testing.T) {
readerWriter := &mocks.TagReaderWriter{}
db := mocks.NewDatabase()
tag := models.Tag{
Name: tagName,
@@ -290,7 +290,7 @@ func TestUpdate(t *testing.T) {
}
i := Importer{
ReaderWriter: readerWriter,
ReaderWriter: db.Tag,
tag: tag,
}
@@ -298,7 +298,7 @@ func TestUpdate(t *testing.T) {
// id needs to be set for the mock input
tag.ID = tagID
readerWriter.On("Update", testCtx, &tag).Return(nil).Once()
db.Tag.On("Update", testCtx, &tag).Return(nil).Once()
err := i.Update(testCtx, tagID)
assert.Nil(t, err)
@@ -307,10 +307,10 @@ func TestUpdate(t *testing.T) {
// need to set id separately
tagErr.ID = errImageID
readerWriter.On("Update", testCtx, &tagErr).Return(errUpdate).Once()
db.Tag.On("Update", testCtx, &tagErr).Return(errUpdate).Once()
err = i.Update(testCtx, errImageID)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
db.AssertExpectations(t)
}

View File

@@ -219,8 +219,7 @@ func TestEnsureHierarchy(t *testing.T) {
}
func testEnsureHierarchy(t *testing.T, tc testUniqueHierarchyCase, queryParents, queryChildren bool) {
mockTagReader := &mocks.TagReaderWriter{}
ctx := context.Background()
db := mocks.NewDatabase()
var parentIDs, childIDs []int
find := make(map[int]*models.Tag)
@@ -247,15 +246,15 @@ func testEnsureHierarchy(t *testing.T, tc testUniqueHierarchyCase, queryParents,
if queryParents {
parentIDs = nil
mockTagReader.On("FindByChildTagID", ctx, tc.id).Return(tc.parents, nil).Once()
db.Tag.On("FindByChildTagID", testCtx, tc.id).Return(tc.parents, nil).Once()
}
if queryChildren {
childIDs = nil
mockTagReader.On("FindByParentTagID", ctx, tc.id).Return(tc.children, nil).Once()
db.Tag.On("FindByParentTagID", testCtx, tc.id).Return(tc.children, nil).Once()
}
mockTagReader.On("FindAllAncestors", ctx, mock.AnythingOfType("int"), []int(nil)).Return(func(ctx context.Context, tagID int, excludeIDs []int) []*models.TagPath {
db.Tag.On("FindAllAncestors", testCtx, mock.AnythingOfType("int"), []int(nil)).Return(func(ctx context.Context, tagID int, excludeIDs []int) []*models.TagPath {
return tc.onFindAllAncestors
}, func(ctx context.Context, tagID int, excludeIDs []int) error {
if tc.onFindAllAncestors != nil {
@@ -264,7 +263,7 @@ func testEnsureHierarchy(t *testing.T, tc testUniqueHierarchyCase, queryParents,
return fmt.Errorf("undefined ancestors for: %d", tagID)
}).Maybe()
mockTagReader.On("FindAllDescendants", ctx, mock.AnythingOfType("int"), []int(nil)).Return(func(ctx context.Context, tagID int, excludeIDs []int) []*models.TagPath {
db.Tag.On("FindAllDescendants", testCtx, mock.AnythingOfType("int"), []int(nil)).Return(func(ctx context.Context, tagID int, excludeIDs []int) []*models.TagPath {
return tc.onFindAllDescendants
}, func(ctx context.Context, tagID int, excludeIDs []int) error {
if tc.onFindAllDescendants != nil {
@@ -273,7 +272,7 @@ func testEnsureHierarchy(t *testing.T, tc testUniqueHierarchyCase, queryParents,
return fmt.Errorf("undefined descendants for: %d", tagID)
}).Maybe()
res := ValidateHierarchy(ctx, testUniqueHierarchyTags[tc.id], parentIDs, childIDs, mockTagReader)
res := ValidateHierarchy(testCtx, testUniqueHierarchyTags[tc.id], parentIDs, childIDs, db.Tag)
assert := assert.New(t)
@@ -285,5 +284,5 @@ func testEnsureHierarchy(t *testing.T, tc testUniqueHierarchyCase, queryParents,
assert.Nil(res)
}
mockTagReader.AssertExpectations(t)
db.AssertExpectations(t)
}