From 7b5bd80515b64646d8154df23a17893fd6083a2f Mon Sep 17 00:00:00 2001
From: WithoutPants <53250216+WithoutPants@users.noreply.github.com>
Date: Mon, 25 Apr 2022 15:55:05 +1000
Subject: [PATCH 01/30] Separate graphql API from rest of the system (#2503)
* Move graphql generated files to api
* Refactor identify options
* Remove models.StashBoxes
* Move ScraperSource to scraper package
* Rename field strategy enums
* Rename identify.TaskOptions to Options
---
.gitignore | 2 +-
gqlgen.yml | 128 +++++++++---
internal/api/resolver.go | 44 ++---
internal/api/resolver_model_image.go | 4 +-
internal/api/resolver_model_scene.go | 10 +-
internal/api/resolver_mutation_configure.go | 20 +-
internal/api/resolver_mutation_dlna.go | 9 +-
internal/api/resolver_mutation_gallery.go | 14 +-
internal/api/resolver_mutation_image.go | 14 +-
internal/api/resolver_mutation_metadata.go | 18 +-
internal/api/resolver_mutation_movie.go | 8 +-
internal/api/resolver_mutation_performer.go | 8 +-
internal/api/resolver_mutation_plugin.go | 4 +-
.../api/resolver_mutation_saved_filter.go | 6 +-
internal/api/resolver_mutation_scene.go | 28 +--
internal/api/resolver_mutation_stash_box.go | 8 +-
internal/api/resolver_mutation_studio.go | 6 +-
internal/api/resolver_mutation_tag.go | 8 +-
internal/api/resolver_mutation_tag_test.go | 6 +-
internal/api/resolver_query_configuration.go | 34 ++--
internal/api/resolver_query_dlna.go | 4 +-
internal/api/resolver_query_find_gallery.go | 4 +-
internal/api/resolver_query_find_image.go | 4 +-
internal/api/resolver_query_find_movie.go | 4 +-
internal/api/resolver_query_find_performer.go | 4 +-
internal/api/resolver_query_find_scene.go | 14 +-
.../api/resolver_query_find_scene_marker.go | 4 +-
internal/api/resolver_query_find_studio.go | 4 +-
internal/api/resolver_query_find_tag.go | 4 +-
internal/api/resolver_query_job.go | 13 +-
internal/api/resolver_query_logs.go | 7 +-
internal/api/resolver_query_metadata.go | 3 +-
internal/api/resolver_query_plugin.go | 6 +-
internal/api/resolver_query_scene.go | 2 +-
internal/api/resolver_query_scraper.go | 96 ++++-----
internal/api/resolver_subscription_job.go | 15 +-
internal/api/resolver_subscription_logging.go | 27 ++-
internal/api/scraped_content.go | 37 ++--
internal/api/server.go | 3 +-
internal/dlna/service.go | 18 +-
internal/dlna/whitelist.go | 7 +-
internal/identify/identify.go | 25 +--
internal/identify/identify_test.go | 89 ++++-----
internal/identify/options.go | 92 +++++++++
internal/identify/scene.go | 14 +-
internal/identify/scene_test.go | 101 +++++-----
internal/manager/config/config.go | 64 +++---
internal/manager/config/tasks.go | 27 +++
internal/manager/config/ui.go | 106 ++++++++++
internal/manager/filename_parser.go | 50 +++--
internal/manager/import.go | 52 ++++-
internal/manager/manager.go | 83 +++++++-
internal/manager/manager_tasks.go | 63 +++++-
internal/manager/scene.go | 26 ++-
internal/manager/task_autotag.go | 4 +-
internal/manager/task_clean.go | 6 +-
internal/manager/task_export.go | 20 +-
internal/manager/task_generate.go | 40 +++-
internal/manager/task_identify.go | 50 ++++-
internal/manager/task_import.go | 13 +-
internal/manager/task_plugin.go | 4 +-
internal/manager/task_scan.go | 22 +--
pkg/models/extension_resolution.go | 46 -----
pkg/models/filter.go | 108 +++++++++++
...xtension_find_filter.go => find_filter.go} | 56 ++++++
pkg/models/gallery.go | 66 +++++++
pkg/models/generate.go | 91 +++++++++
pkg/models/image.go | 49 +++++
pkg/models/import.go | 50 +++++
pkg/models/model_file.go | 49 ++++-
pkg/models/model_saved_filter.go | 59 ++++++
pkg/models/model_scene.go | 24 +++
pkg/models/model_scraped_item.go | 72 +++++++
pkg/models/movie.go | 18 ++
pkg/models/performer.go | 124 ++++++++++++
pkg/models/resolution.go | 183 ++++++++++++++++++
pkg/models/scene.go | 78 ++++++++
pkg/models/scene_marker.go | 17 ++
pkg/models/stash_box.go | 46 +----
pkg/models/stash_ids.go | 5 +
pkg/models/studio.go | 28 +++
pkg/models/tag.go | 32 +++
pkg/plugin/args.go | 24 ++-
pkg/plugin/config.go | 17 +-
pkg/plugin/convert.go | 5 +-
pkg/plugin/hooks.go | 7 +
pkg/plugin/plugins.go | 22 ++-
pkg/plugin/task.go | 6 +
pkg/scraper/action.go | 10 +-
pkg/scraper/autotag.go | 32 +--
pkg/scraper/cache.go | 20 +-
pkg/scraper/config.go | 51 +++--
pkg/scraper/gallery.go | 22 +++
pkg/scraper/group.go | 32 +--
pkg/scraper/image.go | 2 +-
pkg/scraper/json.go | 24 +--
pkg/scraper/mapped.go | 18 +-
pkg/scraper/movie.go | 12 ++
pkg/scraper/performer.go | 27 +++
pkg/scraper/postprocessing.go | 18 +-
pkg/scraper/query_url.go | 2 +-
pkg/scraper/scene.go | 32 +++
pkg/scraper/scraper.go | 150 ++++++++++++--
pkg/scraper/script.go | 44 ++---
pkg/scraper/stash.go | 22 +--
pkg/scraper/stashbox/models.go | 8 +
pkg/scraper/stashbox/stash_box.go | 31 +--
pkg/scraper/xpath.go | 24 +--
pkg/scraper/xpath_test.go | 2 +-
109 files changed, 2684 insertions(+), 791 deletions(-)
create mode 100644 internal/identify/options.go
create mode 100644 internal/manager/config/tasks.go
create mode 100644 internal/manager/config/ui.go
delete mode 100644 pkg/models/extension_resolution.go
create mode 100644 pkg/models/filter.go
rename pkg/models/{extension_find_filter.go => find_filter.go} (57%)
create mode 100644 pkg/models/generate.go
create mode 100644 pkg/models/import.go
create mode 100644 pkg/models/resolution.go
create mode 100644 pkg/scraper/gallery.go
create mode 100644 pkg/scraper/movie.go
create mode 100644 pkg/scraper/performer.go
create mode 100644 pkg/scraper/scene.go
create mode 100644 pkg/scraper/stashbox/models.go
diff --git a/.gitignore b/.gitignore
index 4002f4168..34494cd8d 100644
--- a/.gitignore
+++ b/.gitignore
@@ -16,7 +16,7 @@
*.out
# GraphQL generated output
-pkg/models/generated_*.go
+internal/api/generated_*.go
ui/v2.5/src/core/generated-*.tsx
####
diff --git a/gqlgen.yml b/gqlgen.yml
index 8762177d3..150d86bbc 100644
--- a/gqlgen.yml
+++ b/gqlgen.yml
@@ -4,46 +4,110 @@ schema:
- "graphql/schema/types/*.graphql"
- "graphql/schema/*.graphql"
exec:
- filename: pkg/models/generated_exec.go
+ filename: internal/api/generated_exec.go
model:
- filename: pkg/models/generated_models.go
+ filename: internal/api/generated_models.go
resolver:
filename: internal/api/resolver.go
type: Resolver
struct_tag: gqlgen
+autobind:
+ - github.com/stashapp/stash/pkg/models
+ - github.com/stashapp/stash/pkg/plugin
+ - github.com/stashapp/stash/pkg/scraper
+ - github.com/stashapp/stash/internal/identify
+ - github.com/stashapp/stash/internal/dlna
+ - github.com/stashapp/stash/pkg/scraper/stashbox
+
models:
+ # autobind on config causes generation issues
# Scalars
Timestamp:
model: github.com/stashapp/stash/pkg/models.Timestamp
- # Objects
- Gallery:
- model: github.com/stashapp/stash/pkg/models.Gallery
- Image:
- model: github.com/stashapp/stash/pkg/models.Image
- ImageFileType:
- model: github.com/stashapp/stash/pkg/models.ImageFileType
- Performer:
- model: github.com/stashapp/stash/pkg/models.Performer
- Scene:
- model: github.com/stashapp/stash/pkg/models.Scene
- SceneMarker:
- model: github.com/stashapp/stash/pkg/models.SceneMarker
- ScrapedItem:
- model: github.com/stashapp/stash/pkg/models.ScrapedItem
- Studio:
- model: github.com/stashapp/stash/pkg/models.Studio
- Movie:
- model: github.com/stashapp/stash/pkg/models.Movie
- Tag:
- model: github.com/stashapp/stash/pkg/models.Tag
- SceneFileType:
- model: github.com/stashapp/stash/pkg/models.SceneFileType
- SavedFilter:
- model: github.com/stashapp/stash/pkg/models.SavedFilter
- StashID:
- model: github.com/stashapp/stash/pkg/models.StashID
- SceneCaption:
- model: github.com/stashapp/stash/pkg/models.SceneCaption
-
+ StashConfig:
+ model: github.com/stashapp/stash/internal/manager/config.StashConfig
+ StashConfigInput:
+ model: github.com/stashapp/stash/internal/manager/config.StashConfigInput
+ StashBoxInput:
+ model: github.com/stashapp/stash/internal/manager/config.StashBoxInput
+ ConfigImageLightboxResult:
+ model: github.com/stashapp/stash/internal/manager/config.ConfigImageLightboxResult
+ ImageLightboxDisplayMode:
+ model: github.com/stashapp/stash/internal/manager/config.ImageLightboxDisplayMode
+ ImageLightboxScrollMode:
+ model: github.com/stashapp/stash/internal/manager/config.ImageLightboxScrollMode
+ ConfigDisableDropdownCreate:
+ model: github.com/stashapp/stash/internal/manager/config.ConfigDisableDropdownCreate
+ ScanMetadataOptions:
+ model: github.com/stashapp/stash/internal/manager/config.ScanMetadataOptions
+ AutoTagMetadataOptions:
+ model: github.com/stashapp/stash/internal/manager/config.AutoTagMetadataOptions
+ SceneParserInput:
+ model: github.com/stashapp/stash/internal/manager.SceneParserInput
+ SceneParserResult:
+ model: github.com/stashapp/stash/internal/manager.SceneParserResult
+ SceneMovieID:
+ model: github.com/stashapp/stash/internal/manager.SceneMovieID
+ SystemStatus:
+ model: github.com/stashapp/stash/internal/manager.SystemStatus
+ SystemStatusEnum:
+ model: github.com/stashapp/stash/internal/manager.SystemStatusEnum
+ ImportDuplicateEnum:
+ model: github.com/stashapp/stash/internal/manager.ImportDuplicateEnum
+ SetupInput:
+ model: github.com/stashapp/stash/internal/manager.SetupInput
+ MigrateInput:
+ model: github.com/stashapp/stash/internal/manager.MigrateInput
+ ScanMetadataInput:
+ model: github.com/stashapp/stash/internal/manager.ScanMetadataInput
+ GenerateMetadataInput:
+ model: github.com/stashapp/stash/internal/manager.GenerateMetadataInput
+ GeneratePreviewOptionsInput:
+ model: github.com/stashapp/stash/internal/manager.GeneratePreviewOptionsInput
+ AutoTagMetadataInput:
+ model: github.com/stashapp/stash/internal/manager.AutoTagMetadataInput
+ CleanMetadataInput:
+ model: github.com/stashapp/stash/internal/manager.CleanMetadataInput
+ StashBoxBatchPerformerTagInput:
+ model: github.com/stashapp/stash/internal/manager.StashBoxBatchPerformerTagInput
+ SceneStreamEndpoint:
+ model: github.com/stashapp/stash/internal/manager.SceneStreamEndpoint
+ ExportObjectTypeInput:
+ model: github.com/stashapp/stash/internal/manager.ExportObjectTypeInput
+ ExportObjectsInput:
+ model: github.com/stashapp/stash/internal/manager.ExportObjectsInput
+ ImportObjectsInput:
+ model: github.com/stashapp/stash/internal/manager.ImportObjectsInput
+ ScanMetaDataFilterInput:
+ model: github.com/stashapp/stash/internal/manager.ScanMetaDataFilterInput
+ # renamed types
+ DLNAStatus:
+ model: github.com/stashapp/stash/internal/dlna.Status
+ DLNAIP:
+ model: github.com/stashapp/stash/internal/dlna.Dlnaip
+ IdentifySource:
+ model: github.com/stashapp/stash/internal/identify.Source
+ IdentifyMetadataTaskOptions:
+ model: github.com/stashapp/stash/internal/identify.Options
+ IdentifyMetadataInput:
+ model: github.com/stashapp/stash/internal/identify.Options
+ IdentifyMetadataOptions:
+ model: github.com/stashapp/stash/internal/identify.MetadataOptions
+ IdentifyFieldOptions:
+ model: github.com/stashapp/stash/internal/identify.FieldOptions
+ IdentifyFieldStrategy:
+ model: github.com/stashapp/stash/internal/identify.FieldStrategy
+ ScraperSource:
+ model: github.com/stashapp/stash/pkg/scraper.Source
+ # rebind inputs to types
+ IdentifySourceInput:
+ model: github.com/stashapp/stash/internal/identify.Source
+ IdentifyFieldOptionsInput:
+ model: github.com/stashapp/stash/internal/identify.FieldOptions
+ IdentifyMetadataOptionsInput:
+ model: github.com/stashapp/stash/internal/identify.MetadataOptions
+ ScraperSourceInput:
+ model: github.com/stashapp/stash/pkg/scraper.Source
+
diff --git a/internal/api/resolver.go b/internal/api/resolver.go
index 3dbdd9fa0..5880f1e3c 100644
--- a/internal/api/resolver.go
+++ b/internal/api/resolver.go
@@ -38,37 +38,37 @@ func (r *Resolver) scraperCache() *scraper.Cache {
return manager.GetInstance().ScraperCache
}
-func (r *Resolver) Gallery() models.GalleryResolver {
+func (r *Resolver) Gallery() GalleryResolver {
return &galleryResolver{r}
}
-func (r *Resolver) Mutation() models.MutationResolver {
+func (r *Resolver) Mutation() MutationResolver {
return &mutationResolver{r}
}
-func (r *Resolver) Performer() models.PerformerResolver {
+func (r *Resolver) Performer() PerformerResolver {
return &performerResolver{r}
}
-func (r *Resolver) Query() models.QueryResolver {
+func (r *Resolver) Query() QueryResolver {
return &queryResolver{r}
}
-func (r *Resolver) Scene() models.SceneResolver {
+func (r *Resolver) Scene() SceneResolver {
return &sceneResolver{r}
}
-func (r *Resolver) Image() models.ImageResolver {
+func (r *Resolver) Image() ImageResolver {
return &imageResolver{r}
}
-func (r *Resolver) SceneMarker() models.SceneMarkerResolver {
+func (r *Resolver) SceneMarker() SceneMarkerResolver {
return &sceneMarkerResolver{r}
}
-func (r *Resolver) Studio() models.StudioResolver {
+func (r *Resolver) Studio() StudioResolver {
return &studioResolver{r}
}
-func (r *Resolver) Movie() models.MovieResolver {
+func (r *Resolver) Movie() MovieResolver {
return &movieResolver{r}
}
-func (r *Resolver) Subscription() models.SubscriptionResolver {
+func (r *Resolver) Subscription() SubscriptionResolver {
return &subscriptionResolver{r}
}
-func (r *Resolver) Tag() models.TagResolver {
+func (r *Resolver) Tag() TagResolver {
return &tagResolver{r}
}
@@ -125,8 +125,8 @@ func (r *queryResolver) MarkerStrings(ctx context.Context, q *string, sort *stri
return ret, nil
}
-func (r *queryResolver) Stats(ctx context.Context) (*models.StatsResultType, error) {
- var ret models.StatsResultType
+func (r *queryResolver) Stats(ctx context.Context) (*StatsResultType, error) {
+ var ret StatsResultType
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
scenesQB := repo.Scene()
imageQB := repo.Image()
@@ -146,7 +146,7 @@ func (r *queryResolver) Stats(ctx context.Context) (*models.StatsResultType, err
moviesCount, _ := moviesQB.Count()
tagsCount, _ := tagsQB.Count()
- ret = models.StatsResultType{
+ ret = StatsResultType{
SceneCount: scenesCount,
ScenesSize: scenesSize,
ScenesDuration: scenesDuration,
@@ -167,10 +167,10 @@ func (r *queryResolver) Stats(ctx context.Context) (*models.StatsResultType, err
return &ret, nil
}
-func (r *queryResolver) Version(ctx context.Context) (*models.Version, error) {
+func (r *queryResolver) Version(ctx context.Context) (*Version, error) {
version, hash, buildtime := GetVersion()
- return &models.Version{
+ return &Version{
Version: &version,
Hash: hash,
BuildTime: buildtime,
@@ -178,7 +178,7 @@ func (r *queryResolver) Version(ctx context.Context) (*models.Version, error) {
}
// Latestversion returns the latest git shorthash commit.
-func (r *queryResolver) Latestversion(ctx context.Context) (*models.ShortVersion, error) {
+func (r *queryResolver) Latestversion(ctx context.Context) (*ShortVersion, error) {
ver, url, err := GetLatestVersion(ctx, true)
if err == nil {
logger.Infof("Retrieved latest hash: %s", ver)
@@ -186,21 +186,21 @@ func (r *queryResolver) Latestversion(ctx context.Context) (*models.ShortVersion
logger.Errorf("Error while retrieving latest hash: %s", err)
}
- return &models.ShortVersion{
+ return &ShortVersion{
Shorthash: ver,
URL: url,
}, err
}
// Get scene marker tags which show up under the video.
-func (r *queryResolver) SceneMarkerTags(ctx context.Context, scene_id string) ([]*models.SceneMarkerTag, error) {
+func (r *queryResolver) SceneMarkerTags(ctx context.Context, scene_id string) ([]*SceneMarkerTag, error) {
sceneID, err := strconv.Atoi(scene_id)
if err != nil {
return nil, err
}
var keys []int
- tags := make(map[int]*models.SceneMarkerTag)
+ tags := make(map[int]*SceneMarkerTag)
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
sceneMarkers, err := repo.SceneMarker().FindBySceneID(sceneID)
@@ -216,7 +216,7 @@ func (r *queryResolver) SceneMarkerTags(ctx context.Context, scene_id string) ([
}
_, hasKey := tags[markerPrimaryTag.ID]
if !hasKey {
- sceneMarkerTag := &models.SceneMarkerTag{Tag: markerPrimaryTag}
+ sceneMarkerTag := &SceneMarkerTag{Tag: markerPrimaryTag}
tags[markerPrimaryTag.ID] = sceneMarkerTag
keys = append(keys, markerPrimaryTag.ID)
}
@@ -235,7 +235,7 @@ func (r *queryResolver) SceneMarkerTags(ctx context.Context, scene_id string) ([
return a.SceneMarkers[0].Seconds < b.SceneMarkers[0].Seconds
})
- var result []*models.SceneMarkerTag
+ var result []*SceneMarkerTag
for _, key := range keys {
result = append(result, tags[key])
}
diff --git a/internal/api/resolver_model_image.go b/internal/api/resolver_model_image.go
index a77437dfb..0f47b6c3e 100644
--- a/internal/api/resolver_model_image.go
+++ b/internal/api/resolver_model_image.go
@@ -33,12 +33,12 @@ func (r *imageResolver) File(ctx context.Context, obj *models.Image) (*models.Im
}, nil
}
-func (r *imageResolver) Paths(ctx context.Context, obj *models.Image) (*models.ImagePathsType, error) {
+func (r *imageResolver) Paths(ctx context.Context, obj *models.Image) (*ImagePathsType, error) {
baseURL, _ := ctx.Value(BaseURLCtxKey).(string)
builder := urlbuilders.NewImageURLBuilder(baseURL, obj)
thumbnailPath := builder.GetThumbnailURL()
imagePath := builder.GetImageURL()
- return &models.ImagePathsType{
+ return &ImagePathsType{
Image: &imagePath,
Thumbnail: &thumbnailPath,
}, nil
diff --git a/internal/api/resolver_model_scene.go b/internal/api/resolver_model_scene.go
index 4fd583d5b..0b81beaa5 100644
--- a/internal/api/resolver_model_scene.go
+++ b/internal/api/resolver_model_scene.go
@@ -85,7 +85,7 @@ func (r *sceneResolver) File(ctx context.Context, obj *models.Scene) (*models.Sc
}, nil
}
-func (r *sceneResolver) Paths(ctx context.Context, obj *models.Scene) (*models.ScenePathsType, error) {
+func (r *sceneResolver) Paths(ctx context.Context, obj *models.Scene) (*ScenePathsType, error) {
baseURL, _ := ctx.Value(BaseURLCtxKey).(string)
config := manager.GetInstance().Config
builder := urlbuilders.NewSceneURLBuilder(baseURL, obj.ID)
@@ -101,7 +101,7 @@ func (r *sceneResolver) Paths(ctx context.Context, obj *models.Scene) (*models.S
captionBasePath := builder.GetCaptionURL()
interactiveHeatmap := builder.GetInteractiveHeatmapURL()
- return &models.ScenePathsType{
+ return &ScenePathsType{
Screenshot: &screenshotPath,
Preview: &previewPath,
Stream: &streamPath,
@@ -163,7 +163,7 @@ func (r *sceneResolver) Studio(ctx context.Context, obj *models.Scene) (ret *mod
return ret, nil
}
-func (r *sceneResolver) Movies(ctx context.Context, obj *models.Scene) (ret []*models.SceneMovie, err error) {
+func (r *sceneResolver) Movies(ctx context.Context, obj *models.Scene) (ret []*SceneMovie, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
qb := repo.Scene()
mqb := repo.Movie()
@@ -180,7 +180,7 @@ func (r *sceneResolver) Movies(ctx context.Context, obj *models.Scene) (ret []*m
}
sceneIdx := sm.SceneIndex
- sceneMovie := &models.SceneMovie{
+ sceneMovie := &SceneMovie{
Movie: movie,
}
@@ -252,7 +252,7 @@ func (r *sceneResolver) FileModTime(ctx context.Context, obj *models.Scene) (*ti
return &obj.FileModTime.Timestamp, nil
}
-func (r *sceneResolver) SceneStreams(ctx context.Context, obj *models.Scene) ([]*models.SceneStreamEndpoint, error) {
+func (r *sceneResolver) SceneStreams(ctx context.Context, obj *models.Scene) ([]*manager.SceneStreamEndpoint, error) {
config := manager.GetInstance().Config
baseURL, _ := ctx.Value(BaseURLCtxKey).(string)
diff --git a/internal/api/resolver_mutation_configure.go b/internal/api/resolver_mutation_configure.go
index 7413c413b..48b1c6b9f 100644
--- a/internal/api/resolver_mutation_configure.go
+++ b/internal/api/resolver_mutation_configure.go
@@ -15,17 +15,17 @@ import (
var ErrOverriddenConfig = errors.New("cannot set overridden value")
-func (r *mutationResolver) Setup(ctx context.Context, input models.SetupInput) (bool, error) {
+func (r *mutationResolver) Setup(ctx context.Context, input manager.SetupInput) (bool, error) {
err := manager.GetInstance().Setup(ctx, input)
return err == nil, err
}
-func (r *mutationResolver) Migrate(ctx context.Context, input models.MigrateInput) (bool, error) {
+func (r *mutationResolver) Migrate(ctx context.Context, input manager.MigrateInput) (bool, error) {
err := manager.GetInstance().Migrate(ctx, input)
return err == nil, err
}
-func (r *mutationResolver) ConfigureGeneral(ctx context.Context, input models.ConfigGeneralInput) (*models.ConfigGeneralResult, error) {
+func (r *mutationResolver) ConfigureGeneral(ctx context.Context, input ConfigGeneralInput) (*ConfigGeneralResult, error) {
c := config.GetInstance()
existingPaths := c.GetStashPaths()
@@ -281,7 +281,7 @@ func (r *mutationResolver) ConfigureGeneral(ctx context.Context, input models.Co
return makeConfigGeneralResult(), nil
}
-func (r *mutationResolver) ConfigureInterface(ctx context.Context, input models.ConfigInterfaceInput) (*models.ConfigInterfaceResult, error) {
+func (r *mutationResolver) ConfigureInterface(ctx context.Context, input ConfigInterfaceInput) (*ConfigInterfaceResult, error) {
c := config.GetInstance()
setBool := func(key string, v *bool) {
@@ -338,10 +338,10 @@ func (r *mutationResolver) ConfigureInterface(ctx context.Context, input models.
c.Set(config.ImageLightboxSlideshowDelay, *options.SlideshowDelay)
}
- setString(config.ImageLightboxDisplayMode, (*string)(options.DisplayMode))
+ setString(config.ImageLightboxDisplayModeKey, (*string)(options.DisplayMode))
setBool(config.ImageLightboxScaleUp, options.ScaleUp)
setBool(config.ImageLightboxResetZoomOnNav, options.ResetZoomOnNav)
- setString(config.ImageLightboxScrollMode, (*string)(options.ScrollMode))
+ setString(config.ImageLightboxScrollModeKey, (*string)(options.ScrollMode))
if options.ScrollAttemptsBeforeChange != nil {
c.Set(config.ImageLightboxScrollAttemptsBeforeChange, *options.ScrollAttemptsBeforeChange)
@@ -376,7 +376,7 @@ func (r *mutationResolver) ConfigureInterface(ctx context.Context, input models.
return makeConfigInterfaceResult(), nil
}
-func (r *mutationResolver) ConfigureDlna(ctx context.Context, input models.ConfigDLNAInput) (*models.ConfigDLNAResult, error) {
+func (r *mutationResolver) ConfigureDlna(ctx context.Context, input ConfigDLNAInput) (*ConfigDLNAResult, error) {
c := config.GetInstance()
if input.ServerName != nil {
@@ -413,7 +413,7 @@ func (r *mutationResolver) ConfigureDlna(ctx context.Context, input models.Confi
return makeConfigDLNAResult(), nil
}
-func (r *mutationResolver) ConfigureScraping(ctx context.Context, input models.ConfigScrapingInput) (*models.ConfigScrapingResult, error) {
+func (r *mutationResolver) ConfigureScraping(ctx context.Context, input ConfigScrapingInput) (*ConfigScrapingResult, error) {
c := config.GetInstance()
refreshScraperCache := false
@@ -445,7 +445,7 @@ func (r *mutationResolver) ConfigureScraping(ctx context.Context, input models.C
return makeConfigScrapingResult(), nil
}
-func (r *mutationResolver) ConfigureDefaults(ctx context.Context, input models.ConfigDefaultSettingsInput) (*models.ConfigDefaultSettingsResult, error) {
+func (r *mutationResolver) ConfigureDefaults(ctx context.Context, input ConfigDefaultSettingsInput) (*ConfigDefaultSettingsResult, error) {
c := config.GetInstance()
if input.Identify != nil {
@@ -479,7 +479,7 @@ func (r *mutationResolver) ConfigureDefaults(ctx context.Context, input models.C
return makeConfigDefaultsResult(), nil
}
-func (r *mutationResolver) GenerateAPIKey(ctx context.Context, input models.GenerateAPIKeyInput) (string, error) {
+func (r *mutationResolver) GenerateAPIKey(ctx context.Context, input GenerateAPIKeyInput) (string, error) {
c := config.GetInstance()
var newAPIKey string
diff --git a/internal/api/resolver_mutation_dlna.go b/internal/api/resolver_mutation_dlna.go
index 6f43a9e6f..cb62afac9 100644
--- a/internal/api/resolver_mutation_dlna.go
+++ b/internal/api/resolver_mutation_dlna.go
@@ -5,10 +5,9 @@ import (
"time"
"github.com/stashapp/stash/internal/manager"
- "github.com/stashapp/stash/pkg/models"
)
-func (r *mutationResolver) EnableDlna(ctx context.Context, input models.EnableDLNAInput) (bool, error) {
+func (r *mutationResolver) EnableDlna(ctx context.Context, input EnableDLNAInput) (bool, error) {
err := manager.GetInstance().DLNAService.Start(parseMinutes(input.Duration))
if err != nil {
return false, err
@@ -16,17 +15,17 @@ func (r *mutationResolver) EnableDlna(ctx context.Context, input models.EnableDL
return true, nil
}
-func (r *mutationResolver) DisableDlna(ctx context.Context, input models.DisableDLNAInput) (bool, error) {
+func (r *mutationResolver) DisableDlna(ctx context.Context, input DisableDLNAInput) (bool, error) {
manager.GetInstance().DLNAService.Stop(parseMinutes(input.Duration))
return true, nil
}
-func (r *mutationResolver) AddTempDlnaip(ctx context.Context, input models.AddTempDLNAIPInput) (bool, error) {
+func (r *mutationResolver) AddTempDlnaip(ctx context.Context, input AddTempDLNAIPInput) (bool, error) {
manager.GetInstance().DLNAService.AddTempDLNAIP(input.Address, parseMinutes(input.Duration))
return true, nil
}
-func (r *mutationResolver) RemoveTempDlnaip(ctx context.Context, input models.RemoveTempDLNAIPInput) (bool, error) {
+func (r *mutationResolver) RemoveTempDlnaip(ctx context.Context, input RemoveTempDLNAIPInput) (bool, error) {
ret := manager.GetInstance().DLNAService.RemoveTempDLNAIP(input.Address)
return ret, nil
}
diff --git a/internal/api/resolver_mutation_gallery.go b/internal/api/resolver_mutation_gallery.go
index cd04a0313..816d6aba3 100644
--- a/internal/api/resolver_mutation_gallery.go
+++ b/internal/api/resolver_mutation_gallery.go
@@ -31,7 +31,7 @@ func (r *mutationResolver) getGallery(ctx context.Context, id int) (ret *models.
return ret, nil
}
-func (r *mutationResolver) GalleryCreate(ctx context.Context, input models.GalleryCreateInput) (*models.Gallery, error) {
+func (r *mutationResolver) GalleryCreate(ctx context.Context, input GalleryCreateInput) (*models.Gallery, error) {
// name must be provided
if input.Title == "" {
return nil, errors.New("title must not be empty")
@@ -273,7 +273,7 @@ func (r *mutationResolver) galleryUpdate(input models.GalleryUpdateInput, transl
return gallery, nil
}
-func (r *mutationResolver) BulkGalleryUpdate(ctx context.Context, input models.BulkGalleryUpdateInput) ([]*models.Gallery, error) {
+func (r *mutationResolver) BulkGalleryUpdate(ctx context.Context, input BulkGalleryUpdateInput) ([]*models.Gallery, error) {
// Populate gallery from the input
updatedTime := time.Now()
@@ -367,7 +367,7 @@ func (r *mutationResolver) BulkGalleryUpdate(ctx context.Context, input models.B
return newRet, nil
}
-func adjustGalleryPerformerIDs(qb models.GalleryReader, galleryID int, ids models.BulkUpdateIds) (ret []int, err error) {
+func adjustGalleryPerformerIDs(qb models.GalleryReader, galleryID int, ids BulkUpdateIds) (ret []int, err error) {
ret, err = qb.GetPerformerIDs(galleryID)
if err != nil {
return nil, err
@@ -376,7 +376,7 @@ func adjustGalleryPerformerIDs(qb models.GalleryReader, galleryID int, ids model
return adjustIDs(ret, ids), nil
}
-func adjustGalleryTagIDs(qb models.GalleryReader, galleryID int, ids models.BulkUpdateIds) (ret []int, err error) {
+func adjustGalleryTagIDs(qb models.GalleryReader, galleryID int, ids BulkUpdateIds) (ret []int, err error) {
ret, err = qb.GetTagIDs(galleryID)
if err != nil {
return nil, err
@@ -385,7 +385,7 @@ func adjustGalleryTagIDs(qb models.GalleryReader, galleryID int, ids models.Bulk
return adjustIDs(ret, ids), nil
}
-func adjustGallerySceneIDs(qb models.GalleryReader, galleryID int, ids models.BulkUpdateIds) (ret []int, err error) {
+func adjustGallerySceneIDs(qb models.GalleryReader, galleryID int, ids BulkUpdateIds) (ret []int, err error) {
ret, err = qb.GetSceneIDs(galleryID)
if err != nil {
return nil, err
@@ -526,7 +526,7 @@ func isStashPath(path string) bool {
return false
}
-func (r *mutationResolver) AddGalleryImages(ctx context.Context, input models.GalleryAddInput) (bool, error) {
+func (r *mutationResolver) AddGalleryImages(ctx context.Context, input GalleryAddInput) (bool, error) {
galleryID, err := strconv.Atoi(input.GalleryID)
if err != nil {
return false, err
@@ -566,7 +566,7 @@ func (r *mutationResolver) AddGalleryImages(ctx context.Context, input models.Ga
return true, nil
}
-func (r *mutationResolver) RemoveGalleryImages(ctx context.Context, input models.GalleryRemoveInput) (bool, error) {
+func (r *mutationResolver) RemoveGalleryImages(ctx context.Context, input GalleryRemoveInput) (bool, error) {
galleryID, err := strconv.Atoi(input.GalleryID)
if err != nil {
return false, err
diff --git a/internal/api/resolver_mutation_image.go b/internal/api/resolver_mutation_image.go
index 4c05c0ee6..2df016be2 100644
--- a/internal/api/resolver_mutation_image.go
+++ b/internal/api/resolver_mutation_image.go
@@ -26,7 +26,7 @@ func (r *mutationResolver) getImage(ctx context.Context, id int) (ret *models.Im
return ret, nil
}
-func (r *mutationResolver) ImageUpdate(ctx context.Context, input models.ImageUpdateInput) (ret *models.Image, err error) {
+func (r *mutationResolver) ImageUpdate(ctx context.Context, input ImageUpdateInput) (ret *models.Image, err error) {
translator := changesetTranslator{
inputMap: getUpdateInputMap(ctx),
}
@@ -44,7 +44,7 @@ func (r *mutationResolver) ImageUpdate(ctx context.Context, input models.ImageUp
return r.getImage(ctx, ret.ID)
}
-func (r *mutationResolver) ImagesUpdate(ctx context.Context, input []*models.ImageUpdateInput) (ret []*models.Image, err error) {
+func (r *mutationResolver) ImagesUpdate(ctx context.Context, input []*ImageUpdateInput) (ret []*models.Image, err error) {
inputMaps := getUpdateInputMaps(ctx)
// Start the transaction and save the image
@@ -86,7 +86,7 @@ func (r *mutationResolver) ImagesUpdate(ctx context.Context, input []*models.Ima
return newRet, nil
}
-func (r *mutationResolver) imageUpdate(input models.ImageUpdateInput, translator changesetTranslator, repo models.Repository) (*models.Image, error) {
+func (r *mutationResolver) imageUpdate(input ImageUpdateInput, translator changesetTranslator, repo models.Repository) (*models.Image, error) {
// Populate image from the input
imageID, err := strconv.Atoi(input.ID)
if err != nil {
@@ -157,7 +157,7 @@ func (r *mutationResolver) updateImageTags(qb models.ImageReaderWriter, imageID
return qb.UpdateTags(imageID, ids)
}
-func (r *mutationResolver) BulkImageUpdate(ctx context.Context, input models.BulkImageUpdateInput) (ret []*models.Image, err error) {
+func (r *mutationResolver) BulkImageUpdate(ctx context.Context, input BulkImageUpdateInput) (ret []*models.Image, err error) {
imageIDs, err := stringslice.StringSliceToIntSlice(input.Ids)
if err != nil {
return nil, err
@@ -251,7 +251,7 @@ func (r *mutationResolver) BulkImageUpdate(ctx context.Context, input models.Bul
return newRet, nil
}
-func adjustImageGalleryIDs(qb models.ImageReader, imageID int, ids models.BulkUpdateIds) (ret []int, err error) {
+func adjustImageGalleryIDs(qb models.ImageReader, imageID int, ids BulkUpdateIds) (ret []int, err error) {
ret, err = qb.GetGalleryIDs(imageID)
if err != nil {
return nil, err
@@ -260,7 +260,7 @@ func adjustImageGalleryIDs(qb models.ImageReader, imageID int, ids models.BulkUp
return adjustIDs(ret, ids), nil
}
-func adjustImagePerformerIDs(qb models.ImageReader, imageID int, ids models.BulkUpdateIds) (ret []int, err error) {
+func adjustImagePerformerIDs(qb models.ImageReader, imageID int, ids BulkUpdateIds) (ret []int, err error) {
ret, err = qb.GetPerformerIDs(imageID)
if err != nil {
return nil, err
@@ -269,7 +269,7 @@ func adjustImagePerformerIDs(qb models.ImageReader, imageID int, ids models.Bulk
return adjustIDs(ret, ids), nil
}
-func adjustImageTagIDs(qb models.ImageReader, imageID int, ids models.BulkUpdateIds) (ret []int, err error) {
+func adjustImageTagIDs(qb models.ImageReader, imageID int, ids BulkUpdateIds) (ret []int, err error) {
ret, err = qb.GetTagIDs(imageID)
if err != nil {
return nil, err
diff --git a/internal/api/resolver_mutation_metadata.go b/internal/api/resolver_mutation_metadata.go
index b8154614f..0d1794e58 100644
--- a/internal/api/resolver_mutation_metadata.go
+++ b/internal/api/resolver_mutation_metadata.go
@@ -9,15 +9,15 @@ import (
"sync"
"time"
+ "github.com/stashapp/stash/internal/identify"
"github.com/stashapp/stash/internal/manager"
"github.com/stashapp/stash/internal/manager/config"
"github.com/stashapp/stash/pkg/database"
"github.com/stashapp/stash/pkg/fsutil"
"github.com/stashapp/stash/pkg/logger"
- "github.com/stashapp/stash/pkg/models"
)
-func (r *mutationResolver) MetadataScan(ctx context.Context, input models.ScanMetadataInput) (string, error) {
+func (r *mutationResolver) MetadataScan(ctx context.Context, input manager.ScanMetadataInput) (string, error) {
jobID, err := manager.GetInstance().Scan(ctx, input)
if err != nil {
@@ -36,7 +36,7 @@ func (r *mutationResolver) MetadataImport(ctx context.Context) (string, error) {
return strconv.Itoa(jobID), nil
}
-func (r *mutationResolver) ImportObjects(ctx context.Context, input models.ImportObjectsInput) (string, error) {
+func (r *mutationResolver) ImportObjects(ctx context.Context, input manager.ImportObjectsInput) (string, error) {
t, err := manager.CreateImportTask(config.GetInstance().GetVideoFileNamingAlgorithm(), input)
if err != nil {
return "", err
@@ -56,7 +56,7 @@ func (r *mutationResolver) MetadataExport(ctx context.Context) (string, error) {
return strconv.Itoa(jobID), nil
}
-func (r *mutationResolver) ExportObjects(ctx context.Context, input models.ExportObjectsInput) (*string, error) {
+func (r *mutationResolver) ExportObjects(ctx context.Context, input manager.ExportObjectsInput) (*string, error) {
t := manager.CreateExportTask(config.GetInstance().GetVideoFileNamingAlgorithm(), input)
var wg sync.WaitGroup
@@ -75,7 +75,7 @@ func (r *mutationResolver) ExportObjects(ctx context.Context, input models.Expor
return nil, nil
}
-func (r *mutationResolver) MetadataGenerate(ctx context.Context, input models.GenerateMetadataInput) (string, error) {
+func (r *mutationResolver) MetadataGenerate(ctx context.Context, input manager.GenerateMetadataInput) (string, error) {
jobID, err := manager.GetInstance().Generate(ctx, input)
if err != nil {
@@ -85,19 +85,19 @@ func (r *mutationResolver) MetadataGenerate(ctx context.Context, input models.Ge
return strconv.Itoa(jobID), nil
}
-func (r *mutationResolver) MetadataAutoTag(ctx context.Context, input models.AutoTagMetadataInput) (string, error) {
+func (r *mutationResolver) MetadataAutoTag(ctx context.Context, input manager.AutoTagMetadataInput) (string, error) {
jobID := manager.GetInstance().AutoTag(ctx, input)
return strconv.Itoa(jobID), nil
}
-func (r *mutationResolver) MetadataIdentify(ctx context.Context, input models.IdentifyMetadataInput) (string, error) {
+func (r *mutationResolver) MetadataIdentify(ctx context.Context, input identify.Options) (string, error) {
t := manager.CreateIdentifyJob(input)
jobID := manager.GetInstance().JobManager.Add(ctx, "Identifying...", t)
return strconv.Itoa(jobID), nil
}
-func (r *mutationResolver) MetadataClean(ctx context.Context, input models.CleanMetadataInput) (string, error) {
+func (r *mutationResolver) MetadataClean(ctx context.Context, input manager.CleanMetadataInput) (string, error) {
jobID := manager.GetInstance().Clean(ctx, input)
return strconv.Itoa(jobID), nil
}
@@ -107,7 +107,7 @@ func (r *mutationResolver) MigrateHashNaming(ctx context.Context) (string, error
return strconv.Itoa(jobID), nil
}
-func (r *mutationResolver) BackupDatabase(ctx context.Context, input models.BackupDatabaseInput) (*string, error) {
+func (r *mutationResolver) BackupDatabase(ctx context.Context, input BackupDatabaseInput) (*string, error) {
// if download is true, then backup to temporary file and return a link
download := input.Download != nil && *input.Download
mgr := manager.GetInstance()
diff --git a/internal/api/resolver_mutation_movie.go b/internal/api/resolver_mutation_movie.go
index 325d022d3..da6dfdfe7 100644
--- a/internal/api/resolver_mutation_movie.go
+++ b/internal/api/resolver_mutation_movie.go
@@ -25,7 +25,7 @@ func (r *mutationResolver) getMovie(ctx context.Context, id int) (ret *models.Mo
return ret, nil
}
-func (r *mutationResolver) MovieCreate(ctx context.Context, input models.MovieCreateInput) (*models.Movie, error) {
+func (r *mutationResolver) MovieCreate(ctx context.Context, input MovieCreateInput) (*models.Movie, error) {
// generate checksum from movie name rather than image
checksum := md5.FromString(input.Name)
@@ -123,7 +123,7 @@ func (r *mutationResolver) MovieCreate(ctx context.Context, input models.MovieCr
return r.getMovie(ctx, movie.ID)
}
-func (r *mutationResolver) MovieUpdate(ctx context.Context, input models.MovieUpdateInput) (*models.Movie, error) {
+func (r *mutationResolver) MovieUpdate(ctx context.Context, input MovieUpdateInput) (*models.Movie, error) {
// Populate movie from the input
movieID, err := strconv.Atoi(input.ID)
if err != nil {
@@ -223,7 +223,7 @@ func (r *mutationResolver) MovieUpdate(ctx context.Context, input models.MovieUp
return r.getMovie(ctx, movie.ID)
}
-func (r *mutationResolver) BulkMovieUpdate(ctx context.Context, input models.BulkMovieUpdateInput) ([]*models.Movie, error) {
+func (r *mutationResolver) BulkMovieUpdate(ctx context.Context, input BulkMovieUpdateInput) ([]*models.Movie, error) {
movieIDs, err := stringslice.StringSliceToIntSlice(input.Ids)
if err != nil {
return nil, err
@@ -288,7 +288,7 @@ func (r *mutationResolver) BulkMovieUpdate(ctx context.Context, input models.Bul
return newRet, nil
}
-func (r *mutationResolver) MovieDestroy(ctx context.Context, input models.MovieDestroyInput) (bool, error) {
+func (r *mutationResolver) MovieDestroy(ctx context.Context, input MovieDestroyInput) (bool, error) {
id, err := strconv.Atoi(input.ID)
if err != nil {
return false, err
diff --git a/internal/api/resolver_mutation_performer.go b/internal/api/resolver_mutation_performer.go
index cf6335659..02c942484 100644
--- a/internal/api/resolver_mutation_performer.go
+++ b/internal/api/resolver_mutation_performer.go
@@ -26,7 +26,7 @@ func (r *mutationResolver) getPerformer(ctx context.Context, id int) (ret *model
return ret, nil
}
-func (r *mutationResolver) PerformerCreate(ctx context.Context, input models.PerformerCreateInput) (*models.Performer, error) {
+func (r *mutationResolver) PerformerCreate(ctx context.Context, input PerformerCreateInput) (*models.Performer, error) {
// generate checksum from performer name rather than image
checksum := md5.FromString(input.Name)
@@ -167,7 +167,7 @@ func (r *mutationResolver) PerformerCreate(ctx context.Context, input models.Per
return r.getPerformer(ctx, performer.ID)
}
-func (r *mutationResolver) PerformerUpdate(ctx context.Context, input models.PerformerUpdateInput) (*models.Performer, error) {
+func (r *mutationResolver) PerformerUpdate(ctx context.Context, input PerformerUpdateInput) (*models.Performer, error) {
// Populate performer from the input
performerID, _ := strconv.Atoi(input.ID)
updatedPerformer := models.PerformerPartial{
@@ -298,7 +298,7 @@ func (r *mutationResolver) updatePerformerTags(qb models.PerformerReaderWriter,
return qb.UpdateTags(performerID, ids)
}
-func (r *mutationResolver) BulkPerformerUpdate(ctx context.Context, input models.BulkPerformerUpdateInput) ([]*models.Performer, error) {
+func (r *mutationResolver) BulkPerformerUpdate(ctx context.Context, input BulkPerformerUpdateInput) ([]*models.Performer, error) {
performerIDs, err := stringslice.StringSliceToIntSlice(input.Ids)
if err != nil {
return nil, err
@@ -409,7 +409,7 @@ func (r *mutationResolver) BulkPerformerUpdate(ctx context.Context, input models
return newRet, nil
}
-func (r *mutationResolver) PerformerDestroy(ctx context.Context, input models.PerformerDestroyInput) (bool, error) {
+func (r *mutationResolver) PerformerDestroy(ctx context.Context, input PerformerDestroyInput) (bool, error) {
id, err := strconv.Atoi(input.ID)
if err != nil {
return false, err
diff --git a/internal/api/resolver_mutation_plugin.go b/internal/api/resolver_mutation_plugin.go
index 48a2a29e2..58ad359b7 100644
--- a/internal/api/resolver_mutation_plugin.go
+++ b/internal/api/resolver_mutation_plugin.go
@@ -5,10 +5,10 @@ import (
"github.com/stashapp/stash/internal/manager"
"github.com/stashapp/stash/pkg/logger"
- "github.com/stashapp/stash/pkg/models"
+ "github.com/stashapp/stash/pkg/plugin"
)
-func (r *mutationResolver) RunPluginTask(ctx context.Context, pluginID string, taskName string, args []*models.PluginArgInput) (string, error) {
+func (r *mutationResolver) RunPluginTask(ctx context.Context, pluginID string, taskName string, args []*plugin.PluginArgInput) (string, error) {
m := manager.GetInstance()
m.RunPluginTask(ctx, pluginID, taskName, args)
return "todo", nil
diff --git a/internal/api/resolver_mutation_saved_filter.go b/internal/api/resolver_mutation_saved_filter.go
index f8467cb5e..bf1b4106e 100644
--- a/internal/api/resolver_mutation_saved_filter.go
+++ b/internal/api/resolver_mutation_saved_filter.go
@@ -9,7 +9,7 @@ import (
"github.com/stashapp/stash/pkg/models"
)
-func (r *mutationResolver) SaveFilter(ctx context.Context, input models.SaveFilterInput) (ret *models.SavedFilter, err error) {
+func (r *mutationResolver) SaveFilter(ctx context.Context, input SaveFilterInput) (ret *models.SavedFilter, err error) {
if strings.TrimSpace(input.Name) == "" {
return nil, errors.New("name must be non-empty")
}
@@ -42,7 +42,7 @@ func (r *mutationResolver) SaveFilter(ctx context.Context, input models.SaveFilt
return ret, err
}
-func (r *mutationResolver) DestroySavedFilter(ctx context.Context, input models.DestroyFilterInput) (bool, error) {
+func (r *mutationResolver) DestroySavedFilter(ctx context.Context, input DestroyFilterInput) (bool, error) {
id, err := strconv.Atoi(input.ID)
if err != nil {
return false, err
@@ -57,7 +57,7 @@ func (r *mutationResolver) DestroySavedFilter(ctx context.Context, input models.
return true, nil
}
-func (r *mutationResolver) SetDefaultFilter(ctx context.Context, input models.SetDefaultFilterInput) (bool, error) {
+func (r *mutationResolver) SetDefaultFilter(ctx context.Context, input SetDefaultFilterInput) (bool, error) {
if err := r.withTxn(ctx, func(repo models.Repository) error {
qb := repo.SavedFilter()
diff --git a/internal/api/resolver_mutation_scene.go b/internal/api/resolver_mutation_scene.go
index 3d8a6b238..fdaf1d64d 100644
--- a/internal/api/resolver_mutation_scene.go
+++ b/internal/api/resolver_mutation_scene.go
@@ -232,7 +232,7 @@ func (r *mutationResolver) updateSceneGalleries(qb models.SceneReaderWriter, sce
return qb.UpdateGalleries(sceneID, ids)
}
-func (r *mutationResolver) BulkSceneUpdate(ctx context.Context, input models.BulkSceneUpdateInput) ([]*models.Scene, error) {
+func (r *mutationResolver) BulkSceneUpdate(ctx context.Context, input BulkSceneUpdateInput) ([]*models.Scene, error) {
sceneIDs, err := stringslice.StringSliceToIntSlice(input.Ids)
if err != nil {
return nil, err
@@ -343,9 +343,9 @@ func (r *mutationResolver) BulkSceneUpdate(ctx context.Context, input models.Bul
return newRet, nil
}
-func adjustIDs(existingIDs []int, updateIDs models.BulkUpdateIds) []int {
+func adjustIDs(existingIDs []int, updateIDs BulkUpdateIds) []int {
// if we are setting the ids, just return the ids
- if updateIDs.Mode == models.BulkUpdateIDModeSet {
+ if updateIDs.Mode == BulkUpdateIDModeSet {
existingIDs = []int{}
for _, idStr := range updateIDs.Ids {
id, _ := strconv.Atoi(idStr)
@@ -362,7 +362,7 @@ func adjustIDs(existingIDs []int, updateIDs models.BulkUpdateIds) []int {
foundExisting := false
for idx, existingID := range existingIDs {
if existingID == id {
- if updateIDs.Mode == models.BulkUpdateIDModeRemove {
+ if updateIDs.Mode == BulkUpdateIDModeRemove {
// remove from the list
existingIDs = append(existingIDs[:idx], existingIDs[idx+1:]...)
}
@@ -372,7 +372,7 @@ func adjustIDs(existingIDs []int, updateIDs models.BulkUpdateIds) []int {
}
}
- if !foundExisting && updateIDs.Mode != models.BulkUpdateIDModeRemove {
+ if !foundExisting && updateIDs.Mode != BulkUpdateIDModeRemove {
existingIDs = append(existingIDs, id)
}
}
@@ -380,7 +380,7 @@ func adjustIDs(existingIDs []int, updateIDs models.BulkUpdateIds) []int {
return existingIDs
}
-func adjustScenePerformerIDs(qb models.SceneReader, sceneID int, ids models.BulkUpdateIds) (ret []int, err error) {
+func adjustScenePerformerIDs(qb models.SceneReader, sceneID int, ids BulkUpdateIds) (ret []int, err error) {
ret, err = qb.GetPerformerIDs(sceneID)
if err != nil {
return nil, err
@@ -393,7 +393,7 @@ type tagIDsGetter interface {
GetTagIDs(id int) ([]int, error)
}
-func adjustTagIDs(qb tagIDsGetter, sceneID int, ids models.BulkUpdateIds) (ret []int, err error) {
+func adjustTagIDs(qb tagIDsGetter, sceneID int, ids BulkUpdateIds) (ret []int, err error) {
ret, err = qb.GetTagIDs(sceneID)
if err != nil {
return nil, err
@@ -402,7 +402,7 @@ func adjustTagIDs(qb tagIDsGetter, sceneID int, ids models.BulkUpdateIds) (ret [
return adjustIDs(ret, ids), nil
}
-func adjustSceneGalleryIDs(qb models.SceneReader, sceneID int, ids models.BulkUpdateIds) (ret []int, err error) {
+func adjustSceneGalleryIDs(qb models.SceneReader, sceneID int, ids BulkUpdateIds) (ret []int, err error) {
ret, err = qb.GetGalleryIDs(sceneID)
if err != nil {
return nil, err
@@ -411,14 +411,14 @@ func adjustSceneGalleryIDs(qb models.SceneReader, sceneID int, ids models.BulkUp
return adjustIDs(ret, ids), nil
}
-func adjustSceneMovieIDs(qb models.SceneReader, sceneID int, updateIDs models.BulkUpdateIds) ([]models.MoviesScenes, error) {
+func adjustSceneMovieIDs(qb models.SceneReader, sceneID int, updateIDs BulkUpdateIds) ([]models.MoviesScenes, error) {
existingMovies, err := qb.GetMovies(sceneID)
if err != nil {
return nil, err
}
// if we are setting the ids, just return the ids
- if updateIDs.Mode == models.BulkUpdateIDModeSet {
+ if updateIDs.Mode == BulkUpdateIDModeSet {
existingMovies = []models.MoviesScenes{}
for _, idStr := range updateIDs.Ids {
id, _ := strconv.Atoi(idStr)
@@ -435,7 +435,7 @@ func adjustSceneMovieIDs(qb models.SceneReader, sceneID int, updateIDs models.Bu
foundExisting := false
for idx, existingMovie := range existingMovies {
if existingMovie.MovieID == id {
- if updateIDs.Mode == models.BulkUpdateIDModeRemove {
+ if updateIDs.Mode == BulkUpdateIDModeRemove {
// remove from the list
existingMovies = append(existingMovies[:idx], existingMovies[idx+1:]...)
}
@@ -445,7 +445,7 @@ func adjustSceneMovieIDs(qb models.SceneReader, sceneID int, updateIDs models.Bu
}
}
- if !foundExisting && updateIDs.Mode != models.BulkUpdateIDModeRemove {
+ if !foundExisting && updateIDs.Mode != BulkUpdateIDModeRemove {
existingMovies = append(existingMovies, models.MoviesScenes{MovieID: id})
}
}
@@ -574,7 +574,7 @@ func (r *mutationResolver) getSceneMarker(ctx context.Context, id int) (ret *mod
return ret, nil
}
-func (r *mutationResolver) SceneMarkerCreate(ctx context.Context, input models.SceneMarkerCreateInput) (*models.SceneMarker, error) {
+func (r *mutationResolver) SceneMarkerCreate(ctx context.Context, input SceneMarkerCreateInput) (*models.SceneMarker, error) {
primaryTagID, err := strconv.Atoi(input.PrimaryTagID)
if err != nil {
return nil, err
@@ -609,7 +609,7 @@ func (r *mutationResolver) SceneMarkerCreate(ctx context.Context, input models.S
return r.getSceneMarker(ctx, ret.ID)
}
-func (r *mutationResolver) SceneMarkerUpdate(ctx context.Context, input models.SceneMarkerUpdateInput) (*models.SceneMarker, error) {
+func (r *mutationResolver) SceneMarkerUpdate(ctx context.Context, input SceneMarkerUpdateInput) (*models.SceneMarker, error) {
// Populate scene marker from the input
sceneMarkerID, err := strconv.Atoi(input.ID)
if err != nil {
diff --git a/internal/api/resolver_mutation_stash_box.go b/internal/api/resolver_mutation_stash_box.go
index 2b6d259ad..3d9163317 100644
--- a/internal/api/resolver_mutation_stash_box.go
+++ b/internal/api/resolver_mutation_stash_box.go
@@ -11,7 +11,7 @@ import (
"github.com/stashapp/stash/pkg/scraper/stashbox"
)
-func (r *mutationResolver) SubmitStashBoxFingerprints(ctx context.Context, input models.StashBoxFingerprintSubmissionInput) (bool, error) {
+func (r *mutationResolver) SubmitStashBoxFingerprints(ctx context.Context, input StashBoxFingerprintSubmissionInput) (bool, error) {
boxes := config.GetInstance().GetStashBoxes()
if input.StashBoxIndex < 0 || input.StashBoxIndex >= len(boxes) {
@@ -23,12 +23,12 @@ func (r *mutationResolver) SubmitStashBoxFingerprints(ctx context.Context, input
return client.SubmitStashBoxFingerprints(ctx, input.SceneIds, boxes[input.StashBoxIndex].Endpoint)
}
-func (r *mutationResolver) StashBoxBatchPerformerTag(ctx context.Context, input models.StashBoxBatchPerformerTagInput) (string, error) {
+func (r *mutationResolver) StashBoxBatchPerformerTag(ctx context.Context, input manager.StashBoxBatchPerformerTagInput) (string, error) {
jobID := manager.GetInstance().StashBoxBatchPerformerTag(ctx, input)
return strconv.Itoa(jobID), nil
}
-func (r *mutationResolver) SubmitStashBoxSceneDraft(ctx context.Context, input models.StashBoxDraftSubmissionInput) (*string, error) {
+func (r *mutationResolver) SubmitStashBoxSceneDraft(ctx context.Context, input StashBoxDraftSubmissionInput) (*string, error) {
boxes := config.GetInstance().GetStashBoxes()
if input.StashBoxIndex < 0 || input.StashBoxIndex >= len(boxes) {
@@ -58,7 +58,7 @@ func (r *mutationResolver) SubmitStashBoxSceneDraft(ctx context.Context, input m
return res, err
}
-func (r *mutationResolver) SubmitStashBoxPerformerDraft(ctx context.Context, input models.StashBoxDraftSubmissionInput) (*string, error) {
+func (r *mutationResolver) SubmitStashBoxPerformerDraft(ctx context.Context, input StashBoxDraftSubmissionInput) (*string, error) {
boxes := config.GetInstance().GetStashBoxes()
if input.StashBoxIndex < 0 || input.StashBoxIndex >= len(boxes) {
diff --git a/internal/api/resolver_mutation_studio.go b/internal/api/resolver_mutation_studio.go
index 23f2a9bdc..8fcfd3b53 100644
--- a/internal/api/resolver_mutation_studio.go
+++ b/internal/api/resolver_mutation_studio.go
@@ -27,7 +27,7 @@ func (r *mutationResolver) getStudio(ctx context.Context, id int) (ret *models.S
return ret, nil
}
-func (r *mutationResolver) StudioCreate(ctx context.Context, input models.StudioCreateInput) (*models.Studio, error) {
+func (r *mutationResolver) StudioCreate(ctx context.Context, input StudioCreateInput) (*models.Studio, error) {
// generate checksum from studio name rather than image
checksum := md5.FromString(input.Name)
@@ -115,7 +115,7 @@ func (r *mutationResolver) StudioCreate(ctx context.Context, input models.Studio
return r.getStudio(ctx, s.ID)
}
-func (r *mutationResolver) StudioUpdate(ctx context.Context, input models.StudioUpdateInput) (*models.Studio, error) {
+func (r *mutationResolver) StudioUpdate(ctx context.Context, input StudioUpdateInput) (*models.Studio, error) {
// Populate studio from the input
studioID, err := strconv.Atoi(input.ID)
if err != nil {
@@ -207,7 +207,7 @@ func (r *mutationResolver) StudioUpdate(ctx context.Context, input models.Studio
return r.getStudio(ctx, s.ID)
}
-func (r *mutationResolver) StudioDestroy(ctx context.Context, input models.StudioDestroyInput) (bool, error) {
+func (r *mutationResolver) StudioDestroy(ctx context.Context, input StudioDestroyInput) (bool, error) {
id, err := strconv.Atoi(input.ID)
if err != nil {
return false, err
diff --git a/internal/api/resolver_mutation_tag.go b/internal/api/resolver_mutation_tag.go
index 680479b8d..cff553d72 100644
--- a/internal/api/resolver_mutation_tag.go
+++ b/internal/api/resolver_mutation_tag.go
@@ -25,7 +25,7 @@ func (r *mutationResolver) getTag(ctx context.Context, id int) (ret *models.Tag,
return ret, nil
}
-func (r *mutationResolver) TagCreate(ctx context.Context, input models.TagCreateInput) (*models.Tag, error) {
+func (r *mutationResolver) TagCreate(ctx context.Context, input TagCreateInput) (*models.Tag, error) {
// Populate a new tag from the input
currentTime := time.Now()
newTag := models.Tag{
@@ -127,7 +127,7 @@ func (r *mutationResolver) TagCreate(ctx context.Context, input models.TagCreate
return r.getTag(ctx, t.ID)
}
-func (r *mutationResolver) TagUpdate(ctx context.Context, input models.TagUpdateInput) (*models.Tag, error) {
+func (r *mutationResolver) TagUpdate(ctx context.Context, input TagUpdateInput) (*models.Tag, error) {
// Populate tag from the input
tagID, err := strconv.Atoi(input.ID)
if err != nil {
@@ -252,7 +252,7 @@ func (r *mutationResolver) TagUpdate(ctx context.Context, input models.TagUpdate
return r.getTag(ctx, t.ID)
}
-func (r *mutationResolver) TagDestroy(ctx context.Context, input models.TagDestroyInput) (bool, error) {
+func (r *mutationResolver) TagDestroy(ctx context.Context, input TagDestroyInput) (bool, error) {
tagID, err := strconv.Atoi(input.ID)
if err != nil {
return false, err
@@ -295,7 +295,7 @@ func (r *mutationResolver) TagsDestroy(ctx context.Context, tagIDs []string) (bo
return true, nil
}
-func (r *mutationResolver) TagsMerge(ctx context.Context, input models.TagsMergeInput) (*models.Tag, error) {
+func (r *mutationResolver) TagsMerge(ctx context.Context, input TagsMergeInput) (*models.Tag, error) {
source, err := stringslice.StringSliceToIntSlice(input.Source)
if err != nil {
return nil, err
diff --git a/internal/api/resolver_mutation_tag_test.go b/internal/api/resolver_mutation_tag_test.go
index 9329f6b7d..c8d43f5f8 100644
--- a/internal/api/resolver_mutation_tag_test.go
+++ b/internal/api/resolver_mutation_tag_test.go
@@ -73,13 +73,13 @@ func TestTagCreate(t *testing.T) {
expectedErr := errors.New("TagCreate error")
tagRW.On("Create", mock.AnythingOfType("models.Tag")).Return(nil, expectedErr)
- _, err := r.Mutation().TagCreate(context.TODO(), models.TagCreateInput{
+ _, err := r.Mutation().TagCreate(context.TODO(), TagCreateInput{
Name: existingTagName,
})
assert.NotNil(t, err)
- _, err = r.Mutation().TagCreate(context.TODO(), models.TagCreateInput{
+ _, err = r.Mutation().TagCreate(context.TODO(), TagCreateInput{
Name: errTagName,
})
@@ -98,7 +98,7 @@ func TestTagCreate(t *testing.T) {
tagRW.On("Create", mock.AnythingOfType("models.Tag")).Return(newTag, nil)
tagRW.On("Find", newTagID).Return(newTag, nil)
- tag, err := r.Mutation().TagCreate(context.TODO(), models.TagCreateInput{
+ tag, err := r.Mutation().TagCreate(context.TODO(), TagCreateInput{
Name: tagName,
})
diff --git a/internal/api/resolver_query_configuration.go b/internal/api/resolver_query_configuration.go
index d0852ff13..39c463260 100644
--- a/internal/api/resolver_query_configuration.go
+++ b/internal/api/resolver_query_configuration.go
@@ -13,13 +13,13 @@ import (
"golang.org/x/text/collate"
)
-func (r *queryResolver) Configuration(ctx context.Context) (*models.ConfigResult, error) {
+func (r *queryResolver) Configuration(ctx context.Context) (*ConfigResult, error) {
return makeConfigResult(), nil
}
-func (r *queryResolver) Directory(ctx context.Context, path, locale *string) (*models.Directory, error) {
+func (r *queryResolver) Directory(ctx context.Context, path, locale *string) (*Directory, error) {
- directory := &models.Directory{}
+ directory := &Directory{}
var err error
col := newCollator(locale, collate.IgnoreCase, collate.Numeric)
@@ -59,8 +59,8 @@ func getParent(path string) *string {
}
}
-func makeConfigResult() *models.ConfigResult {
- return &models.ConfigResult{
+func makeConfigResult() *ConfigResult {
+ return &ConfigResult{
General: makeConfigGeneralResult(),
Interface: makeConfigInterfaceResult(),
Dlna: makeConfigDLNAResult(),
@@ -70,7 +70,7 @@ func makeConfigResult() *models.ConfigResult {
}
}
-func makeConfigGeneralResult() *models.ConfigGeneralResult {
+func makeConfigGeneralResult() *ConfigGeneralResult {
config := config.GetInstance()
logFile := config.GetLogFile()
@@ -82,7 +82,7 @@ func makeConfigGeneralResult() *models.ConfigGeneralResult {
scraperUserAgent := config.GetScraperUserAgent()
scraperCDPPath := config.GetScraperCDPPath()
- return &models.ConfigGeneralResult{
+ return &ConfigGeneralResult{
Stashes: config.GetStashPaths(),
DatabasePath: config.GetDatabasePath(),
GeneratedPath: config.GetGeneratedPath(),
@@ -125,7 +125,7 @@ func makeConfigGeneralResult() *models.ConfigGeneralResult {
}
}
-func makeConfigInterfaceResult() *models.ConfigInterfaceResult {
+func makeConfigInterfaceResult() *ConfigInterfaceResult {
config := config.GetInstance()
menuItems := config.GetMenuItems()
soundOnPreview := config.GetSoundOnPreview()
@@ -149,7 +149,7 @@ func makeConfigInterfaceResult() *models.ConfigInterfaceResult {
// FIXME - misnamed output field means we have redundant fields
disableDropdownCreate := config.GetDisableDropdownCreate()
- return &models.ConfigInterfaceResult{
+ return &ConfigInterfaceResult{
MenuItems: menuItems,
SoundOnPreview: &soundOnPreview,
WallShowTitle: &wallShowTitle,
@@ -177,10 +177,10 @@ func makeConfigInterfaceResult() *models.ConfigInterfaceResult {
}
}
-func makeConfigDLNAResult() *models.ConfigDLNAResult {
+func makeConfigDLNAResult() *ConfigDLNAResult {
config := config.GetInstance()
- return &models.ConfigDLNAResult{
+ return &ConfigDLNAResult{
ServerName: config.GetDLNAServerName(),
Enabled: config.GetDLNADefaultEnabled(),
WhitelistedIPs: config.GetDLNADefaultIPWhitelist(),
@@ -188,13 +188,13 @@ func makeConfigDLNAResult() *models.ConfigDLNAResult {
}
}
-func makeConfigScrapingResult() *models.ConfigScrapingResult {
+func makeConfigScrapingResult() *ConfigScrapingResult {
config := config.GetInstance()
scraperUserAgent := config.GetScraperUserAgent()
scraperCDPPath := config.GetScraperCDPPath()
- return &models.ConfigScrapingResult{
+ return &ConfigScrapingResult{
ScraperUserAgent: &scraperUserAgent,
ScraperCertCheck: config.GetScraperCertCheck(),
ScraperCDPPath: &scraperCDPPath,
@@ -202,12 +202,12 @@ func makeConfigScrapingResult() *models.ConfigScrapingResult {
}
}
-func makeConfigDefaultsResult() *models.ConfigDefaultSettingsResult {
+func makeConfigDefaultsResult() *ConfigDefaultSettingsResult {
config := config.GetInstance()
deleteFileDefault := config.GetDeleteFileDefault()
deleteGeneratedDefault := config.GetDeleteGeneratedDefault()
- return &models.ConfigDefaultSettingsResult{
+ return &ConfigDefaultSettingsResult{
Identify: config.GetDefaultIdentifySettings(),
Scan: config.GetDefaultScanSettings(),
AutoTag: config.GetDefaultAutoTagSettings(),
@@ -221,7 +221,7 @@ func makeConfigUIResult() map[string]interface{} {
return config.GetInstance().GetUIConfiguration()
}
-func (r *queryResolver) ValidateStashBoxCredentials(ctx context.Context, input models.StashBoxInput) (*models.StashBoxValidationResult, error) {
+func (r *queryResolver) ValidateStashBoxCredentials(ctx context.Context, input config.StashBoxInput) (*StashBoxValidationResult, error) {
client := stashbox.NewClient(models.StashBox{Endpoint: input.Endpoint, APIKey: input.APIKey}, r.txnManager)
user, err := client.GetUser(ctx)
@@ -248,7 +248,7 @@ func (r *queryResolver) ValidateStashBoxCredentials(ctx context.Context, input m
}
}
- result := models.StashBoxValidationResult{
+ result := StashBoxValidationResult{
Valid: valid,
Status: status,
}
diff --git a/internal/api/resolver_query_dlna.go b/internal/api/resolver_query_dlna.go
index 8d616f463..e620c526d 100644
--- a/internal/api/resolver_query_dlna.go
+++ b/internal/api/resolver_query_dlna.go
@@ -3,10 +3,10 @@ package api
import (
"context"
+ "github.com/stashapp/stash/internal/dlna"
"github.com/stashapp/stash/internal/manager"
- "github.com/stashapp/stash/pkg/models"
)
-func (r *queryResolver) DlnaStatus(ctx context.Context) (*models.DLNAStatus, error) {
+func (r *queryResolver) DlnaStatus(ctx context.Context) (*dlna.Status, error) {
return manager.GetInstance().DLNAService.Status(), nil
}
diff --git a/internal/api/resolver_query_find_gallery.go b/internal/api/resolver_query_find_gallery.go
index c00332a4f..bd6d07c0f 100644
--- a/internal/api/resolver_query_find_gallery.go
+++ b/internal/api/resolver_query_find_gallery.go
@@ -23,14 +23,14 @@ func (r *queryResolver) FindGallery(ctx context.Context, id string) (ret *models
return ret, nil
}
-func (r *queryResolver) FindGalleries(ctx context.Context, galleryFilter *models.GalleryFilterType, filter *models.FindFilterType) (ret *models.FindGalleriesResultType, err error) {
+func (r *queryResolver) FindGalleries(ctx context.Context, galleryFilter *models.GalleryFilterType, filter *models.FindFilterType) (ret *FindGalleriesResultType, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
galleries, total, err := repo.Gallery().Query(galleryFilter, filter)
if err != nil {
return err
}
- ret = &models.FindGalleriesResultType{
+ ret = &FindGalleriesResultType{
Count: total,
Galleries: galleries,
}
diff --git a/internal/api/resolver_query_find_image.go b/internal/api/resolver_query_find_image.go
index d26a8b081..0f01b42a1 100644
--- a/internal/api/resolver_query_find_image.go
+++ b/internal/api/resolver_query_find_image.go
@@ -38,7 +38,7 @@ func (r *queryResolver) FindImage(ctx context.Context, id *string, checksum *str
return image, nil
}
-func (r *queryResolver) FindImages(ctx context.Context, imageFilter *models.ImageFilterType, imageIds []int, filter *models.FindFilterType) (ret *models.FindImagesResultType, err error) {
+func (r *queryResolver) FindImages(ctx context.Context, imageFilter *models.ImageFilterType, imageIds []int, filter *models.FindFilterType) (ret *FindImagesResultType, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
qb := repo.Image()
@@ -62,7 +62,7 @@ func (r *queryResolver) FindImages(ctx context.Context, imageFilter *models.Imag
return err
}
- ret = &models.FindImagesResultType{
+ ret = &FindImagesResultType{
Count: result.Count,
Images: images,
Megapixels: result.Megapixels,
diff --git a/internal/api/resolver_query_find_movie.go b/internal/api/resolver_query_find_movie.go
index 1a66c2461..16a0bbe3b 100644
--- a/internal/api/resolver_query_find_movie.go
+++ b/internal/api/resolver_query_find_movie.go
@@ -23,14 +23,14 @@ func (r *queryResolver) FindMovie(ctx context.Context, id string) (ret *models.M
return ret, nil
}
-func (r *queryResolver) FindMovies(ctx context.Context, movieFilter *models.MovieFilterType, filter *models.FindFilterType) (ret *models.FindMoviesResultType, err error) {
+func (r *queryResolver) FindMovies(ctx context.Context, movieFilter *models.MovieFilterType, filter *models.FindFilterType) (ret *FindMoviesResultType, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
movies, total, err := repo.Movie().Query(movieFilter, filter)
if err != nil {
return err
}
- ret = &models.FindMoviesResultType{
+ ret = &FindMoviesResultType{
Count: total,
Movies: movies,
}
diff --git a/internal/api/resolver_query_find_performer.go b/internal/api/resolver_query_find_performer.go
index 32cc46891..5d7bb2716 100644
--- a/internal/api/resolver_query_find_performer.go
+++ b/internal/api/resolver_query_find_performer.go
@@ -23,14 +23,14 @@ func (r *queryResolver) FindPerformer(ctx context.Context, id string) (ret *mode
return ret, nil
}
-func (r *queryResolver) FindPerformers(ctx context.Context, performerFilter *models.PerformerFilterType, filter *models.FindFilterType) (ret *models.FindPerformersResultType, err error) {
+func (r *queryResolver) FindPerformers(ctx context.Context, performerFilter *models.PerformerFilterType, filter *models.FindFilterType) (ret *FindPerformersResultType, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
performers, total, err := repo.Performer().Query(performerFilter, filter)
if err != nil {
return err
}
- ret = &models.FindPerformersResultType{
+ ret = &FindPerformersResultType{
Count: total,
Performers: performers,
}
diff --git a/internal/api/resolver_query_find_scene.go b/internal/api/resolver_query_find_scene.go
index 180e25a32..486ca744a 100644
--- a/internal/api/resolver_query_find_scene.go
+++ b/internal/api/resolver_query_find_scene.go
@@ -36,7 +36,7 @@ func (r *queryResolver) FindScene(ctx context.Context, id *string, checksum *str
return scene, nil
}
-func (r *queryResolver) FindSceneByHash(ctx context.Context, input models.SceneHashInput) (*models.Scene, error) {
+func (r *queryResolver) FindSceneByHash(ctx context.Context, input SceneHashInput) (*models.Scene, error) {
var scene *models.Scene
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
@@ -64,7 +64,7 @@ func (r *queryResolver) FindSceneByHash(ctx context.Context, input models.SceneH
return scene, nil
}
-func (r *queryResolver) FindScenes(ctx context.Context, sceneFilter *models.SceneFilterType, sceneIDs []int, filter *models.FindFilterType) (ret *models.FindScenesResultType, err error) {
+func (r *queryResolver) FindScenes(ctx context.Context, sceneFilter *models.SceneFilterType, sceneIDs []int, filter *models.FindFilterType) (ret *FindScenesResultType, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
var scenes []*models.Scene
var err error
@@ -101,7 +101,7 @@ func (r *queryResolver) FindScenes(ctx context.Context, sceneFilter *models.Scen
return err
}
- ret = &models.FindScenesResultType{
+ ret = &FindScenesResultType{
Count: result.Count,
Scenes: scenes,
Duration: result.TotalDuration,
@@ -116,7 +116,7 @@ func (r *queryResolver) FindScenes(ctx context.Context, sceneFilter *models.Scen
return ret, nil
}
-func (r *queryResolver) FindScenesByPathRegex(ctx context.Context, filter *models.FindFilterType) (ret *models.FindScenesResultType, err error) {
+func (r *queryResolver) FindScenesByPathRegex(ctx context.Context, filter *models.FindFilterType) (ret *FindScenesResultType, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
sceneFilter := &models.SceneFilterType{}
@@ -156,7 +156,7 @@ func (r *queryResolver) FindScenesByPathRegex(ctx context.Context, filter *model
return err
}
- ret = &models.FindScenesResultType{
+ ret = &FindScenesResultType{
Count: result.Count,
Scenes: scenes,
Duration: result.TotalDuration,
@@ -171,7 +171,7 @@ func (r *queryResolver) FindScenesByPathRegex(ctx context.Context, filter *model
return ret, nil
}
-func (r *queryResolver) ParseSceneFilenames(ctx context.Context, filter *models.FindFilterType, config models.SceneParserInput) (ret *models.SceneParserResultType, err error) {
+func (r *queryResolver) ParseSceneFilenames(ctx context.Context, filter *models.FindFilterType, config manager.SceneParserInput) (ret *SceneParserResultType, err error) {
parser := manager.NewSceneFilenameParser(filter, config)
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
@@ -181,7 +181,7 @@ func (r *queryResolver) ParseSceneFilenames(ctx context.Context, filter *models.
return err
}
- ret = &models.SceneParserResultType{
+ ret = &SceneParserResultType{
Count: count,
Results: result,
}
diff --git a/internal/api/resolver_query_find_scene_marker.go b/internal/api/resolver_query_find_scene_marker.go
index a3d6b6058..7b0f1ee12 100644
--- a/internal/api/resolver_query_find_scene_marker.go
+++ b/internal/api/resolver_query_find_scene_marker.go
@@ -6,13 +6,13 @@ import (
"github.com/stashapp/stash/pkg/models"
)
-func (r *queryResolver) FindSceneMarkers(ctx context.Context, sceneMarkerFilter *models.SceneMarkerFilterType, filter *models.FindFilterType) (ret *models.FindSceneMarkersResultType, err error) {
+func (r *queryResolver) FindSceneMarkers(ctx context.Context, sceneMarkerFilter *models.SceneMarkerFilterType, filter *models.FindFilterType) (ret *FindSceneMarkersResultType, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
sceneMarkers, total, err := repo.SceneMarker().Query(sceneMarkerFilter, filter)
if err != nil {
return err
}
- ret = &models.FindSceneMarkersResultType{
+ ret = &FindSceneMarkersResultType{
Count: total,
SceneMarkers: sceneMarkers,
}
diff --git a/internal/api/resolver_query_find_studio.go b/internal/api/resolver_query_find_studio.go
index 71677cb35..24591e053 100644
--- a/internal/api/resolver_query_find_studio.go
+++ b/internal/api/resolver_query_find_studio.go
@@ -24,14 +24,14 @@ func (r *queryResolver) FindStudio(ctx context.Context, id string) (ret *models.
return ret, nil
}
-func (r *queryResolver) FindStudios(ctx context.Context, studioFilter *models.StudioFilterType, filter *models.FindFilterType) (ret *models.FindStudiosResultType, err error) {
+func (r *queryResolver) FindStudios(ctx context.Context, studioFilter *models.StudioFilterType, filter *models.FindFilterType) (ret *FindStudiosResultType, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
studios, total, err := repo.Studio().Query(studioFilter, filter)
if err != nil {
return err
}
- ret = &models.FindStudiosResultType{
+ ret = &FindStudiosResultType{
Count: total,
Studios: studios,
}
diff --git a/internal/api/resolver_query_find_tag.go b/internal/api/resolver_query_find_tag.go
index e44366361..21aff9f4b 100644
--- a/internal/api/resolver_query_find_tag.go
+++ b/internal/api/resolver_query_find_tag.go
@@ -23,14 +23,14 @@ func (r *queryResolver) FindTag(ctx context.Context, id string) (ret *models.Tag
return ret, nil
}
-func (r *queryResolver) FindTags(ctx context.Context, tagFilter *models.TagFilterType, filter *models.FindFilterType) (ret *models.FindTagsResultType, err error) {
+func (r *queryResolver) FindTags(ctx context.Context, tagFilter *models.TagFilterType, filter *models.FindFilterType) (ret *FindTagsResultType, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
tags, total, err := repo.Tag().Query(tagFilter, filter)
if err != nil {
return err
}
- ret = &models.FindTagsResultType{
+ ret = &FindTagsResultType{
Count: total,
Tags: tags,
}
diff --git a/internal/api/resolver_query_job.go b/internal/api/resolver_query_job.go
index 06f090190..aaa671013 100644
--- a/internal/api/resolver_query_job.go
+++ b/internal/api/resolver_query_job.go
@@ -6,13 +6,12 @@ import (
"github.com/stashapp/stash/internal/manager"
"github.com/stashapp/stash/pkg/job"
- "github.com/stashapp/stash/pkg/models"
)
-func (r *queryResolver) JobQueue(ctx context.Context) ([]*models.Job, error) {
+func (r *queryResolver) JobQueue(ctx context.Context) ([]*Job, error) {
queue := manager.GetInstance().JobManager.GetQueue()
- var ret []*models.Job
+ var ret []*Job
for _, j := range queue {
ret = append(ret, jobToJobModel(j))
}
@@ -20,7 +19,7 @@ func (r *queryResolver) JobQueue(ctx context.Context) ([]*models.Job, error) {
return ret, nil
}
-func (r *queryResolver) FindJob(ctx context.Context, input models.FindJobInput) (*models.Job, error) {
+func (r *queryResolver) FindJob(ctx context.Context, input FindJobInput) (*Job, error) {
jobID, err := strconv.Atoi(input.ID)
if err != nil {
return nil, err
@@ -33,10 +32,10 @@ func (r *queryResolver) FindJob(ctx context.Context, input models.FindJobInput)
return jobToJobModel(*j), nil
}
-func jobToJobModel(j job.Job) *models.Job {
- ret := &models.Job{
+func jobToJobModel(j job.Job) *Job {
+ ret := &Job{
ID: strconv.Itoa(j.ID),
- Status: models.JobStatus(j.Status),
+ Status: JobStatus(j.Status),
Description: j.Description,
SubTasks: j.Details,
StartTime: j.StartTime,
diff --git a/internal/api/resolver_query_logs.go b/internal/api/resolver_query_logs.go
index c6ca6d9fd..e0cee9a84 100644
--- a/internal/api/resolver_query_logs.go
+++ b/internal/api/resolver_query_logs.go
@@ -4,16 +4,15 @@ import (
"context"
"github.com/stashapp/stash/internal/manager"
- "github.com/stashapp/stash/pkg/models"
)
-func (r *queryResolver) Logs(ctx context.Context) ([]*models.LogEntry, error) {
+func (r *queryResolver) Logs(ctx context.Context) ([]*LogEntry, error) {
logger := manager.GetInstance().Logger
logCache := logger.GetLogCache()
- ret := make([]*models.LogEntry, len(logCache))
+ ret := make([]*LogEntry, len(logCache))
for i, entry := range logCache {
- ret[i] = &models.LogEntry{
+ ret[i] = &LogEntry{
Time: entry.Time,
Level: getLogLevel(entry.Type),
Message: entry.Message,
diff --git a/internal/api/resolver_query_metadata.go b/internal/api/resolver_query_metadata.go
index d96beb407..b9189b102 100644
--- a/internal/api/resolver_query_metadata.go
+++ b/internal/api/resolver_query_metadata.go
@@ -4,9 +4,8 @@ import (
"context"
"github.com/stashapp/stash/internal/manager"
- "github.com/stashapp/stash/pkg/models"
)
-func (r *queryResolver) SystemStatus(ctx context.Context) (*models.SystemStatus, error) {
+func (r *queryResolver) SystemStatus(ctx context.Context) (*manager.SystemStatus, error) {
return manager.GetInstance().GetSystemStatus(), nil
}
diff --git a/internal/api/resolver_query_plugin.go b/internal/api/resolver_query_plugin.go
index 87c93dfcc..61463c5df 100644
--- a/internal/api/resolver_query_plugin.go
+++ b/internal/api/resolver_query_plugin.go
@@ -4,13 +4,13 @@ import (
"context"
"github.com/stashapp/stash/internal/manager"
- "github.com/stashapp/stash/pkg/models"
+ "github.com/stashapp/stash/pkg/plugin"
)
-func (r *queryResolver) Plugins(ctx context.Context) ([]*models.Plugin, error) {
+func (r *queryResolver) Plugins(ctx context.Context) ([]*plugin.Plugin, error) {
return manager.GetInstance().PluginCache.ListPlugins(), nil
}
-func (r *queryResolver) PluginTasks(ctx context.Context) ([]*models.PluginTask, error) {
+func (r *queryResolver) PluginTasks(ctx context.Context) ([]*plugin.PluginTask, error) {
return manager.GetInstance().PluginCache.ListPluginTasks(), nil
}
diff --git a/internal/api/resolver_query_scene.go b/internal/api/resolver_query_scene.go
index 3d593d911..7ff77d2a8 100644
--- a/internal/api/resolver_query_scene.go
+++ b/internal/api/resolver_query_scene.go
@@ -11,7 +11,7 @@ import (
"github.com/stashapp/stash/pkg/models"
)
-func (r *queryResolver) SceneStreams(ctx context.Context, id *string) ([]*models.SceneStreamEndpoint, error) {
+func (r *queryResolver) SceneStreams(ctx context.Context, id *string) ([]*manager.SceneStreamEndpoint, error) {
// find the scene
var scene *models.Scene
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
diff --git a/internal/api/resolver_query_scraper.go b/internal/api/resolver_query_scraper.go
index 8fd6c345a..92c5200ce 100644
--- a/internal/api/resolver_query_scraper.go
+++ b/internal/api/resolver_query_scraper.go
@@ -17,13 +17,13 @@ import (
"github.com/stashapp/stash/pkg/sliceutil/stringslice"
)
-func (r *queryResolver) ScrapeURL(ctx context.Context, url string, ty models.ScrapeContentType) (models.ScrapedContent, error) {
+func (r *queryResolver) ScrapeURL(ctx context.Context, url string, ty scraper.ScrapeContentType) (scraper.ScrapedContent, error) {
return r.scraperCache().ScrapeURL(ctx, url, ty)
}
// deprecated
func (r *queryResolver) ScrapeFreeonesPerformerList(ctx context.Context, query string) ([]string, error) {
- content, err := r.scraperCache().ScrapeName(ctx, scraper.FreeonesScraperID, query, models.ScrapeContentTypePerformer)
+ content, err := r.scraperCache().ScrapeName(ctx, scraper.FreeonesScraperID, query, scraper.ScrapeContentTypePerformer)
if err != nil {
return nil, err
@@ -44,24 +44,24 @@ func (r *queryResolver) ScrapeFreeonesPerformerList(ctx context.Context, query s
return ret, nil
}
-func (r *queryResolver) ListScrapers(ctx context.Context, types []models.ScrapeContentType) ([]*models.Scraper, error) {
+func (r *queryResolver) ListScrapers(ctx context.Context, types []scraper.ScrapeContentType) ([]*scraper.Scraper, error) {
return r.scraperCache().ListScrapers(types), nil
}
-func (r *queryResolver) ListPerformerScrapers(ctx context.Context) ([]*models.Scraper, error) {
- return r.scraperCache().ListScrapers([]models.ScrapeContentType{models.ScrapeContentTypePerformer}), nil
+func (r *queryResolver) ListPerformerScrapers(ctx context.Context) ([]*scraper.Scraper, error) {
+ return r.scraperCache().ListScrapers([]scraper.ScrapeContentType{scraper.ScrapeContentTypePerformer}), nil
}
-func (r *queryResolver) ListSceneScrapers(ctx context.Context) ([]*models.Scraper, error) {
- return r.scraperCache().ListScrapers([]models.ScrapeContentType{models.ScrapeContentTypeScene}), nil
+func (r *queryResolver) ListSceneScrapers(ctx context.Context) ([]*scraper.Scraper, error) {
+ return r.scraperCache().ListScrapers([]scraper.ScrapeContentType{scraper.ScrapeContentTypeScene}), nil
}
-func (r *queryResolver) ListGalleryScrapers(ctx context.Context) ([]*models.Scraper, error) {
- return r.scraperCache().ListScrapers([]models.ScrapeContentType{models.ScrapeContentTypeGallery}), nil
+func (r *queryResolver) ListGalleryScrapers(ctx context.Context) ([]*scraper.Scraper, error) {
+ return r.scraperCache().ListScrapers([]scraper.ScrapeContentType{scraper.ScrapeContentTypeGallery}), nil
}
-func (r *queryResolver) ListMovieScrapers(ctx context.Context) ([]*models.Scraper, error) {
- return r.scraperCache().ListScrapers([]models.ScrapeContentType{models.ScrapeContentTypeMovie}), nil
+func (r *queryResolver) ListMovieScrapers(ctx context.Context) ([]*scraper.Scraper, error) {
+ return r.scraperCache().ListScrapers([]scraper.ScrapeContentType{scraper.ScrapeContentTypeMovie}), nil
}
func (r *queryResolver) ScrapePerformerList(ctx context.Context, scraperID string, query string) ([]*models.ScrapedPerformer, error) {
@@ -69,7 +69,7 @@ func (r *queryResolver) ScrapePerformerList(ctx context.Context, scraperID strin
return nil, nil
}
- content, err := r.scraperCache().ScrapeName(ctx, scraperID, query, models.ScrapeContentTypePerformer)
+ content, err := r.scraperCache().ScrapeName(ctx, scraperID, query, scraper.ScrapeContentTypePerformer)
if err != nil {
return nil, err
}
@@ -77,7 +77,7 @@ func (r *queryResolver) ScrapePerformerList(ctx context.Context, scraperID strin
return marshalScrapedPerformers(content)
}
-func (r *queryResolver) ScrapePerformer(ctx context.Context, scraperID string, scrapedPerformer models.ScrapedPerformerInput) (*models.ScrapedPerformer, error) {
+func (r *queryResolver) ScrapePerformer(ctx context.Context, scraperID string, scrapedPerformer scraper.ScrapedPerformerInput) (*models.ScrapedPerformer, error) {
content, err := r.scraperCache().ScrapeFragment(ctx, scraperID, scraper.Input{Performer: &scrapedPerformer})
if err != nil {
return nil, err
@@ -86,7 +86,7 @@ func (r *queryResolver) ScrapePerformer(ctx context.Context, scraperID string, s
}
func (r *queryResolver) ScrapePerformerURL(ctx context.Context, url string) (*models.ScrapedPerformer, error) {
- content, err := r.scraperCache().ScrapeURL(ctx, url, models.ScrapeContentTypePerformer)
+ content, err := r.scraperCache().ScrapeURL(ctx, url, scraper.ScrapeContentTypePerformer)
if err != nil {
return nil, err
}
@@ -94,12 +94,12 @@ func (r *queryResolver) ScrapePerformerURL(ctx context.Context, url string) (*mo
return marshalScrapedPerformer(content)
}
-func (r *queryResolver) ScrapeSceneQuery(ctx context.Context, scraperID string, query string) ([]*models.ScrapedScene, error) {
+func (r *queryResolver) ScrapeSceneQuery(ctx context.Context, scraperID string, query string) ([]*scraper.ScrapedScene, error) {
if query == "" {
return nil, nil
}
- content, err := r.scraperCache().ScrapeName(ctx, scraperID, query, models.ScrapeContentTypeScene)
+ content, err := r.scraperCache().ScrapeName(ctx, scraperID, query, scraper.ScrapeContentTypeScene)
if err != nil {
return nil, err
}
@@ -113,13 +113,13 @@ func (r *queryResolver) ScrapeSceneQuery(ctx context.Context, scraperID string,
return ret, nil
}
-func (r *queryResolver) ScrapeScene(ctx context.Context, scraperID string, scene models.SceneUpdateInput) (*models.ScrapedScene, error) {
+func (r *queryResolver) ScrapeScene(ctx context.Context, scraperID string, scene models.SceneUpdateInput) (*scraper.ScrapedScene, error) {
id, err := strconv.Atoi(scene.ID)
if err != nil {
return nil, fmt.Errorf("%w: scene.ID is not an integer: '%s'", ErrInput, scene.ID)
}
- content, err := r.scraperCache().ScrapeID(ctx, scraperID, id, models.ScrapeContentTypeScene)
+ content, err := r.scraperCache().ScrapeID(ctx, scraperID, id, scraper.ScrapeContentTypeScene)
if err != nil {
return nil, err
}
@@ -129,13 +129,13 @@ func (r *queryResolver) ScrapeScene(ctx context.Context, scraperID string, scene
return nil, err
}
- filterSceneTags([]*models.ScrapedScene{ret})
+ filterSceneTags([]*scraper.ScrapedScene{ret})
return ret, nil
}
// filterSceneTags removes tags matching excluded tag patterns from the provided scraped scenes
-func filterSceneTags(scenes []*models.ScrapedScene) {
+func filterSceneTags(scenes []*scraper.ScrapedScene) {
excludePatterns := manager.GetInstance().Config.GetScraperExcludeTagPatterns()
var excludeRegexps []*regexp.Regexp
@@ -179,8 +179,8 @@ func filterSceneTags(scenes []*models.ScrapedScene) {
}
}
-func (r *queryResolver) ScrapeSceneURL(ctx context.Context, url string) (*models.ScrapedScene, error) {
- content, err := r.scraperCache().ScrapeURL(ctx, url, models.ScrapeContentTypeScene)
+func (r *queryResolver) ScrapeSceneURL(ctx context.Context, url string) (*scraper.ScrapedScene, error) {
+ content, err := r.scraperCache().ScrapeURL(ctx, url, scraper.ScrapeContentTypeScene)
if err != nil {
return nil, err
}
@@ -190,18 +190,18 @@ func (r *queryResolver) ScrapeSceneURL(ctx context.Context, url string) (*models
return nil, err
}
- filterSceneTags([]*models.ScrapedScene{ret})
+ filterSceneTags([]*scraper.ScrapedScene{ret})
return ret, nil
}
-func (r *queryResolver) ScrapeGallery(ctx context.Context, scraperID string, gallery models.GalleryUpdateInput) (*models.ScrapedGallery, error) {
+func (r *queryResolver) ScrapeGallery(ctx context.Context, scraperID string, gallery models.GalleryUpdateInput) (*scraper.ScrapedGallery, error) {
id, err := strconv.Atoi(gallery.ID)
if err != nil {
return nil, fmt.Errorf("%w: gallery id is not an integer: '%s'", ErrInput, gallery.ID)
}
- content, err := r.scraperCache().ScrapeID(ctx, scraperID, id, models.ScrapeContentTypeGallery)
+ content, err := r.scraperCache().ScrapeID(ctx, scraperID, id, scraper.ScrapeContentTypeGallery)
if err != nil {
return nil, err
}
@@ -209,8 +209,8 @@ func (r *queryResolver) ScrapeGallery(ctx context.Context, scraperID string, gal
return marshalScrapedGallery(content)
}
-func (r *queryResolver) ScrapeGalleryURL(ctx context.Context, url string) (*models.ScrapedGallery, error) {
- content, err := r.scraperCache().ScrapeURL(ctx, url, models.ScrapeContentTypeGallery)
+func (r *queryResolver) ScrapeGalleryURL(ctx context.Context, url string) (*scraper.ScrapedGallery, error) {
+ content, err := r.scraperCache().ScrapeURL(ctx, url, scraper.ScrapeContentTypeGallery)
if err != nil {
return nil, err
}
@@ -219,7 +219,7 @@ func (r *queryResolver) ScrapeGalleryURL(ctx context.Context, url string) (*mode
}
func (r *queryResolver) ScrapeMovieURL(ctx context.Context, url string) (*models.ScrapedMovie, error) {
- content, err := r.scraperCache().ScrapeURL(ctx, url, models.ScrapeContentTypeMovie)
+ content, err := r.scraperCache().ScrapeURL(ctx, url, scraper.ScrapeContentTypeMovie)
if err != nil {
return nil, err
}
@@ -237,8 +237,8 @@ func (r *queryResolver) getStashBoxClient(index int) (*stashbox.Client, error) {
return stashbox.NewClient(*boxes[index], r.txnManager), nil
}
-func (r *queryResolver) ScrapeSingleScene(ctx context.Context, source models.ScraperSourceInput, input models.ScrapeSingleSceneInput) ([]*models.ScrapedScene, error) {
- var ret []*models.ScrapedScene
+func (r *queryResolver) ScrapeSingleScene(ctx context.Context, source scraper.Source, input ScrapeSingleSceneInput) ([]*scraper.ScrapedScene, error) {
+ var ret []*scraper.ScrapedScene
var sceneID int
if input.SceneID != nil {
@@ -252,22 +252,22 @@ func (r *queryResolver) ScrapeSingleScene(ctx context.Context, source models.Scr
switch {
case source.ScraperID != nil:
var err error
- var c models.ScrapedContent
- var content []models.ScrapedContent
+ var c scraper.ScrapedContent
+ var content []scraper.ScrapedContent
switch {
case input.SceneID != nil:
- c, err = r.scraperCache().ScrapeID(ctx, *source.ScraperID, sceneID, models.ScrapeContentTypeScene)
+ c, err = r.scraperCache().ScrapeID(ctx, *source.ScraperID, sceneID, scraper.ScrapeContentTypeScene)
if c != nil {
- content = []models.ScrapedContent{c}
+ content = []scraper.ScrapedContent{c}
}
case input.SceneInput != nil:
c, err = r.scraperCache().ScrapeFragment(ctx, *source.ScraperID, scraper.Input{Scene: input.SceneInput})
if c != nil {
- content = []models.ScrapedContent{c}
+ content = []scraper.ScrapedContent{c}
}
case input.Query != nil:
- content, err = r.scraperCache().ScrapeName(ctx, *source.ScraperID, *input.Query, models.ScrapeContentTypeScene)
+ content, err = r.scraperCache().ScrapeName(ctx, *source.ScraperID, *input.Query, scraper.ScrapeContentTypeScene)
default:
err = fmt.Errorf("%w: scene_id, scene_input, or query must be set", ErrInput)
}
@@ -307,7 +307,7 @@ func (r *queryResolver) ScrapeSingleScene(ctx context.Context, source models.Scr
return ret, nil
}
-func (r *queryResolver) ScrapeMultiScenes(ctx context.Context, source models.ScraperSourceInput, input models.ScrapeMultiScenesInput) ([][]*models.ScrapedScene, error) {
+func (r *queryResolver) ScrapeMultiScenes(ctx context.Context, source scraper.Source, input ScrapeMultiScenesInput) ([][]*scraper.ScrapedScene, error) {
if source.ScraperID != nil {
return nil, ErrNotImplemented
} else if source.StashBoxIndex != nil {
@@ -327,7 +327,7 @@ func (r *queryResolver) ScrapeMultiScenes(ctx context.Context, source models.Scr
return nil, errors.New("scraper_id or stash_box_index must be set")
}
-func (r *queryResolver) ScrapeSinglePerformer(ctx context.Context, source models.ScraperSourceInput, input models.ScrapeSinglePerformerInput) ([]*models.ScrapedPerformer, error) {
+func (r *queryResolver) ScrapeSinglePerformer(ctx context.Context, source scraper.Source, input ScrapeSinglePerformerInput) ([]*models.ScrapedPerformer, error) {
if source.ScraperID != nil {
if input.PerformerInput != nil {
performer, err := r.scraperCache().ScrapeFragment(ctx, *source.ScraperID, scraper.Input{Performer: input.PerformerInput})
@@ -335,11 +335,11 @@ func (r *queryResolver) ScrapeSinglePerformer(ctx context.Context, source models
return nil, err
}
- return marshalScrapedPerformers([]models.ScrapedContent{performer})
+ return marshalScrapedPerformers([]scraper.ScrapedContent{performer})
}
if input.Query != nil {
- content, err := r.scraperCache().ScrapeName(ctx, *source.ScraperID, *input.Query, models.ScrapeContentTypePerformer)
+ content, err := r.scraperCache().ScrapeName(ctx, *source.ScraperID, *input.Query, scraper.ScrapeContentTypePerformer)
if err != nil {
return nil, err
}
@@ -354,7 +354,7 @@ func (r *queryResolver) ScrapeSinglePerformer(ctx context.Context, source models
return nil, err
}
- var ret []*models.StashBoxPerformerQueryResult
+ var ret []*stashbox.StashBoxPerformerQueryResult
switch {
case input.PerformerID != nil:
ret, err = client.FindStashBoxPerformersByNames(ctx, []string{*input.PerformerID})
@@ -378,7 +378,7 @@ func (r *queryResolver) ScrapeSinglePerformer(ctx context.Context, source models
return nil, errors.New("scraper_id or stash_box_index must be set")
}
-func (r *queryResolver) ScrapeMultiPerformers(ctx context.Context, source models.ScraperSourceInput, input models.ScrapeMultiPerformersInput) ([][]*models.ScrapedPerformer, error) {
+func (r *queryResolver) ScrapeMultiPerformers(ctx context.Context, source scraper.Source, input ScrapeMultiPerformersInput) ([][]*models.ScrapedPerformer, error) {
if source.ScraperID != nil {
return nil, ErrNotImplemented
} else if source.StashBoxIndex != nil {
@@ -393,7 +393,7 @@ func (r *queryResolver) ScrapeMultiPerformers(ctx context.Context, source models
return nil, errors.New("scraper_id or stash_box_index must be set")
}
-func (r *queryResolver) ScrapeSingleGallery(ctx context.Context, source models.ScraperSourceInput, input models.ScrapeSingleGalleryInput) ([]*models.ScrapedGallery, error) {
+func (r *queryResolver) ScrapeSingleGallery(ctx context.Context, source scraper.Source, input ScrapeSingleGalleryInput) ([]*scraper.ScrapedGallery, error) {
if source.StashBoxIndex != nil {
return nil, ErrNotSupported
}
@@ -402,7 +402,7 @@ func (r *queryResolver) ScrapeSingleGallery(ctx context.Context, source models.S
return nil, fmt.Errorf("%w: scraper_id must be set", ErrInput)
}
- var c models.ScrapedContent
+ var c scraper.ScrapedContent
switch {
case input.GalleryID != nil:
@@ -410,22 +410,22 @@ func (r *queryResolver) ScrapeSingleGallery(ctx context.Context, source models.S
if err != nil {
return nil, fmt.Errorf("%w: gallery id is not an integer: '%s'", ErrInput, *input.GalleryID)
}
- c, err = r.scraperCache().ScrapeID(ctx, *source.ScraperID, galleryID, models.ScrapeContentTypeGallery)
+ c, err = r.scraperCache().ScrapeID(ctx, *source.ScraperID, galleryID, scraper.ScrapeContentTypeGallery)
if err != nil {
return nil, err
}
- return marshalScrapedGalleries([]models.ScrapedContent{c})
+ return marshalScrapedGalleries([]scraper.ScrapedContent{c})
case input.GalleryInput != nil:
c, err := r.scraperCache().ScrapeFragment(ctx, *source.ScraperID, scraper.Input{Gallery: input.GalleryInput})
if err != nil {
return nil, err
}
- return marshalScrapedGalleries([]models.ScrapedContent{c})
+ return marshalScrapedGalleries([]scraper.ScrapedContent{c})
default:
return nil, ErrNotImplemented
}
}
-func (r *queryResolver) ScrapeSingleMovie(ctx context.Context, source models.ScraperSourceInput, input models.ScrapeSingleMovieInput) ([]*models.ScrapedMovie, error) {
+func (r *queryResolver) ScrapeSingleMovie(ctx context.Context, source scraper.Source, input ScrapeSingleMovieInput) ([]*models.ScrapedMovie, error) {
return nil, ErrNotSupported
}
diff --git a/internal/api/resolver_subscription_job.go b/internal/api/resolver_subscription_job.go
index 8e2da6654..84ebee400 100644
--- a/internal/api/resolver_subscription_job.go
+++ b/internal/api/resolver_subscription_job.go
@@ -5,18 +5,17 @@ import (
"github.com/stashapp/stash/internal/manager"
"github.com/stashapp/stash/pkg/job"
- "github.com/stashapp/stash/pkg/models"
)
-func makeJobStatusUpdate(t models.JobStatusUpdateType, j job.Job) *models.JobStatusUpdate {
- return &models.JobStatusUpdate{
+func makeJobStatusUpdate(t JobStatusUpdateType, j job.Job) *JobStatusUpdate {
+ return &JobStatusUpdate{
Type: t,
Job: jobToJobModel(j),
}
}
-func (r *subscriptionResolver) JobsSubscribe(ctx context.Context) (<-chan *models.JobStatusUpdate, error) {
- msg := make(chan *models.JobStatusUpdate, 100)
+func (r *subscriptionResolver) JobsSubscribe(ctx context.Context) (<-chan *JobStatusUpdate, error) {
+ msg := make(chan *JobStatusUpdate, 100)
subscription := manager.GetInstance().JobManager.Subscribe(ctx)
@@ -24,11 +23,11 @@ func (r *subscriptionResolver) JobsSubscribe(ctx context.Context) (<-chan *model
for {
select {
case j := <-subscription.NewJob:
- msg <- makeJobStatusUpdate(models.JobStatusUpdateTypeAdd, j)
+ msg <- makeJobStatusUpdate(JobStatusUpdateTypeAdd, j)
case j := <-subscription.RemovedJob:
- msg <- makeJobStatusUpdate(models.JobStatusUpdateTypeRemove, j)
+ msg <- makeJobStatusUpdate(JobStatusUpdateTypeRemove, j)
case j := <-subscription.UpdatedJob:
- msg <- makeJobStatusUpdate(models.JobStatusUpdateTypeUpdate, j)
+ msg <- makeJobStatusUpdate(JobStatusUpdateTypeUpdate, j)
case <-ctx.Done():
close(msg)
return
diff --git a/internal/api/resolver_subscription_logging.go b/internal/api/resolver_subscription_logging.go
index 17241e12b..423fa88af 100644
--- a/internal/api/resolver_subscription_logging.go
+++ b/internal/api/resolver_subscription_logging.go
@@ -5,33 +5,32 @@ import (
"github.com/stashapp/stash/internal/log"
"github.com/stashapp/stash/internal/manager"
- "github.com/stashapp/stash/pkg/models"
)
-func getLogLevel(logType string) models.LogLevel {
+func getLogLevel(logType string) LogLevel {
switch logType {
case "progress":
- return models.LogLevelProgress
+ return LogLevelProgress
case "trace":
- return models.LogLevelTrace
+ return LogLevelTrace
case "debug":
- return models.LogLevelDebug
+ return LogLevelDebug
case "info":
- return models.LogLevelInfo
+ return LogLevelInfo
case "warn":
- return models.LogLevelWarning
+ return LogLevelWarning
case "error":
- return models.LogLevelError
+ return LogLevelError
default:
- return models.LogLevelDebug
+ return LogLevelDebug
}
}
-func logEntriesFromLogItems(logItems []log.LogItem) []*models.LogEntry {
- ret := make([]*models.LogEntry, len(logItems))
+func logEntriesFromLogItems(logItems []log.LogItem) []*LogEntry {
+ ret := make([]*LogEntry, len(logItems))
for i, entry := range logItems {
- ret[i] = &models.LogEntry{
+ ret[i] = &LogEntry{
Time: entry.Time,
Level: getLogLevel(entry.Type),
Message: entry.Message,
@@ -41,8 +40,8 @@ func logEntriesFromLogItems(logItems []log.LogItem) []*models.LogEntry {
return ret
}
-func (r *subscriptionResolver) LoggingSubscribe(ctx context.Context) (<-chan []*models.LogEntry, error) {
- ret := make(chan []*models.LogEntry, 100)
+func (r *subscriptionResolver) LoggingSubscribe(ctx context.Context) (<-chan []*LogEntry, error) {
+ ret := make(chan []*LogEntry, 100)
stop := make(chan int, 1)
logger := manager.GetInstance().Logger
logSub := logger.SubscribeToLog(stop)
diff --git a/internal/api/scraped_content.go b/internal/api/scraped_content.go
index 46664724a..6d9003892 100644
--- a/internal/api/scraped_content.go
+++ b/internal/api/scraped_content.go
@@ -4,12 +4,13 @@ import (
"fmt"
"github.com/stashapp/stash/pkg/models"
+ "github.com/stashapp/stash/pkg/scraper"
)
// marshalScrapedScenes converts ScrapedContent into ScrapedScene. If conversion fails, an
// error is returned to the caller.
-func marshalScrapedScenes(content []models.ScrapedContent) ([]*models.ScrapedScene, error) {
- var ret []*models.ScrapedScene
+func marshalScrapedScenes(content []scraper.ScrapedContent) ([]*scraper.ScrapedScene, error) {
+ var ret []*scraper.ScrapedScene
for _, c := range content {
if c == nil {
// graphql schema requires scenes to be non-nil
@@ -17,9 +18,9 @@ func marshalScrapedScenes(content []models.ScrapedContent) ([]*models.ScrapedSce
}
switch s := c.(type) {
- case *models.ScrapedScene:
+ case *scraper.ScrapedScene:
ret = append(ret, s)
- case models.ScrapedScene:
+ case scraper.ScrapedScene:
ret = append(ret, &s)
default:
return nil, fmt.Errorf("%w: cannot turn ScrapedContent into ScrapedScene", models.ErrConversion)
@@ -31,7 +32,7 @@ func marshalScrapedScenes(content []models.ScrapedContent) ([]*models.ScrapedSce
// marshalScrapedPerformers converts ScrapedContent into ScrapedPerformer. If conversion
// fails, an error is returned to the caller.
-func marshalScrapedPerformers(content []models.ScrapedContent) ([]*models.ScrapedPerformer, error) {
+func marshalScrapedPerformers(content []scraper.ScrapedContent) ([]*models.ScrapedPerformer, error) {
var ret []*models.ScrapedPerformer
for _, c := range content {
if c == nil {
@@ -54,8 +55,8 @@ func marshalScrapedPerformers(content []models.ScrapedContent) ([]*models.Scrape
// marshalScrapedGalleries converts ScrapedContent into ScrapedGallery. If
// conversion fails, an error is returned.
-func marshalScrapedGalleries(content []models.ScrapedContent) ([]*models.ScrapedGallery, error) {
- var ret []*models.ScrapedGallery
+func marshalScrapedGalleries(content []scraper.ScrapedContent) ([]*scraper.ScrapedGallery, error) {
+ var ret []*scraper.ScrapedGallery
for _, c := range content {
if c == nil {
// graphql schema requires galleries to be non-nil
@@ -63,9 +64,9 @@ func marshalScrapedGalleries(content []models.ScrapedContent) ([]*models.Scraped
}
switch g := c.(type) {
- case *models.ScrapedGallery:
+ case *scraper.ScrapedGallery:
ret = append(ret, g)
- case models.ScrapedGallery:
+ case scraper.ScrapedGallery:
ret = append(ret, &g)
default:
return nil, fmt.Errorf("%w: cannot turn ScrapedContent into ScrapedGallery", models.ErrConversion)
@@ -77,7 +78,7 @@ func marshalScrapedGalleries(content []models.ScrapedContent) ([]*models.Scraped
// marshalScrapedMovies converts ScrapedContent into ScrapedMovie. If conversion
// fails, an error is returned.
-func marshalScrapedMovies(content []models.ScrapedContent) ([]*models.ScrapedMovie, error) {
+func marshalScrapedMovies(content []scraper.ScrapedContent) ([]*models.ScrapedMovie, error) {
var ret []*models.ScrapedMovie
for _, c := range content {
if c == nil {
@@ -99,8 +100,8 @@ func marshalScrapedMovies(content []models.ScrapedContent) ([]*models.ScrapedMov
}
// marshalScrapedPerformer will marshal a single performer
-func marshalScrapedPerformer(content models.ScrapedContent) (*models.ScrapedPerformer, error) {
- p, err := marshalScrapedPerformers([]models.ScrapedContent{content})
+func marshalScrapedPerformer(content scraper.ScrapedContent) (*models.ScrapedPerformer, error) {
+ p, err := marshalScrapedPerformers([]scraper.ScrapedContent{content})
if err != nil {
return nil, err
}
@@ -109,8 +110,8 @@ func marshalScrapedPerformer(content models.ScrapedContent) (*models.ScrapedPerf
}
// marshalScrapedScene will marshal a single scraped scene
-func marshalScrapedScene(content models.ScrapedContent) (*models.ScrapedScene, error) {
- s, err := marshalScrapedScenes([]models.ScrapedContent{content})
+func marshalScrapedScene(content scraper.ScrapedContent) (*scraper.ScrapedScene, error) {
+ s, err := marshalScrapedScenes([]scraper.ScrapedContent{content})
if err != nil {
return nil, err
}
@@ -119,8 +120,8 @@ func marshalScrapedScene(content models.ScrapedContent) (*models.ScrapedScene, e
}
// marshalScrapedGallery will marshal a single scraped gallery
-func marshalScrapedGallery(content models.ScrapedContent) (*models.ScrapedGallery, error) {
- g, err := marshalScrapedGalleries([]models.ScrapedContent{content})
+func marshalScrapedGallery(content scraper.ScrapedContent) (*scraper.ScrapedGallery, error) {
+ g, err := marshalScrapedGalleries([]scraper.ScrapedContent{content})
if err != nil {
return nil, err
}
@@ -129,8 +130,8 @@ func marshalScrapedGallery(content models.ScrapedContent) (*models.ScrapedGaller
}
// marshalScrapedMovie will marshal a single scraped movie
-func marshalScrapedMovie(content models.ScrapedContent) (*models.ScrapedMovie, error) {
- m, err := marshalScrapedMovies([]models.ScrapedContent{content})
+func marshalScrapedMovie(content scraper.ScrapedContent) (*models.ScrapedMovie, error) {
+ m, err := marshalScrapedMovies([]scraper.ScrapedContent{content})
if err != nil {
return nil, err
}
diff --git a/internal/api/server.go b/internal/api/server.go
index c89e20d48..95ce7f775 100644
--- a/internal/api/server.go
+++ b/internal/api/server.go
@@ -30,7 +30,6 @@ import (
"github.com/stashapp/stash/internal/manager/config"
"github.com/stashapp/stash/pkg/fsutil"
"github.com/stashapp/stash/pkg/logger"
- "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/ui"
)
@@ -81,7 +80,7 @@ func Start() error {
hookExecutor: pluginCache,
}
- gqlSrv := gqlHandler.New(models.NewExecutableSchema(models.Config{Resolvers: resolver}))
+ gqlSrv := gqlHandler.New(NewExecutableSchema(Config{Resolvers: resolver}))
gqlSrv.SetRecoverFunc(recoverFunc)
gqlSrv.AddTransport(gqlTransport.Websocket{
Upgrader: websocket.Upgrader{
diff --git a/internal/dlna/service.go b/internal/dlna/service.go
index 6c14e8ee5..961e3f230 100644
--- a/internal/dlna/service.go
+++ b/internal/dlna/service.go
@@ -12,6 +12,20 @@ import (
"github.com/stashapp/stash/pkg/models"
)
+type Status struct {
+ Running bool `json:"running"`
+ // If not currently running, time until it will be started. If running, time until it will be stopped
+ Until *time.Time `json:"until"`
+ RecentIPAddresses []string `json:"recentIPAddresses"`
+ AllowedIPAddresses []*Dlnaip `json:"allowedIPAddresses"`
+}
+
+type Dlnaip struct {
+ IPAddress string `json:"ipAddress"`
+ // Time until IP will be no longer allowed/disallowed
+ Until *time.Time `json:"until"`
+}
+
type dmsConfig struct {
Path string
IfNames []string
@@ -273,11 +287,11 @@ func (s *Service) IsRunning() bool {
return s.running
}
-func (s *Service) Status() *models.DLNAStatus {
+func (s *Service) Status() *Status {
s.mutex.Lock()
defer s.mutex.Unlock()
- ret := &models.DLNAStatus{
+ ret := &Status{
Running: s.running,
RecentIPAddresses: s.ipWhitelistMgr.getRecent(),
AllowedIPAddresses: s.ipWhitelistMgr.getTempAllowed(),
diff --git a/internal/dlna/whitelist.go b/internal/dlna/whitelist.go
index 447e68702..4bf14b20e 100644
--- a/internal/dlna/whitelist.go
+++ b/internal/dlna/whitelist.go
@@ -4,7 +4,6 @@ import (
"sync"
"time"
- "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/sliceutil/stringslice"
)
@@ -59,11 +58,11 @@ func (m *ipWhitelistManager) getRecent() []string {
return m.recentIPAddresses
}
-func (m *ipWhitelistManager) getTempAllowed() []*models.Dlnaip {
+func (m *ipWhitelistManager) getTempAllowed() []*Dlnaip {
m.mutex.Lock()
defer m.mutex.Unlock()
- var ret []*models.Dlnaip
+ var ret []*Dlnaip
now := time.Now()
removeExpired := false
@@ -73,7 +72,7 @@ func (m *ipWhitelistManager) getTempAllowed() []*models.Dlnaip {
continue
}
- ret = append(ret, &models.Dlnaip{
+ ret = append(ret, &Dlnaip{
IPAddress: a.pattern,
Until: a.until,
})
diff --git a/internal/identify/identify.go b/internal/identify/identify.go
index d57670465..4d0d0afa2 100644
--- a/internal/identify/identify.go
+++ b/internal/identify/identify.go
@@ -8,11 +8,12 @@ import (
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/scene"
+ "github.com/stashapp/stash/pkg/scraper"
"github.com/stashapp/stash/pkg/utils"
)
type SceneScraper interface {
- ScrapeScene(ctx context.Context, sceneID int) (*models.ScrapedScene, error)
+ ScrapeScene(ctx context.Context, sceneID int) (*scraper.ScrapedScene, error)
}
type SceneUpdatePostHookExecutor interface {
@@ -21,13 +22,13 @@ type SceneUpdatePostHookExecutor interface {
type ScraperSource struct {
Name string
- Options *models.IdentifyMetadataOptionsInput
+ Options *MetadataOptions
Scraper SceneScraper
RemoteSite string
}
type SceneIdentifier struct {
- DefaultOptions *models.IdentifyMetadataOptionsInput
+ DefaultOptions *MetadataOptions
Sources []ScraperSource
ScreenshotSetter scene.ScreenshotSetter
SceneUpdatePostHookExecutor SceneUpdatePostHookExecutor
@@ -53,7 +54,7 @@ func (t *SceneIdentifier) Identify(ctx context.Context, txnManager models.Transa
}
type scrapeResult struct {
- result *models.ScrapedScene
+ result *scraper.ScrapedScene
source ScraperSource
}
@@ -84,7 +85,7 @@ func (t *SceneIdentifier) getSceneUpdater(ctx context.Context, s *models.Scene,
ID: s.ID,
}
- options := []models.IdentifyMetadataOptionsInput{}
+ options := []MetadataOptions{}
if result.source.Options != nil {
options = append(options, *result.source.Options)
}
@@ -208,9 +209,9 @@ func (t *SceneIdentifier) modifyScene(ctx context.Context, txnManager models.Tra
return nil
}
-func getFieldOptions(options []models.IdentifyMetadataOptionsInput) map[string]*models.IdentifyFieldOptionsInput {
+func getFieldOptions(options []MetadataOptions) map[string]*FieldOptions {
// prefer source-specific field strategies, then the defaults
- ret := make(map[string]*models.IdentifyFieldOptionsInput)
+ ret := make(map[string]*FieldOptions)
for _, oo := range options {
for _, f := range oo.FieldOptions {
if _, found := ret[f.Field]; !found {
@@ -222,7 +223,7 @@ func getFieldOptions(options []models.IdentifyMetadataOptionsInput) map[string]*
return ret
}
-func getScenePartial(scene *models.Scene, scraped *models.ScrapedScene, fieldOptions map[string]*models.IdentifyFieldOptionsInput, setOrganized bool) models.ScenePartial {
+func getScenePartial(scene *models.Scene, scraped *scraper.ScrapedScene, fieldOptions map[string]*FieldOptions, setOrganized bool) models.ScenePartial {
partial := models.ScenePartial{
ID: scene.ID,
}
@@ -259,17 +260,17 @@ func getScenePartial(scene *models.Scene, scraped *models.ScrapedScene, fieldOpt
return partial
}
-func shouldSetSingleValueField(strategy *models.IdentifyFieldOptionsInput, hasExistingValue bool) bool {
+func shouldSetSingleValueField(strategy *FieldOptions, hasExistingValue bool) bool {
// if unset then default to MERGE
- fs := models.IdentifyFieldStrategyMerge
+ fs := FieldStrategyMerge
if strategy != nil && strategy.Strategy.IsValid() {
fs = strategy.Strategy
}
- if fs == models.IdentifyFieldStrategyIgnore {
+ if fs == FieldStrategyIgnore {
return false
}
- return !hasExistingValue || fs == models.IdentifyFieldStrategyOverwrite
+ return !hasExistingValue || fs == FieldStrategyOverwrite
}
diff --git a/internal/identify/identify_test.go b/internal/identify/identify_test.go
index 3a36015f4..c5588c78a 100644
--- a/internal/identify/identify_test.go
+++ b/internal/identify/identify_test.go
@@ -8,16 +8,17 @@ import (
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/models/mocks"
+ "github.com/stashapp/stash/pkg/scraper"
"github.com/stashapp/stash/pkg/sliceutil/intslice"
"github.com/stretchr/testify/mock"
)
type mockSceneScraper struct {
errIDs []int
- results map[int]*models.ScrapedScene
+ results map[int]*scraper.ScrapedScene
}
-func (s mockSceneScraper) ScrapeScene(ctx context.Context, sceneID int) (*models.ScrapedScene, error) {
+func (s mockSceneScraper) ScrapeScene(ctx context.Context, sceneID int) (*scraper.ScrapedScene, error) {
if intslice.IntInclude(s.errIDs, sceneID) {
return nil, errors.New("scrape scene error")
}
@@ -42,12 +43,12 @@ func TestSceneIdentifier_Identify(t *testing.T) {
var scrapedTitle = "scrapedTitle"
- defaultOptions := &models.IdentifyMetadataOptionsInput{}
+ defaultOptions := &MetadataOptions{}
sources := []ScraperSource{
{
Scraper: mockSceneScraper{
errIDs: []int{errID1},
- results: map[int]*models.ScrapedScene{
+ results: map[int]*scraper.ScrapedScene{
found1ID: {
Title: &scrapedTitle,
},
@@ -57,7 +58,7 @@ func TestSceneIdentifier_Identify(t *testing.T) {
{
Scraper: mockSceneScraper{
errIDs: []int{errID2},
- results: map[int]*models.ScrapedScene{
+ results: map[int]*scraper.ScrapedScene{
found2ID: {
Title: &scrapedTitle,
},
@@ -150,7 +151,7 @@ func TestSceneIdentifier_modifyScene(t *testing.T) {
args{
&models.Scene{},
&scrapeResult{
- result: &models.ScrapedScene{},
+ result: &scraper.ScrapedScene{},
},
},
false,
@@ -173,55 +174,55 @@ func Test_getFieldOptions(t *testing.T) {
)
type args struct {
- options []models.IdentifyMetadataOptionsInput
+ options []MetadataOptions
}
tests := []struct {
name string
args args
- want map[string]*models.IdentifyFieldOptionsInput
+ want map[string]*FieldOptions
}{
{
"simple",
args{
- []models.IdentifyMetadataOptionsInput{
+ []MetadataOptions{
{
- FieldOptions: []*models.IdentifyFieldOptionsInput{
+ FieldOptions: []*FieldOptions{
{
Field: inFirst,
- Strategy: models.IdentifyFieldStrategyIgnore,
+ Strategy: FieldStrategyIgnore,
},
{
Field: inBoth,
- Strategy: models.IdentifyFieldStrategyIgnore,
+ Strategy: FieldStrategyIgnore,
},
},
},
{
- FieldOptions: []*models.IdentifyFieldOptionsInput{
+ FieldOptions: []*FieldOptions{
{
Field: inSecond,
- Strategy: models.IdentifyFieldStrategyMerge,
+ Strategy: FieldStrategyMerge,
},
{
Field: inBoth,
- Strategy: models.IdentifyFieldStrategyMerge,
+ Strategy: FieldStrategyMerge,
},
},
},
},
},
- map[string]*models.IdentifyFieldOptionsInput{
+ map[string]*FieldOptions{
inFirst: {
Field: inFirst,
- Strategy: models.IdentifyFieldStrategyIgnore,
+ Strategy: FieldStrategyIgnore,
},
inSecond: {
Field: inSecond,
- Strategy: models.IdentifyFieldStrategyMerge,
+ Strategy: FieldStrategyMerge,
},
inBoth: {
Field: inBoth,
- Strategy: models.IdentifyFieldStrategyIgnore,
+ Strategy: FieldStrategyIgnore,
},
},
},
@@ -275,22 +276,22 @@ func Test_getScenePartial(t *testing.T) {
URL: models.NullStringPtr(scrapedURL),
}
- scrapedScene := &models.ScrapedScene{
+ scrapedScene := &scraper.ScrapedScene{
Title: &scrapedTitle,
Date: &scrapedDate,
Details: &scrapedDetails,
URL: &scrapedURL,
}
- scrapedUnchangedScene := &models.ScrapedScene{
+ scrapedUnchangedScene := &scraper.ScrapedScene{
Title: &originalTitle,
Date: &originalDate,
Details: &originalDetails,
URL: &originalURL,
}
- makeFieldOptions := func(input *models.IdentifyFieldOptionsInput) map[string]*models.IdentifyFieldOptionsInput {
- return map[string]*models.IdentifyFieldOptionsInput{
+ makeFieldOptions := func(input *FieldOptions) map[string]*FieldOptions {
+ return map[string]*FieldOptions{
"title": input,
"date": input,
"details": input,
@@ -298,22 +299,22 @@ func Test_getScenePartial(t *testing.T) {
}
}
- overwriteAll := makeFieldOptions(&models.IdentifyFieldOptionsInput{
- Strategy: models.IdentifyFieldStrategyOverwrite,
+ overwriteAll := makeFieldOptions(&FieldOptions{
+ Strategy: FieldStrategyOverwrite,
})
- ignoreAll := makeFieldOptions(&models.IdentifyFieldOptionsInput{
- Strategy: models.IdentifyFieldStrategyIgnore,
+ ignoreAll := makeFieldOptions(&FieldOptions{
+ Strategy: FieldStrategyIgnore,
})
- mergeAll := makeFieldOptions(&models.IdentifyFieldOptionsInput{
- Strategy: models.IdentifyFieldStrategyMerge,
+ mergeAll := makeFieldOptions(&FieldOptions{
+ Strategy: FieldStrategyMerge,
})
setOrganised := true
type args struct {
scene *models.Scene
- scraped *models.ScrapedScene
- fieldOptions map[string]*models.IdentifyFieldOptionsInput
+ scraped *scraper.ScrapedScene
+ fieldOptions map[string]*FieldOptions
setOrganized bool
}
tests := []struct {
@@ -407,7 +408,7 @@ func Test_shouldSetSingleValueField(t *testing.T) {
const invalid = "invalid"
type args struct {
- strategy *models.IdentifyFieldOptionsInput
+ strategy *FieldOptions
hasExistingValue bool
}
tests := []struct {
@@ -418,8 +419,8 @@ func Test_shouldSetSingleValueField(t *testing.T) {
{
"ignore",
args{
- &models.IdentifyFieldOptionsInput{
- Strategy: models.IdentifyFieldStrategyIgnore,
+ &FieldOptions{
+ Strategy: FieldStrategyIgnore,
},
false,
},
@@ -428,8 +429,8 @@ func Test_shouldSetSingleValueField(t *testing.T) {
{
"merge existing",
args{
- &models.IdentifyFieldOptionsInput{
- Strategy: models.IdentifyFieldStrategyMerge,
+ &FieldOptions{
+ Strategy: FieldStrategyMerge,
},
true,
},
@@ -438,8 +439,8 @@ func Test_shouldSetSingleValueField(t *testing.T) {
{
"merge absent",
args{
- &models.IdentifyFieldOptionsInput{
- Strategy: models.IdentifyFieldStrategyMerge,
+ &FieldOptions{
+ Strategy: FieldStrategyMerge,
},
false,
},
@@ -448,8 +449,8 @@ func Test_shouldSetSingleValueField(t *testing.T) {
{
"overwrite",
args{
- &models.IdentifyFieldOptionsInput{
- Strategy: models.IdentifyFieldStrategyOverwrite,
+ &FieldOptions{
+ Strategy: FieldStrategyOverwrite,
},
true,
},
@@ -458,7 +459,7 @@ func Test_shouldSetSingleValueField(t *testing.T) {
{
"nil (merge) existing",
args{
- &models.IdentifyFieldOptionsInput{},
+ &FieldOptions{},
true,
},
false,
@@ -466,7 +467,7 @@ func Test_shouldSetSingleValueField(t *testing.T) {
{
"nil (merge) absent",
args{
- &models.IdentifyFieldOptionsInput{},
+ &FieldOptions{},
false,
},
true,
@@ -474,7 +475,7 @@ func Test_shouldSetSingleValueField(t *testing.T) {
{
"invalid (merge) existing",
args{
- &models.IdentifyFieldOptionsInput{
+ &FieldOptions{
Strategy: invalid,
},
true,
@@ -484,7 +485,7 @@ func Test_shouldSetSingleValueField(t *testing.T) {
{
"invalid (merge) absent",
args{
- &models.IdentifyFieldOptionsInput{
+ &FieldOptions{
Strategy: invalid,
},
false,
diff --git a/internal/identify/options.go b/internal/identify/options.go
new file mode 100644
index 000000000..84530e5fc
--- /dev/null
+++ b/internal/identify/options.go
@@ -0,0 +1,92 @@
+package identify
+
+import (
+ "fmt"
+ "io"
+ "strconv"
+
+ "github.com/stashapp/stash/pkg/scraper"
+)
+
+type Source struct {
+ Source *scraper.Source `json:"source"`
+ // Options defined for a source override the defaults
+ Options *MetadataOptions `json:"options"`
+}
+
+type Options struct {
+ // An ordered list of sources to identify items with. Only the first source that finds a match is used.
+ Sources []*Source `json:"sources"`
+ // Options defined here override the configured defaults
+ Options *MetadataOptions `json:"options"`
+ // scene ids to identify
+ SceneIDs []string `json:"sceneIDs"`
+ // paths of scenes to identify - ignored if scene ids are set
+ Paths []string `json:"paths"`
+}
+
+type MetadataOptions struct {
+ // any fields missing from here are defaulted to MERGE and createMissing false
+ FieldOptions []*FieldOptions `json:"fieldOptions"`
+ // defaults to true if not provided
+ SetCoverImage *bool `json:"setCoverImage"`
+ SetOrganized *bool `json:"setOrganized"`
+ // defaults to true if not provided
+ IncludeMalePerformers *bool `json:"includeMalePerformers"`
+}
+
+type FieldOptions struct {
+ Field string `json:"field"`
+ Strategy FieldStrategy `json:"strategy"`
+ // creates missing objects if needed - only applicable for performers, tags and studios
+ CreateMissing *bool `json:"createMissing"`
+}
+
+type FieldStrategy string
+
+const (
+ // Never sets the field value
+ FieldStrategyIgnore FieldStrategy = "IGNORE"
+ // For multi-value fields, merge with existing.
+ // For single-value fields, ignore if already set
+ FieldStrategyMerge FieldStrategy = "MERGE"
+ // Always replaces the value if a value is found.
+ // For multi-value fields, any existing values are removed and replaced with the
+ // scraped values.
+ FieldStrategyOverwrite FieldStrategy = "OVERWRITE"
+)
+
+var AllFieldStrategy = []FieldStrategy{
+ FieldStrategyIgnore,
+ FieldStrategyMerge,
+ FieldStrategyOverwrite,
+}
+
+func (e FieldStrategy) IsValid() bool {
+ switch e {
+ case FieldStrategyIgnore, FieldStrategyMerge, FieldStrategyOverwrite:
+ return true
+ }
+ return false
+}
+
+func (e FieldStrategy) String() string {
+ return string(e)
+}
+
+func (e *FieldStrategy) UnmarshalGQL(v interface{}) error {
+ str, ok := v.(string)
+ if !ok {
+ return fmt.Errorf("enums must be strings")
+ }
+
+ *e = FieldStrategy(str)
+ if !e.IsValid() {
+ return fmt.Errorf("%s is not a valid IdentifyFieldStrategy", str)
+ }
+ return nil
+}
+
+func (e FieldStrategy) MarshalGQL(w io.Writer) {
+ fmt.Fprint(w, strconv.Quote(e.String()))
+}
diff --git a/internal/identify/scene.go b/internal/identify/scene.go
index 166755451..3e6fd4a38 100644
--- a/internal/identify/scene.go
+++ b/internal/identify/scene.go
@@ -18,7 +18,7 @@ type sceneRelationships struct {
repo models.Repository
scene *models.Scene
result *scrapeResult
- fieldOptions map[string]*models.IdentifyFieldOptionsInput
+ fieldOptions map[string]*FieldOptions
}
func (g sceneRelationships) studio() (*int64, error) {
@@ -61,7 +61,7 @@ func (g sceneRelationships) performers(ignoreMale bool) ([]int, error) {
}
createMissing := fieldStrategy != nil && utils.IsTrue(fieldStrategy.CreateMissing)
- strategy := models.IdentifyFieldStrategyMerge
+ strategy := FieldStrategyMerge
if fieldStrategy != nil {
strategy = fieldStrategy.Strategy
}
@@ -75,7 +75,7 @@ func (g sceneRelationships) performers(ignoreMale bool) ([]int, error) {
return nil, fmt.Errorf("error getting scene performers: %w", err)
}
- if strategy == models.IdentifyFieldStrategyMerge {
+ if strategy == FieldStrategyMerge {
// add to existing
performerIDs = originalPerformerIDs
}
@@ -115,7 +115,7 @@ func (g sceneRelationships) tags() ([]int, error) {
}
createMissing := fieldStrategy != nil && utils.IsTrue(fieldStrategy.CreateMissing)
- strategy := models.IdentifyFieldStrategyMerge
+ strategy := FieldStrategyMerge
if fieldStrategy != nil {
strategy = fieldStrategy.Strategy
}
@@ -126,7 +126,7 @@ func (g sceneRelationships) tags() ([]int, error) {
return nil, fmt.Errorf("error getting scene tags: %w", err)
}
- if strategy == models.IdentifyFieldStrategyMerge {
+ if strategy == FieldStrategyMerge {
// add to existing
tagIDs = originalTagIDs
}
@@ -176,7 +176,7 @@ func (g sceneRelationships) stashIDs() ([]models.StashID, error) {
return nil, nil
}
- strategy := models.IdentifyFieldStrategyMerge
+ strategy := FieldStrategyMerge
if fieldStrategy != nil {
strategy = fieldStrategy.Strategy
}
@@ -193,7 +193,7 @@ func (g sceneRelationships) stashIDs() ([]models.StashID, error) {
originalStashIDs = append(originalStashIDs, *stashID)
}
- if strategy == models.IdentifyFieldStrategyMerge {
+ if strategy == FieldStrategyMerge {
// add to existing
stashIDs = originalStashIDs
}
diff --git a/internal/identify/scene_test.go b/internal/identify/scene_test.go
index f0ba7da17..2487e6808 100644
--- a/internal/identify/scene_test.go
+++ b/internal/identify/scene_test.go
@@ -9,6 +9,7 @@ import (
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/models/mocks"
+ "github.com/stashapp/stash/pkg/scraper"
"github.com/stashapp/stash/pkg/utils"
"github.com/stretchr/testify/mock"
)
@@ -19,8 +20,8 @@ func Test_sceneRelationships_studio(t *testing.T) {
invalidStoredID := "invalidStoredID"
createMissing := true
- defaultOptions := &models.IdentifyFieldOptionsInput{
- Strategy: models.IdentifyFieldStrategyMerge,
+ defaultOptions := &FieldOptions{
+ Strategy: FieldStrategyMerge,
}
repo := mocks.NewTransactionManager()
@@ -30,13 +31,13 @@ func Test_sceneRelationships_studio(t *testing.T) {
tr := sceneRelationships{
repo: repo,
- fieldOptions: make(map[string]*models.IdentifyFieldOptionsInput),
+ fieldOptions: make(map[string]*FieldOptions),
}
tests := []struct {
name string
scene *models.Scene
- fieldOptions *models.IdentifyFieldOptionsInput
+ fieldOptions *FieldOptions
result *models.ScrapedStudio
want *int64
wantErr bool
@@ -52,8 +53,8 @@ func Test_sceneRelationships_studio(t *testing.T) {
{
"ignore",
&models.Scene{},
- &models.IdentifyFieldOptionsInput{
- Strategy: models.IdentifyFieldStrategyIgnore,
+ &FieldOptions{
+ Strategy: FieldStrategyIgnore,
},
&models.ScrapedStudio{
StoredID: &validStoredID,
@@ -104,8 +105,8 @@ func Test_sceneRelationships_studio(t *testing.T) {
{
"create missing",
&models.Scene{},
- &models.IdentifyFieldOptionsInput{
- Strategy: models.IdentifyFieldStrategyMerge,
+ &FieldOptions{
+ Strategy: FieldStrategyMerge,
CreateMissing: &createMissing,
},
&models.ScrapedStudio{},
@@ -118,7 +119,7 @@ func Test_sceneRelationships_studio(t *testing.T) {
tr.scene = tt.scene
tr.fieldOptions["studio"] = tt.fieldOptions
tr.result = &scrapeResult{
- result: &models.ScrapedScene{
+ result: &scraper.ScrapedScene{
Studio: tt.result,
},
}
@@ -151,8 +152,8 @@ func Test_sceneRelationships_performers(t *testing.T) {
female := models.GenderEnumFemale.String()
male := models.GenderEnumMale.String()
- defaultOptions := &models.IdentifyFieldOptionsInput{
- Strategy: models.IdentifyFieldStrategyMerge,
+ defaultOptions := &FieldOptions{
+ Strategy: FieldStrategyMerge,
}
repo := mocks.NewTransactionManager()
@@ -162,13 +163,13 @@ func Test_sceneRelationships_performers(t *testing.T) {
tr := sceneRelationships{
repo: repo,
- fieldOptions: make(map[string]*models.IdentifyFieldOptionsInput),
+ fieldOptions: make(map[string]*FieldOptions),
}
tests := []struct {
name string
sceneID int
- fieldOptions *models.IdentifyFieldOptionsInput
+ fieldOptions *FieldOptions
scraped []*models.ScrapedPerformer
ignoreMale bool
want []int
@@ -177,8 +178,8 @@ func Test_sceneRelationships_performers(t *testing.T) {
{
"ignore",
sceneID,
- &models.IdentifyFieldOptionsInput{
- Strategy: models.IdentifyFieldStrategyIgnore,
+ &FieldOptions{
+ Strategy: FieldStrategyIgnore,
},
[]*models.ScrapedPerformer{
{
@@ -255,8 +256,8 @@ func Test_sceneRelationships_performers(t *testing.T) {
{
"overwrite",
sceneWithPerformerID,
- &models.IdentifyFieldOptionsInput{
- Strategy: models.IdentifyFieldStrategyOverwrite,
+ &FieldOptions{
+ Strategy: FieldStrategyOverwrite,
},
[]*models.ScrapedPerformer{
{
@@ -271,8 +272,8 @@ func Test_sceneRelationships_performers(t *testing.T) {
{
"ignore male (not male)",
sceneWithPerformerID,
- &models.IdentifyFieldOptionsInput{
- Strategy: models.IdentifyFieldStrategyOverwrite,
+ &FieldOptions{
+ Strategy: FieldStrategyOverwrite,
},
[]*models.ScrapedPerformer{
{
@@ -288,8 +289,8 @@ func Test_sceneRelationships_performers(t *testing.T) {
{
"error getting tag ID",
sceneID,
- &models.IdentifyFieldOptionsInput{
- Strategy: models.IdentifyFieldStrategyOverwrite,
+ &FieldOptions{
+ Strategy: FieldStrategyOverwrite,
CreateMissing: &createMissing,
},
[]*models.ScrapedPerformer{
@@ -310,7 +311,7 @@ func Test_sceneRelationships_performers(t *testing.T) {
}
tr.fieldOptions["performers"] = tt.fieldOptions
tr.result = &scrapeResult{
- result: &models.ScrapedScene{
+ result: &scraper.ScrapedScene{
Performers: tt.scraped,
},
}
@@ -342,8 +343,8 @@ func Test_sceneRelationships_tags(t *testing.T) {
validName := "validName"
invalidName := "invalidName"
- defaultOptions := &models.IdentifyFieldOptionsInput{
- Strategy: models.IdentifyFieldStrategyMerge,
+ defaultOptions := &FieldOptions{
+ Strategy: FieldStrategyMerge,
}
repo := mocks.NewTransactionManager()
@@ -362,13 +363,13 @@ func Test_sceneRelationships_tags(t *testing.T) {
tr := sceneRelationships{
repo: repo,
- fieldOptions: make(map[string]*models.IdentifyFieldOptionsInput),
+ fieldOptions: make(map[string]*FieldOptions),
}
tests := []struct {
name string
sceneID int
- fieldOptions *models.IdentifyFieldOptionsInput
+ fieldOptions *FieldOptions
scraped []*models.ScrapedTag
want []int
wantErr bool
@@ -376,8 +377,8 @@ func Test_sceneRelationships_tags(t *testing.T) {
{
"ignore",
sceneID,
- &models.IdentifyFieldOptionsInput{
- Strategy: models.IdentifyFieldStrategyIgnore,
+ &FieldOptions{
+ Strategy: FieldStrategyIgnore,
},
[]*models.ScrapedTag{
{
@@ -434,8 +435,8 @@ func Test_sceneRelationships_tags(t *testing.T) {
{
"overwrite",
sceneWithTagID,
- &models.IdentifyFieldOptionsInput{
- Strategy: models.IdentifyFieldStrategyOverwrite,
+ &FieldOptions{
+ Strategy: FieldStrategyOverwrite,
},
[]*models.ScrapedTag{
{
@@ -449,8 +450,8 @@ func Test_sceneRelationships_tags(t *testing.T) {
{
"error getting tag ID",
sceneID,
- &models.IdentifyFieldOptionsInput{
- Strategy: models.IdentifyFieldStrategyOverwrite,
+ &FieldOptions{
+ Strategy: FieldStrategyOverwrite,
},
[]*models.ScrapedTag{
{
@@ -464,8 +465,8 @@ func Test_sceneRelationships_tags(t *testing.T) {
{
"create missing",
sceneID,
- &models.IdentifyFieldOptionsInput{
- Strategy: models.IdentifyFieldStrategyOverwrite,
+ &FieldOptions{
+ Strategy: FieldStrategyOverwrite,
CreateMissing: &createMissing,
},
[]*models.ScrapedTag{
@@ -479,8 +480,8 @@ func Test_sceneRelationships_tags(t *testing.T) {
{
"error creating",
sceneID,
- &models.IdentifyFieldOptionsInput{
- Strategy: models.IdentifyFieldStrategyOverwrite,
+ &FieldOptions{
+ Strategy: FieldStrategyOverwrite,
CreateMissing: &createMissing,
},
[]*models.ScrapedTag{
@@ -499,7 +500,7 @@ func Test_sceneRelationships_tags(t *testing.T) {
}
tr.fieldOptions["tags"] = tt.fieldOptions
tr.result = &scrapeResult{
- result: &models.ScrapedScene{
+ result: &scraper.ScrapedScene{
Tags: tt.scraped,
},
}
@@ -529,8 +530,8 @@ func Test_sceneRelationships_stashIDs(t *testing.T) {
remoteSiteID := "remoteSiteID"
newRemoteSiteID := "newRemoteSiteID"
- defaultOptions := &models.IdentifyFieldOptionsInput{
- Strategy: models.IdentifyFieldStrategyMerge,
+ defaultOptions := &FieldOptions{
+ Strategy: FieldStrategyMerge,
}
repo := mocks.NewTransactionManager()
@@ -545,13 +546,13 @@ func Test_sceneRelationships_stashIDs(t *testing.T) {
tr := sceneRelationships{
repo: repo,
- fieldOptions: make(map[string]*models.IdentifyFieldOptionsInput),
+ fieldOptions: make(map[string]*FieldOptions),
}
tests := []struct {
name string
sceneID int
- fieldOptions *models.IdentifyFieldOptionsInput
+ fieldOptions *FieldOptions
endpoint string
remoteSiteID *string
want []models.StashID
@@ -560,8 +561,8 @@ func Test_sceneRelationships_stashIDs(t *testing.T) {
{
"ignore",
sceneID,
- &models.IdentifyFieldOptionsInput{
- Strategy: models.IdentifyFieldStrategyIgnore,
+ &FieldOptions{
+ Strategy: FieldStrategyIgnore,
},
newEndpoint,
&remoteSiteID,
@@ -639,8 +640,8 @@ func Test_sceneRelationships_stashIDs(t *testing.T) {
{
"overwrite",
sceneWithStashID,
- &models.IdentifyFieldOptionsInput{
- Strategy: models.IdentifyFieldStrategyOverwrite,
+ &FieldOptions{
+ Strategy: FieldStrategyOverwrite,
},
newEndpoint,
&newRemoteSiteID,
@@ -655,8 +656,8 @@ func Test_sceneRelationships_stashIDs(t *testing.T) {
{
"overwrite same",
sceneWithStashID,
- &models.IdentifyFieldOptionsInput{
- Strategy: models.IdentifyFieldStrategyOverwrite,
+ &FieldOptions{
+ Strategy: FieldStrategyOverwrite,
},
existingEndpoint,
&remoteSiteID,
@@ -674,7 +675,7 @@ func Test_sceneRelationships_stashIDs(t *testing.T) {
source: ScraperSource{
RemoteSite: tt.endpoint,
},
- result: &models.ScrapedScene{
+ result: &scraper.ScrapedScene{
RemoteSiteID: tt.remoteSiteID,
},
}
@@ -712,7 +713,7 @@ func Test_sceneRelationships_cover(t *testing.T) {
tr := sceneRelationships{
repo: repo,
- fieldOptions: make(map[string]*models.IdentifyFieldOptionsInput),
+ fieldOptions: make(map[string]*FieldOptions),
}
tests := []struct {
@@ -764,7 +765,7 @@ func Test_sceneRelationships_cover(t *testing.T) {
ID: tt.sceneID,
}
tr.result = &scrapeResult{
- result: &models.ScrapedScene{
+ result: &scraper.ScrapedScene{
Image: tt.image,
},
}
diff --git a/internal/manager/config/config.go b/internal/manager/config/config.go
index 8b3725080..90fa2e742 100644
--- a/internal/manager/config/config.go
+++ b/internal/manager/config/config.go
@@ -15,6 +15,7 @@ import (
"github.com/spf13/viper"
+ "github.com/stashapp/stash/internal/identify"
"github.com/stashapp/stash/pkg/fsutil"
"github.com/stashapp/stash/pkg/hash"
"github.com/stashapp/stash/pkg/logger"
@@ -147,10 +148,10 @@ const (
// Image lightbox options
legacyImageLightboxSlideshowDelay = "slideshow_delay"
ImageLightboxSlideshowDelay = "image_lightbox.slideshow_delay"
- ImageLightboxDisplayMode = "image_lightbox.display_mode"
+ ImageLightboxDisplayModeKey = "image_lightbox.display_mode"
ImageLightboxScaleUp = "image_lightbox.scale_up"
ImageLightboxResetZoomOnNav = "image_lightbox.reset_zoom_on_nav"
- ImageLightboxScrollMode = "image_lightbox.scroll_mode"
+ ImageLightboxScrollModeKey = "image_lightbox.scroll_mode"
ImageLightboxScrollAttemptsBeforeChange = "image_lightbox.scroll_attempts_before_change"
UI = "ui"
@@ -460,14 +461,27 @@ func (i *Instance) getStringMapString(key string) map[string]string {
return i.viper(key).GetStringMapString(key)
}
+type StashConfig struct {
+ Path string `json:"path"`
+ ExcludeVideo bool `json:"excludeVideo"`
+ ExcludeImage bool `json:"excludeImage"`
+}
+
+// Stash configuration details
+type StashConfigInput struct {
+ Path string `json:"path"`
+ ExcludeVideo bool `json:"excludeVideo"`
+ ExcludeImage bool `json:"excludeImage"`
+}
+
// GetStathPaths returns the configured stash library paths.
// Works opposite to the usual case - it will return the override
// value only if the main value is not set.
-func (i *Instance) GetStashPaths() []*models.StashConfig {
+func (i *Instance) GetStashPaths() []*StashConfig {
i.RLock()
defer i.RUnlock()
- var ret []*models.StashConfig
+ var ret []*StashConfig
v := i.main
if !v.IsSet(Stash) {
@@ -479,7 +493,7 @@ func (i *Instance) GetStashPaths() []*models.StashConfig {
ss := v.GetStringSlice(Stash)
ret = nil
for _, path := range ss {
- toAdd := &models.StashConfig{
+ toAdd := &StashConfig{
Path: path,
}
ret = append(ret, toAdd)
@@ -610,8 +624,8 @@ func (i *Instance) GetScraperExcludeTagPatterns() []string {
return i.getStringSlice(ScraperExcludeTagPatterns)
}
-func (i *Instance) GetStashBoxes() models.StashBoxes {
- var boxes models.StashBoxes
+func (i *Instance) GetStashBoxes() []*models.StashBox {
+ var boxes []*models.StashBox
if err := i.unmarshalKey(StashBoxes, &boxes); err != nil {
logger.Warnf("error in unmarshalkey: %v", err)
}
@@ -797,7 +811,13 @@ func (i *Instance) ValidateCredentials(username string, password string) bool {
var stashBoxRe = regexp.MustCompile("^http.*graphql$")
-func (i *Instance) ValidateStashBoxes(boxes []*models.StashBoxInput) error {
+type StashBoxInput struct {
+ Endpoint string `json:"endpoint"`
+ APIKey string `json:"api_key"`
+ Name string `json:"name"`
+}
+
+func (i *Instance) ValidateStashBoxes(boxes []*StashBoxInput) error {
isMulti := len(boxes) > 1
for _, box := range boxes {
@@ -933,18 +953,18 @@ func (i *Instance) getSlideshowDelay() int {
return ret
}
-func (i *Instance) GetImageLightboxOptions() models.ConfigImageLightboxResult {
+func (i *Instance) GetImageLightboxOptions() ConfigImageLightboxResult {
i.RLock()
defer i.RUnlock()
delay := i.getSlideshowDelay()
- ret := models.ConfigImageLightboxResult{
+ ret := ConfigImageLightboxResult{
SlideshowDelay: &delay,
}
- if v := i.viperWith(ImageLightboxDisplayMode); v != nil {
- mode := models.ImageLightboxDisplayMode(v.GetString(ImageLightboxDisplayMode))
+ if v := i.viperWith(ImageLightboxDisplayModeKey); v != nil {
+ mode := ImageLightboxDisplayMode(v.GetString(ImageLightboxDisplayModeKey))
ret.DisplayMode = &mode
}
if v := i.viperWith(ImageLightboxScaleUp); v != nil {
@@ -955,8 +975,8 @@ func (i *Instance) GetImageLightboxOptions() models.ConfigImageLightboxResult {
value := v.GetBool(ImageLightboxResetZoomOnNav)
ret.ResetZoomOnNav = &value
}
- if v := i.viperWith(ImageLightboxScrollMode); v != nil {
- mode := models.ImageLightboxScrollMode(v.GetString(ImageLightboxScrollMode))
+ if v := i.viperWith(ImageLightboxScrollModeKey); v != nil {
+ mode := ImageLightboxScrollMode(v.GetString(ImageLightboxScrollModeKey))
ret.ScrollMode = &mode
}
if v := i.viperWith(ImageLightboxScrollAttemptsBeforeChange); v != nil {
@@ -966,8 +986,8 @@ func (i *Instance) GetImageLightboxOptions() models.ConfigImageLightboxResult {
return ret
}
-func (i *Instance) GetDisableDropdownCreate() *models.ConfigDisableDropdownCreate {
- return &models.ConfigDisableDropdownCreate{
+func (i *Instance) GetDisableDropdownCreate() *ConfigDisableDropdownCreate {
+ return &ConfigDisableDropdownCreate{
Performer: i.getBool(DisableDropdownCreatePerformer),
Studio: i.getBool(DisableDropdownCreateStudio),
Tag: i.getBool(DisableDropdownCreateTag),
@@ -1056,13 +1076,13 @@ func (i *Instance) GetDeleteGeneratedDefault() bool {
// GetDefaultIdentifySettings returns the default Identify task settings.
// Returns nil if the settings could not be unmarshalled, or if it
// has not been set.
-func (i *Instance) GetDefaultIdentifySettings() *models.IdentifyMetadataTaskOptions {
+func (i *Instance) GetDefaultIdentifySettings() *identify.Options {
i.RLock()
defer i.RUnlock()
v := i.viper(DefaultIdentifySettings)
if v.IsSet(DefaultIdentifySettings) {
- var ret models.IdentifyMetadataTaskOptions
+ var ret identify.Options
if err := v.UnmarshalKey(DefaultIdentifySettings, &ret); err != nil {
return nil
}
@@ -1075,13 +1095,13 @@ func (i *Instance) GetDefaultIdentifySettings() *models.IdentifyMetadataTaskOpti
// GetDefaultScanSettings returns the default Scan task settings.
// Returns nil if the settings could not be unmarshalled, or if it
// has not been set.
-func (i *Instance) GetDefaultScanSettings() *models.ScanMetadataOptions {
+func (i *Instance) GetDefaultScanSettings() *ScanMetadataOptions {
i.RLock()
defer i.RUnlock()
v := i.viper(DefaultScanSettings)
if v.IsSet(DefaultScanSettings) {
- var ret models.ScanMetadataOptions
+ var ret ScanMetadataOptions
if err := v.UnmarshalKey(DefaultScanSettings, &ret); err != nil {
return nil
}
@@ -1094,13 +1114,13 @@ func (i *Instance) GetDefaultScanSettings() *models.ScanMetadataOptions {
// GetDefaultAutoTagSettings returns the default Scan task settings.
// Returns nil if the settings could not be unmarshalled, or if it
// has not been set.
-func (i *Instance) GetDefaultAutoTagSettings() *models.AutoTagMetadataOptions {
+func (i *Instance) GetDefaultAutoTagSettings() *AutoTagMetadataOptions {
i.RLock()
defer i.RUnlock()
v := i.viper(DefaultAutoTagSettings)
if v.IsSet(DefaultAutoTagSettings) {
- var ret models.AutoTagMetadataOptions
+ var ret AutoTagMetadataOptions
if err := v.UnmarshalKey(DefaultAutoTagSettings, &ret); err != nil {
return nil
}
diff --git a/internal/manager/config/tasks.go b/internal/manager/config/tasks.go
new file mode 100644
index 000000000..e455ea3e1
--- /dev/null
+++ b/internal/manager/config/tasks.go
@@ -0,0 +1,27 @@
+package config
+
+type ScanMetadataOptions struct {
+ // Set name, date, details from metadata (if present)
+ UseFileMetadata bool `json:"useFileMetadata"`
+ // Strip file extension from title
+ StripFileExtension bool `json:"stripFileExtension"`
+ // Generate previews during scan
+ ScanGeneratePreviews bool `json:"scanGeneratePreviews"`
+ // Generate image previews during scan
+ ScanGenerateImagePreviews bool `json:"scanGenerateImagePreviews"`
+ // Generate sprites during scan
+ ScanGenerateSprites bool `json:"scanGenerateSprites"`
+ // Generate phashes during scan
+ ScanGeneratePhashes bool `json:"scanGeneratePhashes"`
+ // Generate image thumbnails during scan
+ ScanGenerateThumbnails bool `json:"scanGenerateThumbnails"`
+}
+
+type AutoTagMetadataOptions struct {
+ // IDs of performers to tag files with, or "*" for all
+ Performers []string `json:"performers"`
+ // IDs of studios to tag files with, or "*" for all
+ Studios []string `json:"studios"`
+ // IDs of tags to tag files with, or "*" for all
+ Tags []string `json:"tags"`
+}
diff --git a/internal/manager/config/ui.go b/internal/manager/config/ui.go
new file mode 100644
index 000000000..a2744a741
--- /dev/null
+++ b/internal/manager/config/ui.go
@@ -0,0 +1,106 @@
+package config
+
+import (
+ "fmt"
+ "io"
+ "strconv"
+)
+
+type ConfigImageLightboxResult struct {
+ SlideshowDelay *int `json:"slideshowDelay"`
+ DisplayMode *ImageLightboxDisplayMode `json:"displayMode"`
+ ScaleUp *bool `json:"scaleUp"`
+ ResetZoomOnNav *bool `json:"resetZoomOnNav"`
+ ScrollMode *ImageLightboxScrollMode `json:"scrollMode"`
+ ScrollAttemptsBeforeChange int `json:"scrollAttemptsBeforeChange"`
+}
+
+type ImageLightboxDisplayMode string
+
+const (
+ ImageLightboxDisplayModeOriginal ImageLightboxDisplayMode = "ORIGINAL"
+ ImageLightboxDisplayModeFitXy ImageLightboxDisplayMode = "FIT_XY"
+ ImageLightboxDisplayModeFitX ImageLightboxDisplayMode = "FIT_X"
+)
+
+var AllImageLightboxDisplayMode = []ImageLightboxDisplayMode{
+ ImageLightboxDisplayModeOriginal,
+ ImageLightboxDisplayModeFitXy,
+ ImageLightboxDisplayModeFitX,
+}
+
+func (e ImageLightboxDisplayMode) IsValid() bool {
+ switch e {
+ case ImageLightboxDisplayModeOriginal, ImageLightboxDisplayModeFitXy, ImageLightboxDisplayModeFitX:
+ return true
+ }
+ return false
+}
+
+func (e ImageLightboxDisplayMode) String() string {
+ return string(e)
+}
+
+func (e *ImageLightboxDisplayMode) UnmarshalGQL(v interface{}) error {
+ str, ok := v.(string)
+ if !ok {
+ return fmt.Errorf("enums must be strings")
+ }
+
+ *e = ImageLightboxDisplayMode(str)
+ if !e.IsValid() {
+ return fmt.Errorf("%s is not a valid ImageLightboxDisplayMode", str)
+ }
+ return nil
+}
+
+func (e ImageLightboxDisplayMode) MarshalGQL(w io.Writer) {
+ fmt.Fprint(w, strconv.Quote(e.String()))
+}
+
+type ImageLightboxScrollMode string
+
+const (
+ ImageLightboxScrollModeZoom ImageLightboxScrollMode = "ZOOM"
+ ImageLightboxScrollModePanY ImageLightboxScrollMode = "PAN_Y"
+)
+
+var AllImageLightboxScrollMode = []ImageLightboxScrollMode{
+ ImageLightboxScrollModeZoom,
+ ImageLightboxScrollModePanY,
+}
+
+func (e ImageLightboxScrollMode) IsValid() bool {
+ switch e {
+ case ImageLightboxScrollModeZoom, ImageLightboxScrollModePanY:
+ return true
+ }
+ return false
+}
+
+func (e ImageLightboxScrollMode) String() string {
+ return string(e)
+}
+
+func (e *ImageLightboxScrollMode) UnmarshalGQL(v interface{}) error {
+ str, ok := v.(string)
+ if !ok {
+ return fmt.Errorf("enums must be strings")
+ }
+
+ *e = ImageLightboxScrollMode(str)
+ if !e.IsValid() {
+ return fmt.Errorf("%s is not a valid ImageLightboxScrollMode", str)
+ }
+ return nil
+}
+
+func (e ImageLightboxScrollMode) MarshalGQL(w io.Writer) {
+ fmt.Fprint(w, strconv.Quote(e.String()))
+}
+
+type ConfigDisableDropdownCreate struct {
+ Performer bool `json:"performer"`
+ Tag bool `json:"tag"`
+ Studio bool `json:"studio"`
+}
diff --git a/internal/manager/filename_parser.go b/internal/manager/filename_parser.go
index e82cfa1ed..3bd856e69 100644
--- a/internal/manager/filename_parser.go
+++ b/internal/manager/filename_parser.go
@@ -16,6 +16,32 @@ import (
"github.com/stashapp/stash/pkg/tag"
)
+type SceneParserInput struct {
+ IgnoreWords []string `json:"ignoreWords"`
+ WhitespaceCharacters *string `json:"whitespaceCharacters"`
+ CapitalizeTitle *bool `json:"capitalizeTitle"`
+ IgnoreOrganized *bool `json:"ignoreOrganized"`
+}
+
+type SceneParserResult struct {
+ Scene *models.Scene `json:"scene"`
+ Title *string `json:"title"`
+ Details *string `json:"details"`
+ URL *string `json:"url"`
+ Date *string `json:"date"`
+ Rating *int `json:"rating"`
+ StudioID *string `json:"studio_id"`
+ GalleryIds []string `json:"gallery_ids"`
+ PerformerIds []string `json:"performer_ids"`
+ Movies []*SceneMovieID `json:"movies"`
+ TagIds []string `json:"tag_ids"`
+}
+
+type SceneMovieID struct {
+ MovieID string `json:"movie_id"`
+ SceneIndex *string `json:"scene_index"`
+}
+
type parserField struct {
field string
fieldRegex *regexp.Regexp
@@ -402,7 +428,7 @@ func (m parseMapper) parse(scene *models.Scene) *sceneHolder {
type SceneFilenameParser struct {
Pattern string
- ParserInput models.SceneParserInput
+ ParserInput SceneParserInput
Filter *models.FindFilterType
whitespaceRE *regexp.Regexp
performerCache map[string]*models.Performer
@@ -411,7 +437,7 @@ type SceneFilenameParser struct {
tagCache map[string]*models.Tag
}
-func NewSceneFilenameParser(filter *models.FindFilterType, config models.SceneParserInput) *SceneFilenameParser {
+func NewSceneFilenameParser(filter *models.FindFilterType, config SceneParserInput) *SceneFilenameParser {
p := &SceneFilenameParser{
Pattern: *filter.Q,
ParserInput: config,
@@ -444,7 +470,7 @@ func (p *SceneFilenameParser) initWhiteSpaceRegex() {
}
}
-func (p *SceneFilenameParser) Parse(repo models.ReaderRepository) ([]*models.SceneParserResult, int, error) {
+func (p *SceneFilenameParser) Parse(repo models.ReaderRepository) ([]*SceneParserResult, int, error) {
// perform the query to find the scenes
mapper, err := newParseMapper(p.Pattern, p.ParserInput.IgnoreWords)
@@ -476,13 +502,13 @@ func (p *SceneFilenameParser) Parse(repo models.ReaderRepository) ([]*models.Sce
return ret, total, nil
}
-func (p *SceneFilenameParser) parseScenes(repo models.ReaderRepository, scenes []*models.Scene, mapper *parseMapper) []*models.SceneParserResult {
- var ret []*models.SceneParserResult
+func (p *SceneFilenameParser) parseScenes(repo models.ReaderRepository, scenes []*models.Scene, mapper *parseMapper) []*SceneParserResult {
+ var ret []*SceneParserResult
for _, scene := range scenes {
sceneHolder := mapper.parse(scene)
if sceneHolder != nil {
- r := &models.SceneParserResult{
+ r := &SceneParserResult{
Scene: scene,
}
p.setParserResult(repo, *sceneHolder, r)
@@ -589,7 +615,7 @@ func (p *SceneFilenameParser) queryTag(qb models.TagReader, tagName string) *mod
return ret
}
-func (p *SceneFilenameParser) setPerformers(qb models.PerformerReader, h sceneHolder, result *models.SceneParserResult) {
+func (p *SceneFilenameParser) setPerformers(qb models.PerformerReader, h sceneHolder, result *SceneParserResult) {
// query for each performer
performersSet := make(map[int]bool)
for _, performerName := range h.performers {
@@ -605,7 +631,7 @@ func (p *SceneFilenameParser) setPerformers(qb models.PerformerReader, h sceneHo
}
}
-func (p *SceneFilenameParser) setTags(qb models.TagReader, h sceneHolder, result *models.SceneParserResult) {
+func (p *SceneFilenameParser) setTags(qb models.TagReader, h sceneHolder, result *SceneParserResult) {
// query for each performer
tagsSet := make(map[int]bool)
for _, tagName := range h.tags {
@@ -621,7 +647,7 @@ func (p *SceneFilenameParser) setTags(qb models.TagReader, h sceneHolder, result
}
}
-func (p *SceneFilenameParser) setStudio(qb models.StudioReader, h sceneHolder, result *models.SceneParserResult) {
+func (p *SceneFilenameParser) setStudio(qb models.StudioReader, h sceneHolder, result *SceneParserResult) {
// query for each performer
if h.studio != "" {
studio := p.queryStudio(qb, h.studio)
@@ -632,7 +658,7 @@ func (p *SceneFilenameParser) setStudio(qb models.StudioReader, h sceneHolder, r
}
}
-func (p *SceneFilenameParser) setMovies(qb models.MovieReader, h sceneHolder, result *models.SceneParserResult) {
+func (p *SceneFilenameParser) setMovies(qb models.MovieReader, h sceneHolder, result *SceneParserResult) {
// query for each movie
moviesSet := make(map[int]bool)
for _, movieName := range h.movies {
@@ -640,7 +666,7 @@ func (p *SceneFilenameParser) setMovies(qb models.MovieReader, h sceneHolder, re
movie := p.queryMovie(qb, movieName)
if movie != nil {
if _, found := moviesSet[movie.ID]; !found {
- result.Movies = append(result.Movies, &models.SceneMovieID{
+ result.Movies = append(result.Movies, &SceneMovieID{
MovieID: strconv.Itoa(movie.ID),
})
moviesSet[movie.ID] = true
@@ -650,7 +676,7 @@ func (p *SceneFilenameParser) setMovies(qb models.MovieReader, h sceneHolder, re
}
}
-func (p *SceneFilenameParser) setParserResult(repo models.ReaderRepository, h sceneHolder, result *models.SceneParserResult) {
+func (p *SceneFilenameParser) setParserResult(repo models.ReaderRepository, h sceneHolder, result *SceneParserResult) {
if h.result.Title.Valid {
title := h.result.Title.String
title = p.replaceWhitespaceCharacters(title)
diff --git a/internal/manager/import.go b/internal/manager/import.go
index c2f2820c7..8e3140577 100644
--- a/internal/manager/import.go
+++ b/internal/manager/import.go
@@ -2,11 +2,55 @@ package manager
import (
"fmt"
+ "io"
+ "strconv"
"github.com/stashapp/stash/pkg/logger"
- "github.com/stashapp/stash/pkg/models"
)
+type ImportDuplicateEnum string
+
+const (
+ ImportDuplicateEnumIgnore ImportDuplicateEnum = "IGNORE"
+ ImportDuplicateEnumOverwrite ImportDuplicateEnum = "OVERWRITE"
+ ImportDuplicateEnumFail ImportDuplicateEnum = "FAIL"
+)
+
+var AllImportDuplicateEnum = []ImportDuplicateEnum{
+ ImportDuplicateEnumIgnore,
+ ImportDuplicateEnumOverwrite,
+ ImportDuplicateEnumFail,
+}
+
+func (e ImportDuplicateEnum) IsValid() bool {
+ switch e {
+ case ImportDuplicateEnumIgnore, ImportDuplicateEnumOverwrite, ImportDuplicateEnumFail:
+ return true
+ }
+ return false
+}
+
+func (e ImportDuplicateEnum) String() string {
+ return string(e)
+}
+
+func (e *ImportDuplicateEnum) UnmarshalGQL(v interface{}) error {
+ str, ok := v.(string)
+ if !ok {
+ return fmt.Errorf("enums must be strings")
+ }
+
+ *e = ImportDuplicateEnum(str)
+ if !e.IsValid() {
+ return fmt.Errorf("%s is not a valid ImportDuplicateEnum", str)
+ }
+ return nil
+}
+
+func (e ImportDuplicateEnum) MarshalGQL(w io.Writer) {
+ fmt.Fprint(w, strconv.Quote(e.String()))
+}
+
type importer interface {
PreImport() error
PostImport(id int) error
@@ -16,7 +60,7 @@ type importer interface {
Update(id int) error
}
-func performImport(i importer, duplicateBehaviour models.ImportDuplicateEnum) error {
+func performImport(i importer, duplicateBehaviour ImportDuplicateEnum) error {
if err := i.PreImport(); err != nil {
return err
}
@@ -31,9 +75,9 @@ func performImport(i importer, duplicateBehaviour models.ImportDuplicateEnum) er
var id int
if existing != nil {
- if duplicateBehaviour == models.ImportDuplicateEnumFail {
+ if duplicateBehaviour == ImportDuplicateEnumFail {
return fmt.Errorf("existing object with name '%s'", name)
- } else if duplicateBehaviour == models.ImportDuplicateEnumIgnore {
+ } else if duplicateBehaviour == ImportDuplicateEnumIgnore {
logger.Info("Skipping existing object")
return nil
}
diff --git a/internal/manager/manager.go b/internal/manager/manager.go
index ae7655d54..56221327b 100644
--- a/internal/manager/manager.go
+++ b/internal/manager/manager.go
@@ -4,9 +4,11 @@ import (
"context"
"errors"
"fmt"
+ "io"
"os"
"path/filepath"
"runtime/pprof"
+ "strconv"
"strings"
"sync"
"time"
@@ -30,6 +32,67 @@ import (
"github.com/stashapp/stash/ui"
)
+type SystemStatus struct {
+ DatabaseSchema *int `json:"databaseSchema"`
+ DatabasePath *string `json:"databasePath"`
+ ConfigPath *string `json:"configPath"`
+ AppSchema int `json:"appSchema"`
+ Status SystemStatusEnum `json:"status"`
+}
+
+type SystemStatusEnum string
+
+const (
+ SystemStatusEnumSetup SystemStatusEnum = "SETUP"
+ SystemStatusEnumNeedsMigration SystemStatusEnum = "NEEDS_MIGRATION"
+ SystemStatusEnumOk SystemStatusEnum = "OK"
+)
+
+var AllSystemStatusEnum = []SystemStatusEnum{
+ SystemStatusEnumSetup,
+ SystemStatusEnumNeedsMigration,
+ SystemStatusEnumOk,
+}
+
+func (e SystemStatusEnum) IsValid() bool {
+ switch e {
+ case SystemStatusEnumSetup, SystemStatusEnumNeedsMigration, SystemStatusEnumOk:
+ return true
+ }
+ return false
+}
+
+func (e SystemStatusEnum) String() string {
+ return string(e)
+}
+
+func (e *SystemStatusEnum) UnmarshalGQL(v interface{}) error {
+ str, ok := v.(string)
+ if !ok {
+ return fmt.Errorf("enums must be strings")
+ }
+
+ *e = SystemStatusEnum(str)
+ if !e.IsValid() {
+ return fmt.Errorf("%s is not a valid SystemStatusEnum", str)
+ }
+ return nil
+}
+
+func (e SystemStatusEnum) MarshalGQL(w io.Writer) {
+ fmt.Fprint(w, strconv.Quote(e.String()))
+}
+
+type SetupInput struct {
+ // Empty to indicate $HOME/.stash/config.yml default
+ ConfigLocation string `json:"configLocation"`
+ Stashes []*config.StashConfigInput `json:"stashes"`
+ // Empty to indicate default
+ DatabaseFile string `json:"databaseFile"`
+ // Empty to indicate default
+ GeneratedLocation string `json:"generatedLocation"`
+}
+
type Manager struct {
Config *config.Instance
Logger *log.Logger
@@ -354,7 +417,7 @@ func (s *Manager) RefreshScraperCache() {
s.ScraperCache = s.initScraperCache()
}
-func setSetupDefaults(input *models.SetupInput) {
+func setSetupDefaults(input *SetupInput) {
if input.ConfigLocation == "" {
input.ConfigLocation = filepath.Join(fsutil.GetHomeDirectory(), ".stash", "config.yml")
}
@@ -369,7 +432,7 @@ func setSetupDefaults(input *models.SetupInput) {
}
}
-func (s *Manager) Setup(ctx context.Context, input models.SetupInput) error {
+func (s *Manager) Setup(ctx context.Context, input SetupInput) error {
setSetupDefaults(&input)
c := s.Config
@@ -433,7 +496,11 @@ func (s *Manager) validateFFMPEG() error {
return nil
}
-func (s *Manager) Migrate(ctx context.Context, input models.MigrateInput) error {
+type MigrateInput struct {
+ BackupPath string `json:"backupPath"`
+}
+
+func (s *Manager) Migrate(ctx context.Context, input MigrateInput) error {
// always backup so that we can roll back to the previous version if
// migration fails
backupPath := input.BackupPath
@@ -473,20 +540,20 @@ func (s *Manager) Migrate(ctx context.Context, input models.MigrateInput) error
return nil
}
-func (s *Manager) GetSystemStatus() *models.SystemStatus {
- status := models.SystemStatusEnumOk
+func (s *Manager) GetSystemStatus() *SystemStatus {
+ status := SystemStatusEnumOk
dbSchema := int(database.Version())
dbPath := database.DatabasePath()
appSchema := int(database.AppSchemaVersion())
configFile := s.Config.GetConfigFile()
if s.Config.IsNewSystem() {
- status = models.SystemStatusEnumSetup
+ status = SystemStatusEnumSetup
} else if dbSchema < appSchema {
- status = models.SystemStatusEnumNeedsMigration
+ status = SystemStatusEnumNeedsMigration
}
- return &models.SystemStatus{
+ return &SystemStatus{
DatabaseSchema: &dbSchema,
DatabasePath: &dbPath,
AppSchema: appSchema,
diff --git a/internal/manager/manager_tasks.go b/internal/manager/manager_tasks.go
index 209e89ceb..3178f7846 100644
--- a/internal/manager/manager_tasks.go
+++ b/internal/manager/manager_tasks.go
@@ -6,6 +6,7 @@ import (
"fmt"
"strconv"
"sync"
+ "time"
"github.com/stashapp/stash/internal/manager/config"
"github.com/stashapp/stash/pkg/fsutil"
@@ -34,12 +35,12 @@ func isImage(pathname string) bool {
return fsutil.MatchExtension(pathname, imgExt)
}
-func getScanPaths(inputPaths []string) []*models.StashConfig {
+func getScanPaths(inputPaths []string) []*config.StashConfig {
if len(inputPaths) == 0 {
return config.GetInstance().GetStashPaths()
}
- var ret []*models.StashConfig
+ var ret []*config.StashConfig
for _, p := range inputPaths {
s := getStashFromDirPath(p)
if s == nil {
@@ -62,7 +63,22 @@ func (s *Manager) ScanSubscribe(ctx context.Context) <-chan bool {
return s.scanSubs.subscribe(ctx)
}
-func (s *Manager) Scan(ctx context.Context, input models.ScanMetadataInput) (int, error) {
+type ScanMetadataInput struct {
+ Paths []string `json:"paths"`
+
+ config.ScanMetadataOptions
+
+ // Filter options for the scan
+ Filter *ScanMetaDataFilterInput `json:"filter"`
+}
+
+// Filter options for meta data scannning
+type ScanMetaDataFilterInput struct {
+ // If set, files with a modification time before this time point are ignored by the scan
+ MinModTime *time.Time `json:"minModTime"`
+}
+
+func (s *Manager) Scan(ctx context.Context, input ScanMetadataInput) (int, error) {
if err := s.validateFFMPEG(); err != nil {
return 0, err
}
@@ -88,7 +104,7 @@ func (s *Manager) Import(ctx context.Context) (int, error) {
txnManager: s.TxnManager,
BaseDir: metadataPath,
Reset: true,
- DuplicateBehaviour: models.ImportDuplicateEnumFail,
+ DuplicateBehaviour: ImportDuplicateEnumFail,
MissingRefBehaviour: models.ImportMissingRefEnumFail,
fileNamingAlgorithm: config.GetVideoFileNamingAlgorithm(),
}
@@ -131,7 +147,7 @@ func (s *Manager) RunSingleTask(ctx context.Context, t Task) int {
return s.JobManager.Add(ctx, t.GetDescription(), j)
}
-func (s *Manager) Generate(ctx context.Context, input models.GenerateMetadataInput) (int, error) {
+func (s *Manager) Generate(ctx context.Context, input GenerateMetadataInput) (int, error) {
if err := s.validateFFMPEG(); err != nil {
return 0, err
}
@@ -193,7 +209,18 @@ func (s *Manager) generateScreenshot(ctx context.Context, sceneId string, at *fl
return s.JobManager.Add(ctx, fmt.Sprintf("Generating screenshot for scene id %s", sceneId), j)
}
-func (s *Manager) AutoTag(ctx context.Context, input models.AutoTagMetadataInput) int {
+type AutoTagMetadataInput struct {
+ // Paths to tag, null for all files
+ Paths []string `json:"paths"`
+ // IDs of performers to tag files with, or "*" for all
+ Performers []string `json:"performers"`
+ // IDs of studios to tag files with, or "*" for all
+ Studios []string `json:"studios"`
+ // IDs of tags to tag files with, or "*" for all
+ Tags []string `json:"tags"`
+}
+
+func (s *Manager) AutoTag(ctx context.Context, input AutoTagMetadataInput) int {
j := autoTagJob{
txnManager: s.TxnManager,
input: input,
@@ -202,7 +229,13 @@ func (s *Manager) AutoTag(ctx context.Context, input models.AutoTagMetadataInput
return s.JobManager.Add(ctx, "Auto-tagging...", &j)
}
-func (s *Manager) Clean(ctx context.Context, input models.CleanMetadataInput) int {
+type CleanMetadataInput struct {
+ Paths []string `json:"paths"`
+ // Do a dry run. Don't delete any files
+ DryRun bool `json:"dryRun"`
+}
+
+func (s *Manager) Clean(ctx context.Context, input CleanMetadataInput) int {
j := cleanJob{
txnManager: s.TxnManager,
input: input,
@@ -260,7 +293,21 @@ func (s *Manager) MigrateHash(ctx context.Context) int {
return s.JobManager.Add(ctx, "Migrating scene hashes...", j)
}
-func (s *Manager) StashBoxBatchPerformerTag(ctx context.Context, input models.StashBoxBatchPerformerTagInput) int {
+// If neither performer_ids nor performer_names are set, tag all performers
+type StashBoxBatchPerformerTagInput struct {
+ // Stash endpoint to use for the performer tagging
+ Endpoint int `json:"endpoint"`
+ // Fields to exclude when executing the performer tagging
+ ExcludeFields []string `json:"exclude_fields"`
+ // Refresh performers already tagged by StashBox if true. Only tag performers with no StashBox tagging if false
+ Refresh bool `json:"refresh"`
+ // If set, only tag these performer ids
+ PerformerIds []string `json:"performer_ids"`
+ // If set, only tag these performer names
+ PerformerNames []string `json:"performer_names"`
+}
+
+func (s *Manager) StashBoxBatchPerformerTag(ctx context.Context, input StashBoxBatchPerformerTagInput) int {
j := job.MakeJobExec(func(ctx context.Context, progress *job.Progress) {
logger.Infof("Initiating stash-box batch performer tag")
diff --git a/internal/manager/scene.go b/internal/manager/scene.go
index e0aee54ea..564d3dcb1 100644
--- a/internal/manager/scene.go
+++ b/internal/manager/scene.go
@@ -50,20 +50,26 @@ func includeSceneStreamPath(scene *models.Scene, streamingResolution models.Stre
return int64(maxStreamingResolution.GetMinResolution()) >= minResolution
}
-func makeStreamEndpoint(streamURL string, streamingResolution models.StreamingResolutionEnum, mimeType, label string) *models.SceneStreamEndpoint {
- return &models.SceneStreamEndpoint{
+type SceneStreamEndpoint struct {
+ URL string `json:"url"`
+ MimeType *string `json:"mime_type"`
+ Label *string `json:"label"`
+}
+
+func makeStreamEndpoint(streamURL string, streamingResolution models.StreamingResolutionEnum, mimeType, label string) *SceneStreamEndpoint {
+ return &SceneStreamEndpoint{
URL: fmt.Sprintf("%s?resolution=%s", streamURL, streamingResolution.String()),
MimeType: &mimeType,
Label: &label,
}
}
-func GetSceneStreamPaths(scene *models.Scene, directStreamURL string, maxStreamingTranscodeSize models.StreamingResolutionEnum) ([]*models.SceneStreamEndpoint, error) {
+func GetSceneStreamPaths(scene *models.Scene, directStreamURL string, maxStreamingTranscodeSize models.StreamingResolutionEnum) ([]*SceneStreamEndpoint, error) {
if scene == nil {
return nil, fmt.Errorf("nil scene")
}
- var ret []*models.SceneStreamEndpoint
+ var ret []*SceneStreamEndpoint
mimeWebm := ffmpeg.MimeWebm
mimeHLS := ffmpeg.MimeHLS
mimeMp4 := ffmpeg.MimeMp4
@@ -82,7 +88,7 @@ func GetSceneStreamPaths(scene *models.Scene, directStreamURL string, maxStreami
if HasTranscode(scene, config.GetInstance().GetVideoFileNamingAlgorithm()) || ffmpeg.IsValidAudioForContainer(audioCodec, container) {
label := "Direct stream"
- ret = append(ret, &models.SceneStreamEndpoint{
+ ret = append(ret, &SceneStreamEndpoint{
URL: directStreamURL,
MimeType: &mimeMp4,
Label: &label,
@@ -92,7 +98,7 @@ func GetSceneStreamPaths(scene *models.Scene, directStreamURL string, maxStreami
// only add mkv stream endpoint if the scene container is an mkv already
if container == ffmpeg.Matroska {
label := "mkv"
- ret = append(ret, &models.SceneStreamEndpoint{
+ ret = append(ret, &SceneStreamEndpoint{
URL: directStreamURL + ".mkv",
// set mkv to mp4 to trick the client, since many clients won't try mkv
MimeType: &mimeMp4,
@@ -115,8 +121,8 @@ func GetSceneStreamPaths(scene *models.Scene, directStreamURL string, maxStreami
mp4LabelStandard := "MP4 Standard (480p)" // "STANDARD"
mp4LabelLow := "MP4 Low (240p)" // "LOW"
- var webmStreams []*models.SceneStreamEndpoint
- var mp4Streams []*models.SceneStreamEndpoint
+ var webmStreams []*SceneStreamEndpoint
+ var mp4Streams []*SceneStreamEndpoint
webmURL := directStreamURL + ".webm"
mp4URL := directStreamURL + ".mp4"
@@ -149,7 +155,7 @@ func GetSceneStreamPaths(scene *models.Scene, directStreamURL string, maxStreami
ret = append(ret, webmStreams...)
ret = append(ret, mp4Streams...)
- defaultStreams := []*models.SceneStreamEndpoint{
+ defaultStreams := []*SceneStreamEndpoint{
{
URL: directStreamURL + ".webm",
MimeType: &mimeWebm,
@@ -159,7 +165,7 @@ func GetSceneStreamPaths(scene *models.Scene, directStreamURL string, maxStreami
ret = append(ret, defaultStreams...)
- hls := models.SceneStreamEndpoint{
+ hls := SceneStreamEndpoint{
URL: directStreamURL + ".m3u8",
MimeType: &mimeHLS,
Label: &labelHLS,
diff --git a/internal/manager/task_autotag.go b/internal/manager/task_autotag.go
index 369d31754..a912dcedc 100644
--- a/internal/manager/task_autotag.go
+++ b/internal/manager/task_autotag.go
@@ -20,7 +20,7 @@ import (
type autoTagJob struct {
txnManager models.TransactionManager
- input models.AutoTagMetadataInput
+ input AutoTagMetadataInput
cache match.Cache
}
@@ -40,7 +40,7 @@ func (j *autoTagJob) Execute(ctx context.Context, progress *job.Progress) {
logger.Infof("Finished autotag after %s", time.Since(begin).String())
}
-func (j *autoTagJob) isFileBasedAutoTag(input models.AutoTagMetadataInput) bool {
+func (j *autoTagJob) isFileBasedAutoTag(input AutoTagMetadataInput) bool {
const wildcard = "*"
performerIds := input.Performers
studioIds := input.Studios
diff --git a/internal/manager/task_clean.go b/internal/manager/task_clean.go
index 2853e612c..2cbd168fb 100644
--- a/internal/manager/task_clean.go
+++ b/internal/manager/task_clean.go
@@ -19,7 +19,7 @@ import (
type cleanJob struct {
txnManager models.TransactionManager
- input models.CleanMetadataInput
+ input CleanMetadataInput
scanSubs *subscriptionManager
}
@@ -488,7 +488,7 @@ func (j *cleanJob) deleteImage(ctx context.Context, imageID int) {
}, nil)
}
-func getStashFromPath(pathToCheck string) *models.StashConfig {
+func getStashFromPath(pathToCheck string) *config.StashConfig {
for _, s := range config.GetInstance().GetStashPaths() {
if fsutil.IsPathInDir(s.Path, filepath.Dir(pathToCheck)) {
return s
@@ -497,7 +497,7 @@ func getStashFromPath(pathToCheck string) *models.StashConfig {
return nil
}
-func getStashFromDirPath(pathToCheck string) *models.StashConfig {
+func getStashFromDirPath(pathToCheck string) *config.StashConfig {
for _, s := range config.GetInstance().GetStashPaths() {
if fsutil.IsPathInDir(s.Path, pathToCheck) {
return s
diff --git a/internal/manager/task_export.go b/internal/manager/task_export.go
index 32141f3ea..512b7e42f 100644
--- a/internal/manager/task_export.go
+++ b/internal/manager/task_export.go
@@ -54,12 +54,28 @@ type ExportTask struct {
DownloadHash string
}
+type ExportObjectTypeInput struct {
+ Ids []string `json:"ids"`
+ All *bool `json:"all"`
+}
+
+type ExportObjectsInput struct {
+ Scenes *ExportObjectTypeInput `json:"scenes"`
+ Images *ExportObjectTypeInput `json:"images"`
+ Studios *ExportObjectTypeInput `json:"studios"`
+ Performers *ExportObjectTypeInput `json:"performers"`
+ Tags *ExportObjectTypeInput `json:"tags"`
+ Movies *ExportObjectTypeInput `json:"movies"`
+ Galleries *ExportObjectTypeInput `json:"galleries"`
+ IncludeDependencies *bool `json:"includeDependencies"`
+}
+
type exportSpec struct {
IDs []int
all bool
}
-func newExportSpec(input *models.ExportObjectTypeInput) *exportSpec {
+func newExportSpec(input *ExportObjectTypeInput) *exportSpec {
if input == nil {
return &exportSpec{}
}
@@ -77,7 +93,7 @@ func newExportSpec(input *models.ExportObjectTypeInput) *exportSpec {
return ret
}
-func CreateExportTask(a models.HashAlgorithm, input models.ExportObjectsInput) *ExportTask {
+func CreateExportTask(a models.HashAlgorithm, input ExportObjectsInput) *ExportTask {
includeDeps := false
if input.IncludeDependencies != nil {
includeDeps = *input.IncludeDependencies
diff --git a/internal/manager/task_generate.go b/internal/manager/task_generate.go
index 3addf7bec..bc385ab22 100644
--- a/internal/manager/task_generate.go
+++ b/internal/manager/task_generate.go
@@ -17,11 +17,45 @@ import (
"github.com/stashapp/stash/pkg/utils"
)
+type GenerateMetadataInput struct {
+ Sprites *bool `json:"sprites"`
+ Previews *bool `json:"previews"`
+ ImagePreviews *bool `json:"imagePreviews"`
+ PreviewOptions *GeneratePreviewOptionsInput `json:"previewOptions"`
+ Markers *bool `json:"markers"`
+ MarkerImagePreviews *bool `json:"markerImagePreviews"`
+ MarkerScreenshots *bool `json:"markerScreenshots"`
+ Transcodes *bool `json:"transcodes"`
+ // Generate transcodes even if not required
+ ForceTranscodes *bool `json:"forceTranscodes"`
+ Phashes *bool `json:"phashes"`
+ InteractiveHeatmapsSpeeds *bool `json:"interactiveHeatmapsSpeeds"`
+ // scene ids to generate for
+ SceneIDs []string `json:"sceneIDs"`
+ // marker ids to generate for
+ MarkerIDs []string `json:"markerIDs"`
+ // overwrite existing media
+ Overwrite *bool `json:"overwrite"`
+}
+
+type GeneratePreviewOptionsInput struct {
+ // Number of segments in a preview file
+ PreviewSegments *int `json:"previewSegments"`
+ // Preview segment duration, in seconds
+ PreviewSegmentDuration *float64 `json:"previewSegmentDuration"`
+ // Duration of start of video to exclude when generating previews
+ PreviewExcludeStart *string `json:"previewExcludeStart"`
+ // Duration of end of video to exclude when generating previews
+ PreviewExcludeEnd *string `json:"previewExcludeEnd"`
+ // Preset when generating preview
+ PreviewPreset *models.PreviewPreset `json:"previewPreset"`
+}
+
const generateQueueSize = 200000
type GenerateJob struct {
txnManager models.TransactionManager
- input models.GenerateMetadataInput
+ input GenerateMetadataInput
overwrite bool
fileNamingAlgo models.HashAlgorithm
@@ -194,7 +228,7 @@ func (j *GenerateJob) queueTasks(ctx context.Context, g *generate.Generator, que
return totals
}
-func getGeneratePreviewOptions(optionsInput models.GeneratePreviewOptionsInput) generate.PreviewOptions {
+func getGeneratePreviewOptions(optionsInput GeneratePreviewOptionsInput) generate.PreviewOptions {
config := config.GetInstance()
ret := generate.PreviewOptions{
@@ -246,7 +280,7 @@ func (j *GenerateJob) queueSceneJobs(ctx context.Context, g *generate.Generator,
generatePreviewOptions := j.input.PreviewOptions
if generatePreviewOptions == nil {
- generatePreviewOptions = &models.GeneratePreviewOptionsInput{}
+ generatePreviewOptions = &GeneratePreviewOptionsInput{}
}
options := getGeneratePreviewOptions(*generatePreviewOptions)
diff --git a/internal/manager/task_identify.go b/internal/manager/task_identify.go
index 678d0c7b3..457c59dd1 100644
--- a/internal/manager/task_identify.go
+++ b/internal/manager/task_identify.go
@@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
+ "strings"
"github.com/stashapp/stash/internal/identify"
"github.com/stashapp/stash/pkg/job"
@@ -20,13 +21,13 @@ var ErrInput = errors.New("invalid request input")
type IdentifyJob struct {
txnManager models.TransactionManager
postHookExecutor identify.SceneUpdatePostHookExecutor
- input models.IdentifyMetadataInput
+ input identify.Options
- stashBoxes models.StashBoxes
+ stashBoxes []*models.StashBox
progress *job.Progress
}
-func CreateIdentifyJob(input models.IdentifyMetadataInput) *IdentifyJob {
+func CreateIdentifyJob(input identify.Options) *IdentifyJob {
return &IdentifyJob{
txnManager: instance.TxnManager,
postHookExecutor: instance.PluginCache,
@@ -192,7 +193,7 @@ func (j *IdentifyJob) getSources() ([]identify.ScraperSource, error) {
return ret, nil
}
-func (j *IdentifyJob) getStashBox(src *models.ScraperSourceInput) (*models.StashBox, error) {
+func (j *IdentifyJob) getStashBox(src *scraper.Source) (*models.StashBox, error) {
if src.ScraperID != nil {
return nil, nil
}
@@ -202,7 +203,38 @@ func (j *IdentifyJob) getStashBox(src *models.ScraperSourceInput) (*models.Stash
return nil, fmt.Errorf("%w: stash_box_index or stash_box_endpoint or scraper_id must be set", ErrInput)
}
- return j.stashBoxes.ResolveStashBox(*src)
+ return resolveStashBox(j.stashBoxes, *src)
+}
+
+func resolveStashBox(sb []*models.StashBox, source scraper.Source) (*models.StashBox, error) {
+ if source.StashBoxIndex != nil {
+ index := source.StashBoxIndex
+ if *index < 0 || *index >= len(sb) {
+ return nil, fmt.Errorf("%w: invalid stash_box_index: %d", models.ErrScraperSource, index)
+ }
+
+ return sb[*index], nil
+ }
+
+ if source.StashBoxEndpoint != nil {
+ var ret *models.StashBox
+ endpoint := *source.StashBoxEndpoint
+ for _, b := range sb {
+ if strings.EqualFold(endpoint, b.Endpoint) {
+ ret = b
+ }
+ }
+
+ if ret == nil {
+ return nil, fmt.Errorf(`%w: stash-box with endpoint "%s"`, models.ErrNotFound, endpoint)
+ }
+
+ return ret, nil
+ }
+
+ // neither stash-box inputs were provided, so assume it is a scraper
+
+ return nil, nil
}
type stashboxSource struct {
@@ -210,7 +242,7 @@ type stashboxSource struct {
endpoint string
}
-func (s stashboxSource) ScrapeScene(ctx context.Context, sceneID int) (*models.ScrapedScene, error) {
+func (s stashboxSource) ScrapeScene(ctx context.Context, sceneID int) (*scraper.ScrapedScene, error) {
results, err := s.FindStashBoxSceneByFingerprints(ctx, sceneID)
if err != nil {
return nil, fmt.Errorf("error querying stash-box using scene ID %d: %w", sceneID, err)
@@ -232,8 +264,8 @@ type scraperSource struct {
scraperID string
}
-func (s scraperSource) ScrapeScene(ctx context.Context, sceneID int) (*models.ScrapedScene, error) {
- content, err := s.cache.ScrapeID(ctx, s.scraperID, sceneID, models.ScrapeContentTypeScene)
+func (s scraperSource) ScrapeScene(ctx context.Context, sceneID int) (*scraper.ScrapedScene, error) {
+ content, err := s.cache.ScrapeID(ctx, s.scraperID, sceneID, scraper.ScrapeContentTypeScene)
if err != nil {
return nil, err
}
@@ -243,7 +275,7 @@ func (s scraperSource) ScrapeScene(ctx context.Context, sceneID int) (*models.Sc
return nil, nil
}
- if scene, ok := content.(models.ScrapedScene); ok {
+ if scene, ok := content.(scraper.ScrapedScene); ok {
return &scene, nil
}
diff --git a/internal/manager/task_import.go b/internal/manager/task_import.go
index ce6e22366..8175bfc59 100644
--- a/internal/manager/task_import.go
+++ b/internal/manager/task_import.go
@@ -11,6 +11,7 @@ import (
"path/filepath"
"time"
+ "github.com/99designs/gqlgen/graphql"
"github.com/stashapp/stash/internal/manager/config"
"github.com/stashapp/stash/pkg/database"
"github.com/stashapp/stash/pkg/fsutil"
@@ -35,7 +36,7 @@ type ImportTask struct {
BaseDir string
TmpZip string
Reset bool
- DuplicateBehaviour models.ImportDuplicateEnum
+ DuplicateBehaviour ImportDuplicateEnum
MissingRefBehaviour models.ImportMissingRefEnum
mappings *jsonschema.Mappings
@@ -43,7 +44,13 @@ type ImportTask struct {
fileNamingAlgorithm models.HashAlgorithm
}
-func CreateImportTask(a models.HashAlgorithm, input models.ImportObjectsInput) (*ImportTask, error) {
+type ImportObjectsInput struct {
+ File graphql.Upload `json:"file"`
+ DuplicateBehaviour ImportDuplicateEnum `json:"duplicateBehaviour"`
+ MissingRefBehaviour models.ImportMissingRefEnum `json:"missingRefBehaviour"`
+}
+
+func CreateImportTask(a models.HashAlgorithm, input ImportObjectsInput) (*ImportTask, error) {
baseDir, err := instance.Paths.Generated.TempDir("import")
if err != nil {
logger.Errorf("error creating temporary directory for import: %s", err.Error())
@@ -101,7 +108,7 @@ func (t *ImportTask) Start(ctx context.Context) {
// set default behaviour if not provided
if !t.DuplicateBehaviour.IsValid() {
- t.DuplicateBehaviour = models.ImportDuplicateEnumFail
+ t.DuplicateBehaviour = ImportDuplicateEnumFail
}
if !t.MissingRefBehaviour.IsValid() {
t.MissingRefBehaviour = models.ImportMissingRefEnumFail
diff --git a/internal/manager/task_plugin.go b/internal/manager/task_plugin.go
index 31023445d..78cd5db02 100644
--- a/internal/manager/task_plugin.go
+++ b/internal/manager/task_plugin.go
@@ -6,10 +6,10 @@ import (
"github.com/stashapp/stash/pkg/job"
"github.com/stashapp/stash/pkg/logger"
- "github.com/stashapp/stash/pkg/models"
+ "github.com/stashapp/stash/pkg/plugin"
)
-func (s *Manager) RunPluginTask(ctx context.Context, pluginID string, taskName string, args []*models.PluginArgInput) int {
+func (s *Manager) RunPluginTask(ctx context.Context, pluginID string, taskName string, args []*plugin.PluginArgInput) int {
j := job.MakeJobExec(func(jobCtx context.Context, progress *job.Progress) {
pluginProgress := make(chan float64)
task, err := s.PluginCache.CreateTask(ctx, pluginID, taskName, args, pluginProgress)
diff --git a/internal/manager/task_scan.go b/internal/manager/task_scan.go
index 05ffe168e..042ff80ac 100644
--- a/internal/manager/task_scan.go
+++ b/internal/manager/task_scan.go
@@ -25,7 +25,7 @@ const scanQueueSize = 200000
type ScanJob struct {
txnManager models.TransactionManager
- input models.ScanMetadataInput
+ input ScanMetadataInput
subscriptions *subscriptionManager
}
@@ -88,15 +88,15 @@ func (j *ScanJob) Execute(ctx context.Context, progress *job.Progress) {
task := ScanTask{
TxnManager: j.txnManager,
file: file.FSFile(f.path, f.info),
- UseFileMetadata: utils.IsTrue(input.UseFileMetadata),
- StripFileExtension: utils.IsTrue(input.StripFileExtension),
+ UseFileMetadata: input.UseFileMetadata,
+ StripFileExtension: input.StripFileExtension,
fileNamingAlgorithm: fileNamingAlgo,
calculateMD5: calculateMD5,
- GeneratePreview: utils.IsTrue(input.ScanGeneratePreviews),
- GenerateImagePreview: utils.IsTrue(input.ScanGenerateImagePreviews),
- GenerateSprite: utils.IsTrue(input.ScanGenerateSprites),
- GeneratePhash: utils.IsTrue(input.ScanGeneratePhashes),
- GenerateThumbnails: utils.IsTrue(input.ScanGenerateThumbnails),
+ GeneratePreview: input.ScanGeneratePreviews,
+ GenerateImagePreview: input.ScanGenerateImagePreviews,
+ GenerateSprite: input.ScanGenerateSprites,
+ GeneratePhash: input.ScanGeneratePhashes,
+ GenerateThumbnails: input.ScanGenerateThumbnails,
progress: progress,
CaseSensitiveFs: f.caseSensitiveFs,
mutexManager: mutexManager,
@@ -145,7 +145,7 @@ func (j *ScanJob) Execute(ctx context.Context, progress *job.Progress) {
j.subscriptions.notify()
}
-func (j *ScanJob) queueFiles(ctx context.Context, paths []*models.StashConfig, scanQueue chan<- scanFile, parallelTasks int) (total int, newFiles int) {
+func (j *ScanJob) queueFiles(ctx context.Context, paths []*config.StashConfig, scanQueue chan<- scanFile, parallelTasks int) (total int, newFiles int) {
defer close(scanQueue)
var minModTime time.Time
@@ -322,7 +322,7 @@ func (t *ScanTask) Start(ctx context.Context) {
iwg.Add()
go t.progress.ExecuteTask(fmt.Sprintf("Generating preview for %s", path), func() {
- options := getGeneratePreviewOptions(models.GeneratePreviewOptionsInput{})
+ options := getGeneratePreviewOptions(GeneratePreviewOptionsInput{})
const overwrite = false
g := &generate.Generator{
@@ -349,7 +349,7 @@ func (t *ScanTask) Start(ctx context.Context) {
iwg.Wait()
}
-func walkFilesToScan(s *models.StashConfig, f filepath.WalkFunc) error {
+func walkFilesToScan(s *config.StashConfig, f filepath.WalkFunc) error {
config := config.GetInstance()
vidExt := config.GetVideoExtensions()
imgExt := config.GetImageExtensions()
diff --git a/pkg/models/extension_resolution.go b/pkg/models/extension_resolution.go
deleted file mode 100644
index a52d4a784..000000000
--- a/pkg/models/extension_resolution.go
+++ /dev/null
@@ -1,46 +0,0 @@
-package models
-
-type ResolutionRange struct {
- min, max int
-}
-
-var resolutionRanges = map[ResolutionEnum]ResolutionRange{
- ResolutionEnumVeryLow: {144, 239},
- ResolutionEnumLow: {240, 359},
- ResolutionEnumR360p: {360, 479},
- ResolutionEnumStandard: {480, 539},
- ResolutionEnumWebHd: {540, 719},
- ResolutionEnumStandardHd: {720, 1079},
- ResolutionEnumFullHd: {1080, 1439},
- ResolutionEnumQuadHd: {1440, 1919},
- ResolutionEnumVrHd: {1920, 2159},
- ResolutionEnumFourK: {2160, 2879},
- ResolutionEnumFiveK: {2880, 3383},
- ResolutionEnumSixK: {3384, 4319},
- ResolutionEnumEightK: {4320, 8639},
-}
-
-// GetMaxResolution returns the maximum width or height that media must be
-// to qualify as this resolution.
-func (r *ResolutionEnum) GetMaxResolution() int {
- return resolutionRanges[*r].max
-}
-
-// GetMinResolution returns the minimum width or height that media must be
-// to qualify as this resolution.
-func (r ResolutionEnum) GetMinResolution() int {
- return resolutionRanges[r].min
-}
-
-var streamingResolutionMax = map[StreamingResolutionEnum]int{
- StreamingResolutionEnumLow: resolutionRanges[ResolutionEnumLow].min,
- StreamingResolutionEnumStandard: resolutionRanges[ResolutionEnumStandard].min,
- StreamingResolutionEnumStandardHd: resolutionRanges[ResolutionEnumStandardHd].min,
- StreamingResolutionEnumFullHd: resolutionRanges[ResolutionEnumFullHd].min,
- StreamingResolutionEnumFourK: resolutionRanges[ResolutionEnumFourK].min,
- StreamingResolutionEnumOriginal: 0,
-}
-
-func (r StreamingResolutionEnum) GetMaxResolution() int {
- return streamingResolutionMax[r]
-}
diff --git a/pkg/models/filter.go b/pkg/models/filter.go
new file mode 100644
index 000000000..57bee72df
--- /dev/null
+++ b/pkg/models/filter.go
@@ -0,0 +1,108 @@
+package models
+
+import (
+ "fmt"
+ "io"
+ "strconv"
+)
+
+type CriterionModifier string
+
+const (
+ // =
+ CriterionModifierEquals CriterionModifier = "EQUALS"
+ // !=
+ CriterionModifierNotEquals CriterionModifier = "NOT_EQUALS"
+ // >
+ CriterionModifierGreaterThan CriterionModifier = "GREATER_THAN"
+ // <
+ CriterionModifierLessThan CriterionModifier = "LESS_THAN"
+ // IS NULL
+ CriterionModifierIsNull CriterionModifier = "IS_NULL"
+ // IS NOT NULL
+ CriterionModifierNotNull CriterionModifier = "NOT_NULL"
+ // INCLUDES ALL
+ CriterionModifierIncludesAll CriterionModifier = "INCLUDES_ALL"
+ CriterionModifierIncludes CriterionModifier = "INCLUDES"
+ CriterionModifierExcludes CriterionModifier = "EXCLUDES"
+ // MATCHES REGEX
+ CriterionModifierMatchesRegex CriterionModifier = "MATCHES_REGEX"
+ // NOT MATCHES REGEX
+ CriterionModifierNotMatchesRegex CriterionModifier = "NOT_MATCHES_REGEX"
+ // >= AND <=
+ CriterionModifierBetween CriterionModifier = "BETWEEN"
+ // < OR >
+ CriterionModifierNotBetween CriterionModifier = "NOT_BETWEEN"
+)
+
+var AllCriterionModifier = []CriterionModifier{
+ CriterionModifierEquals,
+ CriterionModifierNotEquals,
+ CriterionModifierGreaterThan,
+ CriterionModifierLessThan,
+ CriterionModifierIsNull,
+ CriterionModifierNotNull,
+ CriterionModifierIncludesAll,
+ CriterionModifierIncludes,
+ CriterionModifierExcludes,
+ CriterionModifierMatchesRegex,
+ CriterionModifierNotMatchesRegex,
+ CriterionModifierBetween,
+ CriterionModifierNotBetween,
+}
+
+func (e CriterionModifier) IsValid() bool {
+ switch e {
+ case CriterionModifierEquals, CriterionModifierNotEquals, CriterionModifierGreaterThan, CriterionModifierLessThan, CriterionModifierIsNull, CriterionModifierNotNull, CriterionModifierIncludesAll, CriterionModifierIncludes, CriterionModifierExcludes, CriterionModifierMatchesRegex, CriterionModifierNotMatchesRegex, CriterionModifierBetween, CriterionModifierNotBetween:
+ return true
+ }
+ return false
+}
+
+func (e CriterionModifier) String() string {
+ return string(e)
+}
+
+func (e *CriterionModifier) UnmarshalGQL(v interface{}) error {
+ str, ok := v.(string)
+ if !ok {
+ return fmt.Errorf("enums must be strings")
+ }
+
+ *e = CriterionModifier(str)
+ if !e.IsValid() {
+ return fmt.Errorf("%s is not a valid CriterionModifier", str)
+ }
+ return nil
+}
+
+func (e CriterionModifier) MarshalGQL(w io.Writer) {
+ fmt.Fprint(w, strconv.Quote(e.String()))
+}
+
+type StringCriterionInput struct {
+ Value string `json:"value"`
+ Modifier CriterionModifier `json:"modifier"`
+}
+
+type IntCriterionInput struct {
+ Value int `json:"value"`
+ Value2 *int `json:"value2"`
+ Modifier CriterionModifier `json:"modifier"`
+}
+
+type ResolutionCriterionInput struct {
+ Value ResolutionEnum `json:"value"`
+ Modifier CriterionModifier `json:"modifier"`
+}
+
+type HierarchicalMultiCriterionInput struct {
+ Value []string `json:"value"`
+ Modifier CriterionModifier `json:"modifier"`
+ Depth *int `json:"depth"`
+}
+
+type MultiCriterionInput struct {
+ Value []string `json:"value"`
+ Modifier CriterionModifier `json:"modifier"`
+}
diff --git a/pkg/models/extension_find_filter.go b/pkg/models/find_filter.go
similarity index 57%
rename from pkg/models/extension_find_filter.go
rename to pkg/models/find_filter.go
index 1a6fb15ed..f684ca065 100644
--- a/pkg/models/extension_find_filter.go
+++ b/pkg/models/find_filter.go
@@ -1,9 +1,65 @@
package models
+import (
+ "fmt"
+ "io"
+ "strconv"
+)
+
// PerPageAll is the value used for perPage to indicate all results should be
// returned.
const PerPageAll = -1
+type SortDirectionEnum string
+
+const (
+ SortDirectionEnumAsc SortDirectionEnum = "ASC"
+ SortDirectionEnumDesc SortDirectionEnum = "DESC"
+)
+
+var AllSortDirectionEnum = []SortDirectionEnum{
+ SortDirectionEnumAsc,
+ SortDirectionEnumDesc,
+}
+
+func (e SortDirectionEnum) IsValid() bool {
+ switch e {
+ case SortDirectionEnumAsc, SortDirectionEnumDesc:
+ return true
+ }
+ return false
+}
+
+func (e SortDirectionEnum) String() string {
+ return string(e)
+}
+
+func (e *SortDirectionEnum) UnmarshalGQL(v interface{}) error {
+ str, ok := v.(string)
+ if !ok {
+ return fmt.Errorf("enums must be strings")
+ }
+
+ *e = SortDirectionEnum(str)
+ if !e.IsValid() {
+ return fmt.Errorf("%s is not a valid SortDirectionEnum", str)
+ }
+ return nil
+}
+
+func (e SortDirectionEnum) MarshalGQL(w io.Writer) {
+ fmt.Fprint(w, strconv.Quote(e.String()))
+}
+
+type FindFilterType struct {
+ Q *string `json:"q"`
+ Page *int `json:"page"`
+ // use per_page = -1 to indicate all results. Defaults to 25.
+ PerPage *int `json:"per_page"`
+ Sort *string `json:"sort"`
+ Direction *SortDirectionEnum `json:"direction"`
+}
+
func (ff FindFilterType) GetSort(defaultSort string) string {
var sort string
if ff.Sort == nil {
diff --git a/pkg/models/gallery.go b/pkg/models/gallery.go
index 75fcfc896..3c85e7193 100644
--- a/pkg/models/gallery.go
+++ b/pkg/models/gallery.go
@@ -1,5 +1,71 @@
package models
+type GalleryFilterType struct {
+ And *GalleryFilterType `json:"AND"`
+ Or *GalleryFilterType `json:"OR"`
+ Not *GalleryFilterType `json:"NOT"`
+ Title *StringCriterionInput `json:"title"`
+ Details *StringCriterionInput `json:"details"`
+ // Filter by file checksum
+ Checksum *StringCriterionInput `json:"checksum"`
+ // Filter by path
+ Path *StringCriterionInput `json:"path"`
+ // Filter to only include galleries missing this property
+ IsMissing *string `json:"is_missing"`
+ // Filter to include/exclude galleries that were created from zip
+ IsZip *bool `json:"is_zip"`
+ // Filter by rating
+ Rating *IntCriterionInput `json:"rating"`
+ // Filter by organized
+ Organized *bool `json:"organized"`
+ // Filter by average image resolution
+ AverageResolution *ResolutionCriterionInput `json:"average_resolution"`
+ // Filter to only include galleries with this studio
+ Studios *HierarchicalMultiCriterionInput `json:"studios"`
+ // Filter to only include galleries with these tags
+ Tags *HierarchicalMultiCriterionInput `json:"tags"`
+ // Filter by tag count
+ TagCount *IntCriterionInput `json:"tag_count"`
+ // Filter to only include galleries with performers with these tags
+ PerformerTags *HierarchicalMultiCriterionInput `json:"performer_tags"`
+ // Filter to only include galleries with these performers
+ Performers *MultiCriterionInput `json:"performers"`
+ // Filter by performer count
+ PerformerCount *IntCriterionInput `json:"performer_count"`
+ // Filter galleries that have performers that have been favorited
+ PerformerFavorite *bool `json:"performer_favorite"`
+ // Filter galleries by performer age at time of gallery
+ PerformerAge *IntCriterionInput `json:"performer_age"`
+ // Filter by number of images in this gallery
+ ImageCount *IntCriterionInput `json:"image_count"`
+ // Filter by url
+ URL *StringCriterionInput `json:"url"`
+}
+
+type GalleryUpdateInput struct {
+ ClientMutationID *string `json:"clientMutationId"`
+ ID string `json:"id"`
+ Title *string `json:"title"`
+ URL *string `json:"url"`
+ Date *string `json:"date"`
+ Details *string `json:"details"`
+ Rating *int `json:"rating"`
+ Organized *bool `json:"organized"`
+ SceneIds []string `json:"scene_ids"`
+ StudioID *string `json:"studio_id"`
+ TagIds []string `json:"tag_ids"`
+ PerformerIds []string `json:"performer_ids"`
+}
+
+type GalleryDestroyInput struct {
+ Ids []string `json:"ids"`
+ // If true, then the zip file will be deleted if the gallery is zip-file-based.
+ // If gallery is folder-based, then any files not associated with other
+ // galleries will be deleted, along with the folder, if it is not empty.
+ DeleteFile *bool `json:"delete_file"`
+ DeleteGenerated *bool `json:"delete_generated"`
+}
+
type GalleryReader interface {
Find(id int) (*Gallery, error)
FindMany(ids []int) ([]*Gallery, error)
diff --git a/pkg/models/generate.go b/pkg/models/generate.go
new file mode 100644
index 000000000..85685e078
--- /dev/null
+++ b/pkg/models/generate.go
@@ -0,0 +1,91 @@
+package models
+
+import (
+ "fmt"
+ "io"
+ "strconv"
+)
+
+type GenerateMetadataOptions struct {
+ Sprites *bool `json:"sprites"`
+ Previews *bool `json:"previews"`
+ ImagePreviews *bool `json:"imagePreviews"`
+ PreviewOptions *GeneratePreviewOptions `json:"previewOptions"`
+ Markers *bool `json:"markers"`
+ MarkerImagePreviews *bool `json:"markerImagePreviews"`
+ MarkerScreenshots *bool `json:"markerScreenshots"`
+ Transcodes *bool `json:"transcodes"`
+ Phashes *bool `json:"phashes"`
+ InteractiveHeatmapsSpeeds *bool `json:"interactiveHeatmapsSpeeds"`
+}
+
+type GeneratePreviewOptions struct {
+ // Number of segments in a preview file
+ PreviewSegments *int `json:"previewSegments"`
+ // Preview segment duration, in seconds
+ PreviewSegmentDuration *float64 `json:"previewSegmentDuration"`
+ // Duration of start of video to exclude when generating previews
+ PreviewExcludeStart *string `json:"previewExcludeStart"`
+ // Duration of end of video to exclude when generating previews
+ PreviewExcludeEnd *string `json:"previewExcludeEnd"`
+ // Preset when generating preview
+ PreviewPreset *PreviewPreset `json:"previewPreset"`
+}
+
+type PreviewPreset string
+
+const (
+ // X264_ULTRAFAST
+ PreviewPresetUltrafast PreviewPreset = "ultrafast"
+ // X264_VERYFAST
+ PreviewPresetVeryfast PreviewPreset = "veryfast"
+ // X264_FAST
+ PreviewPresetFast PreviewPreset = "fast"
+ // X264_MEDIUM
+ PreviewPresetMedium PreviewPreset = "medium"
+ // X264_SLOW
+ PreviewPresetSlow PreviewPreset = "slow"
+ // X264_SLOWER
+ PreviewPresetSlower PreviewPreset = "slower"
+ // X264_VERYSLOW
+ PreviewPresetVeryslow PreviewPreset = "veryslow"
+)
+
+var AllPreviewPreset = []PreviewPreset{
+ PreviewPresetUltrafast,
+ PreviewPresetVeryfast,
+ PreviewPresetFast,
+ PreviewPresetMedium,
+ PreviewPresetSlow,
+ PreviewPresetSlower,
+ PreviewPresetVeryslow,
+}
+
+func (e PreviewPreset) IsValid() bool {
+ switch e {
+ case PreviewPresetUltrafast, PreviewPresetVeryfast, PreviewPresetFast, PreviewPresetMedium, PreviewPresetSlow, PreviewPresetSlower, PreviewPresetVeryslow:
+ return true
+ }
+ return false
+}
+
+func (e PreviewPreset) String() string {
+ return string(e)
+}
+
+func (e *PreviewPreset) UnmarshalGQL(v interface{}) error {
+ str, ok := v.(string)
+ if !ok {
+ return fmt.Errorf("enums must be strings")
+ }
+
+ *e = PreviewPreset(str)
+ if !e.IsValid() {
+ return fmt.Errorf("%s is not a valid PreviewPreset", str)
+ }
+ return nil
+}
+
+func (e PreviewPreset) MarshalGQL(w io.Writer) {
+ fmt.Fprint(w, strconv.Quote(e.String()))
+}
diff --git a/pkg/models/image.go b/pkg/models/image.go
index bae3c043f..4b28c4c36 100644
--- a/pkg/models/image.go
+++ b/pkg/models/image.go
@@ -1,5 +1,54 @@
package models
+type ImageFilterType struct {
+ And *ImageFilterType `json:"AND"`
+ Or *ImageFilterType `json:"OR"`
+ Not *ImageFilterType `json:"NOT"`
+ Title *StringCriterionInput `json:"title"`
+ // Filter by file checksum
+ Checksum *StringCriterionInput `json:"checksum"`
+ // Filter by path
+ Path *StringCriterionInput `json:"path"`
+ // Filter by rating
+ Rating *IntCriterionInput `json:"rating"`
+ // Filter by organized
+ Organized *bool `json:"organized"`
+ // Filter by o-counter
+ OCounter *IntCriterionInput `json:"o_counter"`
+ // Filter by resolution
+ Resolution *ResolutionCriterionInput `json:"resolution"`
+ // Filter to only include images missing this property
+ IsMissing *string `json:"is_missing"`
+ // Filter to only include images with this studio
+ Studios *HierarchicalMultiCriterionInput `json:"studios"`
+ // Filter to only include images with these tags
+ Tags *HierarchicalMultiCriterionInput `json:"tags"`
+ // Filter by tag count
+ TagCount *IntCriterionInput `json:"tag_count"`
+ // Filter to only include images with performers with these tags
+ PerformerTags *HierarchicalMultiCriterionInput `json:"performer_tags"`
+ // Filter to only include images with these performers
+ Performers *MultiCriterionInput `json:"performers"`
+ // Filter by performer count
+ PerformerCount *IntCriterionInput `json:"performer_count"`
+ // Filter images that have performers that have been favorited
+ PerformerFavorite *bool `json:"performer_favorite"`
+ // Filter to only include images with these galleries
+ Galleries *MultiCriterionInput `json:"galleries"`
+}
+
+type ImageDestroyInput struct {
+ ID string `json:"id"`
+ DeleteFile *bool `json:"delete_file"`
+ DeleteGenerated *bool `json:"delete_generated"`
+}
+
+type ImagesDestroyInput struct {
+ Ids []string `json:"ids"`
+ DeleteFile *bool `json:"delete_file"`
+ DeleteGenerated *bool `json:"delete_generated"`
+}
+
type ImageQueryOptions struct {
QueryOptions
ImageFilter *ImageFilterType
diff --git a/pkg/models/import.go b/pkg/models/import.go
new file mode 100644
index 000000000..164f1b528
--- /dev/null
+++ b/pkg/models/import.go
@@ -0,0 +1,50 @@
+package models
+
+import (
+ "fmt"
+ "io"
+ "strconv"
+)
+
+type ImportMissingRefEnum string
+
+const (
+ ImportMissingRefEnumIgnore ImportMissingRefEnum = "IGNORE"
+ ImportMissingRefEnumFail ImportMissingRefEnum = "FAIL"
+ ImportMissingRefEnumCreate ImportMissingRefEnum = "CREATE"
+)
+
+var AllImportMissingRefEnum = []ImportMissingRefEnum{
+ ImportMissingRefEnumIgnore,
+ ImportMissingRefEnumFail,
+ ImportMissingRefEnumCreate,
+}
+
+func (e ImportMissingRefEnum) IsValid() bool {
+ switch e {
+ case ImportMissingRefEnumIgnore, ImportMissingRefEnumFail, ImportMissingRefEnumCreate:
+ return true
+ }
+ return false
+}
+
+func (e ImportMissingRefEnum) String() string {
+ return string(e)
+}
+
+func (e *ImportMissingRefEnum) UnmarshalGQL(v interface{}) error {
+ str, ok := v.(string)
+ if !ok {
+ return fmt.Errorf("enums must be strings")
+ }
+
+ *e = ImportMissingRefEnum(str)
+ if !e.IsValid() {
+ return fmt.Errorf("%s is not a valid ImportMissingRefEnum", str)
+ }
+ return nil
+}
+
+func (e ImportMissingRefEnum) MarshalGQL(w io.Writer) {
+ fmt.Fprint(w, strconv.Quote(e.String()))
+}
diff --git a/pkg/models/model_file.go b/pkg/models/model_file.go
index 21fd51bab..2a1d31900 100644
--- a/pkg/models/model_file.go
+++ b/pkg/models/model_file.go
@@ -1,6 +1,53 @@
package models
-import "time"
+import (
+ "fmt"
+ "io"
+ "strconv"
+ "time"
+)
+
+type HashAlgorithm string
+
+const (
+ HashAlgorithmMd5 HashAlgorithm = "MD5"
+ // oshash
+ HashAlgorithmOshash HashAlgorithm = "OSHASH"
+)
+
+var AllHashAlgorithm = []HashAlgorithm{
+ HashAlgorithmMd5,
+ HashAlgorithmOshash,
+}
+
+func (e HashAlgorithm) IsValid() bool {
+ switch e {
+ case HashAlgorithmMd5, HashAlgorithmOshash:
+ return true
+ }
+ return false
+}
+
+func (e HashAlgorithm) String() string {
+ return string(e)
+}
+
+func (e *HashAlgorithm) UnmarshalGQL(v interface{}) error {
+ str, ok := v.(string)
+ if !ok {
+ return fmt.Errorf("enums must be strings")
+ }
+
+ *e = HashAlgorithm(str)
+ if !e.IsValid() {
+ return fmt.Errorf("%s is not a valid HashAlgorithm", str)
+ }
+ return nil
+}
+
+func (e HashAlgorithm) MarshalGQL(w io.Writer) {
+ fmt.Fprint(w, strconv.Quote(e.String()))
+}
type File struct {
Checksum string `db:"checksum" json:"checksum"`
diff --git a/pkg/models/model_saved_filter.go b/pkg/models/model_saved_filter.go
index 5acd6d8f8..618e9fe30 100644
--- a/pkg/models/model_saved_filter.go
+++ b/pkg/models/model_saved_filter.go
@@ -1,5 +1,64 @@
package models
+import (
+ "fmt"
+ "io"
+ "strconv"
+)
+
+type FilterMode string
+
+const (
+ FilterModeScenes FilterMode = "SCENES"
+ FilterModePerformers FilterMode = "PERFORMERS"
+ FilterModeStudios FilterMode = "STUDIOS"
+ FilterModeGalleries FilterMode = "GALLERIES"
+ FilterModeSceneMarkers FilterMode = "SCENE_MARKERS"
+ FilterModeMovies FilterMode = "MOVIES"
+ FilterModeTags FilterMode = "TAGS"
+ FilterModeImages FilterMode = "IMAGES"
+)
+
+var AllFilterMode = []FilterMode{
+ FilterModeScenes,
+ FilterModePerformers,
+ FilterModeStudios,
+ FilterModeGalleries,
+ FilterModeSceneMarkers,
+ FilterModeMovies,
+ FilterModeTags,
+ FilterModeImages,
+}
+
+func (e FilterMode) IsValid() bool {
+ switch e {
+ case FilterModeScenes, FilterModePerformers, FilterModeStudios, FilterModeGalleries, FilterModeSceneMarkers, FilterModeMovies, FilterModeTags, FilterModeImages:
+ return true
+ }
+ return false
+}
+
+func (e FilterMode) String() string {
+ return string(e)
+}
+
+func (e *FilterMode) UnmarshalGQL(v interface{}) error {
+ str, ok := v.(string)
+ if !ok {
+ return fmt.Errorf("enums must be strings")
+ }
+
+ *e = FilterMode(str)
+ if !e.IsValid() {
+ return fmt.Errorf("%s is not a valid FilterMode", str)
+ }
+ return nil
+}
+
+func (e FilterMode) MarshalGQL(w io.Writer) {
+ fmt.Fprint(w, strconv.Quote(e.String()))
+}
+
type SavedFilter struct {
ID int `db:"id" json:"id"`
Mode FilterMode `db:"mode" json:"mode"`
diff --git a/pkg/models/model_scene.go b/pkg/models/model_scene.go
index 7e1066ed3..649e78788 100644
--- a/pkg/models/model_scene.go
+++ b/pkg/models/model_scene.go
@@ -122,6 +122,30 @@ type ScenePartial struct {
InteractiveSpeed *sql.NullInt64 `db:"interactive_speed" json:"interactive_speed"`
}
+type SceneMovieInput struct {
+ MovieID string `json:"movie_id"`
+ SceneIndex *int `json:"scene_index"`
+}
+
+type SceneUpdateInput struct {
+ ClientMutationID *string `json:"clientMutationId"`
+ ID string `json:"id"`
+ Title *string `json:"title"`
+ Details *string `json:"details"`
+ URL *string `json:"url"`
+ Date *string `json:"date"`
+ Rating *int `json:"rating"`
+ Organized *bool `json:"organized"`
+ StudioID *string `json:"studio_id"`
+ GalleryIds []string `json:"gallery_ids"`
+ PerformerIds []string `json:"performer_ids"`
+ Movies []*SceneMovieInput `json:"movies"`
+ TagIds []string `json:"tag_ids"`
+ // This should be a URL or a base64 encoded data URL
+ CoverImage *string `json:"cover_image"`
+ StashIds []*StashIDInput `json:"stash_ids"`
+}
+
// UpdateInput constructs a SceneUpdateInput using the populated fields in the ScenePartial object.
func (s ScenePartial) UpdateInput() SceneUpdateInput {
boolPtrCopy := func(v *bool) *bool {
diff --git a/pkg/models/model_scraped_item.go b/pkg/models/model_scraped_item.go
index 4035163b7..9563fbfd2 100644
--- a/pkg/models/model_scraped_item.go
+++ b/pkg/models/model_scraped_item.go
@@ -4,6 +4,78 @@ import (
"database/sql"
)
+type ScrapedStudio struct {
+ // Set if studio matched
+ StoredID *string `json:"stored_id"`
+ Name string `json:"name"`
+ URL *string `json:"url"`
+ Image *string `json:"image"`
+ RemoteSiteID *string `json:"remote_site_id"`
+}
+
+func (ScrapedStudio) IsScrapedContent() {}
+
+// A performer from a scraping operation...
+type ScrapedPerformer struct {
+ // Set if performer matched
+ StoredID *string `json:"stored_id"`
+ Name *string `json:"name"`
+ Gender *string `json:"gender"`
+ URL *string `json:"url"`
+ Twitter *string `json:"twitter"`
+ Instagram *string `json:"instagram"`
+ Birthdate *string `json:"birthdate"`
+ Ethnicity *string `json:"ethnicity"`
+ Country *string `json:"country"`
+ EyeColor *string `json:"eye_color"`
+ Height *string `json:"height"`
+ Measurements *string `json:"measurements"`
+ FakeTits *string `json:"fake_tits"`
+ CareerLength *string `json:"career_length"`
+ Tattoos *string `json:"tattoos"`
+ Piercings *string `json:"piercings"`
+ Aliases *string `json:"aliases"`
+ Tags []*ScrapedTag `json:"tags"`
+ // This should be a base64 encoded data URL
+ Image *string `json:"image"`
+ Images []string `json:"images"`
+ Details *string `json:"details"`
+ DeathDate *string `json:"death_date"`
+ HairColor *string `json:"hair_color"`
+ Weight *string `json:"weight"`
+ RemoteSiteID *string `json:"remote_site_id"`
+}
+
+func (ScrapedPerformer) IsScrapedContent() {}
+
+type ScrapedTag struct {
+ // Set if tag matched
+ StoredID *string `json:"stored_id"`
+ Name string `json:"name"`
+}
+
+func (ScrapedTag) IsScrapedContent() {}
+
+// A movie from a scraping operation...
+type ScrapedMovie struct {
+ StoredID *string `json:"stored_id"`
+ Name *string `json:"name"`
+ Aliases *string `json:"aliases"`
+ Duration *string `json:"duration"`
+ Date *string `json:"date"`
+ Rating *string `json:"rating"`
+ Director *string `json:"director"`
+ URL *string `json:"url"`
+ Synopsis *string `json:"synopsis"`
+ Studio *ScrapedStudio `json:"studio"`
+ // This should be a base64 encoded data URL
+ FrontImage *string `json:"front_image"`
+ // This should be a base64 encoded data URL
+ BackImage *string `json:"back_image"`
+}
+
+func (ScrapedMovie) IsScrapedContent() {}
+
type ScrapedItem struct {
ID int `db:"id" json:"id"`
Title sql.NullString `db:"title" json:"title"`
diff --git a/pkg/models/movie.go b/pkg/models/movie.go
index 3d11e1e51..8b68217ce 100644
--- a/pkg/models/movie.go
+++ b/pkg/models/movie.go
@@ -1,5 +1,23 @@
package models
+type MovieFilterType struct {
+ Name *StringCriterionInput `json:"name"`
+ Director *StringCriterionInput `json:"director"`
+ Synopsis *StringCriterionInput `json:"synopsis"`
+ // Filter by duration (in seconds)
+ Duration *IntCriterionInput `json:"duration"`
+ // Filter by rating
+ Rating *IntCriterionInput `json:"rating"`
+ // Filter to only include movies with this studio
+ Studios *HierarchicalMultiCriterionInput `json:"studios"`
+ // Filter to only include movies missing this property
+ IsMissing *string `json:"is_missing"`
+ // Filter by url
+ URL *StringCriterionInput `json:"url"`
+ // Filter to only include movies where performer appears in a scene
+ Performers *MultiCriterionInput `json:"performers"`
+}
+
type MovieReader interface {
Find(id int) (*Movie, error)
FindMany(ids []int) ([]*Movie, error)
diff --git a/pkg/models/performer.go b/pkg/models/performer.go
index 04173b47e..d50c360a4 100644
--- a/pkg/models/performer.go
+++ b/pkg/models/performer.go
@@ -1,5 +1,129 @@
package models
+import (
+ "fmt"
+ "io"
+ "strconv"
+)
+
+type GenderEnum string
+
+const (
+ GenderEnumMale GenderEnum = "MALE"
+ GenderEnumFemale GenderEnum = "FEMALE"
+ GenderEnumTransgenderMale GenderEnum = "TRANSGENDER_MALE"
+ GenderEnumTransgenderFemale GenderEnum = "TRANSGENDER_FEMALE"
+ GenderEnumIntersex GenderEnum = "INTERSEX"
+ GenderEnumNonBinary GenderEnum = "NON_BINARY"
+)
+
+var AllGenderEnum = []GenderEnum{
+ GenderEnumMale,
+ GenderEnumFemale,
+ GenderEnumTransgenderMale,
+ GenderEnumTransgenderFemale,
+ GenderEnumIntersex,
+ GenderEnumNonBinary,
+}
+
+func (e GenderEnum) IsValid() bool {
+ switch e {
+ case GenderEnumMale, GenderEnumFemale, GenderEnumTransgenderMale, GenderEnumTransgenderFemale, GenderEnumIntersex, GenderEnumNonBinary:
+ return true
+ }
+ return false
+}
+
+func (e GenderEnum) String() string {
+ return string(e)
+}
+
+func (e *GenderEnum) UnmarshalGQL(v interface{}) error {
+ str, ok := v.(string)
+ if !ok {
+ return fmt.Errorf("enums must be strings")
+ }
+
+ *e = GenderEnum(str)
+ if !e.IsValid() {
+ return fmt.Errorf("%s is not a valid GenderEnum", str)
+ }
+ return nil
+}
+
+func (e GenderEnum) MarshalGQL(w io.Writer) {
+ fmt.Fprint(w, strconv.Quote(e.String()))
+}
+
+type GenderCriterionInput struct {
+ Value *GenderEnum `json:"value"`
+ Modifier CriterionModifier `json:"modifier"`
+}
+
+type PerformerFilterType struct {
+ And *PerformerFilterType `json:"AND"`
+ Or *PerformerFilterType `json:"OR"`
+ Not *PerformerFilterType `json:"NOT"`
+ Name *StringCriterionInput `json:"name"`
+ Details *StringCriterionInput `json:"details"`
+ // Filter by favorite
+ FilterFavorites *bool `json:"filter_favorites"`
+ // Filter by birth year
+ BirthYear *IntCriterionInput `json:"birth_year"`
+ // Filter by age
+ Age *IntCriterionInput `json:"age"`
+ // Filter by ethnicity
+ Ethnicity *StringCriterionInput `json:"ethnicity"`
+ // Filter by country
+ Country *StringCriterionInput `json:"country"`
+ // Filter by eye color
+ EyeColor *StringCriterionInput `json:"eye_color"`
+ // Filter by height
+ Height *StringCriterionInput `json:"height"`
+ // Filter by measurements
+ Measurements *StringCriterionInput `json:"measurements"`
+ // Filter by fake tits value
+ FakeTits *StringCriterionInput `json:"fake_tits"`
+ // Filter by career length
+ CareerLength *StringCriterionInput `json:"career_length"`
+ // Filter by tattoos
+ Tattoos *StringCriterionInput `json:"tattoos"`
+ // Filter by piercings
+ Piercings *StringCriterionInput `json:"piercings"`
+ // Filter by aliases
+ Aliases *StringCriterionInput `json:"aliases"`
+ // Filter by gender
+ Gender *GenderCriterionInput `json:"gender"`
+ // Filter to only include performers missing this property
+ IsMissing *string `json:"is_missing"`
+ // Filter to only include performers with these tags
+ Tags *HierarchicalMultiCriterionInput `json:"tags"`
+ // Filter by tag count
+ TagCount *IntCriterionInput `json:"tag_count"`
+ // Filter by scene count
+ SceneCount *IntCriterionInput `json:"scene_count"`
+ // Filter by image count
+ ImageCount *IntCriterionInput `json:"image_count"`
+ // Filter by gallery count
+ GalleryCount *IntCriterionInput `json:"gallery_count"`
+ // Filter by StashID
+ StashID *StringCriterionInput `json:"stash_id"`
+ // Filter by rating
+ Rating *IntCriterionInput `json:"rating"`
+ // Filter by url
+ URL *StringCriterionInput `json:"url"`
+ // Filter by hair color
+ HairColor *StringCriterionInput `json:"hair_color"`
+ // Filter by weight
+ Weight *IntCriterionInput `json:"weight"`
+ // Filter by death year
+ DeathYear *IntCriterionInput `json:"death_year"`
+ // Filter by studios where performer appears in scene/image/gallery
+ Studios *HierarchicalMultiCriterionInput `json:"studios"`
+ // Filter by autotag ignore value
+ IgnoreAutoTag *bool `json:"ignore_auto_tag"`
+}
+
type PerformerReader interface {
Find(id int) (*Performer, error)
FindMany(ids []int) ([]*Performer, error)
diff --git a/pkg/models/resolution.go b/pkg/models/resolution.go
new file mode 100644
index 000000000..6b955797a
--- /dev/null
+++ b/pkg/models/resolution.go
@@ -0,0 +1,183 @@
+package models
+
+import (
+ "fmt"
+ "io"
+ "strconv"
+)
+
+type ResolutionRange struct {
+ min, max int
+}
+
+var resolutionRanges = map[ResolutionEnum]ResolutionRange{
+ ResolutionEnum("VERY_LOW"): {144, 239},
+ ResolutionEnum("LOW"): {240, 359},
+ ResolutionEnum("R360P"): {360, 479},
+ ResolutionEnum("STANDARD"): {480, 539},
+ ResolutionEnum("WEB_HD"): {540, 719},
+ ResolutionEnum("STANDARD_HD"): {720, 1079},
+ ResolutionEnum("FULL_HD"): {1080, 1439},
+ ResolutionEnum("QUAD_HD"): {1440, 1919},
+ ResolutionEnum("VR_HD"): {1920, 2159},
+ ResolutionEnum("FOUR_K"): {2160, 2879},
+ ResolutionEnum("FIVE_K"): {2880, 3383},
+ ResolutionEnum("SIX_K"): {3384, 4319},
+ ResolutionEnum("EIGHT_K"): {4320, 8639},
+}
+
+type ResolutionEnum string
+
+const (
+ // 144p
+ ResolutionEnumVeryLow ResolutionEnum = "VERY_LOW"
+ // 240p
+ ResolutionEnumLow ResolutionEnum = "LOW"
+ // 360p
+ ResolutionEnumR360p ResolutionEnum = "R360P"
+ // 480p
+ ResolutionEnumStandard ResolutionEnum = "STANDARD"
+ // 540p
+ ResolutionEnumWebHd ResolutionEnum = "WEB_HD"
+ // 720p
+ ResolutionEnumStandardHd ResolutionEnum = "STANDARD_HD"
+ // 1080p
+ ResolutionEnumFullHd ResolutionEnum = "FULL_HD"
+ // 1440p
+ ResolutionEnumQuadHd ResolutionEnum = "QUAD_HD"
+ // 1920p
+ ResolutionEnumVrHd ResolutionEnum = "VR_HD"
+ // 4k
+ ResolutionEnumFourK ResolutionEnum = "FOUR_K"
+ // 5k
+ ResolutionEnumFiveK ResolutionEnum = "FIVE_K"
+ // 6k
+ ResolutionEnumSixK ResolutionEnum = "SIX_K"
+ // 8k
+ ResolutionEnumEightK ResolutionEnum = "EIGHT_K"
+)
+
+var AllResolutionEnum = []ResolutionEnum{
+ ResolutionEnumVeryLow,
+ ResolutionEnumLow,
+ ResolutionEnumR360p,
+ ResolutionEnumStandard,
+ ResolutionEnumWebHd,
+ ResolutionEnumStandardHd,
+ ResolutionEnumFullHd,
+ ResolutionEnumQuadHd,
+ ResolutionEnumVrHd,
+ ResolutionEnumFourK,
+ ResolutionEnumFiveK,
+ ResolutionEnumSixK,
+ ResolutionEnumEightK,
+}
+
+func (e ResolutionEnum) IsValid() bool {
+ switch e {
+ case ResolutionEnumVeryLow, ResolutionEnumLow, ResolutionEnumR360p, ResolutionEnumStandard, ResolutionEnumWebHd, ResolutionEnumStandardHd, ResolutionEnumFullHd, ResolutionEnumQuadHd, ResolutionEnumVrHd, ResolutionEnumFourK, ResolutionEnumFiveK, ResolutionEnumSixK, ResolutionEnumEightK:
+ return true
+ }
+ return false
+}
+
+func (e ResolutionEnum) String() string {
+ return string(e)
+}
+
+func (e *ResolutionEnum) UnmarshalGQL(v interface{}) error {
+ str, ok := v.(string)
+ if !ok {
+ return fmt.Errorf("enums must be strings")
+ }
+
+ *e = ResolutionEnum(str)
+ if !e.IsValid() {
+ return fmt.Errorf("%s is not a valid ResolutionEnum", str)
+ }
+ return nil
+}
+
+func (e ResolutionEnum) MarshalGQL(w io.Writer) {
+ fmt.Fprint(w, strconv.Quote(e.String()))
+}
+
+// GetMaxResolution returns the maximum width or height that media must be
+// to qualify as this resolution.
+func (e *ResolutionEnum) GetMaxResolution() int {
+ return resolutionRanges[*e].max
+}
+
+// GetMinResolution returns the minimum width or height that media must be
+// to qualify as this resolution.
+func (e *ResolutionEnum) GetMinResolution() int {
+ return resolutionRanges[*e].min
+}
+
+type StreamingResolutionEnum string
+
+const (
+ // 240p
+ StreamingResolutionEnumLow StreamingResolutionEnum = "LOW"
+ // 480p
+ StreamingResolutionEnumStandard StreamingResolutionEnum = "STANDARD"
+ // 720p
+ StreamingResolutionEnumStandardHd StreamingResolutionEnum = "STANDARD_HD"
+ // 1080p
+ StreamingResolutionEnumFullHd StreamingResolutionEnum = "FULL_HD"
+ // 4k
+ StreamingResolutionEnumFourK StreamingResolutionEnum = "FOUR_K"
+ // Original
+ StreamingResolutionEnumOriginal StreamingResolutionEnum = "ORIGINAL"
+)
+
+var AllStreamingResolutionEnum = []StreamingResolutionEnum{
+ StreamingResolutionEnumLow,
+ StreamingResolutionEnumStandard,
+ StreamingResolutionEnumStandardHd,
+ StreamingResolutionEnumFullHd,
+ StreamingResolutionEnumFourK,
+ StreamingResolutionEnumOriginal,
+}
+
+func (e StreamingResolutionEnum) IsValid() bool {
+ switch e {
+ case StreamingResolutionEnumLow, StreamingResolutionEnumStandard, StreamingResolutionEnumStandardHd, StreamingResolutionEnumFullHd, StreamingResolutionEnumFourK, StreamingResolutionEnumOriginal:
+ return true
+ }
+ return false
+}
+
+func (e StreamingResolutionEnum) String() string {
+ return string(e)
+}
+
+func (e *StreamingResolutionEnum) UnmarshalGQL(v interface{}) error {
+ str, ok := v.(string)
+ if !ok {
+ return fmt.Errorf("enums must be strings")
+ }
+
+ *e = StreamingResolutionEnum(str)
+ if !e.IsValid() {
+ return fmt.Errorf("%s is not a valid StreamingResolutionEnum", str)
+ }
+ return nil
+}
+
+func (e StreamingResolutionEnum) MarshalGQL(w io.Writer) {
+ fmt.Fprint(w, strconv.Quote(e.String()))
+}
+
+var streamingResolutionMax = map[StreamingResolutionEnum]int{
+ StreamingResolutionEnumLow: resolutionRanges[ResolutionEnumLow].min,
+ StreamingResolutionEnumStandard: resolutionRanges[ResolutionEnumStandard].min,
+ StreamingResolutionEnumStandardHd: resolutionRanges[ResolutionEnumStandardHd].min,
+ StreamingResolutionEnumFullHd: resolutionRanges[ResolutionEnumFullHd].min,
+ StreamingResolutionEnumFourK: resolutionRanges[ResolutionEnumFourK].min,
+ StreamingResolutionEnumOriginal: 0,
+}
+
+func (e StreamingResolutionEnum) GetMaxResolution() int {
+ return streamingResolutionMax[e]
+}
diff --git a/pkg/models/scene.go b/pkg/models/scene.go
index 86a131a0f..6940abe84 100644
--- a/pkg/models/scene.go
+++ b/pkg/models/scene.go
@@ -1,5 +1,71 @@
package models
+type PHashDuplicationCriterionInput struct {
+ Duplicated *bool `json:"duplicated"`
+ // Currently unimplemented
+ Distance *int `json:"distance"`
+}
+
+type SceneFilterType struct {
+ And *SceneFilterType `json:"AND"`
+ Or *SceneFilterType `json:"OR"`
+ Not *SceneFilterType `json:"NOT"`
+ Title *StringCriterionInput `json:"title"`
+ Details *StringCriterionInput `json:"details"`
+ // Filter by file oshash
+ Oshash *StringCriterionInput `json:"oshash"`
+ // Filter by file checksum
+ Checksum *StringCriterionInput `json:"checksum"`
+ // Filter by file phash
+ Phash *StringCriterionInput `json:"phash"`
+ // Filter by path
+ Path *StringCriterionInput `json:"path"`
+ // Filter by rating
+ Rating *IntCriterionInput `json:"rating"`
+ // Filter by organized
+ Organized *bool `json:"organized"`
+ // Filter by o-counter
+ OCounter *IntCriterionInput `json:"o_counter"`
+ // Filter Scenes that have an exact phash match available
+ Duplicated *PHashDuplicationCriterionInput `json:"duplicated"`
+ // Filter by resolution
+ Resolution *ResolutionCriterionInput `json:"resolution"`
+ // Filter by duration (in seconds)
+ Duration *IntCriterionInput `json:"duration"`
+ // Filter to only include scenes which have markers. `true` or `false`
+ HasMarkers *string `json:"has_markers"`
+ // Filter to only include scenes missing this property
+ IsMissing *string `json:"is_missing"`
+ // Filter to only include scenes with this studio
+ Studios *HierarchicalMultiCriterionInput `json:"studios"`
+ // Filter to only include scenes with this movie
+ Movies *MultiCriterionInput `json:"movies"`
+ // Filter to only include scenes with these tags
+ Tags *HierarchicalMultiCriterionInput `json:"tags"`
+ // Filter by tag count
+ TagCount *IntCriterionInput `json:"tag_count"`
+ // Filter to only include scenes with performers with these tags
+ PerformerTags *HierarchicalMultiCriterionInput `json:"performer_tags"`
+ // Filter scenes that have performers that have been favorited
+ PerformerFavorite *bool `json:"performer_favorite"`
+ // Filter scenes by performer age at time of scene
+ PerformerAge *IntCriterionInput `json:"performer_age"`
+ // Filter to only include scenes with these performers
+ Performers *MultiCriterionInput `json:"performers"`
+ // Filter by performer count
+ PerformerCount *IntCriterionInput `json:"performer_count"`
+ // Filter by StashID
+ StashID *StringCriterionInput `json:"stash_id"`
+ // Filter by url
+ URL *StringCriterionInput `json:"url"`
+ // Filter by interactive
+ Interactive *bool `json:"interactive"`
+ // Filter by InteractiveSpeed
+ InteractiveSpeed *IntCriterionInput `json:"interactive_speed"`
+
+ Captions *StringCriterionInput `json:"captions"`
+}
+
type SceneQueryOptions struct {
QueryOptions
SceneFilter *SceneFilterType
@@ -18,6 +84,18 @@ type SceneQueryResult struct {
resolveErr error
}
+type SceneDestroyInput struct {
+ ID string `json:"id"`
+ DeleteFile *bool `json:"delete_file"`
+ DeleteGenerated *bool `json:"delete_generated"`
+}
+
+type ScenesDestroyInput struct {
+ Ids []string `json:"ids"`
+ DeleteFile *bool `json:"delete_file"`
+ DeleteGenerated *bool `json:"delete_generated"`
+}
+
func NewSceneQueryResult(finder SceneFinder) *SceneQueryResult {
return &SceneQueryResult{
finder: finder,
diff --git a/pkg/models/scene_marker.go b/pkg/models/scene_marker.go
index 49f169098..ac7501672 100644
--- a/pkg/models/scene_marker.go
+++ b/pkg/models/scene_marker.go
@@ -1,5 +1,22 @@
package models
+type SceneMarkerFilterType struct {
+ // Filter to only include scene markers with this tag
+ TagID *string `json:"tag_id"`
+ // Filter to only include scene markers with these tags
+ Tags *HierarchicalMultiCriterionInput `json:"tags"`
+ // Filter to only include scene markers attached to a scene with these tags
+ SceneTags *HierarchicalMultiCriterionInput `json:"scene_tags"`
+ // Filter to only include scene markers with these performers
+ Performers *MultiCriterionInput `json:"performers"`
+}
+
+type MarkerStringsResultType struct {
+ Count int `json:"count"`
+ ID string `json:"id"`
+ Title string `json:"title"`
+}
+
type SceneMarkerReader interface {
Find(id int) (*SceneMarker, error)
FindMany(ids []int) ([]*SceneMarker, error)
diff --git a/pkg/models/stash_box.go b/pkg/models/stash_box.go
index 3e981484b..9c9d7ed2e 100644
--- a/pkg/models/stash_box.go
+++ b/pkg/models/stash_box.go
@@ -1,39 +1,13 @@
package models
-import (
- "fmt"
- "strings"
-)
-
-type StashBoxes []*StashBox
-
-func (sb StashBoxes) ResolveStashBox(source ScraperSourceInput) (*StashBox, error) {
- if source.StashBoxIndex != nil {
- index := source.StashBoxIndex
- if *index < 0 || *index >= len(sb) {
- return nil, fmt.Errorf("%w: invalid stash_box_index: %d", ErrScraperSource, index)
- }
-
- return sb[*index], nil
- }
-
- if source.StashBoxEndpoint != nil {
- var ret *StashBox
- endpoint := *source.StashBoxEndpoint
- for _, b := range sb {
- if strings.EqualFold(endpoint, b.Endpoint) {
- ret = b
- }
- }
-
- if ret == nil {
- return nil, fmt.Errorf(`%w: stash-box with endpoint "%s"`, ErrNotFound, endpoint)
- }
-
- return ret, nil
- }
-
- // neither stash-box inputs were provided, so assume it is a scraper
-
- return nil, nil
+type StashBoxFingerprint struct {
+ Algorithm string `json:"algorithm"`
+ Hash string `json:"hash"`
+ Duration int `json:"duration"`
+}
+
+type StashBox struct {
+ Endpoint string `json:"endpoint"`
+ APIKey string `json:"api_key"`
+ Name string `json:"name"`
}
diff --git a/pkg/models/stash_ids.go b/pkg/models/stash_ids.go
index fb20522a5..0a7e1edd9 100644
--- a/pkg/models/stash_ids.go
+++ b/pkg/models/stash_ids.go
@@ -1,5 +1,10 @@
package models
+type StashIDInput struct {
+ Endpoint string `json:"endpoint"`
+ StashID string `json:"stash_id"`
+}
+
func StashIDsFromInput(i []*StashIDInput) []StashID {
var ret []StashID
for _, stashID := range i {
diff --git a/pkg/models/studio.go b/pkg/models/studio.go
index e5d6bfb19..fb9f7c415 100644
--- a/pkg/models/studio.go
+++ b/pkg/models/studio.go
@@ -1,5 +1,33 @@
package models
+type StudioFilterType struct {
+ And *StudioFilterType `json:"AND"`
+ Or *StudioFilterType `json:"OR"`
+ Not *StudioFilterType `json:"NOT"`
+ Name *StringCriterionInput `json:"name"`
+ Details *StringCriterionInput `json:"details"`
+ // Filter to only include studios with this parent studio
+ Parents *MultiCriterionInput `json:"parents"`
+ // Filter by StashID
+ StashID *StringCriterionInput `json:"stash_id"`
+ // Filter to only include studios missing this property
+ IsMissing *string `json:"is_missing"`
+ // Filter by rating
+ Rating *IntCriterionInput `json:"rating"`
+ // Filter by scene count
+ SceneCount *IntCriterionInput `json:"scene_count"`
+ // Filter by image count
+ ImageCount *IntCriterionInput `json:"image_count"`
+ // Filter by gallery count
+ GalleryCount *IntCriterionInput `json:"gallery_count"`
+ // Filter by url
+ URL *StringCriterionInput `json:"url"`
+ // Filter by studio aliases
+ Aliases *StringCriterionInput `json:"aliases"`
+ // Filter by autotag ignore value
+ IgnoreAutoTag *bool `json:"ignore_auto_tag"`
+}
+
type StudioReader interface {
Find(id int) (*Studio, error)
FindMany(ids []int) ([]*Studio, error)
diff --git a/pkg/models/tag.go b/pkg/models/tag.go
index 747e7a08e..a7f7518bf 100644
--- a/pkg/models/tag.go
+++ b/pkg/models/tag.go
@@ -1,5 +1,37 @@
package models
+type TagFilterType struct {
+ And *TagFilterType `json:"AND"`
+ Or *TagFilterType `json:"OR"`
+ Not *TagFilterType `json:"NOT"`
+ // Filter by tag name
+ Name *StringCriterionInput `json:"name"`
+ // Filter by tag aliases
+ Aliases *StringCriterionInput `json:"aliases"`
+ // Filter to only include tags missing this property
+ IsMissing *string `json:"is_missing"`
+ // Filter by number of scenes with this tag
+ SceneCount *IntCriterionInput `json:"scene_count"`
+ // Filter by number of images with this tag
+ ImageCount *IntCriterionInput `json:"image_count"`
+ // Filter by number of galleries with this tag
+ GalleryCount *IntCriterionInput `json:"gallery_count"`
+ // Filter by number of performers with this tag
+ PerformerCount *IntCriterionInput `json:"performer_count"`
+ // Filter by number of markers with this tag
+ MarkerCount *IntCriterionInput `json:"marker_count"`
+ // Filter by parent tags
+ Parents *HierarchicalMultiCriterionInput `json:"parents"`
+ // Filter by child tags
+ Children *HierarchicalMultiCriterionInput `json:"children"`
+ // Filter by number of parent tags the tag has
+ ParentCount *IntCriterionInput `json:"parent_count"`
+ // Filter by number f child tags the tag has
+ ChildCount *IntCriterionInput `json:"child_count"`
+ // Filter by autotag ignore value
+ IgnoreAutoTag *bool `json:"ignore_auto_tag"`
+}
+
type TagReader interface {
Find(id int) (*Tag, error)
FindMany(ids []int) ([]*Tag, error)
diff --git a/pkg/plugin/args.go b/pkg/plugin/args.go
index adcdf007f..d85a624c7 100644
--- a/pkg/plugin/args.go
+++ b/pkg/plugin/args.go
@@ -1,10 +1,20 @@
package plugin
-import (
- "github.com/stashapp/stash/pkg/models"
-)
+type PluginArgInput struct {
+ Key string `json:"key"`
+ Value *PluginValueInput `json:"value"`
+}
-func findArg(args []*models.PluginArgInput, name string) *models.PluginArgInput {
+type PluginValueInput struct {
+ Str *string `json:"str"`
+ I *int `json:"i"`
+ B *bool `json:"b"`
+ F *float64 `json:"f"`
+ O []*PluginArgInput `json:"o"`
+ A []*PluginValueInput `json:"a"`
+}
+
+func findArg(args []*PluginArgInput, name string) *PluginArgInput {
for _, v := range args {
if v.Key == name {
return v
@@ -14,13 +24,13 @@ func findArg(args []*models.PluginArgInput, name string) *models.PluginArgInput
return nil
}
-func applyDefaultArgs(args []*models.PluginArgInput, defaultArgs map[string]string) []*models.PluginArgInput {
+func applyDefaultArgs(args []*PluginArgInput, defaultArgs map[string]string) []*PluginArgInput {
for k, v := range defaultArgs {
if arg := findArg(args, k); arg == nil {
v := v // Copy v, because it's being exported out of the loop
- args = append(args, &models.PluginArgInput{
+ args = append(args, &PluginArgInput{
Key: k,
- Value: &models.PluginValueInput{
+ Value: &PluginValueInput{
Str: &v,
},
})
diff --git a/pkg/plugin/config.go b/pkg/plugin/config.go
index 05501b4e2..2a00c3ced 100644
--- a/pkg/plugin/config.go
+++ b/pkg/plugin/config.go
@@ -8,7 +8,6 @@ import (
"path/filepath"
"strings"
- "github.com/stashapp/stash/pkg/models"
"gopkg.in/yaml.v2"
)
@@ -59,11 +58,11 @@ type Config struct {
Hooks []*HookConfig `yaml:"hooks"`
}
-func (c Config) getPluginTasks(includePlugin bool) []*models.PluginTask {
- var ret []*models.PluginTask
+func (c Config) getPluginTasks(includePlugin bool) []*PluginTask {
+ var ret []*PluginTask
for _, o := range c.Tasks {
- task := &models.PluginTask{
+ task := &PluginTask{
Name: o.Name,
Description: &o.Description,
}
@@ -77,11 +76,11 @@ func (c Config) getPluginTasks(includePlugin bool) []*models.PluginTask {
return ret
}
-func (c Config) getPluginHooks(includePlugin bool) []*models.PluginHook {
- var ret []*models.PluginHook
+func (c Config) getPluginHooks(includePlugin bool) []*PluginHook {
+ var ret []*PluginHook
for _, o := range c.Hooks {
- hook := &models.PluginHook{
+ hook := &PluginHook{
Name: o.Name,
Description: &o.Description,
Hooks: convertHooks(o.TriggeredBy),
@@ -113,8 +112,8 @@ func (c Config) getName() string {
return c.id
}
-func (c Config) toPlugin() *models.Plugin {
- return &models.Plugin{
+func (c Config) toPlugin() *Plugin {
+ return &Plugin{
ID: c.id,
Name: c.getName(),
Description: c.Description,
diff --git a/pkg/plugin/convert.go b/pkg/plugin/convert.go
index 989008d60..7aeb95983 100644
--- a/pkg/plugin/convert.go
+++ b/pkg/plugin/convert.go
@@ -1,11 +1,10 @@
package plugin
import (
- "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/plugin/common"
)
-func toPluginArgs(args []*models.PluginArgInput) common.ArgsMap {
+func toPluginArgs(args []*PluginArgInput) common.ArgsMap {
ret := make(common.ArgsMap)
for _, a := range args {
ret[a.Key] = toPluginArgValue(a.Value)
@@ -14,7 +13,7 @@ func toPluginArgs(args []*models.PluginArgInput) common.ArgsMap {
return ret
}
-func toPluginArgValue(arg *models.PluginValueInput) common.PluginArgValue {
+func toPluginArgValue(arg *PluginValueInput) common.PluginArgValue {
if arg == nil {
return nil
}
diff --git a/pkg/plugin/hooks.go b/pkg/plugin/hooks.go
index f95eb1620..a60e44e6c 100644
--- a/pkg/plugin/hooks.go
+++ b/pkg/plugin/hooks.go
@@ -5,6 +5,13 @@ import (
"github.com/stashapp/stash/pkg/plugin/common"
)
+type PluginHook struct {
+ Name string `json:"name"`
+ Description *string `json:"description"`
+ Hooks []string `json:"hooks"`
+ Plugin *Plugin `json:"plugin"`
+}
+
type HookTriggerEnum string
// Scan-related hooks are current disabled until post-hook execution is
diff --git a/pkg/plugin/plugins.go b/pkg/plugin/plugins.go
index 85fde229b..c198b42a4 100644
--- a/pkg/plugin/plugins.go
+++ b/pkg/plugin/plugins.go
@@ -22,6 +22,16 @@ import (
"github.com/stashapp/stash/pkg/sliceutil/stringslice"
)
+type Plugin struct {
+ ID string `json:"id"`
+ Name string `json:"name"`
+ Description *string `json:"description"`
+ URL *string `json:"url"`
+ Version *string `json:"version"`
+ Tasks []*PluginTask `json:"tasks"`
+ Hooks []*PluginHook `json:"hooks"`
+}
+
type ServerConfig interface {
GetHost() string
GetPort() int
@@ -103,8 +113,8 @@ func loadPlugins(path string) ([]Config, error) {
}
// ListPlugins returns plugin details for all of the loaded plugins.
-func (c Cache) ListPlugins() []*models.Plugin {
- var ret []*models.Plugin
+func (c Cache) ListPlugins() []*Plugin {
+ var ret []*Plugin
for _, s := range c.plugins {
ret = append(ret, s.toPlugin())
}
@@ -113,8 +123,8 @@ func (c Cache) ListPlugins() []*models.Plugin {
}
// ListPluginTasks returns all runnable plugin tasks in all loaded plugins.
-func (c Cache) ListPluginTasks() []*models.PluginTask {
- var ret []*models.PluginTask
+func (c Cache) ListPluginTasks() []*PluginTask {
+ var ret []*PluginTask
for _, s := range c.plugins {
ret = append(ret, s.getPluginTasks(true)...)
}
@@ -122,7 +132,7 @@ func (c Cache) ListPluginTasks() []*models.PluginTask {
return ret
}
-func buildPluginInput(plugin *Config, operation *OperationConfig, serverConnection common.StashServerConnection, args []*models.PluginArgInput) common.PluginInput {
+func buildPluginInput(plugin *Config, operation *OperationConfig, serverConnection common.StashServerConnection, args []*PluginArgInput) common.PluginInput {
args = applyDefaultArgs(args, operation.DefaultArgs)
serverConnection.PluginDir = plugin.getConfigPath()
return common.PluginInput{
@@ -152,7 +162,7 @@ func (c Cache) makeServerConnection(ctx context.Context) common.StashServerConne
// CreateTask runs the plugin operation for the pluginID and operation
// name provided. Returns an error if the plugin or the operation could not be
// resolved.
-func (c Cache) CreateTask(ctx context.Context, pluginID string, operationName string, args []*models.PluginArgInput, progress chan float64) (Task, error) {
+func (c Cache) CreateTask(ctx context.Context, pluginID string, operationName string, args []*PluginArgInput, progress chan float64) (Task, error) {
serverConnection := c.makeServerConnection(ctx)
// find the plugin and operation
diff --git a/pkg/plugin/task.go b/pkg/plugin/task.go
index e80c96ff7..58b4b2eba 100644
--- a/pkg/plugin/task.go
+++ b/pkg/plugin/task.go
@@ -6,6 +6,12 @@ import (
"github.com/stashapp/stash/pkg/plugin/common"
)
+type PluginTask struct {
+ Name string `json:"name"`
+ Description *string `json:"description"`
+ Plugin *Plugin `json:"plugin"`
+}
+
// Task is the interface that handles management of a single plugin task.
type Task interface {
// Start starts the plugin task. Returns an error if task could not be
diff --git a/pkg/scraper/action.go b/pkg/scraper/action.go
index 3f80cee29..e674cb2a2 100644
--- a/pkg/scraper/action.go
+++ b/pkg/scraper/action.go
@@ -25,12 +25,12 @@ func (e scraperAction) IsValid() bool {
}
type scraperActionImpl interface {
- scrapeByURL(ctx context.Context, url string, ty models.ScrapeContentType) (models.ScrapedContent, error)
- scrapeByName(ctx context.Context, name string, ty models.ScrapeContentType) ([]models.ScrapedContent, error)
- scrapeByFragment(ctx context.Context, input Input) (models.ScrapedContent, error)
+ scrapeByURL(ctx context.Context, url string, ty ScrapeContentType) (ScrapedContent, error)
+ scrapeByName(ctx context.Context, name string, ty ScrapeContentType) ([]ScrapedContent, error)
+ scrapeByFragment(ctx context.Context, input Input) (ScrapedContent, error)
- scrapeSceneByScene(ctx context.Context, scene *models.Scene) (*models.ScrapedScene, error)
- scrapeGalleryByGallery(ctx context.Context, gallery *models.Gallery) (*models.ScrapedGallery, error)
+ scrapeSceneByScene(ctx context.Context, scene *models.Scene) (*ScrapedScene, error)
+ scrapeGalleryByGallery(ctx context.Context, gallery *models.Gallery) (*ScrapedGallery, error)
}
func (c config) getScraper(scraper scraperTypeConfig, client *http.Client, txnManager models.TransactionManager, globalConfig GlobalConfig) scraperActionImpl {
diff --git a/pkg/scraper/autotag.go b/pkg/scraper/autotag.go
index 4a86d8df2..52f3ce4af 100644
--- a/pkg/scraper/autotag.go
+++ b/pkg/scraper/autotag.go
@@ -83,8 +83,8 @@ func autotagMatchTags(path string, tagReader models.TagReader, trimExt bool) ([]
return ret, nil
}
-func (s autotagScraper) viaScene(ctx context.Context, _client *http.Client, scene *models.Scene) (*models.ScrapedScene, error) {
- var ret *models.ScrapedScene
+func (s autotagScraper) viaScene(ctx context.Context, _client *http.Client, scene *models.Scene) (*ScrapedScene, error) {
+ var ret *ScrapedScene
const trimExt = false
// populate performers, studio and tags based on scene path
@@ -105,7 +105,7 @@ func (s autotagScraper) viaScene(ctx context.Context, _client *http.Client, scen
}
if len(performers) > 0 || studio != nil || len(tags) > 0 {
- ret = &models.ScrapedScene{
+ ret = &ScrapedScene{
Performers: performers,
Studio: studio,
Tags: tags,
@@ -120,7 +120,7 @@ func (s autotagScraper) viaScene(ctx context.Context, _client *http.Client, scen
return ret, nil
}
-func (s autotagScraper) viaGallery(ctx context.Context, _client *http.Client, gallery *models.Gallery) (*models.ScrapedGallery, error) {
+func (s autotagScraper) viaGallery(ctx context.Context, _client *http.Client, gallery *models.Gallery) (*ScrapedGallery, error) {
if !gallery.Path.Valid {
// not valid for non-path-based galleries
return nil, nil
@@ -129,7 +129,7 @@ func (s autotagScraper) viaGallery(ctx context.Context, _client *http.Client, ga
// only trim extension if gallery is file-based
trimExt := gallery.Zip
- var ret *models.ScrapedGallery
+ var ret *ScrapedGallery
// populate performers, studio and tags based on scene path
if err := s.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
@@ -149,7 +149,7 @@ func (s autotagScraper) viaGallery(ctx context.Context, _client *http.Client, ga
}
if len(performers) > 0 || studio != nil || len(tags) > 0 {
- ret = &models.ScrapedGallery{
+ ret = &ScrapedGallery{
Performers: performers,
Studio: studio,
Tags: tags,
@@ -164,33 +164,33 @@ func (s autotagScraper) viaGallery(ctx context.Context, _client *http.Client, ga
return ret, nil
}
-func (s autotagScraper) supports(ty models.ScrapeContentType) bool {
+func (s autotagScraper) supports(ty ScrapeContentType) bool {
switch ty {
- case models.ScrapeContentTypeScene:
+ case ScrapeContentTypeScene:
return true
- case models.ScrapeContentTypeGallery:
+ case ScrapeContentTypeGallery:
return true
}
return false
}
-func (s autotagScraper) supportsURL(url string, ty models.ScrapeContentType) bool {
+func (s autotagScraper) supportsURL(url string, ty ScrapeContentType) bool {
return false
}
-func (s autotagScraper) spec() models.Scraper {
- supportedScrapes := []models.ScrapeType{
- models.ScrapeTypeFragment,
+func (s autotagScraper) spec() Scraper {
+ supportedScrapes := []ScrapeType{
+ ScrapeTypeFragment,
}
- return models.Scraper{
+ return Scraper{
ID: autoTagScraperID,
Name: autoTagScraperName,
- Scene: &models.ScraperSpec{
+ Scene: &ScraperSpec{
SupportedScrapes: supportedScrapes,
},
- Gallery: &models.ScraperSpec{
+ Gallery: &ScraperSpec{
SupportedScrapes: supportedScrapes,
},
}
diff --git a/pkg/scraper/cache.go b/pkg/scraper/cache.go
index 2190dcb03..5357f8b94 100644
--- a/pkg/scraper/cache.go
+++ b/pkg/scraper/cache.go
@@ -148,8 +148,8 @@ func (c *Cache) ReloadScrapers() error {
// ListScrapers lists scrapers matching one of the given types.
// Returns a list of scrapers, sorted by their ID.
-func (c Cache) ListScrapers(tys []models.ScrapeContentType) []*models.Scraper {
- var ret []*models.Scraper
+func (c Cache) ListScrapers(tys []ScrapeContentType) []*Scraper {
+ var ret []*Scraper
for _, s := range c.scrapers {
for _, t := range tys {
if s.supports(t) {
@@ -168,7 +168,7 @@ func (c Cache) ListScrapers(tys []models.ScrapeContentType) []*models.Scraper {
}
// GetScraper returns the scraper matching the provided id.
-func (c Cache) GetScraper(scraperID string) *models.Scraper {
+func (c Cache) GetScraper(scraperID string) *Scraper {
s := c.findScraper(scraperID)
if s != nil {
spec := s.spec()
@@ -187,7 +187,7 @@ func (c Cache) findScraper(scraperID string) scraper {
return nil
}
-func (c Cache) ScrapeName(ctx context.Context, id, query string, ty models.ScrapeContentType) ([]models.ScrapedContent, error) {
+func (c Cache) ScrapeName(ctx context.Context, id, query string, ty ScrapeContentType) ([]ScrapedContent, error) {
// find scraper with the provided id
s := c.findScraper(id)
if s == nil {
@@ -206,7 +206,7 @@ func (c Cache) ScrapeName(ctx context.Context, id, query string, ty models.Scrap
}
// ScrapeFragment uses the given fragment input to scrape
-func (c Cache) ScrapeFragment(ctx context.Context, id string, input Input) (models.ScrapedContent, error) {
+func (c Cache) ScrapeFragment(ctx context.Context, id string, input Input) (ScrapedContent, error) {
s := c.findScraper(id)
if s == nil {
return nil, fmt.Errorf("%w: id %s", ErrNotFound, id)
@@ -228,7 +228,7 @@ func (c Cache) ScrapeFragment(ctx context.Context, id string, input Input) (mode
// ScrapeURL scrapes a given url for the given content. Searches the scraper cache
// and picks the first scraper capable of scraping the given url into the desired
// content. Returns the scraped content or an error if the scrape fails.
-func (c Cache) ScrapeURL(ctx context.Context, url string, ty models.ScrapeContentType) (models.ScrapedContent, error) {
+func (c Cache) ScrapeURL(ctx context.Context, url string, ty ScrapeContentType) (ScrapedContent, error) {
for _, s := range c.scrapers {
if s.supportsURL(url, ty) {
ul, ok := s.(urlScraper)
@@ -251,7 +251,7 @@ func (c Cache) ScrapeURL(ctx context.Context, url string, ty models.ScrapeConten
return nil, nil
}
-func (c Cache) ScrapeID(ctx context.Context, scraperID string, id int, ty models.ScrapeContentType) (models.ScrapedContent, error) {
+func (c Cache) ScrapeID(ctx context.Context, scraperID string, id int, ty ScrapeContentType) (ScrapedContent, error) {
s := c.findScraper(scraperID)
if s == nil {
return nil, fmt.Errorf("%w: id %s", ErrNotFound, scraperID)
@@ -261,9 +261,9 @@ func (c Cache) ScrapeID(ctx context.Context, scraperID string, id int, ty models
return nil, fmt.Errorf("%w: cannot use scraper %s to scrape %v content", ErrNotSupported, scraperID, ty)
}
- var ret models.ScrapedContent
+ var ret ScrapedContent
switch ty {
- case models.ScrapeContentTypeScene:
+ case ScrapeContentTypeScene:
ss, ok := s.(sceneScraper)
if !ok {
return nil, fmt.Errorf("%w: cannot use scraper %s as a scene scraper", ErrNotSupported, scraperID)
@@ -284,7 +284,7 @@ func (c Cache) ScrapeID(ctx context.Context, scraperID string, id int, ty models
if scraped != nil {
ret = scraped
}
- case models.ScrapeContentTypeGallery:
+ case ScrapeContentTypeGallery:
gs, ok := s.(galleryScraper)
if !ok {
return nil, fmt.Errorf("%w: cannot use scraper %s as a gallery scraper", ErrNotSupported, scraperID)
diff --git a/pkg/scraper/config.go b/pkg/scraper/config.go
index 4782fb47b..3a0aadf51 100644
--- a/pkg/scraper/config.go
+++ b/pkg/scraper/config.go
@@ -8,7 +8,6 @@ import (
"path/filepath"
"strings"
- "github.com/stashapp/stash/pkg/models"
"gopkg.in/yaml.v2"
)
@@ -233,21 +232,21 @@ func loadConfigFromYAMLFile(path string) (*config, error) {
return ret, nil
}
-func (c config) spec() models.Scraper {
- ret := models.Scraper{
+func (c config) spec() Scraper {
+ ret := Scraper{
ID: c.ID,
Name: c.Name,
}
- performer := models.ScraperSpec{}
+ performer := ScraperSpec{}
if c.PerformerByName != nil {
- performer.SupportedScrapes = append(performer.SupportedScrapes, models.ScrapeTypeName)
+ performer.SupportedScrapes = append(performer.SupportedScrapes, ScrapeTypeName)
}
if c.PerformerByFragment != nil {
- performer.SupportedScrapes = append(performer.SupportedScrapes, models.ScrapeTypeFragment)
+ performer.SupportedScrapes = append(performer.SupportedScrapes, ScrapeTypeFragment)
}
if len(c.PerformerByURL) > 0 {
- performer.SupportedScrapes = append(performer.SupportedScrapes, models.ScrapeTypeURL)
+ performer.SupportedScrapes = append(performer.SupportedScrapes, ScrapeTypeURL)
for _, v := range c.PerformerByURL {
performer.Urls = append(performer.Urls, v.URL...)
}
@@ -257,15 +256,15 @@ func (c config) spec() models.Scraper {
ret.Performer = &performer
}
- scene := models.ScraperSpec{}
+ scene := ScraperSpec{}
if c.SceneByFragment != nil {
- scene.SupportedScrapes = append(scene.SupportedScrapes, models.ScrapeTypeFragment)
+ scene.SupportedScrapes = append(scene.SupportedScrapes, ScrapeTypeFragment)
}
if c.SceneByName != nil && c.SceneByQueryFragment != nil {
- scene.SupportedScrapes = append(scene.SupportedScrapes, models.ScrapeTypeName)
+ scene.SupportedScrapes = append(scene.SupportedScrapes, ScrapeTypeName)
}
if len(c.SceneByURL) > 0 {
- scene.SupportedScrapes = append(scene.SupportedScrapes, models.ScrapeTypeURL)
+ scene.SupportedScrapes = append(scene.SupportedScrapes, ScrapeTypeURL)
for _, v := range c.SceneByURL {
scene.Urls = append(scene.Urls, v.URL...)
}
@@ -275,12 +274,12 @@ func (c config) spec() models.Scraper {
ret.Scene = &scene
}
- gallery := models.ScraperSpec{}
+ gallery := ScraperSpec{}
if c.GalleryByFragment != nil {
- gallery.SupportedScrapes = append(gallery.SupportedScrapes, models.ScrapeTypeFragment)
+ gallery.SupportedScrapes = append(gallery.SupportedScrapes, ScrapeTypeFragment)
}
if len(c.GalleryByURL) > 0 {
- gallery.SupportedScrapes = append(gallery.SupportedScrapes, models.ScrapeTypeURL)
+ gallery.SupportedScrapes = append(gallery.SupportedScrapes, ScrapeTypeURL)
for _, v := range c.GalleryByURL {
gallery.Urls = append(gallery.Urls, v.URL...)
}
@@ -290,9 +289,9 @@ func (c config) spec() models.Scraper {
ret.Gallery = &gallery
}
- movie := models.ScraperSpec{}
+ movie := ScraperSpec{}
if len(c.MovieByURL) > 0 {
- movie.SupportedScrapes = append(movie.SupportedScrapes, models.ScrapeTypeURL)
+ movie.SupportedScrapes = append(movie.SupportedScrapes, ScrapeTypeURL)
for _, v := range c.MovieByURL {
movie.Urls = append(movie.Urls, v.URL...)
}
@@ -305,42 +304,42 @@ func (c config) spec() models.Scraper {
return ret
}
-func (c config) supports(ty models.ScrapeContentType) bool {
+func (c config) supports(ty ScrapeContentType) bool {
switch ty {
- case models.ScrapeContentTypePerformer:
+ case ScrapeContentTypePerformer:
return c.PerformerByName != nil || c.PerformerByFragment != nil || len(c.PerformerByURL) > 0
- case models.ScrapeContentTypeScene:
+ case ScrapeContentTypeScene:
return (c.SceneByName != nil && c.SceneByQueryFragment != nil) || c.SceneByFragment != nil || len(c.SceneByURL) > 0
- case models.ScrapeContentTypeGallery:
+ case ScrapeContentTypeGallery:
return c.GalleryByFragment != nil || len(c.GalleryByURL) > 0
- case models.ScrapeContentTypeMovie:
+ case ScrapeContentTypeMovie:
return len(c.MovieByURL) > 0
}
panic("Unhandled ScrapeContentType")
}
-func (c config) matchesURL(url string, ty models.ScrapeContentType) bool {
+func (c config) matchesURL(url string, ty ScrapeContentType) bool {
switch ty {
- case models.ScrapeContentTypePerformer:
+ case ScrapeContentTypePerformer:
for _, scraper := range c.PerformerByURL {
if scraper.matchesURL(url) {
return true
}
}
- case models.ScrapeContentTypeScene:
+ case ScrapeContentTypeScene:
for _, scraper := range c.SceneByURL {
if scraper.matchesURL(url) {
return true
}
}
- case models.ScrapeContentTypeGallery:
+ case ScrapeContentTypeGallery:
for _, scraper := range c.GalleryByURL {
if scraper.matchesURL(url) {
return true
}
}
- case models.ScrapeContentTypeMovie:
+ case ScrapeContentTypeMovie:
for _, scraper := range c.MovieByURL {
if scraper.matchesURL(url) {
return true
diff --git a/pkg/scraper/gallery.go b/pkg/scraper/gallery.go
new file mode 100644
index 000000000..db2c98755
--- /dev/null
+++ b/pkg/scraper/gallery.go
@@ -0,0 +1,22 @@
+package scraper
+
+import "github.com/stashapp/stash/pkg/models"
+
+type ScrapedGallery struct {
+ Title *string `json:"title"`
+ Details *string `json:"details"`
+ URL *string `json:"url"`
+ Date *string `json:"date"`
+ Studio *models.ScrapedStudio `json:"studio"`
+ Tags []*models.ScrapedTag `json:"tags"`
+ Performers []*models.ScrapedPerformer `json:"performers"`
+}
+
+func (ScrapedGallery) IsScrapedContent() {}
+
+type ScrapedGalleryInput struct {
+ Title *string `json:"title"`
+ Details *string `json:"details"`
+ URL *string `json:"url"`
+ Date *string `json:"date"`
+}
diff --git a/pkg/scraper/group.go b/pkg/scraper/group.go
index 7a3620118..b423883c4 100644
--- a/pkg/scraper/group.go
+++ b/pkg/scraper/group.go
@@ -23,7 +23,7 @@ func newGroupScraper(c config, txnManager models.TransactionManager, globalConfi
}
}
-func (g group) spec() models.Scraper {
+func (g group) spec() Scraper {
return g.config.spec()
}
@@ -42,14 +42,14 @@ func (g group) fragmentScraper(input Input) *scraperTypeConfig {
return nil
}
-func (g group) viaFragment(ctx context.Context, client *http.Client, input Input) (models.ScrapedContent, error) {
+func (g group) viaFragment(ctx context.Context, client *http.Client, input Input) (ScrapedContent, error) {
stc := g.fragmentScraper(input)
if stc == nil {
// If there's no performer fragment scraper in the group, we try to use
// the URL scraper. Check if there's an URL in the input, and then shift
// to an URL scrape if it's present.
if input.Performer != nil && input.Performer.URL != nil && *input.Performer.URL != "" {
- return g.viaURL(ctx, client, *input.Performer.URL, models.ScrapeContentTypePerformer)
+ return g.viaURL(ctx, client, *input.Performer.URL, ScrapeContentTypePerformer)
}
return nil, ErrNotSupported
@@ -59,7 +59,7 @@ func (g group) viaFragment(ctx context.Context, client *http.Client, input Input
return s.scrapeByFragment(ctx, input)
}
-func (g group) viaScene(ctx context.Context, client *http.Client, scene *models.Scene) (*models.ScrapedScene, error) {
+func (g group) viaScene(ctx context.Context, client *http.Client, scene *models.Scene) (*ScrapedScene, error) {
if g.config.SceneByFragment == nil {
return nil, ErrNotSupported
}
@@ -68,7 +68,7 @@ func (g group) viaScene(ctx context.Context, client *http.Client, scene *models.
return s.scrapeSceneByScene(ctx, scene)
}
-func (g group) viaGallery(ctx context.Context, client *http.Client, gallery *models.Gallery) (*models.ScrapedGallery, error) {
+func (g group) viaGallery(ctx context.Context, client *http.Client, gallery *models.Gallery) (*ScrapedGallery, error) {
if g.config.GalleryByFragment == nil {
return nil, ErrNotSupported
}
@@ -77,22 +77,22 @@ func (g group) viaGallery(ctx context.Context, client *http.Client, gallery *mod
return s.scrapeGalleryByGallery(ctx, gallery)
}
-func loadUrlCandidates(c config, ty models.ScrapeContentType) []*scrapeByURLConfig {
+func loadUrlCandidates(c config, ty ScrapeContentType) []*scrapeByURLConfig {
switch ty {
- case models.ScrapeContentTypePerformer:
+ case ScrapeContentTypePerformer:
return c.PerformerByURL
- case models.ScrapeContentTypeScene:
+ case ScrapeContentTypeScene:
return c.SceneByURL
- case models.ScrapeContentTypeMovie:
+ case ScrapeContentTypeMovie:
return c.MovieByURL
- case models.ScrapeContentTypeGallery:
+ case ScrapeContentTypeGallery:
return c.GalleryByURL
}
panic("loadUrlCandidates: unreachable")
}
-func (g group) viaURL(ctx context.Context, client *http.Client, url string, ty models.ScrapeContentType) (models.ScrapedContent, error) {
+func (g group) viaURL(ctx context.Context, client *http.Client, url string, ty ScrapeContentType) (ScrapedContent, error) {
candidates := loadUrlCandidates(g.config, ty)
for _, scraper := range candidates {
if scraper.matchesURL(url) {
@@ -111,16 +111,16 @@ func (g group) viaURL(ctx context.Context, client *http.Client, url string, ty m
return nil, nil
}
-func (g group) viaName(ctx context.Context, client *http.Client, name string, ty models.ScrapeContentType) ([]models.ScrapedContent, error) {
+func (g group) viaName(ctx context.Context, client *http.Client, name string, ty ScrapeContentType) ([]ScrapedContent, error) {
switch ty {
- case models.ScrapeContentTypePerformer:
+ case ScrapeContentTypePerformer:
if g.config.PerformerByName == nil {
break
}
s := g.config.getScraper(*g.config.PerformerByName, client, g.txnManager, g.globalConf)
return s.scrapeByName(ctx, name, ty)
- case models.ScrapeContentTypeScene:
+ case ScrapeContentTypeScene:
if g.config.SceneByName == nil {
break
}
@@ -132,10 +132,10 @@ func (g group) viaName(ctx context.Context, client *http.Client, name string, ty
return nil, fmt.Errorf("%w: cannot load %v by name", ErrNotSupported, ty)
}
-func (g group) supports(ty models.ScrapeContentType) bool {
+func (g group) supports(ty ScrapeContentType) bool {
return g.config.supports(ty)
}
-func (g group) supportsURL(url string, ty models.ScrapeContentType) bool {
+func (g group) supportsURL(url string, ty ScrapeContentType) bool {
return g.config.matchesURL(url, ty)
}
diff --git a/pkg/scraper/image.go b/pkg/scraper/image.go
index 9b48f9f4f..5757bc9b3 100644
--- a/pkg/scraper/image.go
+++ b/pkg/scraper/image.go
@@ -29,7 +29,7 @@ func setPerformerImage(ctx context.Context, client *http.Client, p *models.Scrap
return nil
}
-func setSceneImage(ctx context.Context, client *http.Client, s *models.ScrapedScene, globalConfig GlobalConfig) error {
+func setSceneImage(ctx context.Context, client *http.Client, s *ScrapedScene, globalConfig GlobalConfig) error {
// don't try to get the image if it doesn't appear to be a URL
if s.Image == nil || !strings.HasPrefix(*s.Image, "http") {
// nothing to do
diff --git a/pkg/scraper/json.go b/pkg/scraper/json.go
index cca0556d0..dbcba38ef 100644
--- a/pkg/scraper/json.go
+++ b/pkg/scraper/json.go
@@ -75,7 +75,7 @@ func (s *jsonScraper) loadURL(ctx context.Context, url string) (string, error) {
return docStr, err
}
-func (s *jsonScraper) scrapeByURL(ctx context.Context, url string, ty models.ScrapeContentType) (models.ScrapedContent, error) {
+func (s *jsonScraper) scrapeByURL(ctx context.Context, url string, ty ScrapeContentType) (ScrapedContent, error) {
u := replaceURL(url, s.scraper) // allow a URL Replace for url-queries
doc, scraper, err := s.scrapeURL(ctx, u)
if err != nil {
@@ -84,20 +84,20 @@ func (s *jsonScraper) scrapeByURL(ctx context.Context, url string, ty models.Scr
q := s.getJsonQuery(doc)
switch ty {
- case models.ScrapeContentTypePerformer:
+ case ScrapeContentTypePerformer:
return scraper.scrapePerformer(ctx, q)
- case models.ScrapeContentTypeScene:
+ case ScrapeContentTypeScene:
return scraper.scrapeScene(ctx, q)
- case models.ScrapeContentTypeGallery:
+ case ScrapeContentTypeGallery:
return scraper.scrapeGallery(ctx, q)
- case models.ScrapeContentTypeMovie:
+ case ScrapeContentTypeMovie:
return scraper.scrapeMovie(ctx, q)
}
return nil, ErrNotSupported
}
-func (s *jsonScraper) scrapeByName(ctx context.Context, name string, ty models.ScrapeContentType) ([]models.ScrapedContent, error) {
+func (s *jsonScraper) scrapeByName(ctx context.Context, name string, ty ScrapeContentType) ([]ScrapedContent, error) {
scraper := s.getJsonScraper()
if scraper == nil {
@@ -121,9 +121,9 @@ func (s *jsonScraper) scrapeByName(ctx context.Context, name string, ty models.S
q := s.getJsonQuery(doc)
q.setType(SearchQuery)
- var content []models.ScrapedContent
+ var content []ScrapedContent
switch ty {
- case models.ScrapeContentTypePerformer:
+ case ScrapeContentTypePerformer:
performers, err := scraper.scrapePerformers(ctx, q)
if err != nil {
return nil, err
@@ -134,7 +134,7 @@ func (s *jsonScraper) scrapeByName(ctx context.Context, name string, ty models.S
}
return content, nil
- case models.ScrapeContentTypeScene:
+ case ScrapeContentTypeScene:
scenes, err := scraper.scrapeScenes(ctx, q)
if err != nil {
return nil, err
@@ -150,7 +150,7 @@ func (s *jsonScraper) scrapeByName(ctx context.Context, name string, ty models.S
return nil, ErrNotSupported
}
-func (s *jsonScraper) scrapeSceneByScene(ctx context.Context, scene *models.Scene) (*models.ScrapedScene, error) {
+func (s *jsonScraper) scrapeSceneByScene(ctx context.Context, scene *models.Scene) (*ScrapedScene, error) {
// construct the URL
queryURL := queryURLParametersFromScene(scene)
if s.scraper.QueryURLReplacements != nil {
@@ -174,7 +174,7 @@ func (s *jsonScraper) scrapeSceneByScene(ctx context.Context, scene *models.Scen
return scraper.scrapeScene(ctx, q)
}
-func (s *jsonScraper) scrapeByFragment(ctx context.Context, input Input) (models.ScrapedContent, error) {
+func (s *jsonScraper) scrapeByFragment(ctx context.Context, input Input) (ScrapedContent, error) {
switch {
case input.Gallery != nil:
return nil, fmt.Errorf("%w: cannot use a json scraper as a gallery fragment scraper", ErrNotSupported)
@@ -209,7 +209,7 @@ func (s *jsonScraper) scrapeByFragment(ctx context.Context, input Input) (models
return scraper.scrapeScene(ctx, q)
}
-func (s *jsonScraper) scrapeGalleryByGallery(ctx context.Context, gallery *models.Gallery) (*models.ScrapedGallery, error) {
+func (s *jsonScraper) scrapeGalleryByGallery(ctx context.Context, gallery *models.Gallery) (*ScrapedGallery, error) {
// construct the URL
queryURL := queryURLParametersFromGallery(gallery)
if s.scraper.QueryURLReplacements != nil {
diff --git a/pkg/scraper/mapped.go b/pkg/scraper/mapped.go
index 4753bb182..e10d1ed65 100644
--- a/pkg/scraper/mapped.go
+++ b/pkg/scraper/mapped.go
@@ -809,8 +809,8 @@ func (s mappedScraper) scrapePerformers(ctx context.Context, q mappedQuery) ([]*
return ret, nil
}
-func (s mappedScraper) processScene(ctx context.Context, q mappedQuery, r mappedResult) *models.ScrapedScene {
- var ret models.ScrapedScene
+func (s mappedScraper) processScene(ctx context.Context, q mappedQuery, r mappedResult) *ScrapedScene {
+ var ret ScrapedScene
sceneScraperConfig := s.Scene
@@ -884,8 +884,8 @@ func (s mappedScraper) processScene(ctx context.Context, q mappedQuery, r mapped
return &ret
}
-func (s mappedScraper) scrapeScenes(ctx context.Context, q mappedQuery) ([]*models.ScrapedScene, error) {
- var ret []*models.ScrapedScene
+func (s mappedScraper) scrapeScenes(ctx context.Context, q mappedQuery) ([]*ScrapedScene, error) {
+ var ret []*ScrapedScene
sceneScraperConfig := s.Scene
sceneMap := sceneScraperConfig.mappedConfig
@@ -903,8 +903,8 @@ func (s mappedScraper) scrapeScenes(ctx context.Context, q mappedQuery) ([]*mode
return ret, nil
}
-func (s mappedScraper) scrapeScene(ctx context.Context, q mappedQuery) (*models.ScrapedScene, error) {
- var ret *models.ScrapedScene
+func (s mappedScraper) scrapeScene(ctx context.Context, q mappedQuery) (*ScrapedScene, error) {
+ var ret *ScrapedScene
sceneScraperConfig := s.Scene
sceneMap := sceneScraperConfig.mappedConfig
@@ -921,8 +921,8 @@ func (s mappedScraper) scrapeScene(ctx context.Context, q mappedQuery) (*models.
return ret, nil
}
-func (s mappedScraper) scrapeGallery(ctx context.Context, q mappedQuery) (*models.ScrapedGallery, error) {
- var ret *models.ScrapedGallery
+func (s mappedScraper) scrapeGallery(ctx context.Context, q mappedQuery) (*ScrapedGallery, error) {
+ var ret *ScrapedGallery
galleryScraperConfig := s.Gallery
galleryMap := galleryScraperConfig.mappedConfig
@@ -937,7 +937,7 @@ func (s mappedScraper) scrapeGallery(ctx context.Context, q mappedQuery) (*model
logger.Debug(`Processing gallery:`)
results := galleryMap.process(ctx, q, s.Common)
if len(results) > 0 {
- ret = &models.ScrapedGallery{}
+ ret = &ScrapedGallery{}
results[0].apply(ret)
diff --git a/pkg/scraper/movie.go b/pkg/scraper/movie.go
new file mode 100644
index 000000000..4416b6199
--- /dev/null
+++ b/pkg/scraper/movie.go
@@ -0,0 +1,12 @@
+package scraper
+
+type ScrapedMovieInput struct {
+ Name *string `json:"name"`
+ Aliases *string `json:"aliases"`
+ Duration *string `json:"duration"`
+ Date *string `json:"date"`
+ Rating *string `json:"rating"`
+ Director *string `json:"director"`
+ URL *string `json:"url"`
+ Synopsis *string `json:"synopsis"`
+}
diff --git a/pkg/scraper/performer.go b/pkg/scraper/performer.go
new file mode 100644
index 000000000..f97250736
--- /dev/null
+++ b/pkg/scraper/performer.go
@@ -0,0 +1,27 @@
+package scraper
+
+type ScrapedPerformerInput struct {
+ // Set if performer matched
+ StoredID *string `json:"stored_id"`
+ Name *string `json:"name"`
+ Gender *string `json:"gender"`
+ URL *string `json:"url"`
+ Twitter *string `json:"twitter"`
+ Instagram *string `json:"instagram"`
+ Birthdate *string `json:"birthdate"`
+ Ethnicity *string `json:"ethnicity"`
+ Country *string `json:"country"`
+ EyeColor *string `json:"eye_color"`
+ Height *string `json:"height"`
+ Measurements *string `json:"measurements"`
+ FakeTits *string `json:"fake_tits"`
+ CareerLength *string `json:"career_length"`
+ Tattoos *string `json:"tattoos"`
+ Piercings *string `json:"piercings"`
+ Aliases *string `json:"aliases"`
+ Details *string `json:"details"`
+ DeathDate *string `json:"death_date"`
+ HairColor *string `json:"hair_color"`
+ Weight *string `json:"weight"`
+ RemoteSiteID *string `json:"remote_site_id"`
+}
diff --git a/pkg/scraper/postprocessing.go b/pkg/scraper/postprocessing.go
index 731769310..ded3bc816 100644
--- a/pkg/scraper/postprocessing.go
+++ b/pkg/scraper/postprocessing.go
@@ -11,7 +11,7 @@ import (
// postScrape handles post-processing of scraped content. If the content
// requires post-processing, this function fans out to the given content
// type and post-processes it.
-func (c Cache) postScrape(ctx context.Context, content models.ScrapedContent) (models.ScrapedContent, error) {
+func (c Cache) postScrape(ctx context.Context, content ScrapedContent) (ScrapedContent, error) {
// Analyze the concrete type, call the right post-processing function
switch v := content.(type) {
case *models.ScrapedPerformer:
@@ -20,17 +20,17 @@ func (c Cache) postScrape(ctx context.Context, content models.ScrapedContent) (m
}
case models.ScrapedPerformer:
return c.postScrapePerformer(ctx, v)
- case *models.ScrapedScene:
+ case *ScrapedScene:
if v != nil {
return c.postScrapeScene(ctx, *v)
}
- case models.ScrapedScene:
+ case ScrapedScene:
return c.postScrapeScene(ctx, v)
- case *models.ScrapedGallery:
+ case *ScrapedGallery:
if v != nil {
return c.postScrapeGallery(ctx, *v)
}
- case models.ScrapedGallery:
+ case ScrapedGallery:
return c.postScrapeGallery(ctx, v)
case *models.ScrapedMovie:
if v != nil {
@@ -44,7 +44,7 @@ func (c Cache) postScrape(ctx context.Context, content models.ScrapedContent) (m
return content, nil
}
-func (c Cache) postScrapePerformer(ctx context.Context, p models.ScrapedPerformer) (models.ScrapedContent, error) {
+func (c Cache) postScrapePerformer(ctx context.Context, p models.ScrapedPerformer) (ScrapedContent, error) {
if err := c.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
tqb := r.Tag()
@@ -67,7 +67,7 @@ func (c Cache) postScrapePerformer(ctx context.Context, p models.ScrapedPerforme
return p, nil
}
-func (c Cache) postScrapeMovie(ctx context.Context, m models.ScrapedMovie) (models.ScrapedContent, error) {
+func (c Cache) postScrapeMovie(ctx context.Context, m models.ScrapedMovie) (ScrapedContent, error) {
if m.Studio != nil {
if err := c.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
return match.ScrapedStudio(r.Studio(), m.Studio, nil)
@@ -105,7 +105,7 @@ func (c Cache) postScrapeScenePerformer(ctx context.Context, p models.ScrapedPer
return nil
}
-func (c Cache) postScrapeScene(ctx context.Context, scene models.ScrapedScene) (models.ScrapedContent, error) {
+func (c Cache) postScrapeScene(ctx context.Context, scene ScrapedScene) (ScrapedContent, error) {
if err := c.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
pqb := r.Performer()
mqb := r.Movie()
@@ -159,7 +159,7 @@ func (c Cache) postScrapeScene(ctx context.Context, scene models.ScrapedScene) (
return scene, nil
}
-func (c Cache) postScrapeGallery(ctx context.Context, g models.ScrapedGallery) (models.ScrapedContent, error) {
+func (c Cache) postScrapeGallery(ctx context.Context, g ScrapedGallery) (ScrapedContent, error) {
if err := c.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
pqb := r.Performer()
tqb := r.Tag()
diff --git a/pkg/scraper/query_url.go b/pkg/scraper/query_url.go
index 2826e15e4..398e86f9d 100644
--- a/pkg/scraper/query_url.go
+++ b/pkg/scraper/query_url.go
@@ -21,7 +21,7 @@ func queryURLParametersFromScene(scene *models.Scene) queryURLParameters {
return ret
}
-func queryURLParametersFromScrapedScene(scene models.ScrapedSceneInput) queryURLParameters {
+func queryURLParametersFromScrapedScene(scene ScrapedSceneInput) queryURLParameters {
ret := make(queryURLParameters)
setField := func(field string, value *string) {
diff --git a/pkg/scraper/scene.go b/pkg/scraper/scene.go
new file mode 100644
index 000000000..9b5a60191
--- /dev/null
+++ b/pkg/scraper/scene.go
@@ -0,0 +1,32 @@
+package scraper
+
+import (
+ "github.com/stashapp/stash/pkg/models"
+)
+
+type ScrapedScene struct {
+ Title *string `json:"title"`
+ Details *string `json:"details"`
+ URL *string `json:"url"`
+ Date *string `json:"date"`
+ // This should be a base64 encoded data URL
+ Image *string `json:"image"`
+ File *models.SceneFileType `json:"file"`
+ Studio *models.ScrapedStudio `json:"studio"`
+ Tags []*models.ScrapedTag `json:"tags"`
+ Performers []*models.ScrapedPerformer `json:"performers"`
+ Movies []*models.ScrapedMovie `json:"movies"`
+ RemoteSiteID *string `json:"remote_site_id"`
+ Duration *int `json:"duration"`
+ Fingerprints []*models.StashBoxFingerprint `json:"fingerprints"`
+}
+
+func (ScrapedScene) IsScrapedContent() {}
+
+type ScrapedSceneInput struct {
+ Title *string `json:"title"`
+ Details *string `json:"details"`
+ URL *string `json:"url"`
+ Date *string `json:"date"`
+ RemoteSiteID *string `json:"remote_site_id"`
+}
diff --git a/pkg/scraper/scraper.go b/pkg/scraper/scraper.go
index 3a2fcc054..67569c18b 100644
--- a/pkg/scraper/scraper.go
+++ b/pkg/scraper/scraper.go
@@ -3,11 +3,139 @@ package scraper
import (
"context"
"errors"
+ "fmt"
+ "io"
"net/http"
+ "strconv"
"github.com/stashapp/stash/pkg/models"
)
+type Source struct {
+ // Index of the configured stash-box instance to use. Should be unset if scraper_id is set
+ StashBoxIndex *int `json:"stash_box_index"`
+ // Stash-box endpoint
+ StashBoxEndpoint *string `json:"stash_box_endpoint"`
+ // Scraper ID to scrape with. Should be unset if stash_box_index is set
+ ScraperID *string `json:"scraper_id"`
+}
+
+// Scraped Content is the forming union over the different scrapers
+type ScrapedContent interface {
+ IsScrapedContent()
+}
+
+// Type of the content a scraper generates
+type ScrapeContentType string
+
+const (
+ ScrapeContentTypeGallery ScrapeContentType = "GALLERY"
+ ScrapeContentTypeMovie ScrapeContentType = "MOVIE"
+ ScrapeContentTypePerformer ScrapeContentType = "PERFORMER"
+ ScrapeContentTypeScene ScrapeContentType = "SCENE"
+)
+
+var AllScrapeContentType = []ScrapeContentType{
+ ScrapeContentTypeGallery,
+ ScrapeContentTypeMovie,
+ ScrapeContentTypePerformer,
+ ScrapeContentTypeScene,
+}
+
+func (e ScrapeContentType) IsValid() bool {
+ switch e {
+ case ScrapeContentTypeGallery, ScrapeContentTypeMovie, ScrapeContentTypePerformer, ScrapeContentTypeScene:
+ return true
+ }
+ return false
+}
+
+func (e ScrapeContentType) String() string {
+ return string(e)
+}
+
+func (e *ScrapeContentType) UnmarshalGQL(v interface{}) error {
+ str, ok := v.(string)
+ if !ok {
+ return fmt.Errorf("enums must be strings")
+ }
+
+ *e = ScrapeContentType(str)
+ if !e.IsValid() {
+ return fmt.Errorf("%s is not a valid ScrapeContentType", str)
+ }
+ return nil
+}
+
+func (e ScrapeContentType) MarshalGQL(w io.Writer) {
+ fmt.Fprint(w, strconv.Quote(e.String()))
+}
+
+type Scraper struct {
+ ID string `json:"id"`
+ Name string `json:"name"`
+ // Details for performer scraper
+ Performer *ScraperSpec `json:"performer"`
+ // Details for scene scraper
+ Scene *ScraperSpec `json:"scene"`
+ // Details for gallery scraper
+ Gallery *ScraperSpec `json:"gallery"`
+ // Details for movie scraper
+ Movie *ScraperSpec `json:"movie"`
+}
+
+type ScraperSpec struct {
+ // URLs matching these can be scraped with
+ Urls []string `json:"urls"`
+ SupportedScrapes []ScrapeType `json:"supported_scrapes"`
+}
+
+type ScrapeType string
+
+const (
+ // From text query
+ ScrapeTypeName ScrapeType = "NAME"
+ // From existing object
+ ScrapeTypeFragment ScrapeType = "FRAGMENT"
+ // From URL
+ ScrapeTypeURL ScrapeType = "URL"
+)
+
+var AllScrapeType = []ScrapeType{
+ ScrapeTypeName,
+ ScrapeTypeFragment,
+ ScrapeTypeURL,
+}
+
+func (e ScrapeType) IsValid() bool {
+ switch e {
+ case ScrapeTypeName, ScrapeTypeFragment, ScrapeTypeURL:
+ return true
+ }
+ return false
+}
+
+func (e ScrapeType) String() string {
+ return string(e)
+}
+
+func (e *ScrapeType) UnmarshalGQL(v interface{}) error {
+ str, ok := v.(string)
+ if !ok {
+ return fmt.Errorf("enums must be strings")
+ }
+
+ *e = ScrapeType(str)
+ if !e.IsValid() {
+ return fmt.Errorf("%s is not a valid ScrapeType", str)
+ }
+ return nil
+}
+
+func (e ScrapeType) MarshalGQL(w io.Writer) {
+ fmt.Fprint(w, strconv.Quote(e.String()))
+}
+
var (
// ErrMaxRedirects is returned if the max number of HTTP redirects are reached.
ErrMaxRedirects = errors.New("maximum number of HTTP redirects reached")
@@ -24,9 +152,9 @@ var (
// The system expects one of these to be set, and the remaining to be
// set to nil.
type Input struct {
- Performer *models.ScrapedPerformerInput
- Scene *models.ScrapedSceneInput
- Gallery *models.ScrapedGalleryInput
+ Performer *ScrapedPerformerInput
+ Scene *ScrapedSceneInput
+ Gallery *ScrapedGalleryInput
}
// simple type definitions that can help customize
@@ -41,32 +169,32 @@ const (
// scraper is the generic interface to the scraper subsystems
type scraper interface {
// spec returns the scraper specification, suitable for graphql
- spec() models.Scraper
+ spec() Scraper
// supports tests if the scraper supports a given content type
- supports(models.ScrapeContentType) bool
+ supports(ScrapeContentType) bool
// supportsURL tests if the scraper supports scrapes of a given url, producing a given content type
- supportsURL(url string, ty models.ScrapeContentType) bool
+ supportsURL(url string, ty ScrapeContentType) bool
}
// urlScraper is the interface of scrapers supporting url loads
type urlScraper interface {
scraper
- viaURL(ctx context.Context, client *http.Client, url string, ty models.ScrapeContentType) (models.ScrapedContent, error)
+ viaURL(ctx context.Context, client *http.Client, url string, ty ScrapeContentType) (ScrapedContent, error)
}
// nameScraper is the interface of scrapers supporting name loads
type nameScraper interface {
scraper
- viaName(ctx context.Context, client *http.Client, name string, ty models.ScrapeContentType) ([]models.ScrapedContent, error)
+ viaName(ctx context.Context, client *http.Client, name string, ty ScrapeContentType) ([]ScrapedContent, error)
}
// fragmentScraper is the interface of scrapers supporting fragment loads
type fragmentScraper interface {
scraper
- viaFragment(ctx context.Context, client *http.Client, input Input) (models.ScrapedContent, error)
+ viaFragment(ctx context.Context, client *http.Client, input Input) (ScrapedContent, error)
}
// sceneScraper is a scraper which supports scene scrapes with
@@ -74,7 +202,7 @@ type fragmentScraper interface {
type sceneScraper interface {
scraper
- viaScene(ctx context.Context, client *http.Client, scene *models.Scene) (*models.ScrapedScene, error)
+ viaScene(ctx context.Context, client *http.Client, scene *models.Scene) (*ScrapedScene, error)
}
// galleryScraper is a scraper which supports gallery scrapes with
@@ -82,5 +210,5 @@ type sceneScraper interface {
type galleryScraper interface {
scraper
- viaGallery(ctx context.Context, client *http.Client, gallery *models.Gallery) (*models.ScrapedGallery, error)
+ viaGallery(ctx context.Context, client *http.Client, gallery *models.Gallery) (*ScrapedGallery, error)
}
diff --git a/pkg/scraper/script.go b/pkg/scraper/script.go
index 4a9074fc2..5d6e25acf 100644
--- a/pkg/scraper/script.go
+++ b/pkg/scraper/script.go
@@ -125,13 +125,13 @@ func (s *scriptScraper) runScraperScript(ctx context.Context, inString string, o
return nil
}
-func (s *scriptScraper) scrapeByName(ctx context.Context, name string, ty models.ScrapeContentType) ([]models.ScrapedContent, error) {
+func (s *scriptScraper) scrapeByName(ctx context.Context, name string, ty ScrapeContentType) ([]ScrapedContent, error) {
input := `{"name": "` + name + `"}`
- var ret []models.ScrapedContent
+ var ret []ScrapedContent
var err error
switch ty {
- case models.ScrapeContentTypePerformer:
+ case ScrapeContentTypePerformer:
var performers []models.ScrapedPerformer
err = s.runScraperScript(ctx, input, &performers)
if err == nil {
@@ -140,8 +140,8 @@ func (s *scriptScraper) scrapeByName(ctx context.Context, name string, ty models
ret = append(ret, &v)
}
}
- case models.ScrapeContentTypeScene:
- var scenes []models.ScrapedScene
+ case ScrapeContentTypeScene:
+ var scenes []ScrapedScene
err = s.runScraperScript(ctx, input, &scenes)
if err == nil {
for _, s := range scenes {
@@ -156,20 +156,20 @@ func (s *scriptScraper) scrapeByName(ctx context.Context, name string, ty models
return ret, err
}
-func (s *scriptScraper) scrapeByFragment(ctx context.Context, input Input) (models.ScrapedContent, error) {
+func (s *scriptScraper) scrapeByFragment(ctx context.Context, input Input) (ScrapedContent, error) {
var inString []byte
var err error
- var ty models.ScrapeContentType
+ var ty ScrapeContentType
switch {
case input.Performer != nil:
inString, err = json.Marshal(*input.Performer)
- ty = models.ScrapeContentTypePerformer
+ ty = ScrapeContentTypePerformer
case input.Gallery != nil:
inString, err = json.Marshal(*input.Gallery)
- ty = models.ScrapeContentTypeGallery
+ ty = ScrapeContentTypeGallery
case input.Scene != nil:
inString, err = json.Marshal(*input.Scene)
- ty = models.ScrapeContentTypeScene
+ ty = ScrapeContentTypeScene
}
if err != nil {
@@ -179,25 +179,25 @@ func (s *scriptScraper) scrapeByFragment(ctx context.Context, input Input) (mode
return s.scrape(ctx, string(inString), ty)
}
-func (s *scriptScraper) scrapeByURL(ctx context.Context, url string, ty models.ScrapeContentType) (models.ScrapedContent, error) {
+func (s *scriptScraper) scrapeByURL(ctx context.Context, url string, ty ScrapeContentType) (ScrapedContent, error) {
return s.scrape(ctx, `{"url": "`+url+`"}`, ty)
}
-func (s *scriptScraper) scrape(ctx context.Context, input string, ty models.ScrapeContentType) (models.ScrapedContent, error) {
+func (s *scriptScraper) scrape(ctx context.Context, input string, ty ScrapeContentType) (ScrapedContent, error) {
switch ty {
- case models.ScrapeContentTypePerformer:
+ case ScrapeContentTypePerformer:
var performer *models.ScrapedPerformer
err := s.runScraperScript(ctx, input, &performer)
return performer, err
- case models.ScrapeContentTypeGallery:
- var gallery *models.ScrapedGallery
+ case ScrapeContentTypeGallery:
+ var gallery *ScrapedGallery
err := s.runScraperScript(ctx, input, &gallery)
return gallery, err
- case models.ScrapeContentTypeScene:
- var scene *models.ScrapedScene
+ case ScrapeContentTypeScene:
+ var scene *ScrapedScene
err := s.runScraperScript(ctx, input, &scene)
return scene, err
- case models.ScrapeContentTypeMovie:
+ case ScrapeContentTypeMovie:
var movie *models.ScrapedMovie
err := s.runScraperScript(ctx, input, &movie)
return movie, err
@@ -206,28 +206,28 @@ func (s *scriptScraper) scrape(ctx context.Context, input string, ty models.Scra
return nil, ErrNotSupported
}
-func (s *scriptScraper) scrapeSceneByScene(ctx context.Context, scene *models.Scene) (*models.ScrapedScene, error) {
+func (s *scriptScraper) scrapeSceneByScene(ctx context.Context, scene *models.Scene) (*ScrapedScene, error) {
inString, err := json.Marshal(sceneToUpdateInput(scene))
if err != nil {
return nil, err
}
- var ret *models.ScrapedScene
+ var ret *ScrapedScene
err = s.runScraperScript(ctx, string(inString), &ret)
return ret, err
}
-func (s *scriptScraper) scrapeGalleryByGallery(ctx context.Context, gallery *models.Gallery) (*models.ScrapedGallery, error) {
+func (s *scriptScraper) scrapeGalleryByGallery(ctx context.Context, gallery *models.Gallery) (*ScrapedGallery, error) {
inString, err := json.Marshal(galleryToUpdateInput(gallery))
if err != nil {
return nil, err
}
- var ret *models.ScrapedGallery
+ var ret *ScrapedGallery
err = s.runScraperScript(ctx, string(inString), &ret)
diff --git a/pkg/scraper/stash.go b/pkg/scraper/stash.go
index 8193f2a67..b6e0e7696 100644
--- a/pkg/scraper/stash.go
+++ b/pkg/scraper/stash.go
@@ -83,7 +83,7 @@ type scrapedPerformerStash struct {
Weight *string `graphql:"weight" json:"weight"`
}
-func (s *stashScraper) scrapeByFragment(ctx context.Context, input Input) (models.ScrapedContent, error) {
+func (s *stashScraper) scrapeByFragment(ctx context.Context, input Input) (ScrapedContent, error) {
if input.Gallery != nil || input.Scene != nil {
return nil, fmt.Errorf("%w: using stash scraper as a fragment scraper", ErrNotSupported)
}
@@ -138,8 +138,8 @@ type stashFindSceneNamesResultType struct {
Scenes []*scrapedSceneStash `graphql:"scenes"`
}
-func (s *stashScraper) scrapedStashSceneToScrapedScene(ctx context.Context, scene *scrapedSceneStash) (*models.ScrapedScene, error) {
- ret := models.ScrapedScene{}
+func (s *stashScraper) scrapedStashSceneToScrapedScene(ctx context.Context, scene *scrapedSceneStash) (*ScrapedScene, error) {
+ ret := ScrapedScene{}
err := copier.Copy(&ret, scene)
if err != nil {
return nil, err
@@ -154,7 +154,7 @@ func (s *stashScraper) scrapedStashSceneToScrapedScene(ctx context.Context, scen
return &ret, nil
}
-func (s *stashScraper) scrapeByName(ctx context.Context, name string, ty models.ScrapeContentType) ([]models.ScrapedContent, error) {
+func (s *stashScraper) scrapeByName(ctx context.Context, name string, ty ScrapeContentType) ([]ScrapedContent, error) {
client := s.getStashClient()
page := 1
@@ -168,9 +168,9 @@ func (s *stashScraper) scrapeByName(ctx context.Context, name string, ty models.
},
}
- var ret []models.ScrapedContent
+ var ret []ScrapedContent
switch ty {
- case models.ScrapeContentTypeScene:
+ case ScrapeContentTypeScene:
var q struct {
FindScenes stashFindSceneNamesResultType `graphql:"findScenes(filter: $f)"`
}
@@ -189,7 +189,7 @@ func (s *stashScraper) scrapeByName(ctx context.Context, name string, ty models.
}
return ret, nil
- case models.ScrapeContentTypePerformer:
+ case ScrapeContentTypePerformer:
var q struct {
FindPerformers stashFindPerformerNamesResultType `graphql:"findPerformers(filter: $f)"`
}
@@ -221,7 +221,7 @@ type scrapedSceneStash struct {
Performers []*scrapedPerformerStash `graphql:"performers" json:"performers"`
}
-func (s *stashScraper) scrapeSceneByScene(ctx context.Context, scene *models.Scene) (*models.ScrapedScene, error) {
+func (s *stashScraper) scrapeSceneByScene(ctx context.Context, scene *models.Scene) (*ScrapedScene, error) {
// query by MD5
var q struct {
FindScene *scrapedSceneStash `graphql:"findSceneByHash(input: $c)"`
@@ -273,7 +273,7 @@ type scrapedGalleryStash struct {
Performers []*scrapedPerformerStash `graphql:"performers" json:"performers"`
}
-func (s *stashScraper) scrapeGalleryByGallery(ctx context.Context, gallery *models.Gallery) (*models.ScrapedGallery, error) {
+func (s *stashScraper) scrapeGalleryByGallery(ctx context.Context, gallery *models.Gallery) (*ScrapedGallery, error) {
var q struct {
FindGallery *scrapedGalleryStash `graphql:"findGalleryByHash(input: $c)"`
}
@@ -296,7 +296,7 @@ func (s *stashScraper) scrapeGalleryByGallery(ctx context.Context, gallery *mode
}
// need to copy back to a scraped scene
- ret := models.ScrapedGallery{}
+ ret := ScrapedGallery{}
if err := copier.Copy(&ret, q.FindGallery); err != nil {
return nil, err
}
@@ -304,7 +304,7 @@ func (s *stashScraper) scrapeGalleryByGallery(ctx context.Context, gallery *mode
return &ret, nil
}
-func (s *stashScraper) scrapeByURL(_ context.Context, _ string, _ models.ScrapeContentType) (models.ScrapedContent, error) {
+func (s *stashScraper) scrapeByURL(_ context.Context, _ string, _ ScrapeContentType) (ScrapedContent, error) {
return nil, ErrNotSupported
}
diff --git a/pkg/scraper/stashbox/models.go b/pkg/scraper/stashbox/models.go
new file mode 100644
index 000000000..60ef5b028
--- /dev/null
+++ b/pkg/scraper/stashbox/models.go
@@ -0,0 +1,8 @@
+package stashbox
+
+import "github.com/stashapp/stash/pkg/models"
+
+type StashBoxPerformerQueryResult struct {
+ Query string `json:"query"`
+ Results []*models.ScrapedPerformer `json:"results"`
+}
diff --git a/pkg/scraper/stashbox/stash_box.go b/pkg/scraper/stashbox/stash_box.go
index abcadc0f2..923cd4482 100644
--- a/pkg/scraper/stashbox/stash_box.go
+++ b/pkg/scraper/stashbox/stash_box.go
@@ -21,6 +21,7 @@ import (
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/match"
"github.com/stashapp/stash/pkg/models"
+ "github.com/stashapp/stash/pkg/scraper"
"github.com/stashapp/stash/pkg/scraper/stashbox/graphql"
"github.com/stashapp/stash/pkg/sliceutil/stringslice"
"github.com/stashapp/stash/pkg/utils"
@@ -55,7 +56,7 @@ func (c Client) getHTTPClient() *http.Client {
}
// QueryStashBoxScene queries stash-box for scenes using a query string.
-func (c Client) QueryStashBoxScene(ctx context.Context, queryStr string) ([]*models.ScrapedScene, error) {
+func (c Client) QueryStashBoxScene(ctx context.Context, queryStr string) ([]*scraper.ScrapedScene, error) {
scenes, err := c.client.SearchScene(ctx, queryStr)
if err != nil {
return nil, err
@@ -63,7 +64,7 @@ func (c Client) QueryStashBoxScene(ctx context.Context, queryStr string) ([]*mod
sceneFragments := scenes.SearchScene
- var ret []*models.ScrapedScene
+ var ret []*scraper.ScrapedScene
for _, s := range sceneFragments {
ss, err := c.sceneFragmentToScrapedScene(ctx, s)
if err != nil {
@@ -77,7 +78,7 @@ func (c Client) QueryStashBoxScene(ctx context.Context, queryStr string) ([]*mod
// FindStashBoxScenesByFingerprints queries stash-box for a scene using the
// scene's MD5/OSHASH checksum, or PHash.
-func (c Client) FindStashBoxSceneByFingerprints(ctx context.Context, sceneID int) ([]*models.ScrapedScene, error) {
+func (c Client) FindStashBoxSceneByFingerprints(ctx context.Context, sceneID int) ([]*scraper.ScrapedScene, error) {
res, err := c.FindStashBoxScenesByFingerprints(ctx, []int{sceneID})
if len(res) > 0 {
return res[0], err
@@ -88,7 +89,7 @@ func (c Client) FindStashBoxSceneByFingerprints(ctx context.Context, sceneID int
// FindStashBoxScenesByFingerprints queries stash-box for scenes using every
// scene's MD5/OSHASH checksum, or PHash, and returns results in the same order
// as the input slice.
-func (c Client) FindStashBoxScenesByFingerprints(ctx context.Context, ids []int) ([][]*models.ScrapedScene, error) {
+func (c Client) FindStashBoxScenesByFingerprints(ctx context.Context, ids []int) ([][]*scraper.ScrapedScene, error) {
var fingerprints [][]*graphql.FingerprintQueryInput
if err := c.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
@@ -139,8 +140,8 @@ func (c Client) FindStashBoxScenesByFingerprints(ctx context.Context, ids []int)
return c.findStashBoxScenesByFingerprints(ctx, fingerprints)
}
-func (c Client) findStashBoxScenesByFingerprints(ctx context.Context, scenes [][]*graphql.FingerprintQueryInput) ([][]*models.ScrapedScene, error) {
- var ret [][]*models.ScrapedScene
+func (c Client) findStashBoxScenesByFingerprints(ctx context.Context, scenes [][]*graphql.FingerprintQueryInput) ([][]*scraper.ScrapedScene, error) {
+ var ret [][]*scraper.ScrapedScene
for i := 0; i < len(scenes); i += 40 {
end := i + 40
if end > len(scenes) {
@@ -153,7 +154,7 @@ func (c Client) findStashBoxScenesByFingerprints(ctx context.Context, scenes [][
}
for _, sceneFragments := range scenes.FindScenesBySceneFingerprints {
- var sceneResults []*models.ScrapedScene
+ var sceneResults []*scraper.ScrapedScene
for _, scene := range sceneFragments {
ss, err := c.sceneFragmentToScrapedScene(ctx, scene)
if err != nil {
@@ -260,10 +261,10 @@ func (c Client) submitStashBoxFingerprints(ctx context.Context, fingerprints []g
}
// QueryStashBoxPerformer queries stash-box for performers using a query string.
-func (c Client) QueryStashBoxPerformer(ctx context.Context, queryStr string) ([]*models.StashBoxPerformerQueryResult, error) {
+func (c Client) QueryStashBoxPerformer(ctx context.Context, queryStr string) ([]*StashBoxPerformerQueryResult, error) {
performers, err := c.queryStashBoxPerformer(ctx, queryStr)
- res := []*models.StashBoxPerformerQueryResult{
+ res := []*StashBoxPerformerQueryResult{
{
Query: queryStr,
Results: performers,
@@ -298,7 +299,7 @@ func (c Client) queryStashBoxPerformer(ctx context.Context, queryStr string) ([]
}
// FindStashBoxPerformersByNames queries stash-box for performers by name
-func (c Client) FindStashBoxPerformersByNames(ctx context.Context, performerIDs []string) ([]*models.StashBoxPerformerQueryResult, error) {
+func (c Client) FindStashBoxPerformersByNames(ctx context.Context, performerIDs []string) ([]*StashBoxPerformerQueryResult, error) {
ids, err := stringslice.StringSliceToIntSlice(performerIDs)
if err != nil {
return nil, err
@@ -376,8 +377,8 @@ func (c Client) FindStashBoxPerformersByPerformerNames(ctx context.Context, perf
return ret, nil
}
-func (c Client) findStashBoxPerformersByNames(ctx context.Context, performers []*models.Performer) ([]*models.StashBoxPerformerQueryResult, error) {
- var ret []*models.StashBoxPerformerQueryResult
+func (c Client) findStashBoxPerformersByNames(ctx context.Context, performers []*models.Performer) ([]*StashBoxPerformerQueryResult, error) {
+ var ret []*StashBoxPerformerQueryResult
for _, performer := range performers {
if performer.Name.Valid {
performerResults, err := c.queryStashBoxPerformer(ctx, performer.Name.String)
@@ -385,7 +386,7 @@ func (c Client) findStashBoxPerformersByNames(ctx context.Context, performers []
return nil, err
}
- result := models.StashBoxPerformerQueryResult{
+ result := StashBoxPerformerQueryResult{
Query: strconv.Itoa(performer.ID),
Results: performerResults,
}
@@ -601,9 +602,9 @@ func getFingerprints(scene *graphql.SceneFragment) []*models.StashBoxFingerprint
return fingerprints
}
-func (c Client) sceneFragmentToScrapedScene(ctx context.Context, s *graphql.SceneFragment) (*models.ScrapedScene, error) {
+func (c Client) sceneFragmentToScrapedScene(ctx context.Context, s *graphql.SceneFragment) (*scraper.ScrapedScene, error) {
stashID := s.ID
- ss := &models.ScrapedScene{
+ ss := &scraper.ScrapedScene{
Title: s.Title,
Date: s.Date,
Details: s.Details,
diff --git a/pkg/scraper/xpath.go b/pkg/scraper/xpath.go
index 79300d30b..092968b5e 100644
--- a/pkg/scraper/xpath.go
+++ b/pkg/scraper/xpath.go
@@ -56,7 +56,7 @@ func (s *xpathScraper) scrapeURL(ctx context.Context, url string) (*html.Node, *
return doc, scraper, nil
}
-func (s *xpathScraper) scrapeByURL(ctx context.Context, url string, ty models.ScrapeContentType) (models.ScrapedContent, error) {
+func (s *xpathScraper) scrapeByURL(ctx context.Context, url string, ty ScrapeContentType) (ScrapedContent, error) {
u := replaceURL(url, s.scraper) // allow a URL Replace for performer by URL queries
doc, scraper, err := s.scrapeURL(ctx, u)
if err != nil {
@@ -65,20 +65,20 @@ func (s *xpathScraper) scrapeByURL(ctx context.Context, url string, ty models.Sc
q := s.getXPathQuery(doc)
switch ty {
- case models.ScrapeContentTypePerformer:
+ case ScrapeContentTypePerformer:
return scraper.scrapePerformer(ctx, q)
- case models.ScrapeContentTypeScene:
+ case ScrapeContentTypeScene:
return scraper.scrapeScene(ctx, q)
- case models.ScrapeContentTypeGallery:
+ case ScrapeContentTypeGallery:
return scraper.scrapeGallery(ctx, q)
- case models.ScrapeContentTypeMovie:
+ case ScrapeContentTypeMovie:
return scraper.scrapeMovie(ctx, q)
}
return nil, ErrNotSupported
}
-func (s *xpathScraper) scrapeByName(ctx context.Context, name string, ty models.ScrapeContentType) ([]models.ScrapedContent, error) {
+func (s *xpathScraper) scrapeByName(ctx context.Context, name string, ty ScrapeContentType) ([]ScrapedContent, error) {
scraper := s.getXpathScraper()
if scraper == nil {
@@ -102,9 +102,9 @@ func (s *xpathScraper) scrapeByName(ctx context.Context, name string, ty models.
q := s.getXPathQuery(doc)
q.setType(SearchQuery)
- var content []models.ScrapedContent
+ var content []ScrapedContent
switch ty {
- case models.ScrapeContentTypePerformer:
+ case ScrapeContentTypePerformer:
performers, err := scraper.scrapePerformers(ctx, q)
if err != nil {
return nil, err
@@ -114,7 +114,7 @@ func (s *xpathScraper) scrapeByName(ctx context.Context, name string, ty models.
}
return content, nil
- case models.ScrapeContentTypeScene:
+ case ScrapeContentTypeScene:
scenes, err := scraper.scrapeScenes(ctx, q)
if err != nil {
return nil, err
@@ -129,7 +129,7 @@ func (s *xpathScraper) scrapeByName(ctx context.Context, name string, ty models.
return nil, ErrNotSupported
}
-func (s *xpathScraper) scrapeSceneByScene(ctx context.Context, scene *models.Scene) (*models.ScrapedScene, error) {
+func (s *xpathScraper) scrapeSceneByScene(ctx context.Context, scene *models.Scene) (*ScrapedScene, error) {
// construct the URL
queryURL := queryURLParametersFromScene(scene)
if s.scraper.QueryURLReplacements != nil {
@@ -153,7 +153,7 @@ func (s *xpathScraper) scrapeSceneByScene(ctx context.Context, scene *models.Sce
return scraper.scrapeScene(ctx, q)
}
-func (s *xpathScraper) scrapeByFragment(ctx context.Context, input Input) (models.ScrapedContent, error) {
+func (s *xpathScraper) scrapeByFragment(ctx context.Context, input Input) (ScrapedContent, error) {
switch {
case input.Gallery != nil:
return nil, fmt.Errorf("%w: cannot use an xpath scraper as a gallery fragment scraper", ErrNotSupported)
@@ -188,7 +188,7 @@ func (s *xpathScraper) scrapeByFragment(ctx context.Context, input Input) (model
return scraper.scrapeScene(ctx, q)
}
-func (s *xpathScraper) scrapeGalleryByGallery(ctx context.Context, gallery *models.Gallery) (*models.ScrapedGallery, error) {
+func (s *xpathScraper) scrapeGalleryByGallery(ctx context.Context, gallery *models.Gallery) (*ScrapedGallery, error) {
// construct the URL
queryURL := queryURLParametersFromGallery(gallery)
if s.scraper.QueryURLReplacements != nil {
diff --git a/pkg/scraper/xpath_test.go b/pkg/scraper/xpath_test.go
index 782b753f0..9aef91a23 100644
--- a/pkg/scraper/xpath_test.go
+++ b/pkg/scraper/xpath_test.go
@@ -890,7 +890,7 @@ xPathScrapers:
if !ok {
t.Error("couldn't convert scraper into url scraper")
}
- content, err := us.viaURL(ctx, client, ts.URL, models.ScrapeContentTypePerformer)
+ content, err := us.viaURL(ctx, client, ts.URL, ScrapeContentTypePerformer)
if err != nil {
t.Errorf("Error scraping performer: %s", err.Error())
From 964b559309193d1fb679b450c9fead1f239ad3bd Mon Sep 17 00:00:00 2001
From: WithoutPants <53250216+WithoutPants@users.noreply.github.com>
Date: Thu, 19 May 2022 17:49:32 +1000
Subject: [PATCH 02/30] Restructure data layer (#2532)
* Add new txn manager interface
* Add txn management to sqlite
* Rename get to getByID
* Add contexts to repository methods
* Update query builders
* Add context to reader writer interfaces
* Use repository in resolver
* Tighten interfaces
* Tighten interfaces in dlna
* Tighten interfaces in match package
* Tighten interfaces in scraper package
* Tighten interfaces in scan code
* Tighten interfaces on autotag package
* Remove ReaderWriter usage
* Merge database package into sqlite
---
internal/api/resolver.go | 70 +--
internal/api/resolver_model_gallery.go | 28 +-
internal/api/resolver_model_image.go | 16 +-
internal/api/resolver_model_movie.go | 16 +-
internal/api/resolver_model_performer.go | 32 +-
internal/api/resolver_model_scene.go | 38 +-
internal/api/resolver_model_scene_marker.go | 12 +-
internal/api/resolver_model_studio.go | 40 +-
internal/api/resolver_model_tag.go | 32 +-
internal/api/resolver_mutation_configure.go | 4 +-
internal/api/resolver_mutation_gallery.go | 137 +++--
internal/api/resolver_mutation_image.go | 100 +--
internal/api/resolver_mutation_metadata.go | 4 +-
internal/api/resolver_mutation_movie.go | 44 +-
internal/api/resolver_mutation_performer.go | 58 +-
.../api/resolver_mutation_saved_filter.go | 20 +-
internal/api/resolver_mutation_scene.go | 152 ++---
internal/api/resolver_mutation_stash_box.go | 28 +-
internal/api/resolver_mutation_studio.go | 46 +-
internal/api/resolver_mutation_tag.go | 74 +--
internal/api/resolver_mutation_tag_test.go | 44 +-
internal/api/resolver_query_configuration.go | 2 +-
internal/api/resolver_query_find_gallery.go | 8 +-
internal/api/resolver_query_find_image.go | 16 +-
internal/api/resolver_query_find_movie.go | 12 +-
internal/api/resolver_query_find_performer.go | 12 +-
.../api/resolver_query_find_saved_filter.go | 14 +-
internal/api/resolver_query_find_scene.go | 44 +-
.../api/resolver_query_find_scene_marker.go | 4 +-
internal/api/resolver_query_find_studio.go | 12 +-
internal/api/resolver_query_find_tag.go | 12 +-
internal/api/resolver_query_scene.go | 4 +-
internal/api/resolver_query_scraper.go | 2 +-
internal/api/routes_image.go | 21 +-
internal/api/routes_movie.go | 27 +-
internal/api/routes_performer.go | 22 +-
internal/api/routes_scene.go | 72 ++-
internal/api/routes_studio.go | 23 +-
internal/api/routes_tag.go | 23 +-
internal/api/server.go | 21 +-
internal/autotag/gallery.go | 20 +-
internal/autotag/gallery_test.go | 53 +-
internal/autotag/image.go | 20 +-
internal/autotag/image_test.go | 50 +-
internal/autotag/integration_test.go | 231 +++----
internal/autotag/performer.go | 35 +-
internal/autotag/performer_test.go | 24 +-
internal/autotag/scene.go | 20 +-
internal/autotag/scene_test.go | 50 +-
internal/autotag/studio.go | 58 +-
internal/autotag/studio_test.go | 30 +-
internal/autotag/tag.go | 35 +-
internal/autotag/tag_test.go | 30 +-
internal/autotag/tagger.go | 29 +-
internal/dlna/cds.go | 32 +-
internal/dlna/cds_test.go | 4 +-
internal/dlna/dms.go | 37 +-
internal/dlna/paging.go | 9 +-
internal/dlna/service.go | 16 +-
internal/identify/identify.go | 37 +-
internal/identify/identify_test.go | 18 +-
internal/identify/performer.go | 16 +-
internal/identify/performer_test.go | 18 +-
internal/identify/scene.go | 52 +-
internal/identify/scene_test.go | 60 +-
internal/identify/studio.go | 12 +-
internal/identify/studio_test.go | 13 +-
internal/manager/checksum.go | 59 +-
internal/manager/filename_parser.go | 73 ++-
internal/manager/import.go | 23 +-
internal/manager/manager.go | 68 ++-
internal/manager/manager_tasks.go | 39 +-
internal/manager/post_migrate.go | 2 +-
internal/manager/running_streams.go | 12 +-
internal/manager/studio.go | 6 +-
internal/manager/task_autotag.go | 134 ++--
internal/manager/task_clean.go | 65 +-
internal/manager/task_export.go | 226 +++----
internal/manager/task_generate.go | 14 +-
...task_generate_interactive_heatmap_speed.go | 12 +-
internal/manager/task_generate_markers.go | 14 +-
internal/manager/task_generate_phash.go | 8 +-
internal/manager/task_generate_screenshot.go | 10 +-
internal/manager/task_identify.go | 29 +-
internal/manager/task_import.go | 114 ++--
internal/manager/task_scan.go | 13 +-
internal/manager/task_scan_gallery.go | 29 +-
internal/manager/task_scan_image.go | 28 +-
internal/manager/task_scan_scene.go | 16 +-
internal/manager/task_stash_box_tag.go | 31 +-
pkg/database/transaction.go | 40 --
pkg/gallery/export.go | 7 +-
pkg/gallery/export_test.go | 8 +-
pkg/gallery/import.go | 74 ++-
pkg/gallery/import_test.go | 117 ++--
pkg/gallery/query.go | 25 +-
pkg/gallery/scan.go | 26 +-
pkg/gallery/update.go | 43 +-
pkg/image/delete.go | 8 +-
pkg/image/export.go | 7 +-
pkg/image/export_test.go | 8 +-
pkg/image/import.go | 94 +--
pkg/image/import_test.go | 147 ++---
pkg/image/query.go | 29 +-
pkg/image/scan.go | 29 +-
pkg/image/update.go | 32 +-
pkg/match/cache.go | 18 +-
pkg/match/path.go | 66 +-
pkg/match/scraped.go | 39 +-
pkg/models/gallery.go | 50 +-
pkg/models/image.go | 54 +-
pkg/models/mocks/GalleryReaderWriter.go | 318 +++++-----
pkg/models/mocks/ImageReaderWriter.go | 322 +++++-----
pkg/models/mocks/MovieReaderWriter.go | 260 ++++----
pkg/models/mocks/PerformerReaderWriter.go | 336 +++++-----
pkg/models/mocks/SavedFilterReaderWriter.go | 124 ++--
pkg/models/mocks/SceneMarkerReaderWriter.go | 166 ++---
pkg/models/mocks/SceneReaderWriter.go | 578 +++++++++---------
pkg/models/mocks/ScrapedItemReaderWriter.go | 30 +-
pkg/models/mocks/StudioReaderWriter.go | 280 ++++-----
pkg/models/mocks/TagReaderWriter.go | 384 ++++++------
pkg/models/mocks/query.go | 12 +-
pkg/models/mocks/transaction.go | 188 +-----
pkg/models/movie.go | 40 +-
pkg/models/performer.go | 51 +-
pkg/models/repository.go | 48 +-
pkg/models/saved_filter.go | 20 +-
pkg/models/scene.go | 94 +--
pkg/models/scene_marker.go | 26 +-
pkg/models/scraped.go | 9 +-
pkg/models/studio.go | 44 +-
pkg/models/tag.go | 60 +-
pkg/models/transaction.go | 86 ---
pkg/movie/export.go | 15 +-
pkg/movie/export_test.go | 32 +-
pkg/movie/import.go | 42 +-
pkg/movie/import_test.go | 71 +--
pkg/movie/update.go | 12 +
pkg/performer/export.go | 12 +-
pkg/performer/export_test.go | 12 +-
pkg/performer/import.go | 52 +-
pkg/performer/import_test.go | 75 +--
pkg/performer/update.go | 12 +
pkg/scene/delete.go | 25 +-
pkg/scene/export.go | 70 ++-
pkg/scene/export_test.go | 68 +--
pkg/scene/import.go | 113 ++--
pkg/scene/import_test.go | 202 +++---
pkg/scene/marker_import.go | 39 +-
pkg/scene/marker_import_test.go | 52 +-
pkg/scene/query.go | 22 +-
pkg/scene/scan.go | 52 +-
pkg/scene/update.go | 74 ++-
pkg/scene/update_test.go | 25 +-
pkg/scraper/action.go | 8 +-
pkg/scraper/autotag.go | 45 +-
pkg/scraper/cache.go | 97 ++-
pkg/scraper/freeones.go | 5 +-
pkg/scraper/group.go | 16 +-
pkg/scraper/json.go | 4 +-
pkg/scraper/postprocessing.go | 54 +-
pkg/scraper/stash.go | 28 +-
pkg/scraper/stashbox/stash_box.go | 108 ++--
pkg/scraper/xpath.go | 4 +-
pkg/scraper/xpath_test.go | 2 +-
pkg/{database => sqlite}/custom_migrations.go | 20 +-
pkg/{database => sqlite}/database.go | 218 ++++---
pkg/sqlite/filter.go | 41 +-
pkg/sqlite/filter_internal_test.go | 27 +-
pkg/{database => sqlite}/functions.go | 2 +-
pkg/sqlite/gallery.go | 198 +++---
pkg/sqlite/gallery_test.go | 234 +++----
pkg/sqlite/image.go | 194 +++---
pkg/sqlite/image_test.go | 264 ++++----
.../migrations/10_image_tables.up.sql | 0
.../migrations/11_tag_image.up.sql | 0
.../migrations/12_oshash.up.sql | 0
.../migrations/13_images.up.sql | 0
.../migrations/14_stash_box_ids.up.sql | 0
.../migrations/15_file_mod_time.up.sql | 0
.../migrations/16_organized_flag.up.sql | 0
.../migrations/17_reset_scene_size.up.sql | 0
.../migrations/18_scene_galleries.up.sql | 0
.../migrations/19_performer_tags.up.sql | 0
.../migrations/1_initial.down.sql | 0
.../migrations/1_initial.up.sql | 0
.../migrations/20_phash.up.sql | 0
.../21_performers_studios_details.up.sql | 0
.../22_performers_studios_rating.up.sql | 0
.../migrations/23_scenes_interactive.up.sql | 0
.../migrations/24_tag_aliases.up.sql | 0
.../migrations/25_saved_filters.up.sql | 0
.../migrations/26_tag_hierarchy.up.sql | 0
.../migrations/27_studio_aliases.up.sql | 0
.../migrations/28_images_indexes.up.sql | 0
.../migrations/29_interactive_speed.up.sql | 0
.../migrations/2_cover_image.up.sql | 0
.../migrations/30_ignore_autotag.up..sql | 0
.../migrations/31_scenes_captions.up.sql | 0
.../migrations/3_o_counter.up.sql | 0
.../migrations/4_movie.up.sql | 0
.../migrations/5_performer_gender.down.sql | 0
.../migrations/5_performer_gender.up.sql | 0
.../migrations/6_scenes_format.up.sql | 0
.../7_performer_optimization.up.sql | 0
.../migrations/8_movie_fix.up.sql | 0
.../migrations/9_studios_parent_studio.up.sql | 0
pkg/sqlite/movies.go | 132 ++--
pkg/sqlite/movies_test.go | 78 +--
pkg/sqlite/performer.go | 208 ++++---
pkg/sqlite/performer_test.go | 198 +++---
pkg/sqlite/query.go | 13 +-
pkg/{database => sqlite}/regex.go | 2 +-
pkg/sqlite/repository.go | 141 ++---
pkg/sqlite/saved_filter.go | 56 +-
pkg/sqlite/saved_filter_test.go | 40 +-
pkg/sqlite/scene.go | 314 +++++-----
pkg/sqlite/scene_marker.go | 101 ++-
pkg/sqlite/scene_marker_test.go | 40 +-
pkg/sqlite/scene_test.go | 396 ++++++------
pkg/sqlite/scraped_item.go | 40 +-
pkg/sqlite/setup_test.go | 207 +++----
pkg/sqlite/sql.go | 5 +-
pkg/sqlite/stash_id_test.go | 29 +-
pkg/sqlite/studio.go | 162 +++--
pkg/sqlite/studio_test.go | 196 +++---
pkg/sqlite/tag.go | 222 ++++---
pkg/sqlite/tag_test.go | 226 +++----
pkg/sqlite/transaction.go | 228 ++-----
pkg/sqlite/tx.go | 55 ++
pkg/studio/export.go | 18 +-
pkg/studio/export_test.go | 38 +-
pkg/studio/import.go | 46 +-
pkg/studio/import_test.go | 85 +--
pkg/studio/query.go | 22 +-
pkg/studio/update.go | 16 +-
pkg/tag/export.go | 15 +-
pkg/tag/export_test.go | 34 +-
pkg/tag/import.go | 46 +-
pkg/tag/import_test.go | 129 ++--
pkg/tag/query.go | 22 +-
pkg/tag/update.go | 39 +-
pkg/tag/update_test.go | 16 +-
pkg/txn/transaction.go | 38 ++
244 files changed, 7377 insertions(+), 6699 deletions(-)
delete mode 100644 pkg/database/transaction.go
delete mode 100644 pkg/models/transaction.go
create mode 100644 pkg/movie/update.go
create mode 100644 pkg/performer/update.go
rename pkg/{database => sqlite}/custom_migrations.go (81%)
rename pkg/{database => sqlite}/database.go (54%)
rename pkg/{database => sqlite}/functions.go (96%)
rename pkg/{database => sqlite}/migrations/10_image_tables.up.sql (100%)
rename pkg/{database => sqlite}/migrations/11_tag_image.up.sql (100%)
rename pkg/{database => sqlite}/migrations/12_oshash.up.sql (100%)
rename pkg/{database => sqlite}/migrations/13_images.up.sql (100%)
rename pkg/{database => sqlite}/migrations/14_stash_box_ids.up.sql (100%)
rename pkg/{database => sqlite}/migrations/15_file_mod_time.up.sql (100%)
rename pkg/{database => sqlite}/migrations/16_organized_flag.up.sql (100%)
rename pkg/{database => sqlite}/migrations/17_reset_scene_size.up.sql (100%)
rename pkg/{database => sqlite}/migrations/18_scene_galleries.up.sql (100%)
rename pkg/{database => sqlite}/migrations/19_performer_tags.up.sql (100%)
rename pkg/{database => sqlite}/migrations/1_initial.down.sql (100%)
rename pkg/{database => sqlite}/migrations/1_initial.up.sql (100%)
rename pkg/{database => sqlite}/migrations/20_phash.up.sql (100%)
rename pkg/{database => sqlite}/migrations/21_performers_studios_details.up.sql (100%)
rename pkg/{database => sqlite}/migrations/22_performers_studios_rating.up.sql (100%)
rename pkg/{database => sqlite}/migrations/23_scenes_interactive.up.sql (100%)
rename pkg/{database => sqlite}/migrations/24_tag_aliases.up.sql (100%)
rename pkg/{database => sqlite}/migrations/25_saved_filters.up.sql (100%)
rename pkg/{database => sqlite}/migrations/26_tag_hierarchy.up.sql (100%)
rename pkg/{database => sqlite}/migrations/27_studio_aliases.up.sql (100%)
rename pkg/{database => sqlite}/migrations/28_images_indexes.up.sql (100%)
rename pkg/{database => sqlite}/migrations/29_interactive_speed.up.sql (100%)
rename pkg/{database => sqlite}/migrations/2_cover_image.up.sql (100%)
rename pkg/{database => sqlite}/migrations/30_ignore_autotag.up..sql (100%)
rename pkg/{database => sqlite}/migrations/31_scenes_captions.up.sql (100%)
rename pkg/{database => sqlite}/migrations/3_o_counter.up.sql (100%)
rename pkg/{database => sqlite}/migrations/4_movie.up.sql (100%)
rename pkg/{database => sqlite}/migrations/5_performer_gender.down.sql (100%)
rename pkg/{database => sqlite}/migrations/5_performer_gender.up.sql (100%)
rename pkg/{database => sqlite}/migrations/6_scenes_format.up.sql (100%)
rename pkg/{database => sqlite}/migrations/7_performer_optimization.up.sql (100%)
rename pkg/{database => sqlite}/migrations/8_movie_fix.up.sql (100%)
rename pkg/{database => sqlite}/migrations/9_studios_parent_studio.up.sql (100%)
rename pkg/{database => sqlite}/regex.go (98%)
create mode 100644 pkg/sqlite/tx.go
create mode 100644 pkg/txn/transaction.go
diff --git a/internal/api/resolver.go b/internal/api/resolver.go
index 5880f1e3c..e6289a218 100644
--- a/internal/api/resolver.go
+++ b/internal/api/resolver.go
@@ -11,6 +11,7 @@ import (
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/plugin"
"github.com/stashapp/stash/pkg/scraper"
+ "github.com/stashapp/stash/pkg/txn"
)
var (
@@ -30,7 +31,9 @@ type hookExecutor interface {
}
type Resolver struct {
- txnManager models.TransactionManager
+ txnManager txn.Manager
+ repository models.Repository
+
hookExecutor hookExecutor
}
@@ -85,17 +88,13 @@ type studioResolver struct{ *Resolver }
type movieResolver struct{ *Resolver }
type tagResolver struct{ *Resolver }
-func (r *Resolver) withTxn(ctx context.Context, fn func(r models.Repository) error) error {
- return r.txnManager.WithTxn(ctx, fn)
-}
-
-func (r *Resolver) withReadTxn(ctx context.Context, fn func(r models.ReaderRepository) error) error {
- return r.txnManager.WithReadTxn(ctx, fn)
+func (r *Resolver) withTxn(ctx context.Context, fn func(ctx context.Context) error) error {
+ return txn.WithTxn(ctx, r.txnManager, fn)
}
func (r *queryResolver) MarkerWall(ctx context.Context, q *string) (ret []*models.SceneMarker, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- ret, err = repo.SceneMarker().Wall(q)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.repository.SceneMarker.Wall(ctx, q)
return err
}); err != nil {
return nil, err
@@ -104,8 +103,8 @@ func (r *queryResolver) MarkerWall(ctx context.Context, q *string) (ret []*model
}
func (r *queryResolver) SceneWall(ctx context.Context, q *string) (ret []*models.Scene, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- ret, err = repo.Scene().Wall(q)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.repository.Scene.Wall(ctx, q)
return err
}); err != nil {
return nil, err
@@ -115,8 +114,8 @@ func (r *queryResolver) SceneWall(ctx context.Context, q *string) (ret []*models
}
func (r *queryResolver) MarkerStrings(ctx context.Context, q *string, sort *string) (ret []*models.MarkerStringsResultType, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- ret, err = repo.SceneMarker().GetMarkerStrings(q, sort)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.repository.SceneMarker.GetMarkerStrings(ctx, q, sort)
return err
}); err != nil {
return nil, err
@@ -127,24 +126,25 @@ func (r *queryResolver) MarkerStrings(ctx context.Context, q *string, sort *stri
func (r *queryResolver) Stats(ctx context.Context) (*StatsResultType, error) {
var ret StatsResultType
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- scenesQB := repo.Scene()
- imageQB := repo.Image()
- galleryQB := repo.Gallery()
- studiosQB := repo.Studio()
- performersQB := repo.Performer()
- moviesQB := repo.Movie()
- tagsQB := repo.Tag()
- scenesCount, _ := scenesQB.Count()
- scenesSize, _ := scenesQB.Size()
- scenesDuration, _ := scenesQB.Duration()
- imageCount, _ := imageQB.Count()
- imageSize, _ := imageQB.Size()
- galleryCount, _ := galleryQB.Count()
- performersCount, _ := performersQB.Count()
- studiosCount, _ := studiosQB.Count()
- moviesCount, _ := moviesQB.Count()
- tagsCount, _ := tagsQB.Count()
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ repo := r.repository
+ scenesQB := repo.Scene
+ imageQB := repo.Image
+ galleryQB := repo.Gallery
+ studiosQB := repo.Studio
+ performersQB := repo.Performer
+ moviesQB := repo.Movie
+ tagsQB := repo.Tag
+ scenesCount, _ := scenesQB.Count(ctx)
+ scenesSize, _ := scenesQB.Size(ctx)
+ scenesDuration, _ := scenesQB.Duration(ctx)
+ imageCount, _ := imageQB.Count(ctx)
+ imageSize, _ := imageQB.Size(ctx)
+ galleryCount, _ := galleryQB.Count(ctx)
+ performersCount, _ := performersQB.Count(ctx)
+ studiosCount, _ := studiosQB.Count(ctx)
+ moviesCount, _ := moviesQB.Count(ctx)
+ tagsCount, _ := tagsQB.Count(ctx)
ret = StatsResultType{
SceneCount: scenesCount,
@@ -202,15 +202,15 @@ func (r *queryResolver) SceneMarkerTags(ctx context.Context, scene_id string) ([
var keys []int
tags := make(map[int]*SceneMarkerTag)
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- sceneMarkers, err := repo.SceneMarker().FindBySceneID(sceneID)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ sceneMarkers, err := r.repository.SceneMarker.FindBySceneID(ctx, sceneID)
if err != nil {
return err
}
- tqb := repo.Tag()
+ tqb := r.repository.Tag
for _, sceneMarker := range sceneMarkers {
- markerPrimaryTag, err := tqb.Find(sceneMarker.PrimaryTagID)
+ markerPrimaryTag, err := tqb.Find(ctx, sceneMarker.PrimaryTagID)
if err != nil {
return err
}
diff --git a/internal/api/resolver_model_gallery.go b/internal/api/resolver_model_gallery.go
index 911cb8fe1..8e1d98dd4 100644
--- a/internal/api/resolver_model_gallery.go
+++ b/internal/api/resolver_model_gallery.go
@@ -24,12 +24,12 @@ func (r *galleryResolver) Title(ctx context.Context, obj *models.Gallery) (*stri
}
func (r *galleryResolver) Images(ctx context.Context, obj *models.Gallery) (ret []*models.Image, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
var err error
// #2376 - sort images by path
// doing this via Query is really slow, so stick with FindByGalleryID
- ret, err = repo.Image().FindByGalleryID(obj.ID)
+ ret, err = r.repository.Image.FindByGalleryID(ctx, obj.ID)
if err != nil {
return err
}
@@ -43,9 +43,9 @@ func (r *galleryResolver) Images(ctx context.Context, obj *models.Gallery) (ret
}
func (r *galleryResolver) Cover(ctx context.Context, obj *models.Gallery) (ret *models.Image, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
// doing this via Query is really slow, so stick with FindByGalleryID
- imgs, err := repo.Image().FindByGalleryID(obj.ID)
+ imgs, err := r.repository.Image.FindByGalleryID(ctx, obj.ID)
if err != nil {
return err
}
@@ -100,9 +100,9 @@ func (r *galleryResolver) Rating(ctx context.Context, obj *models.Gallery) (*int
}
func (r *galleryResolver) Scenes(ctx context.Context, obj *models.Gallery) (ret []*models.Scene, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
var err error
- ret, err = repo.Scene().FindByGalleryID(obj.ID)
+ ret, err = r.repository.Scene.FindByGalleryID(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@@ -116,9 +116,9 @@ func (r *galleryResolver) Studio(ctx context.Context, obj *models.Gallery) (ret
return nil, nil
}
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
var err error
- ret, err = repo.Studio().Find(int(obj.StudioID.Int64))
+ ret, err = r.repository.Studio.Find(ctx, int(obj.StudioID.Int64))
return err
}); err != nil {
return nil, err
@@ -128,9 +128,9 @@ func (r *galleryResolver) Studio(ctx context.Context, obj *models.Gallery) (ret
}
func (r *galleryResolver) Tags(ctx context.Context, obj *models.Gallery) (ret []*models.Tag, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
var err error
- ret, err = repo.Tag().FindByGalleryID(obj.ID)
+ ret, err = r.repository.Tag.FindByGalleryID(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@@ -140,9 +140,9 @@ func (r *galleryResolver) Tags(ctx context.Context, obj *models.Gallery) (ret []
}
func (r *galleryResolver) Performers(ctx context.Context, obj *models.Gallery) (ret []*models.Performer, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
var err error
- ret, err = repo.Performer().FindByGalleryID(obj.ID)
+ ret, err = r.repository.Performer.FindByGalleryID(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@@ -152,9 +152,9 @@ func (r *galleryResolver) Performers(ctx context.Context, obj *models.Gallery) (
}
func (r *galleryResolver) ImageCount(ctx context.Context, obj *models.Gallery) (ret int, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
var err error
- ret, err = repo.Image().CountByGalleryID(obj.ID)
+ ret, err = r.repository.Image.CountByGalleryID(ctx, obj.ID)
return err
}); err != nil {
return 0, err
diff --git a/internal/api/resolver_model_image.go b/internal/api/resolver_model_image.go
index 0f47b6c3e..8c7222b97 100644
--- a/internal/api/resolver_model_image.go
+++ b/internal/api/resolver_model_image.go
@@ -45,9 +45,9 @@ func (r *imageResolver) Paths(ctx context.Context, obj *models.Image) (*ImagePat
}
func (r *imageResolver) Galleries(ctx context.Context, obj *models.Image) (ret []*models.Gallery, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
var err error
- ret, err = repo.Gallery().FindByImageID(obj.ID)
+ ret, err = r.repository.Gallery.FindByImageID(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@@ -61,8 +61,8 @@ func (r *imageResolver) Studio(ctx context.Context, obj *models.Image) (ret *mod
return nil, nil
}
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- ret, err = repo.Studio().Find(int(obj.StudioID.Int64))
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.repository.Studio.Find(ctx, int(obj.StudioID.Int64))
return err
}); err != nil {
return nil, err
@@ -72,8 +72,8 @@ func (r *imageResolver) Studio(ctx context.Context, obj *models.Image) (ret *mod
}
func (r *imageResolver) Tags(ctx context.Context, obj *models.Image) (ret []*models.Tag, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- ret, err = repo.Tag().FindByImageID(obj.ID)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.repository.Tag.FindByImageID(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@@ -83,8 +83,8 @@ func (r *imageResolver) Tags(ctx context.Context, obj *models.Image) (ret []*mod
}
func (r *imageResolver) Performers(ctx context.Context, obj *models.Image) (ret []*models.Performer, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- ret, err = repo.Performer().FindByImageID(obj.ID)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.repository.Performer.FindByImageID(ctx, obj.ID)
return err
}); err != nil {
return nil, err
diff --git a/internal/api/resolver_model_movie.go b/internal/api/resolver_model_movie.go
index 57e979828..9085eee82 100644
--- a/internal/api/resolver_model_movie.go
+++ b/internal/api/resolver_model_movie.go
@@ -56,8 +56,8 @@ func (r *movieResolver) Rating(ctx context.Context, obj *models.Movie) (*int, er
func (r *movieResolver) Studio(ctx context.Context, obj *models.Movie) (ret *models.Studio, err error) {
if obj.StudioID.Valid {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- ret, err = repo.Studio().Find(int(obj.StudioID.Int64))
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.repository.Studio.Find(ctx, int(obj.StudioID.Int64))
return err
}); err != nil {
return nil, err
@@ -92,9 +92,9 @@ func (r *movieResolver) FrontImagePath(ctx context.Context, obj *models.Movie) (
func (r *movieResolver) BackImagePath(ctx context.Context, obj *models.Movie) (*string, error) {
// don't return any thing if there is no back image
var img []byte
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
var err error
- img, err = repo.Movie().GetBackImage(obj.ID)
+ img, err = r.repository.Movie.GetBackImage(ctx, obj.ID)
if err != nil {
return err
}
@@ -115,8 +115,8 @@ func (r *movieResolver) BackImagePath(ctx context.Context, obj *models.Movie) (*
func (r *movieResolver) SceneCount(ctx context.Context, obj *models.Movie) (ret *int, err error) {
var res int
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- res, err = repo.Scene().CountByMovieID(obj.ID)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ res, err = r.repository.Scene.CountByMovieID(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@@ -126,9 +126,9 @@ func (r *movieResolver) SceneCount(ctx context.Context, obj *models.Movie) (ret
}
func (r *movieResolver) Scenes(ctx context.Context, obj *models.Movie) (ret []*models.Scene, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
var err error
- ret, err = repo.Scene().FindByMovieID(obj.ID)
+ ret, err = r.repository.Scene.FindByMovieID(ctx, obj.ID)
return err
}); err != nil {
return nil, err
diff --git a/internal/api/resolver_model_performer.go b/internal/api/resolver_model_performer.go
index 364c346ca..3a44a1cc0 100644
--- a/internal/api/resolver_model_performer.go
+++ b/internal/api/resolver_model_performer.go
@@ -142,8 +142,8 @@ func (r *performerResolver) ImagePath(ctx context.Context, obj *models.Performer
}
func (r *performerResolver) Tags(ctx context.Context, obj *models.Performer) (ret []*models.Tag, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- ret, err = repo.Tag().FindByPerformerID(obj.ID)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.repository.Tag.FindByPerformerID(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@@ -154,8 +154,8 @@ func (r *performerResolver) Tags(ctx context.Context, obj *models.Performer) (re
func (r *performerResolver) SceneCount(ctx context.Context, obj *models.Performer) (ret *int, err error) {
var res int
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- res, err = repo.Scene().CountByPerformerID(obj.ID)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ res, err = r.repository.Scene.CountByPerformerID(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@@ -166,8 +166,8 @@ func (r *performerResolver) SceneCount(ctx context.Context, obj *models.Performe
func (r *performerResolver) ImageCount(ctx context.Context, obj *models.Performer) (ret *int, err error) {
var res int
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- res, err = image.CountByPerformerID(repo.Image(), obj.ID)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ res, err = image.CountByPerformerID(ctx, r.repository.Image, obj.ID)
return err
}); err != nil {
return nil, err
@@ -178,8 +178,8 @@ func (r *performerResolver) ImageCount(ctx context.Context, obj *models.Performe
func (r *performerResolver) GalleryCount(ctx context.Context, obj *models.Performer) (ret *int, err error) {
var res int
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- res, err = gallery.CountByPerformerID(repo.Gallery(), obj.ID)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ res, err = gallery.CountByPerformerID(ctx, r.repository.Gallery, obj.ID)
return err
}); err != nil {
return nil, err
@@ -189,8 +189,8 @@ func (r *performerResolver) GalleryCount(ctx context.Context, obj *models.Perfor
}
func (r *performerResolver) Scenes(ctx context.Context, obj *models.Performer) (ret []*models.Scene, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- ret, err = repo.Scene().FindByPerformerID(obj.ID)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.repository.Scene.FindByPerformerID(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@@ -200,8 +200,8 @@ func (r *performerResolver) Scenes(ctx context.Context, obj *models.Performer) (
}
func (r *performerResolver) StashIds(ctx context.Context, obj *models.Performer) (ret []*models.StashID, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- ret, err = repo.Performer().GetStashIDs(obj.ID)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.repository.Performer.GetStashIDs(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@@ -256,8 +256,8 @@ func (r *performerResolver) UpdatedAt(ctx context.Context, obj *models.Performer
}
func (r *performerResolver) Movies(ctx context.Context, obj *models.Performer) (ret []*models.Movie, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- ret, err = repo.Movie().FindByPerformerID(obj.ID)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.repository.Movie.FindByPerformerID(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@@ -268,8 +268,8 @@ func (r *performerResolver) Movies(ctx context.Context, obj *models.Performer) (
func (r *performerResolver) MovieCount(ctx context.Context, obj *models.Performer) (ret *int, err error) {
var res int
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- res, err = repo.Movie().CountByPerformerID(obj.ID)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ res, err = r.repository.Movie.CountByPerformerID(ctx, obj.ID)
return err
}); err != nil {
return nil, err
diff --git a/internal/api/resolver_model_scene.go b/internal/api/resolver_model_scene.go
index 0b81beaa5..d9c783ac8 100644
--- a/internal/api/resolver_model_scene.go
+++ b/internal/api/resolver_model_scene.go
@@ -116,8 +116,8 @@ func (r *sceneResolver) Paths(ctx context.Context, obj *models.Scene) (*ScenePat
}
func (r *sceneResolver) SceneMarkers(ctx context.Context, obj *models.Scene) (ret []*models.SceneMarker, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- ret, err = repo.SceneMarker().FindBySceneID(obj.ID)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.repository.SceneMarker.FindBySceneID(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@@ -127,8 +127,8 @@ func (r *sceneResolver) SceneMarkers(ctx context.Context, obj *models.Scene) (re
}
func (r *sceneResolver) Captions(ctx context.Context, obj *models.Scene) (ret []*models.SceneCaption, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- ret, err = repo.Scene().GetCaptions(obj.ID)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.repository.Scene.GetCaptions(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@@ -138,8 +138,8 @@ func (r *sceneResolver) Captions(ctx context.Context, obj *models.Scene) (ret []
}
func (r *sceneResolver) Galleries(ctx context.Context, obj *models.Scene) (ret []*models.Gallery, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- ret, err = repo.Gallery().FindBySceneID(obj.ID)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.repository.Gallery.FindBySceneID(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@@ -153,8 +153,8 @@ func (r *sceneResolver) Studio(ctx context.Context, obj *models.Scene) (ret *mod
return nil, nil
}
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- ret, err = repo.Studio().Find(int(obj.StudioID.Int64))
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.repository.Studio.Find(ctx, int(obj.StudioID.Int64))
return err
}); err != nil {
return nil, err
@@ -164,17 +164,17 @@ func (r *sceneResolver) Studio(ctx context.Context, obj *models.Scene) (ret *mod
}
func (r *sceneResolver) Movies(ctx context.Context, obj *models.Scene) (ret []*SceneMovie, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- qb := repo.Scene()
- mqb := repo.Movie()
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ qb := r.repository.Scene
+ mqb := r.repository.Movie
- sceneMovies, err := qb.GetMovies(obj.ID)
+ sceneMovies, err := qb.GetMovies(ctx, obj.ID)
if err != nil {
return err
}
for _, sm := range sceneMovies {
- movie, err := mqb.Find(sm.MovieID)
+ movie, err := mqb.Find(ctx, sm.MovieID)
if err != nil {
return err
}
@@ -200,8 +200,8 @@ func (r *sceneResolver) Movies(ctx context.Context, obj *models.Scene) (ret []*S
}
func (r *sceneResolver) Tags(ctx context.Context, obj *models.Scene) (ret []*models.Tag, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- ret, err = repo.Tag().FindBySceneID(obj.ID)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.repository.Tag.FindBySceneID(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@@ -211,8 +211,8 @@ func (r *sceneResolver) Tags(ctx context.Context, obj *models.Scene) (ret []*mod
}
func (r *sceneResolver) Performers(ctx context.Context, obj *models.Scene) (ret []*models.Performer, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- ret, err = repo.Performer().FindBySceneID(obj.ID)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.repository.Performer.FindBySceneID(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@@ -222,8 +222,8 @@ func (r *sceneResolver) Performers(ctx context.Context, obj *models.Scene) (ret
}
func (r *sceneResolver) StashIds(ctx context.Context, obj *models.Scene) (ret []*models.StashID, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- ret, err = repo.Scene().GetStashIDs(obj.ID)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.repository.Scene.GetStashIDs(ctx, obj.ID)
return err
}); err != nil {
return nil, err
diff --git a/internal/api/resolver_model_scene_marker.go b/internal/api/resolver_model_scene_marker.go
index 64d418bd1..7a4d01be1 100644
--- a/internal/api/resolver_model_scene_marker.go
+++ b/internal/api/resolver_model_scene_marker.go
@@ -13,9 +13,9 @@ func (r *sceneMarkerResolver) Scene(ctx context.Context, obj *models.SceneMarker
panic("Invalid scene id")
}
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
sceneID := int(obj.SceneID.Int64)
- ret, err = repo.Scene().Find(sceneID)
+ ret, err = r.repository.Scene.Find(ctx, sceneID)
return err
}); err != nil {
return nil, err
@@ -25,8 +25,8 @@ func (r *sceneMarkerResolver) Scene(ctx context.Context, obj *models.SceneMarker
}
func (r *sceneMarkerResolver) PrimaryTag(ctx context.Context, obj *models.SceneMarker) (ret *models.Tag, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- ret, err = repo.Tag().Find(obj.PrimaryTagID)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.repository.Tag.Find(ctx, obj.PrimaryTagID)
return err
}); err != nil {
return nil, err
@@ -36,8 +36,8 @@ func (r *sceneMarkerResolver) PrimaryTag(ctx context.Context, obj *models.SceneM
}
func (r *sceneMarkerResolver) Tags(ctx context.Context, obj *models.SceneMarker) (ret []*models.Tag, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- ret, err = repo.Tag().FindBySceneMarkerID(obj.ID)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.repository.Tag.FindBySceneMarkerID(ctx, obj.ID)
return err
}); err != nil {
return nil, err
diff --git a/internal/api/resolver_model_studio.go b/internal/api/resolver_model_studio.go
index d2b6b44c1..a7fc56442 100644
--- a/internal/api/resolver_model_studio.go
+++ b/internal/api/resolver_model_studio.go
@@ -29,9 +29,9 @@ func (r *studioResolver) ImagePath(ctx context.Context, obj *models.Studio) (*st
imagePath := urlbuilders.NewStudioURLBuilder(baseURL, obj).GetStudioImageURL()
var hasImage bool
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
var err error
- hasImage, err = repo.Studio().HasImage(obj.ID)
+ hasImage, err = r.repository.Studio.HasImage(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@@ -46,8 +46,8 @@ func (r *studioResolver) ImagePath(ctx context.Context, obj *models.Studio) (*st
}
func (r *studioResolver) Aliases(ctx context.Context, obj *models.Studio) (ret []string, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- ret, err = repo.Studio().GetAliases(obj.ID)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.repository.Studio.GetAliases(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@@ -58,8 +58,8 @@ func (r *studioResolver) Aliases(ctx context.Context, obj *models.Studio) (ret [
func (r *studioResolver) SceneCount(ctx context.Context, obj *models.Studio) (ret *int, err error) {
var res int
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- res, err = repo.Scene().CountByStudioID(obj.ID)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ res, err = r.repository.Scene.CountByStudioID(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@@ -70,8 +70,8 @@ func (r *studioResolver) SceneCount(ctx context.Context, obj *models.Studio) (re
func (r *studioResolver) ImageCount(ctx context.Context, obj *models.Studio) (ret *int, err error) {
var res int
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- res, err = image.CountByStudioID(repo.Image(), obj.ID)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ res, err = image.CountByStudioID(ctx, r.repository.Image, obj.ID)
return err
}); err != nil {
return nil, err
@@ -82,8 +82,8 @@ func (r *studioResolver) ImageCount(ctx context.Context, obj *models.Studio) (re
func (r *studioResolver) GalleryCount(ctx context.Context, obj *models.Studio) (ret *int, err error) {
var res int
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- res, err = gallery.CountByStudioID(repo.Gallery(), obj.ID)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ res, err = gallery.CountByStudioID(ctx, r.repository.Gallery, obj.ID)
return err
}); err != nil {
return nil, err
@@ -97,8 +97,8 @@ func (r *studioResolver) ParentStudio(ctx context.Context, obj *models.Studio) (
return nil, nil
}
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- ret, err = repo.Studio().Find(int(obj.ParentID.Int64))
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.repository.Studio.Find(ctx, int(obj.ParentID.Int64))
return err
}); err != nil {
return nil, err
@@ -108,8 +108,8 @@ func (r *studioResolver) ParentStudio(ctx context.Context, obj *models.Studio) (
}
func (r *studioResolver) ChildStudios(ctx context.Context, obj *models.Studio) (ret []*models.Studio, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- ret, err = repo.Studio().FindChildren(obj.ID)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.repository.Studio.FindChildren(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@@ -119,8 +119,8 @@ func (r *studioResolver) ChildStudios(ctx context.Context, obj *models.Studio) (
}
func (r *studioResolver) StashIds(ctx context.Context, obj *models.Studio) (ret []*models.StashID, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- ret, err = repo.Studio().GetStashIDs(obj.ID)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.repository.Studio.GetStashIDs(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@@ -153,8 +153,8 @@ func (r *studioResolver) UpdatedAt(ctx context.Context, obj *models.Studio) (*ti
}
func (r *studioResolver) Movies(ctx context.Context, obj *models.Studio) (ret []*models.Movie, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- ret, err = repo.Movie().FindByStudioID(obj.ID)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.repository.Movie.FindByStudioID(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@@ -165,8 +165,8 @@ func (r *studioResolver) Movies(ctx context.Context, obj *models.Studio) (ret []
func (r *studioResolver) MovieCount(ctx context.Context, obj *models.Studio) (ret *int, err error) {
var res int
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- res, err = repo.Movie().CountByStudioID(obj.ID)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ res, err = r.repository.Movie.CountByStudioID(ctx, obj.ID)
return err
}); err != nil {
return nil, err
diff --git a/internal/api/resolver_model_tag.go b/internal/api/resolver_model_tag.go
index cb406e5fc..3592dd959 100644
--- a/internal/api/resolver_model_tag.go
+++ b/internal/api/resolver_model_tag.go
@@ -11,8 +11,8 @@ import (
)
func (r *tagResolver) Parents(ctx context.Context, obj *models.Tag) (ret []*models.Tag, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- ret, err = repo.Tag().FindByChildTagID(obj.ID)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.repository.Tag.FindByChildTagID(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@@ -22,8 +22,8 @@ func (r *tagResolver) Parents(ctx context.Context, obj *models.Tag) (ret []*mode
}
func (r *tagResolver) Children(ctx context.Context, obj *models.Tag) (ret []*models.Tag, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- ret, err = repo.Tag().FindByParentTagID(obj.ID)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.repository.Tag.FindByParentTagID(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@@ -33,8 +33,8 @@ func (r *tagResolver) Children(ctx context.Context, obj *models.Tag) (ret []*mod
}
func (r *tagResolver) Aliases(ctx context.Context, obj *models.Tag) (ret []string, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- ret, err = repo.Tag().GetAliases(obj.ID)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.repository.Tag.GetAliases(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@@ -45,8 +45,8 @@ func (r *tagResolver) Aliases(ctx context.Context, obj *models.Tag) (ret []strin
func (r *tagResolver) SceneCount(ctx context.Context, obj *models.Tag) (ret *int, err error) {
var count int
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- count, err = repo.Scene().CountByTagID(obj.ID)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ count, err = r.repository.Scene.CountByTagID(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@@ -57,8 +57,8 @@ func (r *tagResolver) SceneCount(ctx context.Context, obj *models.Tag) (ret *int
func (r *tagResolver) SceneMarkerCount(ctx context.Context, obj *models.Tag) (ret *int, err error) {
var count int
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- count, err = repo.SceneMarker().CountByTagID(obj.ID)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ count, err = r.repository.SceneMarker.CountByTagID(ctx, obj.ID)
return err
}); err != nil {
return nil, err
@@ -69,8 +69,8 @@ func (r *tagResolver) SceneMarkerCount(ctx context.Context, obj *models.Tag) (re
func (r *tagResolver) ImageCount(ctx context.Context, obj *models.Tag) (ret *int, err error) {
var res int
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- res, err = image.CountByTagID(repo.Image(), obj.ID)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ res, err = image.CountByTagID(ctx, r.repository.Image, obj.ID)
return err
}); err != nil {
return nil, err
@@ -81,8 +81,8 @@ func (r *tagResolver) ImageCount(ctx context.Context, obj *models.Tag) (ret *int
func (r *tagResolver) GalleryCount(ctx context.Context, obj *models.Tag) (ret *int, err error) {
var res int
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- res, err = gallery.CountByTagID(repo.Gallery(), obj.ID)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ res, err = gallery.CountByTagID(ctx, r.repository.Gallery, obj.ID)
return err
}); err != nil {
return nil, err
@@ -93,8 +93,8 @@ func (r *tagResolver) GalleryCount(ctx context.Context, obj *models.Tag) (ret *i
func (r *tagResolver) PerformerCount(ctx context.Context, obj *models.Tag) (ret *int, err error) {
var count int
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- count, err = repo.Performer().CountByTagID(obj.ID)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ count, err = r.repository.Performer.CountByTagID(ctx, obj.ID)
return err
}); err != nil {
return nil, err
diff --git a/internal/api/resolver_mutation_configure.go b/internal/api/resolver_mutation_configure.go
index 48b1c6b9f..484903b54 100644
--- a/internal/api/resolver_mutation_configure.go
+++ b/internal/api/resolver_mutation_configure.go
@@ -132,7 +132,9 @@ func (r *mutationResolver) ConfigureGeneral(ctx context.Context, input ConfigGen
}
// validate changing VideoFileNamingAlgorithm
- if err := manager.ValidateVideoFileNamingAlgorithm(r.txnManager, *input.VideoFileNamingAlgorithm); err != nil {
+ if err := r.withTxn(context.TODO(), func(ctx context.Context) error {
+ return manager.ValidateVideoFileNamingAlgorithm(ctx, r.repository.Scene, *input.VideoFileNamingAlgorithm)
+ }); err != nil {
return makeConfigGeneralResult(), err
}
diff --git a/internal/api/resolver_mutation_gallery.go b/internal/api/resolver_mutation_gallery.go
index 816d6aba3..04a320365 100644
--- a/internal/api/resolver_mutation_gallery.go
+++ b/internal/api/resolver_mutation_gallery.go
@@ -11,6 +11,7 @@ import (
"github.com/stashapp/stash/internal/manager"
"github.com/stashapp/stash/pkg/file"
+ "github.com/stashapp/stash/pkg/gallery"
"github.com/stashapp/stash/pkg/hash/md5"
"github.com/stashapp/stash/pkg/image"
"github.com/stashapp/stash/pkg/models"
@@ -21,8 +22,8 @@ import (
)
func (r *mutationResolver) getGallery(ctx context.Context, id int) (ret *models.Gallery, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- ret, err = repo.Gallery().Find(id)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.repository.Gallery.Find(ctx, id)
return err
}); err != nil {
return nil, err
@@ -80,26 +81,26 @@ func (r *mutationResolver) GalleryCreate(ctx context.Context, input GalleryCreat
// Start the transaction and save the gallery
var gallery *models.Gallery
- if err := r.withTxn(ctx, func(repo models.Repository) error {
- qb := repo.Gallery()
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ qb := r.repository.Gallery
var err error
- gallery, err = qb.Create(newGallery)
+ gallery, err = qb.Create(ctx, newGallery)
if err != nil {
return err
}
// Save the performers
- if err := r.updateGalleryPerformers(qb, gallery.ID, input.PerformerIds); err != nil {
+ if err := r.updateGalleryPerformers(ctx, qb, gallery.ID, input.PerformerIds); err != nil {
return err
}
// Save the tags
- if err := r.updateGalleryTags(qb, gallery.ID, input.TagIds); err != nil {
+ if err := r.updateGalleryTags(ctx, qb, gallery.ID, input.TagIds); err != nil {
return err
}
// Save the scenes
- if err := r.updateGalleryScenes(qb, gallery.ID, input.SceneIds); err != nil {
+ if err := r.updateGalleryScenes(ctx, qb, gallery.ID, input.SceneIds); err != nil {
return err
}
@@ -112,28 +113,32 @@ func (r *mutationResolver) GalleryCreate(ctx context.Context, input GalleryCreat
return r.getGallery(ctx, gallery.ID)
}
-func (r *mutationResolver) updateGalleryPerformers(qb models.GalleryReaderWriter, galleryID int, performerIDs []string) error {
+func (r *mutationResolver) updateGalleryPerformers(ctx context.Context, qb gallery.PerformerUpdater, galleryID int, performerIDs []string) error {
ids, err := stringslice.StringSliceToIntSlice(performerIDs)
if err != nil {
return err
}
- return qb.UpdatePerformers(galleryID, ids)
+ return qb.UpdatePerformers(ctx, galleryID, ids)
}
-func (r *mutationResolver) updateGalleryTags(qb models.GalleryReaderWriter, galleryID int, tagIDs []string) error {
+func (r *mutationResolver) updateGalleryTags(ctx context.Context, qb gallery.TagUpdater, galleryID int, tagIDs []string) error {
ids, err := stringslice.StringSliceToIntSlice(tagIDs)
if err != nil {
return err
}
- return qb.UpdateTags(galleryID, ids)
+ return qb.UpdateTags(ctx, galleryID, ids)
}
-func (r *mutationResolver) updateGalleryScenes(qb models.GalleryReaderWriter, galleryID int, sceneIDs []string) error {
+type GallerySceneUpdater interface {
+ UpdateScenes(ctx context.Context, galleryID int, sceneIDs []int) error
+}
+
+func (r *mutationResolver) updateGalleryScenes(ctx context.Context, qb GallerySceneUpdater, galleryID int, sceneIDs []string) error {
ids, err := stringslice.StringSliceToIntSlice(sceneIDs)
if err != nil {
return err
}
- return qb.UpdateScenes(galleryID, ids)
+ return qb.UpdateScenes(ctx, galleryID, ids)
}
func (r *mutationResolver) GalleryUpdate(ctx context.Context, input models.GalleryUpdateInput) (ret *models.Gallery, err error) {
@@ -142,8 +147,8 @@ func (r *mutationResolver) GalleryUpdate(ctx context.Context, input models.Galle
}
// Start the transaction and save the gallery
- if err := r.withTxn(ctx, func(repo models.Repository) error {
- ret, err = r.galleryUpdate(input, translator, repo)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.galleryUpdate(ctx, input, translator)
return err
}); err != nil {
return nil, err
@@ -158,13 +163,13 @@ func (r *mutationResolver) GalleriesUpdate(ctx context.Context, input []*models.
inputMaps := getUpdateInputMaps(ctx)
// Start the transaction and save the gallery
- if err := r.withTxn(ctx, func(repo models.Repository) error {
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
for i, gallery := range input {
translator := changesetTranslator{
inputMap: inputMaps[i],
}
- thisGallery, err := r.galleryUpdate(*gallery, translator, repo)
+ thisGallery, err := r.galleryUpdate(ctx, *gallery, translator)
if err != nil {
return err
}
@@ -196,8 +201,8 @@ func (r *mutationResolver) GalleriesUpdate(ctx context.Context, input []*models.
return newRet, nil
}
-func (r *mutationResolver) galleryUpdate(input models.GalleryUpdateInput, translator changesetTranslator, repo models.Repository) (*models.Gallery, error) {
- qb := repo.Gallery()
+func (r *mutationResolver) galleryUpdate(ctx context.Context, input models.GalleryUpdateInput, translator changesetTranslator) (*models.Gallery, error) {
+ qb := r.repository.Gallery
// Populate gallery from the input
galleryID, err := strconv.Atoi(input.ID)
@@ -205,7 +210,7 @@ func (r *mutationResolver) galleryUpdate(input models.GalleryUpdateInput, transl
return nil, err
}
- originalGallery, err := qb.Find(galleryID)
+ originalGallery, err := qb.Find(ctx, galleryID)
if err != nil {
return nil, err
}
@@ -244,28 +249,28 @@ func (r *mutationResolver) galleryUpdate(input models.GalleryUpdateInput, transl
// gallery scene is set from the scene only
- gallery, err := qb.UpdatePartial(updatedGallery)
+ gallery, err := qb.UpdatePartial(ctx, updatedGallery)
if err != nil {
return nil, err
}
// Save the performers
if translator.hasField("performer_ids") {
- if err := r.updateGalleryPerformers(qb, galleryID, input.PerformerIds); err != nil {
+ if err := r.updateGalleryPerformers(ctx, qb, galleryID, input.PerformerIds); err != nil {
return nil, err
}
}
// Save the tags
if translator.hasField("tag_ids") {
- if err := r.updateGalleryTags(qb, galleryID, input.TagIds); err != nil {
+ if err := r.updateGalleryTags(ctx, qb, galleryID, input.TagIds); err != nil {
return nil, err
}
}
// Save the scenes
if translator.hasField("scene_ids") {
- if err := r.updateGalleryScenes(qb, galleryID, input.SceneIds); err != nil {
+ if err := r.updateGalleryScenes(ctx, qb, galleryID, input.SceneIds); err != nil {
return nil, err
}
}
@@ -295,14 +300,14 @@ func (r *mutationResolver) BulkGalleryUpdate(ctx context.Context, input BulkGall
ret := []*models.Gallery{}
// Start the transaction and save the galleries
- if err := r.withTxn(ctx, func(repo models.Repository) error {
- qb := repo.Gallery()
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ qb := r.repository.Gallery
for _, galleryIDStr := range input.Ids {
galleryID, _ := strconv.Atoi(galleryIDStr)
updatedGallery.ID = galleryID
- gallery, err := qb.UpdatePartial(updatedGallery)
+ gallery, err := qb.UpdatePartial(ctx, updatedGallery)
if err != nil {
return err
}
@@ -311,36 +316,36 @@ func (r *mutationResolver) BulkGalleryUpdate(ctx context.Context, input BulkGall
// Save the performers
if translator.hasField("performer_ids") {
- performerIDs, err := adjustGalleryPerformerIDs(qb, galleryID, *input.PerformerIds)
+ performerIDs, err := adjustGalleryPerformerIDs(ctx, qb, galleryID, *input.PerformerIds)
if err != nil {
return err
}
- if err := qb.UpdatePerformers(galleryID, performerIDs); err != nil {
+ if err := qb.UpdatePerformers(ctx, galleryID, performerIDs); err != nil {
return err
}
}
// Save the tags
if translator.hasField("tag_ids") {
- tagIDs, err := adjustGalleryTagIDs(qb, galleryID, *input.TagIds)
+ tagIDs, err := adjustGalleryTagIDs(ctx, qb, galleryID, *input.TagIds)
if err != nil {
return err
}
- if err := qb.UpdateTags(galleryID, tagIDs); err != nil {
+ if err := qb.UpdateTags(ctx, galleryID, tagIDs); err != nil {
return err
}
}
// Save the scenes
if translator.hasField("scene_ids") {
- sceneIDs, err := adjustGallerySceneIDs(qb, galleryID, *input.SceneIds)
+ sceneIDs, err := adjustGallerySceneIDs(ctx, qb, galleryID, *input.SceneIds)
if err != nil {
return err
}
- if err := qb.UpdateScenes(galleryID, sceneIDs); err != nil {
+ if err := qb.UpdateScenes(ctx, galleryID, sceneIDs); err != nil {
return err
}
}
@@ -367,8 +372,20 @@ func (r *mutationResolver) BulkGalleryUpdate(ctx context.Context, input BulkGall
return newRet, nil
}
-func adjustGalleryPerformerIDs(qb models.GalleryReader, galleryID int, ids BulkUpdateIds) (ret []int, err error) {
- ret, err = qb.GetPerformerIDs(galleryID)
+type GalleryPerformerGetter interface {
+ GetPerformerIDs(ctx context.Context, galleryID int) ([]int, error)
+}
+
+type GalleryTagGetter interface {
+ GetTagIDs(ctx context.Context, galleryID int) ([]int, error)
+}
+
+type GallerySceneGetter interface {
+ GetSceneIDs(ctx context.Context, galleryID int) ([]int, error)
+}
+
+func adjustGalleryPerformerIDs(ctx context.Context, qb GalleryPerformerGetter, galleryID int, ids BulkUpdateIds) (ret []int, err error) {
+ ret, err = qb.GetPerformerIDs(ctx, galleryID)
if err != nil {
return nil, err
}
@@ -376,8 +393,8 @@ func adjustGalleryPerformerIDs(qb models.GalleryReader, galleryID int, ids BulkU
return adjustIDs(ret, ids), nil
}
-func adjustGalleryTagIDs(qb models.GalleryReader, galleryID int, ids BulkUpdateIds) (ret []int, err error) {
- ret, err = qb.GetTagIDs(galleryID)
+func adjustGalleryTagIDs(ctx context.Context, qb GalleryTagGetter, galleryID int, ids BulkUpdateIds) (ret []int, err error) {
+ ret, err = qb.GetTagIDs(ctx, galleryID)
if err != nil {
return nil, err
}
@@ -385,8 +402,8 @@ func adjustGalleryTagIDs(qb models.GalleryReader, galleryID int, ids BulkUpdateI
return adjustIDs(ret, ids), nil
}
-func adjustGallerySceneIDs(qb models.GalleryReader, galleryID int, ids BulkUpdateIds) (ret []int, err error) {
- ret, err = qb.GetSceneIDs(galleryID)
+func adjustGallerySceneIDs(ctx context.Context, qb GallerySceneGetter, galleryID int, ids BulkUpdateIds) (ret []int, err error) {
+ ret, err = qb.GetSceneIDs(ctx, galleryID)
if err != nil {
return nil, err
}
@@ -410,12 +427,12 @@ func (r *mutationResolver) GalleryDestroy(ctx context.Context, input models.Gall
deleteGenerated := utils.IsTrue(input.DeleteGenerated)
deleteFile := utils.IsTrue(input.DeleteFile)
- if err := r.withTxn(ctx, func(repo models.Repository) error {
- qb := repo.Gallery()
- iqb := repo.Image()
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ qb := r.repository.Gallery
+ iqb := r.repository.Image
for _, id := range galleryIDs {
- gallery, err := qb.Find(id)
+ gallery, err := qb.Find(ctx, id)
if err != nil {
return err
}
@@ -428,13 +445,13 @@ func (r *mutationResolver) GalleryDestroy(ctx context.Context, input models.Gall
// if this is a zip-based gallery, delete the images as well first
if gallery.Zip {
- imgs, err := iqb.FindByGalleryID(id)
+ imgs, err := iqb.FindByGalleryID(ctx, id)
if err != nil {
return err
}
for _, img := range imgs {
- if err := image.Destroy(img, iqb, fileDeleter, deleteGenerated, false); err != nil {
+ if err := image.Destroy(ctx, img, iqb, fileDeleter, deleteGenerated, false); err != nil {
return err
}
@@ -448,19 +465,19 @@ func (r *mutationResolver) GalleryDestroy(ctx context.Context, input models.Gall
}
} else if deleteFile {
// Delete image if it is only attached to this gallery
- imgs, err := iqb.FindByGalleryID(id)
+ imgs, err := iqb.FindByGalleryID(ctx, id)
if err != nil {
return err
}
for _, img := range imgs {
- imgGalleries, err := qb.FindByImageID(img.ID)
+ imgGalleries, err := qb.FindByImageID(ctx, img.ID)
if err != nil {
return err
}
if len(imgGalleries) == 1 {
- if err := image.Destroy(img, iqb, fileDeleter, deleteGenerated, deleteFile); err != nil {
+ if err := image.Destroy(ctx, img, iqb, fileDeleter, deleteGenerated, deleteFile); err != nil {
return err
}
@@ -472,7 +489,7 @@ func (r *mutationResolver) GalleryDestroy(ctx context.Context, input models.Gall
// don't do this with the file deleter
}
- if err := qb.Destroy(id); err != nil {
+ if err := qb.Destroy(ctx, id); err != nil {
return err
}
}
@@ -537,9 +554,9 @@ func (r *mutationResolver) AddGalleryImages(ctx context.Context, input GalleryAd
return false, err
}
- if err := r.withTxn(ctx, func(repo models.Repository) error {
- qb := repo.Gallery()
- gallery, err := qb.Find(galleryID)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ qb := r.repository.Gallery
+ gallery, err := qb.Find(ctx, galleryID)
if err != nil {
return err
}
@@ -552,13 +569,13 @@ func (r *mutationResolver) AddGalleryImages(ctx context.Context, input GalleryAd
return errors.New("cannot modify zip gallery images")
}
- newIDs, err := qb.GetImageIDs(galleryID)
+ newIDs, err := qb.GetImageIDs(ctx, galleryID)
if err != nil {
return err
}
newIDs = intslice.IntAppendUniques(newIDs, imageIDs)
- return qb.UpdateImages(galleryID, newIDs)
+ return qb.UpdateImages(ctx, galleryID, newIDs)
}); err != nil {
return false, err
}
@@ -577,9 +594,9 @@ func (r *mutationResolver) RemoveGalleryImages(ctx context.Context, input Galler
return false, err
}
- if err := r.withTxn(ctx, func(repo models.Repository) error {
- qb := repo.Gallery()
- gallery, err := qb.Find(galleryID)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ qb := r.repository.Gallery
+ gallery, err := qb.Find(ctx, galleryID)
if err != nil {
return err
}
@@ -592,13 +609,13 @@ func (r *mutationResolver) RemoveGalleryImages(ctx context.Context, input Galler
return errors.New("cannot modify zip gallery images")
}
- newIDs, err := qb.GetImageIDs(galleryID)
+ newIDs, err := qb.GetImageIDs(ctx, galleryID)
if err != nil {
return err
}
newIDs = intslice.IntExclude(newIDs, imageIDs)
- return qb.UpdateImages(galleryID, newIDs)
+ return qb.UpdateImages(ctx, galleryID, newIDs)
}); err != nil {
return false, err
}
diff --git a/internal/api/resolver_mutation_image.go b/internal/api/resolver_mutation_image.go
index 2df016be2..72d98696a 100644
--- a/internal/api/resolver_mutation_image.go
+++ b/internal/api/resolver_mutation_image.go
@@ -16,8 +16,8 @@ import (
)
func (r *mutationResolver) getImage(ctx context.Context, id int) (ret *models.Image, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- ret, err = repo.Image().Find(id)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.repository.Image.Find(ctx, id)
return err
}); err != nil {
return nil, err
@@ -32,8 +32,8 @@ func (r *mutationResolver) ImageUpdate(ctx context.Context, input ImageUpdateInp
}
// Start the transaction and save the image
- if err := r.withTxn(ctx, func(repo models.Repository) error {
- ret, err = r.imageUpdate(input, translator, repo)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.imageUpdate(ctx, input, translator)
return err
}); err != nil {
return nil, err
@@ -48,13 +48,13 @@ func (r *mutationResolver) ImagesUpdate(ctx context.Context, input []*ImageUpdat
inputMaps := getUpdateInputMaps(ctx)
// Start the transaction and save the image
- if err := r.withTxn(ctx, func(repo models.Repository) error {
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
for i, image := range input {
translator := changesetTranslator{
inputMap: inputMaps[i],
}
- thisImage, err := r.imageUpdate(*image, translator, repo)
+ thisImage, err := r.imageUpdate(ctx, *image, translator)
if err != nil {
return err
}
@@ -86,7 +86,7 @@ func (r *mutationResolver) ImagesUpdate(ctx context.Context, input []*ImageUpdat
return newRet, nil
}
-func (r *mutationResolver) imageUpdate(input ImageUpdateInput, translator changesetTranslator, repo models.Repository) (*models.Image, error) {
+func (r *mutationResolver) imageUpdate(ctx context.Context, input ImageUpdateInput, translator changesetTranslator) (*models.Image, error) {
// Populate image from the input
imageID, err := strconv.Atoi(input.ID)
if err != nil {
@@ -104,28 +104,28 @@ func (r *mutationResolver) imageUpdate(input ImageUpdateInput, translator change
updatedImage.StudioID = translator.nullInt64FromString(input.StudioID, "studio_id")
updatedImage.Organized = input.Organized
- qb := repo.Image()
- image, err := qb.Update(updatedImage)
+ qb := r.repository.Image
+ image, err := qb.Update(ctx, updatedImage)
if err != nil {
return nil, err
}
if translator.hasField("gallery_ids") {
- if err := r.updateImageGalleries(qb, imageID, input.GalleryIds); err != nil {
+ if err := r.updateImageGalleries(ctx, imageID, input.GalleryIds); err != nil {
return nil, err
}
}
// Save the performers
if translator.hasField("performer_ids") {
- if err := r.updateImagePerformers(qb, imageID, input.PerformerIds); err != nil {
+ if err := r.updateImagePerformers(ctx, imageID, input.PerformerIds); err != nil {
return nil, err
}
}
// Save the tags
if translator.hasField("tag_ids") {
- if err := r.updateImageTags(qb, imageID, input.TagIds); err != nil {
+ if err := r.updateImageTags(ctx, imageID, input.TagIds); err != nil {
return nil, err
}
}
@@ -133,28 +133,28 @@ func (r *mutationResolver) imageUpdate(input ImageUpdateInput, translator change
return image, nil
}
-func (r *mutationResolver) updateImageGalleries(qb models.ImageReaderWriter, imageID int, galleryIDs []string) error {
+func (r *mutationResolver) updateImageGalleries(ctx context.Context, imageID int, galleryIDs []string) error {
ids, err := stringslice.StringSliceToIntSlice(galleryIDs)
if err != nil {
return err
}
- return qb.UpdateGalleries(imageID, ids)
+ return r.repository.Image.UpdateGalleries(ctx, imageID, ids)
}
-func (r *mutationResolver) updateImagePerformers(qb models.ImageReaderWriter, imageID int, performerIDs []string) error {
+func (r *mutationResolver) updateImagePerformers(ctx context.Context, imageID int, performerIDs []string) error {
ids, err := stringslice.StringSliceToIntSlice(performerIDs)
if err != nil {
return err
}
- return qb.UpdatePerformers(imageID, ids)
+ return r.repository.Image.UpdatePerformers(ctx, imageID, ids)
}
-func (r *mutationResolver) updateImageTags(qb models.ImageReaderWriter, imageID int, tagsIDs []string) error {
+func (r *mutationResolver) updateImageTags(ctx context.Context, imageID int, tagsIDs []string) error {
ids, err := stringslice.StringSliceToIntSlice(tagsIDs)
if err != nil {
return err
}
- return qb.UpdateTags(imageID, ids)
+ return r.repository.Image.UpdateTags(ctx, imageID, ids)
}
func (r *mutationResolver) BulkImageUpdate(ctx context.Context, input BulkImageUpdateInput) (ret []*models.Image, err error) {
@@ -180,13 +180,13 @@ func (r *mutationResolver) BulkImageUpdate(ctx context.Context, input BulkImageU
updatedImage.Organized = input.Organized
// Start the transaction and save the image marker
- if err := r.withTxn(ctx, func(repo models.Repository) error {
- qb := repo.Image()
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ qb := r.repository.Image
for _, imageID := range imageIDs {
updatedImage.ID = imageID
- image, err := qb.Update(updatedImage)
+ image, err := qb.Update(ctx, updatedImage)
if err != nil {
return err
}
@@ -195,36 +195,36 @@ func (r *mutationResolver) BulkImageUpdate(ctx context.Context, input BulkImageU
// Save the galleries
if translator.hasField("gallery_ids") {
- galleryIDs, err := adjustImageGalleryIDs(qb, imageID, *input.GalleryIds)
+ galleryIDs, err := r.adjustImageGalleryIDs(ctx, imageID, *input.GalleryIds)
if err != nil {
return err
}
- if err := qb.UpdateGalleries(imageID, galleryIDs); err != nil {
+ if err := qb.UpdateGalleries(ctx, imageID, galleryIDs); err != nil {
return err
}
}
// Save the performers
if translator.hasField("performer_ids") {
- performerIDs, err := adjustImagePerformerIDs(qb, imageID, *input.PerformerIds)
+ performerIDs, err := r.adjustImagePerformerIDs(ctx, imageID, *input.PerformerIds)
if err != nil {
return err
}
- if err := qb.UpdatePerformers(imageID, performerIDs); err != nil {
+ if err := qb.UpdatePerformers(ctx, imageID, performerIDs); err != nil {
return err
}
}
// Save the tags
if translator.hasField("tag_ids") {
- tagIDs, err := adjustImageTagIDs(qb, imageID, *input.TagIds)
+ tagIDs, err := r.adjustImageTagIDs(ctx, imageID, *input.TagIds)
if err != nil {
return err
}
- if err := qb.UpdateTags(imageID, tagIDs); err != nil {
+ if err := qb.UpdateTags(ctx, imageID, tagIDs); err != nil {
return err
}
}
@@ -251,8 +251,8 @@ func (r *mutationResolver) BulkImageUpdate(ctx context.Context, input BulkImageU
return newRet, nil
}
-func adjustImageGalleryIDs(qb models.ImageReader, imageID int, ids BulkUpdateIds) (ret []int, err error) {
- ret, err = qb.GetGalleryIDs(imageID)
+func (r *mutationResolver) adjustImageGalleryIDs(ctx context.Context, imageID int, ids BulkUpdateIds) (ret []int, err error) {
+ ret, err = r.repository.Image.GetGalleryIDs(ctx, imageID)
if err != nil {
return nil, err
}
@@ -260,8 +260,8 @@ func adjustImageGalleryIDs(qb models.ImageReader, imageID int, ids BulkUpdateIds
return adjustIDs(ret, ids), nil
}
-func adjustImagePerformerIDs(qb models.ImageReader, imageID int, ids BulkUpdateIds) (ret []int, err error) {
- ret, err = qb.GetPerformerIDs(imageID)
+func (r *mutationResolver) adjustImagePerformerIDs(ctx context.Context, imageID int, ids BulkUpdateIds) (ret []int, err error) {
+ ret, err = r.repository.Image.GetPerformerIDs(ctx, imageID)
if err != nil {
return nil, err
}
@@ -269,8 +269,8 @@ func adjustImagePerformerIDs(qb models.ImageReader, imageID int, ids BulkUpdateI
return adjustIDs(ret, ids), nil
}
-func adjustImageTagIDs(qb models.ImageReader, imageID int, ids BulkUpdateIds) (ret []int, err error) {
- ret, err = qb.GetTagIDs(imageID)
+func (r *mutationResolver) adjustImageTagIDs(ctx context.Context, imageID int, ids BulkUpdateIds) (ret []int, err error) {
+ ret, err = r.repository.Image.GetTagIDs(ctx, imageID)
if err != nil {
return nil, err
}
@@ -289,10 +289,10 @@ func (r *mutationResolver) ImageDestroy(ctx context.Context, input models.ImageD
Deleter: *file.NewDeleter(),
Paths: manager.GetInstance().Paths,
}
- if err := r.withTxn(ctx, func(repo models.Repository) error {
- qb := repo.Image()
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ qb := r.repository.Image
- i, err = qb.Find(imageID)
+ i, err = r.repository.Image.Find(ctx, imageID)
if err != nil {
return err
}
@@ -301,7 +301,7 @@ func (r *mutationResolver) ImageDestroy(ctx context.Context, input models.ImageD
return fmt.Errorf("image with id %d not found", imageID)
}
- return image.Destroy(i, qb, fileDeleter, utils.IsTrue(input.DeleteGenerated), utils.IsTrue(input.DeleteFile))
+ return image.Destroy(ctx, i, qb, fileDeleter, utils.IsTrue(input.DeleteGenerated), utils.IsTrue(input.DeleteFile))
}); err != nil {
fileDeleter.Rollback()
return false, err
@@ -331,12 +331,12 @@ func (r *mutationResolver) ImagesDestroy(ctx context.Context, input models.Image
Deleter: *file.NewDeleter(),
Paths: manager.GetInstance().Paths,
}
- if err := r.withTxn(ctx, func(repo models.Repository) error {
- qb := repo.Image()
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ qb := r.repository.Image
for _, imageID := range imageIDs {
- i, err := qb.Find(imageID)
+ i, err := qb.Find(ctx, imageID)
if err != nil {
return err
}
@@ -347,7 +347,7 @@ func (r *mutationResolver) ImagesDestroy(ctx context.Context, input models.Image
images = append(images, i)
- if err := image.Destroy(i, qb, fileDeleter, utils.IsTrue(input.DeleteGenerated), utils.IsTrue(input.DeleteFile)); err != nil {
+ if err := image.Destroy(ctx, i, qb, fileDeleter, utils.IsTrue(input.DeleteGenerated), utils.IsTrue(input.DeleteFile)); err != nil {
return err
}
}
@@ -379,10 +379,10 @@ func (r *mutationResolver) ImageIncrementO(ctx context.Context, id string) (ret
return 0, err
}
- if err := r.withTxn(ctx, func(repo models.Repository) error {
- qb := repo.Image()
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ qb := r.repository.Image
- ret, err = qb.IncrementOCounter(imageID)
+ ret, err = qb.IncrementOCounter(ctx, imageID)
return err
}); err != nil {
return 0, err
@@ -397,10 +397,10 @@ func (r *mutationResolver) ImageDecrementO(ctx context.Context, id string) (ret
return 0, err
}
- if err := r.withTxn(ctx, func(repo models.Repository) error {
- qb := repo.Image()
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ qb := r.repository.Image
- ret, err = qb.DecrementOCounter(imageID)
+ ret, err = qb.DecrementOCounter(ctx, imageID)
return err
}); err != nil {
return 0, err
@@ -415,10 +415,10 @@ func (r *mutationResolver) ImageResetO(ctx context.Context, id string) (ret int,
return 0, err
}
- if err := r.withTxn(ctx, func(repo models.Repository) error {
- qb := repo.Image()
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ qb := r.repository.Image
- ret, err = qb.ResetOCounter(imageID)
+ ret, err = qb.ResetOCounter(ctx, imageID)
return err
}); err != nil {
return 0, err
diff --git a/internal/api/resolver_mutation_metadata.go b/internal/api/resolver_mutation_metadata.go
index 0d1794e58..ff8635536 100644
--- a/internal/api/resolver_mutation_metadata.go
+++ b/internal/api/resolver_mutation_metadata.go
@@ -12,7 +12,6 @@ import (
"github.com/stashapp/stash/internal/identify"
"github.com/stashapp/stash/internal/manager"
"github.com/stashapp/stash/internal/manager/config"
- "github.com/stashapp/stash/pkg/database"
"github.com/stashapp/stash/pkg/fsutil"
"github.com/stashapp/stash/pkg/logger"
)
@@ -111,6 +110,7 @@ func (r *mutationResolver) BackupDatabase(ctx context.Context, input BackupDatab
// if download is true, then backup to temporary file and return a link
download := input.Download != nil && *input.Download
mgr := manager.GetInstance()
+ database := mgr.Database
var backupPath string
if download {
if err := fsutil.EnsureDir(mgr.Paths.Generated.Downloads); err != nil {
@@ -127,7 +127,7 @@ func (r *mutationResolver) BackupDatabase(ctx context.Context, input BackupDatab
backupPath = database.DatabaseBackupPath()
}
- err := database.Backup(database.DB, backupPath)
+ err := database.Backup(backupPath)
if err != nil {
return nil, err
}
diff --git a/internal/api/resolver_mutation_movie.go b/internal/api/resolver_mutation_movie.go
index da6dfdfe7..0a22350b6 100644
--- a/internal/api/resolver_mutation_movie.go
+++ b/internal/api/resolver_mutation_movie.go
@@ -15,8 +15,8 @@ import (
)
func (r *mutationResolver) getMovie(ctx context.Context, id int) (ret *models.Movie, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- ret, err = repo.Movie().Find(id)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.repository.Movie.Find(ctx, id)
return err
}); err != nil {
return nil, err
@@ -100,16 +100,16 @@ func (r *mutationResolver) MovieCreate(ctx context.Context, input MovieCreateInp
// Start the transaction and save the movie
var movie *models.Movie
- if err := r.withTxn(ctx, func(repo models.Repository) error {
- qb := repo.Movie()
- movie, err = qb.Create(newMovie)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ qb := r.repository.Movie
+ movie, err = qb.Create(ctx, newMovie)
if err != nil {
return err
}
// update image table
if len(frontimageData) > 0 {
- if err := qb.UpdateImages(movie.ID, frontimageData, backimageData); err != nil {
+ if err := qb.UpdateImages(ctx, movie.ID, frontimageData, backimageData); err != nil {
return err
}
}
@@ -174,9 +174,9 @@ func (r *mutationResolver) MovieUpdate(ctx context.Context, input MovieUpdateInp
// Start the transaction and save the movie
var movie *models.Movie
- if err := r.withTxn(ctx, func(repo models.Repository) error {
- qb := repo.Movie()
- movie, err = qb.Update(updatedMovie)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ qb := r.repository.Movie
+ movie, err = qb.Update(ctx, updatedMovie)
if err != nil {
return err
}
@@ -184,13 +184,13 @@ func (r *mutationResolver) MovieUpdate(ctx context.Context, input MovieUpdateInp
// update image table
if frontImageIncluded || backImageIncluded {
if !frontImageIncluded {
- frontimageData, err = qb.GetFrontImage(updatedMovie.ID)
+ frontimageData, err = qb.GetFrontImage(ctx, updatedMovie.ID)
if err != nil {
return err
}
}
if !backImageIncluded {
- backimageData, err = qb.GetBackImage(updatedMovie.ID)
+ backimageData, err = qb.GetBackImage(ctx, updatedMovie.ID)
if err != nil {
return err
}
@@ -198,7 +198,7 @@ func (r *mutationResolver) MovieUpdate(ctx context.Context, input MovieUpdateInp
if len(frontimageData) == 0 && len(backimageData) == 0 {
// both images are being nulled. Destroy them.
- if err := qb.DestroyImages(movie.ID); err != nil {
+ if err := qb.DestroyImages(ctx, movie.ID); err != nil {
return err
}
} else {
@@ -208,7 +208,7 @@ func (r *mutationResolver) MovieUpdate(ctx context.Context, input MovieUpdateInp
frontimageData, _ = utils.ProcessImageInput(ctx, models.DefaultMovieImage)
}
- if err := qb.UpdateImages(movie.ID, frontimageData, backimageData); err != nil {
+ if err := qb.UpdateImages(ctx, movie.ID, frontimageData, backimageData); err != nil {
return err
}
}
@@ -245,13 +245,13 @@ func (r *mutationResolver) BulkMovieUpdate(ctx context.Context, input BulkMovieU
ret := []*models.Movie{}
- if err := r.withTxn(ctx, func(repo models.Repository) error {
- qb := repo.Movie()
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ qb := r.repository.Movie
for _, movieID := range movieIDs {
updatedMovie.ID = movieID
- existing, err := qb.Find(movieID)
+ existing, err := qb.Find(ctx, movieID)
if err != nil {
return err
}
@@ -260,7 +260,7 @@ func (r *mutationResolver) BulkMovieUpdate(ctx context.Context, input BulkMovieU
return fmt.Errorf("movie with id %d not found", movieID)
}
- movie, err := qb.Update(updatedMovie)
+ movie, err := qb.Update(ctx, updatedMovie)
if err != nil {
return err
}
@@ -294,8 +294,8 @@ func (r *mutationResolver) MovieDestroy(ctx context.Context, input MovieDestroyI
return false, err
}
- if err := r.withTxn(ctx, func(repo models.Repository) error {
- return repo.Movie().Destroy(id)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ return r.repository.Movie.Destroy(ctx, id)
}); err != nil {
return false, err
}
@@ -311,10 +311,10 @@ func (r *mutationResolver) MoviesDestroy(ctx context.Context, movieIDs []string)
return false, err
}
- if err := r.withTxn(ctx, func(repo models.Repository) error {
- qb := repo.Movie()
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ qb := r.repository.Movie
for _, id := range ids {
- if err := qb.Destroy(id); err != nil {
+ if err := qb.Destroy(ctx, id); err != nil {
return err
}
}
diff --git a/internal/api/resolver_mutation_performer.go b/internal/api/resolver_mutation_performer.go
index 02c942484..a5fd19dea 100644
--- a/internal/api/resolver_mutation_performer.go
+++ b/internal/api/resolver_mutation_performer.go
@@ -16,8 +16,8 @@ import (
)
func (r *mutationResolver) getPerformer(ctx context.Context, id int) (ret *models.Performer, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- ret, err = repo.Performer().Find(id)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.repository.Performer.Find(ctx, id)
return err
}); err != nil {
return nil, err
@@ -129,23 +129,23 @@ func (r *mutationResolver) PerformerCreate(ctx context.Context, input PerformerC
// Start the transaction and save the performer
var performer *models.Performer
- if err := r.withTxn(ctx, func(repo models.Repository) error {
- qb := repo.Performer()
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ qb := r.repository.Performer
- performer, err = qb.Create(newPerformer)
+ performer, err = qb.Create(ctx, newPerformer)
if err != nil {
return err
}
if len(input.TagIds) > 0 {
- if err := r.updatePerformerTags(qb, performer.ID, input.TagIds); err != nil {
+ if err := r.updatePerformerTags(ctx, performer.ID, input.TagIds); err != nil {
return err
}
}
// update image table
if len(imageData) > 0 {
- if err := qb.UpdateImage(performer.ID, imageData); err != nil {
+ if err := qb.UpdateImage(ctx, performer.ID, imageData); err != nil {
return err
}
}
@@ -153,7 +153,7 @@ func (r *mutationResolver) PerformerCreate(ctx context.Context, input PerformerC
// Save the stash_ids
if input.StashIds != nil {
stashIDJoins := models.StashIDsFromInput(input.StashIds)
- if err := qb.UpdateStashIDs(performer.ID, stashIDJoins); err != nil {
+ if err := qb.UpdateStashIDs(ctx, performer.ID, stashIDJoins); err != nil {
return err
}
}
@@ -230,11 +230,11 @@ func (r *mutationResolver) PerformerUpdate(ctx context.Context, input PerformerU
// Start the transaction and save the p
var p *models.Performer
- if err := r.withTxn(ctx, func(repo models.Repository) error {
- qb := repo.Performer()
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ qb := r.repository.Performer
// need to get existing performer
- existing, err := qb.Find(updatedPerformer.ID)
+ existing, err := qb.Find(ctx, updatedPerformer.ID)
if err != nil {
return err
}
@@ -249,26 +249,26 @@ func (r *mutationResolver) PerformerUpdate(ctx context.Context, input PerformerU
}
}
- p, err = qb.Update(updatedPerformer)
+ p, err = qb.Update(ctx, updatedPerformer)
if err != nil {
return err
}
// Save the tags
if translator.hasField("tag_ids") {
- if err := r.updatePerformerTags(qb, p.ID, input.TagIds); err != nil {
+ if err := r.updatePerformerTags(ctx, p.ID, input.TagIds); err != nil {
return err
}
}
// update image table
if len(imageData) > 0 {
- if err := qb.UpdateImage(p.ID, imageData); err != nil {
+ if err := qb.UpdateImage(ctx, p.ID, imageData); err != nil {
return err
}
} else if imageIncluded {
// must be unsetting
- if err := qb.DestroyImage(p.ID); err != nil {
+ if err := qb.DestroyImage(ctx, p.ID); err != nil {
return err
}
}
@@ -276,7 +276,7 @@ func (r *mutationResolver) PerformerUpdate(ctx context.Context, input PerformerU
// Save the stash_ids
if translator.hasField("stash_ids") {
stashIDJoins := models.StashIDsFromInput(input.StashIds)
- if err := qb.UpdateStashIDs(performerID, stashIDJoins); err != nil {
+ if err := qb.UpdateStashIDs(ctx, performerID, stashIDJoins); err != nil {
return err
}
}
@@ -290,12 +290,12 @@ func (r *mutationResolver) PerformerUpdate(ctx context.Context, input PerformerU
return r.getPerformer(ctx, p.ID)
}
-func (r *mutationResolver) updatePerformerTags(qb models.PerformerReaderWriter, performerID int, tagsIDs []string) error {
+func (r *mutationResolver) updatePerformerTags(ctx context.Context, performerID int, tagsIDs []string) error {
ids, err := stringslice.StringSliceToIntSlice(tagsIDs)
if err != nil {
return err
}
- return qb.UpdateTags(performerID, ids)
+ return r.repository.Performer.UpdateTags(ctx, performerID, ids)
}
func (r *mutationResolver) BulkPerformerUpdate(ctx context.Context, input BulkPerformerUpdateInput) ([]*models.Performer, error) {
@@ -348,14 +348,14 @@ func (r *mutationResolver) BulkPerformerUpdate(ctx context.Context, input BulkPe
ret := []*models.Performer{}
// Start the transaction and save the scene marker
- if err := r.withTxn(ctx, func(repo models.Repository) error {
- qb := repo.Performer()
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ qb := r.repository.Performer
for _, performerID := range performerIDs {
updatedPerformer.ID = performerID
// need to get existing performer
- existing, err := qb.Find(performerID)
+ existing, err := qb.Find(ctx, performerID)
if err != nil {
return err
}
@@ -368,7 +368,7 @@ func (r *mutationResolver) BulkPerformerUpdate(ctx context.Context, input BulkPe
return err
}
- performer, err := qb.Update(updatedPerformer)
+ performer, err := qb.Update(ctx, updatedPerformer)
if err != nil {
return err
}
@@ -377,12 +377,12 @@ func (r *mutationResolver) BulkPerformerUpdate(ctx context.Context, input BulkPe
// Save the tags
if translator.hasField("tag_ids") {
- tagIDs, err := adjustTagIDs(qb, performerID, *input.TagIds)
+ tagIDs, err := adjustTagIDs(ctx, qb, performerID, *input.TagIds)
if err != nil {
return err
}
- if err := qb.UpdateTags(performerID, tagIDs); err != nil {
+ if err := qb.UpdateTags(ctx, performerID, tagIDs); err != nil {
return err
}
}
@@ -415,8 +415,8 @@ func (r *mutationResolver) PerformerDestroy(ctx context.Context, input Performer
return false, err
}
- if err := r.withTxn(ctx, func(repo models.Repository) error {
- return repo.Performer().Destroy(id)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ return r.repository.Performer.Destroy(ctx, id)
}); err != nil {
return false, err
}
@@ -432,10 +432,10 @@ func (r *mutationResolver) PerformersDestroy(ctx context.Context, performerIDs [
return false, err
}
- if err := r.withTxn(ctx, func(repo models.Repository) error {
- qb := repo.Performer()
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ qb := r.repository.Performer
for _, id := range ids {
- if err := qb.Destroy(id); err != nil {
+ if err := qb.Destroy(ctx, id); err != nil {
return err
}
}
diff --git a/internal/api/resolver_mutation_saved_filter.go b/internal/api/resolver_mutation_saved_filter.go
index bf1b4106e..a995060ea 100644
--- a/internal/api/resolver_mutation_saved_filter.go
+++ b/internal/api/resolver_mutation_saved_filter.go
@@ -23,17 +23,17 @@ func (r *mutationResolver) SaveFilter(ctx context.Context, input SaveFilterInput
id = &idv
}
- if err := r.withTxn(ctx, func(repo models.Repository) error {
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
f := models.SavedFilter{
Mode: input.Mode,
Name: input.Name,
Filter: input.Filter,
}
if id == nil {
- ret, err = repo.SavedFilter().Create(f)
+ ret, err = r.repository.SavedFilter.Create(ctx, f)
} else {
f.ID = *id
- ret, err = repo.SavedFilter().Update(f)
+ ret, err = r.repository.SavedFilter.Update(ctx, f)
}
return err
}); err != nil {
@@ -48,8 +48,8 @@ func (r *mutationResolver) DestroySavedFilter(ctx context.Context, input Destroy
return false, err
}
- if err := r.withTxn(ctx, func(repo models.Repository) error {
- return repo.SavedFilter().Destroy(id)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ return r.repository.SavedFilter.Destroy(ctx, id)
}); err != nil {
return false, err
}
@@ -58,24 +58,24 @@ func (r *mutationResolver) DestroySavedFilter(ctx context.Context, input Destroy
}
func (r *mutationResolver) SetDefaultFilter(ctx context.Context, input SetDefaultFilterInput) (bool, error) {
- if err := r.withTxn(ctx, func(repo models.Repository) error {
- qb := repo.SavedFilter()
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ qb := r.repository.SavedFilter
if input.Filter == nil {
// clearing
- def, err := qb.FindDefault(input.Mode)
+ def, err := qb.FindDefault(ctx, input.Mode)
if err != nil {
return err
}
if def != nil {
- return qb.Destroy(def.ID)
+ return qb.Destroy(ctx, def.ID)
}
return nil
}
- _, err := qb.SetDefault(models.SavedFilter{
+ _, err := qb.SetDefault(ctx, models.SavedFilter{
Mode: input.Mode,
Filter: *input.Filter,
})
diff --git a/internal/api/resolver_mutation_scene.go b/internal/api/resolver_mutation_scene.go
index fdaf1d64d..1a9901062 100644
--- a/internal/api/resolver_mutation_scene.go
+++ b/internal/api/resolver_mutation_scene.go
@@ -19,8 +19,8 @@ import (
)
func (r *mutationResolver) getScene(ctx context.Context, id int) (ret *models.Scene, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- ret, err = repo.Scene().Find(id)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.repository.Scene.Find(ctx, id)
return err
}); err != nil {
return nil, err
@@ -35,8 +35,8 @@ func (r *mutationResolver) SceneUpdate(ctx context.Context, input models.SceneUp
}
// Start the transaction and save the scene
- if err := r.withTxn(ctx, func(repo models.Repository) error {
- ret, err = r.sceneUpdate(ctx, input, translator, repo)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.sceneUpdate(ctx, input, translator)
return err
}); err != nil {
return nil, err
@@ -50,13 +50,13 @@ func (r *mutationResolver) ScenesUpdate(ctx context.Context, input []*models.Sce
inputMaps := getUpdateInputMaps(ctx)
// Start the transaction and save the scene
- if err := r.withTxn(ctx, func(repo models.Repository) error {
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
for i, scene := range input {
translator := changesetTranslator{
inputMap: inputMaps[i],
}
- thisScene, err := r.sceneUpdate(ctx, *scene, translator, repo)
+ thisScene, err := r.sceneUpdate(ctx, *scene, translator)
ret = append(ret, thisScene)
if err != nil {
@@ -89,7 +89,7 @@ func (r *mutationResolver) ScenesUpdate(ctx context.Context, input []*models.Sce
return newRet, nil
}
-func (r *mutationResolver) sceneUpdate(ctx context.Context, input models.SceneUpdateInput, translator changesetTranslator, repo models.Repository) (*models.Scene, error) {
+func (r *mutationResolver) sceneUpdate(ctx context.Context, input models.SceneUpdateInput, translator changesetTranslator) (*models.Scene, error) {
// Populate scene from the input
sceneID, err := strconv.Atoi(input.ID)
if err != nil {
@@ -122,43 +122,43 @@ func (r *mutationResolver) sceneUpdate(ctx context.Context, input models.SceneUp
// update the cover after updating the scene
}
- qb := repo.Scene()
- s, err := qb.Update(updatedScene)
+ qb := r.repository.Scene
+ s, err := qb.Update(ctx, updatedScene)
if err != nil {
return nil, err
}
// update cover table
if len(coverImageData) > 0 {
- if err := qb.UpdateCover(sceneID, coverImageData); err != nil {
+ if err := qb.UpdateCover(ctx, sceneID, coverImageData); err != nil {
return nil, err
}
}
// Save the performers
if translator.hasField("performer_ids") {
- if err := r.updateScenePerformers(qb, sceneID, input.PerformerIds); err != nil {
+ if err := r.updateScenePerformers(ctx, sceneID, input.PerformerIds); err != nil {
return nil, err
}
}
// Save the movies
if translator.hasField("movies") {
- if err := r.updateSceneMovies(qb, sceneID, input.Movies); err != nil {
+ if err := r.updateSceneMovies(ctx, sceneID, input.Movies); err != nil {
return nil, err
}
}
// Save the tags
if translator.hasField("tag_ids") {
- if err := r.updateSceneTags(qb, sceneID, input.TagIds); err != nil {
+ if err := r.updateSceneTags(ctx, sceneID, input.TagIds); err != nil {
return nil, err
}
}
// Save the galleries
if translator.hasField("gallery_ids") {
- if err := r.updateSceneGalleries(qb, sceneID, input.GalleryIds); err != nil {
+ if err := r.updateSceneGalleries(ctx, sceneID, input.GalleryIds); err != nil {
return nil, err
}
}
@@ -166,7 +166,7 @@ func (r *mutationResolver) sceneUpdate(ctx context.Context, input models.SceneUp
// Save the stash_ids
if translator.hasField("stash_ids") {
stashIDJoins := models.StashIDsFromInput(input.StashIds)
- if err := qb.UpdateStashIDs(sceneID, stashIDJoins); err != nil {
+ if err := qb.UpdateStashIDs(ctx, sceneID, stashIDJoins); err != nil {
return nil, err
}
}
@@ -182,15 +182,15 @@ func (r *mutationResolver) sceneUpdate(ctx context.Context, input models.SceneUp
return s, nil
}
-func (r *mutationResolver) updateScenePerformers(qb models.SceneReaderWriter, sceneID int, performerIDs []string) error {
+func (r *mutationResolver) updateScenePerformers(ctx context.Context, sceneID int, performerIDs []string) error {
ids, err := stringslice.StringSliceToIntSlice(performerIDs)
if err != nil {
return err
}
- return qb.UpdatePerformers(sceneID, ids)
+ return r.repository.Scene.UpdatePerformers(ctx, sceneID, ids)
}
-func (r *mutationResolver) updateSceneMovies(qb models.SceneReaderWriter, sceneID int, movies []*models.SceneMovieInput) error {
+func (r *mutationResolver) updateSceneMovies(ctx context.Context, sceneID int, movies []*models.SceneMovieInput) error {
var movieJoins []models.MoviesScenes
for _, movie := range movies {
@@ -213,23 +213,23 @@ func (r *mutationResolver) updateSceneMovies(qb models.SceneReaderWriter, sceneI
movieJoins = append(movieJoins, movieJoin)
}
- return qb.UpdateMovies(sceneID, movieJoins)
+ return r.repository.Scene.UpdateMovies(ctx, sceneID, movieJoins)
}
-func (r *mutationResolver) updateSceneTags(qb models.SceneReaderWriter, sceneID int, tagsIDs []string) error {
+func (r *mutationResolver) updateSceneTags(ctx context.Context, sceneID int, tagsIDs []string) error {
ids, err := stringslice.StringSliceToIntSlice(tagsIDs)
if err != nil {
return err
}
- return qb.UpdateTags(sceneID, ids)
+ return r.repository.Scene.UpdateTags(ctx, sceneID, ids)
}
-func (r *mutationResolver) updateSceneGalleries(qb models.SceneReaderWriter, sceneID int, galleryIDs []string) error {
+func (r *mutationResolver) updateSceneGalleries(ctx context.Context, sceneID int, galleryIDs []string) error {
ids, err := stringslice.StringSliceToIntSlice(galleryIDs)
if err != nil {
return err
}
- return qb.UpdateGalleries(sceneID, ids)
+ return r.repository.Scene.UpdateGalleries(ctx, sceneID, ids)
}
func (r *mutationResolver) BulkSceneUpdate(ctx context.Context, input BulkSceneUpdateInput) ([]*models.Scene, error) {
@@ -260,13 +260,13 @@ func (r *mutationResolver) BulkSceneUpdate(ctx context.Context, input BulkSceneU
ret := []*models.Scene{}
// Start the transaction and save the scene marker
- if err := r.withTxn(ctx, func(repo models.Repository) error {
- qb := repo.Scene()
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ qb := r.repository.Scene
for _, sceneID := range sceneIDs {
updatedScene.ID = sceneID
- scene, err := qb.Update(updatedScene)
+ scene, err := qb.Update(ctx, updatedScene)
if err != nil {
return err
}
@@ -275,48 +275,48 @@ func (r *mutationResolver) BulkSceneUpdate(ctx context.Context, input BulkSceneU
// Save the performers
if translator.hasField("performer_ids") {
- performerIDs, err := adjustScenePerformerIDs(qb, sceneID, *input.PerformerIds)
+ performerIDs, err := r.adjustScenePerformerIDs(ctx, sceneID, *input.PerformerIds)
if err != nil {
return err
}
- if err := qb.UpdatePerformers(sceneID, performerIDs); err != nil {
+ if err := qb.UpdatePerformers(ctx, sceneID, performerIDs); err != nil {
return err
}
}
// Save the tags
if translator.hasField("tag_ids") {
- tagIDs, err := adjustTagIDs(qb, sceneID, *input.TagIds)
+ tagIDs, err := adjustTagIDs(ctx, qb, sceneID, *input.TagIds)
if err != nil {
return err
}
- if err := qb.UpdateTags(sceneID, tagIDs); err != nil {
+ if err := qb.UpdateTags(ctx, sceneID, tagIDs); err != nil {
return err
}
}
// Save the galleries
if translator.hasField("gallery_ids") {
- galleryIDs, err := adjustSceneGalleryIDs(qb, sceneID, *input.GalleryIds)
+ galleryIDs, err := r.adjustSceneGalleryIDs(ctx, sceneID, *input.GalleryIds)
if err != nil {
return err
}
- if err := qb.UpdateGalleries(sceneID, galleryIDs); err != nil {
+ if err := qb.UpdateGalleries(ctx, sceneID, galleryIDs); err != nil {
return err
}
}
// Save the movies
if translator.hasField("movie_ids") {
- movies, err := adjustSceneMovieIDs(qb, sceneID, *input.MovieIds)
+ movies, err := r.adjustSceneMovieIDs(ctx, sceneID, *input.MovieIds)
if err != nil {
return err
}
- if err := qb.UpdateMovies(sceneID, movies); err != nil {
+ if err := qb.UpdateMovies(ctx, sceneID, movies); err != nil {
return err
}
}
@@ -380,8 +380,8 @@ func adjustIDs(existingIDs []int, updateIDs BulkUpdateIds) []int {
return existingIDs
}
-func adjustScenePerformerIDs(qb models.SceneReader, sceneID int, ids BulkUpdateIds) (ret []int, err error) {
- ret, err = qb.GetPerformerIDs(sceneID)
+func (r *mutationResolver) adjustScenePerformerIDs(ctx context.Context, sceneID int, ids BulkUpdateIds) (ret []int, err error) {
+ ret, err = r.repository.Scene.GetPerformerIDs(ctx, sceneID)
if err != nil {
return nil, err
}
@@ -390,11 +390,11 @@ func adjustScenePerformerIDs(qb models.SceneReader, sceneID int, ids BulkUpdateI
}
type tagIDsGetter interface {
- GetTagIDs(id int) ([]int, error)
+ GetTagIDs(ctx context.Context, id int) ([]int, error)
}
-func adjustTagIDs(qb tagIDsGetter, sceneID int, ids BulkUpdateIds) (ret []int, err error) {
- ret, err = qb.GetTagIDs(sceneID)
+func adjustTagIDs(ctx context.Context, qb tagIDsGetter, sceneID int, ids BulkUpdateIds) (ret []int, err error) {
+ ret, err = qb.GetTagIDs(ctx, sceneID)
if err != nil {
return nil, err
}
@@ -402,8 +402,8 @@ func adjustTagIDs(qb tagIDsGetter, sceneID int, ids BulkUpdateIds) (ret []int, e
return adjustIDs(ret, ids), nil
}
-func adjustSceneGalleryIDs(qb models.SceneReader, sceneID int, ids BulkUpdateIds) (ret []int, err error) {
- ret, err = qb.GetGalleryIDs(sceneID)
+func (r *mutationResolver) adjustSceneGalleryIDs(ctx context.Context, sceneID int, ids BulkUpdateIds) (ret []int, err error) {
+ ret, err = r.repository.Scene.GetGalleryIDs(ctx, sceneID)
if err != nil {
return nil, err
}
@@ -411,8 +411,8 @@ func adjustSceneGalleryIDs(qb models.SceneReader, sceneID int, ids BulkUpdateIds
return adjustIDs(ret, ids), nil
}
-func adjustSceneMovieIDs(qb models.SceneReader, sceneID int, updateIDs BulkUpdateIds) ([]models.MoviesScenes, error) {
- existingMovies, err := qb.GetMovies(sceneID)
+func (r *mutationResolver) adjustSceneMovieIDs(ctx context.Context, sceneID int, updateIDs BulkUpdateIds) ([]models.MoviesScenes, error) {
+ existingMovies, err := r.repository.Scene.GetMovies(ctx, sceneID)
if err != nil {
return nil, err
}
@@ -471,10 +471,10 @@ func (r *mutationResolver) SceneDestroy(ctx context.Context, input models.SceneD
deleteGenerated := utils.IsTrue(input.DeleteGenerated)
deleteFile := utils.IsTrue(input.DeleteFile)
- if err := r.withTxn(ctx, func(repo models.Repository) error {
- qb := repo.Scene()
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ qb := r.repository.Scene
var err error
- s, err = qb.Find(sceneID)
+ s, err = qb.Find(ctx, sceneID)
if err != nil {
return err
}
@@ -486,7 +486,7 @@ func (r *mutationResolver) SceneDestroy(ctx context.Context, input models.SceneD
// kill any running encoders
manager.KillRunningStreams(s, fileNamingAlgo)
- return scene.Destroy(s, repo, fileDeleter, deleteGenerated, deleteFile)
+ return scene.Destroy(ctx, s, r.repository.Scene, r.repository.SceneMarker, fileDeleter, deleteGenerated, deleteFile)
}); err != nil {
fileDeleter.Rollback()
return false, err
@@ -519,13 +519,13 @@ func (r *mutationResolver) ScenesDestroy(ctx context.Context, input models.Scene
deleteGenerated := utils.IsTrue(input.DeleteGenerated)
deleteFile := utils.IsTrue(input.DeleteFile)
- if err := r.withTxn(ctx, func(repo models.Repository) error {
- qb := repo.Scene()
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ qb := r.repository.Scene
for _, id := range input.Ids {
sceneID, _ := strconv.Atoi(id)
- s, err := qb.Find(sceneID)
+ s, err := qb.Find(ctx, sceneID)
if err != nil {
return err
}
@@ -536,7 +536,7 @@ func (r *mutationResolver) ScenesDestroy(ctx context.Context, input models.Scene
// kill any running encoders
manager.KillRunningStreams(s, fileNamingAlgo)
- if err := scene.Destroy(s, repo, fileDeleter, deleteGenerated, deleteFile); err != nil {
+ if err := scene.Destroy(ctx, s, r.repository.Scene, r.repository.SceneMarker, fileDeleter, deleteGenerated, deleteFile); err != nil {
return err
}
}
@@ -564,8 +564,8 @@ func (r *mutationResolver) ScenesDestroy(ctx context.Context, input models.Scene
}
func (r *mutationResolver) getSceneMarker(ctx context.Context, id int) (ret *models.SceneMarker, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- ret, err = repo.SceneMarker().Find(id)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.repository.SceneMarker.Find(ctx, id)
return err
}); err != nil {
return nil, err
@@ -666,11 +666,11 @@ func (r *mutationResolver) SceneMarkerDestroy(ctx context.Context, id string) (b
Paths: manager.GetInstance().Paths,
}
- if err := r.withTxn(ctx, func(repo models.Repository) error {
- qb := repo.SceneMarker()
- sqb := repo.Scene()
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ qb := r.repository.SceneMarker
+ sqb := r.repository.Scene
- marker, err := qb.Find(markerID)
+ marker, err := qb.Find(ctx, markerID)
if err != nil {
return err
@@ -680,12 +680,12 @@ func (r *mutationResolver) SceneMarkerDestroy(ctx context.Context, id string) (b
return fmt.Errorf("scene marker with id %d not found", markerID)
}
- s, err := sqb.Find(int(marker.SceneID.Int64))
+ s, err := sqb.Find(ctx, int(marker.SceneID.Int64))
if err != nil {
return err
}
- return scene.DestroyMarker(s, marker, qb, fileDeleter)
+ return scene.DestroyMarker(ctx, s, marker, qb, fileDeleter)
}); err != nil {
fileDeleter.Rollback()
return false, err
@@ -713,26 +713,26 @@ func (r *mutationResolver) changeMarker(ctx context.Context, changeType int, cha
}
// Start the transaction and save the scene marker
- if err := r.withTxn(ctx, func(repo models.Repository) error {
- qb := repo.SceneMarker()
- sqb := repo.Scene()
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ qb := r.repository.SceneMarker
+ sqb := r.repository.Scene
var err error
switch changeType {
case create:
- sceneMarker, err = qb.Create(changedMarker)
+ sceneMarker, err = qb.Create(ctx, changedMarker)
case update:
// check to see if timestamp was changed
- existingMarker, err = qb.Find(changedMarker.ID)
+ existingMarker, err = qb.Find(ctx, changedMarker.ID)
if err != nil {
return err
}
- sceneMarker, err = qb.Update(changedMarker)
+ sceneMarker, err = qb.Update(ctx, changedMarker)
if err != nil {
return err
}
- s, err = sqb.Find(int(existingMarker.SceneID.Int64))
+ s, err = sqb.Find(ctx, int(existingMarker.SceneID.Int64))
}
if err != nil {
return err
@@ -749,7 +749,7 @@ func (r *mutationResolver) changeMarker(ctx context.Context, changeType int, cha
// Save the marker tags
// If this tag is the primary tag, then let's not add it.
tagIDs = intslice.IntExclude(tagIDs, []int{changedMarker.PrimaryTagID})
- return qb.UpdateTags(sceneMarker.ID, tagIDs)
+ return qb.UpdateTags(ctx, sceneMarker.ID, tagIDs)
}); err != nil {
fileDeleter.Rollback()
return nil, err
@@ -766,10 +766,10 @@ func (r *mutationResolver) SceneIncrementO(ctx context.Context, id string) (ret
return 0, err
}
- if err := r.withTxn(ctx, func(repo models.Repository) error {
- qb := repo.Scene()
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ qb := r.repository.Scene
- ret, err = qb.IncrementOCounter(sceneID)
+ ret, err = qb.IncrementOCounter(ctx, sceneID)
return err
}); err != nil {
return 0, err
@@ -784,10 +784,10 @@ func (r *mutationResolver) SceneDecrementO(ctx context.Context, id string) (ret
return 0, err
}
- if err := r.withTxn(ctx, func(repo models.Repository) error {
- qb := repo.Scene()
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ qb := r.repository.Scene
- ret, err = qb.DecrementOCounter(sceneID)
+ ret, err = qb.DecrementOCounter(ctx, sceneID)
return err
}); err != nil {
return 0, err
@@ -802,10 +802,10 @@ func (r *mutationResolver) SceneResetO(ctx context.Context, id string) (ret int,
return 0, err
}
- if err := r.withTxn(ctx, func(repo models.Repository) error {
- qb := repo.Scene()
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ qb := r.repository.Scene
- ret, err = qb.ResetOCounter(sceneID)
+ ret, err = qb.ResetOCounter(ctx, sceneID)
return err
}); err != nil {
return 0, err
diff --git a/internal/api/resolver_mutation_stash_box.go b/internal/api/resolver_mutation_stash_box.go
index 3d9163317..95e300f9d 100644
--- a/internal/api/resolver_mutation_stash_box.go
+++ b/internal/api/resolver_mutation_stash_box.go
@@ -7,10 +7,18 @@ import (
"github.com/stashapp/stash/internal/manager"
"github.com/stashapp/stash/internal/manager/config"
- "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/scraper/stashbox"
)
+func (r *Resolver) stashboxRepository() stashbox.Repository {
+ return stashbox.Repository{
+ Scene: r.repository.Scene,
+ Performer: r.repository.Performer,
+ Tag: r.repository.Tag,
+ Studio: r.repository.Studio,
+ }
+}
+
func (r *mutationResolver) SubmitStashBoxFingerprints(ctx context.Context, input StashBoxFingerprintSubmissionInput) (bool, error) {
boxes := config.GetInstance().GetStashBoxes()
@@ -18,7 +26,7 @@ func (r *mutationResolver) SubmitStashBoxFingerprints(ctx context.Context, input
return false, fmt.Errorf("invalid stash_box_index %d", input.StashBoxIndex)
}
- client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.txnManager)
+ client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.txnManager, r.stashboxRepository())
return client.SubmitStashBoxFingerprints(ctx, input.SceneIds, boxes[input.StashBoxIndex].Endpoint)
}
@@ -35,7 +43,7 @@ func (r *mutationResolver) SubmitStashBoxSceneDraft(ctx context.Context, input S
return nil, fmt.Errorf("invalid stash_box_index %d", input.StashBoxIndex)
}
- client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.txnManager)
+ client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.txnManager, r.stashboxRepository())
id, err := strconv.Atoi(input.ID)
if err != nil {
@@ -43,9 +51,9 @@ func (r *mutationResolver) SubmitStashBoxSceneDraft(ctx context.Context, input S
}
var res *string
- err = r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- qb := repo.Scene()
- scene, err := qb.Find(id)
+ err = r.withTxn(ctx, func(ctx context.Context) error {
+ qb := r.repository.Scene
+ scene, err := qb.Find(ctx, id)
if err != nil {
return err
}
@@ -65,7 +73,7 @@ func (r *mutationResolver) SubmitStashBoxPerformerDraft(ctx context.Context, inp
return nil, fmt.Errorf("invalid stash_box_index %d", input.StashBoxIndex)
}
- client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.txnManager)
+ client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.txnManager, r.stashboxRepository())
id, err := strconv.Atoi(input.ID)
if err != nil {
@@ -73,9 +81,9 @@ func (r *mutationResolver) SubmitStashBoxPerformerDraft(ctx context.Context, inp
}
var res *string
- err = r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- qb := repo.Performer()
- performer, err := qb.Find(id)
+ err = r.withTxn(ctx, func(ctx context.Context) error {
+ qb := r.repository.Performer
+ performer, err := qb.Find(ctx, id)
if err != nil {
return err
}
diff --git a/internal/api/resolver_mutation_studio.go b/internal/api/resolver_mutation_studio.go
index 8fcfd3b53..fde747e3e 100644
--- a/internal/api/resolver_mutation_studio.go
+++ b/internal/api/resolver_mutation_studio.go
@@ -17,8 +17,8 @@ import (
)
func (r *mutationResolver) getStudio(ctx context.Context, id int) (ret *models.Studio, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- ret, err = repo.Studio().Find(id)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.repository.Studio.Find(ctx, id)
return err
}); err != nil {
return nil, err
@@ -72,18 +72,18 @@ func (r *mutationResolver) StudioCreate(ctx context.Context, input StudioCreateI
// Start the transaction and save the studio
var s *models.Studio
- if err := r.withTxn(ctx, func(repo models.Repository) error {
- qb := repo.Studio()
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ qb := r.repository.Studio
var err error
- s, err = qb.Create(newStudio)
+ s, err = qb.Create(ctx, newStudio)
if err != nil {
return err
}
// update image table
if len(imageData) > 0 {
- if err := qb.UpdateImage(s.ID, imageData); err != nil {
+ if err := qb.UpdateImage(ctx, s.ID, imageData); err != nil {
return err
}
}
@@ -91,17 +91,17 @@ func (r *mutationResolver) StudioCreate(ctx context.Context, input StudioCreateI
// Save the stash_ids
if input.StashIds != nil {
stashIDJoins := models.StashIDsFromInput(input.StashIds)
- if err := qb.UpdateStashIDs(s.ID, stashIDJoins); err != nil {
+ if err := qb.UpdateStashIDs(ctx, s.ID, stashIDJoins); err != nil {
return err
}
}
if len(input.Aliases) > 0 {
- if err := studio.EnsureAliasesUnique(s.ID, input.Aliases, qb); err != nil {
+ if err := studio.EnsureAliasesUnique(ctx, s.ID, input.Aliases, qb); err != nil {
return err
}
- if err := qb.UpdateAliases(s.ID, input.Aliases); err != nil {
+ if err := qb.UpdateAliases(ctx, s.ID, input.Aliases); err != nil {
return err
}
}
@@ -155,27 +155,27 @@ func (r *mutationResolver) StudioUpdate(ctx context.Context, input StudioUpdateI
// Start the transaction and save the studio
var s *models.Studio
- if err := r.withTxn(ctx, func(repo models.Repository) error {
- qb := repo.Studio()
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ qb := r.repository.Studio
- if err := manager.ValidateModifyStudio(updatedStudio, qb); err != nil {
+ if err := manager.ValidateModifyStudio(ctx, updatedStudio, qb); err != nil {
return err
}
var err error
- s, err = qb.Update(updatedStudio)
+ s, err = qb.Update(ctx, updatedStudio)
if err != nil {
return err
}
// update image table
if len(imageData) > 0 {
- if err := qb.UpdateImage(s.ID, imageData); err != nil {
+ if err := qb.UpdateImage(ctx, s.ID, imageData); err != nil {
return err
}
} else if imageIncluded {
// must be unsetting
- if err := qb.DestroyImage(s.ID); err != nil {
+ if err := qb.DestroyImage(ctx, s.ID); err != nil {
return err
}
}
@@ -183,17 +183,17 @@ func (r *mutationResolver) StudioUpdate(ctx context.Context, input StudioUpdateI
// Save the stash_ids
if translator.hasField("stash_ids") {
stashIDJoins := models.StashIDsFromInput(input.StashIds)
- if err := qb.UpdateStashIDs(studioID, stashIDJoins); err != nil {
+ if err := qb.UpdateStashIDs(ctx, studioID, stashIDJoins); err != nil {
return err
}
}
if translator.hasField("aliases") {
- if err := studio.EnsureAliasesUnique(studioID, input.Aliases, qb); err != nil {
+ if err := studio.EnsureAliasesUnique(ctx, studioID, input.Aliases, qb); err != nil {
return err
}
- if err := qb.UpdateAliases(studioID, input.Aliases); err != nil {
+ if err := qb.UpdateAliases(ctx, studioID, input.Aliases); err != nil {
return err
}
}
@@ -213,8 +213,8 @@ func (r *mutationResolver) StudioDestroy(ctx context.Context, input StudioDestro
return false, err
}
- if err := r.withTxn(ctx, func(repo models.Repository) error {
- return repo.Studio().Destroy(id)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ return r.repository.Studio.Destroy(ctx, id)
}); err != nil {
return false, err
}
@@ -230,10 +230,10 @@ func (r *mutationResolver) StudiosDestroy(ctx context.Context, studioIDs []strin
return false, err
}
- if err := r.withTxn(ctx, func(repo models.Repository) error {
- qb := repo.Studio()
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ qb := r.repository.Studio
for _, id := range ids {
- if err := qb.Destroy(id); err != nil {
+ if err := qb.Destroy(ctx, id); err != nil {
return err
}
}
diff --git a/internal/api/resolver_mutation_tag.go b/internal/api/resolver_mutation_tag.go
index cff553d72..f5befeba7 100644
--- a/internal/api/resolver_mutation_tag.go
+++ b/internal/api/resolver_mutation_tag.go
@@ -15,8 +15,8 @@ import (
)
func (r *mutationResolver) getTag(ctx context.Context, id int) (ret *models.Tag, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- ret, err = repo.Tag().Find(id)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.repository.Tag.Find(ctx, id)
return err
}); err != nil {
return nil, err
@@ -68,44 +68,44 @@ func (r *mutationResolver) TagCreate(ctx context.Context, input TagCreateInput)
// Start the transaction and save the tag
var t *models.Tag
- if err := r.withTxn(ctx, func(repo models.Repository) error {
- qb := repo.Tag()
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ qb := r.repository.Tag
// ensure name is unique
- if err := tag.EnsureTagNameUnique(0, newTag.Name, qb); err != nil {
+ if err := tag.EnsureTagNameUnique(ctx, 0, newTag.Name, qb); err != nil {
return err
}
- t, err = qb.Create(newTag)
+ t, err = qb.Create(ctx, newTag)
if err != nil {
return err
}
// update image table
if len(imageData) > 0 {
- if err := qb.UpdateImage(t.ID, imageData); err != nil {
+ if err := qb.UpdateImage(ctx, t.ID, imageData); err != nil {
return err
}
}
if len(input.Aliases) > 0 {
- if err := tag.EnsureAliasesUnique(t.ID, input.Aliases, qb); err != nil {
+ if err := tag.EnsureAliasesUnique(ctx, t.ID, input.Aliases, qb); err != nil {
return err
}
- if err := qb.UpdateAliases(t.ID, input.Aliases); err != nil {
+ if err := qb.UpdateAliases(ctx, t.ID, input.Aliases); err != nil {
return err
}
}
if len(parentIDs) > 0 {
- if err := qb.UpdateParentTags(t.ID, parentIDs); err != nil {
+ if err := qb.UpdateParentTags(ctx, t.ID, parentIDs); err != nil {
return err
}
}
if len(childIDs) > 0 {
- if err := qb.UpdateChildTags(t.ID, childIDs); err != nil {
+ if err := qb.UpdateChildTags(ctx, t.ID, childIDs); err != nil {
return err
}
}
@@ -113,7 +113,7 @@ func (r *mutationResolver) TagCreate(ctx context.Context, input TagCreateInput)
// FIXME: This should be called before any changes are made, but
// requires a rewrite of ValidateHierarchy.
if len(parentIDs) > 0 || len(childIDs) > 0 {
- if err := tag.ValidateHierarchy(t, parentIDs, childIDs, qb); err != nil {
+ if err := tag.ValidateHierarchy(ctx, t, parentIDs, childIDs, qb); err != nil {
return err
}
}
@@ -168,11 +168,11 @@ func (r *mutationResolver) TagUpdate(ctx context.Context, input TagUpdateInput)
// Start the transaction and save the tag
var t *models.Tag
- if err := r.withTxn(ctx, func(repo models.Repository) error {
- qb := repo.Tag()
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ qb := r.repository.Tag
// ensure name is unique
- t, err = qb.Find(tagID)
+ t, err = qb.Find(ctx, tagID)
if err != nil {
return err
}
@@ -188,48 +188,48 @@ func (r *mutationResolver) TagUpdate(ctx context.Context, input TagUpdateInput)
}
if input.Name != nil && t.Name != *input.Name {
- if err := tag.EnsureTagNameUnique(tagID, *input.Name, qb); err != nil {
+ if err := tag.EnsureTagNameUnique(ctx, tagID, *input.Name, qb); err != nil {
return err
}
updatedTag.Name = input.Name
}
- t, err = qb.Update(updatedTag)
+ t, err = qb.Update(ctx, updatedTag)
if err != nil {
return err
}
// update image table
if len(imageData) > 0 {
- if err := qb.UpdateImage(tagID, imageData); err != nil {
+ if err := qb.UpdateImage(ctx, tagID, imageData); err != nil {
return err
}
} else if imageIncluded {
// must be unsetting
- if err := qb.DestroyImage(tagID); err != nil {
+ if err := qb.DestroyImage(ctx, tagID); err != nil {
return err
}
}
if translator.hasField("aliases") {
- if err := tag.EnsureAliasesUnique(tagID, input.Aliases, qb); err != nil {
+ if err := tag.EnsureAliasesUnique(ctx, tagID, input.Aliases, qb); err != nil {
return err
}
- if err := qb.UpdateAliases(tagID, input.Aliases); err != nil {
+ if err := qb.UpdateAliases(ctx, tagID, input.Aliases); err != nil {
return err
}
}
if parentIDs != nil {
- if err := qb.UpdateParentTags(tagID, parentIDs); err != nil {
+ if err := qb.UpdateParentTags(ctx, tagID, parentIDs); err != nil {
return err
}
}
if childIDs != nil {
- if err := qb.UpdateChildTags(tagID, childIDs); err != nil {
+ if err := qb.UpdateChildTags(ctx, tagID, childIDs); err != nil {
return err
}
}
@@ -237,7 +237,7 @@ func (r *mutationResolver) TagUpdate(ctx context.Context, input TagUpdateInput)
// FIXME: This should be called before any changes are made, but
// requires a rewrite of ValidateHierarchy.
if parentIDs != nil || childIDs != nil {
- if err := tag.ValidateHierarchy(t, parentIDs, childIDs, qb); err != nil {
+ if err := tag.ValidateHierarchy(ctx, t, parentIDs, childIDs, qb); err != nil {
logger.Errorf("Error saving tag: %s", err)
return err
}
@@ -258,8 +258,8 @@ func (r *mutationResolver) TagDestroy(ctx context.Context, input TagDestroyInput
return false, err
}
- if err := r.withTxn(ctx, func(repo models.Repository) error {
- return repo.Tag().Destroy(tagID)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ return r.repository.Tag.Destroy(ctx, tagID)
}); err != nil {
return false, err
}
@@ -275,10 +275,10 @@ func (r *mutationResolver) TagsDestroy(ctx context.Context, tagIDs []string) (bo
return false, err
}
- if err := r.withTxn(ctx, func(repo models.Repository) error {
- qb := repo.Tag()
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ qb := r.repository.Tag
for _, id := range ids {
- if err := qb.Destroy(id); err != nil {
+ if err := qb.Destroy(ctx, id); err != nil {
return err
}
}
@@ -311,11 +311,11 @@ func (r *mutationResolver) TagsMerge(ctx context.Context, input TagsMergeInput)
}
var t *models.Tag
- if err := r.withTxn(ctx, func(repo models.Repository) error {
- qb := repo.Tag()
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ qb := r.repository.Tag
var err error
- t, err = qb.Find(destination)
+ t, err = qb.Find(ctx, destination)
if err != nil {
return err
}
@@ -324,25 +324,25 @@ func (r *mutationResolver) TagsMerge(ctx context.Context, input TagsMergeInput)
return fmt.Errorf("Tag with ID %d not found", destination)
}
- parents, children, err := tag.MergeHierarchy(destination, source, qb)
+ parents, children, err := tag.MergeHierarchy(ctx, destination, source, qb)
if err != nil {
return err
}
- if err = qb.Merge(source, destination); err != nil {
+ if err = qb.Merge(ctx, source, destination); err != nil {
return err
}
- err = qb.UpdateParentTags(destination, parents)
+ err = qb.UpdateParentTags(ctx, destination, parents)
if err != nil {
return err
}
- err = qb.UpdateChildTags(destination, children)
+ err = qb.UpdateChildTags(ctx, destination, children)
if err != nil {
return err
}
- err = tag.ValidateHierarchy(t, parents, children, qb)
+ err = tag.ValidateHierarchy(ctx, t, parents, children, qb)
if err != nil {
logger.Errorf("Error merging tag: %s", err)
return err
diff --git a/internal/api/resolver_mutation_tag_test.go b/internal/api/resolver_mutation_tag_test.go
index c8d43f5f8..91b87794d 100644
--- a/internal/api/resolver_mutation_tag_test.go
+++ b/internal/api/resolver_mutation_tag_test.go
@@ -16,17 +16,23 @@ import (
// TODO - move this into a common area
func newResolver() *Resolver {
return &Resolver{
- txnManager: mocks.NewTransactionManager(),
+ txnManager: &mocks.TxnManager{},
+ repository: mocks.NewTxnRepository(),
hookExecutor: &mockHookExecutor{},
}
}
-const tagName = "tagName"
-const errTagName = "errTagName"
+const (
+ tagName = "tagName"
+ errTagName = "errTagName"
-const existingTagID = 1
-const existingTagName = "existingTagName"
-const newTagID = 2
+ existingTagID = 1
+ existingTagName = "existingTagName"
+
+ newTagID = 2
+)
+
+var testCtx = context.Background()
type mockHookExecutor struct{}
@@ -36,7 +42,7 @@ func (*mockHookExecutor) ExecutePostHooks(ctx context.Context, id int, hookType
func TestTagCreate(t *testing.T) {
r := newResolver()
- tagRW := r.txnManager.(*mocks.TransactionManager).Tag().(*mocks.TagReaderWriter)
+ tagRW := r.repository.Tag.(*mocks.TagReaderWriter)
pp := 1
findFilter := &models.FindFilterType{
@@ -61,25 +67,25 @@ func TestTagCreate(t *testing.T) {
}
}
- tagRW.On("Query", tagFilterForName(existingTagName), findFilter).Return([]*models.Tag{
+ tagRW.On("Query", testCtx, tagFilterForName(existingTagName), findFilter).Return([]*models.Tag{
{
ID: existingTagID,
Name: existingTagName,
},
}, 1, nil).Once()
- tagRW.On("Query", tagFilterForName(errTagName), findFilter).Return(nil, 0, nil).Once()
- tagRW.On("Query", tagFilterForAlias(errTagName), findFilter).Return(nil, 0, nil).Once()
+ tagRW.On("Query", testCtx, tagFilterForName(errTagName), findFilter).Return(nil, 0, nil).Once()
+ tagRW.On("Query", testCtx, tagFilterForAlias(errTagName), findFilter).Return(nil, 0, nil).Once()
expectedErr := errors.New("TagCreate error")
- tagRW.On("Create", mock.AnythingOfType("models.Tag")).Return(nil, expectedErr)
+ tagRW.On("Create", testCtx, mock.AnythingOfType("models.Tag")).Return(nil, expectedErr)
- _, err := r.Mutation().TagCreate(context.TODO(), TagCreateInput{
+ _, err := r.Mutation().TagCreate(testCtx, TagCreateInput{
Name: existingTagName,
})
assert.NotNil(t, err)
- _, err = r.Mutation().TagCreate(context.TODO(), TagCreateInput{
+ _, err = r.Mutation().TagCreate(testCtx, TagCreateInput{
Name: errTagName,
})
@@ -87,18 +93,18 @@ func TestTagCreate(t *testing.T) {
tagRW.AssertExpectations(t)
r = newResolver()
- tagRW = r.txnManager.(*mocks.TransactionManager).Tag().(*mocks.TagReaderWriter)
+ tagRW = r.repository.Tag.(*mocks.TagReaderWriter)
- tagRW.On("Query", tagFilterForName(tagName), findFilter).Return(nil, 0, nil).Once()
- tagRW.On("Query", tagFilterForAlias(tagName), findFilter).Return(nil, 0, nil).Once()
+ tagRW.On("Query", testCtx, tagFilterForName(tagName), findFilter).Return(nil, 0, nil).Once()
+ tagRW.On("Query", testCtx, tagFilterForAlias(tagName), findFilter).Return(nil, 0, nil).Once()
newTag := &models.Tag{
ID: newTagID,
Name: tagName,
}
- tagRW.On("Create", mock.AnythingOfType("models.Tag")).Return(newTag, nil)
- tagRW.On("Find", newTagID).Return(newTag, nil)
+ tagRW.On("Create", testCtx, mock.AnythingOfType("models.Tag")).Return(newTag, nil)
+ tagRW.On("Find", testCtx, newTagID).Return(newTag, nil)
- tag, err := r.Mutation().TagCreate(context.TODO(), TagCreateInput{
+ tag, err := r.Mutation().TagCreate(testCtx, TagCreateInput{
Name: tagName,
})
diff --git a/internal/api/resolver_query_configuration.go b/internal/api/resolver_query_configuration.go
index 39c463260..f3469de97 100644
--- a/internal/api/resolver_query_configuration.go
+++ b/internal/api/resolver_query_configuration.go
@@ -222,7 +222,7 @@ func makeConfigUIResult() map[string]interface{} {
}
func (r *queryResolver) ValidateStashBoxCredentials(ctx context.Context, input config.StashBoxInput) (*StashBoxValidationResult, error) {
- client := stashbox.NewClient(models.StashBox{Endpoint: input.Endpoint, APIKey: input.APIKey}, r.txnManager)
+ client := stashbox.NewClient(models.StashBox{Endpoint: input.Endpoint, APIKey: input.APIKey}, r.txnManager, r.stashboxRepository())
user, err := client.GetUser(ctx)
valid := user != nil && user.Me != nil
diff --git a/internal/api/resolver_query_find_gallery.go b/internal/api/resolver_query_find_gallery.go
index bd6d07c0f..ee12471d1 100644
--- a/internal/api/resolver_query_find_gallery.go
+++ b/internal/api/resolver_query_find_gallery.go
@@ -13,8 +13,8 @@ func (r *queryResolver) FindGallery(ctx context.Context, id string) (ret *models
return nil, err
}
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- ret, err = repo.Gallery().Find(idInt)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.repository.Gallery.Find(ctx, idInt)
return err
}); err != nil {
return nil, err
@@ -24,8 +24,8 @@ func (r *queryResolver) FindGallery(ctx context.Context, id string) (ret *models
}
func (r *queryResolver) FindGalleries(ctx context.Context, galleryFilter *models.GalleryFilterType, filter *models.FindFilterType) (ret *FindGalleriesResultType, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- galleries, total, err := repo.Gallery().Query(galleryFilter, filter)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ galleries, total, err := r.repository.Gallery.Query(ctx, galleryFilter, filter)
if err != nil {
return err
}
diff --git a/internal/api/resolver_query_find_image.go b/internal/api/resolver_query_find_image.go
index 0f01b42a1..f1269dce8 100644
--- a/internal/api/resolver_query_find_image.go
+++ b/internal/api/resolver_query_find_image.go
@@ -12,8 +12,8 @@ import (
func (r *queryResolver) FindImage(ctx context.Context, id *string, checksum *string) (*models.Image, error) {
var image *models.Image
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- qb := repo.Image()
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ qb := r.repository.Image
var err error
if id != nil {
@@ -22,12 +22,12 @@ func (r *queryResolver) FindImage(ctx context.Context, id *string, checksum *str
return err
}
- image, err = qb.Find(idInt)
+ image, err = qb.Find(ctx, idInt)
if err != nil {
return err
}
} else if checksum != nil {
- image, err = qb.FindByChecksum(*checksum)
+ image, err = qb.FindByChecksum(ctx, *checksum)
}
return err
@@ -39,12 +39,12 @@ func (r *queryResolver) FindImage(ctx context.Context, id *string, checksum *str
}
func (r *queryResolver) FindImages(ctx context.Context, imageFilter *models.ImageFilterType, imageIds []int, filter *models.FindFilterType) (ret *FindImagesResultType, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- qb := repo.Image()
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ qb := r.repository.Image
fields := graphql.CollectAllFields(ctx)
- result, err := qb.Query(models.ImageQueryOptions{
+ result, err := qb.Query(ctx, models.ImageQueryOptions{
QueryOptions: models.QueryOptions{
FindFilter: filter,
Count: stringslice.StrInclude(fields, "count"),
@@ -57,7 +57,7 @@ func (r *queryResolver) FindImages(ctx context.Context, imageFilter *models.Imag
return err
}
- images, err := result.Resolve()
+ images, err := result.Resolve(ctx)
if err != nil {
return err
}
diff --git a/internal/api/resolver_query_find_movie.go b/internal/api/resolver_query_find_movie.go
index 16a0bbe3b..7505c7f36 100644
--- a/internal/api/resolver_query_find_movie.go
+++ b/internal/api/resolver_query_find_movie.go
@@ -13,8 +13,8 @@ func (r *queryResolver) FindMovie(ctx context.Context, id string) (ret *models.M
return nil, err
}
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- ret, err = repo.Movie().Find(idInt)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.repository.Movie.Find(ctx, idInt)
return err
}); err != nil {
return nil, err
@@ -24,8 +24,8 @@ func (r *queryResolver) FindMovie(ctx context.Context, id string) (ret *models.M
}
func (r *queryResolver) FindMovies(ctx context.Context, movieFilter *models.MovieFilterType, filter *models.FindFilterType) (ret *FindMoviesResultType, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- movies, total, err := repo.Movie().Query(movieFilter, filter)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ movies, total, err := r.repository.Movie.Query(ctx, movieFilter, filter)
if err != nil {
return err
}
@@ -44,8 +44,8 @@ func (r *queryResolver) FindMovies(ctx context.Context, movieFilter *models.Movi
}
func (r *queryResolver) AllMovies(ctx context.Context) (ret []*models.Movie, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- ret, err = repo.Movie().All()
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.repository.Movie.All(ctx)
return err
}); err != nil {
return nil, err
diff --git a/internal/api/resolver_query_find_performer.go b/internal/api/resolver_query_find_performer.go
index 5d7bb2716..4314b0f69 100644
--- a/internal/api/resolver_query_find_performer.go
+++ b/internal/api/resolver_query_find_performer.go
@@ -13,8 +13,8 @@ func (r *queryResolver) FindPerformer(ctx context.Context, id string) (ret *mode
return nil, err
}
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- ret, err = repo.Performer().Find(idInt)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.repository.Performer.Find(ctx, idInt)
return err
}); err != nil {
return nil, err
@@ -24,8 +24,8 @@ func (r *queryResolver) FindPerformer(ctx context.Context, id string) (ret *mode
}
func (r *queryResolver) FindPerformers(ctx context.Context, performerFilter *models.PerformerFilterType, filter *models.FindFilterType) (ret *FindPerformersResultType, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- performers, total, err := repo.Performer().Query(performerFilter, filter)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ performers, total, err := r.repository.Performer.Query(ctx, performerFilter, filter)
if err != nil {
return err
}
@@ -43,8 +43,8 @@ func (r *queryResolver) FindPerformers(ctx context.Context, performerFilter *mod
}
func (r *queryResolver) AllPerformers(ctx context.Context) (ret []*models.Performer, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- ret, err = repo.Performer().All()
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.repository.Performer.All(ctx)
return err
}); err != nil {
return nil, err
diff --git a/internal/api/resolver_query_find_saved_filter.go b/internal/api/resolver_query_find_saved_filter.go
index a28ef2f59..7b934f581 100644
--- a/internal/api/resolver_query_find_saved_filter.go
+++ b/internal/api/resolver_query_find_saved_filter.go
@@ -13,8 +13,8 @@ func (r *queryResolver) FindSavedFilter(ctx context.Context, id string) (ret *mo
return nil, err
}
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- ret, err = repo.SavedFilter().Find(idInt)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.repository.SavedFilter.Find(ctx, idInt)
return err
}); err != nil {
return nil, err
@@ -23,11 +23,11 @@ func (r *queryResolver) FindSavedFilter(ctx context.Context, id string) (ret *mo
}
func (r *queryResolver) FindSavedFilters(ctx context.Context, mode *models.FilterMode) (ret []*models.SavedFilter, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
if mode != nil {
- ret, err = repo.SavedFilter().FindByMode(*mode)
+ ret, err = r.repository.SavedFilter.FindByMode(ctx, *mode)
} else {
- ret, err = repo.SavedFilter().All()
+ ret, err = r.repository.SavedFilter.All(ctx)
}
return err
}); err != nil {
@@ -37,8 +37,8 @@ func (r *queryResolver) FindSavedFilters(ctx context.Context, mode *models.Filte
}
func (r *queryResolver) FindDefaultFilter(ctx context.Context, mode models.FilterMode) (ret *models.SavedFilter, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- ret, err = repo.SavedFilter().FindDefault(mode)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.repository.SavedFilter.FindDefault(ctx, mode)
return err
}); err != nil {
return nil, err
diff --git a/internal/api/resolver_query_find_scene.go b/internal/api/resolver_query_find_scene.go
index 486ca744a..823e86503 100644
--- a/internal/api/resolver_query_find_scene.go
+++ b/internal/api/resolver_query_find_scene.go
@@ -12,20 +12,20 @@ import (
func (r *queryResolver) FindScene(ctx context.Context, id *string, checksum *string) (*models.Scene, error) {
var scene *models.Scene
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- qb := repo.Scene()
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ qb := r.repository.Scene
var err error
if id != nil {
idInt, err := strconv.Atoi(*id)
if err != nil {
return err
}
- scene, err = qb.Find(idInt)
+ scene, err = qb.Find(ctx, idInt)
if err != nil {
return err
}
} else if checksum != nil {
- scene, err = qb.FindByChecksum(*checksum)
+ scene, err = qb.FindByChecksum(ctx, *checksum)
}
return err
@@ -39,18 +39,18 @@ func (r *queryResolver) FindScene(ctx context.Context, id *string, checksum *str
func (r *queryResolver) FindSceneByHash(ctx context.Context, input SceneHashInput) (*models.Scene, error) {
var scene *models.Scene
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- qb := repo.Scene()
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ qb := r.repository.Scene
var err error
if input.Checksum != nil {
- scene, err = qb.FindByChecksum(*input.Checksum)
+ scene, err = qb.FindByChecksum(ctx, *input.Checksum)
if err != nil {
return err
}
}
if scene == nil && input.Oshash != nil {
- scene, err = qb.FindByOSHash(*input.Oshash)
+ scene, err = qb.FindByOSHash(ctx, *input.Oshash)
if err != nil {
return err
}
@@ -65,7 +65,7 @@ func (r *queryResolver) FindSceneByHash(ctx context.Context, input SceneHashInpu
}
func (r *queryResolver) FindScenes(ctx context.Context, sceneFilter *models.SceneFilterType, sceneIDs []int, filter *models.FindFilterType) (ret *FindScenesResultType, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
var scenes []*models.Scene
var err error
@@ -73,7 +73,7 @@ func (r *queryResolver) FindScenes(ctx context.Context, sceneFilter *models.Scen
result := &models.SceneQueryResult{}
if len(sceneIDs) > 0 {
- scenes, err = repo.Scene().FindMany(sceneIDs)
+ scenes, err = r.repository.Scene.FindMany(ctx, sceneIDs)
if err == nil {
result.Count = len(scenes)
for _, s := range scenes {
@@ -83,7 +83,7 @@ func (r *queryResolver) FindScenes(ctx context.Context, sceneFilter *models.Scen
}
}
} else {
- result, err = repo.Scene().Query(models.SceneQueryOptions{
+ result, err = r.repository.Scene.Query(ctx, models.SceneQueryOptions{
QueryOptions: models.QueryOptions{
FindFilter: filter,
Count: stringslice.StrInclude(fields, "count"),
@@ -93,7 +93,7 @@ func (r *queryResolver) FindScenes(ctx context.Context, sceneFilter *models.Scen
TotalSize: stringslice.StrInclude(fields, "filesize"),
})
if err == nil {
- scenes, err = result.Resolve()
+ scenes, err = result.Resolve(ctx)
}
}
@@ -117,7 +117,7 @@ func (r *queryResolver) FindScenes(ctx context.Context, sceneFilter *models.Scen
}
func (r *queryResolver) FindScenesByPathRegex(ctx context.Context, filter *models.FindFilterType) (ret *FindScenesResultType, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
sceneFilter := &models.SceneFilterType{}
@@ -138,7 +138,7 @@ func (r *queryResolver) FindScenesByPathRegex(ctx context.Context, filter *model
fields := graphql.CollectAllFields(ctx)
- result, err := repo.Scene().Query(models.SceneQueryOptions{
+ result, err := r.repository.Scene.Query(ctx, models.SceneQueryOptions{
QueryOptions: models.QueryOptions{
FindFilter: queryFilter,
Count: stringslice.StrInclude(fields, "count"),
@@ -151,7 +151,7 @@ func (r *queryResolver) FindScenesByPathRegex(ctx context.Context, filter *model
return err
}
- scenes, err := result.Resolve()
+ scenes, err := result.Resolve(ctx)
if err != nil {
return err
}
@@ -174,8 +174,14 @@ func (r *queryResolver) FindScenesByPathRegex(ctx context.Context, filter *model
func (r *queryResolver) ParseSceneFilenames(ctx context.Context, filter *models.FindFilterType, config manager.SceneParserInput) (ret *SceneParserResultType, err error) {
parser := manager.NewSceneFilenameParser(filter, config)
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- result, count, err := parser.Parse(repo)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ result, count, err := parser.Parse(ctx, manager.SceneFilenameParserRepository{
+ Scene: r.repository.Scene,
+ Performer: r.repository.Performer,
+ Studio: r.repository.Studio,
+ Movie: r.repository.Movie,
+ Tag: r.repository.Tag,
+ })
if err != nil {
return err
@@ -199,8 +205,8 @@ func (r *queryResolver) FindDuplicateScenes(ctx context.Context, distance *int)
if distance != nil {
dist = *distance
}
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- ret, err = repo.Scene().FindDuplicates(dist)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.repository.Scene.FindDuplicates(ctx, dist)
return err
}); err != nil {
return nil, err
diff --git a/internal/api/resolver_query_find_scene_marker.go b/internal/api/resolver_query_find_scene_marker.go
index 7b0f1ee12..03b9e261a 100644
--- a/internal/api/resolver_query_find_scene_marker.go
+++ b/internal/api/resolver_query_find_scene_marker.go
@@ -7,8 +7,8 @@ import (
)
func (r *queryResolver) FindSceneMarkers(ctx context.Context, sceneMarkerFilter *models.SceneMarkerFilterType, filter *models.FindFilterType) (ret *FindSceneMarkersResultType, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- sceneMarkers, total, err := repo.SceneMarker().Query(sceneMarkerFilter, filter)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ sceneMarkers, total, err := r.repository.SceneMarker.Query(ctx, sceneMarkerFilter, filter)
if err != nil {
return err
}
diff --git a/internal/api/resolver_query_find_studio.go b/internal/api/resolver_query_find_studio.go
index 24591e053..0bd17b9ad 100644
--- a/internal/api/resolver_query_find_studio.go
+++ b/internal/api/resolver_query_find_studio.go
@@ -13,9 +13,9 @@ func (r *queryResolver) FindStudio(ctx context.Context, id string) (ret *models.
return nil, err
}
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
var err error
- ret, err = repo.Studio().Find(idInt)
+ ret, err = r.repository.Studio.Find(ctx, idInt)
return err
}); err != nil {
return nil, err
@@ -25,8 +25,8 @@ func (r *queryResolver) FindStudio(ctx context.Context, id string) (ret *models.
}
func (r *queryResolver) FindStudios(ctx context.Context, studioFilter *models.StudioFilterType, filter *models.FindFilterType) (ret *FindStudiosResultType, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- studios, total, err := repo.Studio().Query(studioFilter, filter)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ studios, total, err := r.repository.Studio.Query(ctx, studioFilter, filter)
if err != nil {
return err
}
@@ -45,8 +45,8 @@ func (r *queryResolver) FindStudios(ctx context.Context, studioFilter *models.St
}
func (r *queryResolver) AllStudios(ctx context.Context) (ret []*models.Studio, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- ret, err = repo.Studio().All()
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.repository.Studio.All(ctx)
return err
}); err != nil {
return nil, err
diff --git a/internal/api/resolver_query_find_tag.go b/internal/api/resolver_query_find_tag.go
index 21aff9f4b..77bd57f98 100644
--- a/internal/api/resolver_query_find_tag.go
+++ b/internal/api/resolver_query_find_tag.go
@@ -13,8 +13,8 @@ func (r *queryResolver) FindTag(ctx context.Context, id string) (ret *models.Tag
return nil, err
}
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- ret, err = repo.Tag().Find(idInt)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.repository.Tag.Find(ctx, idInt)
return err
}); err != nil {
return nil, err
@@ -24,8 +24,8 @@ func (r *queryResolver) FindTag(ctx context.Context, id string) (ret *models.Tag
}
func (r *queryResolver) FindTags(ctx context.Context, tagFilter *models.TagFilterType, filter *models.FindFilterType) (ret *FindTagsResultType, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- tags, total, err := repo.Tag().Query(tagFilter, filter)
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ tags, total, err := r.repository.Tag.Query(ctx, tagFilter, filter)
if err != nil {
return err
}
@@ -44,8 +44,8 @@ func (r *queryResolver) FindTags(ctx context.Context, tagFilter *models.TagFilte
}
func (r *queryResolver) AllTags(ctx context.Context) (ret []*models.Tag, err error) {
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
- ret, err = repo.Tag().All()
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
+ ret, err = r.repository.Tag.All(ctx)
return err
}); err != nil {
return nil, err
diff --git a/internal/api/resolver_query_scene.go b/internal/api/resolver_query_scene.go
index 7ff77d2a8..c1ba0edca 100644
--- a/internal/api/resolver_query_scene.go
+++ b/internal/api/resolver_query_scene.go
@@ -14,10 +14,10 @@ import (
func (r *queryResolver) SceneStreams(ctx context.Context, id *string) ([]*manager.SceneStreamEndpoint, error) {
// find the scene
var scene *models.Scene
- if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
+ if err := r.withTxn(ctx, func(ctx context.Context) error {
idInt, _ := strconv.Atoi(*id)
var err error
- scene, err = repo.Scene().Find(idInt)
+ scene, err = r.repository.Scene.Find(ctx, idInt)
return err
}); err != nil {
return nil, err
diff --git a/internal/api/resolver_query_scraper.go b/internal/api/resolver_query_scraper.go
index 92c5200ce..85f47ee2c 100644
--- a/internal/api/resolver_query_scraper.go
+++ b/internal/api/resolver_query_scraper.go
@@ -234,7 +234,7 @@ func (r *queryResolver) getStashBoxClient(index int) (*stashbox.Client, error) {
return nil, fmt.Errorf("%w: invalid stash_box_index %d", ErrInput, index)
}
- return stashbox.NewClient(*boxes[index], r.txnManager), nil
+ return stashbox.NewClient(*boxes[index], r.txnManager, r.stashboxRepository()), nil
}
func (r *queryResolver) ScrapeSingleScene(ctx context.Context, source scraper.Source, input ScrapeSingleSceneInput) ([]*scraper.ScrapedScene, error) {
diff --git a/internal/api/routes_image.go b/internal/api/routes_image.go
index 8ba2e50d5..d66ccf7cc 100644
--- a/internal/api/routes_image.go
+++ b/internal/api/routes_image.go
@@ -13,17 +13,24 @@ import (
"github.com/stashapp/stash/pkg/image"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models"
+ "github.com/stashapp/stash/pkg/txn"
)
+type ImageFinder interface {
+ Find(ctx context.Context, id int) (*models.Image, error)
+ FindByChecksum(ctx context.Context, checksum string) (*models.Image, error)
+}
+
type imageRoutes struct {
- txnManager models.TransactionManager
+ txnManager txn.Manager
+ imageFinder ImageFinder
}
func (rs imageRoutes) Routes() chi.Router {
r := chi.NewRouter()
r.Route("/{imageId}", func(r chi.Router) {
- r.Use(ImageCtx)
+ r.Use(rs.ImageCtx)
r.Get("/image", rs.Image)
r.Get("/thumbnail", rs.Thumbnail)
@@ -85,18 +92,18 @@ func (rs imageRoutes) Image(w http.ResponseWriter, r *http.Request) {
// endregion
-func ImageCtx(next http.Handler) http.Handler {
+func (rs imageRoutes) ImageCtx(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
imageIdentifierQueryParam := chi.URLParam(r, "imageId")
imageID, _ := strconv.Atoi(imageIdentifierQueryParam)
var image *models.Image
- readTxnErr := manager.GetInstance().TxnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
- qb := repo.Image()
+ readTxnErr := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
+ qb := rs.imageFinder
if imageID == 0 {
- image, _ = qb.FindByChecksum(imageIdentifierQueryParam)
+ image, _ = qb.FindByChecksum(ctx, imageIdentifierQueryParam)
} else {
- image, _ = qb.Find(imageID)
+ image, _ = qb.Find(ctx, imageID)
}
return nil
diff --git a/internal/api/routes_movie.go b/internal/api/routes_movie.go
index 439b1e4d3..8fbccdb53 100644
--- a/internal/api/routes_movie.go
+++ b/internal/api/routes_movie.go
@@ -6,21 +6,28 @@ import (
"strconv"
"github.com/go-chi/chi"
- "github.com/stashapp/stash/internal/manager"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models"
+ "github.com/stashapp/stash/pkg/txn"
"github.com/stashapp/stash/pkg/utils"
)
+type MovieFinder interface {
+ GetFrontImage(ctx context.Context, movieID int) ([]byte, error)
+ GetBackImage(ctx context.Context, movieID int) ([]byte, error)
+ Find(ctx context.Context, id int) (*models.Movie, error)
+}
+
type movieRoutes struct {
- txnManager models.TransactionManager
+ txnManager txn.Manager
+ movieFinder MovieFinder
}
func (rs movieRoutes) Routes() chi.Router {
r := chi.NewRouter()
r.Route("/{movieId}", func(r chi.Router) {
- r.Use(MovieCtx)
+ r.Use(rs.MovieCtx)
r.Get("/frontimage", rs.FrontImage)
r.Get("/backimage", rs.BackImage)
})
@@ -33,8 +40,8 @@ func (rs movieRoutes) FrontImage(w http.ResponseWriter, r *http.Request) {
defaultParam := r.URL.Query().Get("default")
var image []byte
if defaultParam != "true" {
- err := rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
- image, _ = repo.Movie().GetFrontImage(movie.ID)
+ err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
+ image, _ = rs.movieFinder.GetFrontImage(ctx, movie.ID)
return nil
})
if err != nil {
@@ -56,8 +63,8 @@ func (rs movieRoutes) BackImage(w http.ResponseWriter, r *http.Request) {
defaultParam := r.URL.Query().Get("default")
var image []byte
if defaultParam != "true" {
- err := rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
- image, _ = repo.Movie().GetBackImage(movie.ID)
+ err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
+ image, _ = rs.movieFinder.GetBackImage(ctx, movie.ID)
return nil
})
if err != nil {
@@ -74,7 +81,7 @@ func (rs movieRoutes) BackImage(w http.ResponseWriter, r *http.Request) {
}
}
-func MovieCtx(next http.Handler) http.Handler {
+func (rs movieRoutes) MovieCtx(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
movieID, err := strconv.Atoi(chi.URLParam(r, "movieId"))
if err != nil {
@@ -83,9 +90,9 @@ func MovieCtx(next http.Handler) http.Handler {
}
var movie *models.Movie
- if err := manager.GetInstance().TxnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
+ if err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
var err error
- movie, err = repo.Movie().Find(movieID)
+ movie, err = rs.movieFinder.Find(ctx, movieID)
return err
}); err != nil {
http.Error(w, http.StatusText(404), 404)
diff --git a/internal/api/routes_performer.go b/internal/api/routes_performer.go
index e5c0bb862..15ad3c743 100644
--- a/internal/api/routes_performer.go
+++ b/internal/api/routes_performer.go
@@ -6,22 +6,28 @@ import (
"strconv"
"github.com/go-chi/chi"
- "github.com/stashapp/stash/internal/manager"
"github.com/stashapp/stash/internal/manager/config"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models"
+ "github.com/stashapp/stash/pkg/txn"
"github.com/stashapp/stash/pkg/utils"
)
+type PerformerFinder interface {
+ Find(ctx context.Context, id int) (*models.Performer, error)
+ GetImage(ctx context.Context, performerID int) ([]byte, error)
+}
+
type performerRoutes struct {
- txnManager models.TransactionManager
+ txnManager txn.Manager
+ performerFinder PerformerFinder
}
func (rs performerRoutes) Routes() chi.Router {
r := chi.NewRouter()
r.Route("/{performerId}", func(r chi.Router) {
- r.Use(PerformerCtx)
+ r.Use(rs.PerformerCtx)
r.Get("/image", rs.Image)
})
@@ -34,8 +40,8 @@ func (rs performerRoutes) Image(w http.ResponseWriter, r *http.Request) {
var image []byte
if defaultParam != "true" {
- readTxnErr := rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
- image, _ = repo.Performer().GetImage(performer.ID)
+ readTxnErr := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
+ image, _ = rs.performerFinder.GetImage(ctx, performer.ID)
return nil
})
if readTxnErr != nil {
@@ -52,7 +58,7 @@ func (rs performerRoutes) Image(w http.ResponseWriter, r *http.Request) {
}
}
-func PerformerCtx(next http.Handler) http.Handler {
+func (rs performerRoutes) PerformerCtx(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
performerID, err := strconv.Atoi(chi.URLParam(r, "performerId"))
if err != nil {
@@ -61,9 +67,9 @@ func PerformerCtx(next http.Handler) http.Handler {
}
var performer *models.Performer
- if err := manager.GetInstance().TxnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
+ if err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
var err error
- performer, err = repo.Performer().Find(performerID)
+ performer, err = rs.performerFinder.Find(ctx, performerID)
return err
}); err != nil {
http.Error(w, http.StatusText(404), 404)
diff --git a/internal/api/routes_scene.go b/internal/api/routes_scene.go
index 3612da72d..069e90876 100644
--- a/internal/api/routes_scene.go
+++ b/internal/api/routes_scene.go
@@ -15,18 +15,36 @@ import (
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/scene"
+ "github.com/stashapp/stash/pkg/txn"
"github.com/stashapp/stash/pkg/utils"
)
+type SceneFinder interface {
+ manager.SceneCoverGetter
+
+ scene.IDFinder
+ FindByChecksum(ctx context.Context, checksum string) (*models.Scene, error)
+ FindByOSHash(ctx context.Context, oshash string) (*models.Scene, error)
+ GetCaptions(ctx context.Context, sceneID int) ([]*models.SceneCaption, error)
+}
+
+type SceneMarkerFinder interface {
+ Find(ctx context.Context, id int) (*models.SceneMarker, error)
+ FindBySceneID(ctx context.Context, sceneID int) ([]*models.SceneMarker, error)
+}
+
type sceneRoutes struct {
- txnManager models.TransactionManager
+ txnManager txn.Manager
+ sceneFinder SceneFinder
+ sceneMarkerFinder SceneMarkerFinder
+ tagFinder scene.MarkerTagFinder
}
func (rs sceneRoutes) Routes() chi.Router {
r := chi.NewRouter()
r.Route("/{sceneId}", func(r chi.Router) {
- r.Use(SceneCtx)
+ r.Use(rs.SceneCtx)
// streaming endpoints
r.Get("/stream", rs.StreamDirect)
@@ -48,8 +66,8 @@ func (rs sceneRoutes) Routes() chi.Router {
r.Get("/scene_marker/{sceneMarkerId}/preview", rs.SceneMarkerPreview)
r.Get("/scene_marker/{sceneMarkerId}/screenshot", rs.SceneMarkerScreenshot)
})
- r.With(SceneCtx).Get("/{sceneId}_thumbs.vtt", rs.VttThumbs)
- r.With(SceneCtx).Get("/{sceneId}_sprite.jpg", rs.VttSprite)
+ r.With(rs.SceneCtx).Get("/{sceneId}_thumbs.vtt", rs.VttThumbs)
+ r.With(rs.SceneCtx).Get("/{sceneId}_sprite.jpg", rs.VttSprite)
return r
}
@@ -60,7 +78,8 @@ func (rs sceneRoutes) StreamDirect(w http.ResponseWriter, r *http.Request) {
scene := r.Context().Value(sceneKey).(*models.Scene)
ss := manager.SceneServer{
- TXNManager: rs.txnManager,
+ TxnManager: rs.txnManager,
+ SceneCoverGetter: rs.sceneFinder,
}
ss.StreamSceneDirect(scene, w, r)
}
@@ -190,7 +209,8 @@ func (rs sceneRoutes) Screenshot(w http.ResponseWriter, r *http.Request) {
scene := r.Context().Value(sceneKey).(*models.Scene)
ss := manager.SceneServer{
- TXNManager: rs.txnManager,
+ TxnManager: rs.txnManager,
+ SceneCoverGetter: rs.sceneFinder,
}
ss.ServeScreenshot(scene, w, r)
}
@@ -221,16 +241,16 @@ func (rs sceneRoutes) getChapterVttTitle(ctx context.Context, marker *models.Sce
}
var ret string
- if err := rs.txnManager.WithReadTxn(ctx, func(repo models.ReaderRepository) error {
- qb := repo.Tag()
- primaryTag, err := qb.Find(marker.PrimaryTagID)
+ if err := txn.WithTxn(ctx, rs.txnManager, func(ctx context.Context) error {
+ qb := rs.tagFinder
+ primaryTag, err := qb.Find(ctx, marker.PrimaryTagID)
if err != nil {
return err
}
ret = primaryTag.Name
- tags, err := qb.FindBySceneMarkerID(marker.ID)
+ tags, err := qb.FindBySceneMarkerID(ctx, marker.ID)
if err != nil {
return err
}
@@ -250,9 +270,9 @@ func (rs sceneRoutes) getChapterVttTitle(ctx context.Context, marker *models.Sce
func (rs sceneRoutes) ChapterVtt(w http.ResponseWriter, r *http.Request) {
scene := r.Context().Value(sceneKey).(*models.Scene)
var sceneMarkers []*models.SceneMarker
- if err := rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
+ if err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
var err error
- sceneMarkers, err = repo.SceneMarker().FindBySceneID(scene.ID)
+ sceneMarkers, err = rs.sceneMarkerFinder.FindBySceneID(ctx, scene.ID)
return err
}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
@@ -289,9 +309,9 @@ func (rs sceneRoutes) InteractiveHeatmap(w http.ResponseWriter, r *http.Request)
func (rs sceneRoutes) Caption(w http.ResponseWriter, r *http.Request, lang string, ext string) {
s := r.Context().Value(sceneKey).(*models.Scene)
- if err := rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
+ if err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
var err error
- captions, err := repo.Scene().GetCaptions(s.ID)
+ captions, err := rs.sceneFinder.GetCaptions(ctx, s.ID)
for _, caption := range captions {
if lang == caption.LanguageCode && ext == caption.CaptionType {
sub, err := scene.ReadSubs(caption.Path(s.Path))
@@ -344,9 +364,9 @@ func (rs sceneRoutes) SceneMarkerStream(w http.ResponseWriter, r *http.Request)
scene := r.Context().Value(sceneKey).(*models.Scene)
sceneMarkerID, _ := strconv.Atoi(chi.URLParam(r, "sceneMarkerId"))
var sceneMarker *models.SceneMarker
- if err := rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
+ if err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
var err error
- sceneMarker, err = repo.SceneMarker().Find(sceneMarkerID)
+ sceneMarker, err = rs.sceneMarkerFinder.Find(ctx, sceneMarkerID)
return err
}); err != nil {
logger.Warnf("Error when getting scene marker for stream: %s", err.Error())
@@ -367,9 +387,9 @@ func (rs sceneRoutes) SceneMarkerPreview(w http.ResponseWriter, r *http.Request)
scene := r.Context().Value(sceneKey).(*models.Scene)
sceneMarkerID, _ := strconv.Atoi(chi.URLParam(r, "sceneMarkerId"))
var sceneMarker *models.SceneMarker
- if err := rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
+ if err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
var err error
- sceneMarker, err = repo.SceneMarker().Find(sceneMarkerID)
+ sceneMarker, err = rs.sceneMarkerFinder.Find(ctx, sceneMarkerID)
return err
}); err != nil {
logger.Warnf("Error when getting scene marker for stream: %s", err.Error())
@@ -400,9 +420,9 @@ func (rs sceneRoutes) SceneMarkerScreenshot(w http.ResponseWriter, r *http.Reque
scene := r.Context().Value(sceneKey).(*models.Scene)
sceneMarkerID, _ := strconv.Atoi(chi.URLParam(r, "sceneMarkerId"))
var sceneMarker *models.SceneMarker
- if err := rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
+ if err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
var err error
- sceneMarker, err = repo.SceneMarker().Find(sceneMarkerID)
+ sceneMarker, err = rs.sceneMarkerFinder.Find(ctx, sceneMarkerID)
return err
}); err != nil {
logger.Warnf("Error when getting scene marker for stream: %s", err.Error())
@@ -431,23 +451,23 @@ func (rs sceneRoutes) SceneMarkerScreenshot(w http.ResponseWriter, r *http.Reque
// endregion
-func SceneCtx(next http.Handler) http.Handler {
+func (rs sceneRoutes) SceneCtx(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sceneIdentifierQueryParam := chi.URLParam(r, "sceneId")
sceneID, _ := strconv.Atoi(sceneIdentifierQueryParam)
var scene *models.Scene
- readTxnErr := manager.GetInstance().TxnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
- qb := repo.Scene()
+ readTxnErr := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
+ qb := rs.sceneFinder
if sceneID == 0 {
// determine checksum/os by the length of the query param
if len(sceneIdentifierQueryParam) == 32 {
- scene, _ = qb.FindByChecksum(sceneIdentifierQueryParam)
+ scene, _ = qb.FindByChecksum(ctx, sceneIdentifierQueryParam)
} else {
- scene, _ = qb.FindByOSHash(sceneIdentifierQueryParam)
+ scene, _ = qb.FindByOSHash(ctx, sceneIdentifierQueryParam)
}
} else {
- scene, _ = qb.Find(sceneID)
+ scene, _ = qb.Find(ctx, sceneID)
}
return nil
diff --git a/internal/api/routes_studio.go b/internal/api/routes_studio.go
index 18f78b30c..e26499f04 100644
--- a/internal/api/routes_studio.go
+++ b/internal/api/routes_studio.go
@@ -8,21 +8,28 @@ import (
"syscall"
"github.com/go-chi/chi"
- "github.com/stashapp/stash/internal/manager"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models"
+ "github.com/stashapp/stash/pkg/studio"
+ "github.com/stashapp/stash/pkg/txn"
"github.com/stashapp/stash/pkg/utils"
)
+type StudioFinder interface {
+ studio.Finder
+ GetImage(ctx context.Context, studioID int) ([]byte, error)
+}
+
type studioRoutes struct {
- txnManager models.TransactionManager
+ txnManager txn.Manager
+ studioFinder StudioFinder
}
func (rs studioRoutes) Routes() chi.Router {
r := chi.NewRouter()
r.Route("/{studioId}", func(r chi.Router) {
- r.Use(StudioCtx)
+ r.Use(rs.StudioCtx)
r.Get("/image", rs.Image)
})
@@ -35,8 +42,8 @@ func (rs studioRoutes) Image(w http.ResponseWriter, r *http.Request) {
var image []byte
if defaultParam != "true" {
- err := rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
- image, _ = repo.Studio().GetImage(studio.ID)
+ err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
+ image, _ = rs.studioFinder.GetImage(ctx, studio.ID)
return nil
})
if err != nil {
@@ -58,7 +65,7 @@ func (rs studioRoutes) Image(w http.ResponseWriter, r *http.Request) {
}
}
-func StudioCtx(next http.Handler) http.Handler {
+func (rs studioRoutes) StudioCtx(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
studioID, err := strconv.Atoi(chi.URLParam(r, "studioId"))
if err != nil {
@@ -67,9 +74,9 @@ func StudioCtx(next http.Handler) http.Handler {
}
var studio *models.Studio
- if err := manager.GetInstance().TxnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
+ if err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
var err error
- studio, err = repo.Studio().Find(studioID)
+ studio, err = rs.studioFinder.Find(ctx, studioID)
return err
}); err != nil {
http.Error(w, http.StatusText(404), 404)
diff --git a/internal/api/routes_tag.go b/internal/api/routes_tag.go
index 8ffdc62c9..69c573cb2 100644
--- a/internal/api/routes_tag.go
+++ b/internal/api/routes_tag.go
@@ -6,21 +6,28 @@ import (
"strconv"
"github.com/go-chi/chi"
- "github.com/stashapp/stash/internal/manager"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models"
+ "github.com/stashapp/stash/pkg/tag"
+ "github.com/stashapp/stash/pkg/txn"
"github.com/stashapp/stash/pkg/utils"
)
+type TagFinder interface {
+ tag.Finder
+ GetImage(ctx context.Context, tagID int) ([]byte, error)
+}
+
type tagRoutes struct {
- txnManager models.TransactionManager
+ txnManager txn.Manager
+ tagFinder TagFinder
}
func (rs tagRoutes) Routes() chi.Router {
r := chi.NewRouter()
r.Route("/{tagId}", func(r chi.Router) {
- r.Use(TagCtx)
+ r.Use(rs.TagCtx)
r.Get("/image", rs.Image)
})
@@ -33,8 +40,8 @@ func (rs tagRoutes) Image(w http.ResponseWriter, r *http.Request) {
var image []byte
if defaultParam != "true" {
- err := rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
- image, _ = repo.Tag().GetImage(tag.ID)
+ err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
+ image, _ = rs.tagFinder.GetImage(ctx, tag.ID)
return nil
})
if err != nil {
@@ -51,7 +58,7 @@ func (rs tagRoutes) Image(w http.ResponseWriter, r *http.Request) {
}
}
-func TagCtx(next http.Handler) http.Handler {
+func (rs tagRoutes) TagCtx(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tagID, err := strconv.Atoi(chi.URLParam(r, "tagId"))
if err != nil {
@@ -60,9 +67,9 @@ func TagCtx(next http.Handler) http.Handler {
}
var tag *models.Tag
- if err := manager.GetInstance().TxnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
+ if err := txn.WithTxn(r.Context(), rs.txnManager, func(ctx context.Context) error {
var err error
- tag, err = repo.Tag().Find(tagID)
+ tag, err = rs.tagFinder.Find(ctx, tagID)
return err
}); err != nil {
http.Error(w, http.StatusText(404), 404)
diff --git a/internal/api/server.go b/internal/api/server.go
index 95ce7f775..48bf87764 100644
--- a/internal/api/server.go
+++ b/internal/api/server.go
@@ -73,10 +73,11 @@ func Start() error {
return errors.New(message)
}
- txnManager := manager.GetInstance().TxnManager
+ txnManager := manager.GetInstance().Repository
pluginCache := manager.GetInstance().PluginCache
resolver := &Resolver{
txnManager: txnManager,
+ repository: txnManager,
hookExecutor: pluginCache,
}
@@ -118,22 +119,30 @@ func Start() error {
r.Get(loginEndPoint, getLoginHandler(loginUIBox))
r.Mount("/performer", performerRoutes{
- txnManager: txnManager,
+ txnManager: txnManager,
+ performerFinder: txnManager.Performer,
}.Routes())
r.Mount("/scene", sceneRoutes{
- txnManager: txnManager,
+ txnManager: txnManager,
+ sceneFinder: txnManager.Scene,
+ sceneMarkerFinder: txnManager.SceneMarker,
+ tagFinder: txnManager.Tag,
}.Routes())
r.Mount("/image", imageRoutes{
- txnManager: txnManager,
+ txnManager: txnManager,
+ imageFinder: txnManager.Image,
}.Routes())
r.Mount("/studio", studioRoutes{
- txnManager: txnManager,
+ txnManager: txnManager,
+ studioFinder: txnManager.Studio,
}.Routes())
r.Mount("/movie", movieRoutes{
- txnManager: txnManager,
+ txnManager: txnManager,
+ movieFinder: txnManager.Movie,
}.Routes())
r.Mount("/tag", tagRoutes{
txnManager: txnManager,
+ tagFinder: txnManager.Tag,
}.Routes())
r.Mount("/downloads", downloadsRoutes{}.Routes())
diff --git a/internal/autotag/gallery.go b/internal/autotag/gallery.go
index 3bdfd3c15..7f90c7e76 100644
--- a/internal/autotag/gallery.go
+++ b/internal/autotag/gallery.go
@@ -1,6 +1,8 @@
package autotag
import (
+ "context"
+
"github.com/stashapp/stash/pkg/gallery"
"github.com/stashapp/stash/pkg/match"
"github.com/stashapp/stash/pkg/models"
@@ -21,18 +23,18 @@ func getGalleryFileTagger(s *models.Gallery, cache *match.Cache) tagger {
}
// GalleryPerformers tags the provided gallery with performers whose name matches the gallery's path.
-func GalleryPerformers(s *models.Gallery, rw models.GalleryReaderWriter, performerReader models.PerformerReader, cache *match.Cache) error {
+func GalleryPerformers(ctx context.Context, s *models.Gallery, rw gallery.PerformerUpdater, performerReader match.PerformerAutoTagQueryer, cache *match.Cache) error {
t := getGalleryFileTagger(s, cache)
- return t.tagPerformers(performerReader, func(subjectID, otherID int) (bool, error) {
- return gallery.AddPerformer(rw, subjectID, otherID)
+ return t.tagPerformers(ctx, performerReader, func(subjectID, otherID int) (bool, error) {
+ return gallery.AddPerformer(ctx, rw, subjectID, otherID)
})
}
// GalleryStudios tags the provided gallery with the first studio whose name matches the gallery's path.
//
// Gallerys will not be tagged if studio is already set.
-func GalleryStudios(s *models.Gallery, rw models.GalleryReaderWriter, studioReader models.StudioReader, cache *match.Cache) error {
+func GalleryStudios(ctx context.Context, s *models.Gallery, rw GalleryFinderUpdater, studioReader match.StudioAutoTagQueryer, cache *match.Cache) error {
if s.StudioID.Valid {
// don't modify
return nil
@@ -40,16 +42,16 @@ func GalleryStudios(s *models.Gallery, rw models.GalleryReaderWriter, studioRead
t := getGalleryFileTagger(s, cache)
- return t.tagStudios(studioReader, func(subjectID, otherID int) (bool, error) {
- return addGalleryStudio(rw, subjectID, otherID)
+ return t.tagStudios(ctx, studioReader, func(subjectID, otherID int) (bool, error) {
+ return addGalleryStudio(ctx, rw, subjectID, otherID)
})
}
// GalleryTags tags the provided gallery with tags whose name matches the gallery's path.
-func GalleryTags(s *models.Gallery, rw models.GalleryReaderWriter, tagReader models.TagReader, cache *match.Cache) error {
+func GalleryTags(ctx context.Context, s *models.Gallery, rw gallery.TagUpdater, tagReader match.TagAutoTagQueryer, cache *match.Cache) error {
t := getGalleryFileTagger(s, cache)
- return t.tagTags(tagReader, func(subjectID, otherID int) (bool, error) {
- return gallery.AddTag(rw, subjectID, otherID)
+ return t.tagTags(ctx, tagReader, func(subjectID, otherID int) (bool, error) {
+ return gallery.AddTag(ctx, rw, subjectID, otherID)
})
}
diff --git a/internal/autotag/gallery_test.go b/internal/autotag/gallery_test.go
index 6d744400a..a50dc8ac4 100644
--- a/internal/autotag/gallery_test.go
+++ b/internal/autotag/gallery_test.go
@@ -1,6 +1,7 @@
package autotag
import (
+ "context"
"testing"
"github.com/stashapp/stash/pkg/models"
@@ -11,6 +12,8 @@ import (
const galleryExt = "zip"
+var testCtx = context.Background()
+
func TestGalleryPerformers(t *testing.T) {
t.Parallel()
@@ -37,19 +40,19 @@ func TestGalleryPerformers(t *testing.T) {
mockPerformerReader := &mocks.PerformerReaderWriter{}
mockGalleryReader := &mocks.GalleryReaderWriter{}
- mockPerformerReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil)
- mockPerformerReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once()
+ mockPerformerReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
+ mockPerformerReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once()
if test.Matches {
- mockGalleryReader.On("GetPerformerIDs", galleryID).Return(nil, nil).Once()
- mockGalleryReader.On("UpdatePerformers", galleryID, []int{performerID}).Return(nil).Once()
+ mockGalleryReader.On("GetPerformerIDs", testCtx, galleryID).Return(nil, nil).Once()
+ mockGalleryReader.On("UpdatePerformers", testCtx, galleryID, []int{performerID}).Return(nil).Once()
}
gallery := models.Gallery{
ID: galleryID,
Path: models.NullString(test.Path),
}
- err := GalleryPerformers(&gallery, mockGalleryReader, mockPerformerReader, nil)
+ err := GalleryPerformers(testCtx, &gallery, mockGalleryReader, mockPerformerReader, nil)
assert.Nil(err)
mockPerformerReader.AssertExpectations(t)
@@ -81,9 +84,9 @@ func TestGalleryStudios(t *testing.T) {
doTest := func(mockStudioReader *mocks.StudioReaderWriter, mockGalleryReader *mocks.GalleryReaderWriter, test pathTestTable) {
if test.Matches {
- mockGalleryReader.On("Find", galleryID).Return(&models.Gallery{}, nil).Once()
+ mockGalleryReader.On("Find", testCtx, galleryID).Return(&models.Gallery{}, nil).Once()
expectedStudioID := models.NullInt64(studioID)
- mockGalleryReader.On("UpdatePartial", models.GalleryPartial{
+ mockGalleryReader.On("UpdatePartial", testCtx, models.GalleryPartial{
ID: galleryID,
StudioID: &expectedStudioID,
}).Return(nil, nil).Once()
@@ -93,7 +96,7 @@ func TestGalleryStudios(t *testing.T) {
ID: galleryID,
Path: models.NullString(test.Path),
}
- err := GalleryStudios(&gallery, mockGalleryReader, mockStudioReader, nil)
+ err := GalleryStudios(testCtx, &gallery, mockGalleryReader, mockStudioReader, nil)
assert.Nil(err)
mockStudioReader.AssertExpectations(t)
@@ -104,9 +107,9 @@ func TestGalleryStudios(t *testing.T) {
mockStudioReader := &mocks.StudioReaderWriter{}
mockGalleryReader := &mocks.GalleryReaderWriter{}
- mockStudioReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil)
- mockStudioReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
- mockStudioReader.On("GetAliases", mock.Anything).Return([]string{}, nil).Maybe()
+ mockStudioReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
+ mockStudioReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
+ mockStudioReader.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe()
doTest(mockStudioReader, mockGalleryReader, test)
}
@@ -119,12 +122,12 @@ func TestGalleryStudios(t *testing.T) {
mockStudioReader := &mocks.StudioReaderWriter{}
mockGalleryReader := &mocks.GalleryReaderWriter{}
- mockStudioReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil)
- mockStudioReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
- mockStudioReader.On("GetAliases", studioID).Return([]string{
+ mockStudioReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
+ mockStudioReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
+ mockStudioReader.On("GetAliases", testCtx, studioID).Return([]string{
studioName,
}, nil).Once()
- mockStudioReader.On("GetAliases", reversedStudioID).Return([]string{}, nil).Once()
+ mockStudioReader.On("GetAliases", testCtx, reversedStudioID).Return([]string{}, nil).Once()
doTest(mockStudioReader, mockGalleryReader, test)
}
@@ -154,15 +157,15 @@ func TestGalleryTags(t *testing.T) {
doTest := func(mockTagReader *mocks.TagReaderWriter, mockGalleryReader *mocks.GalleryReaderWriter, test pathTestTable) {
if test.Matches {
- mockGalleryReader.On("GetTagIDs", galleryID).Return(nil, nil).Once()
- mockGalleryReader.On("UpdateTags", galleryID, []int{tagID}).Return(nil).Once()
+ mockGalleryReader.On("GetTagIDs", testCtx, galleryID).Return(nil, nil).Once()
+ mockGalleryReader.On("UpdateTags", testCtx, galleryID, []int{tagID}).Return(nil).Once()
}
gallery := models.Gallery{
ID: galleryID,
Path: models.NullString(test.Path),
}
- err := GalleryTags(&gallery, mockGalleryReader, mockTagReader, nil)
+ err := GalleryTags(testCtx, &gallery, mockGalleryReader, mockTagReader, nil)
assert.Nil(err)
mockTagReader.AssertExpectations(t)
@@ -173,9 +176,9 @@ func TestGalleryTags(t *testing.T) {
mockTagReader := &mocks.TagReaderWriter{}
mockGalleryReader := &mocks.GalleryReaderWriter{}
- mockTagReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil)
- mockTagReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
- mockTagReader.On("GetAliases", mock.Anything).Return([]string{}, nil).Maybe()
+ mockTagReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
+ mockTagReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
+ mockTagReader.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe()
doTest(mockTagReader, mockGalleryReader, test)
}
@@ -187,12 +190,12 @@ func TestGalleryTags(t *testing.T) {
mockTagReader := &mocks.TagReaderWriter{}
mockGalleryReader := &mocks.GalleryReaderWriter{}
- mockTagReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil)
- mockTagReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
- mockTagReader.On("GetAliases", tagID).Return([]string{
+ mockTagReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
+ mockTagReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
+ mockTagReader.On("GetAliases", testCtx, tagID).Return([]string{
tagName,
}, nil).Once()
- mockTagReader.On("GetAliases", reversedTagID).Return([]string{}, nil).Once()
+ mockTagReader.On("GetAliases", testCtx, reversedTagID).Return([]string{}, nil).Once()
doTest(mockTagReader, mockGalleryReader, test)
}
diff --git a/internal/autotag/image.go b/internal/autotag/image.go
index 516f30181..17d0d1816 100644
--- a/internal/autotag/image.go
+++ b/internal/autotag/image.go
@@ -1,6 +1,8 @@
package autotag
import (
+ "context"
+
"github.com/stashapp/stash/pkg/image"
"github.com/stashapp/stash/pkg/match"
"github.com/stashapp/stash/pkg/models"
@@ -17,18 +19,18 @@ func getImageFileTagger(s *models.Image, cache *match.Cache) tagger {
}
// ImagePerformers tags the provided image with performers whose name matches the image's path.
-func ImagePerformers(s *models.Image, rw models.ImageReaderWriter, performerReader models.PerformerReader, cache *match.Cache) error {
+func ImagePerformers(ctx context.Context, s *models.Image, rw image.PerformerUpdater, performerReader match.PerformerAutoTagQueryer, cache *match.Cache) error {
t := getImageFileTagger(s, cache)
- return t.tagPerformers(performerReader, func(subjectID, otherID int) (bool, error) {
- return image.AddPerformer(rw, subjectID, otherID)
+ return t.tagPerformers(ctx, performerReader, func(subjectID, otherID int) (bool, error) {
+ return image.AddPerformer(ctx, rw, subjectID, otherID)
})
}
// ImageStudios tags the provided image with the first studio whose name matches the image's path.
//
// Images will not be tagged if studio is already set.
-func ImageStudios(s *models.Image, rw models.ImageReaderWriter, studioReader models.StudioReader, cache *match.Cache) error {
+func ImageStudios(ctx context.Context, s *models.Image, rw ImageFinderUpdater, studioReader match.StudioAutoTagQueryer, cache *match.Cache) error {
if s.StudioID.Valid {
// don't modify
return nil
@@ -36,16 +38,16 @@ func ImageStudios(s *models.Image, rw models.ImageReaderWriter, studioReader mod
t := getImageFileTagger(s, cache)
- return t.tagStudios(studioReader, func(subjectID, otherID int) (bool, error) {
- return addImageStudio(rw, subjectID, otherID)
+ return t.tagStudios(ctx, studioReader, func(subjectID, otherID int) (bool, error) {
+ return addImageStudio(ctx, rw, subjectID, otherID)
})
}
// ImageTags tags the provided image with tags whose name matches the image's path.
-func ImageTags(s *models.Image, rw models.ImageReaderWriter, tagReader models.TagReader, cache *match.Cache) error {
+func ImageTags(ctx context.Context, s *models.Image, rw image.TagUpdater, tagReader match.TagAutoTagQueryer, cache *match.Cache) error {
t := getImageFileTagger(s, cache)
- return t.tagTags(tagReader, func(subjectID, otherID int) (bool, error) {
- return image.AddTag(rw, subjectID, otherID)
+ return t.tagTags(ctx, tagReader, func(subjectID, otherID int) (bool, error) {
+ return image.AddTag(ctx, rw, subjectID, otherID)
})
}
diff --git a/internal/autotag/image_test.go b/internal/autotag/image_test.go
index 130ce51af..67eedb689 100644
--- a/internal/autotag/image_test.go
+++ b/internal/autotag/image_test.go
@@ -37,19 +37,19 @@ func TestImagePerformers(t *testing.T) {
mockPerformerReader := &mocks.PerformerReaderWriter{}
mockImageReader := &mocks.ImageReaderWriter{}
- mockPerformerReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil)
- mockPerformerReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once()
+ mockPerformerReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
+ mockPerformerReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once()
if test.Matches {
- mockImageReader.On("GetPerformerIDs", imageID).Return(nil, nil).Once()
- mockImageReader.On("UpdatePerformers", imageID, []int{performerID}).Return(nil).Once()
+ mockImageReader.On("GetPerformerIDs", testCtx, imageID).Return(nil, nil).Once()
+ mockImageReader.On("UpdatePerformers", testCtx, imageID, []int{performerID}).Return(nil).Once()
}
image := models.Image{
ID: imageID,
Path: test.Path,
}
- err := ImagePerformers(&image, mockImageReader, mockPerformerReader, nil)
+ err := ImagePerformers(testCtx, &image, mockImageReader, mockPerformerReader, nil)
assert.Nil(err)
mockPerformerReader.AssertExpectations(t)
@@ -81,9 +81,9 @@ func TestImageStudios(t *testing.T) {
doTest := func(mockStudioReader *mocks.StudioReaderWriter, mockImageReader *mocks.ImageReaderWriter, test pathTestTable) {
if test.Matches {
- mockImageReader.On("Find", imageID).Return(&models.Image{}, nil).Once()
+ mockImageReader.On("Find", testCtx, imageID).Return(&models.Image{}, nil).Once()
expectedStudioID := models.NullInt64(studioID)
- mockImageReader.On("Update", models.ImagePartial{
+ mockImageReader.On("Update", testCtx, models.ImagePartial{
ID: imageID,
StudioID: &expectedStudioID,
}).Return(nil, nil).Once()
@@ -93,7 +93,7 @@ func TestImageStudios(t *testing.T) {
ID: imageID,
Path: test.Path,
}
- err := ImageStudios(&image, mockImageReader, mockStudioReader, nil)
+ err := ImageStudios(testCtx, &image, mockImageReader, mockStudioReader, nil)
assert.Nil(err)
mockStudioReader.AssertExpectations(t)
@@ -104,9 +104,9 @@ func TestImageStudios(t *testing.T) {
mockStudioReader := &mocks.StudioReaderWriter{}
mockImageReader := &mocks.ImageReaderWriter{}
- mockStudioReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil)
- mockStudioReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
- mockStudioReader.On("GetAliases", mock.Anything).Return([]string{}, nil).Maybe()
+ mockStudioReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
+ mockStudioReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
+ mockStudioReader.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe()
doTest(mockStudioReader, mockImageReader, test)
}
@@ -119,12 +119,12 @@ func TestImageStudios(t *testing.T) {
mockStudioReader := &mocks.StudioReaderWriter{}
mockImageReader := &mocks.ImageReaderWriter{}
- mockStudioReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil)
- mockStudioReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
- mockStudioReader.On("GetAliases", studioID).Return([]string{
+ mockStudioReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
+ mockStudioReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
+ mockStudioReader.On("GetAliases", testCtx, studioID).Return([]string{
studioName,
}, nil).Once()
- mockStudioReader.On("GetAliases", reversedStudioID).Return([]string{}, nil).Once()
+ mockStudioReader.On("GetAliases", testCtx, reversedStudioID).Return([]string{}, nil).Once()
doTest(mockStudioReader, mockImageReader, test)
}
@@ -154,15 +154,15 @@ func TestImageTags(t *testing.T) {
doTest := func(mockTagReader *mocks.TagReaderWriter, mockImageReader *mocks.ImageReaderWriter, test pathTestTable) {
if test.Matches {
- mockImageReader.On("GetTagIDs", imageID).Return(nil, nil).Once()
- mockImageReader.On("UpdateTags", imageID, []int{tagID}).Return(nil).Once()
+ mockImageReader.On("GetTagIDs", testCtx, imageID).Return(nil, nil).Once()
+ mockImageReader.On("UpdateTags", testCtx, imageID, []int{tagID}).Return(nil).Once()
}
image := models.Image{
ID: imageID,
Path: test.Path,
}
- err := ImageTags(&image, mockImageReader, mockTagReader, nil)
+ err := ImageTags(testCtx, &image, mockImageReader, mockTagReader, nil)
assert.Nil(err)
mockTagReader.AssertExpectations(t)
@@ -173,9 +173,9 @@ func TestImageTags(t *testing.T) {
mockTagReader := &mocks.TagReaderWriter{}
mockImageReader := &mocks.ImageReaderWriter{}
- mockTagReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil)
- mockTagReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
- mockTagReader.On("GetAliases", mock.Anything).Return([]string{}, nil).Maybe()
+ mockTagReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
+ mockTagReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
+ mockTagReader.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe()
doTest(mockTagReader, mockImageReader, test)
}
@@ -188,12 +188,12 @@ func TestImageTags(t *testing.T) {
mockTagReader := &mocks.TagReaderWriter{}
mockImageReader := &mocks.ImageReaderWriter{}
- mockTagReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil)
- mockTagReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
- mockTagReader.On("GetAliases", tagID).Return([]string{
+ mockTagReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
+ mockTagReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
+ mockTagReader.On("GetAliases", testCtx, tagID).Return([]string{
tagName,
}, nil).Once()
- mockTagReader.On("GetAliases", reversedTagID).Return([]string{}, nil).Once()
+ mockTagReader.On("GetAliases", testCtx, reversedTagID).Return([]string{}, nil).Once()
doTest(mockTagReader, mockImageReader, test)
}
diff --git a/internal/autotag/integration_test.go b/internal/autotag/integration_test.go
index 9ca176d4b..5465d20c8 100644
--- a/internal/autotag/integration_test.go
+++ b/internal/autotag/integration_test.go
@@ -10,10 +10,10 @@ import (
"os"
"testing"
- "github.com/stashapp/stash/pkg/database"
"github.com/stashapp/stash/pkg/hash/md5"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/sqlite"
+ "github.com/stashapp/stash/pkg/txn"
_ "github.com/golang-migrate/migrate/v4/database/sqlite3"
_ "github.com/golang-migrate/migrate/v4/source/file"
@@ -28,8 +28,11 @@ const existingStudioGalleryName = testName + ".dontChangeStudio.mp4"
var existingStudioID int
+var db *sqlite.Database
+var r models.Repository
+
func testTeardown(databaseFile string) {
- err := database.DB.Close()
+ err := db.Close()
if err != nil {
panic(err)
@@ -50,10 +53,13 @@ func runTests(m *testing.M) int {
f.Close()
databaseFile := f.Name()
- if err := database.Initialize(databaseFile); err != nil {
+ db = &sqlite.Database{}
+ if err := db.Open(databaseFile); err != nil {
panic(fmt.Sprintf("Could not initialize database: %s", err.Error()))
}
+ r = db.TxnRepository()
+
// defer close and delete the database
defer testTeardown(databaseFile)
@@ -71,7 +77,7 @@ func TestMain(m *testing.M) {
os.Exit(ret)
}
-func createPerformer(pqb models.PerformerWriter) error {
+func createPerformer(ctx context.Context, pqb models.PerformerWriter) error {
// create the performer
performer := models.Performer{
Checksum: testName,
@@ -79,7 +85,7 @@ func createPerformer(pqb models.PerformerWriter) error {
Favorite: sql.NullBool{Valid: true, Bool: false},
}
- _, err := pqb.Create(performer)
+ _, err := pqb.Create(ctx, performer)
if err != nil {
return err
}
@@ -87,23 +93,23 @@ func createPerformer(pqb models.PerformerWriter) error {
return nil
}
-func createStudio(qb models.StudioWriter, name string) (*models.Studio, error) {
+func createStudio(ctx context.Context, qb models.StudioWriter, name string) (*models.Studio, error) {
// create the studio
studio := models.Studio{
Checksum: name,
Name: sql.NullString{Valid: true, String: name},
}
- return qb.Create(studio)
+ return qb.Create(ctx, studio)
}
-func createTag(qb models.TagWriter) error {
+func createTag(ctx context.Context, qb models.TagWriter) error {
// create the studio
tag := models.Tag{
Name: testName,
}
- _, err := qb.Create(tag)
+ _, err := qb.Create(ctx, tag)
if err != nil {
return err
}
@@ -111,18 +117,18 @@ func createTag(qb models.TagWriter) error {
return nil
}
-func createScenes(sqb models.SceneReaderWriter) error {
+func createScenes(ctx context.Context, sqb models.SceneReaderWriter) error {
// create the scenes
scenePatterns, falseScenePatterns := generateTestPaths(testName, sceneExt)
for _, fn := range scenePatterns {
- err := createScene(sqb, makeScene(fn, true))
+ err := createScene(ctx, sqb, makeScene(fn, true))
if err != nil {
return err
}
}
for _, fn := range falseScenePatterns {
- err := createScene(sqb, makeScene(fn, false))
+ err := createScene(ctx, sqb, makeScene(fn, false))
if err != nil {
return err
}
@@ -132,7 +138,7 @@ func createScenes(sqb models.SceneReaderWriter) error {
for _, fn := range scenePatterns {
s := makeScene("organized"+fn, false)
s.Organized = true
- err := createScene(sqb, s)
+ err := createScene(ctx, sqb, s)
if err != nil {
return err
}
@@ -141,7 +147,7 @@ func createScenes(sqb models.SceneReaderWriter) error {
// create scene with existing studio io
studioScene := makeScene(existingStudioSceneName, true)
studioScene.StudioID = sql.NullInt64{Valid: true, Int64: int64(existingStudioID)}
- err := createScene(sqb, studioScene)
+ err := createScene(ctx, sqb, studioScene)
if err != nil {
return err
}
@@ -163,8 +169,8 @@ func makeScene(name string, expectedResult bool) *models.Scene {
return scene
}
-func createScene(sqb models.SceneWriter, scene *models.Scene) error {
- _, err := sqb.Create(*scene)
+func createScene(ctx context.Context, sqb models.SceneWriter, scene *models.Scene) error {
+ _, err := sqb.Create(ctx, *scene)
if err != nil {
return fmt.Errorf("Failed to create scene with name '%s': %s", scene.Path, err.Error())
@@ -173,18 +179,18 @@ func createScene(sqb models.SceneWriter, scene *models.Scene) error {
return nil
}
-func createImages(sqb models.ImageReaderWriter) error {
+func createImages(ctx context.Context, sqb models.ImageReaderWriter) error {
// create the images
imagePatterns, falseImagePatterns := generateTestPaths(testName, imageExt)
for _, fn := range imagePatterns {
- err := createImage(sqb, makeImage(fn, true))
+ err := createImage(ctx, sqb, makeImage(fn, true))
if err != nil {
return err
}
}
for _, fn := range falseImagePatterns {
- err := createImage(sqb, makeImage(fn, false))
+ err := createImage(ctx, sqb, makeImage(fn, false))
if err != nil {
return err
}
@@ -194,7 +200,7 @@ func createImages(sqb models.ImageReaderWriter) error {
for _, fn := range imagePatterns {
s := makeImage("organized"+fn, false)
s.Organized = true
- err := createImage(sqb, s)
+ err := createImage(ctx, sqb, s)
if err != nil {
return err
}
@@ -203,7 +209,7 @@ func createImages(sqb models.ImageReaderWriter) error {
// create image with existing studio io
studioImage := makeImage(existingStudioImageName, true)
studioImage.StudioID = sql.NullInt64{Valid: true, Int64: int64(existingStudioID)}
- err := createImage(sqb, studioImage)
+ err := createImage(ctx, sqb, studioImage)
if err != nil {
return err
}
@@ -225,8 +231,8 @@ func makeImage(name string, expectedResult bool) *models.Image {
return image
}
-func createImage(sqb models.ImageWriter, image *models.Image) error {
- _, err := sqb.Create(*image)
+func createImage(ctx context.Context, sqb models.ImageWriter, image *models.Image) error {
+ _, err := sqb.Create(ctx, *image)
if err != nil {
return fmt.Errorf("Failed to create image with name '%s': %s", image.Path, err.Error())
@@ -235,18 +241,18 @@ func createImage(sqb models.ImageWriter, image *models.Image) error {
return nil
}
-func createGalleries(sqb models.GalleryReaderWriter) error {
+func createGalleries(ctx context.Context, sqb models.GalleryReaderWriter) error {
// create the galleries
galleryPatterns, falseGalleryPatterns := generateTestPaths(testName, galleryExt)
for _, fn := range galleryPatterns {
- err := createGallery(sqb, makeGallery(fn, true))
+ err := createGallery(ctx, sqb, makeGallery(fn, true))
if err != nil {
return err
}
}
for _, fn := range falseGalleryPatterns {
- err := createGallery(sqb, makeGallery(fn, false))
+ err := createGallery(ctx, sqb, makeGallery(fn, false))
if err != nil {
return err
}
@@ -256,7 +262,7 @@ func createGalleries(sqb models.GalleryReaderWriter) error {
for _, fn := range galleryPatterns {
s := makeGallery("organized"+fn, false)
s.Organized = true
- err := createGallery(sqb, s)
+ err := createGallery(ctx, sqb, s)
if err != nil {
return err
}
@@ -265,7 +271,7 @@ func createGalleries(sqb models.GalleryReaderWriter) error {
// create gallery with existing studio io
studioGallery := makeGallery(existingStudioGalleryName, true)
studioGallery.StudioID = sql.NullInt64{Valid: true, Int64: int64(existingStudioID)}
- err := createGallery(sqb, studioGallery)
+ err := createGallery(ctx, sqb, studioGallery)
if err != nil {
return err
}
@@ -287,8 +293,8 @@ func makeGallery(name string, expectedResult bool) *models.Gallery {
return gallery
}
-func createGallery(sqb models.GalleryWriter, gallery *models.Gallery) error {
- _, err := sqb.Create(*gallery)
+func createGallery(ctx context.Context, sqb models.GalleryWriter, gallery *models.Gallery) error {
+ _, err := sqb.Create(ctx, *gallery)
if err != nil {
return fmt.Errorf("Failed to create gallery with name '%s': %s", gallery.Path.String, err.Error())
@@ -297,47 +303,46 @@ func createGallery(sqb models.GalleryWriter, gallery *models.Gallery) error {
return nil
}
-func withTxn(f func(r models.Repository) error) error {
- t := sqlite.NewTransactionManager()
- return t.WithTxn(context.TODO(), f)
+func withTxn(f func(ctx context.Context) error) error {
+ return txn.WithTxn(context.TODO(), db, f)
}
func populateDB() error {
- if err := withTxn(func(r models.Repository) error {
- err := createPerformer(r.Performer())
+ if err := withTxn(func(ctx context.Context) error {
+ err := createPerformer(ctx, r.Performer)
if err != nil {
return err
}
- _, err = createStudio(r.Studio(), testName)
+ _, err = createStudio(ctx, r.Studio, testName)
if err != nil {
return err
}
// create existing studio
- existingStudio, err := createStudio(r.Studio(), existingStudioName)
+ existingStudio, err := createStudio(ctx, r.Studio, existingStudioName)
if err != nil {
return err
}
existingStudioID = existingStudio.ID
- err = createTag(r.Tag())
+ err = createTag(ctx, r.Tag)
if err != nil {
return err
}
- err = createScenes(r.Scene())
+ err = createScenes(ctx, r.Scene)
if err != nil {
return err
}
- err = createImages(r.Image())
+ err = createImages(ctx, r.Image)
if err != nil {
return err
}
- err = createGalleries(r.Gallery())
+ err = createGalleries(ctx, r.Gallery)
if err != nil {
return err
}
@@ -352,9 +357,9 @@ func populateDB() error {
func TestParsePerformerScenes(t *testing.T) {
var performers []*models.Performer
- if err := withTxn(func(r models.Repository) error {
+ if err := withTxn(func(ctx context.Context) error {
var err error
- performers, err = r.Performer().All()
+ performers, err = r.Performer.All(ctx)
return err
}); err != nil {
t.Errorf("Error getting performer: %s", err)
@@ -362,24 +367,24 @@ func TestParsePerformerScenes(t *testing.T) {
}
for _, p := range performers {
- if err := withTxn(func(r models.Repository) error {
- return PerformerScenes(p, nil, r.Scene(), nil)
+ if err := withTxn(func(ctx context.Context) error {
+ return PerformerScenes(ctx, p, nil, r.Scene, nil)
}); err != nil {
t.Errorf("Error auto-tagging performers: %s", err)
}
}
// verify that scenes were tagged correctly
- withTxn(func(r models.Repository) error {
- pqb := r.Performer()
+ withTxn(func(ctx context.Context) error {
+ pqb := r.Performer
- scenes, err := r.Scene().All()
+ scenes, err := r.Scene.All(ctx)
if err != nil {
t.Error(err.Error())
}
for _, scene := range scenes {
- performers, err := pqb.FindBySceneID(scene.ID)
+ performers, err := pqb.FindBySceneID(ctx, scene.ID)
if err != nil {
t.Errorf("Error getting scene performers: %s", err.Error())
@@ -399,9 +404,9 @@ func TestParsePerformerScenes(t *testing.T) {
func TestParseStudioScenes(t *testing.T) {
var studios []*models.Studio
- if err := withTxn(func(r models.Repository) error {
+ if err := withTxn(func(ctx context.Context) error {
var err error
- studios, err = r.Studio().All()
+ studios, err = r.Studio.All(ctx)
return err
}); err != nil {
t.Errorf("Error getting studio: %s", err)
@@ -409,21 +414,21 @@ func TestParseStudioScenes(t *testing.T) {
}
for _, s := range studios {
- if err := withTxn(func(r models.Repository) error {
- aliases, err := r.Studio().GetAliases(s.ID)
+ if err := withTxn(func(ctx context.Context) error {
+ aliases, err := r.Studio.GetAliases(ctx, s.ID)
if err != nil {
return err
}
- return StudioScenes(s, nil, aliases, r.Scene(), nil)
+ return StudioScenes(ctx, s, nil, aliases, r.Scene, nil)
}); err != nil {
t.Errorf("Error auto-tagging performers: %s", err)
}
}
// verify that scenes were tagged correctly
- withTxn(func(r models.Repository) error {
- scenes, err := r.Scene().All()
+ withTxn(func(ctx context.Context) error {
+ scenes, err := r.Scene.All(ctx)
if err != nil {
t.Error(err.Error())
}
@@ -455,9 +460,9 @@ func TestParseStudioScenes(t *testing.T) {
func TestParseTagScenes(t *testing.T) {
var tags []*models.Tag
- if err := withTxn(func(r models.Repository) error {
+ if err := withTxn(func(ctx context.Context) error {
var err error
- tags, err = r.Tag().All()
+ tags, err = r.Tag.All(ctx)
return err
}); err != nil {
t.Errorf("Error getting performer: %s", err)
@@ -465,29 +470,29 @@ func TestParseTagScenes(t *testing.T) {
}
for _, s := range tags {
- if err := withTxn(func(r models.Repository) error {
- aliases, err := r.Tag().GetAliases(s.ID)
+ if err := withTxn(func(ctx context.Context) error {
+ aliases, err := r.Tag.GetAliases(ctx, s.ID)
if err != nil {
return err
}
- return TagScenes(s, nil, aliases, r.Scene(), nil)
+ return TagScenes(ctx, s, nil, aliases, r.Scene, nil)
}); err != nil {
t.Errorf("Error auto-tagging performers: %s", err)
}
}
// verify that scenes were tagged correctly
- withTxn(func(r models.Repository) error {
- scenes, err := r.Scene().All()
+ withTxn(func(ctx context.Context) error {
+ scenes, err := r.Scene.All(ctx)
if err != nil {
t.Error(err.Error())
}
- tqb := r.Tag()
+ tqb := r.Tag
for _, scene := range scenes {
- tags, err := tqb.FindBySceneID(scene.ID)
+ tags, err := tqb.FindBySceneID(ctx, scene.ID)
if err != nil {
t.Errorf("Error getting scene tags: %s", err.Error())
@@ -507,9 +512,9 @@ func TestParseTagScenes(t *testing.T) {
func TestParsePerformerImages(t *testing.T) {
var performers []*models.Performer
- if err := withTxn(func(r models.Repository) error {
+ if err := withTxn(func(ctx context.Context) error {
var err error
- performers, err = r.Performer().All()
+ performers, err = r.Performer.All(ctx)
return err
}); err != nil {
t.Errorf("Error getting performer: %s", err)
@@ -517,24 +522,24 @@ func TestParsePerformerImages(t *testing.T) {
}
for _, p := range performers {
- if err := withTxn(func(r models.Repository) error {
- return PerformerImages(p, nil, r.Image(), nil)
+ if err := withTxn(func(ctx context.Context) error {
+ return PerformerImages(ctx, p, nil, r.Image, nil)
}); err != nil {
t.Errorf("Error auto-tagging performers: %s", err)
}
}
// verify that images were tagged correctly
- withTxn(func(r models.Repository) error {
- pqb := r.Performer()
+ withTxn(func(ctx context.Context) error {
+ pqb := r.Performer
- images, err := r.Image().All()
+ images, err := r.Image.All(ctx)
if err != nil {
t.Error(err.Error())
}
for _, image := range images {
- performers, err := pqb.FindByImageID(image.ID)
+ performers, err := pqb.FindByImageID(ctx, image.ID)
if err != nil {
t.Errorf("Error getting image performers: %s", err.Error())
@@ -554,9 +559,9 @@ func TestParsePerformerImages(t *testing.T) {
func TestParseStudioImages(t *testing.T) {
var studios []*models.Studio
- if err := withTxn(func(r models.Repository) error {
+ if err := withTxn(func(ctx context.Context) error {
var err error
- studios, err = r.Studio().All()
+ studios, err = r.Studio.All(ctx)
return err
}); err != nil {
t.Errorf("Error getting studio: %s", err)
@@ -564,21 +569,21 @@ func TestParseStudioImages(t *testing.T) {
}
for _, s := range studios {
- if err := withTxn(func(r models.Repository) error {
- aliases, err := r.Studio().GetAliases(s.ID)
+ if err := withTxn(func(ctx context.Context) error {
+ aliases, err := r.Studio.GetAliases(ctx, s.ID)
if err != nil {
return err
}
- return StudioImages(s, nil, aliases, r.Image(), nil)
+ return StudioImages(ctx, s, nil, aliases, r.Image, nil)
}); err != nil {
t.Errorf("Error auto-tagging performers: %s", err)
}
}
// verify that images were tagged correctly
- withTxn(func(r models.Repository) error {
- images, err := r.Image().All()
+ withTxn(func(ctx context.Context) error {
+ images, err := r.Image.All(ctx)
if err != nil {
t.Error(err.Error())
}
@@ -610,9 +615,9 @@ func TestParseStudioImages(t *testing.T) {
func TestParseTagImages(t *testing.T) {
var tags []*models.Tag
- if err := withTxn(func(r models.Repository) error {
+ if err := withTxn(func(ctx context.Context) error {
var err error
- tags, err = r.Tag().All()
+ tags, err = r.Tag.All(ctx)
return err
}); err != nil {
t.Errorf("Error getting performer: %s", err)
@@ -620,29 +625,29 @@ func TestParseTagImages(t *testing.T) {
}
for _, s := range tags {
- if err := withTxn(func(r models.Repository) error {
- aliases, err := r.Tag().GetAliases(s.ID)
+ if err := withTxn(func(ctx context.Context) error {
+ aliases, err := r.Tag.GetAliases(ctx, s.ID)
if err != nil {
return err
}
- return TagImages(s, nil, aliases, r.Image(), nil)
+ return TagImages(ctx, s, nil, aliases, r.Image, nil)
}); err != nil {
t.Errorf("Error auto-tagging performers: %s", err)
}
}
// verify that images were tagged correctly
- withTxn(func(r models.Repository) error {
- images, err := r.Image().All()
+ withTxn(func(ctx context.Context) error {
+ images, err := r.Image.All(ctx)
if err != nil {
t.Error(err.Error())
}
- tqb := r.Tag()
+ tqb := r.Tag
for _, image := range images {
- tags, err := tqb.FindByImageID(image.ID)
+ tags, err := tqb.FindByImageID(ctx, image.ID)
if err != nil {
t.Errorf("Error getting image tags: %s", err.Error())
@@ -662,9 +667,9 @@ func TestParseTagImages(t *testing.T) {
func TestParsePerformerGalleries(t *testing.T) {
var performers []*models.Performer
- if err := withTxn(func(r models.Repository) error {
+ if err := withTxn(func(ctx context.Context) error {
var err error
- performers, err = r.Performer().All()
+ performers, err = r.Performer.All(ctx)
return err
}); err != nil {
t.Errorf("Error getting performer: %s", err)
@@ -672,24 +677,24 @@ func TestParsePerformerGalleries(t *testing.T) {
}
for _, p := range performers {
- if err := withTxn(func(r models.Repository) error {
- return PerformerGalleries(p, nil, r.Gallery(), nil)
+ if err := withTxn(func(ctx context.Context) error {
+ return PerformerGalleries(ctx, p, nil, r.Gallery, nil)
}); err != nil {
t.Errorf("Error auto-tagging performers: %s", err)
}
}
// verify that galleries were tagged correctly
- withTxn(func(r models.Repository) error {
- pqb := r.Performer()
+ withTxn(func(ctx context.Context) error {
+ pqb := r.Performer
- galleries, err := r.Gallery().All()
+ galleries, err := r.Gallery.All(ctx)
if err != nil {
t.Error(err.Error())
}
for _, gallery := range galleries {
- performers, err := pqb.FindByGalleryID(gallery.ID)
+ performers, err := pqb.FindByGalleryID(ctx, gallery.ID)
if err != nil {
t.Errorf("Error getting gallery performers: %s", err.Error())
@@ -709,9 +714,9 @@ func TestParsePerformerGalleries(t *testing.T) {
func TestParseStudioGalleries(t *testing.T) {
var studios []*models.Studio
- if err := withTxn(func(r models.Repository) error {
+ if err := withTxn(func(ctx context.Context) error {
var err error
- studios, err = r.Studio().All()
+ studios, err = r.Studio.All(ctx)
return err
}); err != nil {
t.Errorf("Error getting studio: %s", err)
@@ -719,21 +724,21 @@ func TestParseStudioGalleries(t *testing.T) {
}
for _, s := range studios {
- if err := withTxn(func(r models.Repository) error {
- aliases, err := r.Studio().GetAliases(s.ID)
+ if err := withTxn(func(ctx context.Context) error {
+ aliases, err := r.Studio.GetAliases(ctx, s.ID)
if err != nil {
return err
}
- return StudioGalleries(s, nil, aliases, r.Gallery(), nil)
+ return StudioGalleries(ctx, s, nil, aliases, r.Gallery, nil)
}); err != nil {
t.Errorf("Error auto-tagging performers: %s", err)
}
}
// verify that galleries were tagged correctly
- withTxn(func(r models.Repository) error {
- galleries, err := r.Gallery().All()
+ withTxn(func(ctx context.Context) error {
+ galleries, err := r.Gallery.All(ctx)
if err != nil {
t.Error(err.Error())
}
@@ -765,9 +770,9 @@ func TestParseStudioGalleries(t *testing.T) {
func TestParseTagGalleries(t *testing.T) {
var tags []*models.Tag
- if err := withTxn(func(r models.Repository) error {
+ if err := withTxn(func(ctx context.Context) error {
var err error
- tags, err = r.Tag().All()
+ tags, err = r.Tag.All(ctx)
return err
}); err != nil {
t.Errorf("Error getting performer: %s", err)
@@ -775,29 +780,29 @@ func TestParseTagGalleries(t *testing.T) {
}
for _, s := range tags {
- if err := withTxn(func(r models.Repository) error {
- aliases, err := r.Tag().GetAliases(s.ID)
+ if err := withTxn(func(ctx context.Context) error {
+ aliases, err := r.Tag.GetAliases(ctx, s.ID)
if err != nil {
return err
}
- return TagGalleries(s, nil, aliases, r.Gallery(), nil)
+ return TagGalleries(ctx, s, nil, aliases, r.Gallery, nil)
}); err != nil {
t.Errorf("Error auto-tagging performers: %s", err)
}
}
// verify that galleries were tagged correctly
- withTxn(func(r models.Repository) error {
- galleries, err := r.Gallery().All()
+ withTxn(func(ctx context.Context) error {
+ galleries, err := r.Gallery.All(ctx)
if err != nil {
t.Error(err.Error())
}
- tqb := r.Tag()
+ tqb := r.Tag
for _, gallery := range galleries {
- tags, err := tqb.FindByGalleryID(gallery.ID)
+ tags, err := tqb.FindByGalleryID(ctx, gallery.ID)
if err != nil {
t.Errorf("Error getting gallery tags: %s", err.Error())
diff --git a/internal/autotag/performer.go b/internal/autotag/performer.go
index a6c89466a..ea42667e3 100644
--- a/internal/autotag/performer.go
+++ b/internal/autotag/performer.go
@@ -1,6 +1,8 @@
package autotag
import (
+ "context"
+
"github.com/stashapp/stash/pkg/gallery"
"github.com/stashapp/stash/pkg/image"
"github.com/stashapp/stash/pkg/match"
@@ -8,6 +10,21 @@ import (
"github.com/stashapp/stash/pkg/scene"
)
+type SceneQueryPerformerUpdater interface {
+ scene.Queryer
+ scene.PerformerUpdater
+}
+
+type ImageQueryPerformerUpdater interface {
+ image.Queryer
+ image.PerformerUpdater
+}
+
+type GalleryQueryPerformerUpdater interface {
+ gallery.Queryer
+ gallery.PerformerUpdater
+}
+
func getPerformerTagger(p *models.Performer, cache *match.Cache) tagger {
return tagger{
ID: p.ID,
@@ -18,28 +35,28 @@ func getPerformerTagger(p *models.Performer, cache *match.Cache) tagger {
}
// PerformerScenes searches for scenes whose path matches the provided performer name and tags the scene with the performer.
-func PerformerScenes(p *models.Performer, paths []string, rw models.SceneReaderWriter, cache *match.Cache) error {
+func PerformerScenes(ctx context.Context, p *models.Performer, paths []string, rw SceneQueryPerformerUpdater, cache *match.Cache) error {
t := getPerformerTagger(p, cache)
- return t.tagScenes(paths, rw, func(subjectID, otherID int) (bool, error) {
- return scene.AddPerformer(rw, otherID, subjectID)
+ return t.tagScenes(ctx, paths, rw, func(subjectID, otherID int) (bool, error) {
+ return scene.AddPerformer(ctx, rw, otherID, subjectID)
})
}
// PerformerImages searches for images whose path matches the provided performer name and tags the image with the performer.
-func PerformerImages(p *models.Performer, paths []string, rw models.ImageReaderWriter, cache *match.Cache) error {
+func PerformerImages(ctx context.Context, p *models.Performer, paths []string, rw ImageQueryPerformerUpdater, cache *match.Cache) error {
t := getPerformerTagger(p, cache)
- return t.tagImages(paths, rw, func(subjectID, otherID int) (bool, error) {
- return image.AddPerformer(rw, otherID, subjectID)
+ return t.tagImages(ctx, paths, rw, func(subjectID, otherID int) (bool, error) {
+ return image.AddPerformer(ctx, rw, otherID, subjectID)
})
}
// PerformerGalleries searches for galleries whose path matches the provided performer name and tags the gallery with the performer.
-func PerformerGalleries(p *models.Performer, paths []string, rw models.GalleryReaderWriter, cache *match.Cache) error {
+func PerformerGalleries(ctx context.Context, p *models.Performer, paths []string, rw GalleryQueryPerformerUpdater, cache *match.Cache) error {
t := getPerformerTagger(p, cache)
- return t.tagGalleries(paths, rw, func(subjectID, otherID int) (bool, error) {
- return gallery.AddPerformer(rw, otherID, subjectID)
+ return t.tagGalleries(ctx, paths, rw, func(subjectID, otherID int) (bool, error) {
+ return gallery.AddPerformer(ctx, rw, otherID, subjectID)
})
}
diff --git a/internal/autotag/performer_test.go b/internal/autotag/performer_test.go
index 31befd76a..54a98958a 100644
--- a/internal/autotag/performer_test.go
+++ b/internal/autotag/performer_test.go
@@ -72,16 +72,16 @@ func testPerformerScenes(t *testing.T, performerName, expectedRegex string) {
PerPage: &perPage,
}
- mockSceneReader.On("Query", scene.QueryOptions(expectedSceneFilter, expectedFindFilter, false)).
+ mockSceneReader.On("Query", testCtx, scene.QueryOptions(expectedSceneFilter, expectedFindFilter, false)).
Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once()
for i := range matchingPaths {
sceneID := i + 1
- mockSceneReader.On("GetPerformerIDs", sceneID).Return(nil, nil).Once()
- mockSceneReader.On("UpdatePerformers", sceneID, []int{performerID}).Return(nil).Once()
+ mockSceneReader.On("GetPerformerIDs", testCtx, sceneID).Return(nil, nil).Once()
+ mockSceneReader.On("UpdatePerformers", testCtx, sceneID, []int{performerID}).Return(nil).Once()
}
- err := PerformerScenes(&performer, nil, mockSceneReader, nil)
+ err := PerformerScenes(testCtx, &performer, nil, mockSceneReader, nil)
assert := assert.New(t)
@@ -147,16 +147,16 @@ func testPerformerImages(t *testing.T, performerName, expectedRegex string) {
PerPage: &perPage,
}
- mockImageReader.On("Query", image.QueryOptions(expectedImageFilter, expectedFindFilter, false)).
+ mockImageReader.On("Query", testCtx, image.QueryOptions(expectedImageFilter, expectedFindFilter, false)).
Return(mocks.ImageQueryResult(images, len(images)), nil).Once()
for i := range matchingPaths {
imageID := i + 1
- mockImageReader.On("GetPerformerIDs", imageID).Return(nil, nil).Once()
- mockImageReader.On("UpdatePerformers", imageID, []int{performerID}).Return(nil).Once()
+ mockImageReader.On("GetPerformerIDs", testCtx, imageID).Return(nil, nil).Once()
+ mockImageReader.On("UpdatePerformers", testCtx, imageID, []int{performerID}).Return(nil).Once()
}
- err := PerformerImages(&performer, nil, mockImageReader, nil)
+ err := PerformerImages(testCtx, &performer, nil, mockImageReader, nil)
assert := assert.New(t)
@@ -222,15 +222,15 @@ func testPerformerGalleries(t *testing.T, performerName, expectedRegex string) {
PerPage: &perPage,
}
- mockGalleryReader.On("Query", expectedGalleryFilter, expectedFindFilter).Return(galleries, len(galleries), nil).Once()
+ mockGalleryReader.On("Query", testCtx, expectedGalleryFilter, expectedFindFilter).Return(galleries, len(galleries), nil).Once()
for i := range matchingPaths {
galleryID := i + 1
- mockGalleryReader.On("GetPerformerIDs", galleryID).Return(nil, nil).Once()
- mockGalleryReader.On("UpdatePerformers", galleryID, []int{performerID}).Return(nil).Once()
+ mockGalleryReader.On("GetPerformerIDs", testCtx, galleryID).Return(nil, nil).Once()
+ mockGalleryReader.On("UpdatePerformers", testCtx, galleryID, []int{performerID}).Return(nil).Once()
}
- err := PerformerGalleries(&performer, nil, mockGalleryReader, nil)
+ err := PerformerGalleries(testCtx, &performer, nil, mockGalleryReader, nil)
assert := assert.New(t)
diff --git a/internal/autotag/scene.go b/internal/autotag/scene.go
index cfdcaf393..6c6aeb875 100644
--- a/internal/autotag/scene.go
+++ b/internal/autotag/scene.go
@@ -1,6 +1,8 @@
package autotag
import (
+ "context"
+
"github.com/stashapp/stash/pkg/match"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/scene"
@@ -17,18 +19,18 @@ func getSceneFileTagger(s *models.Scene, cache *match.Cache) tagger {
}
// ScenePerformers tags the provided scene with performers whose name matches the scene's path.
-func ScenePerformers(s *models.Scene, rw models.SceneReaderWriter, performerReader models.PerformerReader, cache *match.Cache) error {
+func ScenePerformers(ctx context.Context, s *models.Scene, rw scene.PerformerUpdater, performerReader match.PerformerAutoTagQueryer, cache *match.Cache) error {
t := getSceneFileTagger(s, cache)
- return t.tagPerformers(performerReader, func(subjectID, otherID int) (bool, error) {
- return scene.AddPerformer(rw, subjectID, otherID)
+ return t.tagPerformers(ctx, performerReader, func(subjectID, otherID int) (bool, error) {
+ return scene.AddPerformer(ctx, rw, subjectID, otherID)
})
}
// SceneStudios tags the provided scene with the first studio whose name matches the scene's path.
//
// Scenes will not be tagged if studio is already set.
-func SceneStudios(s *models.Scene, rw models.SceneReaderWriter, studioReader models.StudioReader, cache *match.Cache) error {
+func SceneStudios(ctx context.Context, s *models.Scene, rw SceneFinderUpdater, studioReader match.StudioAutoTagQueryer, cache *match.Cache) error {
if s.StudioID.Valid {
// don't modify
return nil
@@ -36,16 +38,16 @@ func SceneStudios(s *models.Scene, rw models.SceneReaderWriter, studioReader mod
t := getSceneFileTagger(s, cache)
- return t.tagStudios(studioReader, func(subjectID, otherID int) (bool, error) {
- return addSceneStudio(rw, subjectID, otherID)
+ return t.tagStudios(ctx, studioReader, func(subjectID, otherID int) (bool, error) {
+ return addSceneStudio(ctx, rw, subjectID, otherID)
})
}
// SceneTags tags the provided scene with tags whose name matches the scene's path.
-func SceneTags(s *models.Scene, rw models.SceneReaderWriter, tagReader models.TagReader, cache *match.Cache) error {
+func SceneTags(ctx context.Context, s *models.Scene, rw scene.TagUpdater, tagReader match.TagAutoTagQueryer, cache *match.Cache) error {
t := getSceneFileTagger(s, cache)
- return t.tagTags(tagReader, func(subjectID, otherID int) (bool, error) {
- return scene.AddTag(rw, subjectID, otherID)
+ return t.tagTags(ctx, tagReader, func(subjectID, otherID int) (bool, error) {
+ return scene.AddTag(ctx, rw, subjectID, otherID)
})
}
diff --git a/internal/autotag/scene_test.go b/internal/autotag/scene_test.go
index 578b9e7f6..6e66482fc 100644
--- a/internal/autotag/scene_test.go
+++ b/internal/autotag/scene_test.go
@@ -173,19 +173,19 @@ func TestScenePerformers(t *testing.T) {
mockPerformerReader := &mocks.PerformerReaderWriter{}
mockSceneReader := &mocks.SceneReaderWriter{}
- mockPerformerReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil)
- mockPerformerReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once()
+ mockPerformerReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
+ mockPerformerReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Performer{&performer, &reversedPerformer}, nil).Once()
if test.Matches {
- mockSceneReader.On("GetPerformerIDs", sceneID).Return(nil, nil).Once()
- mockSceneReader.On("UpdatePerformers", sceneID, []int{performerID}).Return(nil).Once()
+ mockSceneReader.On("GetPerformerIDs", testCtx, sceneID).Return(nil, nil).Once()
+ mockSceneReader.On("UpdatePerformers", testCtx, sceneID, []int{performerID}).Return(nil).Once()
}
scene := models.Scene{
ID: sceneID,
Path: test.Path,
}
- err := ScenePerformers(&scene, mockSceneReader, mockPerformerReader, nil)
+ err := ScenePerformers(testCtx, &scene, mockSceneReader, mockPerformerReader, nil)
assert.Nil(err)
mockPerformerReader.AssertExpectations(t)
@@ -217,9 +217,9 @@ func TestSceneStudios(t *testing.T) {
doTest := func(mockStudioReader *mocks.StudioReaderWriter, mockSceneReader *mocks.SceneReaderWriter, test pathTestTable) {
if test.Matches {
- mockSceneReader.On("Find", sceneID).Return(&models.Scene{}, nil).Once()
+ mockSceneReader.On("Find", testCtx, sceneID).Return(&models.Scene{}, nil).Once()
expectedStudioID := models.NullInt64(studioID)
- mockSceneReader.On("Update", models.ScenePartial{
+ mockSceneReader.On("Update", testCtx, models.ScenePartial{
ID: sceneID,
StudioID: &expectedStudioID,
}).Return(nil, nil).Once()
@@ -229,7 +229,7 @@ func TestSceneStudios(t *testing.T) {
ID: sceneID,
Path: test.Path,
}
- err := SceneStudios(&scene, mockSceneReader, mockStudioReader, nil)
+ err := SceneStudios(testCtx, &scene, mockSceneReader, mockStudioReader, nil)
assert.Nil(err)
mockStudioReader.AssertExpectations(t)
@@ -240,9 +240,9 @@ func TestSceneStudios(t *testing.T) {
mockStudioReader := &mocks.StudioReaderWriter{}
mockSceneReader := &mocks.SceneReaderWriter{}
- mockStudioReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil)
- mockStudioReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
- mockStudioReader.On("GetAliases", mock.Anything).Return([]string{}, nil).Maybe()
+ mockStudioReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
+ mockStudioReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
+ mockStudioReader.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe()
doTest(mockStudioReader, mockSceneReader, test)
}
@@ -255,12 +255,12 @@ func TestSceneStudios(t *testing.T) {
mockStudioReader := &mocks.StudioReaderWriter{}
mockSceneReader := &mocks.SceneReaderWriter{}
- mockStudioReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil)
- mockStudioReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
- mockStudioReader.On("GetAliases", studioID).Return([]string{
+ mockStudioReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
+ mockStudioReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Studio{&studio, &reversedStudio}, nil).Once()
+ mockStudioReader.On("GetAliases", testCtx, studioID).Return([]string{
studioName,
}, nil).Once()
- mockStudioReader.On("GetAliases", reversedStudioID).Return([]string{}, nil).Once()
+ mockStudioReader.On("GetAliases", testCtx, reversedStudioID).Return([]string{}, nil).Once()
doTest(mockStudioReader, mockSceneReader, test)
}
@@ -290,15 +290,15 @@ func TestSceneTags(t *testing.T) {
doTest := func(mockTagReader *mocks.TagReaderWriter, mockSceneReader *mocks.SceneReaderWriter, test pathTestTable) {
if test.Matches {
- mockSceneReader.On("GetTagIDs", sceneID).Return(nil, nil).Once()
- mockSceneReader.On("UpdateTags", sceneID, []int{tagID}).Return(nil).Once()
+ mockSceneReader.On("GetTagIDs", testCtx, sceneID).Return(nil, nil).Once()
+ mockSceneReader.On("UpdateTags", testCtx, sceneID, []int{tagID}).Return(nil).Once()
}
scene := models.Scene{
ID: sceneID,
Path: test.Path,
}
- err := SceneTags(&scene, mockSceneReader, mockTagReader, nil)
+ err := SceneTags(testCtx, &scene, mockSceneReader, mockTagReader, nil)
assert.Nil(err)
mockTagReader.AssertExpectations(t)
@@ -309,9 +309,9 @@ func TestSceneTags(t *testing.T) {
mockTagReader := &mocks.TagReaderWriter{}
mockSceneReader := &mocks.SceneReaderWriter{}
- mockTagReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil)
- mockTagReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
- mockTagReader.On("GetAliases", mock.Anything).Return([]string{}, nil).Maybe()
+ mockTagReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
+ mockTagReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
+ mockTagReader.On("GetAliases", testCtx, mock.Anything).Return([]string{}, nil).Maybe()
doTest(mockTagReader, mockSceneReader, test)
}
@@ -324,12 +324,12 @@ func TestSceneTags(t *testing.T) {
mockTagReader := &mocks.TagReaderWriter{}
mockSceneReader := &mocks.SceneReaderWriter{}
- mockTagReader.On("Query", mock.Anything, mock.Anything).Return(nil, 0, nil)
- mockTagReader.On("QueryForAutoTag", mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
- mockTagReader.On("GetAliases", tagID).Return([]string{
+ mockTagReader.On("Query", testCtx, mock.Anything, mock.Anything).Return(nil, 0, nil)
+ mockTagReader.On("QueryForAutoTag", testCtx, mock.Anything).Return([]*models.Tag{&tag, &reversedTag}, nil).Once()
+ mockTagReader.On("GetAliases", testCtx, tagID).Return([]string{
tagName,
}, nil).Once()
- mockTagReader.On("GetAliases", reversedTagID).Return([]string{}, nil).Once()
+ mockTagReader.On("GetAliases", testCtx, reversedTagID).Return([]string{}, nil).Once()
doTest(mockTagReader, mockSceneReader, test)
}
diff --git a/internal/autotag/studio.go b/internal/autotag/studio.go
index 4a02e7305..79cb22586 100644
--- a/internal/autotag/studio.go
+++ b/internal/autotag/studio.go
@@ -1,15 +1,19 @@
package autotag
import (
+ "context"
"database/sql"
+ "github.com/stashapp/stash/pkg/gallery"
+ "github.com/stashapp/stash/pkg/image"
"github.com/stashapp/stash/pkg/match"
"github.com/stashapp/stash/pkg/models"
+ "github.com/stashapp/stash/pkg/scene"
)
-func addSceneStudio(sceneWriter models.SceneReaderWriter, sceneID, studioID int) (bool, error) {
+func addSceneStudio(ctx context.Context, sceneWriter SceneFinderUpdater, sceneID, studioID int) (bool, error) {
// don't set if already set
- scene, err := sceneWriter.Find(sceneID)
+ scene, err := sceneWriter.Find(ctx, sceneID)
if err != nil {
return false, err
}
@@ -25,15 +29,15 @@ func addSceneStudio(sceneWriter models.SceneReaderWriter, sceneID, studioID int)
StudioID: &s,
}
- if _, err := sceneWriter.Update(scenePartial); err != nil {
+ if _, err := sceneWriter.Update(ctx, scenePartial); err != nil {
return false, err
}
return true, nil
}
-func addImageStudio(imageWriter models.ImageReaderWriter, imageID, studioID int) (bool, error) {
+func addImageStudio(ctx context.Context, imageWriter ImageFinderUpdater, imageID, studioID int) (bool, error) {
// don't set if already set
- image, err := imageWriter.Find(imageID)
+ image, err := imageWriter.Find(ctx, imageID)
if err != nil {
return false, err
}
@@ -49,15 +53,15 @@ func addImageStudio(imageWriter models.ImageReaderWriter, imageID, studioID int)
StudioID: &s,
}
- if _, err := imageWriter.Update(imagePartial); err != nil {
+ if _, err := imageWriter.Update(ctx, imagePartial); err != nil {
return false, err
}
return true, nil
}
-func addGalleryStudio(galleryWriter models.GalleryReaderWriter, galleryID, studioID int) (bool, error) {
+func addGalleryStudio(ctx context.Context, galleryWriter GalleryFinderUpdater, galleryID, studioID int) (bool, error) {
// don't set if already set
- gallery, err := galleryWriter.Find(galleryID)
+ gallery, err := galleryWriter.Find(ctx, galleryID)
if err != nil {
return false, err
}
@@ -73,7 +77,7 @@ func addGalleryStudio(galleryWriter models.GalleryReaderWriter, galleryID, studi
StudioID: &s,
}
- if _, err := galleryWriter.UpdatePartial(galleryPartial); err != nil {
+ if _, err := galleryWriter.UpdatePartial(ctx, galleryPartial); err != nil {
return false, err
}
return true, nil
@@ -98,13 +102,19 @@ func getStudioTagger(p *models.Studio, aliases []string, cache *match.Cache) []t
return ret
}
+type SceneFinderUpdater interface {
+ scene.Queryer
+ Find(ctx context.Context, id int) (*models.Scene, error)
+ Update(ctx context.Context, updatedScene models.ScenePartial) (*models.Scene, error)
+}
+
// StudioScenes searches for scenes whose path matches the provided studio name and tags the scene with the studio, if studio is not already set on the scene.
-func StudioScenes(p *models.Studio, paths []string, aliases []string, rw models.SceneReaderWriter, cache *match.Cache) error {
+func StudioScenes(ctx context.Context, p *models.Studio, paths []string, aliases []string, rw SceneFinderUpdater, cache *match.Cache) error {
t := getStudioTagger(p, aliases, cache)
for _, tt := range t {
- if err := tt.tagScenes(paths, rw, func(subjectID, otherID int) (bool, error) {
- return addSceneStudio(rw, otherID, subjectID)
+ if err := tt.tagScenes(ctx, paths, rw, func(subjectID, otherID int) (bool, error) {
+ return addSceneStudio(ctx, rw, otherID, subjectID)
}); err != nil {
return err
}
@@ -113,13 +123,19 @@ func StudioScenes(p *models.Studio, paths []string, aliases []string, rw models.
return nil
}
+type ImageFinderUpdater interface {
+ image.Queryer
+ Find(ctx context.Context, id int) (*models.Image, error)
+ Update(ctx context.Context, updatedImage models.ImagePartial) (*models.Image, error)
+}
+
// StudioImages searches for images whose path matches the provided studio name and tags the image with the studio, if studio is not already set on the image.
-func StudioImages(p *models.Studio, paths []string, aliases []string, rw models.ImageReaderWriter, cache *match.Cache) error {
+func StudioImages(ctx context.Context, p *models.Studio, paths []string, aliases []string, rw ImageFinderUpdater, cache *match.Cache) error {
t := getStudioTagger(p, aliases, cache)
for _, tt := range t {
- if err := tt.tagImages(paths, rw, func(subjectID, otherID int) (bool, error) {
- return addImageStudio(rw, otherID, subjectID)
+ if err := tt.tagImages(ctx, paths, rw, func(subjectID, otherID int) (bool, error) {
+ return addImageStudio(ctx, rw, otherID, subjectID)
}); err != nil {
return err
}
@@ -128,13 +144,19 @@ func StudioImages(p *models.Studio, paths []string, aliases []string, rw models.
return nil
}
+type GalleryFinderUpdater interface {
+ gallery.Queryer
+ Find(ctx context.Context, id int) (*models.Gallery, error)
+ UpdatePartial(ctx context.Context, updatedGallery models.GalleryPartial) (*models.Gallery, error)
+}
+
// StudioGalleries searches for galleries whose path matches the provided studio name and tags the gallery with the studio, if studio is not already set on the gallery.
-func StudioGalleries(p *models.Studio, paths []string, aliases []string, rw models.GalleryReaderWriter, cache *match.Cache) error {
+func StudioGalleries(ctx context.Context, p *models.Studio, paths []string, aliases []string, rw GalleryFinderUpdater, cache *match.Cache) error {
t := getStudioTagger(p, aliases, cache)
for _, tt := range t {
- if err := tt.tagGalleries(paths, rw, func(subjectID, otherID int) (bool, error) {
- return addGalleryStudio(rw, otherID, subjectID)
+ if err := tt.tagGalleries(ctx, paths, rw, func(subjectID, otherID int) (bool, error) {
+ return addGalleryStudio(ctx, rw, otherID, subjectID)
}); err != nil {
return err
}
diff --git a/internal/autotag/studio_test.go b/internal/autotag/studio_test.go
index 76d7e7db5..861740612 100644
--- a/internal/autotag/studio_test.go
+++ b/internal/autotag/studio_test.go
@@ -113,7 +113,7 @@ func testStudioScenes(t *testing.T, tc testStudioCase) {
}
// if alias provided, then don't find by name
- onNameQuery := mockSceneReader.On("Query", scene.QueryOptions(expectedSceneFilter, expectedFindFilter, false))
+ onNameQuery := mockSceneReader.On("Query", testCtx, scene.QueryOptions(expectedSceneFilter, expectedFindFilter, false))
if aliasName == "" {
onNameQuery.Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once()
@@ -128,21 +128,21 @@ func testStudioScenes(t *testing.T, tc testStudioCase) {
},
}
- mockSceneReader.On("Query", scene.QueryOptions(expectedAliasFilter, expectedFindFilter, false)).
+ mockSceneReader.On("Query", testCtx, scene.QueryOptions(expectedAliasFilter, expectedFindFilter, false)).
Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once()
}
for i := range matchingPaths {
sceneID := i + 1
- mockSceneReader.On("Find", sceneID).Return(&models.Scene{}, nil).Once()
+ mockSceneReader.On("Find", testCtx, sceneID).Return(&models.Scene{}, nil).Once()
expectedStudioID := models.NullInt64(studioID)
- mockSceneReader.On("Update", models.ScenePartial{
+ mockSceneReader.On("Update", testCtx, models.ScenePartial{
ID: sceneID,
StudioID: &expectedStudioID,
}).Return(nil, nil).Once()
}
- err := StudioScenes(&studio, nil, aliases, mockSceneReader, nil)
+ err := StudioScenes(testCtx, &studio, nil, aliases, mockSceneReader, nil)
assert := assert.New(t)
@@ -206,7 +206,7 @@ func testStudioImages(t *testing.T, tc testStudioCase) {
}
// if alias provided, then don't find by name
- onNameQuery := mockImageReader.On("Query", image.QueryOptions(expectedImageFilter, expectedFindFilter, false))
+ onNameQuery := mockImageReader.On("Query", testCtx, image.QueryOptions(expectedImageFilter, expectedFindFilter, false))
if aliasName == "" {
onNameQuery.Return(mocks.ImageQueryResult(images, len(images)), nil).Once()
} else {
@@ -220,21 +220,21 @@ func testStudioImages(t *testing.T, tc testStudioCase) {
},
}
- mockImageReader.On("Query", image.QueryOptions(expectedAliasFilter, expectedFindFilter, false)).
+ mockImageReader.On("Query", testCtx, image.QueryOptions(expectedAliasFilter, expectedFindFilter, false)).
Return(mocks.ImageQueryResult(images, len(images)), nil).Once()
}
for i := range matchingPaths {
imageID := i + 1
- mockImageReader.On("Find", imageID).Return(&models.Image{}, nil).Once()
+ mockImageReader.On("Find", testCtx, imageID).Return(&models.Image{}, nil).Once()
expectedStudioID := models.NullInt64(studioID)
- mockImageReader.On("Update", models.ImagePartial{
+ mockImageReader.On("Update", testCtx, models.ImagePartial{
ID: imageID,
StudioID: &expectedStudioID,
}).Return(nil, nil).Once()
}
- err := StudioImages(&studio, nil, aliases, mockImageReader, nil)
+ err := StudioImages(testCtx, &studio, nil, aliases, mockImageReader, nil)
assert := assert.New(t)
@@ -297,7 +297,7 @@ func testStudioGalleries(t *testing.T, tc testStudioCase) {
}
// if alias provided, then don't find by name
- onNameQuery := mockGalleryReader.On("Query", expectedGalleryFilter, expectedFindFilter)
+ onNameQuery := mockGalleryReader.On("Query", testCtx, expectedGalleryFilter, expectedFindFilter)
if aliasName == "" {
onNameQuery.Return(galleries, len(galleries), nil).Once()
} else {
@@ -311,20 +311,20 @@ func testStudioGalleries(t *testing.T, tc testStudioCase) {
},
}
- mockGalleryReader.On("Query", expectedAliasFilter, expectedFindFilter).Return(galleries, len(galleries), nil).Once()
+ mockGalleryReader.On("Query", testCtx, expectedAliasFilter, expectedFindFilter).Return(galleries, len(galleries), nil).Once()
}
for i := range matchingPaths {
galleryID := i + 1
- mockGalleryReader.On("Find", galleryID).Return(&models.Gallery{}, nil).Once()
+ mockGalleryReader.On("Find", testCtx, galleryID).Return(&models.Gallery{}, nil).Once()
expectedStudioID := models.NullInt64(studioID)
- mockGalleryReader.On("UpdatePartial", models.GalleryPartial{
+ mockGalleryReader.On("UpdatePartial", testCtx, models.GalleryPartial{
ID: galleryID,
StudioID: &expectedStudioID,
}).Return(nil, nil).Once()
}
- err := StudioGalleries(&studio, nil, aliases, mockGalleryReader, nil)
+ err := StudioGalleries(testCtx, &studio, nil, aliases, mockGalleryReader, nil)
assert := assert.New(t)
diff --git a/internal/autotag/tag.go b/internal/autotag/tag.go
index f0d080871..4c66573b3 100644
--- a/internal/autotag/tag.go
+++ b/internal/autotag/tag.go
@@ -1,6 +1,8 @@
package autotag
import (
+ "context"
+
"github.com/stashapp/stash/pkg/gallery"
"github.com/stashapp/stash/pkg/image"
"github.com/stashapp/stash/pkg/match"
@@ -8,6 +10,21 @@ import (
"github.com/stashapp/stash/pkg/scene"
)
+type SceneQueryTagUpdater interface {
+ scene.Queryer
+ scene.TagUpdater
+}
+
+type ImageQueryTagUpdater interface {
+ image.Queryer
+ image.TagUpdater
+}
+
+type GalleryQueryTagUpdater interface {
+ gallery.Queryer
+ gallery.TagUpdater
+}
+
func getTagTaggers(p *models.Tag, aliases []string, cache *match.Cache) []tagger {
ret := []tagger{{
ID: p.ID,
@@ -29,12 +46,12 @@ func getTagTaggers(p *models.Tag, aliases []string, cache *match.Cache) []tagger
}
// TagScenes searches for scenes whose path matches the provided tag name and tags the scene with the tag.
-func TagScenes(p *models.Tag, paths []string, aliases []string, rw models.SceneReaderWriter, cache *match.Cache) error {
+func TagScenes(ctx context.Context, p *models.Tag, paths []string, aliases []string, rw SceneQueryTagUpdater, cache *match.Cache) error {
t := getTagTaggers(p, aliases, cache)
for _, tt := range t {
- if err := tt.tagScenes(paths, rw, func(subjectID, otherID int) (bool, error) {
- return scene.AddTag(rw, otherID, subjectID)
+ if err := tt.tagScenes(ctx, paths, rw, func(subjectID, otherID int) (bool, error) {
+ return scene.AddTag(ctx, rw, otherID, subjectID)
}); err != nil {
return err
}
@@ -43,12 +60,12 @@ func TagScenes(p *models.Tag, paths []string, aliases []string, rw models.SceneR
}
// TagImages searches for images whose path matches the provided tag name and tags the image with the tag.
-func TagImages(p *models.Tag, paths []string, aliases []string, rw models.ImageReaderWriter, cache *match.Cache) error {
+func TagImages(ctx context.Context, p *models.Tag, paths []string, aliases []string, rw ImageQueryTagUpdater, cache *match.Cache) error {
t := getTagTaggers(p, aliases, cache)
for _, tt := range t {
- if err := tt.tagImages(paths, rw, func(subjectID, otherID int) (bool, error) {
- return image.AddTag(rw, otherID, subjectID)
+ if err := tt.tagImages(ctx, paths, rw, func(subjectID, otherID int) (bool, error) {
+ return image.AddTag(ctx, rw, otherID, subjectID)
}); err != nil {
return err
}
@@ -57,12 +74,12 @@ func TagImages(p *models.Tag, paths []string, aliases []string, rw models.ImageR
}
// TagGalleries searches for galleries whose path matches the provided tag name and tags the gallery with the tag.
-func TagGalleries(p *models.Tag, paths []string, aliases []string, rw models.GalleryReaderWriter, cache *match.Cache) error {
+func TagGalleries(ctx context.Context, p *models.Tag, paths []string, aliases []string, rw GalleryQueryTagUpdater, cache *match.Cache) error {
t := getTagTaggers(p, aliases, cache)
for _, tt := range t {
- if err := tt.tagGalleries(paths, rw, func(subjectID, otherID int) (bool, error) {
- return gallery.AddTag(rw, otherID, subjectID)
+ if err := tt.tagGalleries(ctx, paths, rw, func(subjectID, otherID int) (bool, error) {
+ return gallery.AddTag(ctx, rw, otherID, subjectID)
}); err != nil {
return err
}
diff --git a/internal/autotag/tag_test.go b/internal/autotag/tag_test.go
index a1eed1eab..c49f580e3 100644
--- a/internal/autotag/tag_test.go
+++ b/internal/autotag/tag_test.go
@@ -113,7 +113,7 @@ func testTagScenes(t *testing.T, tc testTagCase) {
}
// if alias provided, then don't find by name
- onNameQuery := mockSceneReader.On("Query", scene.QueryOptions(expectedSceneFilter, expectedFindFilter, false))
+ onNameQuery := mockSceneReader.On("Query", testCtx, scene.QueryOptions(expectedSceneFilter, expectedFindFilter, false))
if aliasName == "" {
onNameQuery.Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once()
} else {
@@ -127,17 +127,17 @@ func testTagScenes(t *testing.T, tc testTagCase) {
},
}
- mockSceneReader.On("Query", scene.QueryOptions(expectedAliasFilter, expectedFindFilter, false)).
+ mockSceneReader.On("Query", testCtx, scene.QueryOptions(expectedAliasFilter, expectedFindFilter, false)).
Return(mocks.SceneQueryResult(scenes, len(scenes)), nil).Once()
}
for i := range matchingPaths {
sceneID := i + 1
- mockSceneReader.On("GetTagIDs", sceneID).Return(nil, nil).Once()
- mockSceneReader.On("UpdateTags", sceneID, []int{tagID}).Return(nil).Once()
+ mockSceneReader.On("GetTagIDs", testCtx, sceneID).Return(nil, nil).Once()
+ mockSceneReader.On("UpdateTags", testCtx, sceneID, []int{tagID}).Return(nil).Once()
}
- err := TagScenes(&tag, nil, aliases, mockSceneReader, nil)
+ err := TagScenes(testCtx, &tag, nil, aliases, mockSceneReader, nil)
assert := assert.New(t)
@@ -201,7 +201,7 @@ func testTagImages(t *testing.T, tc testTagCase) {
}
// if alias provided, then don't find by name
- onNameQuery := mockImageReader.On("Query", image.QueryOptions(expectedImageFilter, expectedFindFilter, false))
+ onNameQuery := mockImageReader.On("Query", testCtx, image.QueryOptions(expectedImageFilter, expectedFindFilter, false))
if aliasName == "" {
onNameQuery.Return(mocks.ImageQueryResult(images, len(images)), nil).Once()
} else {
@@ -215,17 +215,17 @@ func testTagImages(t *testing.T, tc testTagCase) {
},
}
- mockImageReader.On("Query", image.QueryOptions(expectedAliasFilter, expectedFindFilter, false)).
+ mockImageReader.On("Query", testCtx, image.QueryOptions(expectedAliasFilter, expectedFindFilter, false)).
Return(mocks.ImageQueryResult(images, len(images)), nil).Once()
}
for i := range matchingPaths {
imageID := i + 1
- mockImageReader.On("GetTagIDs", imageID).Return(nil, nil).Once()
- mockImageReader.On("UpdateTags", imageID, []int{tagID}).Return(nil).Once()
+ mockImageReader.On("GetTagIDs", testCtx, imageID).Return(nil, nil).Once()
+ mockImageReader.On("UpdateTags", testCtx, imageID, []int{tagID}).Return(nil).Once()
}
- err := TagImages(&tag, nil, aliases, mockImageReader, nil)
+ err := TagImages(testCtx, &tag, nil, aliases, mockImageReader, nil)
assert := assert.New(t)
@@ -289,7 +289,7 @@ func testTagGalleries(t *testing.T, tc testTagCase) {
}
// if alias provided, then don't find by name
- onNameQuery := mockGalleryReader.On("Query", expectedGalleryFilter, expectedFindFilter)
+ onNameQuery := mockGalleryReader.On("Query", testCtx, expectedGalleryFilter, expectedFindFilter)
if aliasName == "" {
onNameQuery.Return(galleries, len(galleries), nil).Once()
} else {
@@ -303,16 +303,16 @@ func testTagGalleries(t *testing.T, tc testTagCase) {
},
}
- mockGalleryReader.On("Query", expectedAliasFilter, expectedFindFilter).Return(galleries, len(galleries), nil).Once()
+ mockGalleryReader.On("Query", testCtx, expectedAliasFilter, expectedFindFilter).Return(galleries, len(galleries), nil).Once()
}
for i := range matchingPaths {
galleryID := i + 1
- mockGalleryReader.On("GetTagIDs", galleryID).Return(nil, nil).Once()
- mockGalleryReader.On("UpdateTags", galleryID, []int{tagID}).Return(nil).Once()
+ mockGalleryReader.On("GetTagIDs", testCtx, galleryID).Return(nil, nil).Once()
+ mockGalleryReader.On("UpdateTags", testCtx, galleryID, []int{tagID}).Return(nil).Once()
}
- err := TagGalleries(&tag, nil, aliases, mockGalleryReader, nil)
+ err := TagGalleries(testCtx, &tag, nil, aliases, mockGalleryReader, nil)
assert := assert.New(t)
diff --git a/internal/autotag/tagger.go b/internal/autotag/tagger.go
index 4ea1fbc01..dae5cdc07 100644
--- a/internal/autotag/tagger.go
+++ b/internal/autotag/tagger.go
@@ -14,11 +14,14 @@
package autotag
import (
+ "context"
"fmt"
+ "github.com/stashapp/stash/pkg/gallery"
+ "github.com/stashapp/stash/pkg/image"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/match"
- "github.com/stashapp/stash/pkg/models"
+ "github.com/stashapp/stash/pkg/scene"
)
type tagger struct {
@@ -41,8 +44,8 @@ func (t *tagger) addLog(otherType, otherName string) {
logger.Infof("Added %s '%s' to %s '%s'", otherType, otherName, t.Type, t.Name)
}
-func (t *tagger) tagPerformers(performerReader models.PerformerReader, addFunc addLinkFunc) error {
- others, err := match.PathToPerformers(t.Path, performerReader, t.cache, t.trimExt)
+func (t *tagger) tagPerformers(ctx context.Context, performerReader match.PerformerAutoTagQueryer, addFunc addLinkFunc) error {
+ others, err := match.PathToPerformers(ctx, t.Path, performerReader, t.cache, t.trimExt)
if err != nil {
return err
}
@@ -62,8 +65,8 @@ func (t *tagger) tagPerformers(performerReader models.PerformerReader, addFunc a
return nil
}
-func (t *tagger) tagStudios(studioReader models.StudioReader, addFunc addLinkFunc) error {
- studio, err := match.PathToStudio(t.Path, studioReader, t.cache, t.trimExt)
+func (t *tagger) tagStudios(ctx context.Context, studioReader match.StudioAutoTagQueryer, addFunc addLinkFunc) error {
+ studio, err := match.PathToStudio(ctx, t.Path, studioReader, t.cache, t.trimExt)
if err != nil {
return err
}
@@ -83,8 +86,8 @@ func (t *tagger) tagStudios(studioReader models.StudioReader, addFunc addLinkFun
return nil
}
-func (t *tagger) tagTags(tagReader models.TagReader, addFunc addLinkFunc) error {
- others, err := match.PathToTags(t.Path, tagReader, t.cache, t.trimExt)
+func (t *tagger) tagTags(ctx context.Context, tagReader match.TagAutoTagQueryer, addFunc addLinkFunc) error {
+ others, err := match.PathToTags(ctx, t.Path, tagReader, t.cache, t.trimExt)
if err != nil {
return err
}
@@ -104,8 +107,8 @@ func (t *tagger) tagTags(tagReader models.TagReader, addFunc addLinkFunc) error
return nil
}
-func (t *tagger) tagScenes(paths []string, sceneReader models.SceneReader, addFunc addLinkFunc) error {
- others, err := match.PathToScenes(t.Name, paths, sceneReader)
+func (t *tagger) tagScenes(ctx context.Context, paths []string, sceneReader scene.Queryer, addFunc addLinkFunc) error {
+ others, err := match.PathToScenes(ctx, t.Name, paths, sceneReader)
if err != nil {
return err
}
@@ -125,8 +128,8 @@ func (t *tagger) tagScenes(paths []string, sceneReader models.SceneReader, addFu
return nil
}
-func (t *tagger) tagImages(paths []string, imageReader models.ImageReader, addFunc addLinkFunc) error {
- others, err := match.PathToImages(t.Name, paths, imageReader)
+func (t *tagger) tagImages(ctx context.Context, paths []string, imageReader image.Queryer, addFunc addLinkFunc) error {
+ others, err := match.PathToImages(ctx, t.Name, paths, imageReader)
if err != nil {
return err
}
@@ -146,8 +149,8 @@ func (t *tagger) tagImages(paths []string, imageReader models.ImageReader, addFu
return nil
}
-func (t *tagger) tagGalleries(paths []string, galleryReader models.GalleryReader, addFunc addLinkFunc) error {
- others, err := match.PathToGalleries(t.Name, paths, galleryReader)
+func (t *tagger) tagGalleries(ctx context.Context, paths []string, galleryReader gallery.Queryer, addFunc addLinkFunc) error {
+ others, err := match.PathToGalleries(ctx, t.Name, paths, galleryReader)
if err != nil {
return err
}
diff --git a/internal/dlna/cds.go b/internal/dlna/cds.go
index 4544b8759..6faa312b8 100644
--- a/internal/dlna/cds.go
+++ b/internal/dlna/cds.go
@@ -41,6 +41,7 @@ import (
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/scene"
"github.com/stashapp/stash/pkg/sliceutil/stringslice"
+ "github.com/stashapp/stash/pkg/txn"
)
var pageSize = 100
@@ -56,7 +57,6 @@ type browse struct {
type contentDirectoryService struct {
*Server
upnp.Eventing
- txnManager models.TransactionManager
}
func formatDurationSexagesimal(d time.Duration) string {
@@ -352,8 +352,8 @@ func (me *contentDirectoryService) handleBrowseMetadata(obj object, host string)
} else {
var scene *models.Scene
- if err := me.txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error {
- scene, err = r.Scene().Find(sceneID)
+ if err := txn.WithTxn(context.TODO(), me.txnManager, func(ctx context.Context) error {
+ scene, err = me.repository.SceneFinder.Find(ctx, sceneID)
if err != nil {
return err
}
@@ -431,14 +431,14 @@ func getRootObjects() []interface{} {
func (me *contentDirectoryService) getVideos(sceneFilter *models.SceneFilterType, parentID string, host string) []interface{} {
var objs []interface{}
- if err := me.txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error {
+ if err := txn.WithTxn(context.TODO(), me.txnManager, func(ctx context.Context) error {
sort := "title"
findFilter := &models.FindFilterType{
PerPage: &pageSize,
Sort: &sort,
}
- scenes, total, err := scene.QueryWithCount(r.Scene(), sceneFilter, findFilter)
+ scenes, total, err := scene.QueryWithCount(ctx, me.repository.SceneFinder, sceneFilter, findFilter)
if err != nil {
return err
}
@@ -449,7 +449,7 @@ func (me *contentDirectoryService) getVideos(sceneFilter *models.SceneFilterType
parentID: parentID,
}
- objs, err = pager.getPages(r, total)
+ objs, err = pager.getPages(ctx, me.repository.SceneFinder, total)
if err != nil {
return err
}
@@ -470,14 +470,14 @@ func (me *contentDirectoryService) getVideos(sceneFilter *models.SceneFilterType
func (me *contentDirectoryService) getPageVideos(sceneFilter *models.SceneFilterType, parentID string, page int, host string) []interface{} {
var objs []interface{}
- if err := me.txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error {
+ if err := txn.WithTxn(context.TODO(), me.txnManager, func(ctx context.Context) error {
pager := scenePager{
sceneFilter: sceneFilter,
parentID: parentID,
}
var err error
- objs, err = pager.getPageVideos(r, page, host)
+ objs, err = pager.getPageVideos(ctx, me.repository.SceneFinder, page, host)
if err != nil {
return err
}
@@ -511,8 +511,8 @@ func (me *contentDirectoryService) getAllScenes(host string) []interface{} {
func (me *contentDirectoryService) getStudios() []interface{} {
var objs []interface{}
- if err := me.txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error {
- studios, err := r.Studio().All()
+ if err := txn.WithTxn(context.TODO(), me.txnManager, func(ctx context.Context) error {
+ studios, err := me.repository.StudioFinder.All(ctx)
if err != nil {
return err
}
@@ -550,8 +550,8 @@ func (me *contentDirectoryService) getStudioScenes(paths []string, host string)
func (me *contentDirectoryService) getTags() []interface{} {
var objs []interface{}
- if err := me.txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error {
- tags, err := r.Tag().All()
+ if err := txn.WithTxn(context.TODO(), me.txnManager, func(ctx context.Context) error {
+ tags, err := me.repository.TagFinder.All(ctx)
if err != nil {
return err
}
@@ -589,8 +589,8 @@ func (me *contentDirectoryService) getTagScenes(paths []string, host string) []i
func (me *contentDirectoryService) getPerformers() []interface{} {
var objs []interface{}
- if err := me.txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error {
- performers, err := r.Performer().All()
+ if err := txn.WithTxn(context.TODO(), me.txnManager, func(ctx context.Context) error {
+ performers, err := me.repository.PerformerFinder.All(ctx)
if err != nil {
return err
}
@@ -628,8 +628,8 @@ func (me *contentDirectoryService) getPerformerScenes(paths []string, host strin
func (me *contentDirectoryService) getMovies() []interface{} {
var objs []interface{}
- if err := me.txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error {
- movies, err := r.Movie().All()
+ if err := txn.WithTxn(context.TODO(), me.txnManager, func(ctx context.Context) error {
+ movies, err := me.repository.MovieFinder.All(ctx)
if err != nil {
return err
}
diff --git a/internal/dlna/cds_test.go b/internal/dlna/cds_test.go
index b52ca3b88..592f8f818 100644
--- a/internal/dlna/cds_test.go
+++ b/internal/dlna/cds_test.go
@@ -31,7 +31,6 @@ import (
"strings"
"testing"
- "github.com/stashapp/stash/pkg/models/mocks"
"github.com/stretchr/testify/assert"
)
@@ -59,8 +58,7 @@ func TestRootParentObjectID(t *testing.T) {
func testHandleBrowse(argsXML string) (map[string]string, error) {
cds := contentDirectoryService{
- Server: &Server{},
- txnManager: mocks.NewTransactionManager(),
+ Server: &Server{},
}
r := &http.Request{}
diff --git a/internal/dlna/dms.go b/internal/dlna/dms.go
index a1ea8ceac..d5e7cc84e 100644
--- a/internal/dlna/dms.go
+++ b/internal/dlna/dms.go
@@ -48,8 +48,31 @@ import (
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models"
+ "github.com/stashapp/stash/pkg/scene"
+ "github.com/stashapp/stash/pkg/txn"
)
+type SceneFinder interface {
+ scene.Queryer
+ scene.IDFinder
+}
+
+type StudioFinder interface {
+ All(ctx context.Context) ([]*models.Studio, error)
+}
+
+type TagFinder interface {
+ All(ctx context.Context) ([]*models.Tag, error)
+}
+
+type PerformerFinder interface {
+ All(ctx context.Context) ([]*models.Performer, error)
+}
+
+type MovieFinder interface {
+ All(ctx context.Context) ([]*models.Movie, error)
+}
+
const (
serverField = "Linux/3.4 DLNADOC/1.50 UPnP/1.0 DMS/1.0"
rootDeviceType = "urn:schemas-upnp-org:device:MediaServer:1"
@@ -249,7 +272,8 @@ type Server struct {
// Time interval between SSPD announces
NotifyInterval time.Duration
- txnManager models.TransactionManager
+ txnManager txn.Manager
+ repository Repository
sceneServer sceneServer
ipWhitelistManager *ipWhitelistManager
}
@@ -415,12 +439,12 @@ func (me *Server) serveIcon(w http.ResponseWriter, r *http.Request) {
}
var scene *models.Scene
- err := me.txnManager.WithReadTxn(r.Context(), func(r models.ReaderRepository) error {
+ err := txn.WithTxn(r.Context(), me.txnManager, func(ctx context.Context) error {
idInt, err := strconv.Atoi(sceneId)
if err != nil {
return nil
}
- scene, _ = r.Scene().Find(idInt)
+ scene, _ = me.repository.SceneFinder.Find(ctx, idInt)
return nil
})
if err != nil {
@@ -555,12 +579,12 @@ func (me *Server) initMux(mux *http.ServeMux) {
mux.HandleFunc(resPath, func(w http.ResponseWriter, r *http.Request) {
sceneId := r.URL.Query().Get("scene")
var scene *models.Scene
- err := me.txnManager.WithReadTxn(r.Context(), func(r models.ReaderRepository) error {
+ err := txn.WithTxn(r.Context(), me.txnManager, func(ctx context.Context) error {
sceneIdInt, err := strconv.Atoi(sceneId)
if err != nil {
return nil
}
- scene, _ = r.Scene().Find(sceneIdInt)
+ scene, _ = me.repository.SceneFinder.Find(ctx, sceneIdInt)
return nil
})
if err != nil {
@@ -595,8 +619,7 @@ func (me *Server) initMux(mux *http.ServeMux) {
func (me *Server) initServices() {
me.services = map[string]UPnPService{
"ContentDirectory": &contentDirectoryService{
- Server: me,
- txnManager: me.txnManager,
+ Server: me,
},
"ConnectionManager": &connectionManagerService{
Server: me,
diff --git a/internal/dlna/paging.go b/internal/dlna/paging.go
index 6f2afda8e..e5f65f96a 100644
--- a/internal/dlna/paging.go
+++ b/internal/dlna/paging.go
@@ -1,6 +1,7 @@
package dlna
import (
+ "context"
"fmt"
"math"
"strconv"
@@ -18,7 +19,7 @@ func (p *scenePager) getPageID(page int) string {
return p.parentID + "/page/" + strconv.Itoa(page)
}
-func (p *scenePager) getPages(r models.ReaderRepository, total int) ([]interface{}, error) {
+func (p *scenePager) getPages(ctx context.Context, r scene.Queryer, total int) ([]interface{}, error) {
var objs []interface{}
// get the first scene of each page to set an appropriate title
@@ -37,7 +38,7 @@ func (p *scenePager) getPages(r models.ReaderRepository, total int) ([]interface
if pages <= 10 || (page-1)%(pages/10) == 0 {
thisPage := ((page - 1) * pageSize) + 1
findFilter.Page = &thisPage
- scenes, err := scene.Query(r.Scene(), p.sceneFilter, findFilter)
+ scenes, err := scene.Query(ctx, r, p.sceneFilter, findFilter)
if err != nil {
return nil, err
}
@@ -58,7 +59,7 @@ func (p *scenePager) getPages(r models.ReaderRepository, total int) ([]interface
return objs, nil
}
-func (p *scenePager) getPageVideos(r models.ReaderRepository, page int, host string) ([]interface{}, error) {
+func (p *scenePager) getPageVideos(ctx context.Context, r SceneFinder, page int, host string) ([]interface{}, error) {
var objs []interface{}
sort := "title"
@@ -68,7 +69,7 @@ func (p *scenePager) getPageVideos(r models.ReaderRepository, page int, host str
Sort: &sort,
}
- scenes, err := scene.Query(r.Scene(), p.sceneFilter, findFilter)
+ scenes, err := scene.Query(ctx, r, p.sceneFilter, findFilter)
if err != nil {
return nil, err
}
diff --git a/internal/dlna/service.go b/internal/dlna/service.go
index 961e3f230..261a2ab62 100644
--- a/internal/dlna/service.go
+++ b/internal/dlna/service.go
@@ -10,8 +10,17 @@ import (
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models"
+ "github.com/stashapp/stash/pkg/txn"
)
+type Repository struct {
+ SceneFinder SceneFinder
+ StudioFinder StudioFinder
+ TagFinder TagFinder
+ PerformerFinder PerformerFinder
+ MovieFinder MovieFinder
+}
+
type Status struct {
Running bool `json:"running"`
// If not currently running, time until it will be started. If running, time until it will be stopped
@@ -48,7 +57,8 @@ type Config interface {
}
type Service struct {
- txnManager models.TransactionManager
+ txnManager txn.Manager
+ repository Repository
config Config
sceneServer sceneServer
ipWhitelistMgr *ipWhitelistManager
@@ -121,6 +131,7 @@ func (s *Service) init() error {
s.server = &Server{
txnManager: s.txnManager,
sceneServer: s.sceneServer,
+ repository: s.repository,
ipWhitelistManager: s.ipWhitelistMgr,
Interfaces: interfaces,
HTTPConn: func() net.Listener {
@@ -181,9 +192,10 @@ func (s *Service) init() error {
// }
// NewService initialises and returns a new DLNA service.
-func NewService(txnManager models.TransactionManager, cfg Config, sceneServer sceneServer) *Service {
+func NewService(txnManager txn.Manager, repo Repository, cfg Config, sceneServer sceneServer) *Service {
ret := &Service{
txnManager: txnManager,
+ repository: repo,
sceneServer: sceneServer,
config: cfg,
ipWhitelistMgr: &ipWhitelistManager{
diff --git a/internal/identify/identify.go b/internal/identify/identify.go
index 4d0d0afa2..0c34cce96 100644
--- a/internal/identify/identify.go
+++ b/internal/identify/identify.go
@@ -9,6 +9,7 @@ import (
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/scene"
"github.com/stashapp/stash/pkg/scraper"
+ "github.com/stashapp/stash/pkg/txn"
"github.com/stashapp/stash/pkg/utils"
)
@@ -28,13 +29,18 @@ type ScraperSource struct {
}
type SceneIdentifier struct {
+ SceneReaderUpdater SceneReaderUpdater
+ StudioCreator StudioCreator
+ PerformerCreator PerformerCreator
+ TagCreator TagCreator
+
DefaultOptions *MetadataOptions
Sources []ScraperSource
ScreenshotSetter scene.ScreenshotSetter
SceneUpdatePostHookExecutor SceneUpdatePostHookExecutor
}
-func (t *SceneIdentifier) Identify(ctx context.Context, txnManager models.TransactionManager, scene *models.Scene) error {
+func (t *SceneIdentifier) Identify(ctx context.Context, txnManager txn.Manager, scene *models.Scene) error {
result, err := t.scrapeScene(ctx, scene)
if err != nil {
return err
@@ -80,7 +86,7 @@ func (t *SceneIdentifier) scrapeScene(ctx context.Context, scene *models.Scene)
return nil, nil
}
-func (t *SceneIdentifier) getSceneUpdater(ctx context.Context, s *models.Scene, result *scrapeResult, repo models.Repository) (*scene.UpdateSet, error) {
+func (t *SceneIdentifier) getSceneUpdater(ctx context.Context, s *models.Scene, result *scrapeResult) (*scene.UpdateSet, error) {
ret := &scene.UpdateSet{
ID: s.ID,
}
@@ -106,15 +112,18 @@ func (t *SceneIdentifier) getSceneUpdater(ctx context.Context, s *models.Scene,
scraped := result.result
rel := sceneRelationships{
- repo: repo,
- scene: s,
- result: result,
- fieldOptions: fieldOptions,
+ sceneReader: t.SceneReaderUpdater,
+ studioCreator: t.StudioCreator,
+ performerCreator: t.PerformerCreator,
+ tagCreator: t.TagCreator,
+ scene: s,
+ result: result,
+ fieldOptions: fieldOptions,
}
ret.Partial = getScenePartial(s, scraped, fieldOptions, setOrganized)
- studioID, err := rel.studio()
+ studioID, err := rel.studio(ctx)
if err != nil {
return nil, fmt.Errorf("error getting studio: %w", err)
}
@@ -134,17 +143,17 @@ func (t *SceneIdentifier) getSceneUpdater(ctx context.Context, s *models.Scene,
}
}
- ret.PerformerIDs, err = rel.performers(ignoreMale)
+ ret.PerformerIDs, err = rel.performers(ctx, ignoreMale)
if err != nil {
return nil, err
}
- ret.TagIDs, err = rel.tags()
+ ret.TagIDs, err = rel.tags(ctx)
if err != nil {
return nil, err
}
- ret.StashIDs, err = rel.stashIDs()
+ ret.StashIDs, err = rel.stashIDs(ctx)
if err != nil {
return nil, err
}
@@ -167,11 +176,11 @@ func (t *SceneIdentifier) getSceneUpdater(ctx context.Context, s *models.Scene,
return ret, nil
}
-func (t *SceneIdentifier) modifyScene(ctx context.Context, txnManager models.TransactionManager, s *models.Scene, result *scrapeResult) error {
+func (t *SceneIdentifier) modifyScene(ctx context.Context, txnManager txn.Manager, s *models.Scene, result *scrapeResult) error {
var updater *scene.UpdateSet
- if err := txnManager.WithTxn(ctx, func(repo models.Repository) error {
+ if err := txn.WithTxn(ctx, txnManager, func(ctx context.Context) error {
var err error
- updater, err = t.getSceneUpdater(ctx, s, result, repo)
+ updater, err = t.getSceneUpdater(ctx, s, result)
if err != nil {
return err
}
@@ -182,7 +191,7 @@ func (t *SceneIdentifier) modifyScene(ctx context.Context, txnManager models.Tra
return nil
}
- _, err = updater.Update(repo.Scene(), t.ScreenshotSetter)
+ _, err = updater.Update(ctx, t.SceneReaderUpdater, t.ScreenshotSetter)
if err != nil {
return fmt.Errorf("error updating scene: %w", err)
}
diff --git a/internal/identify/identify_test.go b/internal/identify/identify_test.go
index c5588c78a..88be638df 100644
--- a/internal/identify/identify_test.go
+++ b/internal/identify/identify_test.go
@@ -13,6 +13,8 @@ import (
"github.com/stretchr/testify/mock"
)
+var testCtx = context.Background()
+
type mockSceneScraper struct {
errIDs []int
results map[int]*scraper.ScrapedScene
@@ -70,11 +72,12 @@ func TestSceneIdentifier_Identify(t *testing.T) {
},
}
- repo := mocks.NewTransactionManager()
- repo.Scene().(*mocks.SceneReaderWriter).On("Update", mock.MatchedBy(func(partial models.ScenePartial) bool {
+ mockSceneReaderWriter := &mocks.SceneReaderWriter{}
+
+ mockSceneReaderWriter.On("Update", testCtx, mock.MatchedBy(func(partial models.ScenePartial) bool {
return partial.ID != errUpdateID
})).Return(nil, nil)
- repo.Scene().(*mocks.SceneReaderWriter).On("Update", mock.MatchedBy(func(partial models.ScenePartial) bool {
+ mockSceneReaderWriter.On("Update", testCtx, mock.MatchedBy(func(partial models.ScenePartial) bool {
return partial.ID == errUpdateID
})).Return(nil, errors.New("update error"))
@@ -116,6 +119,7 @@ func TestSceneIdentifier_Identify(t *testing.T) {
}
identifier := SceneIdentifier{
+ SceneReaderUpdater: mockSceneReaderWriter,
DefaultOptions: defaultOptions,
Sources: sources,
SceneUpdatePostHookExecutor: mockHookExecutor{},
@@ -126,7 +130,7 @@ func TestSceneIdentifier_Identify(t *testing.T) {
scene := &models.Scene{
ID: tt.sceneID,
}
- if err := identifier.Identify(context.TODO(), repo, scene); (err != nil) != tt.wantErr {
+ if err := identifier.Identify(testCtx, &mocks.TxnManager{}, scene); (err != nil) != tt.wantErr {
t.Errorf("SceneIdentifier.Identify() error = %v, wantErr %v", err, tt.wantErr)
}
})
@@ -134,7 +138,9 @@ func TestSceneIdentifier_Identify(t *testing.T) {
}
func TestSceneIdentifier_modifyScene(t *testing.T) {
- repo := mocks.NewTransactionManager()
+ repo := models.Repository{
+ TxnManager: &mocks.TxnManager{},
+ }
tr := &SceneIdentifier{}
type args struct {
@@ -159,7 +165,7 @@ func TestSceneIdentifier_modifyScene(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- if err := tr.modifyScene(context.TODO(), repo, tt.args.scene, tt.args.result); (err != nil) != tt.wantErr {
+ if err := tr.modifyScene(testCtx, repo, tt.args.scene, tt.args.result); (err != nil) != tt.wantErr {
t.Errorf("SceneIdentifier.modifyScene() error = %v, wantErr %v", err, tt.wantErr)
}
})
diff --git a/internal/identify/performer.go b/internal/identify/performer.go
index 495c3eb8e..435524cc4 100644
--- a/internal/identify/performer.go
+++ b/internal/identify/performer.go
@@ -1,6 +1,7 @@
package identify
import (
+ "context"
"database/sql"
"fmt"
"strconv"
@@ -10,7 +11,12 @@ import (
"github.com/stashapp/stash/pkg/models"
)
-func getPerformerID(endpoint string, r models.Repository, p *models.ScrapedPerformer, createMissing bool) (*int, error) {
+type PerformerCreator interface {
+ Create(ctx context.Context, newPerformer models.Performer) (*models.Performer, error)
+ UpdateStashIDs(ctx context.Context, performerID int, stashIDs []models.StashID) error
+}
+
+func getPerformerID(ctx context.Context, endpoint string, w PerformerCreator, p *models.ScrapedPerformer, createMissing bool) (*int, error) {
if p.StoredID != nil {
// existing performer, just add it
performerID, err := strconv.Atoi(*p.StoredID)
@@ -20,20 +26,20 @@ func getPerformerID(endpoint string, r models.Repository, p *models.ScrapedPerfo
return &performerID, nil
} else if createMissing && p.Name != nil { // name is mandatory
- return createMissingPerformer(endpoint, r, p)
+ return createMissingPerformer(ctx, endpoint, w, p)
}
return nil, nil
}
-func createMissingPerformer(endpoint string, r models.Repository, p *models.ScrapedPerformer) (*int, error) {
- created, err := r.Performer().Create(scrapedToPerformerInput(p))
+func createMissingPerformer(ctx context.Context, endpoint string, w PerformerCreator, p *models.ScrapedPerformer) (*int, error) {
+ created, err := w.Create(ctx, scrapedToPerformerInput(p))
if err != nil {
return nil, fmt.Errorf("error creating performer: %w", err)
}
if endpoint != "" && p.RemoteSiteID != nil {
- if err := r.Performer().UpdateStashIDs(created.ID, []models.StashID{
+ if err := w.UpdateStashIDs(ctx, created.ID, []models.StashID{
{
Endpoint: endpoint,
StashID: *p.RemoteSiteID,
diff --git a/internal/identify/performer_test.go b/internal/identify/performer_test.go
index ebe8e49fe..eeed8a1e7 100644
--- a/internal/identify/performer_test.go
+++ b/internal/identify/performer_test.go
@@ -23,8 +23,8 @@ func Test_getPerformerID(t *testing.T) {
validStoredID := 1
name := "name"
- repo := mocks.NewTransactionManager()
- repo.PerformerMock().On("Create", mock.Anything).Return(&models.Performer{
+ mockPerformerReaderWriter := mocks.PerformerReaderWriter{}
+ mockPerformerReaderWriter.On("Create", testCtx, mock.Anything).Return(&models.Performer{
ID: validStoredID,
}, nil)
@@ -110,7 +110,7 @@ func Test_getPerformerID(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- got, err := getPerformerID(tt.args.endpoint, repo, tt.args.p, tt.args.createMissing)
+ got, err := getPerformerID(testCtx, tt.args.endpoint, &mockPerformerReaderWriter, tt.args.p, tt.args.createMissing)
if (err != nil) != tt.wantErr {
t.Errorf("getPerformerID() error = %v, wantErr %v", err, tt.wantErr)
return
@@ -131,23 +131,23 @@ func Test_createMissingPerformer(t *testing.T) {
invalidName := "invalidName"
performerID := 1
- repo := mocks.NewTransactionManager()
- repo.PerformerMock().On("Create", mock.MatchedBy(func(p models.Performer) bool {
+ mockPerformerReaderWriter := mocks.PerformerReaderWriter{}
+ mockPerformerReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p models.Performer) bool {
return p.Name.String == validName
})).Return(&models.Performer{
ID: performerID,
}, nil)
- repo.PerformerMock().On("Create", mock.MatchedBy(func(p models.Performer) bool {
+ mockPerformerReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p models.Performer) bool {
return p.Name.String == invalidName
})).Return(nil, errors.New("error creating performer"))
- repo.PerformerMock().On("UpdateStashIDs", performerID, []models.StashID{
+ mockPerformerReaderWriter.On("UpdateStashIDs", testCtx, performerID, []models.StashID{
{
Endpoint: invalidEndpoint,
StashID: remoteSiteID,
},
}).Return(errors.New("error updating stash ids"))
- repo.PerformerMock().On("UpdateStashIDs", performerID, []models.StashID{
+ mockPerformerReaderWriter.On("UpdateStashIDs", testCtx, performerID, []models.StashID{
{
Endpoint: validEndpoint,
StashID: remoteSiteID,
@@ -213,7 +213,7 @@ func Test_createMissingPerformer(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- got, err := createMissingPerformer(tt.args.endpoint, repo, tt.args.p)
+ got, err := createMissingPerformer(testCtx, tt.args.endpoint, &mockPerformerReaderWriter, tt.args.p)
if (err != nil) != tt.wantErr {
t.Errorf("createMissingPerformer() error = %v, wantErr %v", err, tt.wantErr)
return
diff --git a/internal/identify/scene.go b/internal/identify/scene.go
index 3e6fd4a38..4e7f4d3cc 100644
--- a/internal/identify/scene.go
+++ b/internal/identify/scene.go
@@ -9,19 +9,35 @@ import (
"time"
"github.com/stashapp/stash/pkg/models"
+ "github.com/stashapp/stash/pkg/scene"
"github.com/stashapp/stash/pkg/sliceutil"
"github.com/stashapp/stash/pkg/sliceutil/intslice"
"github.com/stashapp/stash/pkg/utils"
)
-type sceneRelationships struct {
- repo models.Repository
- scene *models.Scene
- result *scrapeResult
- fieldOptions map[string]*FieldOptions
+type SceneReaderUpdater interface {
+ GetPerformerIDs(ctx context.Context, sceneID int) ([]int, error)
+ GetTagIDs(ctx context.Context, sceneID int) ([]int, error)
+ GetStashIDs(ctx context.Context, sceneID int) ([]*models.StashID, error)
+ GetCover(ctx context.Context, sceneID int) ([]byte, error)
+ scene.Updater
}
-func (g sceneRelationships) studio() (*int64, error) {
+type TagCreator interface {
+ Create(ctx context.Context, newTag models.Tag) (*models.Tag, error)
+}
+
+type sceneRelationships struct {
+ sceneReader SceneReaderUpdater
+ studioCreator StudioCreator
+ performerCreator PerformerCreator
+ tagCreator TagCreator
+ scene *models.Scene
+ result *scrapeResult
+ fieldOptions map[string]*FieldOptions
+}
+
+func (g sceneRelationships) studio(ctx context.Context) (*int64, error) {
existingID := g.scene.StudioID
fieldStrategy := g.fieldOptions["studio"]
createMissing := fieldStrategy != nil && utils.IsTrue(fieldStrategy.CreateMissing)
@@ -45,13 +61,13 @@ func (g sceneRelationships) studio() (*int64, error) {
return &studioID, nil
}
} else if createMissing {
- return createMissingStudio(endpoint, g.repo, scraped)
+ return createMissingStudio(ctx, endpoint, g.studioCreator, scraped)
}
return nil, nil
}
-func (g sceneRelationships) performers(ignoreMale bool) ([]int, error) {
+func (g sceneRelationships) performers(ctx context.Context, ignoreMale bool) ([]int, error) {
fieldStrategy := g.fieldOptions["performers"]
scraped := g.result.result.Performers
@@ -66,11 +82,10 @@ func (g sceneRelationships) performers(ignoreMale bool) ([]int, error) {
strategy = fieldStrategy.Strategy
}
- repo := g.repo
endpoint := g.result.source.RemoteSite
var performerIDs []int
- originalPerformerIDs, err := repo.Scene().GetPerformerIDs(g.scene.ID)
+ originalPerformerIDs, err := g.sceneReader.GetPerformerIDs(ctx, g.scene.ID)
if err != nil {
return nil, fmt.Errorf("error getting scene performers: %w", err)
}
@@ -85,7 +100,7 @@ func (g sceneRelationships) performers(ignoreMale bool) ([]int, error) {
continue
}
- performerID, err := getPerformerID(endpoint, repo, p, createMissing)
+ performerID, err := getPerformerID(ctx, endpoint, g.performerCreator, p, createMissing)
if err != nil {
return nil, err
}
@@ -103,11 +118,10 @@ func (g sceneRelationships) performers(ignoreMale bool) ([]int, error) {
return performerIDs, nil
}
-func (g sceneRelationships) tags() ([]int, error) {
+func (g sceneRelationships) tags(ctx context.Context) ([]int, error) {
fieldStrategy := g.fieldOptions["tags"]
scraped := g.result.result.Tags
target := g.scene
- r := g.repo
// just check if ignored
if len(scraped) == 0 || !shouldSetSingleValueField(fieldStrategy, false) {
@@ -121,7 +135,7 @@ func (g sceneRelationships) tags() ([]int, error) {
}
var tagIDs []int
- originalTagIDs, err := r.Scene().GetTagIDs(target.ID)
+ originalTagIDs, err := g.sceneReader.GetTagIDs(ctx, target.ID)
if err != nil {
return nil, fmt.Errorf("error getting scene tags: %w", err)
}
@@ -142,7 +156,7 @@ func (g sceneRelationships) tags() ([]int, error) {
tagIDs = intslice.IntAppendUnique(tagIDs, int(tagID))
} else if createMissing {
now := time.Now()
- created, err := r.Tag().Create(models.Tag{
+ created, err := g.tagCreator.Create(ctx, models.Tag{
Name: t.Name,
CreatedAt: models.SQLiteTimestamp{Timestamp: now},
UpdatedAt: models.SQLiteTimestamp{Timestamp: now},
@@ -163,11 +177,10 @@ func (g sceneRelationships) tags() ([]int, error) {
return tagIDs, nil
}
-func (g sceneRelationships) stashIDs() ([]models.StashID, error) {
+func (g sceneRelationships) stashIDs(ctx context.Context) ([]models.StashID, error) {
remoteSiteID := g.result.result.RemoteSiteID
fieldStrategy := g.fieldOptions["stash_ids"]
target := g.scene
- r := g.repo
endpoint := g.result.source.RemoteSite
@@ -183,7 +196,7 @@ func (g sceneRelationships) stashIDs() ([]models.StashID, error) {
var originalStashIDs []models.StashID
var stashIDs []models.StashID
- stashIDPtrs, err := r.Scene().GetStashIDs(target.ID)
+ stashIDPtrs, err := g.sceneReader.GetStashIDs(ctx, target.ID)
if err != nil {
return nil, fmt.Errorf("error getting scene tag: %w", err)
}
@@ -227,14 +240,13 @@ func (g sceneRelationships) stashIDs() ([]models.StashID, error) {
func (g sceneRelationships) cover(ctx context.Context) ([]byte, error) {
scraped := g.result.result.Image
- r := g.repo
if scraped == nil {
return nil, nil
}
// always overwrite if present
- existingCover, err := r.Scene().GetCover(g.scene.ID)
+ existingCover, err := g.sceneReader.GetCover(ctx, g.scene.ID)
if err != nil {
return nil, fmt.Errorf("error getting scene cover: %w", err)
}
diff --git a/internal/identify/scene_test.go b/internal/identify/scene_test.go
index 2487e6808..bdef0c864 100644
--- a/internal/identify/scene_test.go
+++ b/internal/identify/scene_test.go
@@ -24,14 +24,14 @@ func Test_sceneRelationships_studio(t *testing.T) {
Strategy: FieldStrategyMerge,
}
- repo := mocks.NewTransactionManager()
- repo.StudioMock().On("Create", mock.Anything).Return(&models.Studio{
+ mockStudioReaderWriter := &mocks.StudioReaderWriter{}
+ mockStudioReaderWriter.On("Create", testCtx, mock.Anything).Return(&models.Studio{
ID: int(validStoredIDInt),
}, nil)
tr := sceneRelationships{
- repo: repo,
- fieldOptions: make(map[string]*FieldOptions),
+ studioCreator: mockStudioReaderWriter,
+ fieldOptions: make(map[string]*FieldOptions),
}
tests := []struct {
@@ -124,7 +124,7 @@ func Test_sceneRelationships_studio(t *testing.T) {
},
}
- got, err := tr.studio()
+ got, err := tr.studio(testCtx)
if (err != nil) != tt.wantErr {
t.Errorf("sceneRelationships.studio() error = %v, wantErr %v", err, tt.wantErr)
return
@@ -156,13 +156,13 @@ func Test_sceneRelationships_performers(t *testing.T) {
Strategy: FieldStrategyMerge,
}
- repo := mocks.NewTransactionManager()
- repo.SceneMock().On("GetPerformerIDs", sceneID).Return(nil, nil)
- repo.SceneMock().On("GetPerformerIDs", sceneWithPerformerID).Return([]int{existingPerformerID}, nil)
- repo.SceneMock().On("GetPerformerIDs", errSceneID).Return(nil, errors.New("error getting IDs"))
+ mockSceneReaderWriter := &mocks.SceneReaderWriter{}
+ mockSceneReaderWriter.On("GetPerformerIDs", testCtx, sceneID).Return(nil, nil)
+ mockSceneReaderWriter.On("GetPerformerIDs", testCtx, sceneWithPerformerID).Return([]int{existingPerformerID}, nil)
+ mockSceneReaderWriter.On("GetPerformerIDs", testCtx, errSceneID).Return(nil, errors.New("error getting IDs"))
tr := sceneRelationships{
- repo: repo,
+ sceneReader: mockSceneReaderWriter,
fieldOptions: make(map[string]*FieldOptions),
}
@@ -316,7 +316,7 @@ func Test_sceneRelationships_performers(t *testing.T) {
},
}
- got, err := tr.performers(tt.ignoreMale)
+ got, err := tr.performers(testCtx, tt.ignoreMale)
if (err != nil) != tt.wantErr {
t.Errorf("sceneRelationships.performers() error = %v, wantErr %v", err, tt.wantErr)
return
@@ -347,22 +347,24 @@ func Test_sceneRelationships_tags(t *testing.T) {
Strategy: FieldStrategyMerge,
}
- repo := mocks.NewTransactionManager()
- repo.SceneMock().On("GetTagIDs", sceneID).Return(nil, nil)
- repo.SceneMock().On("GetTagIDs", sceneWithTagID).Return([]int{existingID}, nil)
- repo.SceneMock().On("GetTagIDs", errSceneID).Return(nil, errors.New("error getting IDs"))
+ mockSceneReaderWriter := &mocks.SceneReaderWriter{}
+ mockTagReaderWriter := &mocks.TagReaderWriter{}
+ mockSceneReaderWriter.On("GetTagIDs", testCtx, sceneID).Return(nil, nil)
+ mockSceneReaderWriter.On("GetTagIDs", testCtx, sceneWithTagID).Return([]int{existingID}, nil)
+ mockSceneReaderWriter.On("GetTagIDs", testCtx, errSceneID).Return(nil, errors.New("error getting IDs"))
- repo.TagMock().On("Create", mock.MatchedBy(func(p models.Tag) bool {
+ mockTagReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p models.Tag) bool {
return p.Name == validName
})).Return(&models.Tag{
ID: validStoredIDInt,
}, nil)
- repo.TagMock().On("Create", mock.MatchedBy(func(p models.Tag) bool {
+ mockTagReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p models.Tag) bool {
return p.Name == invalidName
})).Return(nil, errors.New("error creating tag"))
tr := sceneRelationships{
- repo: repo,
+ sceneReader: mockSceneReaderWriter,
+ tagCreator: mockTagReaderWriter,
fieldOptions: make(map[string]*FieldOptions),
}
@@ -505,7 +507,7 @@ func Test_sceneRelationships_tags(t *testing.T) {
},
}
- got, err := tr.tags()
+ got, err := tr.tags(testCtx)
if (err != nil) != tt.wantErr {
t.Errorf("sceneRelationships.tags() error = %v, wantErr %v", err, tt.wantErr)
return
@@ -534,18 +536,18 @@ func Test_sceneRelationships_stashIDs(t *testing.T) {
Strategy: FieldStrategyMerge,
}
- repo := mocks.NewTransactionManager()
- repo.SceneMock().On("GetStashIDs", sceneID).Return(nil, nil)
- repo.SceneMock().On("GetStashIDs", sceneWithStashID).Return([]*models.StashID{
+ mockSceneReaderWriter := &mocks.SceneReaderWriter{}
+ mockSceneReaderWriter.On("GetStashIDs", testCtx, sceneID).Return(nil, nil)
+ mockSceneReaderWriter.On("GetStashIDs", testCtx, sceneWithStashID).Return([]*models.StashID{
{
StashID: remoteSiteID,
Endpoint: existingEndpoint,
},
}, nil)
- repo.SceneMock().On("GetStashIDs", errSceneID).Return(nil, errors.New("error getting IDs"))
+ mockSceneReaderWriter.On("GetStashIDs", testCtx, errSceneID).Return(nil, errors.New("error getting IDs"))
tr := sceneRelationships{
- repo: repo,
+ sceneReader: mockSceneReaderWriter,
fieldOptions: make(map[string]*FieldOptions),
}
@@ -680,7 +682,7 @@ func Test_sceneRelationships_stashIDs(t *testing.T) {
},
}
- got, err := tr.stashIDs()
+ got, err := tr.stashIDs(testCtx)
if (err != nil) != tt.wantErr {
t.Errorf("sceneRelationships.stashIDs() error = %v, wantErr %v", err, tt.wantErr)
return
@@ -707,12 +709,12 @@ func Test_sceneRelationships_cover(t *testing.T) {
newDataEncoded := base64Prefix + utils.GetBase64StringFromData(newData)
invalidData := newDataEncoded + "!!!"
- repo := mocks.NewTransactionManager()
- repo.SceneMock().On("GetCover", sceneID).Return(existingData, nil)
- repo.SceneMock().On("GetCover", errSceneID).Return(nil, errors.New("error getting cover"))
+ mockSceneReaderWriter := &mocks.SceneReaderWriter{}
+ mockSceneReaderWriter.On("GetCover", testCtx, sceneID).Return(existingData, nil)
+ mockSceneReaderWriter.On("GetCover", testCtx, errSceneID).Return(nil, errors.New("error getting cover"))
tr := sceneRelationships{
- repo: repo,
+ sceneReader: mockSceneReaderWriter,
fieldOptions: make(map[string]*FieldOptions),
}
diff --git a/internal/identify/studio.go b/internal/identify/studio.go
index 86cb6b737..923a0322a 100644
--- a/internal/identify/studio.go
+++ b/internal/identify/studio.go
@@ -1,6 +1,7 @@
package identify
import (
+ "context"
"database/sql"
"fmt"
"time"
@@ -9,14 +10,19 @@ import (
"github.com/stashapp/stash/pkg/models"
)
-func createMissingStudio(endpoint string, repo models.Repository, studio *models.ScrapedStudio) (*int64, error) {
- created, err := repo.Studio().Create(scrapedToStudioInput(studio))
+type StudioCreator interface {
+ Create(ctx context.Context, newStudio models.Studio) (*models.Studio, error)
+ UpdateStashIDs(ctx context.Context, studioID int, stashIDs []models.StashID) error
+}
+
+func createMissingStudio(ctx context.Context, endpoint string, w StudioCreator, studio *models.ScrapedStudio) (*int64, error) {
+ created, err := w.Create(ctx, scrapedToStudioInput(studio))
if err != nil {
return nil, fmt.Errorf("error creating studio: %w", err)
}
if endpoint != "" && studio.RemoteSiteID != nil {
- if err := repo.Studio().UpdateStashIDs(created.ID, []models.StashID{
+ if err := w.UpdateStashIDs(ctx, created.ID, []models.StashID{
{
Endpoint: endpoint,
StashID: *studio.RemoteSiteID,
diff --git a/internal/identify/studio_test.go b/internal/identify/studio_test.go
index 2ba0b840e..1900259ce 100644
--- a/internal/identify/studio_test.go
+++ b/internal/identify/studio_test.go
@@ -20,23 +20,24 @@ func Test_createMissingStudio(t *testing.T) {
createdID := 1
createdID64 := int64(createdID)
- repo := mocks.NewTransactionManager()
- repo.StudioMock().On("Create", mock.MatchedBy(func(p models.Studio) bool {
+ repo := mocks.NewTxnRepository()
+ mockStudioReaderWriter := repo.Studio.(*mocks.StudioReaderWriter)
+ mockStudioReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p models.Studio) bool {
return p.Name.String == validName
})).Return(&models.Studio{
ID: createdID,
}, nil)
- repo.StudioMock().On("Create", mock.MatchedBy(func(p models.Studio) bool {
+ mockStudioReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p models.Studio) bool {
return p.Name.String == invalidName
})).Return(nil, errors.New("error creating performer"))
- repo.StudioMock().On("UpdateStashIDs", createdID, []models.StashID{
+ mockStudioReaderWriter.On("UpdateStashIDs", testCtx, createdID, []models.StashID{
{
Endpoint: invalidEndpoint,
StashID: remoteSiteID,
},
}).Return(errors.New("error updating stash ids"))
- repo.StudioMock().On("UpdateStashIDs", createdID, []models.StashID{
+ mockStudioReaderWriter.On("UpdateStashIDs", testCtx, createdID, []models.StashID{
{
Endpoint: validEndpoint,
StashID: remoteSiteID,
@@ -102,7 +103,7 @@ func Test_createMissingStudio(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- got, err := createMissingStudio(tt.args.endpoint, repo, tt.args.studio)
+ got, err := createMissingStudio(testCtx, tt.args.endpoint, mockStudioReaderWriter, tt.args.studio)
if (err != nil) != tt.wantErr {
t.Errorf("createMissingStudio() error = %v, wantErr %v", err, tt.wantErr)
return
diff --git a/internal/manager/checksum.go b/internal/manager/checksum.go
index 469f2c47f..53f368913 100644
--- a/internal/manager/checksum.go
+++ b/internal/manager/checksum.go
@@ -7,16 +7,21 @@ import (
"github.com/stashapp/stash/internal/manager/config"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models"
+ "github.com/stashapp/stash/pkg/txn"
)
-func setInitialMD5Config(ctx context.Context, txnManager models.TransactionManager) {
+type SceneCounter interface {
+ Count(ctx context.Context) (int, error)
+}
+
+func setInitialMD5Config(ctx context.Context, txnManager txn.Manager, counter SceneCounter) {
// if there are no scene files in the database, then default the
// VideoFileNamingAlgorithm config setting to oshash and calculateMD5 to
// false, otherwise set them to true for backwards compatibility purposes
var count int
- if err := txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
+ if err := txn.WithTxn(ctx, txnManager, func(ctx context.Context) error {
var err error
- count, err = r.Scene().Count()
+ count, err = counter.Count(ctx)
return err
}); err != nil {
logger.Errorf("Error while counting scenes: %s", err.Error())
@@ -36,6 +41,11 @@ func setInitialMD5Config(ctx context.Context, txnManager models.TransactionManag
}
}
+type SceneMissingHashCounter interface {
+ CountMissingChecksum(ctx context.Context) (int, error)
+ CountMissingOSHash(ctx context.Context) (int, error)
+}
+
// ValidateVideoFileNamingAlgorithm validates changing the
// VideoFileNamingAlgorithm configuration flag.
//
@@ -44,30 +54,27 @@ func setInitialMD5Config(ctx context.Context, txnManager models.TransactionManag
//
// Likewise, if VideoFileNamingAlgorithm is set to oshash, then this function
// will ensure that all oshash values are set on all scenes.
-func ValidateVideoFileNamingAlgorithm(txnManager models.TransactionManager, newValue models.HashAlgorithm) error {
+func ValidateVideoFileNamingAlgorithm(ctx context.Context, qb SceneMissingHashCounter, newValue models.HashAlgorithm) error {
// if algorithm is being set to MD5, then all checksums must be present
- return txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error {
- qb := r.Scene()
- if newValue == models.HashAlgorithmMd5 {
- missingMD5, err := qb.CountMissingChecksum()
- if err != nil {
- return err
- }
-
- if missingMD5 > 0 {
- return errors.New("some checksums are missing on scenes. Run Scan with calculateMD5 set to true")
- }
- } else if newValue == models.HashAlgorithmOshash {
- missingOSHash, err := qb.CountMissingOSHash()
- if err != nil {
- return err
- }
-
- if missingOSHash > 0 {
- return errors.New("some oshash values are missing on scenes. Run Scan to populate")
- }
+ if newValue == models.HashAlgorithmMd5 {
+ missingMD5, err := qb.CountMissingChecksum(ctx)
+ if err != nil {
+ return err
}
- return nil
- })
+ if missingMD5 > 0 {
+ return errors.New("some checksums are missing on scenes. Run Scan with calculateMD5 set to true")
+ }
+ } else if newValue == models.HashAlgorithmOshash {
+ missingOSHash, err := qb.CountMissingOSHash(ctx)
+ if err != nil {
+ return err
+ }
+
+ if missingOSHash > 0 {
+ return errors.New("some oshash values are missing on scenes. Run Scan to populate")
+ }
+ }
+
+ return nil
}
diff --git a/internal/manager/filename_parser.go b/internal/manager/filename_parser.go
index 3bd856e69..ecff4ea7a 100644
--- a/internal/manager/filename_parser.go
+++ b/internal/manager/filename_parser.go
@@ -1,6 +1,7 @@
package manager
import (
+ "context"
"database/sql"
"errors"
"path/filepath"
@@ -470,7 +471,15 @@ func (p *SceneFilenameParser) initWhiteSpaceRegex() {
}
}
-func (p *SceneFilenameParser) Parse(repo models.ReaderRepository) ([]*SceneParserResult, int, error) {
+type SceneFilenameParserRepository struct {
+ Scene scene.Queryer
+ Performer PerformerNamesFinder
+ Studio studio.Queryer
+ Movie MovieNameFinder
+ Tag tag.Queryer
+}
+
+func (p *SceneFilenameParser) Parse(ctx context.Context, repo SceneFilenameParserRepository) ([]*SceneParserResult, int, error) {
// perform the query to find the scenes
mapper, err := newParseMapper(p.Pattern, p.ParserInput.IgnoreWords)
@@ -492,17 +501,17 @@ func (p *SceneFilenameParser) Parse(repo models.ReaderRepository) ([]*SceneParse
p.Filter.Q = nil
- scenes, total, err := scene.QueryWithCount(repo.Scene(), sceneFilter, p.Filter)
+ scenes, total, err := scene.QueryWithCount(ctx, repo.Scene, sceneFilter, p.Filter)
if err != nil {
return nil, 0, err
}
- ret := p.parseScenes(repo, scenes, mapper)
+ ret := p.parseScenes(ctx, repo, scenes, mapper)
return ret, total, nil
}
-func (p *SceneFilenameParser) parseScenes(repo models.ReaderRepository, scenes []*models.Scene, mapper *parseMapper) []*SceneParserResult {
+func (p *SceneFilenameParser) parseScenes(ctx context.Context, repo SceneFilenameParserRepository, scenes []*models.Scene, mapper *parseMapper) []*SceneParserResult {
var ret []*SceneParserResult
for _, scene := range scenes {
sceneHolder := mapper.parse(scene)
@@ -511,7 +520,7 @@ func (p *SceneFilenameParser) parseScenes(repo models.ReaderRepository, scenes [
r := &SceneParserResult{
Scene: scene,
}
- p.setParserResult(repo, *sceneHolder, r)
+ p.setParserResult(ctx, repo, *sceneHolder, r)
ret = append(ret, r)
}
@@ -530,7 +539,11 @@ func (p SceneFilenameParser) replaceWhitespaceCharacters(value string) string {
return value
}
-func (p *SceneFilenameParser) queryPerformer(qb models.PerformerReader, performerName string) *models.Performer {
+type PerformerNamesFinder interface {
+ FindByNames(ctx context.Context, names []string, nocase bool) ([]*models.Performer, error)
+}
+
+func (p *SceneFilenameParser) queryPerformer(ctx context.Context, qb PerformerNamesFinder, performerName string) *models.Performer {
// massage the performer name
performerName = delimiterRE.ReplaceAllString(performerName, " ")
@@ -540,7 +553,7 @@ func (p *SceneFilenameParser) queryPerformer(qb models.PerformerReader, performe
}
// perform an exact match and grab the first
- performers, _ := qb.FindByNames([]string{performerName}, true)
+ performers, _ := qb.FindByNames(ctx, []string{performerName}, true)
var ret *models.Performer
if len(performers) > 0 {
@@ -553,7 +566,7 @@ func (p *SceneFilenameParser) queryPerformer(qb models.PerformerReader, performe
return ret
}
-func (p *SceneFilenameParser) queryStudio(qb models.StudioReader, studioName string) *models.Studio {
+func (p *SceneFilenameParser) queryStudio(ctx context.Context, qb studio.Queryer, studioName string) *models.Studio {
// massage the performer name
studioName = delimiterRE.ReplaceAllString(studioName, " ")
@@ -562,11 +575,11 @@ func (p *SceneFilenameParser) queryStudio(qb models.StudioReader, studioName str
return ret
}
- ret, _ := studio.ByName(qb, studioName)
+ ret, _ := studio.ByName(ctx, qb, studioName)
// try to match on alias
if ret == nil {
- ret, _ = studio.ByAlias(qb, studioName)
+ ret, _ = studio.ByAlias(ctx, qb, studioName)
}
// add result to cache
@@ -575,7 +588,11 @@ func (p *SceneFilenameParser) queryStudio(qb models.StudioReader, studioName str
return ret
}
-func (p *SceneFilenameParser) queryMovie(qb models.MovieReader, movieName string) *models.Movie {
+type MovieNameFinder interface {
+ FindByName(ctx context.Context, name string, nocase bool) (*models.Movie, error)
+}
+
+func (p *SceneFilenameParser) queryMovie(ctx context.Context, qb MovieNameFinder, movieName string) *models.Movie {
// massage the movie name
movieName = delimiterRE.ReplaceAllString(movieName, " ")
@@ -584,7 +601,7 @@ func (p *SceneFilenameParser) queryMovie(qb models.MovieReader, movieName string
return ret
}
- ret, _ := qb.FindByName(movieName, true)
+ ret, _ := qb.FindByName(ctx, movieName, true)
// add result to cache
p.movieCache[movieName] = ret
@@ -592,7 +609,7 @@ func (p *SceneFilenameParser) queryMovie(qb models.MovieReader, movieName string
return ret
}
-func (p *SceneFilenameParser) queryTag(qb models.TagReader, tagName string) *models.Tag {
+func (p *SceneFilenameParser) queryTag(ctx context.Context, qb tag.Queryer, tagName string) *models.Tag {
// massage the tag name
tagName = delimiterRE.ReplaceAllString(tagName, " ")
@@ -602,11 +619,11 @@ func (p *SceneFilenameParser) queryTag(qb models.TagReader, tagName string) *mod
}
// match tag name exactly
- ret, _ := tag.ByName(qb, tagName)
+ ret, _ := tag.ByName(ctx, qb, tagName)
// try to match on alias
if ret == nil {
- ret, _ = tag.ByAlias(qb, tagName)
+ ret, _ = tag.ByAlias(ctx, qb, tagName)
}
// add result to cache
@@ -615,12 +632,12 @@ func (p *SceneFilenameParser) queryTag(qb models.TagReader, tagName string) *mod
return ret
}
-func (p *SceneFilenameParser) setPerformers(qb models.PerformerReader, h sceneHolder, result *SceneParserResult) {
+func (p *SceneFilenameParser) setPerformers(ctx context.Context, qb PerformerNamesFinder, h sceneHolder, result *SceneParserResult) {
// query for each performer
performersSet := make(map[int]bool)
for _, performerName := range h.performers {
if performerName != "" {
- performer := p.queryPerformer(qb, performerName)
+ performer := p.queryPerformer(ctx, qb, performerName)
if performer != nil {
if _, found := performersSet[performer.ID]; !found {
result.PerformerIds = append(result.PerformerIds, strconv.Itoa(performer.ID))
@@ -631,12 +648,12 @@ func (p *SceneFilenameParser) setPerformers(qb models.PerformerReader, h sceneHo
}
}
-func (p *SceneFilenameParser) setTags(qb models.TagReader, h sceneHolder, result *SceneParserResult) {
+func (p *SceneFilenameParser) setTags(ctx context.Context, qb tag.Queryer, h sceneHolder, result *SceneParserResult) {
// query for each performer
tagsSet := make(map[int]bool)
for _, tagName := range h.tags {
if tagName != "" {
- tag := p.queryTag(qb, tagName)
+ tag := p.queryTag(ctx, qb, tagName)
if tag != nil {
if _, found := tagsSet[tag.ID]; !found {
result.TagIds = append(result.TagIds, strconv.Itoa(tag.ID))
@@ -647,10 +664,10 @@ func (p *SceneFilenameParser) setTags(qb models.TagReader, h sceneHolder, result
}
}
-func (p *SceneFilenameParser) setStudio(qb models.StudioReader, h sceneHolder, result *SceneParserResult) {
+func (p *SceneFilenameParser) setStudio(ctx context.Context, qb studio.Queryer, h sceneHolder, result *SceneParserResult) {
// query for each performer
if h.studio != "" {
- studio := p.queryStudio(qb, h.studio)
+ studio := p.queryStudio(ctx, qb, h.studio)
if studio != nil {
studioID := strconv.Itoa(studio.ID)
result.StudioID = &studioID
@@ -658,12 +675,12 @@ func (p *SceneFilenameParser) setStudio(qb models.StudioReader, h sceneHolder, r
}
}
-func (p *SceneFilenameParser) setMovies(qb models.MovieReader, h sceneHolder, result *SceneParserResult) {
+func (p *SceneFilenameParser) setMovies(ctx context.Context, qb MovieNameFinder, h sceneHolder, result *SceneParserResult) {
// query for each movie
moviesSet := make(map[int]bool)
for _, movieName := range h.movies {
if movieName != "" {
- movie := p.queryMovie(qb, movieName)
+ movie := p.queryMovie(ctx, qb, movieName)
if movie != nil {
if _, found := moviesSet[movie.ID]; !found {
result.Movies = append(result.Movies, &SceneMovieID{
@@ -676,7 +693,7 @@ func (p *SceneFilenameParser) setMovies(qb models.MovieReader, h sceneHolder, re
}
}
-func (p *SceneFilenameParser) setParserResult(repo models.ReaderRepository, h sceneHolder, result *SceneParserResult) {
+func (p *SceneFilenameParser) setParserResult(ctx context.Context, repo SceneFilenameParserRepository, h sceneHolder, result *SceneParserResult) {
if h.result.Title.Valid {
title := h.result.Title.String
title = p.replaceWhitespaceCharacters(title)
@@ -698,15 +715,15 @@ func (p *SceneFilenameParser) setParserResult(repo models.ReaderRepository, h sc
}
if len(h.performers) > 0 {
- p.setPerformers(repo.Performer(), h, result)
+ p.setPerformers(ctx, repo.Performer, h, result)
}
if len(h.tags) > 0 {
- p.setTags(repo.Tag(), h, result)
+ p.setTags(ctx, repo.Tag, h, result)
}
- p.setStudio(repo.Studio(), h, result)
+ p.setStudio(ctx, repo.Studio, h, result)
if len(h.movies) > 0 {
- p.setMovies(repo.Movie(), h, result)
+ p.setMovies(ctx, repo.Movie, h, result)
}
}
diff --git a/internal/manager/import.go b/internal/manager/import.go
index 8e3140577..0762096c2 100644
--- a/internal/manager/import.go
+++ b/internal/manager/import.go
@@ -1,6 +1,7 @@
package manager
import (
+ "context"
"fmt"
"io"
"strconv"
@@ -52,22 +53,22 @@ func (e ImportDuplicateEnum) MarshalGQL(w io.Writer) {
}
type importer interface {
- PreImport() error
- PostImport(id int) error
+ PreImport(ctx context.Context) error
+ PostImport(ctx context.Context, id int) error
Name() string
- FindExistingID() (*int, error)
- Create() (*int, error)
- Update(id int) error
+ FindExistingID(ctx context.Context) (*int, error)
+ Create(ctx context.Context) (*int, error)
+ Update(ctx context.Context, id int) error
}
-func performImport(i importer, duplicateBehaviour ImportDuplicateEnum) error {
- if err := i.PreImport(); err != nil {
+func performImport(ctx context.Context, i importer, duplicateBehaviour ImportDuplicateEnum) error {
+ if err := i.PreImport(ctx); err != nil {
return err
}
// try to find an existing object with the same name
name := i.Name()
- existing, err := i.FindExistingID()
+ existing, err := i.FindExistingID(ctx)
if err != nil {
return fmt.Errorf("error finding existing objects: %v", err)
}
@@ -84,12 +85,12 @@ func performImport(i importer, duplicateBehaviour ImportDuplicateEnum) error {
// must be overwriting
id = *existing
- if err := i.Update(id); err != nil {
+ if err := i.Update(ctx, id); err != nil {
return fmt.Errorf("error updating existing object: %v", err)
}
} else {
// creating
- createdID, err := i.Create()
+ createdID, err := i.Create(ctx)
if err != nil {
return fmt.Errorf("error creating object: %v", err)
}
@@ -97,7 +98,7 @@ func performImport(i importer, duplicateBehaviour ImportDuplicateEnum) error {
id = *createdID
}
- if err := i.PostImport(id); err != nil {
+ if err := i.PostImport(ctx, id); err != nil {
return err
}
diff --git a/internal/manager/manager.go b/internal/manager/manager.go
index 56221327b..4f292750e 100644
--- a/internal/manager/manager.go
+++ b/internal/manager/manager.go
@@ -17,7 +17,6 @@ import (
"github.com/stashapp/stash/internal/dlna"
"github.com/stashapp/stash/internal/log"
"github.com/stashapp/stash/internal/manager/config"
- "github.com/stashapp/stash/pkg/database"
"github.com/stashapp/stash/pkg/ffmpeg"
"github.com/stashapp/stash/pkg/fsutil"
"github.com/stashapp/stash/pkg/job"
@@ -115,7 +114,8 @@ type Manager struct {
DLNAService *dlna.Service
- TxnManager models.TransactionManager
+ Database *sqlite.Database
+ Repository models.Repository
scanSubs *subscriptionManager
}
@@ -150,6 +150,8 @@ func initialize() error {
l := initLog()
initProfiling(cfg.GetCPUProfilePath())
+ db := &sqlite.Database{}
+
instance = &Manager{
Config: cfg,
Logger: l,
@@ -157,7 +159,20 @@ func initialize() error {
DownloadStore: NewDownloadStore(),
PluginCache: plugin.NewCache(cfg),
- TxnManager: sqlite.NewTransactionManager(),
+ Database: db,
+ Repository: models.Repository{
+ TxnManager: db,
+ Gallery: sqlite.GalleryReaderWriter,
+ Image: sqlite.ImageReaderWriter,
+ Movie: sqlite.MovieReaderWriter,
+ Performer: sqlite.PerformerReaderWriter,
+ Scene: sqlite.SceneReaderWriter,
+ SceneMarker: sqlite.SceneMarkerReaderWriter,
+ ScrapedItem: sqlite.ScrapedItemReaderWriter,
+ Studio: sqlite.StudioReaderWriter,
+ Tag: sqlite.TagReaderWriter,
+ SavedFilter: sqlite.SavedFilterReaderWriter,
+ },
scanSubs: &subscriptionManager{},
}
@@ -165,9 +180,17 @@ func initialize() error {
instance.JobManager = initJobManager()
sceneServer := SceneServer{
- TXNManager: instance.TxnManager,
+ TxnManager: instance.Repository,
+ SceneCoverGetter: instance.Repository.Scene,
}
- instance.DLNAService = dlna.NewService(instance.TxnManager, instance.Config, &sceneServer)
+
+ instance.DLNAService = dlna.NewService(instance.Repository, dlna.Repository{
+ SceneFinder: instance.Repository.Scene,
+ StudioFinder: instance.Repository.Studio,
+ TagFinder: instance.Repository.Tag,
+ PerformerFinder: instance.Repository.Performer,
+ MovieFinder: instance.Repository.Movie,
+ }, instance.Config, &sceneServer)
if !cfg.IsNewSystem() {
logger.Infof("using config file: %s", cfg.GetConfigFile())
@@ -177,9 +200,14 @@ func initialize() error {
}
if err != nil {
- return fmt.Errorf("error initializing configuration: %w", err)
+ panic(fmt.Sprintf("error initializing configuration: %s", err.Error()))
} else if err := instance.PostInit(ctx); err != nil {
- return err
+ var migrationNeededErr *sqlite.MigrationNeededError
+ if errors.As(err, &migrationNeededErr) {
+ logger.Warn(err.Error())
+ } else {
+ panic(err)
+ }
}
initSecurity(cfg)
@@ -352,7 +380,8 @@ func (s *Manager) PostInit(ctx context.Context) error {
})
}
- if err := database.Initialize(s.Config.GetDatabasePath()); err != nil {
+ database := s.Database
+ if err := database.Open(s.Config.GetDatabasePath()); err != nil {
return err
}
@@ -377,7 +406,14 @@ func writeStashIcon() {
// initScraperCache initializes a new scraper cache and returns it.
func (s *Manager) initScraperCache() *scraper.Cache {
- ret, err := scraper.NewCache(config.GetInstance(), s.TxnManager)
+ ret, err := scraper.NewCache(config.GetInstance(), s.Repository, scraper.Repository{
+ SceneFinder: s.Repository.Scene,
+ GalleryFinder: s.Repository.Gallery,
+ TagFinder: s.Repository.Tag,
+ PerformerFinder: s.Repository.Performer,
+ MovieFinder: s.Repository.Movie,
+ StudioFinder: s.Repository.Studio,
+ })
if err != nil {
logger.Errorf("Error reading scraper configs: %s", err.Error())
@@ -476,7 +512,12 @@ func (s *Manager) Setup(ctx context.Context, input SetupInput) error {
// initialise the database
if err := s.PostInit(ctx); err != nil {
- return fmt.Errorf("error initializing the database: %v", err)
+ var migrationNeededErr *sqlite.MigrationNeededError
+ if errors.As(err, &migrationNeededErr) {
+ logger.Warn(err.Error())
+ } else {
+ return fmt.Errorf("error initializing the database: %v", err)
+ }
}
s.Config.FinalizeSetup()
@@ -501,6 +542,8 @@ type MigrateInput struct {
}
func (s *Manager) Migrate(ctx context.Context, input MigrateInput) error {
+ database := s.Database
+
// always backup so that we can roll back to the previous version if
// migration fails
backupPath := input.BackupPath
@@ -509,7 +552,7 @@ func (s *Manager) Migrate(ctx context.Context, input MigrateInput) error {
}
// perform database backup
- if err := database.Backup(database.DB, backupPath); err != nil {
+ if err := database.Backup(backupPath); err != nil {
return fmt.Errorf("error backing up database: %s", err)
}
@@ -541,6 +584,7 @@ func (s *Manager) Migrate(ctx context.Context, input MigrateInput) error {
}
func (s *Manager) GetSystemStatus() *SystemStatus {
+ database := s.Database
status := SystemStatusEnumOk
dbSchema := int(database.Version())
dbPath := database.DatabasePath()
@@ -569,7 +613,7 @@ func (s *Manager) Shutdown(code int) {
// TODO: Each part of the manager needs to gracefully stop at some point
// for now, we just close the database.
- err := database.Close()
+ err := s.Database.Close()
if err != nil {
logger.Errorf("Error closing database: %s", err)
if code == 0 {
diff --git a/internal/manager/manager_tasks.go b/internal/manager/manager_tasks.go
index 3178f7846..95f5c935f 100644
--- a/internal/manager/manager_tasks.go
+++ b/internal/manager/manager_tasks.go
@@ -84,7 +84,7 @@ func (s *Manager) Scan(ctx context.Context, input ScanMetadataInput) (int, error
}
scanJob := ScanJob{
- txnManager: s.TxnManager,
+ txnManager: s.Repository,
input: input,
subscriptions: s.scanSubs,
}
@@ -101,7 +101,7 @@ func (s *Manager) Import(ctx context.Context) (int, error) {
j := job.MakeJobExec(func(ctx context.Context, progress *job.Progress) {
task := ImportTask{
- txnManager: s.TxnManager,
+ txnManager: s.Repository,
BaseDir: metadataPath,
Reset: true,
DuplicateBehaviour: ImportDuplicateEnumFail,
@@ -125,7 +125,7 @@ func (s *Manager) Export(ctx context.Context) (int, error) {
var wg sync.WaitGroup
wg.Add(1)
task := ExportTask{
- txnManager: s.TxnManager,
+ txnManager: s.Repository,
full: true,
fileNamingAlgorithm: config.GetVideoFileNamingAlgorithm(),
}
@@ -156,7 +156,7 @@ func (s *Manager) Generate(ctx context.Context, input GenerateMetadataInput) (in
}
j := &GenerateJob{
- txnManager: s.TxnManager,
+ txnManager: s.Repository,
input: input,
}
@@ -185,9 +185,9 @@ func (s *Manager) generateScreenshot(ctx context.Context, sceneId string, at *fl
}
var scene *models.Scene
- if err := s.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
+ if err := s.Repository.WithTxn(ctx, func(ctx context.Context) error {
var err error
- scene, err = r.Scene().Find(sceneIdInt)
+ scene, err = s.Repository.Scene.Find(ctx, sceneIdInt)
return err
}); err != nil || scene == nil {
logger.Errorf("failed to get scene for generate: %s", err.Error())
@@ -195,7 +195,7 @@ func (s *Manager) generateScreenshot(ctx context.Context, sceneId string, at *fl
}
task := GenerateScreenshotTask{
- txnManager: s.TxnManager,
+ txnManager: s.Repository,
Scene: *scene,
ScreenshotAt: at,
fileNamingAlgorithm: config.GetInstance().GetVideoFileNamingAlgorithm(),
@@ -222,7 +222,7 @@ type AutoTagMetadataInput struct {
func (s *Manager) AutoTag(ctx context.Context, input AutoTagMetadataInput) int {
j := autoTagJob{
- txnManager: s.TxnManager,
+ txnManager: s.Repository,
input: input,
}
@@ -237,7 +237,7 @@ type CleanMetadataInput struct {
func (s *Manager) Clean(ctx context.Context, input CleanMetadataInput) int {
j := cleanJob{
- txnManager: s.TxnManager,
+ txnManager: s.Repository,
input: input,
scanSubs: s.scanSubs,
}
@@ -251,9 +251,9 @@ func (s *Manager) MigrateHash(ctx context.Context) int {
logger.Infof("Migrating generated files for %s naming hash", fileNamingAlgo.String())
var scenes []*models.Scene
- if err := s.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
+ if err := s.Repository.WithTxn(ctx, func(ctx context.Context) error {
var err error
- scenes, err = r.Scene().All()
+ scenes, err = s.Repository.Scene.All(ctx)
return err
}); err != nil {
logger.Errorf("failed to fetch list of scenes for migration: %s", err.Error())
@@ -327,15 +327,14 @@ func (s *Manager) StashBoxBatchPerformerTag(ctx context.Context, input StashBoxB
// This is why we mark this section nolint. In principle, we should look to
// rewrite the section at some point, to avoid the linter warning.
if len(input.PerformerIds) > 0 { //nolint:gocritic
- if err := s.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
- performerQuery := r.Performer()
+ if err := s.Repository.WithTxn(ctx, func(ctx context.Context) error {
+ performerQuery := s.Repository.Performer
for _, performerID := range input.PerformerIds {
if id, err := strconv.Atoi(performerID); err == nil {
- performer, err := performerQuery.Find(id)
+ performer, err := performerQuery.Find(ctx, id)
if err == nil {
tasks = append(tasks, StashBoxPerformerTagTask{
- txnManager: s.TxnManager,
performer: performer,
refresh: input.Refresh,
box: box,
@@ -354,7 +353,6 @@ func (s *Manager) StashBoxBatchPerformerTag(ctx context.Context, input StashBoxB
for i := range input.PerformerNames {
if len(input.PerformerNames[i]) > 0 {
tasks = append(tasks, StashBoxPerformerTagTask{
- txnManager: s.TxnManager,
name: &input.PerformerNames[i],
refresh: input.Refresh,
box: box,
@@ -367,14 +365,14 @@ func (s *Manager) StashBoxBatchPerformerTag(ctx context.Context, input StashBoxB
// However, this doesn't really help with readability of the current section. Mark it
// as nolint for now. In the future we'd like to rewrite this code by factoring some of
// this into separate functions.
- if err := s.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
- performerQuery := r.Performer()
+ if err := s.Repository.WithTxn(ctx, func(ctx context.Context) error {
+ performerQuery := s.Repository.Performer
var performers []*models.Performer
var err error
if input.Refresh {
- performers, err = performerQuery.FindByStashIDStatus(true, box.Endpoint)
+ performers, err = performerQuery.FindByStashIDStatus(ctx, true, box.Endpoint)
} else {
- performers, err = performerQuery.FindByStashIDStatus(false, box.Endpoint)
+ performers, err = performerQuery.FindByStashIDStatus(ctx, false, box.Endpoint)
}
if err != nil {
return fmt.Errorf("error querying performers: %v", err)
@@ -382,7 +380,6 @@ func (s *Manager) StashBoxBatchPerformerTag(ctx context.Context, input StashBoxB
for _, performer := range performers {
tasks = append(tasks, StashBoxPerformerTagTask{
- txnManager: s.TxnManager,
performer: performer,
refresh: input.Refresh,
box: box,
diff --git a/internal/manager/post_migrate.go b/internal/manager/post_migrate.go
index 1db1aac40..acc93ae69 100644
--- a/internal/manager/post_migrate.go
+++ b/internal/manager/post_migrate.go
@@ -4,5 +4,5 @@ import "context"
// PostMigrate is executed after migrations have been executed.
func (s *Manager) PostMigrate(ctx context.Context) {
- setInitialMD5Config(ctx, s.TxnManager)
+ setInitialMD5Config(ctx, s.Repository, s.Repository.Scene)
}
diff --git a/internal/manager/running_streams.go b/internal/manager/running_streams.go
index 41c196462..9d43d26d2 100644
--- a/internal/manager/running_streams.go
+++ b/internal/manager/running_streams.go
@@ -8,6 +8,7 @@ import (
"github.com/stashapp/stash/pkg/fsutil"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models"
+ "github.com/stashapp/stash/pkg/txn"
"github.com/stashapp/stash/pkg/utils"
)
@@ -49,8 +50,13 @@ func KillRunningStreams(scene *models.Scene, fileNamingAlgo models.HashAlgorithm
instance.ReadLockManager.Cancel(transcodePath)
}
+type SceneCoverGetter interface {
+ GetCover(ctx context.Context, sceneID int) ([]byte, error)
+}
+
type SceneServer struct {
- TXNManager models.TransactionManager
+ TxnManager txn.Manager
+ SceneCoverGetter SceneCoverGetter
}
func (s *SceneServer) StreamSceneDirect(scene *models.Scene, w http.ResponseWriter, r *http.Request) {
@@ -75,8 +81,8 @@ func (s *SceneServer) ServeScreenshot(scene *models.Scene, w http.ResponseWriter
http.ServeFile(w, r, filepath)
} else {
var cover []byte
- err := s.TXNManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
- cover, _ = repo.Scene().GetCover(scene.ID)
+ err := txn.WithTxn(r.Context(), s.TxnManager, func(ctx context.Context) error {
+ cover, _ = s.SceneCoverGetter.GetCover(ctx, scene.ID)
return nil
})
if err != nil {
diff --git a/internal/manager/studio.go b/internal/manager/studio.go
index 3b0d81ceb..6b517af6f 100644
--- a/internal/manager/studio.go
+++ b/internal/manager/studio.go
@@ -1,13 +1,15 @@
package manager
import (
+ "context"
"errors"
"fmt"
"github.com/stashapp/stash/pkg/models"
+ "github.com/stashapp/stash/pkg/studio"
)
-func ValidateModifyStudio(studio models.StudioPartial, qb models.StudioReader) error {
+func ValidateModifyStudio(ctx context.Context, studio models.StudioPartial, qb studio.Finder) error {
if studio.ParentID == nil || !studio.ParentID.Valid {
return nil
}
@@ -22,7 +24,7 @@ func ValidateModifyStudio(studio models.StudioPartial, qb models.StudioReader) e
return errors.New("studio cannot be an ancestor of itself")
}
- currentStudio, err := qb.Find(int(currentParentID.Int64))
+ currentStudio, err := qb.Find(ctx, int(currentParentID.Int64))
if err != nil {
return fmt.Errorf("error finding parent studio: %v", err)
}
diff --git a/internal/manager/task_autotag.go b/internal/manager/task_autotag.go
index a912dcedc..674fdfe64 100644
--- a/internal/manager/task_autotag.go
+++ b/internal/manager/task_autotag.go
@@ -19,7 +19,7 @@ import (
)
type autoTagJob struct {
- txnManager models.TransactionManager
+ txnManager models.Repository
input AutoTagMetadataInput
cache match.Cache
@@ -73,27 +73,28 @@ func (j *autoTagJob) autoTagSpecific(ctx context.Context, progress *job.Progress
studioCount := len(studioIds)
tagCount := len(tagIds)
- if err := j.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
- performerQuery := r.Performer()
- studioQuery := r.Studio()
- tagQuery := r.Tag()
+ if err := j.txnManager.WithTxn(ctx, func(ctx context.Context) error {
+ r := j.txnManager
+ performerQuery := r.Performer
+ studioQuery := r.Studio
+ tagQuery := r.Tag
const wildcard = "*"
var err error
if performerCount == 1 && performerIds[0] == wildcard {
- performerCount, err = performerQuery.Count()
+ performerCount, err = performerQuery.Count(ctx)
if err != nil {
return fmt.Errorf("error getting performer count: %v", err)
}
}
if studioCount == 1 && studioIds[0] == wildcard {
- studioCount, err = studioQuery.Count()
+ studioCount, err = studioQuery.Count(ctx)
if err != nil {
return fmt.Errorf("error getting studio count: %v", err)
}
}
if tagCount == 1 && tagIds[0] == wildcard {
- tagCount, err = tagQuery.Count()
+ tagCount, err = tagQuery.Count(ctx)
if err != nil {
return fmt.Errorf("error getting tag count: %v", err)
}
@@ -123,14 +124,14 @@ func (j *autoTagJob) autoTagPerformers(ctx context.Context, progress *job.Progre
for _, performerId := range performerIds {
var performers []*models.Performer
- if err := j.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
- performerQuery := r.Performer()
+ if err := j.txnManager.WithTxn(ctx, func(ctx context.Context) error {
+ performerQuery := j.txnManager.Performer
ignoreAutoTag := false
perPage := -1
if performerId == "*" {
var err error
- performers, _, err = performerQuery.Query(&models.PerformerFilterType{
+ performers, _, err = performerQuery.Query(ctx, &models.PerformerFilterType{
IgnoreAutoTag: &ignoreAutoTag,
}, &models.FindFilterType{
PerPage: &perPage,
@@ -144,7 +145,7 @@ func (j *autoTagJob) autoTagPerformers(ctx context.Context, progress *job.Progre
return fmt.Errorf("error parsing performer id %s: %s", performerId, err.Error())
}
- performer, err := performerQuery.Find(performerIdInt)
+ performer, err := performerQuery.Find(ctx, performerIdInt)
if err != nil {
return fmt.Errorf("error finding performer id %s: %s", performerId, err.Error())
}
@@ -161,14 +162,15 @@ func (j *autoTagJob) autoTagPerformers(ctx context.Context, progress *job.Progre
return nil
}
- if err := j.txnManager.WithTxn(ctx, func(r models.Repository) error {
- if err := autotag.PerformerScenes(performer, paths, r.Scene(), &j.cache); err != nil {
+ if err := j.txnManager.WithTxn(ctx, func(ctx context.Context) error {
+ r := j.txnManager
+ if err := autotag.PerformerScenes(ctx, performer, paths, r.Scene, &j.cache); err != nil {
return err
}
- if err := autotag.PerformerImages(performer, paths, r.Image(), &j.cache); err != nil {
+ if err := autotag.PerformerImages(ctx, performer, paths, r.Image, &j.cache); err != nil {
return err
}
- if err := autotag.PerformerGalleries(performer, paths, r.Gallery(), &j.cache); err != nil {
+ if err := autotag.PerformerGalleries(ctx, performer, paths, r.Gallery, &j.cache); err != nil {
return err
}
@@ -193,16 +195,18 @@ func (j *autoTagJob) autoTagStudios(ctx context.Context, progress *job.Progress,
return
}
+ r := j.txnManager
+
for _, studioId := range studioIds {
var studios []*models.Studio
- if err := j.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
- studioQuery := r.Studio()
+ if err := r.WithTxn(ctx, func(ctx context.Context) error {
+ studioQuery := r.Studio
ignoreAutoTag := false
perPage := -1
if studioId == "*" {
var err error
- studios, _, err = studioQuery.Query(&models.StudioFilterType{
+ studios, _, err = studioQuery.Query(ctx, &models.StudioFilterType{
IgnoreAutoTag: &ignoreAutoTag,
}, &models.FindFilterType{
PerPage: &perPage,
@@ -216,7 +220,7 @@ func (j *autoTagJob) autoTagStudios(ctx context.Context, progress *job.Progress,
return fmt.Errorf("error parsing studio id %s: %s", studioId, err.Error())
}
- studio, err := studioQuery.Find(studioIdInt)
+ studio, err := studioQuery.Find(ctx, studioIdInt)
if err != nil {
return fmt.Errorf("error finding studio id %s: %s", studioId, err.Error())
}
@@ -234,19 +238,19 @@ func (j *autoTagJob) autoTagStudios(ctx context.Context, progress *job.Progress,
return nil
}
- if err := j.txnManager.WithTxn(ctx, func(r models.Repository) error {
- aliases, err := r.Studio().GetAliases(studio.ID)
+ if err := j.txnManager.WithTxn(ctx, func(ctx context.Context) error {
+ aliases, err := r.Studio.GetAliases(ctx, studio.ID)
if err != nil {
return err
}
- if err := autotag.StudioScenes(studio, paths, aliases, r.Scene(), &j.cache); err != nil {
+ if err := autotag.StudioScenes(ctx, studio, paths, aliases, r.Scene, &j.cache); err != nil {
return err
}
- if err := autotag.StudioImages(studio, paths, aliases, r.Image(), &j.cache); err != nil {
+ if err := autotag.StudioImages(ctx, studio, paths, aliases, r.Image, &j.cache); err != nil {
return err
}
- if err := autotag.StudioGalleries(studio, paths, aliases, r.Gallery(), &j.cache); err != nil {
+ if err := autotag.StudioGalleries(ctx, studio, paths, aliases, r.Gallery, &j.cache); err != nil {
return err
}
@@ -271,15 +275,17 @@ func (j *autoTagJob) autoTagTags(ctx context.Context, progress *job.Progress, pa
return
}
+ r := j.txnManager
+
for _, tagId := range tagIds {
var tags []*models.Tag
- if err := j.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
- tagQuery := r.Tag()
+ if err := j.txnManager.WithTxn(ctx, func(ctx context.Context) error {
+ tagQuery := r.Tag
ignoreAutoTag := false
perPage := -1
if tagId == "*" {
var err error
- tags, _, err = tagQuery.Query(&models.TagFilterType{
+ tags, _, err = tagQuery.Query(ctx, &models.TagFilterType{
IgnoreAutoTag: &ignoreAutoTag,
}, &models.FindFilterType{
PerPage: &perPage,
@@ -293,7 +299,7 @@ func (j *autoTagJob) autoTagTags(ctx context.Context, progress *job.Progress, pa
return fmt.Errorf("error parsing tag id %s: %s", tagId, err.Error())
}
- tag, err := tagQuery.Find(tagIdInt)
+ tag, err := tagQuery.Find(ctx, tagIdInt)
if err != nil {
return fmt.Errorf("error finding tag id %s: %s", tagId, err.Error())
}
@@ -306,19 +312,19 @@ func (j *autoTagJob) autoTagTags(ctx context.Context, progress *job.Progress, pa
return nil
}
- if err := j.txnManager.WithTxn(ctx, func(r models.Repository) error {
- aliases, err := r.Tag().GetAliases(tag.ID)
+ if err := j.txnManager.WithTxn(ctx, func(ctx context.Context) error {
+ aliases, err := r.Tag.GetAliases(ctx, tag.ID)
if err != nil {
return err
}
- if err := autotag.TagScenes(tag, paths, aliases, r.Scene(), &j.cache); err != nil {
+ if err := autotag.TagScenes(ctx, tag, paths, aliases, r.Scene, &j.cache); err != nil {
return err
}
- if err := autotag.TagImages(tag, paths, aliases, r.Image(), &j.cache); err != nil {
+ if err := autotag.TagImages(ctx, tag, paths, aliases, r.Image, &j.cache); err != nil {
return err
}
- if err := autotag.TagGalleries(tag, paths, aliases, r.Gallery(), &j.cache); err != nil {
+ if err := autotag.TagGalleries(ctx, tag, paths, aliases, r.Gallery, &j.cache); err != nil {
return err
}
@@ -345,7 +351,7 @@ type autoTagFilesTask struct {
tags bool
progress *job.Progress
- txnManager models.TransactionManager
+ txnManager models.Repository
cache *match.Cache
}
@@ -425,13 +431,13 @@ func (t *autoTagFilesTask) makeGalleryFilter() *models.GalleryFilterType {
return ret
}
-func (t *autoTagFilesTask) getCount(r models.ReaderRepository) (int, error) {
+func (t *autoTagFilesTask) getCount(ctx context.Context, r models.Repository) (int, error) {
pp := 0
findFilter := &models.FindFilterType{
PerPage: &pp,
}
- sceneResults, err := r.Scene().Query(models.SceneQueryOptions{
+ sceneResults, err := r.Scene.Query(ctx, models.SceneQueryOptions{
QueryOptions: models.QueryOptions{
FindFilter: findFilter,
Count: true,
@@ -444,7 +450,7 @@ func (t *autoTagFilesTask) getCount(r models.ReaderRepository) (int, error) {
sceneCount := sceneResults.Count
- imageResults, err := r.Image().Query(models.ImageQueryOptions{
+ imageResults, err := r.Image.Query(ctx, models.ImageQueryOptions{
QueryOptions: models.QueryOptions{
FindFilter: findFilter,
Count: true,
@@ -457,7 +463,7 @@ func (t *autoTagFilesTask) getCount(r models.ReaderRepository) (int, error) {
imageCount := imageResults.Count
- _, galleryCount, err := r.Gallery().Query(t.makeGalleryFilter(), findFilter)
+ _, galleryCount, err := r.Gallery.Query(ctx, t.makeGalleryFilter(), findFilter)
if err != nil {
return 0, err
}
@@ -465,7 +471,7 @@ func (t *autoTagFilesTask) getCount(r models.ReaderRepository) (int, error) {
return sceneCount + imageCount + galleryCount, nil
}
-func (t *autoTagFilesTask) processScenes(ctx context.Context, r models.ReaderRepository) error {
+func (t *autoTagFilesTask) processScenes(ctx context.Context, r models.Repository) error {
if job.IsCancelled(ctx) {
return nil
}
@@ -477,7 +483,7 @@ func (t *autoTagFilesTask) processScenes(ctx context.Context, r models.ReaderRep
more := true
for more {
- scenes, err := scene.Query(r.Scene(), sceneFilter, findFilter)
+ scenes, err := scene.Query(ctx, r.Scene, sceneFilter, findFilter)
if err != nil {
return err
}
@@ -518,7 +524,7 @@ func (t *autoTagFilesTask) processScenes(ctx context.Context, r models.ReaderRep
return nil
}
-func (t *autoTagFilesTask) processImages(ctx context.Context, r models.ReaderRepository) error {
+func (t *autoTagFilesTask) processImages(ctx context.Context, r models.Repository) error {
if job.IsCancelled(ctx) {
return nil
}
@@ -530,7 +536,7 @@ func (t *autoTagFilesTask) processImages(ctx context.Context, r models.ReaderRep
more := true
for more {
- images, err := image.Query(r.Image(), imageFilter, findFilter)
+ images, err := image.Query(ctx, r.Image, imageFilter, findFilter)
if err != nil {
return err
}
@@ -571,7 +577,7 @@ func (t *autoTagFilesTask) processImages(ctx context.Context, r models.ReaderRep
return nil
}
-func (t *autoTagFilesTask) processGalleries(ctx context.Context, r models.ReaderRepository) error {
+func (t *autoTagFilesTask) processGalleries(ctx context.Context, r models.Repository) error {
if job.IsCancelled(ctx) {
return nil
}
@@ -583,7 +589,7 @@ func (t *autoTagFilesTask) processGalleries(ctx context.Context, r models.Reader
more := true
for more {
- galleries, _, err := r.Gallery().Query(galleryFilter, findFilter)
+ galleries, _, err := r.Gallery.Query(ctx, galleryFilter, findFilter)
if err != nil {
return err
}
@@ -625,8 +631,9 @@ func (t *autoTagFilesTask) processGalleries(ctx context.Context, r models.Reader
}
func (t *autoTagFilesTask) process(ctx context.Context) {
- if err := t.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
- total, err := t.getCount(r)
+ r := t.txnManager
+ if err := r.WithTxn(ctx, func(ctx context.Context) error {
+ total, err := t.getCount(ctx, t.txnManager)
if err != nil {
return err
}
@@ -661,7 +668,7 @@ func (t *autoTagFilesTask) process(ctx context.Context) {
}
type autoTagSceneTask struct {
- txnManager models.TransactionManager
+ txnManager models.Repository
scene *models.Scene
performers bool
@@ -673,19 +680,20 @@ type autoTagSceneTask struct {
func (t *autoTagSceneTask) Start(ctx context.Context, wg *sync.WaitGroup) {
defer wg.Done()
- if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error {
+ r := t.txnManager
+ if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
if t.performers {
- if err := autotag.ScenePerformers(t.scene, r.Scene(), r.Performer(), t.cache); err != nil {
+ if err := autotag.ScenePerformers(ctx, t.scene, r.Scene, r.Performer, t.cache); err != nil {
return fmt.Errorf("error tagging scene performers for %s: %v", t.scene.Path, err)
}
}
if t.studios {
- if err := autotag.SceneStudios(t.scene, r.Scene(), r.Studio(), t.cache); err != nil {
+ if err := autotag.SceneStudios(ctx, t.scene, r.Scene, r.Studio, t.cache); err != nil {
return fmt.Errorf("error tagging scene studio for %s: %v", t.scene.Path, err)
}
}
if t.tags {
- if err := autotag.SceneTags(t.scene, r.Scene(), r.Tag(), t.cache); err != nil {
+ if err := autotag.SceneTags(ctx, t.scene, r.Scene, r.Tag, t.cache); err != nil {
return fmt.Errorf("error tagging scene tags for %s: %v", t.scene.Path, err)
}
}
@@ -697,7 +705,7 @@ func (t *autoTagSceneTask) Start(ctx context.Context, wg *sync.WaitGroup) {
}
type autoTagImageTask struct {
- txnManager models.TransactionManager
+ txnManager models.Repository
image *models.Image
performers bool
@@ -709,19 +717,20 @@ type autoTagImageTask struct {
func (t *autoTagImageTask) Start(ctx context.Context, wg *sync.WaitGroup) {
defer wg.Done()
- if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error {
+ r := t.txnManager
+ if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
if t.performers {
- if err := autotag.ImagePerformers(t.image, r.Image(), r.Performer(), t.cache); err != nil {
+ if err := autotag.ImagePerformers(ctx, t.image, r.Image, r.Performer, t.cache); err != nil {
return fmt.Errorf("error tagging image performers for %s: %v", t.image.Path, err)
}
}
if t.studios {
- if err := autotag.ImageStudios(t.image, r.Image(), r.Studio(), t.cache); err != nil {
+ if err := autotag.ImageStudios(ctx, t.image, r.Image, r.Studio, t.cache); err != nil {
return fmt.Errorf("error tagging image studio for %s: %v", t.image.Path, err)
}
}
if t.tags {
- if err := autotag.ImageTags(t.image, r.Image(), r.Tag(), t.cache); err != nil {
+ if err := autotag.ImageTags(ctx, t.image, r.Image, r.Tag, t.cache); err != nil {
return fmt.Errorf("error tagging image tags for %s: %v", t.image.Path, err)
}
}
@@ -733,7 +742,7 @@ func (t *autoTagImageTask) Start(ctx context.Context, wg *sync.WaitGroup) {
}
type autoTagGalleryTask struct {
- txnManager models.TransactionManager
+ txnManager models.Repository
gallery *models.Gallery
performers bool
@@ -745,19 +754,20 @@ type autoTagGalleryTask struct {
func (t *autoTagGalleryTask) Start(ctx context.Context, wg *sync.WaitGroup) {
defer wg.Done()
- if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error {
+ r := t.txnManager
+ if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
if t.performers {
- if err := autotag.GalleryPerformers(t.gallery, r.Gallery(), r.Performer(), t.cache); err != nil {
+ if err := autotag.GalleryPerformers(ctx, t.gallery, r.Gallery, r.Performer, t.cache); err != nil {
return fmt.Errorf("error tagging gallery performers for %s: %v", t.gallery.Path.String, err)
}
}
if t.studios {
- if err := autotag.GalleryStudios(t.gallery, r.Gallery(), r.Studio(), t.cache); err != nil {
+ if err := autotag.GalleryStudios(ctx, t.gallery, r.Gallery, r.Studio, t.cache); err != nil {
return fmt.Errorf("error tagging gallery studio for %s: %v", t.gallery.Path.String, err)
}
}
if t.tags {
- if err := autotag.GalleryTags(t.gallery, r.Gallery(), r.Tag(), t.cache); err != nil {
+ if err := autotag.GalleryTags(ctx, t.gallery, r.Gallery, r.Tag, t.cache); err != nil {
return fmt.Errorf("error tagging gallery tags for %s: %v", t.gallery.Path.String, err)
}
}
diff --git a/internal/manager/task_clean.go b/internal/manager/task_clean.go
index 2cbd168fb..d165a9eba 100644
--- a/internal/manager/task_clean.go
+++ b/internal/manager/task_clean.go
@@ -18,7 +18,7 @@ import (
)
type cleanJob struct {
- txnManager models.TransactionManager
+ txnManager models.Repository
input CleanMetadataInput
scanSubs *subscriptionManager
}
@@ -29,8 +29,10 @@ func (j *cleanJob) Execute(ctx context.Context, progress *job.Progress) {
logger.Infof("Running in Dry Mode")
}
- if err := j.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
- total, err := j.getCount(r)
+ r := j.txnManager
+
+ if err := j.txnManager.WithTxn(ctx, func(ctx context.Context) error {
+ total, err := j.getCount(ctx, r)
if err != nil {
return fmt.Errorf("error getting count: %w", err)
}
@@ -41,13 +43,13 @@ func (j *cleanJob) Execute(ctx context.Context, progress *job.Progress) {
return nil
}
- if err := j.processScenes(ctx, progress, r.Scene()); err != nil {
+ if err := j.processScenes(ctx, progress, r.Scene); err != nil {
return fmt.Errorf("error cleaning scenes: %w", err)
}
- if err := j.processImages(ctx, progress, r.Image()); err != nil {
+ if err := j.processImages(ctx, progress, r.Image); err != nil {
return fmt.Errorf("error cleaning images: %w", err)
}
- if err := j.processGalleries(ctx, progress, r.Gallery(), r.Image()); err != nil {
+ if err := j.processGalleries(ctx, progress, r.Gallery, r.Image); err != nil {
return fmt.Errorf("error cleaning galleries: %w", err)
}
@@ -66,9 +68,9 @@ func (j *cleanJob) Execute(ctx context.Context, progress *job.Progress) {
logger.Info("Finished Cleaning")
}
-func (j *cleanJob) getCount(r models.ReaderRepository) (int, error) {
+func (j *cleanJob) getCount(ctx context.Context, r models.Repository) (int, error) {
sceneFilter := scene.PathsFilter(j.input.Paths)
- sceneResult, err := r.Scene().Query(models.SceneQueryOptions{
+ sceneResult, err := r.Scene.Query(ctx, models.SceneQueryOptions{
QueryOptions: models.QueryOptions{
Count: true,
},
@@ -78,12 +80,12 @@ func (j *cleanJob) getCount(r models.ReaderRepository) (int, error) {
return 0, err
}
- imageCount, err := r.Image().QueryCount(image.PathsFilter(j.input.Paths), nil)
+ imageCount, err := r.Image.QueryCount(ctx, image.PathsFilter(j.input.Paths), nil)
if err != nil {
return 0, err
}
- galleryCount, err := r.Gallery().QueryCount(gallery.PathsFilter(j.input.Paths), nil)
+ galleryCount, err := r.Gallery.QueryCount(ctx, gallery.PathsFilter(j.input.Paths), nil)
if err != nil {
return 0, err
}
@@ -91,7 +93,7 @@ func (j *cleanJob) getCount(r models.ReaderRepository) (int, error) {
return sceneResult.Count + imageCount + galleryCount, nil
}
-func (j *cleanJob) processScenes(ctx context.Context, progress *job.Progress, qb models.SceneReader) error {
+func (j *cleanJob) processScenes(ctx context.Context, progress *job.Progress, qb scene.Queryer) error {
batchSize := 1000
findFilter := models.BatchFindFilter(batchSize)
@@ -107,7 +109,7 @@ func (j *cleanJob) processScenes(ctx context.Context, progress *job.Progress, qb
return nil
}
- scenes, err := scene.Query(qb, sceneFilter, findFilter)
+ scenes, err := scene.Query(ctx, qb, sceneFilter, findFilter)
if err != nil {
return fmt.Errorf("error querying for scenes: %w", err)
}
@@ -154,7 +156,7 @@ func (j *cleanJob) processScenes(ctx context.Context, progress *job.Progress, qb
return nil
}
-func (j *cleanJob) processGalleries(ctx context.Context, progress *job.Progress, qb models.GalleryReader, iqb models.ImageReader) error {
+func (j *cleanJob) processGalleries(ctx context.Context, progress *job.Progress, qb gallery.Queryer, iqb models.ImageReader) error {
batchSize := 1000
findFilter := models.BatchFindFilter(batchSize)
@@ -170,14 +172,14 @@ func (j *cleanJob) processGalleries(ctx context.Context, progress *job.Progress,
return nil
}
- galleries, _, err := qb.Query(galleryFilter, findFilter)
+ galleries, _, err := qb.Query(ctx, galleryFilter, findFilter)
if err != nil {
return fmt.Errorf("error querying for galleries: %w", err)
}
for _, gallery := range galleries {
progress.ExecuteTask(fmt.Sprintf("Assessing gallery %s for clean", gallery.GetTitle()), func() {
- if j.shouldCleanGallery(gallery, iqb) {
+ if j.shouldCleanGallery(ctx, gallery, iqb) {
toDelete = append(toDelete, gallery.ID)
} else {
// increment progress, no further processing
@@ -215,7 +217,7 @@ func (j *cleanJob) processGalleries(ctx context.Context, progress *job.Progress,
return nil
}
-func (j *cleanJob) processImages(ctx context.Context, progress *job.Progress, qb models.ImageReader) error {
+func (j *cleanJob) processImages(ctx context.Context, progress *job.Progress, qb image.Queryer) error {
batchSize := 1000
findFilter := models.BatchFindFilter(batchSize)
@@ -234,7 +236,7 @@ func (j *cleanJob) processImages(ctx context.Context, progress *job.Progress, qb
return nil
}
- images, err := image.Query(qb, imageFilter, findFilter)
+ images, err := image.Query(ctx, qb, imageFilter, findFilter)
if err != nil {
return fmt.Errorf("error querying for images: %w", err)
}
@@ -318,7 +320,7 @@ func (j *cleanJob) shouldCleanScene(s *models.Scene) bool {
return false
}
-func (j *cleanJob) shouldCleanGallery(g *models.Gallery, qb models.ImageReader) bool {
+func (j *cleanJob) shouldCleanGallery(ctx context.Context, g *models.Gallery, qb models.ImageReader) bool {
// never clean manually created galleries
if !g.Path.Valid {
return false
@@ -348,7 +350,7 @@ func (j *cleanJob) shouldCleanGallery(g *models.Gallery, qb models.ImageReader)
}
} else {
// folder-based - delete if it has no images
- count, err := qb.CountByGalleryID(g.ID)
+ count, err := qb.CountByGalleryID(ctx, g.ID)
if err != nil {
logger.Warnf("Error trying to count gallery images for %q: %v", path, err)
return false
@@ -401,16 +403,17 @@ func (j *cleanJob) deleteScene(ctx context.Context, fileNamingAlgorithm models.H
Paths: GetInstance().Paths,
}
var s *models.Scene
- if err := j.txnManager.WithTxn(ctx, func(repo models.Repository) error {
- qb := repo.Scene()
+ if err := j.txnManager.WithTxn(ctx, func(ctx context.Context) error {
+ repo := j.txnManager
+ qb := repo.Scene
var err error
- s, err = qb.Find(sceneID)
+ s, err = qb.Find(ctx, sceneID)
if err != nil {
return err
}
- return scene.Destroy(s, repo, fileDeleter, true, false)
+ return scene.Destroy(ctx, s, repo.Scene, repo.SceneMarker, fileDeleter, true, false)
}); err != nil {
fileDeleter.Rollback()
@@ -431,16 +434,16 @@ func (j *cleanJob) deleteScene(ctx context.Context, fileNamingAlgorithm models.H
func (j *cleanJob) deleteGallery(ctx context.Context, galleryID int) {
var g *models.Gallery
- if err := j.txnManager.WithTxn(ctx, func(repo models.Repository) error {
- qb := repo.Gallery()
+ if err := j.txnManager.WithTxn(ctx, func(ctx context.Context) error {
+ qb := j.txnManager.Gallery
var err error
- g, err = qb.Find(galleryID)
+ g, err = qb.Find(ctx, galleryID)
if err != nil {
return err
}
- return qb.Destroy(galleryID)
+ return qb.Destroy(ctx, galleryID)
}); err != nil {
logger.Errorf("Error deleting gallery from database: %s", err.Error())
return
@@ -459,11 +462,11 @@ func (j *cleanJob) deleteImage(ctx context.Context, imageID int) {
}
var i *models.Image
- if err := j.txnManager.WithTxn(ctx, func(repo models.Repository) error {
- qb := repo.Image()
+ if err := j.txnManager.WithTxn(ctx, func(ctx context.Context) error {
+ qb := j.txnManager.Image
var err error
- i, err = qb.Find(imageID)
+ i, err = qb.Find(ctx, imageID)
if err != nil {
return err
}
@@ -472,7 +475,7 @@ func (j *cleanJob) deleteImage(ctx context.Context, imageID int) {
return fmt.Errorf("image not found: %d", imageID)
}
- return image.Destroy(i, qb, fileDeleter, true, false)
+ return image.Destroy(ctx, i, qb, fileDeleter, true, false)
}); err != nil {
fileDeleter.Rollback()
diff --git a/internal/manager/task_export.go b/internal/manager/task_export.go
index 512b7e42f..3219252cb 100644
--- a/internal/manager/task_export.go
+++ b/internal/manager/task_export.go
@@ -32,7 +32,7 @@ import (
)
type ExportTask struct {
- txnManager models.TransactionManager
+ txnManager models.Repository
full bool
baseDir string
@@ -100,7 +100,7 @@ func CreateExportTask(a models.HashAlgorithm, input ExportObjectsInput) *ExportT
}
return &ExportTask{
- txnManager: GetInstance().TxnManager,
+ txnManager: GetInstance().Repository,
fileNamingAlgorithm: a,
scenes: newExportSpec(input.Scenes),
images: newExportSpec(input.Images),
@@ -146,30 +146,32 @@ func (t *ExportTask) Start(ctx context.Context, wg *sync.WaitGroup) {
paths.EnsureJSONDirs(t.baseDir)
- txnErr := t.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
+ txnErr := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
+ r := t.txnManager
+
// include movie scenes and gallery images
if !t.full {
// only include movie scenes if includeDependencies is also set
if !t.scenes.all && t.includeDependencies {
- t.populateMovieScenes(r)
+ t.populateMovieScenes(ctx, r)
}
// always export gallery images
if !t.images.all {
- t.populateGalleryImages(r)
+ t.populateGalleryImages(ctx, r)
}
}
- t.ExportScenes(workerCount, r)
- t.ExportImages(workerCount, r)
- t.ExportGalleries(workerCount, r)
- t.ExportMovies(workerCount, r)
- t.ExportPerformers(workerCount, r)
- t.ExportStudios(workerCount, r)
- t.ExportTags(workerCount, r)
+ t.ExportScenes(ctx, workerCount, r)
+ t.ExportImages(ctx, workerCount, r)
+ t.ExportGalleries(ctx, workerCount, r)
+ t.ExportMovies(ctx, workerCount, r)
+ t.ExportPerformers(ctx, workerCount, r)
+ t.ExportStudios(ctx, workerCount, r)
+ t.ExportTags(ctx, workerCount, r)
if t.full {
- t.ExportScrapedItems(r)
+ t.ExportScrapedItems(ctx, r)
}
return nil
@@ -284,17 +286,17 @@ func (t *ExportTask) zipFile(fn, outDir string, z *zip.Writer) error {
return nil
}
-func (t *ExportTask) populateMovieScenes(repo models.ReaderRepository) {
- reader := repo.Movie()
- sceneReader := repo.Scene()
+func (t *ExportTask) populateMovieScenes(ctx context.Context, repo models.Repository) {
+ reader := repo.Movie
+ sceneReader := repo.Scene
var movies []*models.Movie
var err error
all := t.full || (t.movies != nil && t.movies.all)
if all {
- movies, err = reader.All()
+ movies, err = reader.All(ctx)
} else if t.movies != nil && len(t.movies.IDs) > 0 {
- movies, err = reader.FindMany(t.movies.IDs)
+ movies, err = reader.FindMany(ctx, t.movies.IDs)
}
if err != nil {
@@ -302,7 +304,7 @@ func (t *ExportTask) populateMovieScenes(repo models.ReaderRepository) {
}
for _, m := range movies {
- scenes, err := sceneReader.FindByMovieID(m.ID)
+ scenes, err := sceneReader.FindByMovieID(ctx, m.ID)
if err != nil {
logger.Errorf("[movies] <%s> failed to fetch scenes for movie: %s", m.Checksum, err.Error())
continue
@@ -314,17 +316,17 @@ func (t *ExportTask) populateMovieScenes(repo models.ReaderRepository) {
}
}
-func (t *ExportTask) populateGalleryImages(repo models.ReaderRepository) {
- reader := repo.Gallery()
- imageReader := repo.Image()
+func (t *ExportTask) populateGalleryImages(ctx context.Context, repo models.Repository) {
+ reader := repo.Gallery
+ imageReader := repo.Image
var galleries []*models.Gallery
var err error
all := t.full || (t.galleries != nil && t.galleries.all)
if all {
- galleries, err = reader.All()
+ galleries, err = reader.All(ctx)
} else if t.galleries != nil && len(t.galleries.IDs) > 0 {
- galleries, err = reader.FindMany(t.galleries.IDs)
+ galleries, err = reader.FindMany(ctx, t.galleries.IDs)
}
if err != nil {
@@ -332,7 +334,7 @@ func (t *ExportTask) populateGalleryImages(repo models.ReaderRepository) {
}
for _, g := range galleries {
- images, err := imageReader.FindByGalleryID(g.ID)
+ images, err := imageReader.FindByGalleryID(ctx, g.ID)
if err != nil {
logger.Errorf("[galleries] <%s> failed to fetch images for gallery: %s", g.Checksum, err.Error())
continue
@@ -344,18 +346,18 @@ func (t *ExportTask) populateGalleryImages(repo models.ReaderRepository) {
}
}
-func (t *ExportTask) ExportScenes(workers int, repo models.ReaderRepository) {
+func (t *ExportTask) ExportScenes(ctx context.Context, workers int, repo models.Repository) {
var scenesWg sync.WaitGroup
- sceneReader := repo.Scene()
+ sceneReader := repo.Scene
var scenes []*models.Scene
var err error
all := t.full || (t.scenes != nil && t.scenes.all)
if all {
- scenes, err = sceneReader.All()
+ scenes, err = sceneReader.All(ctx)
} else if t.scenes != nil && len(t.scenes.IDs) > 0 {
- scenes, err = sceneReader.FindMany(t.scenes.IDs)
+ scenes, err = sceneReader.FindMany(ctx, t.scenes.IDs)
}
if err != nil {
@@ -369,7 +371,7 @@ func (t *ExportTask) ExportScenes(workers int, repo models.ReaderRepository) {
for w := 0; w < workers; w++ { // create export Scene workers
scenesWg.Add(1)
- go exportScene(&scenesWg, jobCh, repo, t)
+ go exportScene(ctx, &scenesWg, jobCh, repo, t)
}
for i, scene := range scenes {
@@ -388,32 +390,32 @@ func (t *ExportTask) ExportScenes(workers int, repo models.ReaderRepository) {
logger.Infof("[scenes] export complete in %s. %d workers used.", time.Since(startTime), workers)
}
-func exportScene(wg *sync.WaitGroup, jobChan <-chan *models.Scene, repo models.ReaderRepository, t *ExportTask) {
+func exportScene(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Scene, repo models.Repository, t *ExportTask) {
defer wg.Done()
- sceneReader := repo.Scene()
- studioReader := repo.Studio()
- movieReader := repo.Movie()
- galleryReader := repo.Gallery()
- performerReader := repo.Performer()
- tagReader := repo.Tag()
- sceneMarkerReader := repo.SceneMarker()
+ sceneReader := repo.Scene
+ studioReader := repo.Studio
+ movieReader := repo.Movie
+ galleryReader := repo.Gallery
+ performerReader := repo.Performer
+ tagReader := repo.Tag
+ sceneMarkerReader := repo.SceneMarker
for s := range jobChan {
sceneHash := s.GetHash(t.fileNamingAlgorithm)
- newSceneJSON, err := scene.ToBasicJSON(sceneReader, s)
+ newSceneJSON, err := scene.ToBasicJSON(ctx, sceneReader, s)
if err != nil {
logger.Errorf("[scenes] <%s> error getting scene JSON: %s", sceneHash, err.Error())
continue
}
- newSceneJSON.Studio, err = scene.GetStudioName(studioReader, s)
+ newSceneJSON.Studio, err = scene.GetStudioName(ctx, studioReader, s)
if err != nil {
logger.Errorf("[scenes] <%s> error getting scene studio name: %s", sceneHash, err.Error())
continue
}
- galleries, err := galleryReader.FindBySceneID(s.ID)
+ galleries, err := galleryReader.FindBySceneID(ctx, s.ID)
if err != nil {
logger.Errorf("[scenes] <%s> error getting scene gallery checksums: %s", sceneHash, err.Error())
continue
@@ -421,7 +423,7 @@ func exportScene(wg *sync.WaitGroup, jobChan <-chan *models.Scene, repo models.R
newSceneJSON.Galleries = gallery.GetChecksums(galleries)
- performers, err := performerReader.FindBySceneID(s.ID)
+ performers, err := performerReader.FindBySceneID(ctx, s.ID)
if err != nil {
logger.Errorf("[scenes] <%s> error getting scene performer names: %s", sceneHash, err.Error())
continue
@@ -429,19 +431,19 @@ func exportScene(wg *sync.WaitGroup, jobChan <-chan *models.Scene, repo models.R
newSceneJSON.Performers = performer.GetNames(performers)
- newSceneJSON.Tags, err = scene.GetTagNames(tagReader, s)
+ newSceneJSON.Tags, err = scene.GetTagNames(ctx, tagReader, s)
if err != nil {
logger.Errorf("[scenes] <%s> error getting scene tag names: %s", sceneHash, err.Error())
continue
}
- newSceneJSON.Markers, err = scene.GetSceneMarkersJSON(sceneMarkerReader, tagReader, s)
+ newSceneJSON.Markers, err = scene.GetSceneMarkersJSON(ctx, sceneMarkerReader, tagReader, s)
if err != nil {
logger.Errorf("[scenes] <%s> error getting scene markers JSON: %s", sceneHash, err.Error())
continue
}
- newSceneJSON.Movies, err = scene.GetSceneMoviesJSON(movieReader, sceneReader, s)
+ newSceneJSON.Movies, err = scene.GetSceneMoviesJSON(ctx, movieReader, sceneReader, s)
if err != nil {
logger.Errorf("[scenes] <%s> error getting scene movies JSON: %s", sceneHash, err.Error())
continue
@@ -454,14 +456,14 @@ func exportScene(wg *sync.WaitGroup, jobChan <-chan *models.Scene, repo models.R
t.galleries.IDs = intslice.IntAppendUniques(t.galleries.IDs, gallery.GetIDs(galleries))
- tagIDs, err := scene.GetDependentTagIDs(tagReader, sceneMarkerReader, s)
+ tagIDs, err := scene.GetDependentTagIDs(ctx, tagReader, sceneMarkerReader, s)
if err != nil {
logger.Errorf("[scenes] <%s> error getting scene tags: %s", sceneHash, err.Error())
continue
}
t.tags.IDs = intslice.IntAppendUniques(t.tags.IDs, tagIDs)
- movieIDs, err := scene.GetDependentMovieIDs(sceneReader, s)
+ movieIDs, err := scene.GetDependentMovieIDs(ctx, sceneReader, s)
if err != nil {
logger.Errorf("[scenes] <%s> error getting scene movies: %s", sceneHash, err.Error())
continue
@@ -482,18 +484,18 @@ func exportScene(wg *sync.WaitGroup, jobChan <-chan *models.Scene, repo models.R
}
}
-func (t *ExportTask) ExportImages(workers int, repo models.ReaderRepository) {
+func (t *ExportTask) ExportImages(ctx context.Context, workers int, repo models.Repository) {
var imagesWg sync.WaitGroup
- imageReader := repo.Image()
+ imageReader := repo.Image
var images []*models.Image
var err error
all := t.full || (t.images != nil && t.images.all)
if all {
- images, err = imageReader.All()
+ images, err = imageReader.All(ctx)
} else if t.images != nil && len(t.images.IDs) > 0 {
- images, err = imageReader.FindMany(t.images.IDs)
+ images, err = imageReader.FindMany(ctx, t.images.IDs)
}
if err != nil {
@@ -507,7 +509,7 @@ func (t *ExportTask) ExportImages(workers int, repo models.ReaderRepository) {
for w := 0; w < workers; w++ { // create export Image workers
imagesWg.Add(1)
- go exportImage(&imagesWg, jobCh, repo, t)
+ go exportImage(ctx, &imagesWg, jobCh, repo, t)
}
for i, image := range images {
@@ -526,12 +528,12 @@ func (t *ExportTask) ExportImages(workers int, repo models.ReaderRepository) {
logger.Infof("[images] export complete in %s. %d workers used.", time.Since(startTime), workers)
}
-func exportImage(wg *sync.WaitGroup, jobChan <-chan *models.Image, repo models.ReaderRepository, t *ExportTask) {
+func exportImage(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Image, repo models.Repository, t *ExportTask) {
defer wg.Done()
- studioReader := repo.Studio()
- galleryReader := repo.Gallery()
- performerReader := repo.Performer()
- tagReader := repo.Tag()
+ studioReader := repo.Studio
+ galleryReader := repo.Gallery
+ performerReader := repo.Performer
+ tagReader := repo.Tag
for s := range jobChan {
imageHash := s.Checksum
@@ -539,13 +541,13 @@ func exportImage(wg *sync.WaitGroup, jobChan <-chan *models.Image, repo models.R
newImageJSON := image.ToBasicJSON(s)
var err error
- newImageJSON.Studio, err = image.GetStudioName(studioReader, s)
+ newImageJSON.Studio, err = image.GetStudioName(ctx, studioReader, s)
if err != nil {
logger.Errorf("[images] <%s> error getting image studio name: %s", imageHash, err.Error())
continue
}
- imageGalleries, err := galleryReader.FindByImageID(s.ID)
+ imageGalleries, err := galleryReader.FindByImageID(ctx, s.ID)
if err != nil {
logger.Errorf("[images] <%s> error getting image galleries: %s", imageHash, err.Error())
continue
@@ -553,7 +555,7 @@ func exportImage(wg *sync.WaitGroup, jobChan <-chan *models.Image, repo models.R
newImageJSON.Galleries = t.getGalleryChecksums(imageGalleries)
- performers, err := performerReader.FindByImageID(s.ID)
+ performers, err := performerReader.FindByImageID(ctx, s.ID)
if err != nil {
logger.Errorf("[images] <%s> error getting image performer names: %s", imageHash, err.Error())
continue
@@ -561,7 +563,7 @@ func exportImage(wg *sync.WaitGroup, jobChan <-chan *models.Image, repo models.R
newImageJSON.Performers = performer.GetNames(performers)
- tags, err := tagReader.FindByImageID(s.ID)
+ tags, err := tagReader.FindByImageID(ctx, s.ID)
if err != nil {
logger.Errorf("[images] <%s> error getting image tag names: %s", imageHash, err.Error())
continue
@@ -597,18 +599,18 @@ func (t *ExportTask) getGalleryChecksums(galleries []*models.Gallery) (ret []str
return
}
-func (t *ExportTask) ExportGalleries(workers int, repo models.ReaderRepository) {
+func (t *ExportTask) ExportGalleries(ctx context.Context, workers int, repo models.Repository) {
var galleriesWg sync.WaitGroup
- reader := repo.Gallery()
+ reader := repo.Gallery
var galleries []*models.Gallery
var err error
all := t.full || (t.galleries != nil && t.galleries.all)
if all {
- galleries, err = reader.All()
+ galleries, err = reader.All(ctx)
} else if t.galleries != nil && len(t.galleries.IDs) > 0 {
- galleries, err = reader.FindMany(t.galleries.IDs)
+ galleries, err = reader.FindMany(ctx, t.galleries.IDs)
}
if err != nil {
@@ -622,7 +624,7 @@ func (t *ExportTask) ExportGalleries(workers int, repo models.ReaderRepository)
for w := 0; w < workers; w++ { // create export Scene workers
galleriesWg.Add(1)
- go exportGallery(&galleriesWg, jobCh, repo, t)
+ go exportGallery(ctx, &galleriesWg, jobCh, repo, t)
}
for i, gallery := range galleries {
@@ -646,11 +648,11 @@ func (t *ExportTask) ExportGalleries(workers int, repo models.ReaderRepository)
logger.Infof("[galleries] export complete in %s. %d workers used.", time.Since(startTime), workers)
}
-func exportGallery(wg *sync.WaitGroup, jobChan <-chan *models.Gallery, repo models.ReaderRepository, t *ExportTask) {
+func exportGallery(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Gallery, repo models.Repository, t *ExportTask) {
defer wg.Done()
- studioReader := repo.Studio()
- performerReader := repo.Performer()
- tagReader := repo.Tag()
+ studioReader := repo.Studio
+ performerReader := repo.Performer
+ tagReader := repo.Tag
for g := range jobChan {
galleryHash := g.Checksum
@@ -661,13 +663,13 @@ func exportGallery(wg *sync.WaitGroup, jobChan <-chan *models.Gallery, repo mode
continue
}
- newGalleryJSON.Studio, err = gallery.GetStudioName(studioReader, g)
+ newGalleryJSON.Studio, err = gallery.GetStudioName(ctx, studioReader, g)
if err != nil {
logger.Errorf("[galleries] <%s> error getting gallery studio name: %s", galleryHash, err.Error())
continue
}
- performers, err := performerReader.FindByGalleryID(g.ID)
+ performers, err := performerReader.FindByGalleryID(ctx, g.ID)
if err != nil {
logger.Errorf("[galleries] <%s> error getting gallery performer names: %s", galleryHash, err.Error())
continue
@@ -675,7 +677,7 @@ func exportGallery(wg *sync.WaitGroup, jobChan <-chan *models.Gallery, repo mode
newGalleryJSON.Performers = performer.GetNames(performers)
- tags, err := tagReader.FindByGalleryID(g.ID)
+ tags, err := tagReader.FindByGalleryID(ctx, g.ID)
if err != nil {
logger.Errorf("[galleries] <%s> error getting gallery tag names: %s", galleryHash, err.Error())
continue
@@ -703,17 +705,17 @@ func exportGallery(wg *sync.WaitGroup, jobChan <-chan *models.Gallery, repo mode
}
}
-func (t *ExportTask) ExportPerformers(workers int, repo models.ReaderRepository) {
+func (t *ExportTask) ExportPerformers(ctx context.Context, workers int, repo models.Repository) {
var performersWg sync.WaitGroup
- reader := repo.Performer()
+ reader := repo.Performer
var performers []*models.Performer
var err error
all := t.full || (t.performers != nil && t.performers.all)
if all {
- performers, err = reader.All()
+ performers, err = reader.All(ctx)
} else if t.performers != nil && len(t.performers.IDs) > 0 {
- performers, err = reader.FindMany(t.performers.IDs)
+ performers, err = reader.FindMany(ctx, t.performers.IDs)
}
if err != nil {
@@ -726,7 +728,7 @@ func (t *ExportTask) ExportPerformers(workers int, repo models.ReaderRepository)
for w := 0; w < workers; w++ { // create export Performer workers
performersWg.Add(1)
- go t.exportPerformer(&performersWg, jobCh, repo)
+ go t.exportPerformer(ctx, &performersWg, jobCh, repo)
}
for i, performer := range performers {
@@ -743,20 +745,20 @@ func (t *ExportTask) ExportPerformers(workers int, repo models.ReaderRepository)
logger.Infof("[performers] export complete in %s. %d workers used.", time.Since(startTime), workers)
}
-func (t *ExportTask) exportPerformer(wg *sync.WaitGroup, jobChan <-chan *models.Performer, repo models.ReaderRepository) {
+func (t *ExportTask) exportPerformer(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Performer, repo models.Repository) {
defer wg.Done()
- performerReader := repo.Performer()
+ performerReader := repo.Performer
for p := range jobChan {
- newPerformerJSON, err := performer.ToJSON(performerReader, p)
+ newPerformerJSON, err := performer.ToJSON(ctx, performerReader, p)
if err != nil {
logger.Errorf("[performers] <%s> error getting performer JSON: %s", p.Checksum, err.Error())
continue
}
- tags, err := repo.Tag().FindByPerformerID(p.ID)
+ tags, err := repo.Tag.FindByPerformerID(ctx, p.ID)
if err != nil {
logger.Errorf("[performers] <%s> error getting performer tags: %s", p.Checksum, err.Error())
continue
@@ -781,17 +783,17 @@ func (t *ExportTask) exportPerformer(wg *sync.WaitGroup, jobChan <-chan *models.
}
}
-func (t *ExportTask) ExportStudios(workers int, repo models.ReaderRepository) {
+func (t *ExportTask) ExportStudios(ctx context.Context, workers int, repo models.Repository) {
var studiosWg sync.WaitGroup
- reader := repo.Studio()
+ reader := repo.Studio
var studios []*models.Studio
var err error
all := t.full || (t.studios != nil && t.studios.all)
if all {
- studios, err = reader.All()
+ studios, err = reader.All(ctx)
} else if t.studios != nil && len(t.studios.IDs) > 0 {
- studios, err = reader.FindMany(t.studios.IDs)
+ studios, err = reader.FindMany(ctx, t.studios.IDs)
}
if err != nil {
@@ -805,7 +807,7 @@ func (t *ExportTask) ExportStudios(workers int, repo models.ReaderRepository) {
for w := 0; w < workers; w++ { // create export Studio workers
studiosWg.Add(1)
- go t.exportStudio(&studiosWg, jobCh, repo)
+ go t.exportStudio(ctx, &studiosWg, jobCh, repo)
}
for i, studio := range studios {
@@ -822,13 +824,13 @@ func (t *ExportTask) ExportStudios(workers int, repo models.ReaderRepository) {
logger.Infof("[studios] export complete in %s. %d workers used.", time.Since(startTime), workers)
}
-func (t *ExportTask) exportStudio(wg *sync.WaitGroup, jobChan <-chan *models.Studio, repo models.ReaderRepository) {
+func (t *ExportTask) exportStudio(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Studio, repo models.Repository) {
defer wg.Done()
- studioReader := repo.Studio()
+ studioReader := repo.Studio
for s := range jobChan {
- newStudioJSON, err := studio.ToJSON(studioReader, s)
+ newStudioJSON, err := studio.ToJSON(ctx, studioReader, s)
if err != nil {
logger.Errorf("[studios] <%s> error getting studio JSON: %s", s.Checksum, err.Error())
@@ -846,17 +848,17 @@ func (t *ExportTask) exportStudio(wg *sync.WaitGroup, jobChan <-chan *models.Stu
}
}
-func (t *ExportTask) ExportTags(workers int, repo models.ReaderRepository) {
+func (t *ExportTask) ExportTags(ctx context.Context, workers int, repo models.Repository) {
var tagsWg sync.WaitGroup
- reader := repo.Tag()
+ reader := repo.Tag
var tags []*models.Tag
var err error
all := t.full || (t.tags != nil && t.tags.all)
if all {
- tags, err = reader.All()
+ tags, err = reader.All(ctx)
} else if t.tags != nil && len(t.tags.IDs) > 0 {
- tags, err = reader.FindMany(t.tags.IDs)
+ tags, err = reader.FindMany(ctx, t.tags.IDs)
}
if err != nil {
@@ -870,7 +872,7 @@ func (t *ExportTask) ExportTags(workers int, repo models.ReaderRepository) {
for w := 0; w < workers; w++ { // create export Tag workers
tagsWg.Add(1)
- go t.exportTag(&tagsWg, jobCh, repo)
+ go t.exportTag(ctx, &tagsWg, jobCh, repo)
}
for i, tag := range tags {
@@ -890,13 +892,13 @@ func (t *ExportTask) ExportTags(workers int, repo models.ReaderRepository) {
logger.Infof("[tags] export complete in %s. %d workers used.", time.Since(startTime), workers)
}
-func (t *ExportTask) exportTag(wg *sync.WaitGroup, jobChan <-chan *models.Tag, repo models.ReaderRepository) {
+func (t *ExportTask) exportTag(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Tag, repo models.Repository) {
defer wg.Done()
- tagReader := repo.Tag()
+ tagReader := repo.Tag
for thisTag := range jobChan {
- newTagJSON, err := tag.ToJSON(tagReader, thisTag)
+ newTagJSON, err := tag.ToJSON(ctx, tagReader, thisTag)
if err != nil {
logger.Errorf("[tags] <%s> error getting tag JSON: %s", thisTag.Name, err.Error())
@@ -917,17 +919,17 @@ func (t *ExportTask) exportTag(wg *sync.WaitGroup, jobChan <-chan *models.Tag, r
}
}
-func (t *ExportTask) ExportMovies(workers int, repo models.ReaderRepository) {
+func (t *ExportTask) ExportMovies(ctx context.Context, workers int, repo models.Repository) {
var moviesWg sync.WaitGroup
- reader := repo.Movie()
+ reader := repo.Movie
var movies []*models.Movie
var err error
all := t.full || (t.movies != nil && t.movies.all)
if all {
- movies, err = reader.All()
+ movies, err = reader.All(ctx)
} else if t.movies != nil && len(t.movies.IDs) > 0 {
- movies, err = reader.FindMany(t.movies.IDs)
+ movies, err = reader.FindMany(ctx, t.movies.IDs)
}
if err != nil {
@@ -941,7 +943,7 @@ func (t *ExportTask) ExportMovies(workers int, repo models.ReaderRepository) {
for w := 0; w < workers; w++ { // create export Studio workers
moviesWg.Add(1)
- go t.exportMovie(&moviesWg, jobCh, repo)
+ go t.exportMovie(ctx, &moviesWg, jobCh, repo)
}
for i, movie := range movies {
@@ -958,14 +960,14 @@ func (t *ExportTask) ExportMovies(workers int, repo models.ReaderRepository) {
logger.Infof("[movies] export complete in %s. %d workers used.", time.Since(startTime), workers)
}
-func (t *ExportTask) exportMovie(wg *sync.WaitGroup, jobChan <-chan *models.Movie, repo models.ReaderRepository) {
+func (t *ExportTask) exportMovie(ctx context.Context, wg *sync.WaitGroup, jobChan <-chan *models.Movie, repo models.Repository) {
defer wg.Done()
- movieReader := repo.Movie()
- studioReader := repo.Studio()
+ movieReader := repo.Movie
+ studioReader := repo.Studio
for m := range jobChan {
- newMovieJSON, err := movie.ToJSON(movieReader, studioReader, m)
+ newMovieJSON, err := movie.ToJSON(ctx, movieReader, studioReader, m)
if err != nil {
logger.Errorf("[movies] <%s> error getting tag JSON: %s", m.Checksum, err.Error())
@@ -991,10 +993,10 @@ func (t *ExportTask) exportMovie(wg *sync.WaitGroup, jobChan <-chan *models.Movi
}
}
-func (t *ExportTask) ExportScrapedItems(repo models.ReaderRepository) {
- qb := repo.ScrapedItem()
- sqb := repo.Studio()
- scrapedItems, err := qb.All()
+func (t *ExportTask) ExportScrapedItems(ctx context.Context, repo models.Repository) {
+ qb := repo.ScrapedItem
+ sqb := repo.Studio
+ scrapedItems, err := qb.All(ctx)
if err != nil {
logger.Errorf("[scraped sites] failed to fetch all items: %s", err.Error())
}
@@ -1009,7 +1011,7 @@ func (t *ExportTask) ExportScrapedItems(repo models.ReaderRepository) {
var studioName string
if scrapedItem.StudioID.Valid {
- studio, _ := sqb.Find(int(scrapedItem.StudioID.Int64))
+ studio, _ := sqb.Find(ctx, int(scrapedItem.StudioID.Int64))
if studio != nil {
studioName = studio.Name.String
}
diff --git a/internal/manager/task_generate.go b/internal/manager/task_generate.go
index bc385ab22..a8f71f7c6 100644
--- a/internal/manager/task_generate.go
+++ b/internal/manager/task_generate.go
@@ -54,7 +54,7 @@ type GeneratePreviewOptionsInput struct {
const generateQueueSize = 200000
type GenerateJob struct {
- txnManager models.TransactionManager
+ txnManager models.Repository
input GenerateMetadataInput
overwrite bool
@@ -110,20 +110,20 @@ func (j *GenerateJob) Execute(ctx context.Context, progress *job.Progress) {
Overwrite: j.overwrite,
}
- if err := j.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
- qb := r.Scene()
+ if err := j.txnManager.WithTxn(ctx, func(ctx context.Context) error {
+ qb := j.txnManager.Scene
if len(j.input.SceneIDs) == 0 && len(j.input.MarkerIDs) == 0 {
totals = j.queueTasks(ctx, g, queue)
} else {
if len(j.input.SceneIDs) > 0 {
- scenes, err = qb.FindMany(sceneIDs)
+ scenes, err = qb.FindMany(ctx, sceneIDs)
for _, s := range scenes {
j.queueSceneJobs(ctx, g, s, queue, &totals)
}
}
if len(j.input.MarkerIDs) > 0 {
- markers, err = r.SceneMarker().FindMany(markerIDs)
+ markers, err = j.txnManager.SceneMarker.FindMany(ctx, markerIDs)
if err != nil {
return err
}
@@ -192,13 +192,13 @@ func (j *GenerateJob) queueTasks(ctx context.Context, g *generate.Generator, que
findFilter := models.BatchFindFilter(batchSize)
- if err := j.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
+ if err := j.txnManager.WithTxn(ctx, func(ctx context.Context) error {
for more := true; more; {
if job.IsCancelled(ctx) {
return context.Canceled
}
- scenes, err := scene.Query(r.Scene(), nil, findFilter)
+ scenes, err := scene.Query(ctx, j.txnManager.Scene, nil, findFilter)
if err != nil {
return err
}
diff --git a/internal/manager/task_generate_interactive_heatmap_speed.go b/internal/manager/task_generate_interactive_heatmap_speed.go
index f6ca0a04e..f9a1e8360 100644
--- a/internal/manager/task_generate_interactive_heatmap_speed.go
+++ b/internal/manager/task_generate_interactive_heatmap_speed.go
@@ -15,7 +15,7 @@ type GenerateInteractiveHeatmapSpeedTask struct {
Scene models.Scene
Overwrite bool
fileNamingAlgorithm models.HashAlgorithm
- TxnManager models.TransactionManager
+ TxnManager models.Repository
}
func (t *GenerateInteractiveHeatmapSpeedTask) GetDescription() string {
@@ -47,22 +47,22 @@ func (t *GenerateInteractiveHeatmapSpeedTask) Start(ctx context.Context) {
var s *models.Scene
- if err := t.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
+ if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error {
var err error
- s, err = r.Scene().FindByPath(t.Scene.Path)
+ s, err = t.TxnManager.Scene.FindByPath(ctx, t.Scene.Path)
return err
}); err != nil {
logger.Error(err.Error())
return
}
- if err := t.TxnManager.WithTxn(ctx, func(r models.Repository) error {
- qb := r.Scene()
+ if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error {
+ qb := t.TxnManager.Scene
scenePartial := models.ScenePartial{
ID: s.ID,
InteractiveSpeed: &median,
}
- _, err := qb.Update(scenePartial)
+ _, err := qb.Update(ctx, scenePartial)
return err
}); err != nil {
logger.Error(err.Error())
diff --git a/internal/manager/task_generate_markers.go b/internal/manager/task_generate_markers.go
index 3ef53ddd0..59ddefe63 100644
--- a/internal/manager/task_generate_markers.go
+++ b/internal/manager/task_generate_markers.go
@@ -13,7 +13,7 @@ import (
)
type GenerateMarkersTask struct {
- TxnManager models.TransactionManager
+ TxnManager models.Repository
Scene *models.Scene
Marker *models.SceneMarker
Overwrite bool
@@ -42,9 +42,9 @@ func (t *GenerateMarkersTask) Start(ctx context.Context) {
if t.Marker != nil {
var scene *models.Scene
- if err := t.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
+ if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error {
var err error
- scene, err = r.Scene().Find(int(t.Marker.SceneID.Int64))
+ scene, err = t.TxnManager.Scene.Find(ctx, int(t.Marker.SceneID.Int64))
return err
}); err != nil {
logger.Errorf("error finding scene for marker: %s", err.Error())
@@ -69,9 +69,9 @@ func (t *GenerateMarkersTask) Start(ctx context.Context) {
func (t *GenerateMarkersTask) generateSceneMarkers(ctx context.Context) {
var sceneMarkers []*models.SceneMarker
- if err := t.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
+ if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error {
var err error
- sceneMarkers, err = r.SceneMarker().FindBySceneID(t.Scene.ID)
+ sceneMarkers, err = t.TxnManager.SceneMarker.FindBySceneID(ctx, t.Scene.ID)
return err
}); err != nil {
logger.Errorf("error getting scene markers: %s", err.Error())
@@ -134,9 +134,9 @@ func (t *GenerateMarkersTask) generateMarker(videoFile *ffmpeg.VideoFile, scene
func (t *GenerateMarkersTask) markersNeeded(ctx context.Context) int {
markers := 0
var sceneMarkers []*models.SceneMarker
- if err := t.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
+ if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error {
var err error
- sceneMarkers, err = r.SceneMarker().FindBySceneID(t.Scene.ID)
+ sceneMarkers, err = t.TxnManager.SceneMarker.FindBySceneID(ctx, t.Scene.ID)
return err
}); err != nil {
logger.Errorf("errror finding scene markers: %s", err.Error())
diff --git a/internal/manager/task_generate_phash.go b/internal/manager/task_generate_phash.go
index 880bb7794..b4350cc8b 100644
--- a/internal/manager/task_generate_phash.go
+++ b/internal/manager/task_generate_phash.go
@@ -14,7 +14,7 @@ type GeneratePhashTask struct {
Scene models.Scene
Overwrite bool
fileNamingAlgorithm models.HashAlgorithm
- txnManager models.TransactionManager
+ txnManager models.Repository
}
func (t *GeneratePhashTask) GetDescription() string {
@@ -40,14 +40,14 @@ func (t *GeneratePhashTask) Start(ctx context.Context) {
return
}
- if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error {
- qb := r.Scene()
+ if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
+ qb := t.txnManager.Scene
hashValue := sql.NullInt64{Int64: int64(*hash), Valid: true}
scenePartial := models.ScenePartial{
ID: t.Scene.ID,
Phash: &hashValue,
}
- _, err := qb.Update(scenePartial)
+ _, err := qb.Update(ctx, scenePartial)
return err
}); err != nil {
logger.Error(err.Error())
diff --git a/internal/manager/task_generate_screenshot.go b/internal/manager/task_generate_screenshot.go
index 80ef9e40d..9b941de8e 100644
--- a/internal/manager/task_generate_screenshot.go
+++ b/internal/manager/task_generate_screenshot.go
@@ -17,7 +17,7 @@ type GenerateScreenshotTask struct {
Scene models.Scene
ScreenshotAt *float64
fileNamingAlgorithm models.HashAlgorithm
- txnManager models.TransactionManager
+ txnManager models.Repository
}
func (t *GenerateScreenshotTask) Start(ctx context.Context) {
@@ -74,8 +74,8 @@ func (t *GenerateScreenshotTask) Start(ctx context.Context) {
return
}
- if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error {
- qb := r.Scene()
+ if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
+ qb := t.txnManager.Scene
updatedTime := time.Now()
updatedScene := models.ScenePartial{
ID: t.Scene.ID,
@@ -87,12 +87,12 @@ func (t *GenerateScreenshotTask) Start(ctx context.Context) {
}
// update the scene cover table
- if err := qb.UpdateCover(t.Scene.ID, coverImageData); err != nil {
+ if err := qb.UpdateCover(ctx, t.Scene.ID, coverImageData); err != nil {
return fmt.Errorf("error setting screenshot: %v", err)
}
// update the scene with the update date
- _, err = qb.Update(updatedScene)
+ _, err = qb.Update(ctx, updatedScene)
if err != nil {
return fmt.Errorf("error updating scene: %v", err)
}
diff --git a/internal/manager/task_identify.go b/internal/manager/task_identify.go
index 457c59dd1..beec6fca9 100644
--- a/internal/manager/task_identify.go
+++ b/internal/manager/task_identify.go
@@ -14,12 +14,12 @@ import (
"github.com/stashapp/stash/pkg/scraper"
"github.com/stashapp/stash/pkg/scraper/stashbox"
"github.com/stashapp/stash/pkg/sliceutil/stringslice"
+ "github.com/stashapp/stash/pkg/txn"
)
var ErrInput = errors.New("invalid request input")
type IdentifyJob struct {
- txnManager models.TransactionManager
postHookExecutor identify.SceneUpdatePostHookExecutor
input identify.Options
@@ -29,7 +29,6 @@ type IdentifyJob struct {
func CreateIdentifyJob(input identify.Options) *IdentifyJob {
return &IdentifyJob{
- txnManager: instance.TxnManager,
postHookExecutor: instance.PluginCache,
input: input,
stashBoxes: instance.Config.GetStashBoxes(),
@@ -52,9 +51,9 @@ func (j *IdentifyJob) Execute(ctx context.Context, progress *job.Progress) {
// if scene ids provided, use those
// otherwise, batch query for all scenes - ordering by path
- if err := j.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
+ if err := txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error {
if len(j.input.SceneIDs) == 0 {
- return j.identifyAllScenes(ctx, r, sources)
+ return j.identifyAllScenes(ctx, sources)
}
sceneIDs, err := stringslice.StringSliceToIntSlice(j.input.SceneIDs)
@@ -70,7 +69,7 @@ func (j *IdentifyJob) Execute(ctx context.Context, progress *job.Progress) {
// find the scene
var err error
- scene, err := r.Scene().Find(id)
+ scene, err := instance.Repository.Scene.Find(ctx, id)
if err != nil {
return fmt.Errorf("error finding scene with id %d: %w", id, err)
}
@@ -88,7 +87,7 @@ func (j *IdentifyJob) Execute(ctx context.Context, progress *job.Progress) {
}
}
-func (j *IdentifyJob) identifyAllScenes(ctx context.Context, r models.ReaderRepository, sources []identify.ScraperSource) error {
+func (j *IdentifyJob) identifyAllScenes(ctx context.Context, sources []identify.ScraperSource) error {
// exclude organised
organised := false
sceneFilter := scene.FilterFromPaths(j.input.Paths)
@@ -102,7 +101,7 @@ func (j *IdentifyJob) identifyAllScenes(ctx context.Context, r models.ReaderRepo
// get the count
pp := 0
findFilter.PerPage = &pp
- countResult, err := r.Scene().Query(models.SceneQueryOptions{
+ countResult, err := instance.Repository.Scene.Query(ctx, models.SceneQueryOptions{
QueryOptions: models.QueryOptions{
FindFilter: findFilter,
Count: true,
@@ -115,7 +114,7 @@ func (j *IdentifyJob) identifyAllScenes(ctx context.Context, r models.ReaderRepo
j.progress.SetTotal(countResult.Count)
- return scene.BatchProcess(ctx, r.Scene(), sceneFilter, findFilter, func(scene *models.Scene) error {
+ return scene.BatchProcess(ctx, instance.Repository.Scene, sceneFilter, findFilter, func(scene *models.Scene) error {
if job.IsCancelled(ctx) {
return nil
}
@@ -133,6 +132,11 @@ func (j *IdentifyJob) identifyScene(ctx context.Context, s *models.Scene, source
var taskError error
j.progress.ExecuteTask("Identifying "+s.Path, func() {
task := identify.SceneIdentifier{
+ SceneReaderUpdater: instance.Repository.Scene,
+ StudioCreator: instance.Repository.Studio,
+ PerformerCreator: instance.Repository.Performer,
+ TagCreator: instance.Repository.Tag,
+
DefaultOptions: j.input.Options,
Sources: sources,
ScreenshotSetter: &scene.PathsScreenshotSetter{
@@ -142,7 +146,7 @@ func (j *IdentifyJob) identifyScene(ctx context.Context, s *models.Scene, source
SceneUpdatePostHookExecutor: j.postHookExecutor,
}
- taskError = task.Identify(ctx, j.txnManager, s)
+ taskError = task.Identify(ctx, instance.Repository, s)
})
if taskError != nil {
@@ -166,7 +170,12 @@ func (j *IdentifyJob) getSources() ([]identify.ScraperSource, error) {
src = identify.ScraperSource{
Name: "stash-box: " + stashBox.Endpoint,
Scraper: stashboxSource{
- stashbox.NewClient(*stashBox, j.txnManager),
+ stashbox.NewClient(*stashBox, instance.Repository, stashbox.Repository{
+ Scene: instance.Repository.Scene,
+ Performer: instance.Repository.Performer,
+ Tag: instance.Repository.Tag,
+ Studio: instance.Repository.Studio,
+ }),
stashBox.Endpoint,
},
RemoteSite: stashBox.Endpoint,
diff --git a/internal/manager/task_import.go b/internal/manager/task_import.go
index 8175bfc59..dccc98354 100644
--- a/internal/manager/task_import.go
+++ b/internal/manager/task_import.go
@@ -12,8 +12,6 @@ import (
"time"
"github.com/99designs/gqlgen/graphql"
- "github.com/stashapp/stash/internal/manager/config"
- "github.com/stashapp/stash/pkg/database"
"github.com/stashapp/stash/pkg/fsutil"
"github.com/stashapp/stash/pkg/gallery"
"github.com/stashapp/stash/pkg/image"
@@ -30,7 +28,7 @@ import (
)
type ImportTask struct {
- txnManager models.TransactionManager
+ txnManager models.Repository
json jsonUtils
BaseDir string
@@ -73,7 +71,7 @@ func CreateImportTask(a models.HashAlgorithm, input ImportObjectsInput) (*Import
}
return &ImportTask{
- txnManager: GetInstance().TxnManager,
+ txnManager: GetInstance().Repository,
BaseDir: baseDir,
TmpZip: tmpZip,
Reset: false,
@@ -126,7 +124,7 @@ func (t *ImportTask) Start(ctx context.Context) {
t.scraped = scraped
if t.Reset {
- err := database.Reset(config.GetInstance().GetDatabasePath())
+ err := t.txnManager.Reset()
if err != nil {
logger.Errorf("Error resetting database: %s", err.Error())
@@ -211,15 +209,16 @@ func (t *ImportTask) ImportPerformers(ctx context.Context) {
logger.Progressf("[performers] %d of %d", index, len(t.mappings.Performers))
- if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error {
- readerWriter := r.Performer()
+ if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
+ r := t.txnManager
+ readerWriter := r.Performer
importer := &performer.Importer{
ReaderWriter: readerWriter,
- TagWriter: r.Tag(),
+ TagWriter: r.Tag,
Input: *performerJSON,
}
- return performImport(importer, t.DuplicateBehaviour)
+ return performImport(ctx, importer, t.DuplicateBehaviour)
}); err != nil {
logger.Errorf("[performers] <%s> import failed: %s", mappingJSON.Checksum, err.Error())
}
@@ -243,8 +242,8 @@ func (t *ImportTask) ImportStudios(ctx context.Context) {
logger.Progressf("[studios] %d of %d", index, len(t.mappings.Studios))
- if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error {
- return t.ImportStudio(studioJSON, pendingParent, r.Studio())
+ if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
+ return t.ImportStudio(ctx, studioJSON, pendingParent, t.txnManager.Studio)
}); err != nil {
if errors.Is(err, studio.ErrParentStudioNotExist) {
// add to the pending parent list so that it is created after the parent
@@ -265,8 +264,8 @@ func (t *ImportTask) ImportStudios(ctx context.Context) {
for _, s := range pendingParent {
for _, orphanStudioJSON := range s {
- if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error {
- return t.ImportStudio(orphanStudioJSON, nil, r.Studio())
+ if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
+ return t.ImportStudio(ctx, orphanStudioJSON, nil, t.txnManager.Studio)
}); err != nil {
logger.Errorf("[studios] <%s> failed to create: %s", orphanStudioJSON.Name, err.Error())
continue
@@ -278,7 +277,7 @@ func (t *ImportTask) ImportStudios(ctx context.Context) {
logger.Info("[studios] import complete")
}
-func (t *ImportTask) ImportStudio(studioJSON *jsonschema.Studio, pendingParent map[string][]*jsonschema.Studio, readerWriter models.StudioReaderWriter) error {
+func (t *ImportTask) ImportStudio(ctx context.Context, studioJSON *jsonschema.Studio, pendingParent map[string][]*jsonschema.Studio, readerWriter studio.NameFinderCreatorUpdater) error {
importer := &studio.Importer{
ReaderWriter: readerWriter,
Input: *studioJSON,
@@ -290,7 +289,7 @@ func (t *ImportTask) ImportStudio(studioJSON *jsonschema.Studio, pendingParent m
importer.MissingRefBehaviour = models.ImportMissingRefEnumFail
}
- if err := performImport(importer, t.DuplicateBehaviour); err != nil {
+ if err := performImport(ctx, importer, t.DuplicateBehaviour); err != nil {
return err
}
@@ -298,7 +297,7 @@ func (t *ImportTask) ImportStudio(studioJSON *jsonschema.Studio, pendingParent m
s := pendingParent[studioJSON.Name]
for _, childStudioJSON := range s {
// map is nil since we're not checking parent studios at this point
- if err := t.ImportStudio(childStudioJSON, nil, readerWriter); err != nil {
+ if err := t.ImportStudio(ctx, childStudioJSON, nil, readerWriter); err != nil {
return fmt.Errorf("failed to create child studio <%s>: %s", childStudioJSON.Name, err.Error())
}
}
@@ -322,9 +321,10 @@ func (t *ImportTask) ImportMovies(ctx context.Context) {
logger.Progressf("[movies] %d of %d", index, len(t.mappings.Movies))
- if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error {
- readerWriter := r.Movie()
- studioReaderWriter := r.Studio()
+ if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
+ r := t.txnManager
+ readerWriter := r.Movie
+ studioReaderWriter := r.Studio
movieImporter := &movie.Importer{
ReaderWriter: readerWriter,
@@ -333,7 +333,7 @@ func (t *ImportTask) ImportMovies(ctx context.Context) {
MissingRefBehaviour: t.MissingRefBehaviour,
}
- return performImport(movieImporter, t.DuplicateBehaviour)
+ return performImport(ctx, movieImporter, t.DuplicateBehaviour)
}); err != nil {
logger.Errorf("[movies] <%s> import failed: %s", mappingJSON.Checksum, err.Error())
continue
@@ -356,11 +356,12 @@ func (t *ImportTask) ImportGalleries(ctx context.Context) {
logger.Progressf("[galleries] %d of %d", index, len(t.mappings.Galleries))
- if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error {
- readerWriter := r.Gallery()
- tagWriter := r.Tag()
- performerWriter := r.Performer()
- studioWriter := r.Studio()
+ if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
+ r := t.txnManager
+ readerWriter := r.Gallery
+ tagWriter := r.Tag
+ performerWriter := r.Performer
+ studioWriter := r.Studio
galleryImporter := &gallery.Importer{
ReaderWriter: readerWriter,
@@ -371,7 +372,7 @@ func (t *ImportTask) ImportGalleries(ctx context.Context) {
MissingRefBehaviour: t.MissingRefBehaviour,
}
- return performImport(galleryImporter, t.DuplicateBehaviour)
+ return performImport(ctx, galleryImporter, t.DuplicateBehaviour)
}); err != nil {
logger.Errorf("[galleries] <%s> import failed to commit: %s", mappingJSON.Checksum, err.Error())
continue
@@ -395,8 +396,8 @@ func (t *ImportTask) ImportTags(ctx context.Context) {
logger.Progressf("[tags] %d of %d", index, len(t.mappings.Tags))
- if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error {
- return t.ImportTag(tagJSON, pendingParent, false, r.Tag())
+ if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
+ return t.ImportTag(ctx, tagJSON, pendingParent, false, t.txnManager.Tag)
}); err != nil {
var parentError tag.ParentTagNotExistError
if errors.As(err, &parentError) {
@@ -411,8 +412,8 @@ func (t *ImportTask) ImportTags(ctx context.Context) {
for _, s := range pendingParent {
for _, orphanTagJSON := range s {
- if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error {
- return t.ImportTag(orphanTagJSON, nil, true, r.Tag())
+ if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
+ return t.ImportTag(ctx, orphanTagJSON, nil, true, t.txnManager.Tag)
}); err != nil {
logger.Errorf("[tags] <%s> failed to create: %s", orphanTagJSON.Name, err.Error())
continue
@@ -423,7 +424,7 @@ func (t *ImportTask) ImportTags(ctx context.Context) {
logger.Info("[tags] import complete")
}
-func (t *ImportTask) ImportTag(tagJSON *jsonschema.Tag, pendingParent map[string][]*jsonschema.Tag, fail bool, readerWriter models.TagReaderWriter) error {
+func (t *ImportTask) ImportTag(ctx context.Context, tagJSON *jsonschema.Tag, pendingParent map[string][]*jsonschema.Tag, fail bool, readerWriter tag.NameFinderCreatorUpdater) error {
importer := &tag.Importer{
ReaderWriter: readerWriter,
Input: *tagJSON,
@@ -435,12 +436,12 @@ func (t *ImportTask) ImportTag(tagJSON *jsonschema.Tag, pendingParent map[string
importer.MissingRefBehaviour = models.ImportMissingRefEnumFail
}
- if err := performImport(importer, t.DuplicateBehaviour); err != nil {
+ if err := performImport(ctx, importer, t.DuplicateBehaviour); err != nil {
return err
}
for _, childTagJSON := range pendingParent[tagJSON.Name] {
- if err := t.ImportTag(childTagJSON, pendingParent, fail, readerWriter); err != nil {
+ if err := t.ImportTag(ctx, childTagJSON, pendingParent, fail, readerWriter); err != nil {
var parentError tag.ParentTagNotExistError
if errors.As(err, &parentError) {
pendingParent[parentError.MissingParent()] = append(pendingParent[parentError.MissingParent()], tagJSON)
@@ -457,10 +458,11 @@ func (t *ImportTask) ImportTag(tagJSON *jsonschema.Tag, pendingParent map[string
}
func (t *ImportTask) ImportScrapedItems(ctx context.Context) {
- if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error {
+ if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
logger.Info("[scraped sites] importing")
- qb := r.ScrapedItem()
- sqb := r.Studio()
+ r := t.txnManager
+ qb := r.ScrapedItem
+ sqb := r.Studio
currentTime := time.Now()
for i, mappingJSON := range t.scraped {
@@ -484,7 +486,7 @@ func (t *ImportTask) ImportScrapedItems(ctx context.Context) {
UpdatedAt: models.SQLiteTimestamp{Timestamp: t.getTimeFromJSONTime(mappingJSON.UpdatedAt)},
}
- studio, err := sqb.FindByName(mappingJSON.Studio, false)
+ studio, err := sqb.FindByName(ctx, mappingJSON.Studio, false)
if err != nil {
logger.Errorf("[scraped sites] failed to fetch studio: %s", err.Error())
}
@@ -492,7 +494,7 @@ func (t *ImportTask) ImportScrapedItems(ctx context.Context) {
newScrapedItem.StudioID = sql.NullInt64{Int64: int64(studio.ID), Valid: true}
}
- _, err = qb.Create(newScrapedItem)
+ _, err = qb.Create(ctx, newScrapedItem)
if err != nil {
logger.Errorf("[scraped sites] <%s> failed to create: %s", newScrapedItem.Title.String, err.Error())
}
@@ -522,14 +524,15 @@ func (t *ImportTask) ImportScenes(ctx context.Context) {
sceneHash := mappingJSON.Checksum
- if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error {
- readerWriter := r.Scene()
- tagWriter := r.Tag()
- galleryWriter := r.Gallery()
- movieWriter := r.Movie()
- performerWriter := r.Performer()
- studioWriter := r.Studio()
- markerWriter := r.SceneMarker()
+ if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
+ r := t.txnManager
+ readerWriter := r.Scene
+ tagWriter := r.Tag
+ galleryWriter := r.Gallery
+ movieWriter := r.Movie
+ performerWriter := r.Performer
+ studioWriter := r.Studio
+ markerWriter := r.SceneMarker
sceneImporter := &scene.Importer{
ReaderWriter: readerWriter,
@@ -546,7 +549,7 @@ func (t *ImportTask) ImportScenes(ctx context.Context) {
TagWriter: tagWriter,
}
- if err := performImport(sceneImporter, t.DuplicateBehaviour); err != nil {
+ if err := performImport(ctx, sceneImporter, t.DuplicateBehaviour); err != nil {
return err
}
@@ -560,7 +563,7 @@ func (t *ImportTask) ImportScenes(ctx context.Context) {
TagWriter: tagWriter,
}
- if err := performImport(markerImporter, t.DuplicateBehaviour); err != nil {
+ if err := performImport(ctx, markerImporter, t.DuplicateBehaviour); err != nil {
return err
}
}
@@ -590,12 +593,13 @@ func (t *ImportTask) ImportImages(ctx context.Context) {
imageHash := mappingJSON.Checksum
- if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error {
- readerWriter := r.Image()
- tagWriter := r.Tag()
- galleryWriter := r.Gallery()
- performerWriter := r.Performer()
- studioWriter := r.Studio()
+ if err := t.txnManager.WithTxn(ctx, func(ctx context.Context) error {
+ r := t.txnManager
+ readerWriter := r.Image
+ tagWriter := r.Tag
+ galleryWriter := r.Gallery
+ performerWriter := r.Performer
+ studioWriter := r.Studio
imageImporter := &image.Importer{
ReaderWriter: readerWriter,
@@ -610,7 +614,7 @@ func (t *ImportTask) ImportImages(ctx context.Context) {
TagWriter: tagWriter,
}
- return performImport(imageImporter, t.DuplicateBehaviour)
+ return performImport(ctx, imageImporter, t.DuplicateBehaviour)
}); err != nil {
logger.Errorf("[images] <%s> import failed: %s", imageHash, err.Error())
}
diff --git a/internal/manager/task_scan.go b/internal/manager/task_scan.go
index 042ff80ac..99fafceb0 100644
--- a/internal/manager/task_scan.go
+++ b/internal/manager/task_scan.go
@@ -24,7 +24,7 @@ import (
const scanQueueSize = 200000
type ScanJob struct {
- txnManager models.TransactionManager
+ txnManager models.Repository
input ScanMetadataInput
subscriptions *subscriptionManager
}
@@ -220,20 +220,21 @@ func (j *ScanJob) doesPathExist(ctx context.Context, path string) bool {
gExt := config.GetGalleryExtensions()
ret := false
- txnErr := j.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
+ txnErr := j.txnManager.WithTxn(ctx, func(ctx context.Context) error {
+ r := j.txnManager
switch {
case fsutil.MatchExtension(path, gExt):
- g, _ := r.Gallery().FindByPath(path)
+ g, _ := r.Gallery.FindByPath(ctx, path)
if g != nil {
ret = true
}
case fsutil.MatchExtension(path, vidExt):
- s, _ := r.Scene().FindByPath(path)
+ s, _ := r.Scene.FindByPath(ctx, path)
if s != nil {
ret = true
}
case fsutil.MatchExtension(path, imgExt):
- i, _ := r.Image().FindByPath(path)
+ i, _ := r.Image.FindByPath(ctx, path)
if i != nil {
ret = true
}
@@ -249,7 +250,7 @@ func (j *ScanJob) doesPathExist(ctx context.Context, path string) bool {
}
type ScanTask struct {
- TxnManager models.TransactionManager
+ TxnManager models.Repository
file file.SourceFile
UseFileMetadata bool
StripFileExtension bool
diff --git a/internal/manager/task_scan_gallery.go b/internal/manager/task_scan_gallery.go
index 8c3f5c550..2a2669e28 100644
--- a/internal/manager/task_scan_gallery.go
+++ b/internal/manager/task_scan_gallery.go
@@ -21,12 +21,12 @@ func (t *ScanTask) scanGallery(ctx context.Context) {
images := 0
scanImages := false
- if err := t.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
+ if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error {
var err error
- g, err = r.Gallery().FindByPath(path)
+ g, err = t.TxnManager.Gallery.FindByPath(ctx, path)
if g != nil && err == nil {
- images, err = r.Image().CountByGalleryID(g.ID)
+ images, err = t.TxnManager.Image.CountByGalleryID(ctx, g.ID)
if err != nil {
return fmt.Errorf("error getting images for zip gallery %s: %s", path, err.Error())
}
@@ -43,7 +43,7 @@ func (t *ScanTask) scanGallery(ctx context.Context) {
ImageExtensions: instance.Config.GetImageExtensions(),
StripFileExtension: t.StripFileExtension,
CaseSensitiveFs: t.CaseSensitiveFs,
- TxnManager: t.TxnManager,
+ CreatorUpdater: t.TxnManager.Gallery,
Paths: instance.Paths,
PluginCache: instance.PluginCache,
MutexManager: t.mutexManager,
@@ -79,10 +79,11 @@ func (t *ScanTask) scanGallery(ctx context.Context) {
// associates a gallery to a scene with the same basename
func (t *ScanTask) associateGallery(ctx context.Context, wg *sizedwaitgroup.SizedWaitGroup) {
path := t.file.Path()
- if err := t.TxnManager.WithTxn(ctx, func(r models.Repository) error {
- qb := r.Gallery()
- sqb := r.Scene()
- g, err := qb.FindByPath(path)
+ if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error {
+ r := t.TxnManager
+ qb := r.Gallery
+ sqb := r.Scene
+ g, err := qb.FindByPath(ctx, path)
if err != nil {
return err
}
@@ -106,10 +107,10 @@ func (t *ScanTask) associateGallery(ctx context.Context, wg *sizedwaitgroup.Size
}
}
for _, scenePath := range relatedFiles {
- scene, _ := sqb.FindByPath(scenePath)
+ scene, _ := sqb.FindByPath(ctx, scenePath)
// found related Scene
if scene != nil {
- sceneGalleries, _ := sqb.FindByGalleryID(g.ID) // check if gallery is already associated to the scene
+ sceneGalleries, _ := sqb.FindByGalleryID(ctx, g.ID) // check if gallery is already associated to the scene
isAssoc := false
for _, sg := range sceneGalleries {
if scene.ID == sg.ID {
@@ -119,7 +120,7 @@ func (t *ScanTask) associateGallery(ctx context.Context, wg *sizedwaitgroup.Size
}
if !isAssoc {
logger.Infof("associate: Gallery %s is related to scene: %d", path, scene.ID)
- if err := sqb.UpdateGalleries(scene.ID, []int{g.ID}); err != nil {
+ if err := sqb.UpdateGalleries(ctx, scene.ID, []int{g.ID}); err != nil {
return err
}
}
@@ -152,11 +153,11 @@ func (t *ScanTask) scanZipImages(ctx context.Context, zipGallery *models.Gallery
func (t *ScanTask) regenerateZipImages(ctx context.Context, zipGallery *models.Gallery) {
var images []*models.Image
- if err := t.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
- iqb := r.Image()
+ if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error {
+ iqb := t.TxnManager.Image
var err error
- images, err = iqb.FindByGalleryID(zipGallery.ID)
+ images, err = iqb.FindByGalleryID(ctx, zipGallery.ID)
return err
}); err != nil {
logger.Warnf("failed to find gallery images: %s", err.Error())
diff --git a/internal/manager/task_scan_image.go b/internal/manager/task_scan_image.go
index 36aff5a04..20bd78224 100644
--- a/internal/manager/task_scan_image.go
+++ b/internal/manager/task_scan_image.go
@@ -23,9 +23,9 @@ func (t *ScanTask) scanImage(ctx context.Context) {
var i *models.Image
path := t.file.Path()
- if err := t.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
+ if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error {
var err error
- i, err = r.Image().FindByPath(path)
+ i, err = t.TxnManager.Image.FindByPath(ctx, path)
return err
}); err != nil {
logger.Error(err.Error())
@@ -36,6 +36,8 @@ func (t *ScanTask) scanImage(ctx context.Context) {
Scanner: image.FileScanner(&file.FSHasher{}),
StripFileExtension: t.StripFileExtension,
TxnManager: t.TxnManager,
+ CreatorUpdater: t.TxnManager.Image,
+ CaseSensitiveFs: t.CaseSensitiveFs,
Paths: GetInstance().Paths,
PluginCache: instance.PluginCache,
MutexManager: t.mutexManager,
@@ -58,8 +60,8 @@ func (t *ScanTask) scanImage(ctx context.Context) {
if i != nil {
if t.zipGallery != nil {
// associate with gallery
- if err := t.TxnManager.WithTxn(ctx, func(r models.Repository) error {
- return gallery.AddImage(r.Gallery(), t.zipGallery.ID, i.ID)
+ if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error {
+ return gallery.AddImage(ctx, t.TxnManager.Gallery, t.zipGallery.ID, i.ID)
}); err != nil {
logger.Error(err.Error())
return
@@ -69,9 +71,9 @@ func (t *ScanTask) scanImage(ctx context.Context) {
logger.Infof("Associating image %s with folder gallery", i.Path)
var galleryID int
var isNewGallery bool
- if err := t.TxnManager.WithTxn(ctx, func(r models.Repository) error {
+ if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error {
var err error
- galleryID, isNewGallery, err = t.associateImageWithFolderGallery(i.ID, r.Gallery())
+ galleryID, isNewGallery, err = t.associateImageWithFolderGallery(ctx, i.ID, t.TxnManager.Gallery)
return err
}); err != nil {
logger.Error(err.Error())
@@ -90,11 +92,17 @@ func (t *ScanTask) scanImage(ctx context.Context) {
}
}
-func (t *ScanTask) associateImageWithFolderGallery(imageID int, qb models.GalleryReaderWriter) (galleryID int, isNew bool, err error) {
+type GalleryImageAssociator interface {
+ FindByPath(ctx context.Context, path string) (*models.Gallery, error)
+ Create(ctx context.Context, newGallery models.Gallery) (*models.Gallery, error)
+ gallery.ImageUpdater
+}
+
+func (t *ScanTask) associateImageWithFolderGallery(ctx context.Context, imageID int, qb GalleryImageAssociator) (galleryID int, isNew bool, err error) {
// find a gallery with the path specified
path := filepath.Dir(t.file.Path())
var g *models.Gallery
- g, err = qb.FindByPath(path)
+ g, err = qb.FindByPath(ctx, path)
if err != nil {
return
}
@@ -120,7 +128,7 @@ func (t *ScanTask) associateImageWithFolderGallery(imageID int, qb models.Galler
}
logger.Infof("Creating gallery for folder %s", path)
- g, err = qb.Create(newGallery)
+ g, err = qb.Create(ctx, newGallery)
if err != nil {
return 0, false, err
}
@@ -129,7 +137,7 @@ func (t *ScanTask) associateImageWithFolderGallery(imageID int, qb models.Galler
}
// associate image with gallery
- err = gallery.AddImage(qb, g.ID, imageID)
+ err = gallery.AddImage(ctx, qb, g.ID, imageID)
galleryID = g.ID
return
}
diff --git a/internal/manager/task_scan_scene.go b/internal/manager/task_scan_scene.go
index 218a2e012..295a0c7ef 100644
--- a/internal/manager/task_scan_scene.go
+++ b/internal/manager/task_scan_scene.go
@@ -34,9 +34,9 @@ func (t *ScanTask) scanScene(ctx context.Context) *models.Scene {
var retScene *models.Scene
var s *models.Scene
- if err := t.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
+ if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error {
var err error
- s, err = r.Scene().FindByPath(t.file.Path())
+ s, err = t.TxnManager.Scene.FindByPath(ctx, t.file.Path())
return err
}); err != nil {
logger.Error(err.Error())
@@ -54,7 +54,9 @@ func (t *ScanTask) scanScene(ctx context.Context) *models.Scene {
StripFileExtension: t.StripFileExtension,
FileNamingAlgorithm: t.fileNamingAlgorithm,
TxnManager: t.TxnManager,
+ CreatorUpdater: t.TxnManager.Scene,
Paths: GetInstance().Paths,
+ CaseSensitiveFs: t.CaseSensitiveFs,
Screenshotter: &sceneScreenshotter{
g: g,
},
@@ -88,12 +90,12 @@ func (t *ScanTask) associateCaptions(ctx context.Context) {
captionLang := scene.GetCaptionsLangFromPath(captionPath)
relatedFiles := scene.GenerateCaptionCandidates(captionPath, vExt)
- if err := t.TxnManager.WithTxn(ctx, func(r models.Repository) error {
+ if err := t.TxnManager.WithTxn(ctx, func(ctx context.Context) error {
var err error
- sqb := r.Scene()
+ sqb := t.TxnManager.Scene
for _, scenePath := range relatedFiles {
- s, er := sqb.FindByPath(scenePath)
+ s, er := sqb.FindByPath(ctx, scenePath)
if er != nil {
logger.Errorf("Error searching for scene %s: %v", scenePath, er)
@@ -101,7 +103,7 @@ func (t *ScanTask) associateCaptions(ctx context.Context) {
}
if s != nil { // found related Scene
logger.Debugf("Matched captions to scene %s", s.Path)
- captions, er := sqb.GetCaptions(s.ID)
+ captions, er := sqb.GetCaptions(ctx, s.ID)
if er == nil {
fileExt := filepath.Ext(captionPath)
ext := fileExt[1:]
@@ -112,7 +114,7 @@ func (t *ScanTask) associateCaptions(ctx context.Context) {
CaptionType: ext,
}
captions = append(captions, newCaption)
- er = sqb.UpdateCaptions(s.ID, captions)
+ er = sqb.UpdateCaptions(ctx, s.ID, captions)
if er == nil {
logger.Debugf("Updated captions for scene %s. Added %s", s.Path, captionLang)
}
diff --git a/internal/manager/task_stash_box_tag.go b/internal/manager/task_stash_box_tag.go
index 2932a3167..cf7add510 100644
--- a/internal/manager/task_stash_box_tag.go
+++ b/internal/manager/task_stash_box_tag.go
@@ -10,11 +10,11 @@ import (
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/scraper/stashbox"
+ "github.com/stashapp/stash/pkg/txn"
"github.com/stashapp/stash/pkg/utils"
)
type StashBoxPerformerTagTask struct {
- txnManager models.TransactionManager
box *models.StashBox
name *string
performer *models.Performer
@@ -41,12 +41,17 @@ func (t *StashBoxPerformerTagTask) stashBoxPerformerTag(ctx context.Context) {
var performer *models.ScrapedPerformer
var err error
- client := stashbox.NewClient(*t.box, t.txnManager)
+ client := stashbox.NewClient(*t.box, instance.Repository, stashbox.Repository{
+ Scene: instance.Repository.Scene,
+ Performer: instance.Repository.Performer,
+ Tag: instance.Repository.Tag,
+ Studio: instance.Repository.Studio,
+ })
if t.refresh {
var performerID string
- txnErr := t.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
- stashids, _ := r.Performer().GetStashIDs(t.performer.ID)
+ txnErr := txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error {
+ stashids, _ := instance.Repository.Performer.GetStashIDs(ctx, t.performer.ID)
for _, id := range stashids {
if id.Endpoint == t.box.Endpoint {
performerID = id.StashID
@@ -156,11 +161,12 @@ func (t *StashBoxPerformerTagTask) stashBoxPerformerTag(ctx context.Context) {
partial.URL = &value
}
- txnErr := t.txnManager.WithTxn(ctx, func(r models.Repository) error {
- _, err := r.Performer().Update(partial)
+ txnErr := txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error {
+ r := instance.Repository
+ _, err := r.Performer.Update(ctx, partial)
if !t.refresh {
- err = r.Performer().UpdateStashIDs(t.performer.ID, []models.StashID{
+ err = r.Performer.UpdateStashIDs(ctx, t.performer.ID, []models.StashID{
{
Endpoint: t.box.Endpoint,
StashID: *performer.RemoteSiteID,
@@ -176,7 +182,7 @@ func (t *StashBoxPerformerTagTask) stashBoxPerformerTag(ctx context.Context) {
if err != nil {
return err
}
- err = r.Performer().UpdateImage(t.performer.ID, image)
+ err = r.Performer.UpdateImage(ctx, t.performer.ID, image)
if err != nil {
return err
}
@@ -218,13 +224,14 @@ func (t *StashBoxPerformerTagTask) stashBoxPerformerTag(ctx context.Context) {
URL: getNullString(performer.URL),
UpdatedAt: models.SQLiteTimestamp{Timestamp: currentTime},
}
- err := t.txnManager.WithTxn(ctx, func(r models.Repository) error {
- createdPerformer, err := r.Performer().Create(newPerformer)
+ err := txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error {
+ r := instance.Repository
+ createdPerformer, err := r.Performer.Create(ctx, newPerformer)
if err != nil {
return err
}
- err = r.Performer().UpdateStashIDs(createdPerformer.ID, []models.StashID{
+ err = r.Performer.UpdateStashIDs(ctx, createdPerformer.ID, []models.StashID{
{
Endpoint: t.box.Endpoint,
StashID: *performer.RemoteSiteID,
@@ -239,7 +246,7 @@ func (t *StashBoxPerformerTagTask) stashBoxPerformerTag(ctx context.Context) {
if imageErr != nil {
return imageErr
}
- err = r.Performer().UpdateImage(createdPerformer.ID, image)
+ err = r.Performer.UpdateImage(ctx, createdPerformer.ID, image)
}
return err
})
diff --git a/pkg/database/transaction.go b/pkg/database/transaction.go
deleted file mode 100644
index d8c23fb3b..000000000
--- a/pkg/database/transaction.go
+++ /dev/null
@@ -1,40 +0,0 @@
-package database
-
-import (
- "context"
-
- "github.com/jmoiron/sqlx"
- "github.com/stashapp/stash/pkg/logger"
-)
-
-// WithTxn executes the provided function within a transaction. It rolls back
-// the transaction if the function returns an error, otherwise the transaction
-// is committed.
-func WithTxn(fn func(tx *sqlx.Tx) error) error {
- ctx := context.TODO()
- tx := DB.MustBeginTx(ctx, nil)
-
- var err error
- defer func() {
- if p := recover(); p != nil {
- // a panic occurred, rollback and repanic
- if err := tx.Rollback(); err != nil {
- logger.Warnf("failure when performing transaction rollback: %v", err)
- }
- panic(p)
- }
-
- if err != nil {
- // something went wrong, rollback
- if err := tx.Rollback(); err != nil {
- logger.Warnf("failure when performing transaction rollback: %v", err)
- }
- } else {
- // all good, commit
- err = tx.Commit()
- }
- }()
-
- err = fn(tx)
- return err
-}
diff --git a/pkg/gallery/export.go b/pkg/gallery/export.go
index f24660e60..296929b35 100644
--- a/pkg/gallery/export.go
+++ b/pkg/gallery/export.go
@@ -1,9 +1,12 @@
package gallery
import (
+ "context"
+
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/models/json"
"github.com/stashapp/stash/pkg/models/jsonschema"
+ "github.com/stashapp/stash/pkg/studio"
"github.com/stashapp/stash/pkg/utils"
)
@@ -52,9 +55,9 @@ func ToBasicJSON(gallery *models.Gallery) (*jsonschema.Gallery, error) {
// GetStudioName returns the name of the provided gallery's studio. It returns an
// empty string if there is no studio assigned to the gallery.
-func GetStudioName(reader models.StudioReader, gallery *models.Gallery) (string, error) {
+func GetStudioName(ctx context.Context, reader studio.Finder, gallery *models.Gallery) (string, error) {
if gallery.StudioID.Valid {
- studio, err := reader.Find(int(gallery.StudioID.Int64))
+ studio, err := reader.Find(ctx, int(gallery.StudioID.Int64))
if err != nil {
return "", err
}
diff --git a/pkg/gallery/export_test.go b/pkg/gallery/export_test.go
index 80418d7e0..fe371fad7 100644
--- a/pkg/gallery/export_test.go
+++ b/pkg/gallery/export_test.go
@@ -154,15 +154,15 @@ func TestGetStudioName(t *testing.T) {
studioErr := errors.New("error getting image")
- mockStudioReader.On("Find", studioID).Return(&models.Studio{
+ mockStudioReader.On("Find", testCtx, studioID).Return(&models.Studio{
Name: models.NullString(studioName),
}, nil).Once()
- mockStudioReader.On("Find", missingStudioID).Return(nil, nil).Once()
- mockStudioReader.On("Find", errStudioID).Return(nil, studioErr).Once()
+ mockStudioReader.On("Find", testCtx, missingStudioID).Return(nil, nil).Once()
+ mockStudioReader.On("Find", testCtx, errStudioID).Return(nil, studioErr).Once()
for i, s := range getStudioScenarios {
gallery := s.input
- json, err := GetStudioName(mockStudioReader, &gallery)
+ json, err := GetStudioName(testCtx, mockStudioReader, &gallery)
switch {
case !s.err && err != nil:
diff --git a/pkg/gallery/import.go b/pkg/gallery/import.go
index f82cff13b..85c90e3f0 100644
--- a/pkg/gallery/import.go
+++ b/pkg/gallery/import.go
@@ -1,20 +1,30 @@
package gallery
import (
+ "context"
"database/sql"
"fmt"
"strings"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/models/jsonschema"
+ "github.com/stashapp/stash/pkg/performer"
"github.com/stashapp/stash/pkg/sliceutil/stringslice"
+ "github.com/stashapp/stash/pkg/studio"
+ "github.com/stashapp/stash/pkg/tag"
)
+type FullCreatorUpdater interface {
+ FinderCreatorUpdater
+ UpdatePerformers(ctx context.Context, galleryID int, performerIDs []int) error
+ UpdateTags(ctx context.Context, galleryID int, tagIDs []int) error
+}
+
type Importer struct {
- ReaderWriter models.GalleryReaderWriter
- StudioWriter models.StudioReaderWriter
- PerformerWriter models.PerformerReaderWriter
- TagWriter models.TagReaderWriter
+ ReaderWriter FullCreatorUpdater
+ StudioWriter studio.NameFinderCreator
+ PerformerWriter performer.NameFinderCreator
+ TagWriter tag.NameFinderCreator
Input jsonschema.Gallery
MissingRefBehaviour models.ImportMissingRefEnum
@@ -23,18 +33,18 @@ type Importer struct {
tags []*models.Tag
}
-func (i *Importer) PreImport() error {
+func (i *Importer) PreImport(ctx context.Context) error {
i.gallery = i.galleryJSONToGallery(i.Input)
- if err := i.populateStudio(); err != nil {
+ if err := i.populateStudio(ctx); err != nil {
return err
}
- if err := i.populatePerformers(); err != nil {
+ if err := i.populatePerformers(ctx); err != nil {
return err
}
- if err := i.populateTags(); err != nil {
+ if err := i.populateTags(ctx); err != nil {
return err
}
@@ -74,9 +84,9 @@ func (i *Importer) galleryJSONToGallery(galleryJSON jsonschema.Gallery) models.G
return newGallery
}
-func (i *Importer) populateStudio() error {
+func (i *Importer) populateStudio(ctx context.Context) error {
if i.Input.Studio != "" {
- studio, err := i.StudioWriter.FindByName(i.Input.Studio, false)
+ studio, err := i.StudioWriter.FindByName(ctx, i.Input.Studio, false)
if err != nil {
return fmt.Errorf("error finding studio by name: %v", err)
}
@@ -91,7 +101,7 @@ func (i *Importer) populateStudio() error {
}
if i.MissingRefBehaviour == models.ImportMissingRefEnumCreate {
- studioID, err := i.createStudio(i.Input.Studio)
+ studioID, err := i.createStudio(ctx, i.Input.Studio)
if err != nil {
return err
}
@@ -108,10 +118,10 @@ func (i *Importer) populateStudio() error {
return nil
}
-func (i *Importer) createStudio(name string) (int, error) {
+func (i *Importer) createStudio(ctx context.Context, name string) (int, error) {
newStudio := *models.NewStudio(name)
- created, err := i.StudioWriter.Create(newStudio)
+ created, err := i.StudioWriter.Create(ctx, newStudio)
if err != nil {
return 0, err
}
@@ -119,10 +129,10 @@ func (i *Importer) createStudio(name string) (int, error) {
return created.ID, nil
}
-func (i *Importer) populatePerformers() error {
+func (i *Importer) populatePerformers(ctx context.Context) error {
if len(i.Input.Performers) > 0 {
names := i.Input.Performers
- performers, err := i.PerformerWriter.FindByNames(names, false)
+ performers, err := i.PerformerWriter.FindByNames(ctx, names, false)
if err != nil {
return err
}
@@ -145,7 +155,7 @@ func (i *Importer) populatePerformers() error {
}
if i.MissingRefBehaviour == models.ImportMissingRefEnumCreate {
- createdPerformers, err := i.createPerformers(missingPerformers)
+ createdPerformers, err := i.createPerformers(ctx, missingPerformers)
if err != nil {
return fmt.Errorf("error creating gallery performers: %v", err)
}
@@ -162,12 +172,12 @@ func (i *Importer) populatePerformers() error {
return nil
}
-func (i *Importer) createPerformers(names []string) ([]*models.Performer, error) {
+func (i *Importer) createPerformers(ctx context.Context, names []string) ([]*models.Performer, error) {
var ret []*models.Performer
for _, name := range names {
newPerformer := *models.NewPerformer(name)
- created, err := i.PerformerWriter.Create(newPerformer)
+ created, err := i.PerformerWriter.Create(ctx, newPerformer)
if err != nil {
return nil, err
}
@@ -178,10 +188,10 @@ func (i *Importer) createPerformers(names []string) ([]*models.Performer, error)
return ret, nil
}
-func (i *Importer) populateTags() error {
+func (i *Importer) populateTags(ctx context.Context) error {
if len(i.Input.Tags) > 0 {
names := i.Input.Tags
- tags, err := i.TagWriter.FindByNames(names, false)
+ tags, err := i.TagWriter.FindByNames(ctx, names, false)
if err != nil {
return err
}
@@ -201,7 +211,7 @@ func (i *Importer) populateTags() error {
}
if i.MissingRefBehaviour == models.ImportMissingRefEnumCreate {
- createdTags, err := i.createTags(missingTags)
+ createdTags, err := i.createTags(ctx, missingTags)
if err != nil {
return fmt.Errorf("error creating gallery tags: %v", err)
}
@@ -218,12 +228,12 @@ func (i *Importer) populateTags() error {
return nil
}
-func (i *Importer) createTags(names []string) ([]*models.Tag, error) {
+func (i *Importer) createTags(ctx context.Context, names []string) ([]*models.Tag, error) {
var ret []*models.Tag
for _, name := range names {
newTag := *models.NewTag(name)
- created, err := i.TagWriter.Create(newTag)
+ created, err := i.TagWriter.Create(ctx, newTag)
if err != nil {
return nil, err
}
@@ -234,14 +244,14 @@ func (i *Importer) createTags(names []string) ([]*models.Tag, error) {
return ret, nil
}
-func (i *Importer) PostImport(id int) error {
+func (i *Importer) PostImport(ctx context.Context, id int) error {
if len(i.performers) > 0 {
var performerIDs []int
for _, performer := range i.performers {
performerIDs = append(performerIDs, performer.ID)
}
- if err := i.ReaderWriter.UpdatePerformers(id, performerIDs); err != nil {
+ if err := i.ReaderWriter.UpdatePerformers(ctx, id, performerIDs); err != nil {
return fmt.Errorf("failed to associate performers: %v", err)
}
}
@@ -251,7 +261,7 @@ func (i *Importer) PostImport(id int) error {
for _, t := range i.tags {
tagIDs = append(tagIDs, t.ID)
}
- if err := i.ReaderWriter.UpdateTags(id, tagIDs); err != nil {
+ if err := i.ReaderWriter.UpdateTags(ctx, id, tagIDs); err != nil {
return fmt.Errorf("failed to associate tags: %v", err)
}
}
@@ -263,8 +273,8 @@ func (i *Importer) Name() string {
return i.Input.Path
}
-func (i *Importer) FindExistingID() (*int, error) {
- existing, err := i.ReaderWriter.FindByChecksum(i.Input.Checksum)
+func (i *Importer) FindExistingID(ctx context.Context) (*int, error) {
+ existing, err := i.ReaderWriter.FindByChecksum(ctx, i.Input.Checksum)
if err != nil {
return nil, err
}
@@ -277,8 +287,8 @@ func (i *Importer) FindExistingID() (*int, error) {
return nil, nil
}
-func (i *Importer) Create() (*int, error) {
- created, err := i.ReaderWriter.Create(i.gallery)
+func (i *Importer) Create(ctx context.Context) (*int, error) {
+ created, err := i.ReaderWriter.Create(ctx, i.gallery)
if err != nil {
return nil, fmt.Errorf("error creating gallery: %v", err)
}
@@ -287,10 +297,10 @@ func (i *Importer) Create() (*int, error) {
return &id, nil
}
-func (i *Importer) Update(id int) error {
+func (i *Importer) Update(ctx context.Context, id int) error {
gallery := i.gallery
gallery.ID = id
- _, err := i.ReaderWriter.Update(gallery)
+ _, err := i.ReaderWriter.Update(ctx, gallery)
if err != nil {
return fmt.Errorf("error updating existing gallery: %v", err)
}
diff --git a/pkg/gallery/import_test.go b/pkg/gallery/import_test.go
index d50fd16d1..6f111aa4b 100644
--- a/pkg/gallery/import_test.go
+++ b/pkg/gallery/import_test.go
@@ -1,6 +1,7 @@
package gallery
import (
+ "context"
"errors"
"testing"
"time"
@@ -40,6 +41,8 @@ const (
errChecksum = "errChecksum"
)
+var testCtx = context.Background()
+
var (
createdAt = time.Date(2001, time.January, 2, 1, 2, 3, 4, time.Local)
updatedAt = time.Date(2002, time.January, 2, 1, 2, 3, 4, time.Local)
@@ -75,7 +78,7 @@ func TestImporterPreImport(t *testing.T) {
},
}
- err := i.PreImport()
+ err := i.PreImport(testCtx)
assert.Nil(t, err)
expectedGallery := models.Gallery{
@@ -112,17 +115,17 @@ func TestImporterPreImportWithStudio(t *testing.T) {
},
}
- studioReaderWriter.On("FindByName", existingStudioName, false).Return(&models.Studio{
+ studioReaderWriter.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{
ID: existingStudioID,
}, nil).Once()
- studioReaderWriter.On("FindByName", existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once()
+ studioReaderWriter.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once()
- err := i.PreImport()
+ err := i.PreImport(testCtx)
assert.Nil(t, err)
assert.Equal(t, int64(existingStudioID), i.gallery.StudioID.Int64)
i.Input.Studio = existingStudioErr
- err = i.PreImport()
+ err = i.PreImport(testCtx)
assert.NotNil(t, err)
studioReaderWriter.AssertExpectations(t)
@@ -140,20 +143,20 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail,
}
- studioReaderWriter.On("FindByName", missingStudioName, false).Return(nil, nil).Times(3)
- studioReaderWriter.On("Create", mock.AnythingOfType("models.Studio")).Return(&models.Studio{
+ studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3)
+ studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Studio")).Return(&models.Studio{
ID: existingStudioID,
}, nil)
- err := i.PreImport()
+ err := i.PreImport(testCtx)
assert.NotNil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore
- err = i.PreImport()
+ err = i.PreImport(testCtx)
assert.Nil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumCreate
- err = i.PreImport()
+ err = i.PreImport(testCtx)
assert.Nil(t, err)
assert.Equal(t, int64(existingStudioID), i.gallery.StudioID.Int64)
@@ -172,10 +175,10 @@ func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
}
- studioReaderWriter.On("FindByName", missingStudioName, false).Return(nil, nil).Once()
- studioReaderWriter.On("Create", mock.AnythingOfType("models.Studio")).Return(nil, errors.New("Create error"))
+ studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once()
+ studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Studio")).Return(nil, errors.New("Create error"))
- err := i.PreImport()
+ err := i.PreImport(testCtx)
assert.NotNil(t, err)
}
@@ -193,20 +196,20 @@ func TestImporterPreImportWithPerformer(t *testing.T) {
},
}
- performerReaderWriter.On("FindByNames", []string{existingPerformerName}, false).Return([]*models.Performer{
+ performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerName}, false).Return([]*models.Performer{
{
ID: existingPerformerID,
Name: models.NullString(existingPerformerName),
},
}, nil).Once()
- performerReaderWriter.On("FindByNames", []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once()
+ performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once()
- err := i.PreImport()
+ err := i.PreImport(testCtx)
assert.Nil(t, err)
assert.Equal(t, existingPerformerID, i.performers[0].ID)
i.Input.Performers = []string{existingPerformerErr}
- err = i.PreImport()
+ err = i.PreImport(testCtx)
assert.NotNil(t, err)
performerReaderWriter.AssertExpectations(t)
@@ -226,20 +229,20 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail,
}
- performerReaderWriter.On("FindByNames", []string{missingPerformerName}, false).Return(nil, nil).Times(3)
- performerReaderWriter.On("Create", mock.AnythingOfType("models.Performer")).Return(&models.Performer{
+ performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3)
+ performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Performer")).Return(&models.Performer{
ID: existingPerformerID,
}, nil)
- err := i.PreImport()
+ err := i.PreImport(testCtx)
assert.NotNil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore
- err = i.PreImport()
+ err = i.PreImport(testCtx)
assert.Nil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumCreate
- err = i.PreImport()
+ err = i.PreImport(testCtx)
assert.Nil(t, err)
assert.Equal(t, existingPerformerID, i.performers[0].ID)
@@ -260,10 +263,10 @@ func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
}
- performerReaderWriter.On("FindByNames", []string{missingPerformerName}, false).Return(nil, nil).Once()
- performerReaderWriter.On("Create", mock.AnythingOfType("models.Performer")).Return(nil, errors.New("Create error"))
+ performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once()
+ performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Performer")).Return(nil, errors.New("Create error"))
- err := i.PreImport()
+ err := i.PreImport(testCtx)
assert.NotNil(t, err)
}
@@ -281,20 +284,20 @@ func TestImporterPreImportWithTag(t *testing.T) {
},
}
- tagReaderWriter.On("FindByNames", []string{existingTagName}, false).Return([]*models.Tag{
+ tagReaderWriter.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{
{
ID: existingTagID,
Name: existingTagName,
},
}, nil).Once()
- tagReaderWriter.On("FindByNames", []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once()
+ tagReaderWriter.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once()
- err := i.PreImport()
+ err := i.PreImport(testCtx)
assert.Nil(t, err)
assert.Equal(t, existingTagID, i.tags[0].ID)
i.Input.Tags = []string{existingTagErr}
- err = i.PreImport()
+ err = i.PreImport(testCtx)
assert.NotNil(t, err)
tagReaderWriter.AssertExpectations(t)
@@ -314,20 +317,20 @@ func TestImporterPreImportWithMissingTag(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail,
}
- tagReaderWriter.On("FindByNames", []string{missingTagName}, false).Return(nil, nil).Times(3)
- tagReaderWriter.On("Create", mock.AnythingOfType("models.Tag")).Return(&models.Tag{
+ tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3)
+ tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Tag")).Return(&models.Tag{
ID: existingTagID,
}, nil)
- err := i.PreImport()
+ err := i.PreImport(testCtx)
assert.NotNil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore
- err = i.PreImport()
+ err = i.PreImport(testCtx)
assert.Nil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumCreate
- err = i.PreImport()
+ err = i.PreImport(testCtx)
assert.Nil(t, err)
assert.Equal(t, existingTagID, i.tags[0].ID)
@@ -348,10 +351,10 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
}
- tagReaderWriter.On("FindByNames", []string{missingTagName}, false).Return(nil, nil).Once()
- tagReaderWriter.On("Create", mock.AnythingOfType("models.Tag")).Return(nil, errors.New("Create error"))
+ tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once()
+ tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Tag")).Return(nil, errors.New("Create error"))
- err := i.PreImport()
+ err := i.PreImport(testCtx)
assert.NotNil(t, err)
}
@@ -369,13 +372,13 @@ func TestImporterPostImportUpdatePerformers(t *testing.T) {
updateErr := errors.New("UpdatePerformers error")
- galleryReaderWriter.On("UpdatePerformers", galleryID, []int{existingPerformerID}).Return(nil).Once()
- galleryReaderWriter.On("UpdatePerformers", errPerformersID, mock.AnythingOfType("[]int")).Return(updateErr).Once()
+ galleryReaderWriter.On("UpdatePerformers", testCtx, galleryID, []int{existingPerformerID}).Return(nil).Once()
+ galleryReaderWriter.On("UpdatePerformers", testCtx, errPerformersID, mock.AnythingOfType("[]int")).Return(updateErr).Once()
- err := i.PostImport(galleryID)
+ err := i.PostImport(testCtx, galleryID)
assert.Nil(t, err)
- err = i.PostImport(errPerformersID)
+ err = i.PostImport(testCtx, errPerformersID)
assert.NotNil(t, err)
galleryReaderWriter.AssertExpectations(t)
@@ -395,13 +398,13 @@ func TestImporterPostImportUpdateTags(t *testing.T) {
updateErr := errors.New("UpdateTags error")
- galleryReaderWriter.On("UpdateTags", galleryID, []int{existingTagID}).Return(nil).Once()
- galleryReaderWriter.On("UpdateTags", errTagsID, mock.AnythingOfType("[]int")).Return(updateErr).Once()
+ galleryReaderWriter.On("UpdateTags", testCtx, galleryID, []int{existingTagID}).Return(nil).Once()
+ galleryReaderWriter.On("UpdateTags", testCtx, errTagsID, mock.AnythingOfType("[]int")).Return(updateErr).Once()
- err := i.PostImport(galleryID)
+ err := i.PostImport(testCtx, galleryID)
assert.Nil(t, err)
- err = i.PostImport(errTagsID)
+ err = i.PostImport(testCtx, errTagsID)
assert.NotNil(t, err)
galleryReaderWriter.AssertExpectations(t)
@@ -419,23 +422,23 @@ func TestImporterFindExistingID(t *testing.T) {
}
expectedErr := errors.New("FindBy* error")
- readerWriter.On("FindByChecksum", missingChecksum).Return(nil, nil).Once()
- readerWriter.On("FindByChecksum", checksum).Return(&models.Gallery{
+ readerWriter.On("FindByChecksum", testCtx, missingChecksum).Return(nil, nil).Once()
+ readerWriter.On("FindByChecksum", testCtx, checksum).Return(&models.Gallery{
ID: existingGalleryID,
}, nil).Once()
- readerWriter.On("FindByChecksum", errChecksum).Return(nil, expectedErr).Once()
+ readerWriter.On("FindByChecksum", testCtx, errChecksum).Return(nil, expectedErr).Once()
- id, err := i.FindExistingID()
+ id, err := i.FindExistingID(testCtx)
assert.Nil(t, id)
assert.Nil(t, err)
i.Input.Checksum = checksum
- id, err = i.FindExistingID()
+ id, err = i.FindExistingID(testCtx)
assert.Equal(t, existingGalleryID, *id)
assert.Nil(t, err)
i.Input.Checksum = errChecksum
- id, err = i.FindExistingID()
+ id, err = i.FindExistingID(testCtx)
assert.Nil(t, id)
assert.NotNil(t, err)
@@ -459,17 +462,17 @@ func TestCreate(t *testing.T) {
}
errCreate := errors.New("Create error")
- readerWriter.On("Create", gallery).Return(&models.Gallery{
+ readerWriter.On("Create", testCtx, gallery).Return(&models.Gallery{
ID: galleryID,
}, nil).Once()
- readerWriter.On("Create", galleryErr).Return(nil, errCreate).Once()
+ readerWriter.On("Create", testCtx, galleryErr).Return(nil, errCreate).Once()
- id, err := i.Create()
+ id, err := i.Create(testCtx)
assert.Equal(t, galleryID, *id)
assert.Nil(t, err)
i.gallery = galleryErr
- id, err = i.Create()
+ id, err = i.Create(testCtx)
assert.Nil(t, id)
assert.NotNil(t, err)
@@ -490,9 +493,9 @@ func TestUpdate(t *testing.T) {
// id needs to be set for the mock input
gallery.ID = galleryID
- readerWriter.On("Update", gallery).Return(nil, nil).Once()
+ readerWriter.On("Update", testCtx, gallery).Return(nil, nil).Once()
- err := i.Update(galleryID)
+ err := i.Update(testCtx, galleryID)
assert.Nil(t, err)
readerWriter.AssertExpectations(t)
diff --git a/pkg/gallery/query.go b/pkg/gallery/query.go
index f15e480f2..34065435b 100644
--- a/pkg/gallery/query.go
+++ b/pkg/gallery/query.go
@@ -1,12 +1,25 @@
package gallery
import (
+ "context"
"strconv"
"github.com/stashapp/stash/pkg/models"
)
-func CountByPerformerID(r models.GalleryReader, id int) (int, error) {
+type Queryer interface {
+ Query(ctx context.Context, galleryFilter *models.GalleryFilterType, findFilter *models.FindFilterType) ([]*models.Gallery, int, error)
+}
+
+type CountQueryer interface {
+ QueryCount(ctx context.Context, galleryFilter *models.GalleryFilterType, findFilter *models.FindFilterType) (int, error)
+}
+
+type ChecksumsFinder interface {
+ FindByChecksums(ctx context.Context, checksums []string) ([]*models.Gallery, error)
+}
+
+func CountByPerformerID(ctx context.Context, r CountQueryer, id int) (int, error) {
filter := &models.GalleryFilterType{
Performers: &models.MultiCriterionInput{
Value: []string{strconv.Itoa(id)},
@@ -14,10 +27,10 @@ func CountByPerformerID(r models.GalleryReader, id int) (int, error) {
},
}
- return r.QueryCount(filter, nil)
+ return r.QueryCount(ctx, filter, nil)
}
-func CountByStudioID(r models.GalleryReader, id int) (int, error) {
+func CountByStudioID(ctx context.Context, r CountQueryer, id int) (int, error) {
filter := &models.GalleryFilterType{
Studios: &models.HierarchicalMultiCriterionInput{
Value: []string{strconv.Itoa(id)},
@@ -25,10 +38,10 @@ func CountByStudioID(r models.GalleryReader, id int) (int, error) {
},
}
- return r.QueryCount(filter, nil)
+ return r.QueryCount(ctx, filter, nil)
}
-func CountByTagID(r models.GalleryReader, id int) (int, error) {
+func CountByTagID(ctx context.Context, r CountQueryer, id int) (int, error) {
filter := &models.GalleryFilterType{
Tags: &models.HierarchicalMultiCriterionInput{
Value: []string{strconv.Itoa(id)},
@@ -36,5 +49,5 @@ func CountByTagID(r models.GalleryReader, id int) (int, error) {
},
}
- return r.QueryCount(filter, nil)
+ return r.QueryCount(ctx, filter, nil)
}
diff --git a/pkg/gallery/scan.go b/pkg/gallery/scan.go
index f45a26d77..643ba8988 100644
--- a/pkg/gallery/scan.go
+++ b/pkg/gallery/scan.go
@@ -14,18 +14,26 @@ import (
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/models/paths"
"github.com/stashapp/stash/pkg/plugin"
+ "github.com/stashapp/stash/pkg/txn"
"github.com/stashapp/stash/pkg/utils"
)
const mutexType = "gallery"
+type FinderCreatorUpdater interface {
+ FindByChecksum(ctx context.Context, checksum string) (*models.Gallery, error)
+ Create(ctx context.Context, newGallery models.Gallery) (*models.Gallery, error)
+ Update(ctx context.Context, updatedGallery models.Gallery) (*models.Gallery, error)
+}
+
type Scanner struct {
file.Scanner
ImageExtensions []string
StripFileExtension bool
CaseSensitiveFs bool
- TxnManager models.TransactionManager
+ TxnManager txn.Manager
+ CreatorUpdater FinderCreatorUpdater
Paths *paths.Paths
PluginCache *plugin.Cache
MutexManager *utils.MutexManager
@@ -75,19 +83,19 @@ func (scanner *Scanner) ScanExisting(ctx context.Context, existing file.FileBase
done := make(chan struct{})
scanner.MutexManager.Claim(mutexType, scanned.New.Checksum, done)
- if err := scanner.TxnManager.WithTxn(ctx, func(r models.Repository) error {
+ if err := txn.WithTxn(ctx, scanner.TxnManager, func(ctx context.Context) error {
// free the mutex once transaction is complete
defer close(done)
// ensure no clashes of hashes
if scanned.New.Checksum != "" && scanned.Old.Checksum != scanned.New.Checksum {
- dupe, _ := r.Gallery().FindByChecksum(retGallery.Checksum)
+ dupe, _ := scanner.CreatorUpdater.FindByChecksum(ctx, retGallery.Checksum)
if dupe != nil {
return fmt.Errorf("MD5 for file %s is the same as that of %s", path, dupe.Path.String)
}
}
- retGallery, err = r.Gallery().Update(*retGallery)
+ retGallery, err = scanner.CreatorUpdater.Update(ctx, *retGallery)
return err
}); err != nil {
return nil, false, err
@@ -116,10 +124,10 @@ func (scanner *Scanner) ScanNew(ctx context.Context, file file.SourceFile) (retG
scanner.MutexManager.Claim(mutexType, checksum, done)
defer close(done)
- if err := scanner.TxnManager.WithTxn(ctx, func(r models.Repository) error {
- qb := r.Gallery()
+ if err := txn.WithTxn(ctx, scanner.TxnManager, func(ctx context.Context) error {
+ qb := scanner.CreatorUpdater
- g, _ = qb.FindByChecksum(checksum)
+ g, _ = qb.FindByChecksum(ctx, checksum)
if g != nil {
exists, _ := fsutil.FileExists(g.Path.String)
if !scanner.CaseSensitiveFs {
@@ -138,7 +146,7 @@ func (scanner *Scanner) ScanNew(ctx context.Context, file file.SourceFile) (retG
String: path,
Valid: true,
}
- g, err = qb.Update(*g)
+ g, err = qb.Update(ctx, *g)
if err != nil {
return err
}
@@ -167,7 +175,7 @@ func (scanner *Scanner) ScanNew(ctx context.Context, file file.SourceFile) (retG
}
logger.Infof("%s doesn't exist. Creating new item...", path)
- g, err = qb.Create(*g)
+ g, err = qb.Create(ctx, *g)
if err != nil {
return err
}
diff --git a/pkg/gallery/update.go b/pkg/gallery/update.go
index 4c16793ca..1c94faea6 100644
--- a/pkg/gallery/update.go
+++ b/pkg/gallery/update.go
@@ -1,29 +1,50 @@
package gallery
import (
+ "context"
+
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/sliceutil/intslice"
)
-func UpdateFileModTime(qb models.GalleryWriter, id int, modTime models.NullSQLiteTimestamp) (*models.Gallery, error) {
- return qb.UpdatePartial(models.GalleryPartial{
+type PartialUpdater interface {
+ UpdatePartial(ctx context.Context, updatedGallery models.GalleryPartial) (*models.Gallery, error)
+}
+
+type ImageUpdater interface {
+ GetImageIDs(ctx context.Context, galleryID int) ([]int, error)
+ UpdateImages(ctx context.Context, galleryID int, imageIDs []int) error
+}
+
+type PerformerUpdater interface {
+ GetPerformerIDs(ctx context.Context, galleryID int) ([]int, error)
+ UpdatePerformers(ctx context.Context, galleryID int, performerIDs []int) error
+}
+
+type TagUpdater interface {
+ GetTagIDs(ctx context.Context, galleryID int) ([]int, error)
+ UpdateTags(ctx context.Context, galleryID int, tagIDs []int) error
+}
+
+func UpdateFileModTime(ctx context.Context, qb PartialUpdater, id int, modTime models.NullSQLiteTimestamp) (*models.Gallery, error) {
+ return qb.UpdatePartial(ctx, models.GalleryPartial{
ID: id,
FileModTime: &modTime,
})
}
-func AddImage(qb models.GalleryReaderWriter, galleryID int, imageID int) error {
- imageIDs, err := qb.GetImageIDs(galleryID)
+func AddImage(ctx context.Context, qb ImageUpdater, galleryID int, imageID int) error {
+ imageIDs, err := qb.GetImageIDs(ctx, galleryID)
if err != nil {
return err
}
imageIDs = intslice.IntAppendUnique(imageIDs, imageID)
- return qb.UpdateImages(galleryID, imageIDs)
+ return qb.UpdateImages(ctx, galleryID, imageIDs)
}
-func AddPerformer(qb models.GalleryReaderWriter, id int, performerID int) (bool, error) {
- performerIDs, err := qb.GetPerformerIDs(id)
+func AddPerformer(ctx context.Context, qb PerformerUpdater, id int, performerID int) (bool, error) {
+ performerIDs, err := qb.GetPerformerIDs(ctx, id)
if err != nil {
return false, err
}
@@ -32,7 +53,7 @@ func AddPerformer(qb models.GalleryReaderWriter, id int, performerID int) (bool,
performerIDs = intslice.IntAppendUnique(performerIDs, performerID)
if len(performerIDs) != oldLen {
- if err := qb.UpdatePerformers(id, performerIDs); err != nil {
+ if err := qb.UpdatePerformers(ctx, id, performerIDs); err != nil {
return false, err
}
@@ -42,8 +63,8 @@ func AddPerformer(qb models.GalleryReaderWriter, id int, performerID int) (bool,
return false, nil
}
-func AddTag(qb models.GalleryReaderWriter, id int, tagID int) (bool, error) {
- tagIDs, err := qb.GetTagIDs(id)
+func AddTag(ctx context.Context, qb TagUpdater, id int, tagID int) (bool, error) {
+ tagIDs, err := qb.GetTagIDs(ctx, id)
if err != nil {
return false, err
}
@@ -52,7 +73,7 @@ func AddTag(qb models.GalleryReaderWriter, id int, tagID int) (bool, error) {
tagIDs = intslice.IntAppendUnique(tagIDs, tagID)
if len(tagIDs) != oldLen {
- if err := qb.UpdateTags(id, tagIDs); err != nil {
+ if err := qb.UpdateTags(ctx, id, tagIDs); err != nil {
return false, err
}
diff --git a/pkg/image/delete.go b/pkg/image/delete.go
index 35ab3704b..8e2ca8237 100644
--- a/pkg/image/delete.go
+++ b/pkg/image/delete.go
@@ -1,6 +1,8 @@
package image
import (
+ "context"
+
"github.com/stashapp/stash/pkg/file"
"github.com/stashapp/stash/pkg/fsutil"
"github.com/stashapp/stash/pkg/models"
@@ -8,7 +10,7 @@ import (
)
type Destroyer interface {
- Destroy(id int) error
+ Destroy(ctx context.Context, id int) error
}
// FileDeleter is an extension of file.Deleter that handles deletion of image files.
@@ -30,7 +32,7 @@ func (d *FileDeleter) MarkGeneratedFiles(image *models.Image) error {
}
// Destroy destroys an image, optionally marking the file and generated files for deletion.
-func Destroy(i *models.Image, destroyer Destroyer, fileDeleter *FileDeleter, deleteGenerated, deleteFile bool) error {
+func Destroy(ctx context.Context, i *models.Image, destroyer Destroyer, fileDeleter *FileDeleter, deleteGenerated, deleteFile bool) error {
// don't try to delete if the image is in a zip file
if deleteFile && !file.IsZipPath(i.Path) {
if err := fileDeleter.Files([]string{i.Path}); err != nil {
@@ -44,5 +46,5 @@ func Destroy(i *models.Image, destroyer Destroyer, fileDeleter *FileDeleter, del
}
}
- return destroyer.Destroy(i.ID)
+ return destroyer.Destroy(ctx, i.ID)
}
diff --git a/pkg/image/export.go b/pkg/image/export.go
index 3938a39bf..da7306bdb 100644
--- a/pkg/image/export.go
+++ b/pkg/image/export.go
@@ -1,9 +1,12 @@
package image
import (
+ "context"
+
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/models/json"
"github.com/stashapp/stash/pkg/models/jsonschema"
+ "github.com/stashapp/stash/pkg/studio"
)
// ToBasicJSON converts a image object into its JSON object equivalent. It
@@ -56,9 +59,9 @@ func getImageFileJSON(image *models.Image) *jsonschema.ImageFile {
// GetStudioName returns the name of the provided image's studio. It returns an
// empty string if there is no studio assigned to the image.
-func GetStudioName(reader models.StudioReader, image *models.Image) (string, error) {
+func GetStudioName(ctx context.Context, reader studio.Finder, image *models.Image) (string, error) {
if image.StudioID.Valid {
- studio, err := reader.Find(int(image.StudioID.Int64))
+ studio, err := reader.Find(ctx, int(image.StudioID.Int64))
if err != nil {
return "", err
}
diff --git a/pkg/image/export_test.go b/pkg/image/export_test.go
index 0a449c443..2aacac5ad 100644
--- a/pkg/image/export_test.go
+++ b/pkg/image/export_test.go
@@ -156,15 +156,15 @@ func TestGetStudioName(t *testing.T) {
studioErr := errors.New("error getting image")
- mockStudioReader.On("Find", studioID).Return(&models.Studio{
+ mockStudioReader.On("Find", testCtx, studioID).Return(&models.Studio{
Name: models.NullString(studioName),
}, nil).Once()
- mockStudioReader.On("Find", missingStudioID).Return(nil, nil).Once()
- mockStudioReader.On("Find", errStudioID).Return(nil, studioErr).Once()
+ mockStudioReader.On("Find", testCtx, missingStudioID).Return(nil, nil).Once()
+ mockStudioReader.On("Find", testCtx, errStudioID).Return(nil, studioErr).Once()
for i, s := range getStudioScenarios {
image := s.input
- json, err := GetStudioName(mockStudioReader, &image)
+ json, err := GetStudioName(testCtx, mockStudioReader, &image)
switch {
case !s.err && err != nil:
diff --git a/pkg/image/import.go b/pkg/image/import.go
index 78b60c4b1..d1de6b2a5 100644
--- a/pkg/image/import.go
+++ b/pkg/image/import.go
@@ -1,21 +1,33 @@
package image
import (
+ "context"
"database/sql"
"fmt"
"strings"
+ "github.com/stashapp/stash/pkg/gallery"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/models/jsonschema"
+ "github.com/stashapp/stash/pkg/performer"
"github.com/stashapp/stash/pkg/sliceutil/stringslice"
+ "github.com/stashapp/stash/pkg/studio"
+ "github.com/stashapp/stash/pkg/tag"
)
+type FullCreatorUpdater interface {
+ FinderCreatorUpdater
+ UpdatePerformers(ctx context.Context, imageID int, performerIDs []int) error
+ UpdateTags(ctx context.Context, imageID int, tagIDs []int) error
+ UpdateGalleries(ctx context.Context, imageID int, galleryIDs []int) error
+}
+
type Importer struct {
- ReaderWriter models.ImageReaderWriter
- StudioWriter models.StudioReaderWriter
- GalleryWriter models.GalleryReaderWriter
- PerformerWriter models.PerformerReaderWriter
- TagWriter models.TagReaderWriter
+ ReaderWriter FullCreatorUpdater
+ StudioWriter studio.NameFinderCreator
+ GalleryWriter gallery.ChecksumsFinder
+ PerformerWriter performer.NameFinderCreator
+ TagWriter tag.NameFinderCreator
Input jsonschema.Image
Path string
MissingRefBehaviour models.ImportMissingRefEnum
@@ -27,22 +39,22 @@ type Importer struct {
tags []*models.Tag
}
-func (i *Importer) PreImport() error {
+func (i *Importer) PreImport(ctx context.Context) error {
i.image = i.imageJSONToImage(i.Input)
- if err := i.populateStudio(); err != nil {
+ if err := i.populateStudio(ctx); err != nil {
return err
}
- if err := i.populateGalleries(); err != nil {
+ if err := i.populateGalleries(ctx); err != nil {
return err
}
- if err := i.populatePerformers(); err != nil {
+ if err := i.populatePerformers(ctx); err != nil {
return err
}
- if err := i.populateTags(); err != nil {
+ if err := i.populateTags(ctx); err != nil {
return err
}
@@ -82,9 +94,9 @@ func (i *Importer) imageJSONToImage(imageJSON jsonschema.Image) models.Image {
return newImage
}
-func (i *Importer) populateStudio() error {
+func (i *Importer) populateStudio(ctx context.Context) error {
if i.Input.Studio != "" {
- studio, err := i.StudioWriter.FindByName(i.Input.Studio, false)
+ studio, err := i.StudioWriter.FindByName(ctx, i.Input.Studio, false)
if err != nil {
return fmt.Errorf("error finding studio by name: %v", err)
}
@@ -99,7 +111,7 @@ func (i *Importer) populateStudio() error {
}
if i.MissingRefBehaviour == models.ImportMissingRefEnumCreate {
- studioID, err := i.createStudio(i.Input.Studio)
+ studioID, err := i.createStudio(ctx, i.Input.Studio)
if err != nil {
return err
}
@@ -116,10 +128,10 @@ func (i *Importer) populateStudio() error {
return nil
}
-func (i *Importer) createStudio(name string) (int, error) {
+func (i *Importer) createStudio(ctx context.Context, name string) (int, error) {
newStudio := *models.NewStudio(name)
- created, err := i.StudioWriter.Create(newStudio)
+ created, err := i.StudioWriter.Create(ctx, newStudio)
if err != nil {
return 0, err
}
@@ -127,14 +139,14 @@ func (i *Importer) createStudio(name string) (int, error) {
return created.ID, nil
}
-func (i *Importer) populateGalleries() error {
+func (i *Importer) populateGalleries(ctx context.Context) error {
for _, checksum := range i.Input.Galleries {
- gallery, err := i.GalleryWriter.FindByChecksum(checksum)
+ gallery, err := i.GalleryWriter.FindByChecksums(ctx, []string{checksum})
if err != nil {
return fmt.Errorf("error finding gallery: %v", err)
}
- if gallery == nil {
+ if len(gallery) == 0 {
if i.MissingRefBehaviour == models.ImportMissingRefEnumFail {
return fmt.Errorf("image gallery '%s' not found", i.Input.Studio)
}
@@ -144,17 +156,17 @@ func (i *Importer) populateGalleries() error {
continue
}
} else {
- i.galleries = append(i.galleries, gallery)
+ i.galleries = append(i.galleries, gallery[0])
}
}
return nil
}
-func (i *Importer) populatePerformers() error {
+func (i *Importer) populatePerformers(ctx context.Context) error {
if len(i.Input.Performers) > 0 {
names := i.Input.Performers
- performers, err := i.PerformerWriter.FindByNames(names, false)
+ performers, err := i.PerformerWriter.FindByNames(ctx, names, false)
if err != nil {
return err
}
@@ -177,7 +189,7 @@ func (i *Importer) populatePerformers() error {
}
if i.MissingRefBehaviour == models.ImportMissingRefEnumCreate {
- createdPerformers, err := i.createPerformers(missingPerformers)
+ createdPerformers, err := i.createPerformers(ctx, missingPerformers)
if err != nil {
return fmt.Errorf("error creating image performers: %v", err)
}
@@ -194,12 +206,12 @@ func (i *Importer) populatePerformers() error {
return nil
}
-func (i *Importer) createPerformers(names []string) ([]*models.Performer, error) {
+func (i *Importer) createPerformers(ctx context.Context, names []string) ([]*models.Performer, error) {
var ret []*models.Performer
for _, name := range names {
newPerformer := *models.NewPerformer(name)
- created, err := i.PerformerWriter.Create(newPerformer)
+ created, err := i.PerformerWriter.Create(ctx, newPerformer)
if err != nil {
return nil, err
}
@@ -210,10 +222,10 @@ func (i *Importer) createPerformers(names []string) ([]*models.Performer, error)
return ret, nil
}
-func (i *Importer) populateTags() error {
+func (i *Importer) populateTags(ctx context.Context) error {
if len(i.Input.Tags) > 0 {
- tags, err := importTags(i.TagWriter, i.Input.Tags, i.MissingRefBehaviour)
+ tags, err := importTags(ctx, i.TagWriter, i.Input.Tags, i.MissingRefBehaviour)
if err != nil {
return err
}
@@ -224,14 +236,14 @@ func (i *Importer) populateTags() error {
return nil
}
-func (i *Importer) PostImport(id int) error {
+func (i *Importer) PostImport(ctx context.Context, id int) error {
if len(i.galleries) > 0 {
var galleryIDs []int
for _, g := range i.galleries {
galleryIDs = append(galleryIDs, g.ID)
}
- if err := i.ReaderWriter.UpdateGalleries(id, galleryIDs); err != nil {
+ if err := i.ReaderWriter.UpdateGalleries(ctx, id, galleryIDs); err != nil {
return fmt.Errorf("failed to associate galleries: %v", err)
}
}
@@ -242,7 +254,7 @@ func (i *Importer) PostImport(id int) error {
performerIDs = append(performerIDs, performer.ID)
}
- if err := i.ReaderWriter.UpdatePerformers(id, performerIDs); err != nil {
+ if err := i.ReaderWriter.UpdatePerformers(ctx, id, performerIDs); err != nil {
return fmt.Errorf("failed to associate performers: %v", err)
}
}
@@ -252,7 +264,7 @@ func (i *Importer) PostImport(id int) error {
for _, t := range i.tags {
tagIDs = append(tagIDs, t.ID)
}
- if err := i.ReaderWriter.UpdateTags(id, tagIDs); err != nil {
+ if err := i.ReaderWriter.UpdateTags(ctx, id, tagIDs); err != nil {
return fmt.Errorf("failed to associate tags: %v", err)
}
}
@@ -264,10 +276,10 @@ func (i *Importer) Name() string {
return i.Path
}
-func (i *Importer) FindExistingID() (*int, error) {
+func (i *Importer) FindExistingID(ctx context.Context) (*int, error) {
var existing *models.Image
var err error
- existing, err = i.ReaderWriter.FindByChecksum(i.Input.Checksum)
+ existing, err = i.ReaderWriter.FindByChecksum(ctx, i.Input.Checksum)
if err != nil {
return nil, err
@@ -281,8 +293,8 @@ func (i *Importer) FindExistingID() (*int, error) {
return nil, nil
}
-func (i *Importer) Create() (*int, error) {
- created, err := i.ReaderWriter.Create(i.image)
+func (i *Importer) Create(ctx context.Context) (*int, error) {
+ created, err := i.ReaderWriter.Create(ctx, i.image)
if err != nil {
return nil, fmt.Errorf("error creating image: %v", err)
}
@@ -292,11 +304,11 @@ func (i *Importer) Create() (*int, error) {
return &id, nil
}
-func (i *Importer) Update(id int) error {
+func (i *Importer) Update(ctx context.Context, id int) error {
image := i.image
image.ID = id
i.ID = id
- _, err := i.ReaderWriter.UpdateFull(image)
+ _, err := i.ReaderWriter.UpdateFull(ctx, image)
if err != nil {
return fmt.Errorf("error updating existing image: %v", err)
}
@@ -304,8 +316,8 @@ func (i *Importer) Update(id int) error {
return nil
}
-func importTags(tagWriter models.TagReaderWriter, names []string, missingRefBehaviour models.ImportMissingRefEnum) ([]*models.Tag, error) {
- tags, err := tagWriter.FindByNames(names, false)
+func importTags(ctx context.Context, tagWriter tag.NameFinderCreator, names []string, missingRefBehaviour models.ImportMissingRefEnum) ([]*models.Tag, error) {
+ tags, err := tagWriter.FindByNames(ctx, names, false)
if err != nil {
return nil, err
}
@@ -325,7 +337,7 @@ func importTags(tagWriter models.TagReaderWriter, names []string, missingRefBeha
}
if missingRefBehaviour == models.ImportMissingRefEnumCreate {
- createdTags, err := createTags(tagWriter, missingTags)
+ createdTags, err := createTags(ctx, tagWriter, missingTags)
if err != nil {
return nil, fmt.Errorf("error creating tags: %v", err)
}
@@ -339,12 +351,12 @@ func importTags(tagWriter models.TagReaderWriter, names []string, missingRefBeha
return tags, nil
}
-func createTags(tagWriter models.TagWriter, names []string) ([]*models.Tag, error) {
+func createTags(ctx context.Context, tagWriter tag.NameFinderCreator, names []string) ([]*models.Tag, error) {
var ret []*models.Tag
for _, name := range names {
newTag := *models.NewTag(name)
- created, err := tagWriter.Create(newTag)
+ created, err := tagWriter.Create(ctx, newTag)
if err != nil {
return nil, err
}
diff --git a/pkg/image/import_test.go b/pkg/image/import_test.go
index 156ec96d2..856c338c1 100644
--- a/pkg/image/import_test.go
+++ b/pkg/image/import_test.go
@@ -1,6 +1,7 @@
package image
import (
+ "context"
"errors"
"testing"
@@ -47,6 +48,8 @@ const (
errChecksum = "errChecksum"
)
+var testCtx = context.Background()
+
func TestImporterName(t *testing.T) {
i := Importer{
Path: path,
@@ -61,7 +64,7 @@ func TestImporterPreImport(t *testing.T) {
Path: path,
}
- err := i.PreImport()
+ err := i.PreImport(testCtx)
assert.Nil(t, err)
}
@@ -76,17 +79,17 @@ func TestImporterPreImportWithStudio(t *testing.T) {
},
}
- studioReaderWriter.On("FindByName", existingStudioName, false).Return(&models.Studio{
+ studioReaderWriter.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{
ID: existingStudioID,
}, nil).Once()
- studioReaderWriter.On("FindByName", existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once()
+ studioReaderWriter.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once()
- err := i.PreImport()
+ err := i.PreImport(testCtx)
assert.Nil(t, err)
assert.Equal(t, int64(existingStudioID), i.image.StudioID.Int64)
i.Input.Studio = existingStudioErr
- err = i.PreImport()
+ err = i.PreImport(testCtx)
assert.NotNil(t, err)
studioReaderWriter.AssertExpectations(t)
@@ -104,20 +107,20 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail,
}
- studioReaderWriter.On("FindByName", missingStudioName, false).Return(nil, nil).Times(3)
- studioReaderWriter.On("Create", mock.AnythingOfType("models.Studio")).Return(&models.Studio{
+ studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3)
+ studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Studio")).Return(&models.Studio{
ID: existingStudioID,
}, nil)
- err := i.PreImport()
+ err := i.PreImport(testCtx)
assert.NotNil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore
- err = i.PreImport()
+ err = i.PreImport(testCtx)
assert.Nil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumCreate
- err = i.PreImport()
+ err = i.PreImport(testCtx)
assert.Nil(t, err)
assert.Equal(t, int64(existingStudioID), i.image.StudioID.Int64)
@@ -136,10 +139,10 @@ func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
}
- studioReaderWriter.On("FindByName", missingStudioName, false).Return(nil, nil).Once()
- studioReaderWriter.On("Create", mock.AnythingOfType("models.Studio")).Return(nil, errors.New("Create error"))
+ studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once()
+ studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Studio")).Return(nil, errors.New("Create error"))
- err := i.PreImport()
+ err := i.PreImport(testCtx)
assert.NotNil(t, err)
}
@@ -156,12 +159,12 @@ func TestImporterPreImportWithGallery(t *testing.T) {
},
}
- galleryReaderWriter.On("FindByChecksum", existingGalleryChecksum).Return(&models.Gallery{
+ galleryReaderWriter.On("FindByChecksums", testCtx, []string{existingGalleryChecksum}).Return([]*models.Gallery{{
ID: existingGalleryID,
- }, nil).Once()
- galleryReaderWriter.On("FindByChecksum", existingGalleryErr).Return(nil, errors.New("FindByChecksum error")).Once()
+ }}, nil).Once()
+ galleryReaderWriter.On("FindByChecksums", testCtx, []string{existingGalleryErr}).Return(nil, errors.New("FindByChecksum error")).Once()
- err := i.PreImport()
+ err := i.PreImport(testCtx)
assert.Nil(t, err)
assert.Equal(t, existingGalleryID, i.galleries[0].ID)
@@ -169,7 +172,7 @@ func TestImporterPreImportWithGallery(t *testing.T) {
existingGalleryErr,
}
- err = i.PreImport()
+ err = i.PreImport(testCtx)
assert.NotNil(t, err)
galleryReaderWriter.AssertExpectations(t)
@@ -189,18 +192,18 @@ func TestImporterPreImportWithMissingGallery(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail,
}
- galleryReaderWriter.On("FindByChecksum", missingGalleryChecksum).Return(nil, nil).Times(3)
+ galleryReaderWriter.On("FindByChecksums", testCtx, []string{missingGalleryChecksum}).Return(nil, nil).Times(3)
- err := i.PreImport()
+ err := i.PreImport(testCtx)
assert.NotNil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore
- err = i.PreImport()
+ err = i.PreImport(testCtx)
assert.Nil(t, err)
assert.Nil(t, i.galleries)
i.MissingRefBehaviour = models.ImportMissingRefEnumCreate
- err = i.PreImport()
+ err = i.PreImport(testCtx)
assert.Nil(t, err)
assert.Nil(t, i.galleries)
@@ -221,20 +224,20 @@ func TestImporterPreImportWithPerformer(t *testing.T) {
},
}
- performerReaderWriter.On("FindByNames", []string{existingPerformerName}, false).Return([]*models.Performer{
+ performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerName}, false).Return([]*models.Performer{
{
ID: existingPerformerID,
Name: models.NullString(existingPerformerName),
},
}, nil).Once()
- performerReaderWriter.On("FindByNames", []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once()
+ performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once()
- err := i.PreImport()
+ err := i.PreImport(testCtx)
assert.Nil(t, err)
assert.Equal(t, existingPerformerID, i.performers[0].ID)
i.Input.Performers = []string{existingPerformerErr}
- err = i.PreImport()
+ err = i.PreImport(testCtx)
assert.NotNil(t, err)
performerReaderWriter.AssertExpectations(t)
@@ -254,20 +257,20 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail,
}
- performerReaderWriter.On("FindByNames", []string{missingPerformerName}, false).Return(nil, nil).Times(3)
- performerReaderWriter.On("Create", mock.AnythingOfType("models.Performer")).Return(&models.Performer{
+ performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3)
+ performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Performer")).Return(&models.Performer{
ID: existingPerformerID,
}, nil)
- err := i.PreImport()
+ err := i.PreImport(testCtx)
assert.NotNil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore
- err = i.PreImport()
+ err = i.PreImport(testCtx)
assert.Nil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumCreate
- err = i.PreImport()
+ err = i.PreImport(testCtx)
assert.Nil(t, err)
assert.Equal(t, existingPerformerID, i.performers[0].ID)
@@ -288,10 +291,10 @@ func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
}
- performerReaderWriter.On("FindByNames", []string{missingPerformerName}, false).Return(nil, nil).Once()
- performerReaderWriter.On("Create", mock.AnythingOfType("models.Performer")).Return(nil, errors.New("Create error"))
+ performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once()
+ performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Performer")).Return(nil, errors.New("Create error"))
- err := i.PreImport()
+ err := i.PreImport(testCtx)
assert.NotNil(t, err)
}
@@ -309,20 +312,20 @@ func TestImporterPreImportWithTag(t *testing.T) {
},
}
- tagReaderWriter.On("FindByNames", []string{existingTagName}, false).Return([]*models.Tag{
+ tagReaderWriter.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{
{
ID: existingTagID,
Name: existingTagName,
},
}, nil).Once()
- tagReaderWriter.On("FindByNames", []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once()
+ tagReaderWriter.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once()
- err := i.PreImport()
+ err := i.PreImport(testCtx)
assert.Nil(t, err)
assert.Equal(t, existingTagID, i.tags[0].ID)
i.Input.Tags = []string{existingTagErr}
- err = i.PreImport()
+ err = i.PreImport(testCtx)
assert.NotNil(t, err)
tagReaderWriter.AssertExpectations(t)
@@ -342,20 +345,20 @@ func TestImporterPreImportWithMissingTag(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail,
}
- tagReaderWriter.On("FindByNames", []string{missingTagName}, false).Return(nil, nil).Times(3)
- tagReaderWriter.On("Create", mock.AnythingOfType("models.Tag")).Return(&models.Tag{
+ tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3)
+ tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Tag")).Return(&models.Tag{
ID: existingTagID,
}, nil)
- err := i.PreImport()
+ err := i.PreImport(testCtx)
assert.NotNil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore
- err = i.PreImport()
+ err = i.PreImport(testCtx)
assert.Nil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumCreate
- err = i.PreImport()
+ err = i.PreImport(testCtx)
assert.Nil(t, err)
assert.Equal(t, existingTagID, i.tags[0].ID)
@@ -376,10 +379,10 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
}
- tagReaderWriter.On("FindByNames", []string{missingTagName}, false).Return(nil, nil).Once()
- tagReaderWriter.On("Create", mock.AnythingOfType("models.Tag")).Return(nil, errors.New("Create error"))
+ tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once()
+ tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Tag")).Return(nil, errors.New("Create error"))
- err := i.PreImport()
+ err := i.PreImport(testCtx)
assert.NotNil(t, err)
}
@@ -397,13 +400,13 @@ func TestImporterPostImportUpdateGallery(t *testing.T) {
updateErr := errors.New("UpdateGalleries error")
- readerWriter.On("UpdateGalleries", imageID, []int{existingGalleryID}).Return(nil).Once()
- readerWriter.On("UpdateGalleries", errGalleriesID, mock.AnythingOfType("[]int")).Return(updateErr).Once()
+ readerWriter.On("UpdateGalleries", testCtx, imageID, []int{existingGalleryID}).Return(nil).Once()
+ readerWriter.On("UpdateGalleries", testCtx, errGalleriesID, mock.AnythingOfType("[]int")).Return(updateErr).Once()
- err := i.PostImport(imageID)
+ err := i.PostImport(testCtx, imageID)
assert.Nil(t, err)
- err = i.PostImport(errGalleriesID)
+ err = i.PostImport(testCtx, errGalleriesID)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
@@ -423,13 +426,13 @@ func TestImporterPostImportUpdatePerformers(t *testing.T) {
updateErr := errors.New("UpdatePerformers error")
- readerWriter.On("UpdatePerformers", imageID, []int{existingPerformerID}).Return(nil).Once()
- readerWriter.On("UpdatePerformers", errPerformersID, mock.AnythingOfType("[]int")).Return(updateErr).Once()
+ readerWriter.On("UpdatePerformers", testCtx, imageID, []int{existingPerformerID}).Return(nil).Once()
+ readerWriter.On("UpdatePerformers", testCtx, errPerformersID, mock.AnythingOfType("[]int")).Return(updateErr).Once()
- err := i.PostImport(imageID)
+ err := i.PostImport(testCtx, imageID)
assert.Nil(t, err)
- err = i.PostImport(errPerformersID)
+ err = i.PostImport(testCtx, errPerformersID)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
@@ -449,13 +452,13 @@ func TestImporterPostImportUpdateTags(t *testing.T) {
updateErr := errors.New("UpdateTags error")
- readerWriter.On("UpdateTags", imageID, []int{existingTagID}).Return(nil).Once()
- readerWriter.On("UpdateTags", errTagsID, mock.AnythingOfType("[]int")).Return(updateErr).Once()
+ readerWriter.On("UpdateTags", testCtx, imageID, []int{existingTagID}).Return(nil).Once()
+ readerWriter.On("UpdateTags", testCtx, errTagsID, mock.AnythingOfType("[]int")).Return(updateErr).Once()
- err := i.PostImport(imageID)
+ err := i.PostImport(testCtx, imageID)
assert.Nil(t, err)
- err = i.PostImport(errTagsID)
+ err = i.PostImport(testCtx, errTagsID)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
@@ -473,23 +476,23 @@ func TestImporterFindExistingID(t *testing.T) {
}
expectedErr := errors.New("FindBy* error")
- readerWriter.On("FindByChecksum", missingChecksum).Return(nil, nil).Once()
- readerWriter.On("FindByChecksum", checksum).Return(&models.Image{
+ readerWriter.On("FindByChecksum", testCtx, missingChecksum).Return(nil, nil).Once()
+ readerWriter.On("FindByChecksum", testCtx, checksum).Return(&models.Image{
ID: existingImageID,
}, nil).Once()
- readerWriter.On("FindByChecksum", errChecksum).Return(nil, expectedErr).Once()
+ readerWriter.On("FindByChecksum", testCtx, errChecksum).Return(nil, expectedErr).Once()
- id, err := i.FindExistingID()
+ id, err := i.FindExistingID(testCtx)
assert.Nil(t, id)
assert.Nil(t, err)
i.Input.Checksum = checksum
- id, err = i.FindExistingID()
+ id, err = i.FindExistingID(testCtx)
assert.Equal(t, existingImageID, *id)
assert.Nil(t, err)
i.Input.Checksum = errChecksum
- id, err = i.FindExistingID()
+ id, err = i.FindExistingID(testCtx)
assert.Nil(t, id)
assert.NotNil(t, err)
@@ -513,18 +516,18 @@ func TestCreate(t *testing.T) {
}
errCreate := errors.New("Create error")
- readerWriter.On("Create", image).Return(&models.Image{
+ readerWriter.On("Create", testCtx, image).Return(&models.Image{
ID: imageID,
}, nil).Once()
- readerWriter.On("Create", imageErr).Return(nil, errCreate).Once()
+ readerWriter.On("Create", testCtx, imageErr).Return(nil, errCreate).Once()
- id, err := i.Create()
+ id, err := i.Create(testCtx)
assert.Equal(t, imageID, *id)
assert.Nil(t, err)
assert.Equal(t, imageID, i.ID)
i.image = imageErr
- id, err = i.Create()
+ id, err = i.Create(testCtx)
assert.Nil(t, id)
assert.NotNil(t, err)
@@ -551,9 +554,9 @@ func TestUpdate(t *testing.T) {
// id needs to be set for the mock input
image.ID = imageID
- readerWriter.On("UpdateFull", image).Return(nil, nil).Once()
+ readerWriter.On("UpdateFull", testCtx, image).Return(nil, nil).Once()
- err := i.Update(imageID)
+ err := i.Update(testCtx, imageID)
assert.Nil(t, err)
assert.Equal(t, imageID, i.ID)
@@ -561,9 +564,9 @@ func TestUpdate(t *testing.T) {
// need to set id separately
imageErr.ID = errImageID
- readerWriter.On("UpdateFull", imageErr).Return(nil, errUpdate).Once()
+ readerWriter.On("UpdateFull", testCtx, imageErr).Return(nil, errUpdate).Once()
- err = i.Update(errImageID)
+ err = i.Update(testCtx, errImageID)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
diff --git a/pkg/image/query.go b/pkg/image/query.go
index 058d0a842..36ed3a8c3 100644
--- a/pkg/image/query.go
+++ b/pkg/image/query.go
@@ -1,13 +1,18 @@
package image
import (
+ "context"
"strconv"
"github.com/stashapp/stash/pkg/models"
)
type Queryer interface {
- Query(options models.ImageQueryOptions) (*models.ImageQueryResult, error)
+ Query(ctx context.Context, options models.ImageQueryOptions) (*models.ImageQueryResult, error)
+}
+
+type CountQueryer interface {
+ QueryCount(ctx context.Context, imageFilter *models.ImageFilterType, findFilter *models.FindFilterType) (int, error)
}
// QueryOptions returns a ImageQueryResult populated with the provided filters.
@@ -22,13 +27,13 @@ func QueryOptions(imageFilter *models.ImageFilterType, findFilter *models.FindFi
}
// Query queries for images using the provided filters.
-func Query(qb Queryer, imageFilter *models.ImageFilterType, findFilter *models.FindFilterType) ([]*models.Image, error) {
- result, err := qb.Query(QueryOptions(imageFilter, findFilter, false))
+func Query(ctx context.Context, qb Queryer, imageFilter *models.ImageFilterType, findFilter *models.FindFilterType) ([]*models.Image, error) {
+ result, err := qb.Query(ctx, QueryOptions(imageFilter, findFilter, false))
if err != nil {
return nil, err
}
- images, err := result.Resolve()
+ images, err := result.Resolve(ctx)
if err != nil {
return nil, err
}
@@ -36,7 +41,7 @@ func Query(qb Queryer, imageFilter *models.ImageFilterType, findFilter *models.F
return images, nil
}
-func CountByPerformerID(r models.ImageReader, id int) (int, error) {
+func CountByPerformerID(ctx context.Context, r CountQueryer, id int) (int, error) {
filter := &models.ImageFilterType{
Performers: &models.MultiCriterionInput{
Value: []string{strconv.Itoa(id)},
@@ -44,10 +49,10 @@ func CountByPerformerID(r models.ImageReader, id int) (int, error) {
},
}
- return r.QueryCount(filter, nil)
+ return r.QueryCount(ctx, filter, nil)
}
-func CountByStudioID(r models.ImageReader, id int) (int, error) {
+func CountByStudioID(ctx context.Context, r CountQueryer, id int) (int, error) {
filter := &models.ImageFilterType{
Studios: &models.HierarchicalMultiCriterionInput{
Value: []string{strconv.Itoa(id)},
@@ -55,10 +60,10 @@ func CountByStudioID(r models.ImageReader, id int) (int, error) {
},
}
- return r.QueryCount(filter, nil)
+ return r.QueryCount(ctx, filter, nil)
}
-func CountByTagID(r models.ImageReader, id int) (int, error) {
+func CountByTagID(ctx context.Context, r CountQueryer, id int) (int, error) {
filter := &models.ImageFilterType{
Tags: &models.HierarchicalMultiCriterionInput{
Value: []string{strconv.Itoa(id)},
@@ -66,10 +71,10 @@ func CountByTagID(r models.ImageReader, id int) (int, error) {
},
}
- return r.QueryCount(filter, nil)
+ return r.QueryCount(ctx, filter, nil)
}
-func FindByGalleryID(r models.ImageReader, galleryID int, sortBy string, sortDir models.SortDirectionEnum) ([]*models.Image, error) {
+func FindByGalleryID(ctx context.Context, r Queryer, galleryID int, sortBy string, sortDir models.SortDirectionEnum) ([]*models.Image, error) {
perPage := -1
findFilter := models.FindFilterType{
@@ -84,7 +89,7 @@ func FindByGalleryID(r models.ImageReader, galleryID int, sortBy string, sortDir
findFilter.Direction = &sortDir
}
- return Query(r, &models.ImageFilterType{
+ return Query(ctx, r, &models.ImageFilterType{
Galleries: &models.MultiCriterionInput{
Value: []string{strconv.Itoa(galleryID)},
Modifier: models.CriterionModifierIncludes,
diff --git a/pkg/image/scan.go b/pkg/image/scan.go
index 8fa2f24a6..751f41000 100644
--- a/pkg/image/scan.go
+++ b/pkg/image/scan.go
@@ -12,18 +12,27 @@ import (
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/models/paths"
"github.com/stashapp/stash/pkg/plugin"
+ "github.com/stashapp/stash/pkg/txn"
"github.com/stashapp/stash/pkg/utils"
)
const mutexType = "image"
+type FinderCreatorUpdater interface {
+ FindByChecksum(ctx context.Context, checksum string) (*models.Image, error)
+ Create(ctx context.Context, newImage models.Image) (*models.Image, error)
+ UpdateFull(ctx context.Context, updatedImage models.Image) (*models.Image, error)
+ Update(ctx context.Context, updatedImage models.ImagePartial) (*models.Image, error)
+}
+
type Scanner struct {
file.Scanner
StripFileExtension bool
CaseSensitiveFs bool
- TxnManager models.TransactionManager
+ TxnManager txn.Manager
+ CreatorUpdater FinderCreatorUpdater
Paths *paths.Paths
PluginCache *plugin.Cache
MutexManager *utils.MutexManager
@@ -71,20 +80,20 @@ func (scanner *Scanner) ScanExisting(ctx context.Context, existing file.FileBase
done := make(chan struct{})
scanner.MutexManager.Claim(mutexType, scanned.New.Checksum, done)
- if err := scanner.TxnManager.WithTxn(ctx, func(r models.Repository) error {
+ if err := txn.WithTxn(ctx, scanner.TxnManager, func(ctx context.Context) error {
// free the mutex once transaction is complete
defer close(done)
var err error
// ensure no clashes of hashes
if scanned.New.Checksum != "" && scanned.Old.Checksum != scanned.New.Checksum {
- dupe, _ := r.Image().FindByChecksum(i.Checksum)
+ dupe, _ := scanner.CreatorUpdater.FindByChecksum(ctx, i.Checksum)
if dupe != nil {
return fmt.Errorf("MD5 for file %s is the same as that of %s", path, dupe.Path)
}
}
- retImage, err = r.Image().UpdateFull(*i)
+ retImage, err = scanner.CreatorUpdater.UpdateFull(ctx, *i)
return err
}); err != nil {
return nil, err
@@ -121,9 +130,9 @@ func (scanner *Scanner) ScanNew(ctx context.Context, f file.SourceFile) (retImag
// check for image by checksum
var existingImage *models.Image
- if err := scanner.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
+ if err := txn.WithTxn(ctx, scanner.TxnManager, func(ctx context.Context) error {
var err error
- existingImage, err = r.Image().FindByChecksum(checksum)
+ existingImage, err = scanner.CreatorUpdater.FindByChecksum(ctx, checksum)
return err
}); err != nil {
return nil, err
@@ -151,8 +160,8 @@ func (scanner *Scanner) ScanNew(ctx context.Context, f file.SourceFile) (retImag
Path: &path,
}
- if err := scanner.TxnManager.WithTxn(ctx, func(r models.Repository) error {
- retImage, err = r.Image().Update(imagePartial)
+ if err := txn.WithTxn(ctx, scanner.TxnManager, func(ctx context.Context) error {
+ retImage, err = scanner.CreatorUpdater.Update(ctx, imagePartial)
return err
}); err != nil {
return nil, err
@@ -176,9 +185,9 @@ func (scanner *Scanner) ScanNew(ctx context.Context, f file.SourceFile) (retImag
return nil, err
}
- if err := scanner.TxnManager.WithTxn(ctx, func(r models.Repository) error {
+ if err := txn.WithTxn(ctx, scanner.TxnManager, func(ctx context.Context) error {
var err error
- retImage, err = r.Image().Create(newImage)
+ retImage, err = scanner.CreatorUpdater.Create(ctx, newImage)
return err
}); err != nil {
return nil, err
diff --git a/pkg/image/update.go b/pkg/image/update.go
index f8c43f965..1b1b22535 100644
--- a/pkg/image/update.go
+++ b/pkg/image/update.go
@@ -1,19 +1,35 @@
package image
import (
+ "context"
+
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/sliceutil/intslice"
)
-func UpdateFileModTime(qb models.ImageWriter, id int, modTime models.NullSQLiteTimestamp) (*models.Image, error) {
- return qb.Update(models.ImagePartial{
+type PartialUpdater interface {
+ Update(ctx context.Context, updatedImage models.ImagePartial) (*models.Image, error)
+}
+
+type PerformerUpdater interface {
+ GetPerformerIDs(ctx context.Context, imageID int) ([]int, error)
+ UpdatePerformers(ctx context.Context, imageID int, performerIDs []int) error
+}
+
+type TagUpdater interface {
+ GetTagIDs(ctx context.Context, imageID int) ([]int, error)
+ UpdateTags(ctx context.Context, imageID int, tagIDs []int) error
+}
+
+func UpdateFileModTime(ctx context.Context, qb PartialUpdater, id int, modTime models.NullSQLiteTimestamp) (*models.Image, error) {
+ return qb.Update(ctx, models.ImagePartial{
ID: id,
FileModTime: &modTime,
})
}
-func AddPerformer(qb models.ImageReaderWriter, id int, performerID int) (bool, error) {
- performerIDs, err := qb.GetPerformerIDs(id)
+func AddPerformer(ctx context.Context, qb PerformerUpdater, id int, performerID int) (bool, error) {
+ performerIDs, err := qb.GetPerformerIDs(ctx, id)
if err != nil {
return false, err
}
@@ -22,7 +38,7 @@ func AddPerformer(qb models.ImageReaderWriter, id int, performerID int) (bool, e
performerIDs = intslice.IntAppendUnique(performerIDs, performerID)
if len(performerIDs) != oldLen {
- if err := qb.UpdatePerformers(id, performerIDs); err != nil {
+ if err := qb.UpdatePerformers(ctx, id, performerIDs); err != nil {
return false, err
}
@@ -32,8 +48,8 @@ func AddPerformer(qb models.ImageReaderWriter, id int, performerID int) (bool, e
return false, nil
}
-func AddTag(qb models.ImageReaderWriter, id int, tagID int) (bool, error) {
- tagIDs, err := qb.GetTagIDs(id)
+func AddTag(ctx context.Context, qb TagUpdater, id int, tagID int) (bool, error) {
+ tagIDs, err := qb.GetTagIDs(ctx, id)
if err != nil {
return false, err
}
@@ -42,7 +58,7 @@ func AddTag(qb models.ImageReaderWriter, id int, tagID int) (bool, error) {
tagIDs = intslice.IntAppendUnique(tagIDs, tagID)
if len(tagIDs) != oldLen {
- if err := qb.UpdateTags(id, tagIDs); err != nil {
+ if err := qb.UpdateTags(ctx, id, tagIDs); err != nil {
return false, err
}
diff --git a/pkg/match/cache.go b/pkg/match/cache.go
index 6d7238809..06237c7f6 100644
--- a/pkg/match/cache.go
+++ b/pkg/match/cache.go
@@ -1,6 +1,10 @@
package match
-import "github.com/stashapp/stash/pkg/models"
+import (
+ "context"
+
+ "github.com/stashapp/stash/pkg/models"
+)
const singleFirstCharacterRegex = `^[\p{L}][.\-_ ]`
@@ -16,14 +20,14 @@ type Cache struct {
// against. This means that performers with single-letter words in their names could potentially
// be missed.
// This query is expensive, so it's queried once and cached, if the cache if provided.
-func getSingleLetterPerformers(c *Cache, reader models.PerformerReader) ([]*models.Performer, error) {
+func getSingleLetterPerformers(ctx context.Context, c *Cache, reader PerformerAutoTagQueryer) ([]*models.Performer, error) {
if c == nil {
c = &Cache{}
}
if c.singleCharPerformers == nil {
pp := -1
- performers, _, err := reader.Query(&models.PerformerFilterType{
+ performers, _, err := reader.Query(ctx, &models.PerformerFilterType{
Name: &models.StringCriterionInput{
Value: singleFirstCharacterRegex,
Modifier: models.CriterionModifierMatchesRegex,
@@ -49,14 +53,14 @@ func getSingleLetterPerformers(c *Cache, reader models.PerformerReader) ([]*mode
// getSingleLetterStudios returns all studios with names that start with single character words.
// See getSingleLetterPerformers for details.
-func getSingleLetterStudios(c *Cache, reader models.StudioReader) ([]*models.Studio, error) {
+func getSingleLetterStudios(ctx context.Context, c *Cache, reader StudioAutoTagQueryer) ([]*models.Studio, error) {
if c == nil {
c = &Cache{}
}
if c.singleCharStudios == nil {
pp := -1
- studios, _, err := reader.Query(&models.StudioFilterType{
+ studios, _, err := reader.Query(ctx, &models.StudioFilterType{
Name: &models.StringCriterionInput{
Value: singleFirstCharacterRegex,
Modifier: models.CriterionModifierMatchesRegex,
@@ -82,14 +86,14 @@ func getSingleLetterStudios(c *Cache, reader models.StudioReader) ([]*models.Stu
// getSingleLetterTags returns all tags with names that start with single character words.
// See getSingleLetterPerformers for details.
-func getSingleLetterTags(c *Cache, reader models.TagReader) ([]*models.Tag, error) {
+func getSingleLetterTags(ctx context.Context, c *Cache, reader TagAutoTagQueryer) ([]*models.Tag, error) {
if c == nil {
c = &Cache{}
}
if c.singleCharTags == nil {
pp := -1
- tags, _, err := reader.Query(&models.TagFilterType{
+ tags, _, err := reader.Query(ctx, &models.TagFilterType{
Name: &models.StringCriterionInput{
Value: singleFirstCharacterRegex,
Modifier: models.CriterionModifierMatchesRegex,
diff --git a/pkg/match/path.go b/pkg/match/path.go
index 4f20423dd..a20678834 100644
--- a/pkg/match/path.go
+++ b/pkg/match/path.go
@@ -1,6 +1,7 @@
package match
import (
+ "context"
"fmt"
"path/filepath"
"regexp"
@@ -12,6 +13,8 @@ import (
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/scene"
"github.com/stashapp/stash/pkg/sliceutil/stringslice"
+ "github.com/stashapp/stash/pkg/studio"
+ "github.com/stashapp/stash/pkg/tag"
)
const (
@@ -24,6 +27,23 @@ const (
var separatorRE = regexp.MustCompile(separatorPattern)
+type PerformerAutoTagQueryer interface {
+ Query(ctx context.Context, performerFilter *models.PerformerFilterType, findFilter *models.FindFilterType) ([]*models.Performer, int, error)
+ QueryForAutoTag(ctx context.Context, words []string) ([]*models.Performer, error)
+}
+
+type StudioAutoTagQueryer interface {
+ QueryForAutoTag(ctx context.Context, words []string) ([]*models.Studio, error)
+ studio.Queryer
+ GetAliases(ctx context.Context, studioID int) ([]string, error)
+}
+
+type TagAutoTagQueryer interface {
+ QueryForAutoTag(ctx context.Context, words []string) ([]*models.Tag, error)
+ tag.Queryer
+ GetAliases(ctx context.Context, tagID int) ([]string, error)
+}
+
func getPathQueryRegex(name string) string {
// escape specific regex characters
name = regexp.QuoteMeta(name)
@@ -124,13 +144,13 @@ func regexpMatchesPath(r *regexp.Regexp, path string) int {
return found[len(found)-1][0]
}
-func getPerformers(words []string, performerReader models.PerformerReader, cache *Cache) ([]*models.Performer, error) {
- performers, err := performerReader.QueryForAutoTag(words)
+func getPerformers(ctx context.Context, words []string, performerReader PerformerAutoTagQueryer, cache *Cache) ([]*models.Performer, error) {
+ performers, err := performerReader.QueryForAutoTag(ctx, words)
if err != nil {
return nil, err
}
- swPerformers, err := getSingleLetterPerformers(cache, performerReader)
+ swPerformers, err := getSingleLetterPerformers(ctx, cache, performerReader)
if err != nil {
return nil, err
}
@@ -138,10 +158,10 @@ func getPerformers(words []string, performerReader models.PerformerReader, cache
return append(performers, swPerformers...), nil
}
-func PathToPerformers(path string, reader models.PerformerReader, cache *Cache, trimExt bool) ([]*models.Performer, error) {
+func PathToPerformers(ctx context.Context, path string, reader PerformerAutoTagQueryer, cache *Cache, trimExt bool) ([]*models.Performer, error) {
words := getPathWords(path, trimExt)
- performers, err := getPerformers(words, reader, cache)
+ performers, err := getPerformers(ctx, words, reader, cache)
if err != nil {
return nil, err
}
@@ -157,13 +177,13 @@ func PathToPerformers(path string, reader models.PerformerReader, cache *Cache,
return ret, nil
}
-func getStudios(words []string, reader models.StudioReader, cache *Cache) ([]*models.Studio, error) {
- studios, err := reader.QueryForAutoTag(words)
+func getStudios(ctx context.Context, words []string, reader StudioAutoTagQueryer, cache *Cache) ([]*models.Studio, error) {
+ studios, err := reader.QueryForAutoTag(ctx, words)
if err != nil {
return nil, err
}
- swStudios, err := getSingleLetterStudios(cache, reader)
+ swStudios, err := getSingleLetterStudios(ctx, cache, reader)
if err != nil {
return nil, err
}
@@ -174,9 +194,9 @@ func getStudios(words []string, reader models.StudioReader, cache *Cache) ([]*mo
// PathToStudio returns the Studio that matches the given path.
// Where multiple matching studios are found, the one that matches the latest
// position in the path is returned.
-func PathToStudio(path string, reader models.StudioReader, cache *Cache, trimExt bool) (*models.Studio, error) {
+func PathToStudio(ctx context.Context, path string, reader StudioAutoTagQueryer, cache *Cache, trimExt bool) (*models.Studio, error) {
words := getPathWords(path, trimExt)
- candidates, err := getStudios(words, reader, cache)
+ candidates, err := getStudios(ctx, words, reader, cache)
if err != nil {
return nil, err
@@ -191,7 +211,7 @@ func PathToStudio(path string, reader models.StudioReader, cache *Cache, trimExt
index = matchIndex
}
- aliases, err := reader.GetAliases(c.ID)
+ aliases, err := reader.GetAliases(ctx, c.ID)
if err != nil {
return nil, err
}
@@ -208,13 +228,13 @@ func PathToStudio(path string, reader models.StudioReader, cache *Cache, trimExt
return ret, nil
}
-func getTags(words []string, reader models.TagReader, cache *Cache) ([]*models.Tag, error) {
- tags, err := reader.QueryForAutoTag(words)
+func getTags(ctx context.Context, words []string, reader TagAutoTagQueryer, cache *Cache) ([]*models.Tag, error) {
+ tags, err := reader.QueryForAutoTag(ctx, words)
if err != nil {
return nil, err
}
- swTags, err := getSingleLetterTags(cache, reader)
+ swTags, err := getSingleLetterTags(ctx, cache, reader)
if err != nil {
return nil, err
}
@@ -222,9 +242,9 @@ func getTags(words []string, reader models.TagReader, cache *Cache) ([]*models.T
return append(tags, swTags...), nil
}
-func PathToTags(path string, reader models.TagReader, cache *Cache, trimExt bool) ([]*models.Tag, error) {
+func PathToTags(ctx context.Context, path string, reader TagAutoTagQueryer, cache *Cache, trimExt bool) ([]*models.Tag, error) {
words := getPathWords(path, trimExt)
- tags, err := getTags(words, reader, cache)
+ tags, err := getTags(ctx, words, reader, cache)
if err != nil {
return nil, err
@@ -238,7 +258,7 @@ func PathToTags(path string, reader models.TagReader, cache *Cache, trimExt bool
}
if !matches {
- aliases, err := reader.GetAliases(t.ID)
+ aliases, err := reader.GetAliases(ctx, t.ID)
if err != nil {
return nil, err
}
@@ -258,7 +278,7 @@ func PathToTags(path string, reader models.TagReader, cache *Cache, trimExt bool
return ret, nil
}
-func PathToScenes(name string, paths []string, sceneReader models.SceneReader) ([]*models.Scene, error) {
+func PathToScenes(ctx context.Context, name string, paths []string, sceneReader scene.Queryer) ([]*models.Scene, error) {
regex := getPathQueryRegex(name)
organized := false
filter := models.SceneFilterType{
@@ -272,7 +292,7 @@ func PathToScenes(name string, paths []string, sceneReader models.SceneReader) (
filter.And = scene.PathsFilter(paths)
pp := models.PerPageAll
- scenes, err := scene.Query(sceneReader, &filter, &models.FindFilterType{
+ scenes, err := scene.Query(ctx, sceneReader, &filter, &models.FindFilterType{
PerPage: &pp,
})
@@ -295,7 +315,7 @@ func PathToScenes(name string, paths []string, sceneReader models.SceneReader) (
return ret, nil
}
-func PathToImages(name string, paths []string, imageReader models.ImageReader) ([]*models.Image, error) {
+func PathToImages(ctx context.Context, name string, paths []string, imageReader image.Queryer) ([]*models.Image, error) {
regex := getPathQueryRegex(name)
organized := false
filter := models.ImageFilterType{
@@ -309,7 +329,7 @@ func PathToImages(name string, paths []string, imageReader models.ImageReader) (
filter.And = image.PathsFilter(paths)
pp := models.PerPageAll
- images, err := image.Query(imageReader, &filter, &models.FindFilterType{
+ images, err := image.Query(ctx, imageReader, &filter, &models.FindFilterType{
PerPage: &pp,
})
@@ -332,7 +352,7 @@ func PathToImages(name string, paths []string, imageReader models.ImageReader) (
return ret, nil
}
-func PathToGalleries(name string, paths []string, galleryReader models.GalleryReader) ([]*models.Gallery, error) {
+func PathToGalleries(ctx context.Context, name string, paths []string, galleryReader gallery.Queryer) ([]*models.Gallery, error) {
regex := getPathQueryRegex(name)
organized := false
filter := models.GalleryFilterType{
@@ -346,7 +366,7 @@ func PathToGalleries(name string, paths []string, galleryReader models.GalleryRe
filter.And = gallery.PathsFilter(paths)
pp := models.PerPageAll
- gallerys, _, err := galleryReader.Query(&filter, &models.FindFilterType{
+ gallerys, _, err := galleryReader.Query(ctx, &filter, &models.FindFilterType{
PerPage: &pp,
})
diff --git a/pkg/match/scraped.go b/pkg/match/scraped.go
index 1e9de81e1..d1182a329 100644
--- a/pkg/match/scraped.go
+++ b/pkg/match/scraped.go
@@ -1,6 +1,7 @@
package match
import (
+ "context"
"strconv"
"github.com/stashapp/stash/pkg/models"
@@ -8,16 +9,25 @@ import (
"github.com/stashapp/stash/pkg/tag"
)
+type PerformerFinder interface {
+ FindByNames(ctx context.Context, names []string, nocase bool) ([]*models.Performer, error)
+ FindByStashID(ctx context.Context, stashID models.StashID) ([]*models.Performer, error)
+}
+
+type MovieNamesFinder interface {
+ FindByNames(ctx context.Context, names []string, nocase bool) ([]*models.Movie, error)
+}
+
// ScrapedPerformer matches the provided performer with the
// performers in the database and sets the ID field if one is found.
-func ScrapedPerformer(qb models.PerformerReader, p *models.ScrapedPerformer, stashBoxEndpoint *string) error {
+func ScrapedPerformer(ctx context.Context, qb PerformerFinder, p *models.ScrapedPerformer, stashBoxEndpoint *string) error {
if p.StoredID != nil || p.Name == nil {
return nil
}
// Check if a performer with the StashID already exists
if stashBoxEndpoint != nil && p.RemoteSiteID != nil {
- performers, err := qb.FindByStashID(models.StashID{
+ performers, err := qb.FindByStashID(ctx, models.StashID{
StashID: *p.RemoteSiteID,
Endpoint: *stashBoxEndpoint,
})
@@ -31,7 +41,7 @@ func ScrapedPerformer(qb models.PerformerReader, p *models.ScrapedPerformer, sta
}
}
- performers, err := qb.FindByNames([]string{*p.Name}, true)
+ performers, err := qb.FindByNames(ctx, []string{*p.Name}, true)
if err != nil {
return err
@@ -47,16 +57,21 @@ func ScrapedPerformer(qb models.PerformerReader, p *models.ScrapedPerformer, sta
return nil
}
+type StudioFinder interface {
+ studio.Queryer
+ FindByStashID(ctx context.Context, stashID models.StashID) ([]*models.Studio, error)
+}
+
// ScrapedStudio matches the provided studio with the studios
// in the database and sets the ID field if one is found.
-func ScrapedStudio(qb models.StudioReader, s *models.ScrapedStudio, stashBoxEndpoint *string) error {
+func ScrapedStudio(ctx context.Context, qb StudioFinder, s *models.ScrapedStudio, stashBoxEndpoint *string) error {
if s.StoredID != nil {
return nil
}
// Check if a studio with the StashID already exists
if stashBoxEndpoint != nil && s.RemoteSiteID != nil {
- studios, err := qb.FindByStashID(models.StashID{
+ studios, err := qb.FindByStashID(ctx, models.StashID{
StashID: *s.RemoteSiteID,
Endpoint: *stashBoxEndpoint,
})
@@ -70,7 +85,7 @@ func ScrapedStudio(qb models.StudioReader, s *models.ScrapedStudio, stashBoxEndp
}
}
- st, err := studio.ByName(qb, s.Name)
+ st, err := studio.ByName(ctx, qb, s.Name)
if err != nil {
return err
@@ -78,7 +93,7 @@ func ScrapedStudio(qb models.StudioReader, s *models.ScrapedStudio, stashBoxEndp
if st == nil {
// try matching by alias
- st, err = studio.ByAlias(qb, s.Name)
+ st, err = studio.ByAlias(ctx, qb, s.Name)
if err != nil {
return err
}
@@ -96,12 +111,12 @@ func ScrapedStudio(qb models.StudioReader, s *models.ScrapedStudio, stashBoxEndp
// ScrapedMovie matches the provided movie with the movies
// in the database and sets the ID field if one is found.
-func ScrapedMovie(qb models.MovieReader, m *models.ScrapedMovie) error {
+func ScrapedMovie(ctx context.Context, qb MovieNamesFinder, m *models.ScrapedMovie) error {
if m.StoredID != nil || m.Name == nil {
return nil
}
- movies, err := qb.FindByNames([]string{*m.Name}, true)
+ movies, err := qb.FindByNames(ctx, []string{*m.Name}, true)
if err != nil {
return err
@@ -119,12 +134,12 @@ func ScrapedMovie(qb models.MovieReader, m *models.ScrapedMovie) error {
// ScrapedTag matches the provided tag with the tags
// in the database and sets the ID field if one is found.
-func ScrapedTag(qb models.TagReader, s *models.ScrapedTag) error {
+func ScrapedTag(ctx context.Context, qb tag.Queryer, s *models.ScrapedTag) error {
if s.StoredID != nil {
return nil
}
- t, err := tag.ByName(qb, s.Name)
+ t, err := tag.ByName(ctx, qb, s.Name)
if err != nil {
return err
@@ -132,7 +147,7 @@ func ScrapedTag(qb models.TagReader, s *models.ScrapedTag) error {
if t == nil {
// try matching by alias
- t, err = tag.ByAlias(qb, s.Name)
+ t, err = tag.ByAlias(ctx, qb, s.Name)
if err != nil {
return err
}
diff --git a/pkg/models/gallery.go b/pkg/models/gallery.go
index 3c85e7193..676b61937 100644
--- a/pkg/models/gallery.go
+++ b/pkg/models/gallery.go
@@ -1,5 +1,7 @@
package models
+import "context"
+
type GalleryFilterType struct {
And *GalleryFilterType `json:"AND"`
Or *GalleryFilterType `json:"OR"`
@@ -67,33 +69,33 @@ type GalleryDestroyInput struct {
}
type GalleryReader interface {
- Find(id int) (*Gallery, error)
- FindMany(ids []int) ([]*Gallery, error)
- FindByChecksum(checksum string) (*Gallery, error)
- FindByChecksums(checksums []string) ([]*Gallery, error)
- FindByPath(path string) (*Gallery, error)
- FindBySceneID(sceneID int) ([]*Gallery, error)
- FindByImageID(imageID int) ([]*Gallery, error)
- Count() (int, error)
- All() ([]*Gallery, error)
- Query(galleryFilter *GalleryFilterType, findFilter *FindFilterType) ([]*Gallery, int, error)
- QueryCount(galleryFilter *GalleryFilterType, findFilter *FindFilterType) (int, error)
- GetPerformerIDs(galleryID int) ([]int, error)
- GetTagIDs(galleryID int) ([]int, error)
- GetSceneIDs(galleryID int) ([]int, error)
- GetImageIDs(galleryID int) ([]int, error)
+ Find(ctx context.Context, id int) (*Gallery, error)
+ FindMany(ctx context.Context, ids []int) ([]*Gallery, error)
+ FindByChecksum(ctx context.Context, checksum string) (*Gallery, error)
+ FindByChecksums(ctx context.Context, checksums []string) ([]*Gallery, error)
+ FindByPath(ctx context.Context, path string) (*Gallery, error)
+ FindBySceneID(ctx context.Context, sceneID int) ([]*Gallery, error)
+ FindByImageID(ctx context.Context, imageID int) ([]*Gallery, error)
+ Count(ctx context.Context) (int, error)
+ All(ctx context.Context) ([]*Gallery, error)
+ Query(ctx context.Context, galleryFilter *GalleryFilterType, findFilter *FindFilterType) ([]*Gallery, int, error)
+ QueryCount(ctx context.Context, galleryFilter *GalleryFilterType, findFilter *FindFilterType) (int, error)
+ GetPerformerIDs(ctx context.Context, galleryID int) ([]int, error)
+ GetTagIDs(ctx context.Context, galleryID int) ([]int, error)
+ GetSceneIDs(ctx context.Context, galleryID int) ([]int, error)
+ GetImageIDs(ctx context.Context, galleryID int) ([]int, error)
}
type GalleryWriter interface {
- Create(newGallery Gallery) (*Gallery, error)
- Update(updatedGallery Gallery) (*Gallery, error)
- UpdatePartial(updatedGallery GalleryPartial) (*Gallery, error)
- UpdateFileModTime(id int, modTime NullSQLiteTimestamp) error
- Destroy(id int) error
- UpdatePerformers(galleryID int, performerIDs []int) error
- UpdateTags(galleryID int, tagIDs []int) error
- UpdateScenes(galleryID int, sceneIDs []int) error
- UpdateImages(galleryID int, imageIDs []int) error
+ Create(ctx context.Context, newGallery Gallery) (*Gallery, error)
+ Update(ctx context.Context, updatedGallery Gallery) (*Gallery, error)
+ UpdatePartial(ctx context.Context, updatedGallery GalleryPartial) (*Gallery, error)
+ UpdateFileModTime(ctx context.Context, id int, modTime NullSQLiteTimestamp) error
+ Destroy(ctx context.Context, id int) error
+ UpdatePerformers(ctx context.Context, galleryID int, performerIDs []int) error
+ UpdateTags(ctx context.Context, galleryID int, tagIDs []int) error
+ UpdateScenes(ctx context.Context, galleryID int, sceneIDs []int) error
+ UpdateImages(ctx context.Context, galleryID int, imageIDs []int) error
}
type GalleryReaderWriter interface {
diff --git a/pkg/models/image.go b/pkg/models/image.go
index 4b28c4c36..4509ef709 100644
--- a/pkg/models/image.go
+++ b/pkg/models/image.go
@@ -1,5 +1,7 @@
package models
+import "context"
+
type ImageFilterType struct {
And *ImageFilterType `json:"AND"`
Or *ImageFilterType `json:"OR"`
@@ -73,54 +75,54 @@ func NewImageQueryResult(finder ImageFinder) *ImageQueryResult {
}
}
-func (r *ImageQueryResult) Resolve() ([]*Image, error) {
+func (r *ImageQueryResult) Resolve(ctx context.Context) ([]*Image, error) {
// cache results
if r.images == nil && r.resolveErr == nil {
- r.images, r.resolveErr = r.finder.FindMany(r.IDs)
+ r.images, r.resolveErr = r.finder.FindMany(ctx, r.IDs)
}
return r.images, r.resolveErr
}
type ImageFinder interface {
// TODO - rename to Find and remove existing method
- FindMany(ids []int) ([]*Image, error)
+ FindMany(ctx context.Context, ids []int) ([]*Image, error)
}
type ImageReader interface {
ImageFinder
// TODO - remove this in another PR
- Find(id int) (*Image, error)
- FindByChecksum(checksum string) (*Image, error)
- FindByGalleryID(galleryID int) ([]*Image, error)
- CountByGalleryID(galleryID int) (int, error)
- FindByPath(path string) (*Image, error)
+ Find(ctx context.Context, id int) (*Image, error)
+ FindByChecksum(ctx context.Context, checksum string) (*Image, error)
+ FindByGalleryID(ctx context.Context, galleryID int) ([]*Image, error)
+ CountByGalleryID(ctx context.Context, galleryID int) (int, error)
+ FindByPath(ctx context.Context, path string) (*Image, error)
// FindByPerformerID(performerID int) ([]*Image, error)
// CountByPerformerID(performerID int) (int, error)
// FindByStudioID(studioID int) ([]*Image, error)
- Count() (int, error)
- Size() (float64, error)
+ Count(ctx context.Context) (int, error)
+ Size(ctx context.Context) (float64, error)
// SizeCount() (string, error)
// CountByStudioID(studioID int) (int, error)
// CountByTagID(tagID int) (int, error)
- All() ([]*Image, error)
- Query(options ImageQueryOptions) (*ImageQueryResult, error)
- QueryCount(imageFilter *ImageFilterType, findFilter *FindFilterType) (int, error)
- GetGalleryIDs(imageID int) ([]int, error)
- GetTagIDs(imageID int) ([]int, error)
- GetPerformerIDs(imageID int) ([]int, error)
+ All(ctx context.Context) ([]*Image, error)
+ Query(ctx context.Context, options ImageQueryOptions) (*ImageQueryResult, error)
+ QueryCount(ctx context.Context, imageFilter *ImageFilterType, findFilter *FindFilterType) (int, error)
+ GetGalleryIDs(ctx context.Context, imageID int) ([]int, error)
+ GetTagIDs(ctx context.Context, imageID int) ([]int, error)
+ GetPerformerIDs(ctx context.Context, imageID int) ([]int, error)
}
type ImageWriter interface {
- Create(newImage Image) (*Image, error)
- Update(updatedImage ImagePartial) (*Image, error)
- UpdateFull(updatedImage Image) (*Image, error)
- IncrementOCounter(id int) (int, error)
- DecrementOCounter(id int) (int, error)
- ResetOCounter(id int) (int, error)
- Destroy(id int) error
- UpdateGalleries(imageID int, galleryIDs []int) error
- UpdatePerformers(imageID int, performerIDs []int) error
- UpdateTags(imageID int, tagIDs []int) error
+ Create(ctx context.Context, newImage Image) (*Image, error)
+ Update(ctx context.Context, updatedImage ImagePartial) (*Image, error)
+ UpdateFull(ctx context.Context, updatedImage Image) (*Image, error)
+ IncrementOCounter(ctx context.Context, id int) (int, error)
+ DecrementOCounter(ctx context.Context, id int) (int, error)
+ ResetOCounter(ctx context.Context, id int) (int, error)
+ Destroy(ctx context.Context, id int) error
+ UpdateGalleries(ctx context.Context, imageID int, galleryIDs []int) error
+ UpdatePerformers(ctx context.Context, imageID int, performerIDs []int) error
+ UpdateTags(ctx context.Context, imageID int, tagIDs []int) error
}
type ImageReaderWriter interface {
diff --git a/pkg/models/mocks/GalleryReaderWriter.go b/pkg/models/mocks/GalleryReaderWriter.go
index 9731147fe..ee8ec643d 100644
--- a/pkg/models/mocks/GalleryReaderWriter.go
+++ b/pkg/models/mocks/GalleryReaderWriter.go
@@ -3,6 +3,8 @@
package mocks
import (
+ context "context"
+
models "github.com/stashapp/stash/pkg/models"
mock "github.com/stretchr/testify/mock"
)
@@ -12,13 +14,13 @@ type GalleryReaderWriter struct {
mock.Mock
}
-// All provides a mock function with given fields:
-func (_m *GalleryReaderWriter) All() ([]*models.Gallery, error) {
- ret := _m.Called()
+// All provides a mock function with given fields: ctx
+func (_m *GalleryReaderWriter) All(ctx context.Context) ([]*models.Gallery, error) {
+ ret := _m.Called(ctx)
var r0 []*models.Gallery
- if rf, ok := ret.Get(0).(func() []*models.Gallery); ok {
- r0 = rf()
+ if rf, ok := ret.Get(0).(func(context.Context) []*models.Gallery); ok {
+ r0 = rf(ctx)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Gallery)
@@ -26,8 +28,8 @@ func (_m *GalleryReaderWriter) All() ([]*models.Gallery, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func() error); ok {
- r1 = rf()
+ if rf, ok := ret.Get(1).(func(context.Context) error); ok {
+ r1 = rf(ctx)
} else {
r1 = ret.Error(1)
}
@@ -35,20 +37,20 @@ func (_m *GalleryReaderWriter) All() ([]*models.Gallery, error) {
return r0, r1
}
-// Count provides a mock function with given fields:
-func (_m *GalleryReaderWriter) Count() (int, error) {
- ret := _m.Called()
+// Count provides a mock function with given fields: ctx
+func (_m *GalleryReaderWriter) Count(ctx context.Context) (int, error) {
+ ret := _m.Called(ctx)
var r0 int
- if rf, ok := ret.Get(0).(func() int); ok {
- r0 = rf()
+ if rf, ok := ret.Get(0).(func(context.Context) int); ok {
+ r0 = rf(ctx)
} else {
r0 = ret.Get(0).(int)
}
var r1 error
- if rf, ok := ret.Get(1).(func() error); ok {
- r1 = rf()
+ if rf, ok := ret.Get(1).(func(context.Context) error); ok {
+ r1 = rf(ctx)
} else {
r1 = ret.Error(1)
}
@@ -56,13 +58,13 @@ func (_m *GalleryReaderWriter) Count() (int, error) {
return r0, r1
}
-// Create provides a mock function with given fields: newGallery
-func (_m *GalleryReaderWriter) Create(newGallery models.Gallery) (*models.Gallery, error) {
- ret := _m.Called(newGallery)
+// Create provides a mock function with given fields: ctx, newGallery
+func (_m *GalleryReaderWriter) Create(ctx context.Context, newGallery models.Gallery) (*models.Gallery, error) {
+ ret := _m.Called(ctx, newGallery)
var r0 *models.Gallery
- if rf, ok := ret.Get(0).(func(models.Gallery) *models.Gallery); ok {
- r0 = rf(newGallery)
+ if rf, ok := ret.Get(0).(func(context.Context, models.Gallery) *models.Gallery); ok {
+ r0 = rf(ctx, newGallery)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.Gallery)
@@ -70,8 +72,8 @@ func (_m *GalleryReaderWriter) Create(newGallery models.Gallery) (*models.Galler
}
var r1 error
- if rf, ok := ret.Get(1).(func(models.Gallery) error); ok {
- r1 = rf(newGallery)
+ if rf, ok := ret.Get(1).(func(context.Context, models.Gallery) error); ok {
+ r1 = rf(ctx, newGallery)
} else {
r1 = ret.Error(1)
}
@@ -79,13 +81,13 @@ func (_m *GalleryReaderWriter) Create(newGallery models.Gallery) (*models.Galler
return r0, r1
}
-// Destroy provides a mock function with given fields: id
-func (_m *GalleryReaderWriter) Destroy(id int) error {
- ret := _m.Called(id)
+// Destroy provides a mock function with given fields: ctx, id
+func (_m *GalleryReaderWriter) Destroy(ctx context.Context, id int) error {
+ ret := _m.Called(ctx, id)
var r0 error
- if rf, ok := ret.Get(0).(func(int) error); ok {
- r0 = rf(id)
+ if rf, ok := ret.Get(0).(func(context.Context, int) error); ok {
+ r0 = rf(ctx, id)
} else {
r0 = ret.Error(0)
}
@@ -93,13 +95,13 @@ func (_m *GalleryReaderWriter) Destroy(id int) error {
return r0
}
-// Find provides a mock function with given fields: id
-func (_m *GalleryReaderWriter) Find(id int) (*models.Gallery, error) {
- ret := _m.Called(id)
+// Find provides a mock function with given fields: ctx, id
+func (_m *GalleryReaderWriter) Find(ctx context.Context, id int) (*models.Gallery, error) {
+ ret := _m.Called(ctx, id)
var r0 *models.Gallery
- if rf, ok := ret.Get(0).(func(int) *models.Gallery); ok {
- r0 = rf(id)
+ if rf, ok := ret.Get(0).(func(context.Context, int) *models.Gallery); ok {
+ r0 = rf(ctx, id)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.Gallery)
@@ -107,8 +109,8 @@ func (_m *GalleryReaderWriter) Find(id int) (*models.Gallery, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(id)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, id)
} else {
r1 = ret.Error(1)
}
@@ -116,13 +118,13 @@ func (_m *GalleryReaderWriter) Find(id int) (*models.Gallery, error) {
return r0, r1
}
-// FindByChecksum provides a mock function with given fields: checksum
-func (_m *GalleryReaderWriter) FindByChecksum(checksum string) (*models.Gallery, error) {
- ret := _m.Called(checksum)
+// FindByChecksum provides a mock function with given fields: ctx, checksum
+func (_m *GalleryReaderWriter) FindByChecksum(ctx context.Context, checksum string) (*models.Gallery, error) {
+ ret := _m.Called(ctx, checksum)
var r0 *models.Gallery
- if rf, ok := ret.Get(0).(func(string) *models.Gallery); ok {
- r0 = rf(checksum)
+ if rf, ok := ret.Get(0).(func(context.Context, string) *models.Gallery); ok {
+ r0 = rf(ctx, checksum)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.Gallery)
@@ -130,8 +132,8 @@ func (_m *GalleryReaderWriter) FindByChecksum(checksum string) (*models.Gallery,
}
var r1 error
- if rf, ok := ret.Get(1).(func(string) error); ok {
- r1 = rf(checksum)
+ if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
+ r1 = rf(ctx, checksum)
} else {
r1 = ret.Error(1)
}
@@ -139,13 +141,13 @@ func (_m *GalleryReaderWriter) FindByChecksum(checksum string) (*models.Gallery,
return r0, r1
}
-// FindByChecksums provides a mock function with given fields: checksums
-func (_m *GalleryReaderWriter) FindByChecksums(checksums []string) ([]*models.Gallery, error) {
- ret := _m.Called(checksums)
+// FindByChecksums provides a mock function with given fields: ctx, checksums
+func (_m *GalleryReaderWriter) FindByChecksums(ctx context.Context, checksums []string) ([]*models.Gallery, error) {
+ ret := _m.Called(ctx, checksums)
var r0 []*models.Gallery
- if rf, ok := ret.Get(0).(func([]string) []*models.Gallery); ok {
- r0 = rf(checksums)
+ if rf, ok := ret.Get(0).(func(context.Context, []string) []*models.Gallery); ok {
+ r0 = rf(ctx, checksums)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Gallery)
@@ -153,8 +155,8 @@ func (_m *GalleryReaderWriter) FindByChecksums(checksums []string) ([]*models.Ga
}
var r1 error
- if rf, ok := ret.Get(1).(func([]string) error); ok {
- r1 = rf(checksums)
+ if rf, ok := ret.Get(1).(func(context.Context, []string) error); ok {
+ r1 = rf(ctx, checksums)
} else {
r1 = ret.Error(1)
}
@@ -162,13 +164,13 @@ func (_m *GalleryReaderWriter) FindByChecksums(checksums []string) ([]*models.Ga
return r0, r1
}
-// FindByImageID provides a mock function with given fields: imageID
-func (_m *GalleryReaderWriter) FindByImageID(imageID int) ([]*models.Gallery, error) {
- ret := _m.Called(imageID)
+// FindByImageID provides a mock function with given fields: ctx, imageID
+func (_m *GalleryReaderWriter) FindByImageID(ctx context.Context, imageID int) ([]*models.Gallery, error) {
+ ret := _m.Called(ctx, imageID)
var r0 []*models.Gallery
- if rf, ok := ret.Get(0).(func(int) []*models.Gallery); ok {
- r0 = rf(imageID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) []*models.Gallery); ok {
+ r0 = rf(ctx, imageID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Gallery)
@@ -176,8 +178,8 @@ func (_m *GalleryReaderWriter) FindByImageID(imageID int) ([]*models.Gallery, er
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(imageID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, imageID)
} else {
r1 = ret.Error(1)
}
@@ -185,13 +187,13 @@ func (_m *GalleryReaderWriter) FindByImageID(imageID int) ([]*models.Gallery, er
return r0, r1
}
-// FindByPath provides a mock function with given fields: path
-func (_m *GalleryReaderWriter) FindByPath(path string) (*models.Gallery, error) {
- ret := _m.Called(path)
+// FindByPath provides a mock function with given fields: ctx, path
+func (_m *GalleryReaderWriter) FindByPath(ctx context.Context, path string) (*models.Gallery, error) {
+ ret := _m.Called(ctx, path)
var r0 *models.Gallery
- if rf, ok := ret.Get(0).(func(string) *models.Gallery); ok {
- r0 = rf(path)
+ if rf, ok := ret.Get(0).(func(context.Context, string) *models.Gallery); ok {
+ r0 = rf(ctx, path)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.Gallery)
@@ -199,8 +201,8 @@ func (_m *GalleryReaderWriter) FindByPath(path string) (*models.Gallery, error)
}
var r1 error
- if rf, ok := ret.Get(1).(func(string) error); ok {
- r1 = rf(path)
+ if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
+ r1 = rf(ctx, path)
} else {
r1 = ret.Error(1)
}
@@ -208,13 +210,13 @@ func (_m *GalleryReaderWriter) FindByPath(path string) (*models.Gallery, error)
return r0, r1
}
-// FindBySceneID provides a mock function with given fields: sceneID
-func (_m *GalleryReaderWriter) FindBySceneID(sceneID int) ([]*models.Gallery, error) {
- ret := _m.Called(sceneID)
+// FindBySceneID provides a mock function with given fields: ctx, sceneID
+func (_m *GalleryReaderWriter) FindBySceneID(ctx context.Context, sceneID int) ([]*models.Gallery, error) {
+ ret := _m.Called(ctx, sceneID)
var r0 []*models.Gallery
- if rf, ok := ret.Get(0).(func(int) []*models.Gallery); ok {
- r0 = rf(sceneID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) []*models.Gallery); ok {
+ r0 = rf(ctx, sceneID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Gallery)
@@ -222,8 +224,8 @@ func (_m *GalleryReaderWriter) FindBySceneID(sceneID int) ([]*models.Gallery, er
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(sceneID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, sceneID)
} else {
r1 = ret.Error(1)
}
@@ -231,13 +233,13 @@ func (_m *GalleryReaderWriter) FindBySceneID(sceneID int) ([]*models.Gallery, er
return r0, r1
}
-// FindMany provides a mock function with given fields: ids
-func (_m *GalleryReaderWriter) FindMany(ids []int) ([]*models.Gallery, error) {
- ret := _m.Called(ids)
+// FindMany provides a mock function with given fields: ctx, ids
+func (_m *GalleryReaderWriter) FindMany(ctx context.Context, ids []int) ([]*models.Gallery, error) {
+ ret := _m.Called(ctx, ids)
var r0 []*models.Gallery
- if rf, ok := ret.Get(0).(func([]int) []*models.Gallery); ok {
- r0 = rf(ids)
+ if rf, ok := ret.Get(0).(func(context.Context, []int) []*models.Gallery); ok {
+ r0 = rf(ctx, ids)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Gallery)
@@ -245,8 +247,8 @@ func (_m *GalleryReaderWriter) FindMany(ids []int) ([]*models.Gallery, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func([]int) error); ok {
- r1 = rf(ids)
+ if rf, ok := ret.Get(1).(func(context.Context, []int) error); ok {
+ r1 = rf(ctx, ids)
} else {
r1 = ret.Error(1)
}
@@ -254,13 +256,13 @@ func (_m *GalleryReaderWriter) FindMany(ids []int) ([]*models.Gallery, error) {
return r0, r1
}
-// GetImageIDs provides a mock function with given fields: galleryID
-func (_m *GalleryReaderWriter) GetImageIDs(galleryID int) ([]int, error) {
- ret := _m.Called(galleryID)
+// GetImageIDs provides a mock function with given fields: ctx, galleryID
+func (_m *GalleryReaderWriter) GetImageIDs(ctx context.Context, galleryID int) ([]int, error) {
+ ret := _m.Called(ctx, galleryID)
var r0 []int
- if rf, ok := ret.Get(0).(func(int) []int); ok {
- r0 = rf(galleryID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) []int); ok {
+ r0 = rf(ctx, galleryID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]int)
@@ -268,8 +270,8 @@ func (_m *GalleryReaderWriter) GetImageIDs(galleryID int) ([]int, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(galleryID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, galleryID)
} else {
r1 = ret.Error(1)
}
@@ -277,13 +279,13 @@ func (_m *GalleryReaderWriter) GetImageIDs(galleryID int) ([]int, error) {
return r0, r1
}
-// GetPerformerIDs provides a mock function with given fields: galleryID
-func (_m *GalleryReaderWriter) GetPerformerIDs(galleryID int) ([]int, error) {
- ret := _m.Called(galleryID)
+// GetPerformerIDs provides a mock function with given fields: ctx, galleryID
+func (_m *GalleryReaderWriter) GetPerformerIDs(ctx context.Context, galleryID int) ([]int, error) {
+ ret := _m.Called(ctx, galleryID)
var r0 []int
- if rf, ok := ret.Get(0).(func(int) []int); ok {
- r0 = rf(galleryID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) []int); ok {
+ r0 = rf(ctx, galleryID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]int)
@@ -291,8 +293,8 @@ func (_m *GalleryReaderWriter) GetPerformerIDs(galleryID int) ([]int, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(galleryID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, galleryID)
} else {
r1 = ret.Error(1)
}
@@ -300,13 +302,13 @@ func (_m *GalleryReaderWriter) GetPerformerIDs(galleryID int) ([]int, error) {
return r0, r1
}
-// GetSceneIDs provides a mock function with given fields: galleryID
-func (_m *GalleryReaderWriter) GetSceneIDs(galleryID int) ([]int, error) {
- ret := _m.Called(galleryID)
+// GetSceneIDs provides a mock function with given fields: ctx, galleryID
+func (_m *GalleryReaderWriter) GetSceneIDs(ctx context.Context, galleryID int) ([]int, error) {
+ ret := _m.Called(ctx, galleryID)
var r0 []int
- if rf, ok := ret.Get(0).(func(int) []int); ok {
- r0 = rf(galleryID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) []int); ok {
+ r0 = rf(ctx, galleryID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]int)
@@ -314,8 +316,8 @@ func (_m *GalleryReaderWriter) GetSceneIDs(galleryID int) ([]int, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(galleryID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, galleryID)
} else {
r1 = ret.Error(1)
}
@@ -323,13 +325,13 @@ func (_m *GalleryReaderWriter) GetSceneIDs(galleryID int) ([]int, error) {
return r0, r1
}
-// GetTagIDs provides a mock function with given fields: galleryID
-func (_m *GalleryReaderWriter) GetTagIDs(galleryID int) ([]int, error) {
- ret := _m.Called(galleryID)
+// GetTagIDs provides a mock function with given fields: ctx, galleryID
+func (_m *GalleryReaderWriter) GetTagIDs(ctx context.Context, galleryID int) ([]int, error) {
+ ret := _m.Called(ctx, galleryID)
var r0 []int
- if rf, ok := ret.Get(0).(func(int) []int); ok {
- r0 = rf(galleryID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) []int); ok {
+ r0 = rf(ctx, galleryID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]int)
@@ -337,8 +339,8 @@ func (_m *GalleryReaderWriter) GetTagIDs(galleryID int) ([]int, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(galleryID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, galleryID)
} else {
r1 = ret.Error(1)
}
@@ -346,13 +348,13 @@ func (_m *GalleryReaderWriter) GetTagIDs(galleryID int) ([]int, error) {
return r0, r1
}
-// Query provides a mock function with given fields: galleryFilter, findFilter
-func (_m *GalleryReaderWriter) Query(galleryFilter *models.GalleryFilterType, findFilter *models.FindFilterType) ([]*models.Gallery, int, error) {
- ret := _m.Called(galleryFilter, findFilter)
+// Query provides a mock function with given fields: ctx, galleryFilter, findFilter
+func (_m *GalleryReaderWriter) Query(ctx context.Context, galleryFilter *models.GalleryFilterType, findFilter *models.FindFilterType) ([]*models.Gallery, int, error) {
+ ret := _m.Called(ctx, galleryFilter, findFilter)
var r0 []*models.Gallery
- if rf, ok := ret.Get(0).(func(*models.GalleryFilterType, *models.FindFilterType) []*models.Gallery); ok {
- r0 = rf(galleryFilter, findFilter)
+ if rf, ok := ret.Get(0).(func(context.Context, *models.GalleryFilterType, *models.FindFilterType) []*models.Gallery); ok {
+ r0 = rf(ctx, galleryFilter, findFilter)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Gallery)
@@ -360,15 +362,15 @@ func (_m *GalleryReaderWriter) Query(galleryFilter *models.GalleryFilterType, fi
}
var r1 int
- if rf, ok := ret.Get(1).(func(*models.GalleryFilterType, *models.FindFilterType) int); ok {
- r1 = rf(galleryFilter, findFilter)
+ if rf, ok := ret.Get(1).(func(context.Context, *models.GalleryFilterType, *models.FindFilterType) int); ok {
+ r1 = rf(ctx, galleryFilter, findFilter)
} else {
r1 = ret.Get(1).(int)
}
var r2 error
- if rf, ok := ret.Get(2).(func(*models.GalleryFilterType, *models.FindFilterType) error); ok {
- r2 = rf(galleryFilter, findFilter)
+ if rf, ok := ret.Get(2).(func(context.Context, *models.GalleryFilterType, *models.FindFilterType) error); ok {
+ r2 = rf(ctx, galleryFilter, findFilter)
} else {
r2 = ret.Error(2)
}
@@ -376,20 +378,20 @@ func (_m *GalleryReaderWriter) Query(galleryFilter *models.GalleryFilterType, fi
return r0, r1, r2
}
-// QueryCount provides a mock function with given fields: galleryFilter, findFilter
-func (_m *GalleryReaderWriter) QueryCount(galleryFilter *models.GalleryFilterType, findFilter *models.FindFilterType) (int, error) {
- ret := _m.Called(galleryFilter, findFilter)
+// QueryCount provides a mock function with given fields: ctx, galleryFilter, findFilter
+func (_m *GalleryReaderWriter) QueryCount(ctx context.Context, galleryFilter *models.GalleryFilterType, findFilter *models.FindFilterType) (int, error) {
+ ret := _m.Called(ctx, galleryFilter, findFilter)
var r0 int
- if rf, ok := ret.Get(0).(func(*models.GalleryFilterType, *models.FindFilterType) int); ok {
- r0 = rf(galleryFilter, findFilter)
+ if rf, ok := ret.Get(0).(func(context.Context, *models.GalleryFilterType, *models.FindFilterType) int); ok {
+ r0 = rf(ctx, galleryFilter, findFilter)
} else {
r0 = ret.Get(0).(int)
}
var r1 error
- if rf, ok := ret.Get(1).(func(*models.GalleryFilterType, *models.FindFilterType) error); ok {
- r1 = rf(galleryFilter, findFilter)
+ if rf, ok := ret.Get(1).(func(context.Context, *models.GalleryFilterType, *models.FindFilterType) error); ok {
+ r1 = rf(ctx, galleryFilter, findFilter)
} else {
r1 = ret.Error(1)
}
@@ -397,13 +399,13 @@ func (_m *GalleryReaderWriter) QueryCount(galleryFilter *models.GalleryFilterTyp
return r0, r1
}
-// Update provides a mock function with given fields: updatedGallery
-func (_m *GalleryReaderWriter) Update(updatedGallery models.Gallery) (*models.Gallery, error) {
- ret := _m.Called(updatedGallery)
+// Update provides a mock function with given fields: ctx, updatedGallery
+func (_m *GalleryReaderWriter) Update(ctx context.Context, updatedGallery models.Gallery) (*models.Gallery, error) {
+ ret := _m.Called(ctx, updatedGallery)
var r0 *models.Gallery
- if rf, ok := ret.Get(0).(func(models.Gallery) *models.Gallery); ok {
- r0 = rf(updatedGallery)
+ if rf, ok := ret.Get(0).(func(context.Context, models.Gallery) *models.Gallery); ok {
+ r0 = rf(ctx, updatedGallery)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.Gallery)
@@ -411,8 +413,8 @@ func (_m *GalleryReaderWriter) Update(updatedGallery models.Gallery) (*models.Ga
}
var r1 error
- if rf, ok := ret.Get(1).(func(models.Gallery) error); ok {
- r1 = rf(updatedGallery)
+ if rf, ok := ret.Get(1).(func(context.Context, models.Gallery) error); ok {
+ r1 = rf(ctx, updatedGallery)
} else {
r1 = ret.Error(1)
}
@@ -420,13 +422,13 @@ func (_m *GalleryReaderWriter) Update(updatedGallery models.Gallery) (*models.Ga
return r0, r1
}
-// UpdateFileModTime provides a mock function with given fields: id, modTime
-func (_m *GalleryReaderWriter) UpdateFileModTime(id int, modTime models.NullSQLiteTimestamp) error {
- ret := _m.Called(id, modTime)
+// UpdateFileModTime provides a mock function with given fields: ctx, id, modTime
+func (_m *GalleryReaderWriter) UpdateFileModTime(ctx context.Context, id int, modTime models.NullSQLiteTimestamp) error {
+ ret := _m.Called(ctx, id, modTime)
var r0 error
- if rf, ok := ret.Get(0).(func(int, models.NullSQLiteTimestamp) error); ok {
- r0 = rf(id, modTime)
+ if rf, ok := ret.Get(0).(func(context.Context, int, models.NullSQLiteTimestamp) error); ok {
+ r0 = rf(ctx, id, modTime)
} else {
r0 = ret.Error(0)
}
@@ -434,13 +436,13 @@ func (_m *GalleryReaderWriter) UpdateFileModTime(id int, modTime models.NullSQLi
return r0
}
-// UpdateImages provides a mock function with given fields: galleryID, imageIDs
-func (_m *GalleryReaderWriter) UpdateImages(galleryID int, imageIDs []int) error {
- ret := _m.Called(galleryID, imageIDs)
+// UpdateImages provides a mock function with given fields: ctx, galleryID, imageIDs
+func (_m *GalleryReaderWriter) UpdateImages(ctx context.Context, galleryID int, imageIDs []int) error {
+ ret := _m.Called(ctx, galleryID, imageIDs)
var r0 error
- if rf, ok := ret.Get(0).(func(int, []int) error); ok {
- r0 = rf(galleryID, imageIDs)
+ if rf, ok := ret.Get(0).(func(context.Context, int, []int) error); ok {
+ r0 = rf(ctx, galleryID, imageIDs)
} else {
r0 = ret.Error(0)
}
@@ -448,13 +450,13 @@ func (_m *GalleryReaderWriter) UpdateImages(galleryID int, imageIDs []int) error
return r0
}
-// UpdatePartial provides a mock function with given fields: updatedGallery
-func (_m *GalleryReaderWriter) UpdatePartial(updatedGallery models.GalleryPartial) (*models.Gallery, error) {
- ret := _m.Called(updatedGallery)
+// UpdatePartial provides a mock function with given fields: ctx, updatedGallery
+func (_m *GalleryReaderWriter) UpdatePartial(ctx context.Context, updatedGallery models.GalleryPartial) (*models.Gallery, error) {
+ ret := _m.Called(ctx, updatedGallery)
var r0 *models.Gallery
- if rf, ok := ret.Get(0).(func(models.GalleryPartial) *models.Gallery); ok {
- r0 = rf(updatedGallery)
+ if rf, ok := ret.Get(0).(func(context.Context, models.GalleryPartial) *models.Gallery); ok {
+ r0 = rf(ctx, updatedGallery)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.Gallery)
@@ -462,8 +464,8 @@ func (_m *GalleryReaderWriter) UpdatePartial(updatedGallery models.GalleryPartia
}
var r1 error
- if rf, ok := ret.Get(1).(func(models.GalleryPartial) error); ok {
- r1 = rf(updatedGallery)
+ if rf, ok := ret.Get(1).(func(context.Context, models.GalleryPartial) error); ok {
+ r1 = rf(ctx, updatedGallery)
} else {
r1 = ret.Error(1)
}
@@ -471,13 +473,13 @@ func (_m *GalleryReaderWriter) UpdatePartial(updatedGallery models.GalleryPartia
return r0, r1
}
-// UpdatePerformers provides a mock function with given fields: galleryID, performerIDs
-func (_m *GalleryReaderWriter) UpdatePerformers(galleryID int, performerIDs []int) error {
- ret := _m.Called(galleryID, performerIDs)
+// UpdatePerformers provides a mock function with given fields: ctx, galleryID, performerIDs
+func (_m *GalleryReaderWriter) UpdatePerformers(ctx context.Context, galleryID int, performerIDs []int) error {
+ ret := _m.Called(ctx, galleryID, performerIDs)
var r0 error
- if rf, ok := ret.Get(0).(func(int, []int) error); ok {
- r0 = rf(galleryID, performerIDs)
+ if rf, ok := ret.Get(0).(func(context.Context, int, []int) error); ok {
+ r0 = rf(ctx, galleryID, performerIDs)
} else {
r0 = ret.Error(0)
}
@@ -485,13 +487,13 @@ func (_m *GalleryReaderWriter) UpdatePerformers(galleryID int, performerIDs []in
return r0
}
-// UpdateScenes provides a mock function with given fields: galleryID, sceneIDs
-func (_m *GalleryReaderWriter) UpdateScenes(galleryID int, sceneIDs []int) error {
- ret := _m.Called(galleryID, sceneIDs)
+// UpdateScenes provides a mock function with given fields: ctx, galleryID, sceneIDs
+func (_m *GalleryReaderWriter) UpdateScenes(ctx context.Context, galleryID int, sceneIDs []int) error {
+ ret := _m.Called(ctx, galleryID, sceneIDs)
var r0 error
- if rf, ok := ret.Get(0).(func(int, []int) error); ok {
- r0 = rf(galleryID, sceneIDs)
+ if rf, ok := ret.Get(0).(func(context.Context, int, []int) error); ok {
+ r0 = rf(ctx, galleryID, sceneIDs)
} else {
r0 = ret.Error(0)
}
@@ -499,13 +501,13 @@ func (_m *GalleryReaderWriter) UpdateScenes(galleryID int, sceneIDs []int) error
return r0
}
-// UpdateTags provides a mock function with given fields: galleryID, tagIDs
-func (_m *GalleryReaderWriter) UpdateTags(galleryID int, tagIDs []int) error {
- ret := _m.Called(galleryID, tagIDs)
+// UpdateTags provides a mock function with given fields: ctx, galleryID, tagIDs
+func (_m *GalleryReaderWriter) UpdateTags(ctx context.Context, galleryID int, tagIDs []int) error {
+ ret := _m.Called(ctx, galleryID, tagIDs)
var r0 error
- if rf, ok := ret.Get(0).(func(int, []int) error); ok {
- r0 = rf(galleryID, tagIDs)
+ if rf, ok := ret.Get(0).(func(context.Context, int, []int) error); ok {
+ r0 = rf(ctx, galleryID, tagIDs)
} else {
r0 = ret.Error(0)
}
diff --git a/pkg/models/mocks/ImageReaderWriter.go b/pkg/models/mocks/ImageReaderWriter.go
index 5a13ad986..9660849f1 100644
--- a/pkg/models/mocks/ImageReaderWriter.go
+++ b/pkg/models/mocks/ImageReaderWriter.go
@@ -3,6 +3,8 @@
package mocks
import (
+ context "context"
+
models "github.com/stashapp/stash/pkg/models"
mock "github.com/stretchr/testify/mock"
)
@@ -12,13 +14,13 @@ type ImageReaderWriter struct {
mock.Mock
}
-// All provides a mock function with given fields:
-func (_m *ImageReaderWriter) All() ([]*models.Image, error) {
- ret := _m.Called()
+// All provides a mock function with given fields: ctx
+func (_m *ImageReaderWriter) All(ctx context.Context) ([]*models.Image, error) {
+ ret := _m.Called(ctx)
var r0 []*models.Image
- if rf, ok := ret.Get(0).(func() []*models.Image); ok {
- r0 = rf()
+ if rf, ok := ret.Get(0).(func(context.Context) []*models.Image); ok {
+ r0 = rf(ctx)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Image)
@@ -26,8 +28,8 @@ func (_m *ImageReaderWriter) All() ([]*models.Image, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func() error); ok {
- r1 = rf()
+ if rf, ok := ret.Get(1).(func(context.Context) error); ok {
+ r1 = rf(ctx)
} else {
r1 = ret.Error(1)
}
@@ -35,20 +37,20 @@ func (_m *ImageReaderWriter) All() ([]*models.Image, error) {
return r0, r1
}
-// Count provides a mock function with given fields:
-func (_m *ImageReaderWriter) Count() (int, error) {
- ret := _m.Called()
+// Count provides a mock function with given fields: ctx
+func (_m *ImageReaderWriter) Count(ctx context.Context) (int, error) {
+ ret := _m.Called(ctx)
var r0 int
- if rf, ok := ret.Get(0).(func() int); ok {
- r0 = rf()
+ if rf, ok := ret.Get(0).(func(context.Context) int); ok {
+ r0 = rf(ctx)
} else {
r0 = ret.Get(0).(int)
}
var r1 error
- if rf, ok := ret.Get(1).(func() error); ok {
- r1 = rf()
+ if rf, ok := ret.Get(1).(func(context.Context) error); ok {
+ r1 = rf(ctx)
} else {
r1 = ret.Error(1)
}
@@ -56,20 +58,20 @@ func (_m *ImageReaderWriter) Count() (int, error) {
return r0, r1
}
-// CountByGalleryID provides a mock function with given fields: galleryID
-func (_m *ImageReaderWriter) CountByGalleryID(galleryID int) (int, error) {
- ret := _m.Called(galleryID)
+// CountByGalleryID provides a mock function with given fields: ctx, galleryID
+func (_m *ImageReaderWriter) CountByGalleryID(ctx context.Context, galleryID int) (int, error) {
+ ret := _m.Called(ctx, galleryID)
var r0 int
- if rf, ok := ret.Get(0).(func(int) int); ok {
- r0 = rf(galleryID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) int); ok {
+ r0 = rf(ctx, galleryID)
} else {
r0 = ret.Get(0).(int)
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(galleryID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, galleryID)
} else {
r1 = ret.Error(1)
}
@@ -77,13 +79,13 @@ func (_m *ImageReaderWriter) CountByGalleryID(galleryID int) (int, error) {
return r0, r1
}
-// Create provides a mock function with given fields: newImage
-func (_m *ImageReaderWriter) Create(newImage models.Image) (*models.Image, error) {
- ret := _m.Called(newImage)
+// Create provides a mock function with given fields: ctx, newImage
+func (_m *ImageReaderWriter) Create(ctx context.Context, newImage models.Image) (*models.Image, error) {
+ ret := _m.Called(ctx, newImage)
var r0 *models.Image
- if rf, ok := ret.Get(0).(func(models.Image) *models.Image); ok {
- r0 = rf(newImage)
+ if rf, ok := ret.Get(0).(func(context.Context, models.Image) *models.Image); ok {
+ r0 = rf(ctx, newImage)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.Image)
@@ -91,8 +93,8 @@ func (_m *ImageReaderWriter) Create(newImage models.Image) (*models.Image, error
}
var r1 error
- if rf, ok := ret.Get(1).(func(models.Image) error); ok {
- r1 = rf(newImage)
+ if rf, ok := ret.Get(1).(func(context.Context, models.Image) error); ok {
+ r1 = rf(ctx, newImage)
} else {
r1 = ret.Error(1)
}
@@ -100,20 +102,20 @@ func (_m *ImageReaderWriter) Create(newImage models.Image) (*models.Image, error
return r0, r1
}
-// DecrementOCounter provides a mock function with given fields: id
-func (_m *ImageReaderWriter) DecrementOCounter(id int) (int, error) {
- ret := _m.Called(id)
+// DecrementOCounter provides a mock function with given fields: ctx, id
+func (_m *ImageReaderWriter) DecrementOCounter(ctx context.Context, id int) (int, error) {
+ ret := _m.Called(ctx, id)
var r0 int
- if rf, ok := ret.Get(0).(func(int) int); ok {
- r0 = rf(id)
+ if rf, ok := ret.Get(0).(func(context.Context, int) int); ok {
+ r0 = rf(ctx, id)
} else {
r0 = ret.Get(0).(int)
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(id)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, id)
} else {
r1 = ret.Error(1)
}
@@ -121,13 +123,13 @@ func (_m *ImageReaderWriter) DecrementOCounter(id int) (int, error) {
return r0, r1
}
-// Destroy provides a mock function with given fields: id
-func (_m *ImageReaderWriter) Destroy(id int) error {
- ret := _m.Called(id)
+// Destroy provides a mock function with given fields: ctx, id
+func (_m *ImageReaderWriter) Destroy(ctx context.Context, id int) error {
+ ret := _m.Called(ctx, id)
var r0 error
- if rf, ok := ret.Get(0).(func(int) error); ok {
- r0 = rf(id)
+ if rf, ok := ret.Get(0).(func(context.Context, int) error); ok {
+ r0 = rf(ctx, id)
} else {
r0 = ret.Error(0)
}
@@ -135,13 +137,13 @@ func (_m *ImageReaderWriter) Destroy(id int) error {
return r0
}
-// Find provides a mock function with given fields: id
-func (_m *ImageReaderWriter) Find(id int) (*models.Image, error) {
- ret := _m.Called(id)
+// Find provides a mock function with given fields: ctx, id
+func (_m *ImageReaderWriter) Find(ctx context.Context, id int) (*models.Image, error) {
+ ret := _m.Called(ctx, id)
var r0 *models.Image
- if rf, ok := ret.Get(0).(func(int) *models.Image); ok {
- r0 = rf(id)
+ if rf, ok := ret.Get(0).(func(context.Context, int) *models.Image); ok {
+ r0 = rf(ctx, id)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.Image)
@@ -149,8 +151,8 @@ func (_m *ImageReaderWriter) Find(id int) (*models.Image, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(id)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, id)
} else {
r1 = ret.Error(1)
}
@@ -158,13 +160,13 @@ func (_m *ImageReaderWriter) Find(id int) (*models.Image, error) {
return r0, r1
}
-// FindByChecksum provides a mock function with given fields: checksum
-func (_m *ImageReaderWriter) FindByChecksum(checksum string) (*models.Image, error) {
- ret := _m.Called(checksum)
+// FindByChecksum provides a mock function with given fields: ctx, checksum
+func (_m *ImageReaderWriter) FindByChecksum(ctx context.Context, checksum string) (*models.Image, error) {
+ ret := _m.Called(ctx, checksum)
var r0 *models.Image
- if rf, ok := ret.Get(0).(func(string) *models.Image); ok {
- r0 = rf(checksum)
+ if rf, ok := ret.Get(0).(func(context.Context, string) *models.Image); ok {
+ r0 = rf(ctx, checksum)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.Image)
@@ -172,8 +174,8 @@ func (_m *ImageReaderWriter) FindByChecksum(checksum string) (*models.Image, err
}
var r1 error
- if rf, ok := ret.Get(1).(func(string) error); ok {
- r1 = rf(checksum)
+ if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
+ r1 = rf(ctx, checksum)
} else {
r1 = ret.Error(1)
}
@@ -181,13 +183,13 @@ func (_m *ImageReaderWriter) FindByChecksum(checksum string) (*models.Image, err
return r0, r1
}
-// FindByGalleryID provides a mock function with given fields: galleryID
-func (_m *ImageReaderWriter) FindByGalleryID(galleryID int) ([]*models.Image, error) {
- ret := _m.Called(galleryID)
+// FindByGalleryID provides a mock function with given fields: ctx, galleryID
+func (_m *ImageReaderWriter) FindByGalleryID(ctx context.Context, galleryID int) ([]*models.Image, error) {
+ ret := _m.Called(ctx, galleryID)
var r0 []*models.Image
- if rf, ok := ret.Get(0).(func(int) []*models.Image); ok {
- r0 = rf(galleryID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) []*models.Image); ok {
+ r0 = rf(ctx, galleryID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Image)
@@ -195,8 +197,8 @@ func (_m *ImageReaderWriter) FindByGalleryID(galleryID int) ([]*models.Image, er
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(galleryID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, galleryID)
} else {
r1 = ret.Error(1)
}
@@ -204,13 +206,13 @@ func (_m *ImageReaderWriter) FindByGalleryID(galleryID int) ([]*models.Image, er
return r0, r1
}
-// FindByPath provides a mock function with given fields: path
-func (_m *ImageReaderWriter) FindByPath(path string) (*models.Image, error) {
- ret := _m.Called(path)
+// FindByPath provides a mock function with given fields: ctx, path
+func (_m *ImageReaderWriter) FindByPath(ctx context.Context, path string) (*models.Image, error) {
+ ret := _m.Called(ctx, path)
var r0 *models.Image
- if rf, ok := ret.Get(0).(func(string) *models.Image); ok {
- r0 = rf(path)
+ if rf, ok := ret.Get(0).(func(context.Context, string) *models.Image); ok {
+ r0 = rf(ctx, path)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.Image)
@@ -218,8 +220,8 @@ func (_m *ImageReaderWriter) FindByPath(path string) (*models.Image, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func(string) error); ok {
- r1 = rf(path)
+ if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
+ r1 = rf(ctx, path)
} else {
r1 = ret.Error(1)
}
@@ -227,13 +229,13 @@ func (_m *ImageReaderWriter) FindByPath(path string) (*models.Image, error) {
return r0, r1
}
-// FindMany provides a mock function with given fields: ids
-func (_m *ImageReaderWriter) FindMany(ids []int) ([]*models.Image, error) {
- ret := _m.Called(ids)
+// FindMany provides a mock function with given fields: ctx, ids
+func (_m *ImageReaderWriter) FindMany(ctx context.Context, ids []int) ([]*models.Image, error) {
+ ret := _m.Called(ctx, ids)
var r0 []*models.Image
- if rf, ok := ret.Get(0).(func([]int) []*models.Image); ok {
- r0 = rf(ids)
+ if rf, ok := ret.Get(0).(func(context.Context, []int) []*models.Image); ok {
+ r0 = rf(ctx, ids)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Image)
@@ -241,8 +243,8 @@ func (_m *ImageReaderWriter) FindMany(ids []int) ([]*models.Image, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func([]int) error); ok {
- r1 = rf(ids)
+ if rf, ok := ret.Get(1).(func(context.Context, []int) error); ok {
+ r1 = rf(ctx, ids)
} else {
r1 = ret.Error(1)
}
@@ -250,13 +252,13 @@ func (_m *ImageReaderWriter) FindMany(ids []int) ([]*models.Image, error) {
return r0, r1
}
-// GetGalleryIDs provides a mock function with given fields: imageID
-func (_m *ImageReaderWriter) GetGalleryIDs(imageID int) ([]int, error) {
- ret := _m.Called(imageID)
+// GetGalleryIDs provides a mock function with given fields: ctx, imageID
+func (_m *ImageReaderWriter) GetGalleryIDs(ctx context.Context, imageID int) ([]int, error) {
+ ret := _m.Called(ctx, imageID)
var r0 []int
- if rf, ok := ret.Get(0).(func(int) []int); ok {
- r0 = rf(imageID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) []int); ok {
+ r0 = rf(ctx, imageID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]int)
@@ -264,8 +266,8 @@ func (_m *ImageReaderWriter) GetGalleryIDs(imageID int) ([]int, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(imageID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, imageID)
} else {
r1 = ret.Error(1)
}
@@ -273,13 +275,13 @@ func (_m *ImageReaderWriter) GetGalleryIDs(imageID int) ([]int, error) {
return r0, r1
}
-// GetPerformerIDs provides a mock function with given fields: imageID
-func (_m *ImageReaderWriter) GetPerformerIDs(imageID int) ([]int, error) {
- ret := _m.Called(imageID)
+// GetPerformerIDs provides a mock function with given fields: ctx, imageID
+func (_m *ImageReaderWriter) GetPerformerIDs(ctx context.Context, imageID int) ([]int, error) {
+ ret := _m.Called(ctx, imageID)
var r0 []int
- if rf, ok := ret.Get(0).(func(int) []int); ok {
- r0 = rf(imageID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) []int); ok {
+ r0 = rf(ctx, imageID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]int)
@@ -287,8 +289,8 @@ func (_m *ImageReaderWriter) GetPerformerIDs(imageID int) ([]int, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(imageID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, imageID)
} else {
r1 = ret.Error(1)
}
@@ -296,13 +298,13 @@ func (_m *ImageReaderWriter) GetPerformerIDs(imageID int) ([]int, error) {
return r0, r1
}
-// GetTagIDs provides a mock function with given fields: imageID
-func (_m *ImageReaderWriter) GetTagIDs(imageID int) ([]int, error) {
- ret := _m.Called(imageID)
+// GetTagIDs provides a mock function with given fields: ctx, imageID
+func (_m *ImageReaderWriter) GetTagIDs(ctx context.Context, imageID int) ([]int, error) {
+ ret := _m.Called(ctx, imageID)
var r0 []int
- if rf, ok := ret.Get(0).(func(int) []int); ok {
- r0 = rf(imageID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) []int); ok {
+ r0 = rf(ctx, imageID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]int)
@@ -310,8 +312,8 @@ func (_m *ImageReaderWriter) GetTagIDs(imageID int) ([]int, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(imageID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, imageID)
} else {
r1 = ret.Error(1)
}
@@ -319,20 +321,20 @@ func (_m *ImageReaderWriter) GetTagIDs(imageID int) ([]int, error) {
return r0, r1
}
-// IncrementOCounter provides a mock function with given fields: id
-func (_m *ImageReaderWriter) IncrementOCounter(id int) (int, error) {
- ret := _m.Called(id)
+// IncrementOCounter provides a mock function with given fields: ctx, id
+func (_m *ImageReaderWriter) IncrementOCounter(ctx context.Context, id int) (int, error) {
+ ret := _m.Called(ctx, id)
var r0 int
- if rf, ok := ret.Get(0).(func(int) int); ok {
- r0 = rf(id)
+ if rf, ok := ret.Get(0).(func(context.Context, int) int); ok {
+ r0 = rf(ctx, id)
} else {
r0 = ret.Get(0).(int)
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(id)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, id)
} else {
r1 = ret.Error(1)
}
@@ -340,13 +342,13 @@ func (_m *ImageReaderWriter) IncrementOCounter(id int) (int, error) {
return r0, r1
}
-// Query provides a mock function with given fields: options
-func (_m *ImageReaderWriter) Query(options models.ImageQueryOptions) (*models.ImageQueryResult, error) {
- ret := _m.Called(options)
+// Query provides a mock function with given fields: ctx, options
+func (_m *ImageReaderWriter) Query(ctx context.Context, options models.ImageQueryOptions) (*models.ImageQueryResult, error) {
+ ret := _m.Called(ctx, options)
var r0 *models.ImageQueryResult
- if rf, ok := ret.Get(0).(func(models.ImageQueryOptions) *models.ImageQueryResult); ok {
- r0 = rf(options)
+ if rf, ok := ret.Get(0).(func(context.Context, models.ImageQueryOptions) *models.ImageQueryResult); ok {
+ r0 = rf(ctx, options)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.ImageQueryResult)
@@ -354,8 +356,8 @@ func (_m *ImageReaderWriter) Query(options models.ImageQueryOptions) (*models.Im
}
var r1 error
- if rf, ok := ret.Get(1).(func(models.ImageQueryOptions) error); ok {
- r1 = rf(options)
+ if rf, ok := ret.Get(1).(func(context.Context, models.ImageQueryOptions) error); ok {
+ r1 = rf(ctx, options)
} else {
r1 = ret.Error(1)
}
@@ -363,20 +365,20 @@ func (_m *ImageReaderWriter) Query(options models.ImageQueryOptions) (*models.Im
return r0, r1
}
-// QueryCount provides a mock function with given fields: imageFilter, findFilter
-func (_m *ImageReaderWriter) QueryCount(imageFilter *models.ImageFilterType, findFilter *models.FindFilterType) (int, error) {
- ret := _m.Called(imageFilter, findFilter)
+// QueryCount provides a mock function with given fields: ctx, imageFilter, findFilter
+func (_m *ImageReaderWriter) QueryCount(ctx context.Context, imageFilter *models.ImageFilterType, findFilter *models.FindFilterType) (int, error) {
+ ret := _m.Called(ctx, imageFilter, findFilter)
var r0 int
- if rf, ok := ret.Get(0).(func(*models.ImageFilterType, *models.FindFilterType) int); ok {
- r0 = rf(imageFilter, findFilter)
+ if rf, ok := ret.Get(0).(func(context.Context, *models.ImageFilterType, *models.FindFilterType) int); ok {
+ r0 = rf(ctx, imageFilter, findFilter)
} else {
r0 = ret.Get(0).(int)
}
var r1 error
- if rf, ok := ret.Get(1).(func(*models.ImageFilterType, *models.FindFilterType) error); ok {
- r1 = rf(imageFilter, findFilter)
+ if rf, ok := ret.Get(1).(func(context.Context, *models.ImageFilterType, *models.FindFilterType) error); ok {
+ r1 = rf(ctx, imageFilter, findFilter)
} else {
r1 = ret.Error(1)
}
@@ -384,20 +386,20 @@ func (_m *ImageReaderWriter) QueryCount(imageFilter *models.ImageFilterType, fin
return r0, r1
}
-// ResetOCounter provides a mock function with given fields: id
-func (_m *ImageReaderWriter) ResetOCounter(id int) (int, error) {
- ret := _m.Called(id)
+// ResetOCounter provides a mock function with given fields: ctx, id
+func (_m *ImageReaderWriter) ResetOCounter(ctx context.Context, id int) (int, error) {
+ ret := _m.Called(ctx, id)
var r0 int
- if rf, ok := ret.Get(0).(func(int) int); ok {
- r0 = rf(id)
+ if rf, ok := ret.Get(0).(func(context.Context, int) int); ok {
+ r0 = rf(ctx, id)
} else {
r0 = ret.Get(0).(int)
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(id)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, id)
} else {
r1 = ret.Error(1)
}
@@ -405,20 +407,20 @@ func (_m *ImageReaderWriter) ResetOCounter(id int) (int, error) {
return r0, r1
}
-// Size provides a mock function with given fields:
-func (_m *ImageReaderWriter) Size() (float64, error) {
- ret := _m.Called()
+// Size provides a mock function with given fields: ctx
+func (_m *ImageReaderWriter) Size(ctx context.Context) (float64, error) {
+ ret := _m.Called(ctx)
var r0 float64
- if rf, ok := ret.Get(0).(func() float64); ok {
- r0 = rf()
+ if rf, ok := ret.Get(0).(func(context.Context) float64); ok {
+ r0 = rf(ctx)
} else {
r0 = ret.Get(0).(float64)
}
var r1 error
- if rf, ok := ret.Get(1).(func() error); ok {
- r1 = rf()
+ if rf, ok := ret.Get(1).(func(context.Context) error); ok {
+ r1 = rf(ctx)
} else {
r1 = ret.Error(1)
}
@@ -426,13 +428,13 @@ func (_m *ImageReaderWriter) Size() (float64, error) {
return r0, r1
}
-// Update provides a mock function with given fields: updatedImage
-func (_m *ImageReaderWriter) Update(updatedImage models.ImagePartial) (*models.Image, error) {
- ret := _m.Called(updatedImage)
+// Update provides a mock function with given fields: ctx, updatedImage
+func (_m *ImageReaderWriter) Update(ctx context.Context, updatedImage models.ImagePartial) (*models.Image, error) {
+ ret := _m.Called(ctx, updatedImage)
var r0 *models.Image
- if rf, ok := ret.Get(0).(func(models.ImagePartial) *models.Image); ok {
- r0 = rf(updatedImage)
+ if rf, ok := ret.Get(0).(func(context.Context, models.ImagePartial) *models.Image); ok {
+ r0 = rf(ctx, updatedImage)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.Image)
@@ -440,8 +442,8 @@ func (_m *ImageReaderWriter) Update(updatedImage models.ImagePartial) (*models.I
}
var r1 error
- if rf, ok := ret.Get(1).(func(models.ImagePartial) error); ok {
- r1 = rf(updatedImage)
+ if rf, ok := ret.Get(1).(func(context.Context, models.ImagePartial) error); ok {
+ r1 = rf(ctx, updatedImage)
} else {
r1 = ret.Error(1)
}
@@ -449,13 +451,13 @@ func (_m *ImageReaderWriter) Update(updatedImage models.ImagePartial) (*models.I
return r0, r1
}
-// UpdateFull provides a mock function with given fields: updatedImage
-func (_m *ImageReaderWriter) UpdateFull(updatedImage models.Image) (*models.Image, error) {
- ret := _m.Called(updatedImage)
+// UpdateFull provides a mock function with given fields: ctx, updatedImage
+func (_m *ImageReaderWriter) UpdateFull(ctx context.Context, updatedImage models.Image) (*models.Image, error) {
+ ret := _m.Called(ctx, updatedImage)
var r0 *models.Image
- if rf, ok := ret.Get(0).(func(models.Image) *models.Image); ok {
- r0 = rf(updatedImage)
+ if rf, ok := ret.Get(0).(func(context.Context, models.Image) *models.Image); ok {
+ r0 = rf(ctx, updatedImage)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.Image)
@@ -463,8 +465,8 @@ func (_m *ImageReaderWriter) UpdateFull(updatedImage models.Image) (*models.Imag
}
var r1 error
- if rf, ok := ret.Get(1).(func(models.Image) error); ok {
- r1 = rf(updatedImage)
+ if rf, ok := ret.Get(1).(func(context.Context, models.Image) error); ok {
+ r1 = rf(ctx, updatedImage)
} else {
r1 = ret.Error(1)
}
@@ -472,13 +474,13 @@ func (_m *ImageReaderWriter) UpdateFull(updatedImage models.Image) (*models.Imag
return r0, r1
}
-// UpdateGalleries provides a mock function with given fields: imageID, galleryIDs
-func (_m *ImageReaderWriter) UpdateGalleries(imageID int, galleryIDs []int) error {
- ret := _m.Called(imageID, galleryIDs)
+// UpdateGalleries provides a mock function with given fields: ctx, imageID, galleryIDs
+func (_m *ImageReaderWriter) UpdateGalleries(ctx context.Context, imageID int, galleryIDs []int) error {
+ ret := _m.Called(ctx, imageID, galleryIDs)
var r0 error
- if rf, ok := ret.Get(0).(func(int, []int) error); ok {
- r0 = rf(imageID, galleryIDs)
+ if rf, ok := ret.Get(0).(func(context.Context, int, []int) error); ok {
+ r0 = rf(ctx, imageID, galleryIDs)
} else {
r0 = ret.Error(0)
}
@@ -486,13 +488,13 @@ func (_m *ImageReaderWriter) UpdateGalleries(imageID int, galleryIDs []int) erro
return r0
}
-// UpdatePerformers provides a mock function with given fields: imageID, performerIDs
-func (_m *ImageReaderWriter) UpdatePerformers(imageID int, performerIDs []int) error {
- ret := _m.Called(imageID, performerIDs)
+// UpdatePerformers provides a mock function with given fields: ctx, imageID, performerIDs
+func (_m *ImageReaderWriter) UpdatePerformers(ctx context.Context, imageID int, performerIDs []int) error {
+ ret := _m.Called(ctx, imageID, performerIDs)
var r0 error
- if rf, ok := ret.Get(0).(func(int, []int) error); ok {
- r0 = rf(imageID, performerIDs)
+ if rf, ok := ret.Get(0).(func(context.Context, int, []int) error); ok {
+ r0 = rf(ctx, imageID, performerIDs)
} else {
r0 = ret.Error(0)
}
@@ -500,13 +502,13 @@ func (_m *ImageReaderWriter) UpdatePerformers(imageID int, performerIDs []int) e
return r0
}
-// UpdateTags provides a mock function with given fields: imageID, tagIDs
-func (_m *ImageReaderWriter) UpdateTags(imageID int, tagIDs []int) error {
- ret := _m.Called(imageID, tagIDs)
+// UpdateTags provides a mock function with given fields: ctx, imageID, tagIDs
+func (_m *ImageReaderWriter) UpdateTags(ctx context.Context, imageID int, tagIDs []int) error {
+ ret := _m.Called(ctx, imageID, tagIDs)
var r0 error
- if rf, ok := ret.Get(0).(func(int, []int) error); ok {
- r0 = rf(imageID, tagIDs)
+ if rf, ok := ret.Get(0).(func(context.Context, int, []int) error); ok {
+ r0 = rf(ctx, imageID, tagIDs)
} else {
r0 = ret.Error(0)
}
diff --git a/pkg/models/mocks/MovieReaderWriter.go b/pkg/models/mocks/MovieReaderWriter.go
index 288eb6fd4..c125fc7b1 100644
--- a/pkg/models/mocks/MovieReaderWriter.go
+++ b/pkg/models/mocks/MovieReaderWriter.go
@@ -3,6 +3,8 @@
package mocks
import (
+ context "context"
+
models "github.com/stashapp/stash/pkg/models"
mock "github.com/stretchr/testify/mock"
)
@@ -12,13 +14,13 @@ type MovieReaderWriter struct {
mock.Mock
}
-// All provides a mock function with given fields:
-func (_m *MovieReaderWriter) All() ([]*models.Movie, error) {
- ret := _m.Called()
+// All provides a mock function with given fields: ctx
+func (_m *MovieReaderWriter) All(ctx context.Context) ([]*models.Movie, error) {
+ ret := _m.Called(ctx)
var r0 []*models.Movie
- if rf, ok := ret.Get(0).(func() []*models.Movie); ok {
- r0 = rf()
+ if rf, ok := ret.Get(0).(func(context.Context) []*models.Movie); ok {
+ r0 = rf(ctx)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Movie)
@@ -26,8 +28,8 @@ func (_m *MovieReaderWriter) All() ([]*models.Movie, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func() error); ok {
- r1 = rf()
+ if rf, ok := ret.Get(1).(func(context.Context) error); ok {
+ r1 = rf(ctx)
} else {
r1 = ret.Error(1)
}
@@ -35,20 +37,20 @@ func (_m *MovieReaderWriter) All() ([]*models.Movie, error) {
return r0, r1
}
-// Count provides a mock function with given fields:
-func (_m *MovieReaderWriter) Count() (int, error) {
- ret := _m.Called()
+// Count provides a mock function with given fields: ctx
+func (_m *MovieReaderWriter) Count(ctx context.Context) (int, error) {
+ ret := _m.Called(ctx)
var r0 int
- if rf, ok := ret.Get(0).(func() int); ok {
- r0 = rf()
+ if rf, ok := ret.Get(0).(func(context.Context) int); ok {
+ r0 = rf(ctx)
} else {
r0 = ret.Get(0).(int)
}
var r1 error
- if rf, ok := ret.Get(1).(func() error); ok {
- r1 = rf()
+ if rf, ok := ret.Get(1).(func(context.Context) error); ok {
+ r1 = rf(ctx)
} else {
r1 = ret.Error(1)
}
@@ -56,20 +58,20 @@ func (_m *MovieReaderWriter) Count() (int, error) {
return r0, r1
}
-// CountByPerformerID provides a mock function with given fields: performerID
-func (_m *MovieReaderWriter) CountByPerformerID(performerID int) (int, error) {
- ret := _m.Called(performerID)
+// CountByPerformerID provides a mock function with given fields: ctx, performerID
+func (_m *MovieReaderWriter) CountByPerformerID(ctx context.Context, performerID int) (int, error) {
+ ret := _m.Called(ctx, performerID)
var r0 int
- if rf, ok := ret.Get(0).(func(int) int); ok {
- r0 = rf(performerID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) int); ok {
+ r0 = rf(ctx, performerID)
} else {
r0 = ret.Get(0).(int)
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(performerID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, performerID)
} else {
r1 = ret.Error(1)
}
@@ -77,20 +79,20 @@ func (_m *MovieReaderWriter) CountByPerformerID(performerID int) (int, error) {
return r0, r1
}
-// CountByStudioID provides a mock function with given fields: studioID
-func (_m *MovieReaderWriter) CountByStudioID(studioID int) (int, error) {
- ret := _m.Called(studioID)
+// CountByStudioID provides a mock function with given fields: ctx, studioID
+func (_m *MovieReaderWriter) CountByStudioID(ctx context.Context, studioID int) (int, error) {
+ ret := _m.Called(ctx, studioID)
var r0 int
- if rf, ok := ret.Get(0).(func(int) int); ok {
- r0 = rf(studioID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) int); ok {
+ r0 = rf(ctx, studioID)
} else {
r0 = ret.Get(0).(int)
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(studioID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, studioID)
} else {
r1 = ret.Error(1)
}
@@ -98,13 +100,13 @@ func (_m *MovieReaderWriter) CountByStudioID(studioID int) (int, error) {
return r0, r1
}
-// Create provides a mock function with given fields: newMovie
-func (_m *MovieReaderWriter) Create(newMovie models.Movie) (*models.Movie, error) {
- ret := _m.Called(newMovie)
+// Create provides a mock function with given fields: ctx, newMovie
+func (_m *MovieReaderWriter) Create(ctx context.Context, newMovie models.Movie) (*models.Movie, error) {
+ ret := _m.Called(ctx, newMovie)
var r0 *models.Movie
- if rf, ok := ret.Get(0).(func(models.Movie) *models.Movie); ok {
- r0 = rf(newMovie)
+ if rf, ok := ret.Get(0).(func(context.Context, models.Movie) *models.Movie); ok {
+ r0 = rf(ctx, newMovie)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.Movie)
@@ -112,8 +114,8 @@ func (_m *MovieReaderWriter) Create(newMovie models.Movie) (*models.Movie, error
}
var r1 error
- if rf, ok := ret.Get(1).(func(models.Movie) error); ok {
- r1 = rf(newMovie)
+ if rf, ok := ret.Get(1).(func(context.Context, models.Movie) error); ok {
+ r1 = rf(ctx, newMovie)
} else {
r1 = ret.Error(1)
}
@@ -121,13 +123,13 @@ func (_m *MovieReaderWriter) Create(newMovie models.Movie) (*models.Movie, error
return r0, r1
}
-// Destroy provides a mock function with given fields: id
-func (_m *MovieReaderWriter) Destroy(id int) error {
- ret := _m.Called(id)
+// Destroy provides a mock function with given fields: ctx, id
+func (_m *MovieReaderWriter) Destroy(ctx context.Context, id int) error {
+ ret := _m.Called(ctx, id)
var r0 error
- if rf, ok := ret.Get(0).(func(int) error); ok {
- r0 = rf(id)
+ if rf, ok := ret.Get(0).(func(context.Context, int) error); ok {
+ r0 = rf(ctx, id)
} else {
r0 = ret.Error(0)
}
@@ -135,13 +137,13 @@ func (_m *MovieReaderWriter) Destroy(id int) error {
return r0
}
-// DestroyImages provides a mock function with given fields: movieID
-func (_m *MovieReaderWriter) DestroyImages(movieID int) error {
- ret := _m.Called(movieID)
+// DestroyImages provides a mock function with given fields: ctx, movieID
+func (_m *MovieReaderWriter) DestroyImages(ctx context.Context, movieID int) error {
+ ret := _m.Called(ctx, movieID)
var r0 error
- if rf, ok := ret.Get(0).(func(int) error); ok {
- r0 = rf(movieID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) error); ok {
+ r0 = rf(ctx, movieID)
} else {
r0 = ret.Error(0)
}
@@ -149,13 +151,13 @@ func (_m *MovieReaderWriter) DestroyImages(movieID int) error {
return r0
}
-// Find provides a mock function with given fields: id
-func (_m *MovieReaderWriter) Find(id int) (*models.Movie, error) {
- ret := _m.Called(id)
+// Find provides a mock function with given fields: ctx, id
+func (_m *MovieReaderWriter) Find(ctx context.Context, id int) (*models.Movie, error) {
+ ret := _m.Called(ctx, id)
var r0 *models.Movie
- if rf, ok := ret.Get(0).(func(int) *models.Movie); ok {
- r0 = rf(id)
+ if rf, ok := ret.Get(0).(func(context.Context, int) *models.Movie); ok {
+ r0 = rf(ctx, id)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.Movie)
@@ -163,8 +165,8 @@ func (_m *MovieReaderWriter) Find(id int) (*models.Movie, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(id)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, id)
} else {
r1 = ret.Error(1)
}
@@ -172,13 +174,13 @@ func (_m *MovieReaderWriter) Find(id int) (*models.Movie, error) {
return r0, r1
}
-// FindByName provides a mock function with given fields: name, nocase
-func (_m *MovieReaderWriter) FindByName(name string, nocase bool) (*models.Movie, error) {
- ret := _m.Called(name, nocase)
+// FindByName provides a mock function with given fields: ctx, name, nocase
+func (_m *MovieReaderWriter) FindByName(ctx context.Context, name string, nocase bool) (*models.Movie, error) {
+ ret := _m.Called(ctx, name, nocase)
var r0 *models.Movie
- if rf, ok := ret.Get(0).(func(string, bool) *models.Movie); ok {
- r0 = rf(name, nocase)
+ if rf, ok := ret.Get(0).(func(context.Context, string, bool) *models.Movie); ok {
+ r0 = rf(ctx, name, nocase)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.Movie)
@@ -186,8 +188,8 @@ func (_m *MovieReaderWriter) FindByName(name string, nocase bool) (*models.Movie
}
var r1 error
- if rf, ok := ret.Get(1).(func(string, bool) error); ok {
- r1 = rf(name, nocase)
+ if rf, ok := ret.Get(1).(func(context.Context, string, bool) error); ok {
+ r1 = rf(ctx, name, nocase)
} else {
r1 = ret.Error(1)
}
@@ -195,13 +197,13 @@ func (_m *MovieReaderWriter) FindByName(name string, nocase bool) (*models.Movie
return r0, r1
}
-// FindByNames provides a mock function with given fields: names, nocase
-func (_m *MovieReaderWriter) FindByNames(names []string, nocase bool) ([]*models.Movie, error) {
- ret := _m.Called(names, nocase)
+// FindByNames provides a mock function with given fields: ctx, names, nocase
+func (_m *MovieReaderWriter) FindByNames(ctx context.Context, names []string, nocase bool) ([]*models.Movie, error) {
+ ret := _m.Called(ctx, names, nocase)
var r0 []*models.Movie
- if rf, ok := ret.Get(0).(func([]string, bool) []*models.Movie); ok {
- r0 = rf(names, nocase)
+ if rf, ok := ret.Get(0).(func(context.Context, []string, bool) []*models.Movie); ok {
+ r0 = rf(ctx, names, nocase)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Movie)
@@ -209,8 +211,8 @@ func (_m *MovieReaderWriter) FindByNames(names []string, nocase bool) ([]*models
}
var r1 error
- if rf, ok := ret.Get(1).(func([]string, bool) error); ok {
- r1 = rf(names, nocase)
+ if rf, ok := ret.Get(1).(func(context.Context, []string, bool) error); ok {
+ r1 = rf(ctx, names, nocase)
} else {
r1 = ret.Error(1)
}
@@ -218,13 +220,13 @@ func (_m *MovieReaderWriter) FindByNames(names []string, nocase bool) ([]*models
return r0, r1
}
-// FindByPerformerID provides a mock function with given fields: performerID
-func (_m *MovieReaderWriter) FindByPerformerID(performerID int) ([]*models.Movie, error) {
- ret := _m.Called(performerID)
+// FindByPerformerID provides a mock function with given fields: ctx, performerID
+func (_m *MovieReaderWriter) FindByPerformerID(ctx context.Context, performerID int) ([]*models.Movie, error) {
+ ret := _m.Called(ctx, performerID)
var r0 []*models.Movie
- if rf, ok := ret.Get(0).(func(int) []*models.Movie); ok {
- r0 = rf(performerID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) []*models.Movie); ok {
+ r0 = rf(ctx, performerID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Movie)
@@ -232,8 +234,8 @@ func (_m *MovieReaderWriter) FindByPerformerID(performerID int) ([]*models.Movie
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(performerID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, performerID)
} else {
r1 = ret.Error(1)
}
@@ -241,13 +243,13 @@ func (_m *MovieReaderWriter) FindByPerformerID(performerID int) ([]*models.Movie
return r0, r1
}
-// FindByStudioID provides a mock function with given fields: studioID
-func (_m *MovieReaderWriter) FindByStudioID(studioID int) ([]*models.Movie, error) {
- ret := _m.Called(studioID)
+// FindByStudioID provides a mock function with given fields: ctx, studioID
+func (_m *MovieReaderWriter) FindByStudioID(ctx context.Context, studioID int) ([]*models.Movie, error) {
+ ret := _m.Called(ctx, studioID)
var r0 []*models.Movie
- if rf, ok := ret.Get(0).(func(int) []*models.Movie); ok {
- r0 = rf(studioID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) []*models.Movie); ok {
+ r0 = rf(ctx, studioID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Movie)
@@ -255,8 +257,8 @@ func (_m *MovieReaderWriter) FindByStudioID(studioID int) ([]*models.Movie, erro
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(studioID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, studioID)
} else {
r1 = ret.Error(1)
}
@@ -264,13 +266,13 @@ func (_m *MovieReaderWriter) FindByStudioID(studioID int) ([]*models.Movie, erro
return r0, r1
}
-// FindMany provides a mock function with given fields: ids
-func (_m *MovieReaderWriter) FindMany(ids []int) ([]*models.Movie, error) {
- ret := _m.Called(ids)
+// FindMany provides a mock function with given fields: ctx, ids
+func (_m *MovieReaderWriter) FindMany(ctx context.Context, ids []int) ([]*models.Movie, error) {
+ ret := _m.Called(ctx, ids)
var r0 []*models.Movie
- if rf, ok := ret.Get(0).(func([]int) []*models.Movie); ok {
- r0 = rf(ids)
+ if rf, ok := ret.Get(0).(func(context.Context, []int) []*models.Movie); ok {
+ r0 = rf(ctx, ids)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Movie)
@@ -278,8 +280,8 @@ func (_m *MovieReaderWriter) FindMany(ids []int) ([]*models.Movie, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func([]int) error); ok {
- r1 = rf(ids)
+ if rf, ok := ret.Get(1).(func(context.Context, []int) error); ok {
+ r1 = rf(ctx, ids)
} else {
r1 = ret.Error(1)
}
@@ -287,13 +289,13 @@ func (_m *MovieReaderWriter) FindMany(ids []int) ([]*models.Movie, error) {
return r0, r1
}
-// GetBackImage provides a mock function with given fields: movieID
-func (_m *MovieReaderWriter) GetBackImage(movieID int) ([]byte, error) {
- ret := _m.Called(movieID)
+// GetBackImage provides a mock function with given fields: ctx, movieID
+func (_m *MovieReaderWriter) GetBackImage(ctx context.Context, movieID int) ([]byte, error) {
+ ret := _m.Called(ctx, movieID)
var r0 []byte
- if rf, ok := ret.Get(0).(func(int) []byte); ok {
- r0 = rf(movieID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) []byte); ok {
+ r0 = rf(ctx, movieID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]byte)
@@ -301,8 +303,8 @@ func (_m *MovieReaderWriter) GetBackImage(movieID int) ([]byte, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(movieID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, movieID)
} else {
r1 = ret.Error(1)
}
@@ -310,13 +312,13 @@ func (_m *MovieReaderWriter) GetBackImage(movieID int) ([]byte, error) {
return r0, r1
}
-// GetFrontImage provides a mock function with given fields: movieID
-func (_m *MovieReaderWriter) GetFrontImage(movieID int) ([]byte, error) {
- ret := _m.Called(movieID)
+// GetFrontImage provides a mock function with given fields: ctx, movieID
+func (_m *MovieReaderWriter) GetFrontImage(ctx context.Context, movieID int) ([]byte, error) {
+ ret := _m.Called(ctx, movieID)
var r0 []byte
- if rf, ok := ret.Get(0).(func(int) []byte); ok {
- r0 = rf(movieID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) []byte); ok {
+ r0 = rf(ctx, movieID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]byte)
@@ -324,8 +326,8 @@ func (_m *MovieReaderWriter) GetFrontImage(movieID int) ([]byte, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(movieID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, movieID)
} else {
r1 = ret.Error(1)
}
@@ -333,13 +335,13 @@ func (_m *MovieReaderWriter) GetFrontImage(movieID int) ([]byte, error) {
return r0, r1
}
-// Query provides a mock function with given fields: movieFilter, findFilter
-func (_m *MovieReaderWriter) Query(movieFilter *models.MovieFilterType, findFilter *models.FindFilterType) ([]*models.Movie, int, error) {
- ret := _m.Called(movieFilter, findFilter)
+// Query provides a mock function with given fields: ctx, movieFilter, findFilter
+func (_m *MovieReaderWriter) Query(ctx context.Context, movieFilter *models.MovieFilterType, findFilter *models.FindFilterType) ([]*models.Movie, int, error) {
+ ret := _m.Called(ctx, movieFilter, findFilter)
var r0 []*models.Movie
- if rf, ok := ret.Get(0).(func(*models.MovieFilterType, *models.FindFilterType) []*models.Movie); ok {
- r0 = rf(movieFilter, findFilter)
+ if rf, ok := ret.Get(0).(func(context.Context, *models.MovieFilterType, *models.FindFilterType) []*models.Movie); ok {
+ r0 = rf(ctx, movieFilter, findFilter)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Movie)
@@ -347,15 +349,15 @@ func (_m *MovieReaderWriter) Query(movieFilter *models.MovieFilterType, findFilt
}
var r1 int
- if rf, ok := ret.Get(1).(func(*models.MovieFilterType, *models.FindFilterType) int); ok {
- r1 = rf(movieFilter, findFilter)
+ if rf, ok := ret.Get(1).(func(context.Context, *models.MovieFilterType, *models.FindFilterType) int); ok {
+ r1 = rf(ctx, movieFilter, findFilter)
} else {
r1 = ret.Get(1).(int)
}
var r2 error
- if rf, ok := ret.Get(2).(func(*models.MovieFilterType, *models.FindFilterType) error); ok {
- r2 = rf(movieFilter, findFilter)
+ if rf, ok := ret.Get(2).(func(context.Context, *models.MovieFilterType, *models.FindFilterType) error); ok {
+ r2 = rf(ctx, movieFilter, findFilter)
} else {
r2 = ret.Error(2)
}
@@ -363,13 +365,13 @@ func (_m *MovieReaderWriter) Query(movieFilter *models.MovieFilterType, findFilt
return r0, r1, r2
}
-// Update provides a mock function with given fields: updatedMovie
-func (_m *MovieReaderWriter) Update(updatedMovie models.MoviePartial) (*models.Movie, error) {
- ret := _m.Called(updatedMovie)
+// Update provides a mock function with given fields: ctx, updatedMovie
+func (_m *MovieReaderWriter) Update(ctx context.Context, updatedMovie models.MoviePartial) (*models.Movie, error) {
+ ret := _m.Called(ctx, updatedMovie)
var r0 *models.Movie
- if rf, ok := ret.Get(0).(func(models.MoviePartial) *models.Movie); ok {
- r0 = rf(updatedMovie)
+ if rf, ok := ret.Get(0).(func(context.Context, models.MoviePartial) *models.Movie); ok {
+ r0 = rf(ctx, updatedMovie)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.Movie)
@@ -377,8 +379,8 @@ func (_m *MovieReaderWriter) Update(updatedMovie models.MoviePartial) (*models.M
}
var r1 error
- if rf, ok := ret.Get(1).(func(models.MoviePartial) error); ok {
- r1 = rf(updatedMovie)
+ if rf, ok := ret.Get(1).(func(context.Context, models.MoviePartial) error); ok {
+ r1 = rf(ctx, updatedMovie)
} else {
r1 = ret.Error(1)
}
@@ -386,13 +388,13 @@ func (_m *MovieReaderWriter) Update(updatedMovie models.MoviePartial) (*models.M
return r0, r1
}
-// UpdateFull provides a mock function with given fields: updatedMovie
-func (_m *MovieReaderWriter) UpdateFull(updatedMovie models.Movie) (*models.Movie, error) {
- ret := _m.Called(updatedMovie)
+// UpdateFull provides a mock function with given fields: ctx, updatedMovie
+func (_m *MovieReaderWriter) UpdateFull(ctx context.Context, updatedMovie models.Movie) (*models.Movie, error) {
+ ret := _m.Called(ctx, updatedMovie)
var r0 *models.Movie
- if rf, ok := ret.Get(0).(func(models.Movie) *models.Movie); ok {
- r0 = rf(updatedMovie)
+ if rf, ok := ret.Get(0).(func(context.Context, models.Movie) *models.Movie); ok {
+ r0 = rf(ctx, updatedMovie)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.Movie)
@@ -400,8 +402,8 @@ func (_m *MovieReaderWriter) UpdateFull(updatedMovie models.Movie) (*models.Movi
}
var r1 error
- if rf, ok := ret.Get(1).(func(models.Movie) error); ok {
- r1 = rf(updatedMovie)
+ if rf, ok := ret.Get(1).(func(context.Context, models.Movie) error); ok {
+ r1 = rf(ctx, updatedMovie)
} else {
r1 = ret.Error(1)
}
@@ -409,13 +411,13 @@ func (_m *MovieReaderWriter) UpdateFull(updatedMovie models.Movie) (*models.Movi
return r0, r1
}
-// UpdateImages provides a mock function with given fields: movieID, frontImage, backImage
-func (_m *MovieReaderWriter) UpdateImages(movieID int, frontImage []byte, backImage []byte) error {
- ret := _m.Called(movieID, frontImage, backImage)
+// UpdateImages provides a mock function with given fields: ctx, movieID, frontImage, backImage
+func (_m *MovieReaderWriter) UpdateImages(ctx context.Context, movieID int, frontImage []byte, backImage []byte) error {
+ ret := _m.Called(ctx, movieID, frontImage, backImage)
var r0 error
- if rf, ok := ret.Get(0).(func(int, []byte, []byte) error); ok {
- r0 = rf(movieID, frontImage, backImage)
+ if rf, ok := ret.Get(0).(func(context.Context, int, []byte, []byte) error); ok {
+ r0 = rf(ctx, movieID, frontImage, backImage)
} else {
r0 = ret.Error(0)
}
diff --git a/pkg/models/mocks/PerformerReaderWriter.go b/pkg/models/mocks/PerformerReaderWriter.go
index 485f75170..2f97b66eb 100644
--- a/pkg/models/mocks/PerformerReaderWriter.go
+++ b/pkg/models/mocks/PerformerReaderWriter.go
@@ -3,6 +3,8 @@
package mocks
import (
+ context "context"
+
models "github.com/stashapp/stash/pkg/models"
mock "github.com/stretchr/testify/mock"
)
@@ -12,13 +14,13 @@ type PerformerReaderWriter struct {
mock.Mock
}
-// All provides a mock function with given fields:
-func (_m *PerformerReaderWriter) All() ([]*models.Performer, error) {
- ret := _m.Called()
+// All provides a mock function with given fields: ctx
+func (_m *PerformerReaderWriter) All(ctx context.Context) ([]*models.Performer, error) {
+ ret := _m.Called(ctx)
var r0 []*models.Performer
- if rf, ok := ret.Get(0).(func() []*models.Performer); ok {
- r0 = rf()
+ if rf, ok := ret.Get(0).(func(context.Context) []*models.Performer); ok {
+ r0 = rf(ctx)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Performer)
@@ -26,8 +28,8 @@ func (_m *PerformerReaderWriter) All() ([]*models.Performer, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func() error); ok {
- r1 = rf()
+ if rf, ok := ret.Get(1).(func(context.Context) error); ok {
+ r1 = rf(ctx)
} else {
r1 = ret.Error(1)
}
@@ -35,20 +37,20 @@ func (_m *PerformerReaderWriter) All() ([]*models.Performer, error) {
return r0, r1
}
-// Count provides a mock function with given fields:
-func (_m *PerformerReaderWriter) Count() (int, error) {
- ret := _m.Called()
+// Count provides a mock function with given fields: ctx
+func (_m *PerformerReaderWriter) Count(ctx context.Context) (int, error) {
+ ret := _m.Called(ctx)
var r0 int
- if rf, ok := ret.Get(0).(func() int); ok {
- r0 = rf()
+ if rf, ok := ret.Get(0).(func(context.Context) int); ok {
+ r0 = rf(ctx)
} else {
r0 = ret.Get(0).(int)
}
var r1 error
- if rf, ok := ret.Get(1).(func() error); ok {
- r1 = rf()
+ if rf, ok := ret.Get(1).(func(context.Context) error); ok {
+ r1 = rf(ctx)
} else {
r1 = ret.Error(1)
}
@@ -56,20 +58,20 @@ func (_m *PerformerReaderWriter) Count() (int, error) {
return r0, r1
}
-// CountByTagID provides a mock function with given fields: tagID
-func (_m *PerformerReaderWriter) CountByTagID(tagID int) (int, error) {
- ret := _m.Called(tagID)
+// CountByTagID provides a mock function with given fields: ctx, tagID
+func (_m *PerformerReaderWriter) CountByTagID(ctx context.Context, tagID int) (int, error) {
+ ret := _m.Called(ctx, tagID)
var r0 int
- if rf, ok := ret.Get(0).(func(int) int); ok {
- r0 = rf(tagID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) int); ok {
+ r0 = rf(ctx, tagID)
} else {
r0 = ret.Get(0).(int)
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(tagID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, tagID)
} else {
r1 = ret.Error(1)
}
@@ -77,13 +79,13 @@ func (_m *PerformerReaderWriter) CountByTagID(tagID int) (int, error) {
return r0, r1
}
-// Create provides a mock function with given fields: newPerformer
-func (_m *PerformerReaderWriter) Create(newPerformer models.Performer) (*models.Performer, error) {
- ret := _m.Called(newPerformer)
+// Create provides a mock function with given fields: ctx, newPerformer
+func (_m *PerformerReaderWriter) Create(ctx context.Context, newPerformer models.Performer) (*models.Performer, error) {
+ ret := _m.Called(ctx, newPerformer)
var r0 *models.Performer
- if rf, ok := ret.Get(0).(func(models.Performer) *models.Performer); ok {
- r0 = rf(newPerformer)
+ if rf, ok := ret.Get(0).(func(context.Context, models.Performer) *models.Performer); ok {
+ r0 = rf(ctx, newPerformer)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.Performer)
@@ -91,8 +93,8 @@ func (_m *PerformerReaderWriter) Create(newPerformer models.Performer) (*models.
}
var r1 error
- if rf, ok := ret.Get(1).(func(models.Performer) error); ok {
- r1 = rf(newPerformer)
+ if rf, ok := ret.Get(1).(func(context.Context, models.Performer) error); ok {
+ r1 = rf(ctx, newPerformer)
} else {
r1 = ret.Error(1)
}
@@ -100,13 +102,13 @@ func (_m *PerformerReaderWriter) Create(newPerformer models.Performer) (*models.
return r0, r1
}
-// Destroy provides a mock function with given fields: id
-func (_m *PerformerReaderWriter) Destroy(id int) error {
- ret := _m.Called(id)
+// Destroy provides a mock function with given fields: ctx, id
+func (_m *PerformerReaderWriter) Destroy(ctx context.Context, id int) error {
+ ret := _m.Called(ctx, id)
var r0 error
- if rf, ok := ret.Get(0).(func(int) error); ok {
- r0 = rf(id)
+ if rf, ok := ret.Get(0).(func(context.Context, int) error); ok {
+ r0 = rf(ctx, id)
} else {
r0 = ret.Error(0)
}
@@ -114,13 +116,13 @@ func (_m *PerformerReaderWriter) Destroy(id int) error {
return r0
}
-// DestroyImage provides a mock function with given fields: performerID
-func (_m *PerformerReaderWriter) DestroyImage(performerID int) error {
- ret := _m.Called(performerID)
+// DestroyImage provides a mock function with given fields: ctx, performerID
+func (_m *PerformerReaderWriter) DestroyImage(ctx context.Context, performerID int) error {
+ ret := _m.Called(ctx, performerID)
var r0 error
- if rf, ok := ret.Get(0).(func(int) error); ok {
- r0 = rf(performerID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) error); ok {
+ r0 = rf(ctx, performerID)
} else {
r0 = ret.Error(0)
}
@@ -128,13 +130,13 @@ func (_m *PerformerReaderWriter) DestroyImage(performerID int) error {
return r0
}
-// Find provides a mock function with given fields: id
-func (_m *PerformerReaderWriter) Find(id int) (*models.Performer, error) {
- ret := _m.Called(id)
+// Find provides a mock function with given fields: ctx, id
+func (_m *PerformerReaderWriter) Find(ctx context.Context, id int) (*models.Performer, error) {
+ ret := _m.Called(ctx, id)
var r0 *models.Performer
- if rf, ok := ret.Get(0).(func(int) *models.Performer); ok {
- r0 = rf(id)
+ if rf, ok := ret.Get(0).(func(context.Context, int) *models.Performer); ok {
+ r0 = rf(ctx, id)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.Performer)
@@ -142,8 +144,8 @@ func (_m *PerformerReaderWriter) Find(id int) (*models.Performer, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(id)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, id)
} else {
r1 = ret.Error(1)
}
@@ -151,13 +153,13 @@ func (_m *PerformerReaderWriter) Find(id int) (*models.Performer, error) {
return r0, r1
}
-// FindByGalleryID provides a mock function with given fields: galleryID
-func (_m *PerformerReaderWriter) FindByGalleryID(galleryID int) ([]*models.Performer, error) {
- ret := _m.Called(galleryID)
+// FindByGalleryID provides a mock function with given fields: ctx, galleryID
+func (_m *PerformerReaderWriter) FindByGalleryID(ctx context.Context, galleryID int) ([]*models.Performer, error) {
+ ret := _m.Called(ctx, galleryID)
var r0 []*models.Performer
- if rf, ok := ret.Get(0).(func(int) []*models.Performer); ok {
- r0 = rf(galleryID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) []*models.Performer); ok {
+ r0 = rf(ctx, galleryID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Performer)
@@ -165,8 +167,8 @@ func (_m *PerformerReaderWriter) FindByGalleryID(galleryID int) ([]*models.Perfo
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(galleryID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, galleryID)
} else {
r1 = ret.Error(1)
}
@@ -174,13 +176,13 @@ func (_m *PerformerReaderWriter) FindByGalleryID(galleryID int) ([]*models.Perfo
return r0, r1
}
-// FindByImageID provides a mock function with given fields: imageID
-func (_m *PerformerReaderWriter) FindByImageID(imageID int) ([]*models.Performer, error) {
- ret := _m.Called(imageID)
+// FindByImageID provides a mock function with given fields: ctx, imageID
+func (_m *PerformerReaderWriter) FindByImageID(ctx context.Context, imageID int) ([]*models.Performer, error) {
+ ret := _m.Called(ctx, imageID)
var r0 []*models.Performer
- if rf, ok := ret.Get(0).(func(int) []*models.Performer); ok {
- r0 = rf(imageID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) []*models.Performer); ok {
+ r0 = rf(ctx, imageID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Performer)
@@ -188,8 +190,8 @@ func (_m *PerformerReaderWriter) FindByImageID(imageID int) ([]*models.Performer
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(imageID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, imageID)
} else {
r1 = ret.Error(1)
}
@@ -197,13 +199,13 @@ func (_m *PerformerReaderWriter) FindByImageID(imageID int) ([]*models.Performer
return r0, r1
}
-// FindByNames provides a mock function with given fields: names, nocase
-func (_m *PerformerReaderWriter) FindByNames(names []string, nocase bool) ([]*models.Performer, error) {
- ret := _m.Called(names, nocase)
+// FindByNames provides a mock function with given fields: ctx, names, nocase
+func (_m *PerformerReaderWriter) FindByNames(ctx context.Context, names []string, nocase bool) ([]*models.Performer, error) {
+ ret := _m.Called(ctx, names, nocase)
var r0 []*models.Performer
- if rf, ok := ret.Get(0).(func([]string, bool) []*models.Performer); ok {
- r0 = rf(names, nocase)
+ if rf, ok := ret.Get(0).(func(context.Context, []string, bool) []*models.Performer); ok {
+ r0 = rf(ctx, names, nocase)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Performer)
@@ -211,8 +213,8 @@ func (_m *PerformerReaderWriter) FindByNames(names []string, nocase bool) ([]*mo
}
var r1 error
- if rf, ok := ret.Get(1).(func([]string, bool) error); ok {
- r1 = rf(names, nocase)
+ if rf, ok := ret.Get(1).(func(context.Context, []string, bool) error); ok {
+ r1 = rf(ctx, names, nocase)
} else {
r1 = ret.Error(1)
}
@@ -220,13 +222,13 @@ func (_m *PerformerReaderWriter) FindByNames(names []string, nocase bool) ([]*mo
return r0, r1
}
-// FindBySceneID provides a mock function with given fields: sceneID
-func (_m *PerformerReaderWriter) FindBySceneID(sceneID int) ([]*models.Performer, error) {
- ret := _m.Called(sceneID)
+// FindBySceneID provides a mock function with given fields: ctx, sceneID
+func (_m *PerformerReaderWriter) FindBySceneID(ctx context.Context, sceneID int) ([]*models.Performer, error) {
+ ret := _m.Called(ctx, sceneID)
var r0 []*models.Performer
- if rf, ok := ret.Get(0).(func(int) []*models.Performer); ok {
- r0 = rf(sceneID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) []*models.Performer); ok {
+ r0 = rf(ctx, sceneID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Performer)
@@ -234,8 +236,8 @@ func (_m *PerformerReaderWriter) FindBySceneID(sceneID int) ([]*models.Performer
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(sceneID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, sceneID)
} else {
r1 = ret.Error(1)
}
@@ -243,13 +245,13 @@ func (_m *PerformerReaderWriter) FindBySceneID(sceneID int) ([]*models.Performer
return r0, r1
}
-// FindByStashID provides a mock function with given fields: stashID
-func (_m *PerformerReaderWriter) FindByStashID(stashID models.StashID) ([]*models.Performer, error) {
- ret := _m.Called(stashID)
+// FindByStashID provides a mock function with given fields: ctx, stashID
+func (_m *PerformerReaderWriter) FindByStashID(ctx context.Context, stashID models.StashID) ([]*models.Performer, error) {
+ ret := _m.Called(ctx, stashID)
var r0 []*models.Performer
- if rf, ok := ret.Get(0).(func(models.StashID) []*models.Performer); ok {
- r0 = rf(stashID)
+ if rf, ok := ret.Get(0).(func(context.Context, models.StashID) []*models.Performer); ok {
+ r0 = rf(ctx, stashID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Performer)
@@ -257,8 +259,8 @@ func (_m *PerformerReaderWriter) FindByStashID(stashID models.StashID) ([]*model
}
var r1 error
- if rf, ok := ret.Get(1).(func(models.StashID) error); ok {
- r1 = rf(stashID)
+ if rf, ok := ret.Get(1).(func(context.Context, models.StashID) error); ok {
+ r1 = rf(ctx, stashID)
} else {
r1 = ret.Error(1)
}
@@ -266,13 +268,13 @@ func (_m *PerformerReaderWriter) FindByStashID(stashID models.StashID) ([]*model
return r0, r1
}
-// FindByStashIDStatus provides a mock function with given fields: hasStashID, stashboxEndpoint
-func (_m *PerformerReaderWriter) FindByStashIDStatus(hasStashID bool, stashboxEndpoint string) ([]*models.Performer, error) {
- ret := _m.Called(hasStashID, stashboxEndpoint)
+// FindByStashIDStatus provides a mock function with given fields: ctx, hasStashID, stashboxEndpoint
+func (_m *PerformerReaderWriter) FindByStashIDStatus(ctx context.Context, hasStashID bool, stashboxEndpoint string) ([]*models.Performer, error) {
+ ret := _m.Called(ctx, hasStashID, stashboxEndpoint)
var r0 []*models.Performer
- if rf, ok := ret.Get(0).(func(bool, string) []*models.Performer); ok {
- r0 = rf(hasStashID, stashboxEndpoint)
+ if rf, ok := ret.Get(0).(func(context.Context, bool, string) []*models.Performer); ok {
+ r0 = rf(ctx, hasStashID, stashboxEndpoint)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Performer)
@@ -280,8 +282,8 @@ func (_m *PerformerReaderWriter) FindByStashIDStatus(hasStashID bool, stashboxEn
}
var r1 error
- if rf, ok := ret.Get(1).(func(bool, string) error); ok {
- r1 = rf(hasStashID, stashboxEndpoint)
+ if rf, ok := ret.Get(1).(func(context.Context, bool, string) error); ok {
+ r1 = rf(ctx, hasStashID, stashboxEndpoint)
} else {
r1 = ret.Error(1)
}
@@ -289,13 +291,13 @@ func (_m *PerformerReaderWriter) FindByStashIDStatus(hasStashID bool, stashboxEn
return r0, r1
}
-// FindMany provides a mock function with given fields: ids
-func (_m *PerformerReaderWriter) FindMany(ids []int) ([]*models.Performer, error) {
- ret := _m.Called(ids)
+// FindMany provides a mock function with given fields: ctx, ids
+func (_m *PerformerReaderWriter) FindMany(ctx context.Context, ids []int) ([]*models.Performer, error) {
+ ret := _m.Called(ctx, ids)
var r0 []*models.Performer
- if rf, ok := ret.Get(0).(func([]int) []*models.Performer); ok {
- r0 = rf(ids)
+ if rf, ok := ret.Get(0).(func(context.Context, []int) []*models.Performer); ok {
+ r0 = rf(ctx, ids)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Performer)
@@ -303,8 +305,8 @@ func (_m *PerformerReaderWriter) FindMany(ids []int) ([]*models.Performer, error
}
var r1 error
- if rf, ok := ret.Get(1).(func([]int) error); ok {
- r1 = rf(ids)
+ if rf, ok := ret.Get(1).(func(context.Context, []int) error); ok {
+ r1 = rf(ctx, ids)
} else {
r1 = ret.Error(1)
}
@@ -312,13 +314,13 @@ func (_m *PerformerReaderWriter) FindMany(ids []int) ([]*models.Performer, error
return r0, r1
}
-// FindNamesBySceneID provides a mock function with given fields: sceneID
-func (_m *PerformerReaderWriter) FindNamesBySceneID(sceneID int) ([]*models.Performer, error) {
- ret := _m.Called(sceneID)
+// FindNamesBySceneID provides a mock function with given fields: ctx, sceneID
+func (_m *PerformerReaderWriter) FindNamesBySceneID(ctx context.Context, sceneID int) ([]*models.Performer, error) {
+ ret := _m.Called(ctx, sceneID)
var r0 []*models.Performer
- if rf, ok := ret.Get(0).(func(int) []*models.Performer); ok {
- r0 = rf(sceneID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) []*models.Performer); ok {
+ r0 = rf(ctx, sceneID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Performer)
@@ -326,8 +328,8 @@ func (_m *PerformerReaderWriter) FindNamesBySceneID(sceneID int) ([]*models.Perf
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(sceneID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, sceneID)
} else {
r1 = ret.Error(1)
}
@@ -335,13 +337,13 @@ func (_m *PerformerReaderWriter) FindNamesBySceneID(sceneID int) ([]*models.Perf
return r0, r1
}
-// GetImage provides a mock function with given fields: performerID
-func (_m *PerformerReaderWriter) GetImage(performerID int) ([]byte, error) {
- ret := _m.Called(performerID)
+// GetImage provides a mock function with given fields: ctx, performerID
+func (_m *PerformerReaderWriter) GetImage(ctx context.Context, performerID int) ([]byte, error) {
+ ret := _m.Called(ctx, performerID)
var r0 []byte
- if rf, ok := ret.Get(0).(func(int) []byte); ok {
- r0 = rf(performerID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) []byte); ok {
+ r0 = rf(ctx, performerID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]byte)
@@ -349,8 +351,8 @@ func (_m *PerformerReaderWriter) GetImage(performerID int) ([]byte, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(performerID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, performerID)
} else {
r1 = ret.Error(1)
}
@@ -358,13 +360,13 @@ func (_m *PerformerReaderWriter) GetImage(performerID int) ([]byte, error) {
return r0, r1
}
-// GetStashIDs provides a mock function with given fields: performerID
-func (_m *PerformerReaderWriter) GetStashIDs(performerID int) ([]*models.StashID, error) {
- ret := _m.Called(performerID)
+// GetStashIDs provides a mock function with given fields: ctx, performerID
+func (_m *PerformerReaderWriter) GetStashIDs(ctx context.Context, performerID int) ([]*models.StashID, error) {
+ ret := _m.Called(ctx, performerID)
var r0 []*models.StashID
- if rf, ok := ret.Get(0).(func(int) []*models.StashID); ok {
- r0 = rf(performerID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) []*models.StashID); ok {
+ r0 = rf(ctx, performerID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.StashID)
@@ -372,8 +374,8 @@ func (_m *PerformerReaderWriter) GetStashIDs(performerID int) ([]*models.StashID
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(performerID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, performerID)
} else {
r1 = ret.Error(1)
}
@@ -381,13 +383,13 @@ func (_m *PerformerReaderWriter) GetStashIDs(performerID int) ([]*models.StashID
return r0, r1
}
-// GetTagIDs provides a mock function with given fields: performerID
-func (_m *PerformerReaderWriter) GetTagIDs(performerID int) ([]int, error) {
- ret := _m.Called(performerID)
+// GetTagIDs provides a mock function with given fields: ctx, performerID
+func (_m *PerformerReaderWriter) GetTagIDs(ctx context.Context, performerID int) ([]int, error) {
+ ret := _m.Called(ctx, performerID)
var r0 []int
- if rf, ok := ret.Get(0).(func(int) []int); ok {
- r0 = rf(performerID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) []int); ok {
+ r0 = rf(ctx, performerID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]int)
@@ -395,8 +397,8 @@ func (_m *PerformerReaderWriter) GetTagIDs(performerID int) ([]int, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(performerID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, performerID)
} else {
r1 = ret.Error(1)
}
@@ -404,13 +406,13 @@ func (_m *PerformerReaderWriter) GetTagIDs(performerID int) ([]int, error) {
return r0, r1
}
-// Query provides a mock function with given fields: performerFilter, findFilter
-func (_m *PerformerReaderWriter) Query(performerFilter *models.PerformerFilterType, findFilter *models.FindFilterType) ([]*models.Performer, int, error) {
- ret := _m.Called(performerFilter, findFilter)
+// Query provides a mock function with given fields: ctx, performerFilter, findFilter
+func (_m *PerformerReaderWriter) Query(ctx context.Context, performerFilter *models.PerformerFilterType, findFilter *models.FindFilterType) ([]*models.Performer, int, error) {
+ ret := _m.Called(ctx, performerFilter, findFilter)
var r0 []*models.Performer
- if rf, ok := ret.Get(0).(func(*models.PerformerFilterType, *models.FindFilterType) []*models.Performer); ok {
- r0 = rf(performerFilter, findFilter)
+ if rf, ok := ret.Get(0).(func(context.Context, *models.PerformerFilterType, *models.FindFilterType) []*models.Performer); ok {
+ r0 = rf(ctx, performerFilter, findFilter)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Performer)
@@ -418,15 +420,15 @@ func (_m *PerformerReaderWriter) Query(performerFilter *models.PerformerFilterTy
}
var r1 int
- if rf, ok := ret.Get(1).(func(*models.PerformerFilterType, *models.FindFilterType) int); ok {
- r1 = rf(performerFilter, findFilter)
+ if rf, ok := ret.Get(1).(func(context.Context, *models.PerformerFilterType, *models.FindFilterType) int); ok {
+ r1 = rf(ctx, performerFilter, findFilter)
} else {
r1 = ret.Get(1).(int)
}
var r2 error
- if rf, ok := ret.Get(2).(func(*models.PerformerFilterType, *models.FindFilterType) error); ok {
- r2 = rf(performerFilter, findFilter)
+ if rf, ok := ret.Get(2).(func(context.Context, *models.PerformerFilterType, *models.FindFilterType) error); ok {
+ r2 = rf(ctx, performerFilter, findFilter)
} else {
r2 = ret.Error(2)
}
@@ -434,13 +436,13 @@ func (_m *PerformerReaderWriter) Query(performerFilter *models.PerformerFilterTy
return r0, r1, r2
}
-// QueryForAutoTag provides a mock function with given fields: words
-func (_m *PerformerReaderWriter) QueryForAutoTag(words []string) ([]*models.Performer, error) {
- ret := _m.Called(words)
+// QueryForAutoTag provides a mock function with given fields: ctx, words
+func (_m *PerformerReaderWriter) QueryForAutoTag(ctx context.Context, words []string) ([]*models.Performer, error) {
+ ret := _m.Called(ctx, words)
var r0 []*models.Performer
- if rf, ok := ret.Get(0).(func([]string) []*models.Performer); ok {
- r0 = rf(words)
+ if rf, ok := ret.Get(0).(func(context.Context, []string) []*models.Performer); ok {
+ r0 = rf(ctx, words)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Performer)
@@ -448,8 +450,8 @@ func (_m *PerformerReaderWriter) QueryForAutoTag(words []string) ([]*models.Perf
}
var r1 error
- if rf, ok := ret.Get(1).(func([]string) error); ok {
- r1 = rf(words)
+ if rf, ok := ret.Get(1).(func(context.Context, []string) error); ok {
+ r1 = rf(ctx, words)
} else {
r1 = ret.Error(1)
}
@@ -457,13 +459,13 @@ func (_m *PerformerReaderWriter) QueryForAutoTag(words []string) ([]*models.Perf
return r0, r1
}
-// Update provides a mock function with given fields: updatedPerformer
-func (_m *PerformerReaderWriter) Update(updatedPerformer models.PerformerPartial) (*models.Performer, error) {
- ret := _m.Called(updatedPerformer)
+// Update provides a mock function with given fields: ctx, updatedPerformer
+func (_m *PerformerReaderWriter) Update(ctx context.Context, updatedPerformer models.PerformerPartial) (*models.Performer, error) {
+ ret := _m.Called(ctx, updatedPerformer)
var r0 *models.Performer
- if rf, ok := ret.Get(0).(func(models.PerformerPartial) *models.Performer); ok {
- r0 = rf(updatedPerformer)
+ if rf, ok := ret.Get(0).(func(context.Context, models.PerformerPartial) *models.Performer); ok {
+ r0 = rf(ctx, updatedPerformer)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.Performer)
@@ -471,8 +473,8 @@ func (_m *PerformerReaderWriter) Update(updatedPerformer models.PerformerPartial
}
var r1 error
- if rf, ok := ret.Get(1).(func(models.PerformerPartial) error); ok {
- r1 = rf(updatedPerformer)
+ if rf, ok := ret.Get(1).(func(context.Context, models.PerformerPartial) error); ok {
+ r1 = rf(ctx, updatedPerformer)
} else {
r1 = ret.Error(1)
}
@@ -480,13 +482,13 @@ func (_m *PerformerReaderWriter) Update(updatedPerformer models.PerformerPartial
return r0, r1
}
-// UpdateFull provides a mock function with given fields: updatedPerformer
-func (_m *PerformerReaderWriter) UpdateFull(updatedPerformer models.Performer) (*models.Performer, error) {
- ret := _m.Called(updatedPerformer)
+// UpdateFull provides a mock function with given fields: ctx, updatedPerformer
+func (_m *PerformerReaderWriter) UpdateFull(ctx context.Context, updatedPerformer models.Performer) (*models.Performer, error) {
+ ret := _m.Called(ctx, updatedPerformer)
var r0 *models.Performer
- if rf, ok := ret.Get(0).(func(models.Performer) *models.Performer); ok {
- r0 = rf(updatedPerformer)
+ if rf, ok := ret.Get(0).(func(context.Context, models.Performer) *models.Performer); ok {
+ r0 = rf(ctx, updatedPerformer)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.Performer)
@@ -494,8 +496,8 @@ func (_m *PerformerReaderWriter) UpdateFull(updatedPerformer models.Performer) (
}
var r1 error
- if rf, ok := ret.Get(1).(func(models.Performer) error); ok {
- r1 = rf(updatedPerformer)
+ if rf, ok := ret.Get(1).(func(context.Context, models.Performer) error); ok {
+ r1 = rf(ctx, updatedPerformer)
} else {
r1 = ret.Error(1)
}
@@ -503,13 +505,13 @@ func (_m *PerformerReaderWriter) UpdateFull(updatedPerformer models.Performer) (
return r0, r1
}
-// UpdateImage provides a mock function with given fields: performerID, image
-func (_m *PerformerReaderWriter) UpdateImage(performerID int, image []byte) error {
- ret := _m.Called(performerID, image)
+// UpdateImage provides a mock function with given fields: ctx, performerID, image
+func (_m *PerformerReaderWriter) UpdateImage(ctx context.Context, performerID int, image []byte) error {
+ ret := _m.Called(ctx, performerID, image)
var r0 error
- if rf, ok := ret.Get(0).(func(int, []byte) error); ok {
- r0 = rf(performerID, image)
+ if rf, ok := ret.Get(0).(func(context.Context, int, []byte) error); ok {
+ r0 = rf(ctx, performerID, image)
} else {
r0 = ret.Error(0)
}
@@ -517,13 +519,13 @@ func (_m *PerformerReaderWriter) UpdateImage(performerID int, image []byte) erro
return r0
}
-// UpdateStashIDs provides a mock function with given fields: performerID, stashIDs
-func (_m *PerformerReaderWriter) UpdateStashIDs(performerID int, stashIDs []models.StashID) error {
- ret := _m.Called(performerID, stashIDs)
+// UpdateStashIDs provides a mock function with given fields: ctx, performerID, stashIDs
+func (_m *PerformerReaderWriter) UpdateStashIDs(ctx context.Context, performerID int, stashIDs []models.StashID) error {
+ ret := _m.Called(ctx, performerID, stashIDs)
var r0 error
- if rf, ok := ret.Get(0).(func(int, []models.StashID) error); ok {
- r0 = rf(performerID, stashIDs)
+ if rf, ok := ret.Get(0).(func(context.Context, int, []models.StashID) error); ok {
+ r0 = rf(ctx, performerID, stashIDs)
} else {
r0 = ret.Error(0)
}
@@ -531,13 +533,13 @@ func (_m *PerformerReaderWriter) UpdateStashIDs(performerID int, stashIDs []mode
return r0
}
-// UpdateTags provides a mock function with given fields: performerID, tagIDs
-func (_m *PerformerReaderWriter) UpdateTags(performerID int, tagIDs []int) error {
- ret := _m.Called(performerID, tagIDs)
+// UpdateTags provides a mock function with given fields: ctx, performerID, tagIDs
+func (_m *PerformerReaderWriter) UpdateTags(ctx context.Context, performerID int, tagIDs []int) error {
+ ret := _m.Called(ctx, performerID, tagIDs)
var r0 error
- if rf, ok := ret.Get(0).(func(int, []int) error); ok {
- r0 = rf(performerID, tagIDs)
+ if rf, ok := ret.Get(0).(func(context.Context, int, []int) error); ok {
+ r0 = rf(ctx, performerID, tagIDs)
} else {
r0 = ret.Error(0)
}
diff --git a/pkg/models/mocks/SavedFilterReaderWriter.go b/pkg/models/mocks/SavedFilterReaderWriter.go
index 952497be2..8f9e6e553 100644
--- a/pkg/models/mocks/SavedFilterReaderWriter.go
+++ b/pkg/models/mocks/SavedFilterReaderWriter.go
@@ -3,6 +3,8 @@
package mocks
import (
+ context "context"
+
models "github.com/stashapp/stash/pkg/models"
mock "github.com/stretchr/testify/mock"
)
@@ -12,13 +14,13 @@ type SavedFilterReaderWriter struct {
mock.Mock
}
-// All provides a mock function with given fields:
-func (_m *SavedFilterReaderWriter) All() ([]*models.SavedFilter, error) {
- ret := _m.Called()
+// All provides a mock function with given fields: ctx
+func (_m *SavedFilterReaderWriter) All(ctx context.Context) ([]*models.SavedFilter, error) {
+ ret := _m.Called(ctx)
var r0 []*models.SavedFilter
- if rf, ok := ret.Get(0).(func() []*models.SavedFilter); ok {
- r0 = rf()
+ if rf, ok := ret.Get(0).(func(context.Context) []*models.SavedFilter); ok {
+ r0 = rf(ctx)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.SavedFilter)
@@ -26,8 +28,8 @@ func (_m *SavedFilterReaderWriter) All() ([]*models.SavedFilter, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func() error); ok {
- r1 = rf()
+ if rf, ok := ret.Get(1).(func(context.Context) error); ok {
+ r1 = rf(ctx)
} else {
r1 = ret.Error(1)
}
@@ -35,13 +37,13 @@ func (_m *SavedFilterReaderWriter) All() ([]*models.SavedFilter, error) {
return r0, r1
}
-// Create provides a mock function with given fields: obj
-func (_m *SavedFilterReaderWriter) Create(obj models.SavedFilter) (*models.SavedFilter, error) {
- ret := _m.Called(obj)
+// Create provides a mock function with given fields: ctx, obj
+func (_m *SavedFilterReaderWriter) Create(ctx context.Context, obj models.SavedFilter) (*models.SavedFilter, error) {
+ ret := _m.Called(ctx, obj)
var r0 *models.SavedFilter
- if rf, ok := ret.Get(0).(func(models.SavedFilter) *models.SavedFilter); ok {
- r0 = rf(obj)
+ if rf, ok := ret.Get(0).(func(context.Context, models.SavedFilter) *models.SavedFilter); ok {
+ r0 = rf(ctx, obj)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.SavedFilter)
@@ -49,8 +51,8 @@ func (_m *SavedFilterReaderWriter) Create(obj models.SavedFilter) (*models.Saved
}
var r1 error
- if rf, ok := ret.Get(1).(func(models.SavedFilter) error); ok {
- r1 = rf(obj)
+ if rf, ok := ret.Get(1).(func(context.Context, models.SavedFilter) error); ok {
+ r1 = rf(ctx, obj)
} else {
r1 = ret.Error(1)
}
@@ -58,13 +60,13 @@ func (_m *SavedFilterReaderWriter) Create(obj models.SavedFilter) (*models.Saved
return r0, r1
}
-// Destroy provides a mock function with given fields: id
-func (_m *SavedFilterReaderWriter) Destroy(id int) error {
- ret := _m.Called(id)
+// Destroy provides a mock function with given fields: ctx, id
+func (_m *SavedFilterReaderWriter) Destroy(ctx context.Context, id int) error {
+ ret := _m.Called(ctx, id)
var r0 error
- if rf, ok := ret.Get(0).(func(int) error); ok {
- r0 = rf(id)
+ if rf, ok := ret.Get(0).(func(context.Context, int) error); ok {
+ r0 = rf(ctx, id)
} else {
r0 = ret.Error(0)
}
@@ -72,13 +74,13 @@ func (_m *SavedFilterReaderWriter) Destroy(id int) error {
return r0
}
-// Find provides a mock function with given fields: id
-func (_m *SavedFilterReaderWriter) Find(id int) (*models.SavedFilter, error) {
- ret := _m.Called(id)
+// Find provides a mock function with given fields: ctx, id
+func (_m *SavedFilterReaderWriter) Find(ctx context.Context, id int) (*models.SavedFilter, error) {
+ ret := _m.Called(ctx, id)
var r0 *models.SavedFilter
- if rf, ok := ret.Get(0).(func(int) *models.SavedFilter); ok {
- r0 = rf(id)
+ if rf, ok := ret.Get(0).(func(context.Context, int) *models.SavedFilter); ok {
+ r0 = rf(ctx, id)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.SavedFilter)
@@ -86,8 +88,8 @@ func (_m *SavedFilterReaderWriter) Find(id int) (*models.SavedFilter, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(id)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, id)
} else {
r1 = ret.Error(1)
}
@@ -95,13 +97,13 @@ func (_m *SavedFilterReaderWriter) Find(id int) (*models.SavedFilter, error) {
return r0, r1
}
-// FindByMode provides a mock function with given fields: mode
-func (_m *SavedFilterReaderWriter) FindByMode(mode models.FilterMode) ([]*models.SavedFilter, error) {
- ret := _m.Called(mode)
+// FindByMode provides a mock function with given fields: ctx, mode
+func (_m *SavedFilterReaderWriter) FindByMode(ctx context.Context, mode models.FilterMode) ([]*models.SavedFilter, error) {
+ ret := _m.Called(ctx, mode)
var r0 []*models.SavedFilter
- if rf, ok := ret.Get(0).(func(models.FilterMode) []*models.SavedFilter); ok {
- r0 = rf(mode)
+ if rf, ok := ret.Get(0).(func(context.Context, models.FilterMode) []*models.SavedFilter); ok {
+ r0 = rf(ctx, mode)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.SavedFilter)
@@ -109,8 +111,8 @@ func (_m *SavedFilterReaderWriter) FindByMode(mode models.FilterMode) ([]*models
}
var r1 error
- if rf, ok := ret.Get(1).(func(models.FilterMode) error); ok {
- r1 = rf(mode)
+ if rf, ok := ret.Get(1).(func(context.Context, models.FilterMode) error); ok {
+ r1 = rf(ctx, mode)
} else {
r1 = ret.Error(1)
}
@@ -118,13 +120,13 @@ func (_m *SavedFilterReaderWriter) FindByMode(mode models.FilterMode) ([]*models
return r0, r1
}
-// FindDefault provides a mock function with given fields: mode
-func (_m *SavedFilterReaderWriter) FindDefault(mode models.FilterMode) (*models.SavedFilter, error) {
- ret := _m.Called(mode)
+// FindDefault provides a mock function with given fields: ctx, mode
+func (_m *SavedFilterReaderWriter) FindDefault(ctx context.Context, mode models.FilterMode) (*models.SavedFilter, error) {
+ ret := _m.Called(ctx, mode)
var r0 *models.SavedFilter
- if rf, ok := ret.Get(0).(func(models.FilterMode) *models.SavedFilter); ok {
- r0 = rf(mode)
+ if rf, ok := ret.Get(0).(func(context.Context, models.FilterMode) *models.SavedFilter); ok {
+ r0 = rf(ctx, mode)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.SavedFilter)
@@ -132,8 +134,8 @@ func (_m *SavedFilterReaderWriter) FindDefault(mode models.FilterMode) (*models.
}
var r1 error
- if rf, ok := ret.Get(1).(func(models.FilterMode) error); ok {
- r1 = rf(mode)
+ if rf, ok := ret.Get(1).(func(context.Context, models.FilterMode) error); ok {
+ r1 = rf(ctx, mode)
} else {
r1 = ret.Error(1)
}
@@ -141,13 +143,13 @@ func (_m *SavedFilterReaderWriter) FindDefault(mode models.FilterMode) (*models.
return r0, r1
}
-// FindMany provides a mock function with given fields: ids, ignoreNotFound
-func (_m *SavedFilterReaderWriter) FindMany(ids []int, ignoreNotFound bool) ([]*models.SavedFilter, error) {
- ret := _m.Called(ids, ignoreNotFound)
+// FindMany provides a mock function with given fields: ctx, ids, ignoreNotFound
+func (_m *SavedFilterReaderWriter) FindMany(ctx context.Context, ids []int, ignoreNotFound bool) ([]*models.SavedFilter, error) {
+ ret := _m.Called(ctx, ids, ignoreNotFound)
var r0 []*models.SavedFilter
- if rf, ok := ret.Get(0).(func([]int, bool) []*models.SavedFilter); ok {
- r0 = rf(ids, ignoreNotFound)
+ if rf, ok := ret.Get(0).(func(context.Context, []int, bool) []*models.SavedFilter); ok {
+ r0 = rf(ctx, ids, ignoreNotFound)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.SavedFilter)
@@ -155,8 +157,8 @@ func (_m *SavedFilterReaderWriter) FindMany(ids []int, ignoreNotFound bool) ([]*
}
var r1 error
- if rf, ok := ret.Get(1).(func([]int, bool) error); ok {
- r1 = rf(ids, ignoreNotFound)
+ if rf, ok := ret.Get(1).(func(context.Context, []int, bool) error); ok {
+ r1 = rf(ctx, ids, ignoreNotFound)
} else {
r1 = ret.Error(1)
}
@@ -164,13 +166,13 @@ func (_m *SavedFilterReaderWriter) FindMany(ids []int, ignoreNotFound bool) ([]*
return r0, r1
}
-// SetDefault provides a mock function with given fields: obj
-func (_m *SavedFilterReaderWriter) SetDefault(obj models.SavedFilter) (*models.SavedFilter, error) {
- ret := _m.Called(obj)
+// SetDefault provides a mock function with given fields: ctx, obj
+func (_m *SavedFilterReaderWriter) SetDefault(ctx context.Context, obj models.SavedFilter) (*models.SavedFilter, error) {
+ ret := _m.Called(ctx, obj)
var r0 *models.SavedFilter
- if rf, ok := ret.Get(0).(func(models.SavedFilter) *models.SavedFilter); ok {
- r0 = rf(obj)
+ if rf, ok := ret.Get(0).(func(context.Context, models.SavedFilter) *models.SavedFilter); ok {
+ r0 = rf(ctx, obj)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.SavedFilter)
@@ -178,8 +180,8 @@ func (_m *SavedFilterReaderWriter) SetDefault(obj models.SavedFilter) (*models.S
}
var r1 error
- if rf, ok := ret.Get(1).(func(models.SavedFilter) error); ok {
- r1 = rf(obj)
+ if rf, ok := ret.Get(1).(func(context.Context, models.SavedFilter) error); ok {
+ r1 = rf(ctx, obj)
} else {
r1 = ret.Error(1)
}
@@ -187,13 +189,13 @@ func (_m *SavedFilterReaderWriter) SetDefault(obj models.SavedFilter) (*models.S
return r0, r1
}
-// Update provides a mock function with given fields: obj
-func (_m *SavedFilterReaderWriter) Update(obj models.SavedFilter) (*models.SavedFilter, error) {
- ret := _m.Called(obj)
+// Update provides a mock function with given fields: ctx, obj
+func (_m *SavedFilterReaderWriter) Update(ctx context.Context, obj models.SavedFilter) (*models.SavedFilter, error) {
+ ret := _m.Called(ctx, obj)
var r0 *models.SavedFilter
- if rf, ok := ret.Get(0).(func(models.SavedFilter) *models.SavedFilter); ok {
- r0 = rf(obj)
+ if rf, ok := ret.Get(0).(func(context.Context, models.SavedFilter) *models.SavedFilter); ok {
+ r0 = rf(ctx, obj)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.SavedFilter)
@@ -201,8 +203,8 @@ func (_m *SavedFilterReaderWriter) Update(obj models.SavedFilter) (*models.Saved
}
var r1 error
- if rf, ok := ret.Get(1).(func(models.SavedFilter) error); ok {
- r1 = rf(obj)
+ if rf, ok := ret.Get(1).(func(context.Context, models.SavedFilter) error); ok {
+ r1 = rf(ctx, obj)
} else {
r1 = ret.Error(1)
}
diff --git a/pkg/models/mocks/SceneMarkerReaderWriter.go b/pkg/models/mocks/SceneMarkerReaderWriter.go
index 2e6fea3a0..695a54391 100644
--- a/pkg/models/mocks/SceneMarkerReaderWriter.go
+++ b/pkg/models/mocks/SceneMarkerReaderWriter.go
@@ -3,6 +3,8 @@
package mocks
import (
+ context "context"
+
models "github.com/stashapp/stash/pkg/models"
mock "github.com/stretchr/testify/mock"
)
@@ -12,20 +14,20 @@ type SceneMarkerReaderWriter struct {
mock.Mock
}
-// CountByTagID provides a mock function with given fields: tagID
-func (_m *SceneMarkerReaderWriter) CountByTagID(tagID int) (int, error) {
- ret := _m.Called(tagID)
+// CountByTagID provides a mock function with given fields: ctx, tagID
+func (_m *SceneMarkerReaderWriter) CountByTagID(ctx context.Context, tagID int) (int, error) {
+ ret := _m.Called(ctx, tagID)
var r0 int
- if rf, ok := ret.Get(0).(func(int) int); ok {
- r0 = rf(tagID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) int); ok {
+ r0 = rf(ctx, tagID)
} else {
r0 = ret.Get(0).(int)
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(tagID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, tagID)
} else {
r1 = ret.Error(1)
}
@@ -33,13 +35,13 @@ func (_m *SceneMarkerReaderWriter) CountByTagID(tagID int) (int, error) {
return r0, r1
}
-// Create provides a mock function with given fields: newSceneMarker
-func (_m *SceneMarkerReaderWriter) Create(newSceneMarker models.SceneMarker) (*models.SceneMarker, error) {
- ret := _m.Called(newSceneMarker)
+// Create provides a mock function with given fields: ctx, newSceneMarker
+func (_m *SceneMarkerReaderWriter) Create(ctx context.Context, newSceneMarker models.SceneMarker) (*models.SceneMarker, error) {
+ ret := _m.Called(ctx, newSceneMarker)
var r0 *models.SceneMarker
- if rf, ok := ret.Get(0).(func(models.SceneMarker) *models.SceneMarker); ok {
- r0 = rf(newSceneMarker)
+ if rf, ok := ret.Get(0).(func(context.Context, models.SceneMarker) *models.SceneMarker); ok {
+ r0 = rf(ctx, newSceneMarker)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.SceneMarker)
@@ -47,8 +49,8 @@ func (_m *SceneMarkerReaderWriter) Create(newSceneMarker models.SceneMarker) (*m
}
var r1 error
- if rf, ok := ret.Get(1).(func(models.SceneMarker) error); ok {
- r1 = rf(newSceneMarker)
+ if rf, ok := ret.Get(1).(func(context.Context, models.SceneMarker) error); ok {
+ r1 = rf(ctx, newSceneMarker)
} else {
r1 = ret.Error(1)
}
@@ -56,13 +58,13 @@ func (_m *SceneMarkerReaderWriter) Create(newSceneMarker models.SceneMarker) (*m
return r0, r1
}
-// Destroy provides a mock function with given fields: id
-func (_m *SceneMarkerReaderWriter) Destroy(id int) error {
- ret := _m.Called(id)
+// Destroy provides a mock function with given fields: ctx, id
+func (_m *SceneMarkerReaderWriter) Destroy(ctx context.Context, id int) error {
+ ret := _m.Called(ctx, id)
var r0 error
- if rf, ok := ret.Get(0).(func(int) error); ok {
- r0 = rf(id)
+ if rf, ok := ret.Get(0).(func(context.Context, int) error); ok {
+ r0 = rf(ctx, id)
} else {
r0 = ret.Error(0)
}
@@ -70,13 +72,13 @@ func (_m *SceneMarkerReaderWriter) Destroy(id int) error {
return r0
}
-// Find provides a mock function with given fields: id
-func (_m *SceneMarkerReaderWriter) Find(id int) (*models.SceneMarker, error) {
- ret := _m.Called(id)
+// Find provides a mock function with given fields: ctx, id
+func (_m *SceneMarkerReaderWriter) Find(ctx context.Context, id int) (*models.SceneMarker, error) {
+ ret := _m.Called(ctx, id)
var r0 *models.SceneMarker
- if rf, ok := ret.Get(0).(func(int) *models.SceneMarker); ok {
- r0 = rf(id)
+ if rf, ok := ret.Get(0).(func(context.Context, int) *models.SceneMarker); ok {
+ r0 = rf(ctx, id)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.SceneMarker)
@@ -84,8 +86,8 @@ func (_m *SceneMarkerReaderWriter) Find(id int) (*models.SceneMarker, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(id)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, id)
} else {
r1 = ret.Error(1)
}
@@ -93,13 +95,13 @@ func (_m *SceneMarkerReaderWriter) Find(id int) (*models.SceneMarker, error) {
return r0, r1
}
-// FindBySceneID provides a mock function with given fields: sceneID
-func (_m *SceneMarkerReaderWriter) FindBySceneID(sceneID int) ([]*models.SceneMarker, error) {
- ret := _m.Called(sceneID)
+// FindBySceneID provides a mock function with given fields: ctx, sceneID
+func (_m *SceneMarkerReaderWriter) FindBySceneID(ctx context.Context, sceneID int) ([]*models.SceneMarker, error) {
+ ret := _m.Called(ctx, sceneID)
var r0 []*models.SceneMarker
- if rf, ok := ret.Get(0).(func(int) []*models.SceneMarker); ok {
- r0 = rf(sceneID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) []*models.SceneMarker); ok {
+ r0 = rf(ctx, sceneID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.SceneMarker)
@@ -107,8 +109,8 @@ func (_m *SceneMarkerReaderWriter) FindBySceneID(sceneID int) ([]*models.SceneMa
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(sceneID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, sceneID)
} else {
r1 = ret.Error(1)
}
@@ -116,13 +118,13 @@ func (_m *SceneMarkerReaderWriter) FindBySceneID(sceneID int) ([]*models.SceneMa
return r0, r1
}
-// FindMany provides a mock function with given fields: ids
-func (_m *SceneMarkerReaderWriter) FindMany(ids []int) ([]*models.SceneMarker, error) {
- ret := _m.Called(ids)
+// FindMany provides a mock function with given fields: ctx, ids
+func (_m *SceneMarkerReaderWriter) FindMany(ctx context.Context, ids []int) ([]*models.SceneMarker, error) {
+ ret := _m.Called(ctx, ids)
var r0 []*models.SceneMarker
- if rf, ok := ret.Get(0).(func([]int) []*models.SceneMarker); ok {
- r0 = rf(ids)
+ if rf, ok := ret.Get(0).(func(context.Context, []int) []*models.SceneMarker); ok {
+ r0 = rf(ctx, ids)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.SceneMarker)
@@ -130,8 +132,8 @@ func (_m *SceneMarkerReaderWriter) FindMany(ids []int) ([]*models.SceneMarker, e
}
var r1 error
- if rf, ok := ret.Get(1).(func([]int) error); ok {
- r1 = rf(ids)
+ if rf, ok := ret.Get(1).(func(context.Context, []int) error); ok {
+ r1 = rf(ctx, ids)
} else {
r1 = ret.Error(1)
}
@@ -139,13 +141,13 @@ func (_m *SceneMarkerReaderWriter) FindMany(ids []int) ([]*models.SceneMarker, e
return r0, r1
}
-// GetMarkerStrings provides a mock function with given fields: q, sort
-func (_m *SceneMarkerReaderWriter) GetMarkerStrings(q *string, sort *string) ([]*models.MarkerStringsResultType, error) {
- ret := _m.Called(q, sort)
+// GetMarkerStrings provides a mock function with given fields: ctx, q, sort
+func (_m *SceneMarkerReaderWriter) GetMarkerStrings(ctx context.Context, q *string, sort *string) ([]*models.MarkerStringsResultType, error) {
+ ret := _m.Called(ctx, q, sort)
var r0 []*models.MarkerStringsResultType
- if rf, ok := ret.Get(0).(func(*string, *string) []*models.MarkerStringsResultType); ok {
- r0 = rf(q, sort)
+ if rf, ok := ret.Get(0).(func(context.Context, *string, *string) []*models.MarkerStringsResultType); ok {
+ r0 = rf(ctx, q, sort)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.MarkerStringsResultType)
@@ -153,8 +155,8 @@ func (_m *SceneMarkerReaderWriter) GetMarkerStrings(q *string, sort *string) ([]
}
var r1 error
- if rf, ok := ret.Get(1).(func(*string, *string) error); ok {
- r1 = rf(q, sort)
+ if rf, ok := ret.Get(1).(func(context.Context, *string, *string) error); ok {
+ r1 = rf(ctx, q, sort)
} else {
r1 = ret.Error(1)
}
@@ -162,13 +164,13 @@ func (_m *SceneMarkerReaderWriter) GetMarkerStrings(q *string, sort *string) ([]
return r0, r1
}
-// GetTagIDs provides a mock function with given fields: imageID
-func (_m *SceneMarkerReaderWriter) GetTagIDs(imageID int) ([]int, error) {
- ret := _m.Called(imageID)
+// GetTagIDs provides a mock function with given fields: ctx, imageID
+func (_m *SceneMarkerReaderWriter) GetTagIDs(ctx context.Context, imageID int) ([]int, error) {
+ ret := _m.Called(ctx, imageID)
var r0 []int
- if rf, ok := ret.Get(0).(func(int) []int); ok {
- r0 = rf(imageID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) []int); ok {
+ r0 = rf(ctx, imageID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]int)
@@ -176,8 +178,8 @@ func (_m *SceneMarkerReaderWriter) GetTagIDs(imageID int) ([]int, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(imageID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, imageID)
} else {
r1 = ret.Error(1)
}
@@ -185,13 +187,13 @@ func (_m *SceneMarkerReaderWriter) GetTagIDs(imageID int) ([]int, error) {
return r0, r1
}
-// Query provides a mock function with given fields: sceneMarkerFilter, findFilter
-func (_m *SceneMarkerReaderWriter) Query(sceneMarkerFilter *models.SceneMarkerFilterType, findFilter *models.FindFilterType) ([]*models.SceneMarker, int, error) {
- ret := _m.Called(sceneMarkerFilter, findFilter)
+// Query provides a mock function with given fields: ctx, sceneMarkerFilter, findFilter
+func (_m *SceneMarkerReaderWriter) Query(ctx context.Context, sceneMarkerFilter *models.SceneMarkerFilterType, findFilter *models.FindFilterType) ([]*models.SceneMarker, int, error) {
+ ret := _m.Called(ctx, sceneMarkerFilter, findFilter)
var r0 []*models.SceneMarker
- if rf, ok := ret.Get(0).(func(*models.SceneMarkerFilterType, *models.FindFilterType) []*models.SceneMarker); ok {
- r0 = rf(sceneMarkerFilter, findFilter)
+ if rf, ok := ret.Get(0).(func(context.Context, *models.SceneMarkerFilterType, *models.FindFilterType) []*models.SceneMarker); ok {
+ r0 = rf(ctx, sceneMarkerFilter, findFilter)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.SceneMarker)
@@ -199,15 +201,15 @@ func (_m *SceneMarkerReaderWriter) Query(sceneMarkerFilter *models.SceneMarkerFi
}
var r1 int
- if rf, ok := ret.Get(1).(func(*models.SceneMarkerFilterType, *models.FindFilterType) int); ok {
- r1 = rf(sceneMarkerFilter, findFilter)
+ if rf, ok := ret.Get(1).(func(context.Context, *models.SceneMarkerFilterType, *models.FindFilterType) int); ok {
+ r1 = rf(ctx, sceneMarkerFilter, findFilter)
} else {
r1 = ret.Get(1).(int)
}
var r2 error
- if rf, ok := ret.Get(2).(func(*models.SceneMarkerFilterType, *models.FindFilterType) error); ok {
- r2 = rf(sceneMarkerFilter, findFilter)
+ if rf, ok := ret.Get(2).(func(context.Context, *models.SceneMarkerFilterType, *models.FindFilterType) error); ok {
+ r2 = rf(ctx, sceneMarkerFilter, findFilter)
} else {
r2 = ret.Error(2)
}
@@ -215,13 +217,13 @@ func (_m *SceneMarkerReaderWriter) Query(sceneMarkerFilter *models.SceneMarkerFi
return r0, r1, r2
}
-// Update provides a mock function with given fields: updatedSceneMarker
-func (_m *SceneMarkerReaderWriter) Update(updatedSceneMarker models.SceneMarker) (*models.SceneMarker, error) {
- ret := _m.Called(updatedSceneMarker)
+// Update provides a mock function with given fields: ctx, updatedSceneMarker
+func (_m *SceneMarkerReaderWriter) Update(ctx context.Context, updatedSceneMarker models.SceneMarker) (*models.SceneMarker, error) {
+ ret := _m.Called(ctx, updatedSceneMarker)
var r0 *models.SceneMarker
- if rf, ok := ret.Get(0).(func(models.SceneMarker) *models.SceneMarker); ok {
- r0 = rf(updatedSceneMarker)
+ if rf, ok := ret.Get(0).(func(context.Context, models.SceneMarker) *models.SceneMarker); ok {
+ r0 = rf(ctx, updatedSceneMarker)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.SceneMarker)
@@ -229,8 +231,8 @@ func (_m *SceneMarkerReaderWriter) Update(updatedSceneMarker models.SceneMarker)
}
var r1 error
- if rf, ok := ret.Get(1).(func(models.SceneMarker) error); ok {
- r1 = rf(updatedSceneMarker)
+ if rf, ok := ret.Get(1).(func(context.Context, models.SceneMarker) error); ok {
+ r1 = rf(ctx, updatedSceneMarker)
} else {
r1 = ret.Error(1)
}
@@ -238,13 +240,13 @@ func (_m *SceneMarkerReaderWriter) Update(updatedSceneMarker models.SceneMarker)
return r0, r1
}
-// UpdateTags provides a mock function with given fields: markerID, tagIDs
-func (_m *SceneMarkerReaderWriter) UpdateTags(markerID int, tagIDs []int) error {
- ret := _m.Called(markerID, tagIDs)
+// UpdateTags provides a mock function with given fields: ctx, markerID, tagIDs
+func (_m *SceneMarkerReaderWriter) UpdateTags(ctx context.Context, markerID int, tagIDs []int) error {
+ ret := _m.Called(ctx, markerID, tagIDs)
var r0 error
- if rf, ok := ret.Get(0).(func(int, []int) error); ok {
- r0 = rf(markerID, tagIDs)
+ if rf, ok := ret.Get(0).(func(context.Context, int, []int) error); ok {
+ r0 = rf(ctx, markerID, tagIDs)
} else {
r0 = ret.Error(0)
}
@@ -252,13 +254,13 @@ func (_m *SceneMarkerReaderWriter) UpdateTags(markerID int, tagIDs []int) error
return r0
}
-// Wall provides a mock function with given fields: q
-func (_m *SceneMarkerReaderWriter) Wall(q *string) ([]*models.SceneMarker, error) {
- ret := _m.Called(q)
+// Wall provides a mock function with given fields: ctx, q
+func (_m *SceneMarkerReaderWriter) Wall(ctx context.Context, q *string) ([]*models.SceneMarker, error) {
+ ret := _m.Called(ctx, q)
var r0 []*models.SceneMarker
- if rf, ok := ret.Get(0).(func(*string) []*models.SceneMarker); ok {
- r0 = rf(q)
+ if rf, ok := ret.Get(0).(func(context.Context, *string) []*models.SceneMarker); ok {
+ r0 = rf(ctx, q)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.SceneMarker)
@@ -266,8 +268,8 @@ func (_m *SceneMarkerReaderWriter) Wall(q *string) ([]*models.SceneMarker, error
}
var r1 error
- if rf, ok := ret.Get(1).(func(*string) error); ok {
- r1 = rf(q)
+ if rf, ok := ret.Get(1).(func(context.Context, *string) error); ok {
+ r1 = rf(ctx, q)
} else {
r1 = ret.Error(1)
}
diff --git a/pkg/models/mocks/SceneReaderWriter.go b/pkg/models/mocks/SceneReaderWriter.go
index 0635fd200..a9ab69097 100644
--- a/pkg/models/mocks/SceneReaderWriter.go
+++ b/pkg/models/mocks/SceneReaderWriter.go
@@ -3,6 +3,8 @@
package mocks
import (
+ context "context"
+
models "github.com/stashapp/stash/pkg/models"
mock "github.com/stretchr/testify/mock"
)
@@ -12,13 +14,13 @@ type SceneReaderWriter struct {
mock.Mock
}
-// All provides a mock function with given fields:
-func (_m *SceneReaderWriter) All() ([]*models.Scene, error) {
- ret := _m.Called()
+// All provides a mock function with given fields: ctx
+func (_m *SceneReaderWriter) All(ctx context.Context) ([]*models.Scene, error) {
+ ret := _m.Called(ctx)
var r0 []*models.Scene
- if rf, ok := ret.Get(0).(func() []*models.Scene); ok {
- r0 = rf()
+ if rf, ok := ret.Get(0).(func(context.Context) []*models.Scene); ok {
+ r0 = rf(ctx)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Scene)
@@ -26,8 +28,8 @@ func (_m *SceneReaderWriter) All() ([]*models.Scene, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func() error); ok {
- r1 = rf()
+ if rf, ok := ret.Get(1).(func(context.Context) error); ok {
+ r1 = rf(ctx)
} else {
r1 = ret.Error(1)
}
@@ -35,20 +37,20 @@ func (_m *SceneReaderWriter) All() ([]*models.Scene, error) {
return r0, r1
}
-// Count provides a mock function with given fields:
-func (_m *SceneReaderWriter) Count() (int, error) {
- ret := _m.Called()
+// Count provides a mock function with given fields: ctx
+func (_m *SceneReaderWriter) Count(ctx context.Context) (int, error) {
+ ret := _m.Called(ctx)
var r0 int
- if rf, ok := ret.Get(0).(func() int); ok {
- r0 = rf()
+ if rf, ok := ret.Get(0).(func(context.Context) int); ok {
+ r0 = rf(ctx)
} else {
r0 = ret.Get(0).(int)
}
var r1 error
- if rf, ok := ret.Get(1).(func() error); ok {
- r1 = rf()
+ if rf, ok := ret.Get(1).(func(context.Context) error); ok {
+ r1 = rf(ctx)
} else {
r1 = ret.Error(1)
}
@@ -56,20 +58,20 @@ func (_m *SceneReaderWriter) Count() (int, error) {
return r0, r1
}
-// CountByMovieID provides a mock function with given fields: movieID
-func (_m *SceneReaderWriter) CountByMovieID(movieID int) (int, error) {
- ret := _m.Called(movieID)
+// CountByMovieID provides a mock function with given fields: ctx, movieID
+func (_m *SceneReaderWriter) CountByMovieID(ctx context.Context, movieID int) (int, error) {
+ ret := _m.Called(ctx, movieID)
var r0 int
- if rf, ok := ret.Get(0).(func(int) int); ok {
- r0 = rf(movieID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) int); ok {
+ r0 = rf(ctx, movieID)
} else {
r0 = ret.Get(0).(int)
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(movieID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, movieID)
} else {
r1 = ret.Error(1)
}
@@ -77,20 +79,20 @@ func (_m *SceneReaderWriter) CountByMovieID(movieID int) (int, error) {
return r0, r1
}
-// CountByPerformerID provides a mock function with given fields: performerID
-func (_m *SceneReaderWriter) CountByPerformerID(performerID int) (int, error) {
- ret := _m.Called(performerID)
+// CountByPerformerID provides a mock function with given fields: ctx, performerID
+func (_m *SceneReaderWriter) CountByPerformerID(ctx context.Context, performerID int) (int, error) {
+ ret := _m.Called(ctx, performerID)
var r0 int
- if rf, ok := ret.Get(0).(func(int) int); ok {
- r0 = rf(performerID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) int); ok {
+ r0 = rf(ctx, performerID)
} else {
r0 = ret.Get(0).(int)
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(performerID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, performerID)
} else {
r1 = ret.Error(1)
}
@@ -98,20 +100,20 @@ func (_m *SceneReaderWriter) CountByPerformerID(performerID int) (int, error) {
return r0, r1
}
-// CountByStudioID provides a mock function with given fields: studioID
-func (_m *SceneReaderWriter) CountByStudioID(studioID int) (int, error) {
- ret := _m.Called(studioID)
+// CountByStudioID provides a mock function with given fields: ctx, studioID
+func (_m *SceneReaderWriter) CountByStudioID(ctx context.Context, studioID int) (int, error) {
+ ret := _m.Called(ctx, studioID)
var r0 int
- if rf, ok := ret.Get(0).(func(int) int); ok {
- r0 = rf(studioID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) int); ok {
+ r0 = rf(ctx, studioID)
} else {
r0 = ret.Get(0).(int)
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(studioID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, studioID)
} else {
r1 = ret.Error(1)
}
@@ -119,20 +121,20 @@ func (_m *SceneReaderWriter) CountByStudioID(studioID int) (int, error) {
return r0, r1
}
-// CountByTagID provides a mock function with given fields: tagID
-func (_m *SceneReaderWriter) CountByTagID(tagID int) (int, error) {
- ret := _m.Called(tagID)
+// CountByTagID provides a mock function with given fields: ctx, tagID
+func (_m *SceneReaderWriter) CountByTagID(ctx context.Context, tagID int) (int, error) {
+ ret := _m.Called(ctx, tagID)
var r0 int
- if rf, ok := ret.Get(0).(func(int) int); ok {
- r0 = rf(tagID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) int); ok {
+ r0 = rf(ctx, tagID)
} else {
r0 = ret.Get(0).(int)
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(tagID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, tagID)
} else {
r1 = ret.Error(1)
}
@@ -140,20 +142,20 @@ func (_m *SceneReaderWriter) CountByTagID(tagID int) (int, error) {
return r0, r1
}
-// CountMissingChecksum provides a mock function with given fields:
-func (_m *SceneReaderWriter) CountMissingChecksum() (int, error) {
- ret := _m.Called()
+// CountMissingChecksum provides a mock function with given fields: ctx
+func (_m *SceneReaderWriter) CountMissingChecksum(ctx context.Context) (int, error) {
+ ret := _m.Called(ctx)
var r0 int
- if rf, ok := ret.Get(0).(func() int); ok {
- r0 = rf()
+ if rf, ok := ret.Get(0).(func(context.Context) int); ok {
+ r0 = rf(ctx)
} else {
r0 = ret.Get(0).(int)
}
var r1 error
- if rf, ok := ret.Get(1).(func() error); ok {
- r1 = rf()
+ if rf, ok := ret.Get(1).(func(context.Context) error); ok {
+ r1 = rf(ctx)
} else {
r1 = ret.Error(1)
}
@@ -161,20 +163,20 @@ func (_m *SceneReaderWriter) CountMissingChecksum() (int, error) {
return r0, r1
}
-// CountMissingOSHash provides a mock function with given fields:
-func (_m *SceneReaderWriter) CountMissingOSHash() (int, error) {
- ret := _m.Called()
+// CountMissingOSHash provides a mock function with given fields: ctx
+func (_m *SceneReaderWriter) CountMissingOSHash(ctx context.Context) (int, error) {
+ ret := _m.Called(ctx)
var r0 int
- if rf, ok := ret.Get(0).(func() int); ok {
- r0 = rf()
+ if rf, ok := ret.Get(0).(func(context.Context) int); ok {
+ r0 = rf(ctx)
} else {
r0 = ret.Get(0).(int)
}
var r1 error
- if rf, ok := ret.Get(1).(func() error); ok {
- r1 = rf()
+ if rf, ok := ret.Get(1).(func(context.Context) error); ok {
+ r1 = rf(ctx)
} else {
r1 = ret.Error(1)
}
@@ -182,13 +184,13 @@ func (_m *SceneReaderWriter) CountMissingOSHash() (int, error) {
return r0, r1
}
-// Create provides a mock function with given fields: newScene
-func (_m *SceneReaderWriter) Create(newScene models.Scene) (*models.Scene, error) {
- ret := _m.Called(newScene)
+// Create provides a mock function with given fields: ctx, newScene
+func (_m *SceneReaderWriter) Create(ctx context.Context, newScene models.Scene) (*models.Scene, error) {
+ ret := _m.Called(ctx, newScene)
var r0 *models.Scene
- if rf, ok := ret.Get(0).(func(models.Scene) *models.Scene); ok {
- r0 = rf(newScene)
+ if rf, ok := ret.Get(0).(func(context.Context, models.Scene) *models.Scene); ok {
+ r0 = rf(ctx, newScene)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.Scene)
@@ -196,8 +198,8 @@ func (_m *SceneReaderWriter) Create(newScene models.Scene) (*models.Scene, error
}
var r1 error
- if rf, ok := ret.Get(1).(func(models.Scene) error); ok {
- r1 = rf(newScene)
+ if rf, ok := ret.Get(1).(func(context.Context, models.Scene) error); ok {
+ r1 = rf(ctx, newScene)
} else {
r1 = ret.Error(1)
}
@@ -205,20 +207,20 @@ func (_m *SceneReaderWriter) Create(newScene models.Scene) (*models.Scene, error
return r0, r1
}
-// DecrementOCounter provides a mock function with given fields: id
-func (_m *SceneReaderWriter) DecrementOCounter(id int) (int, error) {
- ret := _m.Called(id)
+// DecrementOCounter provides a mock function with given fields: ctx, id
+func (_m *SceneReaderWriter) DecrementOCounter(ctx context.Context, id int) (int, error) {
+ ret := _m.Called(ctx, id)
var r0 int
- if rf, ok := ret.Get(0).(func(int) int); ok {
- r0 = rf(id)
+ if rf, ok := ret.Get(0).(func(context.Context, int) int); ok {
+ r0 = rf(ctx, id)
} else {
r0 = ret.Get(0).(int)
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(id)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, id)
} else {
r1 = ret.Error(1)
}
@@ -226,13 +228,13 @@ func (_m *SceneReaderWriter) DecrementOCounter(id int) (int, error) {
return r0, r1
}
-// Destroy provides a mock function with given fields: id
-func (_m *SceneReaderWriter) Destroy(id int) error {
- ret := _m.Called(id)
+// Destroy provides a mock function with given fields: ctx, id
+func (_m *SceneReaderWriter) Destroy(ctx context.Context, id int) error {
+ ret := _m.Called(ctx, id)
var r0 error
- if rf, ok := ret.Get(0).(func(int) error); ok {
- r0 = rf(id)
+ if rf, ok := ret.Get(0).(func(context.Context, int) error); ok {
+ r0 = rf(ctx, id)
} else {
r0 = ret.Error(0)
}
@@ -240,13 +242,13 @@ func (_m *SceneReaderWriter) Destroy(id int) error {
return r0
}
-// DestroyCover provides a mock function with given fields: sceneID
-func (_m *SceneReaderWriter) DestroyCover(sceneID int) error {
- ret := _m.Called(sceneID)
+// DestroyCover provides a mock function with given fields: ctx, sceneID
+func (_m *SceneReaderWriter) DestroyCover(ctx context.Context, sceneID int) error {
+ ret := _m.Called(ctx, sceneID)
var r0 error
- if rf, ok := ret.Get(0).(func(int) error); ok {
- r0 = rf(sceneID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) error); ok {
+ r0 = rf(ctx, sceneID)
} else {
r0 = ret.Error(0)
}
@@ -254,20 +256,20 @@ func (_m *SceneReaderWriter) DestroyCover(sceneID int) error {
return r0
}
-// Duration provides a mock function with given fields:
-func (_m *SceneReaderWriter) Duration() (float64, error) {
- ret := _m.Called()
+// Duration provides a mock function with given fields: ctx
+func (_m *SceneReaderWriter) Duration(ctx context.Context) (float64, error) {
+ ret := _m.Called(ctx)
var r0 float64
- if rf, ok := ret.Get(0).(func() float64); ok {
- r0 = rf()
+ if rf, ok := ret.Get(0).(func(context.Context) float64); ok {
+ r0 = rf(ctx)
} else {
r0 = ret.Get(0).(float64)
}
var r1 error
- if rf, ok := ret.Get(1).(func() error); ok {
- r1 = rf()
+ if rf, ok := ret.Get(1).(func(context.Context) error); ok {
+ r1 = rf(ctx)
} else {
r1 = ret.Error(1)
}
@@ -275,13 +277,13 @@ func (_m *SceneReaderWriter) Duration() (float64, error) {
return r0, r1
}
-// Find provides a mock function with given fields: id
-func (_m *SceneReaderWriter) Find(id int) (*models.Scene, error) {
- ret := _m.Called(id)
+// Find provides a mock function with given fields: ctx, id
+func (_m *SceneReaderWriter) Find(ctx context.Context, id int) (*models.Scene, error) {
+ ret := _m.Called(ctx, id)
var r0 *models.Scene
- if rf, ok := ret.Get(0).(func(int) *models.Scene); ok {
- r0 = rf(id)
+ if rf, ok := ret.Get(0).(func(context.Context, int) *models.Scene); ok {
+ r0 = rf(ctx, id)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.Scene)
@@ -289,8 +291,8 @@ func (_m *SceneReaderWriter) Find(id int) (*models.Scene, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(id)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, id)
} else {
r1 = ret.Error(1)
}
@@ -298,13 +300,13 @@ func (_m *SceneReaderWriter) Find(id int) (*models.Scene, error) {
return r0, r1
}
-// FindByChecksum provides a mock function with given fields: checksum
-func (_m *SceneReaderWriter) FindByChecksum(checksum string) (*models.Scene, error) {
- ret := _m.Called(checksum)
+// FindByChecksum provides a mock function with given fields: ctx, checksum
+func (_m *SceneReaderWriter) FindByChecksum(ctx context.Context, checksum string) (*models.Scene, error) {
+ ret := _m.Called(ctx, checksum)
var r0 *models.Scene
- if rf, ok := ret.Get(0).(func(string) *models.Scene); ok {
- r0 = rf(checksum)
+ if rf, ok := ret.Get(0).(func(context.Context, string) *models.Scene); ok {
+ r0 = rf(ctx, checksum)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.Scene)
@@ -312,8 +314,8 @@ func (_m *SceneReaderWriter) FindByChecksum(checksum string) (*models.Scene, err
}
var r1 error
- if rf, ok := ret.Get(1).(func(string) error); ok {
- r1 = rf(checksum)
+ if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
+ r1 = rf(ctx, checksum)
} else {
r1 = ret.Error(1)
}
@@ -321,13 +323,13 @@ func (_m *SceneReaderWriter) FindByChecksum(checksum string) (*models.Scene, err
return r0, r1
}
-// FindByGalleryID provides a mock function with given fields: performerID
-func (_m *SceneReaderWriter) FindByGalleryID(performerID int) ([]*models.Scene, error) {
- ret := _m.Called(performerID)
+// FindByGalleryID provides a mock function with given fields: ctx, performerID
+func (_m *SceneReaderWriter) FindByGalleryID(ctx context.Context, performerID int) ([]*models.Scene, error) {
+ ret := _m.Called(ctx, performerID)
var r0 []*models.Scene
- if rf, ok := ret.Get(0).(func(int) []*models.Scene); ok {
- r0 = rf(performerID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) []*models.Scene); ok {
+ r0 = rf(ctx, performerID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Scene)
@@ -335,8 +337,8 @@ func (_m *SceneReaderWriter) FindByGalleryID(performerID int) ([]*models.Scene,
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(performerID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, performerID)
} else {
r1 = ret.Error(1)
}
@@ -344,13 +346,13 @@ func (_m *SceneReaderWriter) FindByGalleryID(performerID int) ([]*models.Scene,
return r0, r1
}
-// FindByMovieID provides a mock function with given fields: movieID
-func (_m *SceneReaderWriter) FindByMovieID(movieID int) ([]*models.Scene, error) {
- ret := _m.Called(movieID)
+// FindByMovieID provides a mock function with given fields: ctx, movieID
+func (_m *SceneReaderWriter) FindByMovieID(ctx context.Context, movieID int) ([]*models.Scene, error) {
+ ret := _m.Called(ctx, movieID)
var r0 []*models.Scene
- if rf, ok := ret.Get(0).(func(int) []*models.Scene); ok {
- r0 = rf(movieID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) []*models.Scene); ok {
+ r0 = rf(ctx, movieID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Scene)
@@ -358,8 +360,8 @@ func (_m *SceneReaderWriter) FindByMovieID(movieID int) ([]*models.Scene, error)
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(movieID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, movieID)
} else {
r1 = ret.Error(1)
}
@@ -367,13 +369,13 @@ func (_m *SceneReaderWriter) FindByMovieID(movieID int) ([]*models.Scene, error)
return r0, r1
}
-// FindByOSHash provides a mock function with given fields: oshash
-func (_m *SceneReaderWriter) FindByOSHash(oshash string) (*models.Scene, error) {
- ret := _m.Called(oshash)
+// FindByOSHash provides a mock function with given fields: ctx, oshash
+func (_m *SceneReaderWriter) FindByOSHash(ctx context.Context, oshash string) (*models.Scene, error) {
+ ret := _m.Called(ctx, oshash)
var r0 *models.Scene
- if rf, ok := ret.Get(0).(func(string) *models.Scene); ok {
- r0 = rf(oshash)
+ if rf, ok := ret.Get(0).(func(context.Context, string) *models.Scene); ok {
+ r0 = rf(ctx, oshash)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.Scene)
@@ -381,8 +383,8 @@ func (_m *SceneReaderWriter) FindByOSHash(oshash string) (*models.Scene, error)
}
var r1 error
- if rf, ok := ret.Get(1).(func(string) error); ok {
- r1 = rf(oshash)
+ if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
+ r1 = rf(ctx, oshash)
} else {
r1 = ret.Error(1)
}
@@ -390,13 +392,13 @@ func (_m *SceneReaderWriter) FindByOSHash(oshash string) (*models.Scene, error)
return r0, r1
}
-// FindByPath provides a mock function with given fields: path
-func (_m *SceneReaderWriter) FindByPath(path string) (*models.Scene, error) {
- ret := _m.Called(path)
+// FindByPath provides a mock function with given fields: ctx, path
+func (_m *SceneReaderWriter) FindByPath(ctx context.Context, path string) (*models.Scene, error) {
+ ret := _m.Called(ctx, path)
var r0 *models.Scene
- if rf, ok := ret.Get(0).(func(string) *models.Scene); ok {
- r0 = rf(path)
+ if rf, ok := ret.Get(0).(func(context.Context, string) *models.Scene); ok {
+ r0 = rf(ctx, path)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.Scene)
@@ -404,8 +406,8 @@ func (_m *SceneReaderWriter) FindByPath(path string) (*models.Scene, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func(string) error); ok {
- r1 = rf(path)
+ if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
+ r1 = rf(ctx, path)
} else {
r1 = ret.Error(1)
}
@@ -413,13 +415,13 @@ func (_m *SceneReaderWriter) FindByPath(path string) (*models.Scene, error) {
return r0, r1
}
-// FindByPerformerID provides a mock function with given fields: performerID
-func (_m *SceneReaderWriter) FindByPerformerID(performerID int) ([]*models.Scene, error) {
- ret := _m.Called(performerID)
+// FindByPerformerID provides a mock function with given fields: ctx, performerID
+func (_m *SceneReaderWriter) FindByPerformerID(ctx context.Context, performerID int) ([]*models.Scene, error) {
+ ret := _m.Called(ctx, performerID)
var r0 []*models.Scene
- if rf, ok := ret.Get(0).(func(int) []*models.Scene); ok {
- r0 = rf(performerID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) []*models.Scene); ok {
+ r0 = rf(ctx, performerID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Scene)
@@ -427,8 +429,8 @@ func (_m *SceneReaderWriter) FindByPerformerID(performerID int) ([]*models.Scene
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(performerID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, performerID)
} else {
r1 = ret.Error(1)
}
@@ -436,13 +438,13 @@ func (_m *SceneReaderWriter) FindByPerformerID(performerID int) ([]*models.Scene
return r0, r1
}
-// FindDuplicates provides a mock function with given fields: distance
-func (_m *SceneReaderWriter) FindDuplicates(distance int) ([][]*models.Scene, error) {
- ret := _m.Called(distance)
+// FindDuplicates provides a mock function with given fields: ctx, distance
+func (_m *SceneReaderWriter) FindDuplicates(ctx context.Context, distance int) ([][]*models.Scene, error) {
+ ret := _m.Called(ctx, distance)
var r0 [][]*models.Scene
- if rf, ok := ret.Get(0).(func(int) [][]*models.Scene); ok {
- r0 = rf(distance)
+ if rf, ok := ret.Get(0).(func(context.Context, int) [][]*models.Scene); ok {
+ r0 = rf(ctx, distance)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([][]*models.Scene)
@@ -450,8 +452,8 @@ func (_m *SceneReaderWriter) FindDuplicates(distance int) ([][]*models.Scene, er
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(distance)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, distance)
} else {
r1 = ret.Error(1)
}
@@ -459,13 +461,13 @@ func (_m *SceneReaderWriter) FindDuplicates(distance int) ([][]*models.Scene, er
return r0, r1
}
-// FindMany provides a mock function with given fields: ids
-func (_m *SceneReaderWriter) FindMany(ids []int) ([]*models.Scene, error) {
- ret := _m.Called(ids)
+// FindMany provides a mock function with given fields: ctx, ids
+func (_m *SceneReaderWriter) FindMany(ctx context.Context, ids []int) ([]*models.Scene, error) {
+ ret := _m.Called(ctx, ids)
var r0 []*models.Scene
- if rf, ok := ret.Get(0).(func([]int) []*models.Scene); ok {
- r0 = rf(ids)
+ if rf, ok := ret.Get(0).(func(context.Context, []int) []*models.Scene); ok {
+ r0 = rf(ctx, ids)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Scene)
@@ -473,8 +475,8 @@ func (_m *SceneReaderWriter) FindMany(ids []int) ([]*models.Scene, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func([]int) error); ok {
- r1 = rf(ids)
+ if rf, ok := ret.Get(1).(func(context.Context, []int) error); ok {
+ r1 = rf(ctx, ids)
} else {
r1 = ret.Error(1)
}
@@ -482,13 +484,13 @@ func (_m *SceneReaderWriter) FindMany(ids []int) ([]*models.Scene, error) {
return r0, r1
}
-// GetCaptions provides a mock function with given fields: sceneID
-func (_m *SceneReaderWriter) GetCaptions(sceneID int) ([]*models.SceneCaption, error) {
- ret := _m.Called(sceneID)
+// GetCaptions provides a mock function with given fields: ctx, sceneID
+func (_m *SceneReaderWriter) GetCaptions(ctx context.Context, sceneID int) ([]*models.SceneCaption, error) {
+ ret := _m.Called(ctx, sceneID)
var r0 []*models.SceneCaption
- if rf, ok := ret.Get(0).(func(int) []*models.SceneCaption); ok {
- r0 = rf(sceneID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) []*models.SceneCaption); ok {
+ r0 = rf(ctx, sceneID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.SceneCaption)
@@ -496,8 +498,8 @@ func (_m *SceneReaderWriter) GetCaptions(sceneID int) ([]*models.SceneCaption, e
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(sceneID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, sceneID)
} else {
r1 = ret.Error(1)
}
@@ -505,13 +507,13 @@ func (_m *SceneReaderWriter) GetCaptions(sceneID int) ([]*models.SceneCaption, e
return r0, r1
}
-// GetCover provides a mock function with given fields: sceneID
-func (_m *SceneReaderWriter) GetCover(sceneID int) ([]byte, error) {
- ret := _m.Called(sceneID)
+// GetCover provides a mock function with given fields: ctx, sceneID
+func (_m *SceneReaderWriter) GetCover(ctx context.Context, sceneID int) ([]byte, error) {
+ ret := _m.Called(ctx, sceneID)
var r0 []byte
- if rf, ok := ret.Get(0).(func(int) []byte); ok {
- r0 = rf(sceneID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) []byte); ok {
+ r0 = rf(ctx, sceneID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]byte)
@@ -519,8 +521,8 @@ func (_m *SceneReaderWriter) GetCover(sceneID int) ([]byte, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(sceneID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, sceneID)
} else {
r1 = ret.Error(1)
}
@@ -528,13 +530,13 @@ func (_m *SceneReaderWriter) GetCover(sceneID int) ([]byte, error) {
return r0, r1
}
-// GetGalleryIDs provides a mock function with given fields: sceneID
-func (_m *SceneReaderWriter) GetGalleryIDs(sceneID int) ([]int, error) {
- ret := _m.Called(sceneID)
+// GetGalleryIDs provides a mock function with given fields: ctx, sceneID
+func (_m *SceneReaderWriter) GetGalleryIDs(ctx context.Context, sceneID int) ([]int, error) {
+ ret := _m.Called(ctx, sceneID)
var r0 []int
- if rf, ok := ret.Get(0).(func(int) []int); ok {
- r0 = rf(sceneID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) []int); ok {
+ r0 = rf(ctx, sceneID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]int)
@@ -542,8 +544,8 @@ func (_m *SceneReaderWriter) GetGalleryIDs(sceneID int) ([]int, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(sceneID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, sceneID)
} else {
r1 = ret.Error(1)
}
@@ -551,13 +553,13 @@ func (_m *SceneReaderWriter) GetGalleryIDs(sceneID int) ([]int, error) {
return r0, r1
}
-// GetMovies provides a mock function with given fields: sceneID
-func (_m *SceneReaderWriter) GetMovies(sceneID int) ([]models.MoviesScenes, error) {
- ret := _m.Called(sceneID)
+// GetMovies provides a mock function with given fields: ctx, sceneID
+func (_m *SceneReaderWriter) GetMovies(ctx context.Context, sceneID int) ([]models.MoviesScenes, error) {
+ ret := _m.Called(ctx, sceneID)
var r0 []models.MoviesScenes
- if rf, ok := ret.Get(0).(func(int) []models.MoviesScenes); ok {
- r0 = rf(sceneID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) []models.MoviesScenes); ok {
+ r0 = rf(ctx, sceneID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]models.MoviesScenes)
@@ -565,8 +567,8 @@ func (_m *SceneReaderWriter) GetMovies(sceneID int) ([]models.MoviesScenes, erro
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(sceneID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, sceneID)
} else {
r1 = ret.Error(1)
}
@@ -574,13 +576,13 @@ func (_m *SceneReaderWriter) GetMovies(sceneID int) ([]models.MoviesScenes, erro
return r0, r1
}
-// GetPerformerIDs provides a mock function with given fields: sceneID
-func (_m *SceneReaderWriter) GetPerformerIDs(sceneID int) ([]int, error) {
- ret := _m.Called(sceneID)
+// GetPerformerIDs provides a mock function with given fields: ctx, sceneID
+func (_m *SceneReaderWriter) GetPerformerIDs(ctx context.Context, sceneID int) ([]int, error) {
+ ret := _m.Called(ctx, sceneID)
var r0 []int
- if rf, ok := ret.Get(0).(func(int) []int); ok {
- r0 = rf(sceneID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) []int); ok {
+ r0 = rf(ctx, sceneID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]int)
@@ -588,8 +590,8 @@ func (_m *SceneReaderWriter) GetPerformerIDs(sceneID int) ([]int, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(sceneID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, sceneID)
} else {
r1 = ret.Error(1)
}
@@ -597,13 +599,13 @@ func (_m *SceneReaderWriter) GetPerformerIDs(sceneID int) ([]int, error) {
return r0, r1
}
-// GetStashIDs provides a mock function with given fields: sceneID
-func (_m *SceneReaderWriter) GetStashIDs(sceneID int) ([]*models.StashID, error) {
- ret := _m.Called(sceneID)
+// GetStashIDs provides a mock function with given fields: ctx, sceneID
+func (_m *SceneReaderWriter) GetStashIDs(ctx context.Context, sceneID int) ([]*models.StashID, error) {
+ ret := _m.Called(ctx, sceneID)
var r0 []*models.StashID
- if rf, ok := ret.Get(0).(func(int) []*models.StashID); ok {
- r0 = rf(sceneID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) []*models.StashID); ok {
+ r0 = rf(ctx, sceneID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.StashID)
@@ -611,8 +613,8 @@ func (_m *SceneReaderWriter) GetStashIDs(sceneID int) ([]*models.StashID, error)
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(sceneID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, sceneID)
} else {
r1 = ret.Error(1)
}
@@ -620,13 +622,13 @@ func (_m *SceneReaderWriter) GetStashIDs(sceneID int) ([]*models.StashID, error)
return r0, r1
}
-// GetTagIDs provides a mock function with given fields: sceneID
-func (_m *SceneReaderWriter) GetTagIDs(sceneID int) ([]int, error) {
- ret := _m.Called(sceneID)
+// GetTagIDs provides a mock function with given fields: ctx, sceneID
+func (_m *SceneReaderWriter) GetTagIDs(ctx context.Context, sceneID int) ([]int, error) {
+ ret := _m.Called(ctx, sceneID)
var r0 []int
- if rf, ok := ret.Get(0).(func(int) []int); ok {
- r0 = rf(sceneID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) []int); ok {
+ r0 = rf(ctx, sceneID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]int)
@@ -634,8 +636,8 @@ func (_m *SceneReaderWriter) GetTagIDs(sceneID int) ([]int, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(sceneID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, sceneID)
} else {
r1 = ret.Error(1)
}
@@ -643,20 +645,20 @@ func (_m *SceneReaderWriter) GetTagIDs(sceneID int) ([]int, error) {
return r0, r1
}
-// IncrementOCounter provides a mock function with given fields: id
-func (_m *SceneReaderWriter) IncrementOCounter(id int) (int, error) {
- ret := _m.Called(id)
+// IncrementOCounter provides a mock function with given fields: ctx, id
+func (_m *SceneReaderWriter) IncrementOCounter(ctx context.Context, id int) (int, error) {
+ ret := _m.Called(ctx, id)
var r0 int
- if rf, ok := ret.Get(0).(func(int) int); ok {
- r0 = rf(id)
+ if rf, ok := ret.Get(0).(func(context.Context, int) int); ok {
+ r0 = rf(ctx, id)
} else {
r0 = ret.Get(0).(int)
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(id)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, id)
} else {
r1 = ret.Error(1)
}
@@ -664,13 +666,13 @@ func (_m *SceneReaderWriter) IncrementOCounter(id int) (int, error) {
return r0, r1
}
-// Query provides a mock function with given fields: options
-func (_m *SceneReaderWriter) Query(options models.SceneQueryOptions) (*models.SceneQueryResult, error) {
- ret := _m.Called(options)
+// Query provides a mock function with given fields: ctx, options
+func (_m *SceneReaderWriter) Query(ctx context.Context, options models.SceneQueryOptions) (*models.SceneQueryResult, error) {
+ ret := _m.Called(ctx, options)
var r0 *models.SceneQueryResult
- if rf, ok := ret.Get(0).(func(models.SceneQueryOptions) *models.SceneQueryResult); ok {
- r0 = rf(options)
+ if rf, ok := ret.Get(0).(func(context.Context, models.SceneQueryOptions) *models.SceneQueryResult); ok {
+ r0 = rf(ctx, options)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.SceneQueryResult)
@@ -678,8 +680,8 @@ func (_m *SceneReaderWriter) Query(options models.SceneQueryOptions) (*models.Sc
}
var r1 error
- if rf, ok := ret.Get(1).(func(models.SceneQueryOptions) error); ok {
- r1 = rf(options)
+ if rf, ok := ret.Get(1).(func(context.Context, models.SceneQueryOptions) error); ok {
+ r1 = rf(ctx, options)
} else {
r1 = ret.Error(1)
}
@@ -687,20 +689,20 @@ func (_m *SceneReaderWriter) Query(options models.SceneQueryOptions) (*models.Sc
return r0, r1
}
-// ResetOCounter provides a mock function with given fields: id
-func (_m *SceneReaderWriter) ResetOCounter(id int) (int, error) {
- ret := _m.Called(id)
+// ResetOCounter provides a mock function with given fields: ctx, id
+func (_m *SceneReaderWriter) ResetOCounter(ctx context.Context, id int) (int, error) {
+ ret := _m.Called(ctx, id)
var r0 int
- if rf, ok := ret.Get(0).(func(int) int); ok {
- r0 = rf(id)
+ if rf, ok := ret.Get(0).(func(context.Context, int) int); ok {
+ r0 = rf(ctx, id)
} else {
r0 = ret.Get(0).(int)
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(id)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, id)
} else {
r1 = ret.Error(1)
}
@@ -708,20 +710,20 @@ func (_m *SceneReaderWriter) ResetOCounter(id int) (int, error) {
return r0, r1
}
-// Size provides a mock function with given fields:
-func (_m *SceneReaderWriter) Size() (float64, error) {
- ret := _m.Called()
+// Size provides a mock function with given fields: ctx
+func (_m *SceneReaderWriter) Size(ctx context.Context) (float64, error) {
+ ret := _m.Called(ctx)
var r0 float64
- if rf, ok := ret.Get(0).(func() float64); ok {
- r0 = rf()
+ if rf, ok := ret.Get(0).(func(context.Context) float64); ok {
+ r0 = rf(ctx)
} else {
r0 = ret.Get(0).(float64)
}
var r1 error
- if rf, ok := ret.Get(1).(func() error); ok {
- r1 = rf()
+ if rf, ok := ret.Get(1).(func(context.Context) error); ok {
+ r1 = rf(ctx)
} else {
r1 = ret.Error(1)
}
@@ -729,13 +731,13 @@ func (_m *SceneReaderWriter) Size() (float64, error) {
return r0, r1
}
-// Update provides a mock function with given fields: updatedScene
-func (_m *SceneReaderWriter) Update(updatedScene models.ScenePartial) (*models.Scene, error) {
- ret := _m.Called(updatedScene)
+// Update provides a mock function with given fields: ctx, updatedScene
+func (_m *SceneReaderWriter) Update(ctx context.Context, updatedScene models.ScenePartial) (*models.Scene, error) {
+ ret := _m.Called(ctx, updatedScene)
var r0 *models.Scene
- if rf, ok := ret.Get(0).(func(models.ScenePartial) *models.Scene); ok {
- r0 = rf(updatedScene)
+ if rf, ok := ret.Get(0).(func(context.Context, models.ScenePartial) *models.Scene); ok {
+ r0 = rf(ctx, updatedScene)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.Scene)
@@ -743,8 +745,8 @@ func (_m *SceneReaderWriter) Update(updatedScene models.ScenePartial) (*models.S
}
var r1 error
- if rf, ok := ret.Get(1).(func(models.ScenePartial) error); ok {
- r1 = rf(updatedScene)
+ if rf, ok := ret.Get(1).(func(context.Context, models.ScenePartial) error); ok {
+ r1 = rf(ctx, updatedScene)
} else {
r1 = ret.Error(1)
}
@@ -752,13 +754,13 @@ func (_m *SceneReaderWriter) Update(updatedScene models.ScenePartial) (*models.S
return r0, r1
}
-// UpdateCaptions provides a mock function with given fields: id, captions
-func (_m *SceneReaderWriter) UpdateCaptions(id int, captions []*models.SceneCaption) error {
- ret := _m.Called(id, captions)
+// UpdateCaptions provides a mock function with given fields: ctx, id, captions
+func (_m *SceneReaderWriter) UpdateCaptions(ctx context.Context, id int, captions []*models.SceneCaption) error {
+ ret := _m.Called(ctx, id, captions)
var r0 error
- if rf, ok := ret.Get(0).(func(int, []*models.SceneCaption) error); ok {
- r0 = rf(id, captions)
+ if rf, ok := ret.Get(0).(func(context.Context, int, []*models.SceneCaption) error); ok {
+ r0 = rf(ctx, id, captions)
} else {
r0 = ret.Error(0)
}
@@ -766,13 +768,13 @@ func (_m *SceneReaderWriter) UpdateCaptions(id int, captions []*models.SceneCapt
return r0
}
-// UpdateCover provides a mock function with given fields: sceneID, cover
-func (_m *SceneReaderWriter) UpdateCover(sceneID int, cover []byte) error {
- ret := _m.Called(sceneID, cover)
+// UpdateCover provides a mock function with given fields: ctx, sceneID, cover
+func (_m *SceneReaderWriter) UpdateCover(ctx context.Context, sceneID int, cover []byte) error {
+ ret := _m.Called(ctx, sceneID, cover)
var r0 error
- if rf, ok := ret.Get(0).(func(int, []byte) error); ok {
- r0 = rf(sceneID, cover)
+ if rf, ok := ret.Get(0).(func(context.Context, int, []byte) error); ok {
+ r0 = rf(ctx, sceneID, cover)
} else {
r0 = ret.Error(0)
}
@@ -780,13 +782,13 @@ func (_m *SceneReaderWriter) UpdateCover(sceneID int, cover []byte) error {
return r0
}
-// UpdateFileModTime provides a mock function with given fields: id, modTime
-func (_m *SceneReaderWriter) UpdateFileModTime(id int, modTime models.NullSQLiteTimestamp) error {
- ret := _m.Called(id, modTime)
+// UpdateFileModTime provides a mock function with given fields: ctx, id, modTime
+func (_m *SceneReaderWriter) UpdateFileModTime(ctx context.Context, id int, modTime models.NullSQLiteTimestamp) error {
+ ret := _m.Called(ctx, id, modTime)
var r0 error
- if rf, ok := ret.Get(0).(func(int, models.NullSQLiteTimestamp) error); ok {
- r0 = rf(id, modTime)
+ if rf, ok := ret.Get(0).(func(context.Context, int, models.NullSQLiteTimestamp) error); ok {
+ r0 = rf(ctx, id, modTime)
} else {
r0 = ret.Error(0)
}
@@ -794,13 +796,13 @@ func (_m *SceneReaderWriter) UpdateFileModTime(id int, modTime models.NullSQLite
return r0
}
-// UpdateFull provides a mock function with given fields: updatedScene
-func (_m *SceneReaderWriter) UpdateFull(updatedScene models.Scene) (*models.Scene, error) {
- ret := _m.Called(updatedScene)
+// UpdateFull provides a mock function with given fields: ctx, updatedScene
+func (_m *SceneReaderWriter) UpdateFull(ctx context.Context, updatedScene models.Scene) (*models.Scene, error) {
+ ret := _m.Called(ctx, updatedScene)
var r0 *models.Scene
- if rf, ok := ret.Get(0).(func(models.Scene) *models.Scene); ok {
- r0 = rf(updatedScene)
+ if rf, ok := ret.Get(0).(func(context.Context, models.Scene) *models.Scene); ok {
+ r0 = rf(ctx, updatedScene)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.Scene)
@@ -808,8 +810,8 @@ func (_m *SceneReaderWriter) UpdateFull(updatedScene models.Scene) (*models.Scen
}
var r1 error
- if rf, ok := ret.Get(1).(func(models.Scene) error); ok {
- r1 = rf(updatedScene)
+ if rf, ok := ret.Get(1).(func(context.Context, models.Scene) error); ok {
+ r1 = rf(ctx, updatedScene)
} else {
r1 = ret.Error(1)
}
@@ -817,13 +819,13 @@ func (_m *SceneReaderWriter) UpdateFull(updatedScene models.Scene) (*models.Scen
return r0, r1
}
-// UpdateGalleries provides a mock function with given fields: sceneID, galleryIDs
-func (_m *SceneReaderWriter) UpdateGalleries(sceneID int, galleryIDs []int) error {
- ret := _m.Called(sceneID, galleryIDs)
+// UpdateGalleries provides a mock function with given fields: ctx, sceneID, galleryIDs
+func (_m *SceneReaderWriter) UpdateGalleries(ctx context.Context, sceneID int, galleryIDs []int) error {
+ ret := _m.Called(ctx, sceneID, galleryIDs)
var r0 error
- if rf, ok := ret.Get(0).(func(int, []int) error); ok {
- r0 = rf(sceneID, galleryIDs)
+ if rf, ok := ret.Get(0).(func(context.Context, int, []int) error); ok {
+ r0 = rf(ctx, sceneID, galleryIDs)
} else {
r0 = ret.Error(0)
}
@@ -831,13 +833,13 @@ func (_m *SceneReaderWriter) UpdateGalleries(sceneID int, galleryIDs []int) erro
return r0
}
-// UpdateMovies provides a mock function with given fields: sceneID, movies
-func (_m *SceneReaderWriter) UpdateMovies(sceneID int, movies []models.MoviesScenes) error {
- ret := _m.Called(sceneID, movies)
+// UpdateMovies provides a mock function with given fields: ctx, sceneID, movies
+func (_m *SceneReaderWriter) UpdateMovies(ctx context.Context, sceneID int, movies []models.MoviesScenes) error {
+ ret := _m.Called(ctx, sceneID, movies)
var r0 error
- if rf, ok := ret.Get(0).(func(int, []models.MoviesScenes) error); ok {
- r0 = rf(sceneID, movies)
+ if rf, ok := ret.Get(0).(func(context.Context, int, []models.MoviesScenes) error); ok {
+ r0 = rf(ctx, sceneID, movies)
} else {
r0 = ret.Error(0)
}
@@ -845,13 +847,13 @@ func (_m *SceneReaderWriter) UpdateMovies(sceneID int, movies []models.MoviesSce
return r0
}
-// UpdatePerformers provides a mock function with given fields: sceneID, performerIDs
-func (_m *SceneReaderWriter) UpdatePerformers(sceneID int, performerIDs []int) error {
- ret := _m.Called(sceneID, performerIDs)
+// UpdatePerformers provides a mock function with given fields: ctx, sceneID, performerIDs
+func (_m *SceneReaderWriter) UpdatePerformers(ctx context.Context, sceneID int, performerIDs []int) error {
+ ret := _m.Called(ctx, sceneID, performerIDs)
var r0 error
- if rf, ok := ret.Get(0).(func(int, []int) error); ok {
- r0 = rf(sceneID, performerIDs)
+ if rf, ok := ret.Get(0).(func(context.Context, int, []int) error); ok {
+ r0 = rf(ctx, sceneID, performerIDs)
} else {
r0 = ret.Error(0)
}
@@ -859,13 +861,13 @@ func (_m *SceneReaderWriter) UpdatePerformers(sceneID int, performerIDs []int) e
return r0
}
-// UpdateStashIDs provides a mock function with given fields: sceneID, stashIDs
-func (_m *SceneReaderWriter) UpdateStashIDs(sceneID int, stashIDs []models.StashID) error {
- ret := _m.Called(sceneID, stashIDs)
+// UpdateStashIDs provides a mock function with given fields: ctx, sceneID, stashIDs
+func (_m *SceneReaderWriter) UpdateStashIDs(ctx context.Context, sceneID int, stashIDs []models.StashID) error {
+ ret := _m.Called(ctx, sceneID, stashIDs)
var r0 error
- if rf, ok := ret.Get(0).(func(int, []models.StashID) error); ok {
- r0 = rf(sceneID, stashIDs)
+ if rf, ok := ret.Get(0).(func(context.Context, int, []models.StashID) error); ok {
+ r0 = rf(ctx, sceneID, stashIDs)
} else {
r0 = ret.Error(0)
}
@@ -873,13 +875,13 @@ func (_m *SceneReaderWriter) UpdateStashIDs(sceneID int, stashIDs []models.Stash
return r0
}
-// UpdateTags provides a mock function with given fields: sceneID, tagIDs
-func (_m *SceneReaderWriter) UpdateTags(sceneID int, tagIDs []int) error {
- ret := _m.Called(sceneID, tagIDs)
+// UpdateTags provides a mock function with given fields: ctx, sceneID, tagIDs
+func (_m *SceneReaderWriter) UpdateTags(ctx context.Context, sceneID int, tagIDs []int) error {
+ ret := _m.Called(ctx, sceneID, tagIDs)
var r0 error
- if rf, ok := ret.Get(0).(func(int, []int) error); ok {
- r0 = rf(sceneID, tagIDs)
+ if rf, ok := ret.Get(0).(func(context.Context, int, []int) error); ok {
+ r0 = rf(ctx, sceneID, tagIDs)
} else {
r0 = ret.Error(0)
}
@@ -887,13 +889,13 @@ func (_m *SceneReaderWriter) UpdateTags(sceneID int, tagIDs []int) error {
return r0
}
-// Wall provides a mock function with given fields: q
-func (_m *SceneReaderWriter) Wall(q *string) ([]*models.Scene, error) {
- ret := _m.Called(q)
+// Wall provides a mock function with given fields: ctx, q
+func (_m *SceneReaderWriter) Wall(ctx context.Context, q *string) ([]*models.Scene, error) {
+ ret := _m.Called(ctx, q)
var r0 []*models.Scene
- if rf, ok := ret.Get(0).(func(*string) []*models.Scene); ok {
- r0 = rf(q)
+ if rf, ok := ret.Get(0).(func(context.Context, *string) []*models.Scene); ok {
+ r0 = rf(ctx, q)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Scene)
@@ -901,8 +903,8 @@ func (_m *SceneReaderWriter) Wall(q *string) ([]*models.Scene, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func(*string) error); ok {
- r1 = rf(q)
+ if rf, ok := ret.Get(1).(func(context.Context, *string) error); ok {
+ r1 = rf(ctx, q)
} else {
r1 = ret.Error(1)
}
diff --git a/pkg/models/mocks/ScrapedItemReaderWriter.go b/pkg/models/mocks/ScrapedItemReaderWriter.go
index e06b7451d..7157ab855 100644
--- a/pkg/models/mocks/ScrapedItemReaderWriter.go
+++ b/pkg/models/mocks/ScrapedItemReaderWriter.go
@@ -3,6 +3,8 @@
package mocks
import (
+ context "context"
+
models "github.com/stashapp/stash/pkg/models"
mock "github.com/stretchr/testify/mock"
)
@@ -12,13 +14,13 @@ type ScrapedItemReaderWriter struct {
mock.Mock
}
-// All provides a mock function with given fields:
-func (_m *ScrapedItemReaderWriter) All() ([]*models.ScrapedItem, error) {
- ret := _m.Called()
+// All provides a mock function with given fields: ctx
+func (_m *ScrapedItemReaderWriter) All(ctx context.Context) ([]*models.ScrapedItem, error) {
+ ret := _m.Called(ctx)
var r0 []*models.ScrapedItem
- if rf, ok := ret.Get(0).(func() []*models.ScrapedItem); ok {
- r0 = rf()
+ if rf, ok := ret.Get(0).(func(context.Context) []*models.ScrapedItem); ok {
+ r0 = rf(ctx)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.ScrapedItem)
@@ -26,8 +28,8 @@ func (_m *ScrapedItemReaderWriter) All() ([]*models.ScrapedItem, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func() error); ok {
- r1 = rf()
+ if rf, ok := ret.Get(1).(func(context.Context) error); ok {
+ r1 = rf(ctx)
} else {
r1 = ret.Error(1)
}
@@ -35,13 +37,13 @@ func (_m *ScrapedItemReaderWriter) All() ([]*models.ScrapedItem, error) {
return r0, r1
}
-// Create provides a mock function with given fields: newObject
-func (_m *ScrapedItemReaderWriter) Create(newObject models.ScrapedItem) (*models.ScrapedItem, error) {
- ret := _m.Called(newObject)
+// Create provides a mock function with given fields: ctx, newObject
+func (_m *ScrapedItemReaderWriter) Create(ctx context.Context, newObject models.ScrapedItem) (*models.ScrapedItem, error) {
+ ret := _m.Called(ctx, newObject)
var r0 *models.ScrapedItem
- if rf, ok := ret.Get(0).(func(models.ScrapedItem) *models.ScrapedItem); ok {
- r0 = rf(newObject)
+ if rf, ok := ret.Get(0).(func(context.Context, models.ScrapedItem) *models.ScrapedItem); ok {
+ r0 = rf(ctx, newObject)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.ScrapedItem)
@@ -49,8 +51,8 @@ func (_m *ScrapedItemReaderWriter) Create(newObject models.ScrapedItem) (*models
}
var r1 error
- if rf, ok := ret.Get(1).(func(models.ScrapedItem) error); ok {
- r1 = rf(newObject)
+ if rf, ok := ret.Get(1).(func(context.Context, models.ScrapedItem) error); ok {
+ r1 = rf(ctx, newObject)
} else {
r1 = ret.Error(1)
}
diff --git a/pkg/models/mocks/StudioReaderWriter.go b/pkg/models/mocks/StudioReaderWriter.go
index c15c73719..bc8891983 100644
--- a/pkg/models/mocks/StudioReaderWriter.go
+++ b/pkg/models/mocks/StudioReaderWriter.go
@@ -3,6 +3,8 @@
package mocks
import (
+ context "context"
+
models "github.com/stashapp/stash/pkg/models"
mock "github.com/stretchr/testify/mock"
)
@@ -12,13 +14,13 @@ type StudioReaderWriter struct {
mock.Mock
}
-// All provides a mock function with given fields:
-func (_m *StudioReaderWriter) All() ([]*models.Studio, error) {
- ret := _m.Called()
+// All provides a mock function with given fields: ctx
+func (_m *StudioReaderWriter) All(ctx context.Context) ([]*models.Studio, error) {
+ ret := _m.Called(ctx)
var r0 []*models.Studio
- if rf, ok := ret.Get(0).(func() []*models.Studio); ok {
- r0 = rf()
+ if rf, ok := ret.Get(0).(func(context.Context) []*models.Studio); ok {
+ r0 = rf(ctx)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Studio)
@@ -26,8 +28,8 @@ func (_m *StudioReaderWriter) All() ([]*models.Studio, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func() error); ok {
- r1 = rf()
+ if rf, ok := ret.Get(1).(func(context.Context) error); ok {
+ r1 = rf(ctx)
} else {
r1 = ret.Error(1)
}
@@ -35,20 +37,20 @@ func (_m *StudioReaderWriter) All() ([]*models.Studio, error) {
return r0, r1
}
-// Count provides a mock function with given fields:
-func (_m *StudioReaderWriter) Count() (int, error) {
- ret := _m.Called()
+// Count provides a mock function with given fields: ctx
+func (_m *StudioReaderWriter) Count(ctx context.Context) (int, error) {
+ ret := _m.Called(ctx)
var r0 int
- if rf, ok := ret.Get(0).(func() int); ok {
- r0 = rf()
+ if rf, ok := ret.Get(0).(func(context.Context) int); ok {
+ r0 = rf(ctx)
} else {
r0 = ret.Get(0).(int)
}
var r1 error
- if rf, ok := ret.Get(1).(func() error); ok {
- r1 = rf()
+ if rf, ok := ret.Get(1).(func(context.Context) error); ok {
+ r1 = rf(ctx)
} else {
r1 = ret.Error(1)
}
@@ -56,13 +58,13 @@ func (_m *StudioReaderWriter) Count() (int, error) {
return r0, r1
}
-// Create provides a mock function with given fields: newStudio
-func (_m *StudioReaderWriter) Create(newStudio models.Studio) (*models.Studio, error) {
- ret := _m.Called(newStudio)
+// Create provides a mock function with given fields: ctx, newStudio
+func (_m *StudioReaderWriter) Create(ctx context.Context, newStudio models.Studio) (*models.Studio, error) {
+ ret := _m.Called(ctx, newStudio)
var r0 *models.Studio
- if rf, ok := ret.Get(0).(func(models.Studio) *models.Studio); ok {
- r0 = rf(newStudio)
+ if rf, ok := ret.Get(0).(func(context.Context, models.Studio) *models.Studio); ok {
+ r0 = rf(ctx, newStudio)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.Studio)
@@ -70,8 +72,8 @@ func (_m *StudioReaderWriter) Create(newStudio models.Studio) (*models.Studio, e
}
var r1 error
- if rf, ok := ret.Get(1).(func(models.Studio) error); ok {
- r1 = rf(newStudio)
+ if rf, ok := ret.Get(1).(func(context.Context, models.Studio) error); ok {
+ r1 = rf(ctx, newStudio)
} else {
r1 = ret.Error(1)
}
@@ -79,13 +81,13 @@ func (_m *StudioReaderWriter) Create(newStudio models.Studio) (*models.Studio, e
return r0, r1
}
-// Destroy provides a mock function with given fields: id
-func (_m *StudioReaderWriter) Destroy(id int) error {
- ret := _m.Called(id)
+// Destroy provides a mock function with given fields: ctx, id
+func (_m *StudioReaderWriter) Destroy(ctx context.Context, id int) error {
+ ret := _m.Called(ctx, id)
var r0 error
- if rf, ok := ret.Get(0).(func(int) error); ok {
- r0 = rf(id)
+ if rf, ok := ret.Get(0).(func(context.Context, int) error); ok {
+ r0 = rf(ctx, id)
} else {
r0 = ret.Error(0)
}
@@ -93,13 +95,13 @@ func (_m *StudioReaderWriter) Destroy(id int) error {
return r0
}
-// DestroyImage provides a mock function with given fields: studioID
-func (_m *StudioReaderWriter) DestroyImage(studioID int) error {
- ret := _m.Called(studioID)
+// DestroyImage provides a mock function with given fields: ctx, studioID
+func (_m *StudioReaderWriter) DestroyImage(ctx context.Context, studioID int) error {
+ ret := _m.Called(ctx, studioID)
var r0 error
- if rf, ok := ret.Get(0).(func(int) error); ok {
- r0 = rf(studioID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) error); ok {
+ r0 = rf(ctx, studioID)
} else {
r0 = ret.Error(0)
}
@@ -107,13 +109,13 @@ func (_m *StudioReaderWriter) DestroyImage(studioID int) error {
return r0
}
-// Find provides a mock function with given fields: id
-func (_m *StudioReaderWriter) Find(id int) (*models.Studio, error) {
- ret := _m.Called(id)
+// Find provides a mock function with given fields: ctx, id
+func (_m *StudioReaderWriter) Find(ctx context.Context, id int) (*models.Studio, error) {
+ ret := _m.Called(ctx, id)
var r0 *models.Studio
- if rf, ok := ret.Get(0).(func(int) *models.Studio); ok {
- r0 = rf(id)
+ if rf, ok := ret.Get(0).(func(context.Context, int) *models.Studio); ok {
+ r0 = rf(ctx, id)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.Studio)
@@ -121,8 +123,8 @@ func (_m *StudioReaderWriter) Find(id int) (*models.Studio, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(id)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, id)
} else {
r1 = ret.Error(1)
}
@@ -130,13 +132,13 @@ func (_m *StudioReaderWriter) Find(id int) (*models.Studio, error) {
return r0, r1
}
-// FindByName provides a mock function with given fields: name, nocase
-func (_m *StudioReaderWriter) FindByName(name string, nocase bool) (*models.Studio, error) {
- ret := _m.Called(name, nocase)
+// FindByName provides a mock function with given fields: ctx, name, nocase
+func (_m *StudioReaderWriter) FindByName(ctx context.Context, name string, nocase bool) (*models.Studio, error) {
+ ret := _m.Called(ctx, name, nocase)
var r0 *models.Studio
- if rf, ok := ret.Get(0).(func(string, bool) *models.Studio); ok {
- r0 = rf(name, nocase)
+ if rf, ok := ret.Get(0).(func(context.Context, string, bool) *models.Studio); ok {
+ r0 = rf(ctx, name, nocase)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.Studio)
@@ -144,8 +146,8 @@ func (_m *StudioReaderWriter) FindByName(name string, nocase bool) (*models.Stud
}
var r1 error
- if rf, ok := ret.Get(1).(func(string, bool) error); ok {
- r1 = rf(name, nocase)
+ if rf, ok := ret.Get(1).(func(context.Context, string, bool) error); ok {
+ r1 = rf(ctx, name, nocase)
} else {
r1 = ret.Error(1)
}
@@ -153,13 +155,13 @@ func (_m *StudioReaderWriter) FindByName(name string, nocase bool) (*models.Stud
return r0, r1
}
-// FindByStashID provides a mock function with given fields: stashID
-func (_m *StudioReaderWriter) FindByStashID(stashID models.StashID) ([]*models.Studio, error) {
- ret := _m.Called(stashID)
+// FindByStashID provides a mock function with given fields: ctx, stashID
+func (_m *StudioReaderWriter) FindByStashID(ctx context.Context, stashID models.StashID) ([]*models.Studio, error) {
+ ret := _m.Called(ctx, stashID)
var r0 []*models.Studio
- if rf, ok := ret.Get(0).(func(models.StashID) []*models.Studio); ok {
- r0 = rf(stashID)
+ if rf, ok := ret.Get(0).(func(context.Context, models.StashID) []*models.Studio); ok {
+ r0 = rf(ctx, stashID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Studio)
@@ -167,8 +169,8 @@ func (_m *StudioReaderWriter) FindByStashID(stashID models.StashID) ([]*models.S
}
var r1 error
- if rf, ok := ret.Get(1).(func(models.StashID) error); ok {
- r1 = rf(stashID)
+ if rf, ok := ret.Get(1).(func(context.Context, models.StashID) error); ok {
+ r1 = rf(ctx, stashID)
} else {
r1 = ret.Error(1)
}
@@ -176,13 +178,13 @@ func (_m *StudioReaderWriter) FindByStashID(stashID models.StashID) ([]*models.S
return r0, r1
}
-// FindChildren provides a mock function with given fields: id
-func (_m *StudioReaderWriter) FindChildren(id int) ([]*models.Studio, error) {
- ret := _m.Called(id)
+// FindChildren provides a mock function with given fields: ctx, id
+func (_m *StudioReaderWriter) FindChildren(ctx context.Context, id int) ([]*models.Studio, error) {
+ ret := _m.Called(ctx, id)
var r0 []*models.Studio
- if rf, ok := ret.Get(0).(func(int) []*models.Studio); ok {
- r0 = rf(id)
+ if rf, ok := ret.Get(0).(func(context.Context, int) []*models.Studio); ok {
+ r0 = rf(ctx, id)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Studio)
@@ -190,8 +192,8 @@ func (_m *StudioReaderWriter) FindChildren(id int) ([]*models.Studio, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(id)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, id)
} else {
r1 = ret.Error(1)
}
@@ -199,13 +201,13 @@ func (_m *StudioReaderWriter) FindChildren(id int) ([]*models.Studio, error) {
return r0, r1
}
-// FindMany provides a mock function with given fields: ids
-func (_m *StudioReaderWriter) FindMany(ids []int) ([]*models.Studio, error) {
- ret := _m.Called(ids)
+// FindMany provides a mock function with given fields: ctx, ids
+func (_m *StudioReaderWriter) FindMany(ctx context.Context, ids []int) ([]*models.Studio, error) {
+ ret := _m.Called(ctx, ids)
var r0 []*models.Studio
- if rf, ok := ret.Get(0).(func([]int) []*models.Studio); ok {
- r0 = rf(ids)
+ if rf, ok := ret.Get(0).(func(context.Context, []int) []*models.Studio); ok {
+ r0 = rf(ctx, ids)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Studio)
@@ -213,8 +215,8 @@ func (_m *StudioReaderWriter) FindMany(ids []int) ([]*models.Studio, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func([]int) error); ok {
- r1 = rf(ids)
+ if rf, ok := ret.Get(1).(func(context.Context, []int) error); ok {
+ r1 = rf(ctx, ids)
} else {
r1 = ret.Error(1)
}
@@ -222,13 +224,13 @@ func (_m *StudioReaderWriter) FindMany(ids []int) ([]*models.Studio, error) {
return r0, r1
}
-// GetAliases provides a mock function with given fields: studioID
-func (_m *StudioReaderWriter) GetAliases(studioID int) ([]string, error) {
- ret := _m.Called(studioID)
+// GetAliases provides a mock function with given fields: ctx, studioID
+func (_m *StudioReaderWriter) GetAliases(ctx context.Context, studioID int) ([]string, error) {
+ ret := _m.Called(ctx, studioID)
var r0 []string
- if rf, ok := ret.Get(0).(func(int) []string); ok {
- r0 = rf(studioID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) []string); ok {
+ r0 = rf(ctx, studioID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]string)
@@ -236,8 +238,8 @@ func (_m *StudioReaderWriter) GetAliases(studioID int) ([]string, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(studioID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, studioID)
} else {
r1 = ret.Error(1)
}
@@ -245,13 +247,13 @@ func (_m *StudioReaderWriter) GetAliases(studioID int) ([]string, error) {
return r0, r1
}
-// GetImage provides a mock function with given fields: studioID
-func (_m *StudioReaderWriter) GetImage(studioID int) ([]byte, error) {
- ret := _m.Called(studioID)
+// GetImage provides a mock function with given fields: ctx, studioID
+func (_m *StudioReaderWriter) GetImage(ctx context.Context, studioID int) ([]byte, error) {
+ ret := _m.Called(ctx, studioID)
var r0 []byte
- if rf, ok := ret.Get(0).(func(int) []byte); ok {
- r0 = rf(studioID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) []byte); ok {
+ r0 = rf(ctx, studioID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]byte)
@@ -259,8 +261,8 @@ func (_m *StudioReaderWriter) GetImage(studioID int) ([]byte, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(studioID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, studioID)
} else {
r1 = ret.Error(1)
}
@@ -268,13 +270,13 @@ func (_m *StudioReaderWriter) GetImage(studioID int) ([]byte, error) {
return r0, r1
}
-// GetStashIDs provides a mock function with given fields: studioID
-func (_m *StudioReaderWriter) GetStashIDs(studioID int) ([]*models.StashID, error) {
- ret := _m.Called(studioID)
+// GetStashIDs provides a mock function with given fields: ctx, studioID
+func (_m *StudioReaderWriter) GetStashIDs(ctx context.Context, studioID int) ([]*models.StashID, error) {
+ ret := _m.Called(ctx, studioID)
var r0 []*models.StashID
- if rf, ok := ret.Get(0).(func(int) []*models.StashID); ok {
- r0 = rf(studioID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) []*models.StashID); ok {
+ r0 = rf(ctx, studioID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.StashID)
@@ -282,8 +284,8 @@ func (_m *StudioReaderWriter) GetStashIDs(studioID int) ([]*models.StashID, erro
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(studioID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, studioID)
} else {
r1 = ret.Error(1)
}
@@ -291,20 +293,20 @@ func (_m *StudioReaderWriter) GetStashIDs(studioID int) ([]*models.StashID, erro
return r0, r1
}
-// HasImage provides a mock function with given fields: studioID
-func (_m *StudioReaderWriter) HasImage(studioID int) (bool, error) {
- ret := _m.Called(studioID)
+// HasImage provides a mock function with given fields: ctx, studioID
+func (_m *StudioReaderWriter) HasImage(ctx context.Context, studioID int) (bool, error) {
+ ret := _m.Called(ctx, studioID)
var r0 bool
- if rf, ok := ret.Get(0).(func(int) bool); ok {
- r0 = rf(studioID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) bool); ok {
+ r0 = rf(ctx, studioID)
} else {
r0 = ret.Get(0).(bool)
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(studioID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, studioID)
} else {
r1 = ret.Error(1)
}
@@ -312,13 +314,13 @@ func (_m *StudioReaderWriter) HasImage(studioID int) (bool, error) {
return r0, r1
}
-// Query provides a mock function with given fields: studioFilter, findFilter
-func (_m *StudioReaderWriter) Query(studioFilter *models.StudioFilterType, findFilter *models.FindFilterType) ([]*models.Studio, int, error) {
- ret := _m.Called(studioFilter, findFilter)
+// Query provides a mock function with given fields: ctx, studioFilter, findFilter
+func (_m *StudioReaderWriter) Query(ctx context.Context, studioFilter *models.StudioFilterType, findFilter *models.FindFilterType) ([]*models.Studio, int, error) {
+ ret := _m.Called(ctx, studioFilter, findFilter)
var r0 []*models.Studio
- if rf, ok := ret.Get(0).(func(*models.StudioFilterType, *models.FindFilterType) []*models.Studio); ok {
- r0 = rf(studioFilter, findFilter)
+ if rf, ok := ret.Get(0).(func(context.Context, *models.StudioFilterType, *models.FindFilterType) []*models.Studio); ok {
+ r0 = rf(ctx, studioFilter, findFilter)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Studio)
@@ -326,15 +328,15 @@ func (_m *StudioReaderWriter) Query(studioFilter *models.StudioFilterType, findF
}
var r1 int
- if rf, ok := ret.Get(1).(func(*models.StudioFilterType, *models.FindFilterType) int); ok {
- r1 = rf(studioFilter, findFilter)
+ if rf, ok := ret.Get(1).(func(context.Context, *models.StudioFilterType, *models.FindFilterType) int); ok {
+ r1 = rf(ctx, studioFilter, findFilter)
} else {
r1 = ret.Get(1).(int)
}
var r2 error
- if rf, ok := ret.Get(2).(func(*models.StudioFilterType, *models.FindFilterType) error); ok {
- r2 = rf(studioFilter, findFilter)
+ if rf, ok := ret.Get(2).(func(context.Context, *models.StudioFilterType, *models.FindFilterType) error); ok {
+ r2 = rf(ctx, studioFilter, findFilter)
} else {
r2 = ret.Error(2)
}
@@ -342,13 +344,13 @@ func (_m *StudioReaderWriter) Query(studioFilter *models.StudioFilterType, findF
return r0, r1, r2
}
-// QueryForAutoTag provides a mock function with given fields: words
-func (_m *StudioReaderWriter) QueryForAutoTag(words []string) ([]*models.Studio, error) {
- ret := _m.Called(words)
+// QueryForAutoTag provides a mock function with given fields: ctx, words
+func (_m *StudioReaderWriter) QueryForAutoTag(ctx context.Context, words []string) ([]*models.Studio, error) {
+ ret := _m.Called(ctx, words)
var r0 []*models.Studio
- if rf, ok := ret.Get(0).(func([]string) []*models.Studio); ok {
- r0 = rf(words)
+ if rf, ok := ret.Get(0).(func(context.Context, []string) []*models.Studio); ok {
+ r0 = rf(ctx, words)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Studio)
@@ -356,8 +358,8 @@ func (_m *StudioReaderWriter) QueryForAutoTag(words []string) ([]*models.Studio,
}
var r1 error
- if rf, ok := ret.Get(1).(func([]string) error); ok {
- r1 = rf(words)
+ if rf, ok := ret.Get(1).(func(context.Context, []string) error); ok {
+ r1 = rf(ctx, words)
} else {
r1 = ret.Error(1)
}
@@ -365,13 +367,13 @@ func (_m *StudioReaderWriter) QueryForAutoTag(words []string) ([]*models.Studio,
return r0, r1
}
-// Update provides a mock function with given fields: updatedStudio
-func (_m *StudioReaderWriter) Update(updatedStudio models.StudioPartial) (*models.Studio, error) {
- ret := _m.Called(updatedStudio)
+// Update provides a mock function with given fields: ctx, updatedStudio
+func (_m *StudioReaderWriter) Update(ctx context.Context, updatedStudio models.StudioPartial) (*models.Studio, error) {
+ ret := _m.Called(ctx, updatedStudio)
var r0 *models.Studio
- if rf, ok := ret.Get(0).(func(models.StudioPartial) *models.Studio); ok {
- r0 = rf(updatedStudio)
+ if rf, ok := ret.Get(0).(func(context.Context, models.StudioPartial) *models.Studio); ok {
+ r0 = rf(ctx, updatedStudio)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.Studio)
@@ -379,8 +381,8 @@ func (_m *StudioReaderWriter) Update(updatedStudio models.StudioPartial) (*model
}
var r1 error
- if rf, ok := ret.Get(1).(func(models.StudioPartial) error); ok {
- r1 = rf(updatedStudio)
+ if rf, ok := ret.Get(1).(func(context.Context, models.StudioPartial) error); ok {
+ r1 = rf(ctx, updatedStudio)
} else {
r1 = ret.Error(1)
}
@@ -388,13 +390,13 @@ func (_m *StudioReaderWriter) Update(updatedStudio models.StudioPartial) (*model
return r0, r1
}
-// UpdateAliases provides a mock function with given fields: studioID, aliases
-func (_m *StudioReaderWriter) UpdateAliases(studioID int, aliases []string) error {
- ret := _m.Called(studioID, aliases)
+// UpdateAliases provides a mock function with given fields: ctx, studioID, aliases
+func (_m *StudioReaderWriter) UpdateAliases(ctx context.Context, studioID int, aliases []string) error {
+ ret := _m.Called(ctx, studioID, aliases)
var r0 error
- if rf, ok := ret.Get(0).(func(int, []string) error); ok {
- r0 = rf(studioID, aliases)
+ if rf, ok := ret.Get(0).(func(context.Context, int, []string) error); ok {
+ r0 = rf(ctx, studioID, aliases)
} else {
r0 = ret.Error(0)
}
@@ -402,13 +404,13 @@ func (_m *StudioReaderWriter) UpdateAliases(studioID int, aliases []string) erro
return r0
}
-// UpdateFull provides a mock function with given fields: updatedStudio
-func (_m *StudioReaderWriter) UpdateFull(updatedStudio models.Studio) (*models.Studio, error) {
- ret := _m.Called(updatedStudio)
+// UpdateFull provides a mock function with given fields: ctx, updatedStudio
+func (_m *StudioReaderWriter) UpdateFull(ctx context.Context, updatedStudio models.Studio) (*models.Studio, error) {
+ ret := _m.Called(ctx, updatedStudio)
var r0 *models.Studio
- if rf, ok := ret.Get(0).(func(models.Studio) *models.Studio); ok {
- r0 = rf(updatedStudio)
+ if rf, ok := ret.Get(0).(func(context.Context, models.Studio) *models.Studio); ok {
+ r0 = rf(ctx, updatedStudio)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.Studio)
@@ -416,8 +418,8 @@ func (_m *StudioReaderWriter) UpdateFull(updatedStudio models.Studio) (*models.S
}
var r1 error
- if rf, ok := ret.Get(1).(func(models.Studio) error); ok {
- r1 = rf(updatedStudio)
+ if rf, ok := ret.Get(1).(func(context.Context, models.Studio) error); ok {
+ r1 = rf(ctx, updatedStudio)
} else {
r1 = ret.Error(1)
}
@@ -425,13 +427,13 @@ func (_m *StudioReaderWriter) UpdateFull(updatedStudio models.Studio) (*models.S
return r0, r1
}
-// UpdateImage provides a mock function with given fields: studioID, image
-func (_m *StudioReaderWriter) UpdateImage(studioID int, image []byte) error {
- ret := _m.Called(studioID, image)
+// UpdateImage provides a mock function with given fields: ctx, studioID, image
+func (_m *StudioReaderWriter) UpdateImage(ctx context.Context, studioID int, image []byte) error {
+ ret := _m.Called(ctx, studioID, image)
var r0 error
- if rf, ok := ret.Get(0).(func(int, []byte) error); ok {
- r0 = rf(studioID, image)
+ if rf, ok := ret.Get(0).(func(context.Context, int, []byte) error); ok {
+ r0 = rf(ctx, studioID, image)
} else {
r0 = ret.Error(0)
}
@@ -439,13 +441,13 @@ func (_m *StudioReaderWriter) UpdateImage(studioID int, image []byte) error {
return r0
}
-// UpdateStashIDs provides a mock function with given fields: studioID, stashIDs
-func (_m *StudioReaderWriter) UpdateStashIDs(studioID int, stashIDs []models.StashID) error {
- ret := _m.Called(studioID, stashIDs)
+// UpdateStashIDs provides a mock function with given fields: ctx, studioID, stashIDs
+func (_m *StudioReaderWriter) UpdateStashIDs(ctx context.Context, studioID int, stashIDs []models.StashID) error {
+ ret := _m.Called(ctx, studioID, stashIDs)
var r0 error
- if rf, ok := ret.Get(0).(func(int, []models.StashID) error); ok {
- r0 = rf(studioID, stashIDs)
+ if rf, ok := ret.Get(0).(func(context.Context, int, []models.StashID) error); ok {
+ r0 = rf(ctx, studioID, stashIDs)
} else {
r0 = ret.Error(0)
}
diff --git a/pkg/models/mocks/TagReaderWriter.go b/pkg/models/mocks/TagReaderWriter.go
index 64a8088a6..1a53adf05 100644
--- a/pkg/models/mocks/TagReaderWriter.go
+++ b/pkg/models/mocks/TagReaderWriter.go
@@ -3,6 +3,8 @@
package mocks
import (
+ context "context"
+
models "github.com/stashapp/stash/pkg/models"
mock "github.com/stretchr/testify/mock"
)
@@ -12,13 +14,13 @@ type TagReaderWriter struct {
mock.Mock
}
-// All provides a mock function with given fields:
-func (_m *TagReaderWriter) All() ([]*models.Tag, error) {
- ret := _m.Called()
+// All provides a mock function with given fields: ctx
+func (_m *TagReaderWriter) All(ctx context.Context) ([]*models.Tag, error) {
+ ret := _m.Called(ctx)
var r0 []*models.Tag
- if rf, ok := ret.Get(0).(func() []*models.Tag); ok {
- r0 = rf()
+ if rf, ok := ret.Get(0).(func(context.Context) []*models.Tag); ok {
+ r0 = rf(ctx)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Tag)
@@ -26,8 +28,8 @@ func (_m *TagReaderWriter) All() ([]*models.Tag, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func() error); ok {
- r1 = rf()
+ if rf, ok := ret.Get(1).(func(context.Context) error); ok {
+ r1 = rf(ctx)
} else {
r1 = ret.Error(1)
}
@@ -35,20 +37,20 @@ func (_m *TagReaderWriter) All() ([]*models.Tag, error) {
return r0, r1
}
-// Count provides a mock function with given fields:
-func (_m *TagReaderWriter) Count() (int, error) {
- ret := _m.Called()
+// Count provides a mock function with given fields: ctx
+func (_m *TagReaderWriter) Count(ctx context.Context) (int, error) {
+ ret := _m.Called(ctx)
var r0 int
- if rf, ok := ret.Get(0).(func() int); ok {
- r0 = rf()
+ if rf, ok := ret.Get(0).(func(context.Context) int); ok {
+ r0 = rf(ctx)
} else {
r0 = ret.Get(0).(int)
}
var r1 error
- if rf, ok := ret.Get(1).(func() error); ok {
- r1 = rf()
+ if rf, ok := ret.Get(1).(func(context.Context) error); ok {
+ r1 = rf(ctx)
} else {
r1 = ret.Error(1)
}
@@ -56,13 +58,13 @@ func (_m *TagReaderWriter) Count() (int, error) {
return r0, r1
}
-// Create provides a mock function with given fields: newTag
-func (_m *TagReaderWriter) Create(newTag models.Tag) (*models.Tag, error) {
- ret := _m.Called(newTag)
+// Create provides a mock function with given fields: ctx, newTag
+func (_m *TagReaderWriter) Create(ctx context.Context, newTag models.Tag) (*models.Tag, error) {
+ ret := _m.Called(ctx, newTag)
var r0 *models.Tag
- if rf, ok := ret.Get(0).(func(models.Tag) *models.Tag); ok {
- r0 = rf(newTag)
+ if rf, ok := ret.Get(0).(func(context.Context, models.Tag) *models.Tag); ok {
+ r0 = rf(ctx, newTag)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.Tag)
@@ -70,8 +72,8 @@ func (_m *TagReaderWriter) Create(newTag models.Tag) (*models.Tag, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func(models.Tag) error); ok {
- r1 = rf(newTag)
+ if rf, ok := ret.Get(1).(func(context.Context, models.Tag) error); ok {
+ r1 = rf(ctx, newTag)
} else {
r1 = ret.Error(1)
}
@@ -79,13 +81,13 @@ func (_m *TagReaderWriter) Create(newTag models.Tag) (*models.Tag, error) {
return r0, r1
}
-// Destroy provides a mock function with given fields: id
-func (_m *TagReaderWriter) Destroy(id int) error {
- ret := _m.Called(id)
+// Destroy provides a mock function with given fields: ctx, id
+func (_m *TagReaderWriter) Destroy(ctx context.Context, id int) error {
+ ret := _m.Called(ctx, id)
var r0 error
- if rf, ok := ret.Get(0).(func(int) error); ok {
- r0 = rf(id)
+ if rf, ok := ret.Get(0).(func(context.Context, int) error); ok {
+ r0 = rf(ctx, id)
} else {
r0 = ret.Error(0)
}
@@ -93,13 +95,13 @@ func (_m *TagReaderWriter) Destroy(id int) error {
return r0
}
-// DestroyImage provides a mock function with given fields: tagID
-func (_m *TagReaderWriter) DestroyImage(tagID int) error {
- ret := _m.Called(tagID)
+// DestroyImage provides a mock function with given fields: ctx, tagID
+func (_m *TagReaderWriter) DestroyImage(ctx context.Context, tagID int) error {
+ ret := _m.Called(ctx, tagID)
var r0 error
- if rf, ok := ret.Get(0).(func(int) error); ok {
- r0 = rf(tagID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) error); ok {
+ r0 = rf(ctx, tagID)
} else {
r0 = ret.Error(0)
}
@@ -107,13 +109,13 @@ func (_m *TagReaderWriter) DestroyImage(tagID int) error {
return r0
}
-// Find provides a mock function with given fields: id
-func (_m *TagReaderWriter) Find(id int) (*models.Tag, error) {
- ret := _m.Called(id)
+// Find provides a mock function with given fields: ctx, id
+func (_m *TagReaderWriter) Find(ctx context.Context, id int) (*models.Tag, error) {
+ ret := _m.Called(ctx, id)
var r0 *models.Tag
- if rf, ok := ret.Get(0).(func(int) *models.Tag); ok {
- r0 = rf(id)
+ if rf, ok := ret.Get(0).(func(context.Context, int) *models.Tag); ok {
+ r0 = rf(ctx, id)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.Tag)
@@ -121,8 +123,8 @@ func (_m *TagReaderWriter) Find(id int) (*models.Tag, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(id)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, id)
} else {
r1 = ret.Error(1)
}
@@ -130,13 +132,13 @@ func (_m *TagReaderWriter) Find(id int) (*models.Tag, error) {
return r0, r1
}
-// FindAllAncestors provides a mock function with given fields: tagID, excludeIDs
-func (_m *TagReaderWriter) FindAllAncestors(tagID int, excludeIDs []int) ([]*models.TagPath, error) {
- ret := _m.Called(tagID, excludeIDs)
+// FindAllAncestors provides a mock function with given fields: ctx, tagID, excludeIDs
+func (_m *TagReaderWriter) FindAllAncestors(ctx context.Context, tagID int, excludeIDs []int) ([]*models.TagPath, error) {
+ ret := _m.Called(ctx, tagID, excludeIDs)
var r0 []*models.TagPath
- if rf, ok := ret.Get(0).(func(int, []int) []*models.TagPath); ok {
- r0 = rf(tagID, excludeIDs)
+ if rf, ok := ret.Get(0).(func(context.Context, int, []int) []*models.TagPath); ok {
+ r0 = rf(ctx, tagID, excludeIDs)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.TagPath)
@@ -144,8 +146,8 @@ func (_m *TagReaderWriter) FindAllAncestors(tagID int, excludeIDs []int) ([]*mod
}
var r1 error
- if rf, ok := ret.Get(1).(func(int, []int) error); ok {
- r1 = rf(tagID, excludeIDs)
+ if rf, ok := ret.Get(1).(func(context.Context, int, []int) error); ok {
+ r1 = rf(ctx, tagID, excludeIDs)
} else {
r1 = ret.Error(1)
}
@@ -153,13 +155,13 @@ func (_m *TagReaderWriter) FindAllAncestors(tagID int, excludeIDs []int) ([]*mod
return r0, r1
}
-// FindAllDescendants provides a mock function with given fields: tagID, excludeIDs
-func (_m *TagReaderWriter) FindAllDescendants(tagID int, excludeIDs []int) ([]*models.TagPath, error) {
- ret := _m.Called(tagID, excludeIDs)
+// FindAllDescendants provides a mock function with given fields: ctx, tagID, excludeIDs
+func (_m *TagReaderWriter) FindAllDescendants(ctx context.Context, tagID int, excludeIDs []int) ([]*models.TagPath, error) {
+ ret := _m.Called(ctx, tagID, excludeIDs)
var r0 []*models.TagPath
- if rf, ok := ret.Get(0).(func(int, []int) []*models.TagPath); ok {
- r0 = rf(tagID, excludeIDs)
+ if rf, ok := ret.Get(0).(func(context.Context, int, []int) []*models.TagPath); ok {
+ r0 = rf(ctx, tagID, excludeIDs)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.TagPath)
@@ -167,8 +169,8 @@ func (_m *TagReaderWriter) FindAllDescendants(tagID int, excludeIDs []int) ([]*m
}
var r1 error
- if rf, ok := ret.Get(1).(func(int, []int) error); ok {
- r1 = rf(tagID, excludeIDs)
+ if rf, ok := ret.Get(1).(func(context.Context, int, []int) error); ok {
+ r1 = rf(ctx, tagID, excludeIDs)
} else {
r1 = ret.Error(1)
}
@@ -176,13 +178,13 @@ func (_m *TagReaderWriter) FindAllDescendants(tagID int, excludeIDs []int) ([]*m
return r0, r1
}
-// FindByChildTagID provides a mock function with given fields: childID
-func (_m *TagReaderWriter) FindByChildTagID(childID int) ([]*models.Tag, error) {
- ret := _m.Called(childID)
+// FindByChildTagID provides a mock function with given fields: ctx, childID
+func (_m *TagReaderWriter) FindByChildTagID(ctx context.Context, childID int) ([]*models.Tag, error) {
+ ret := _m.Called(ctx, childID)
var r0 []*models.Tag
- if rf, ok := ret.Get(0).(func(int) []*models.Tag); ok {
- r0 = rf(childID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) []*models.Tag); ok {
+ r0 = rf(ctx, childID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Tag)
@@ -190,8 +192,8 @@ func (_m *TagReaderWriter) FindByChildTagID(childID int) ([]*models.Tag, error)
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(childID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, childID)
} else {
r1 = ret.Error(1)
}
@@ -199,13 +201,13 @@ func (_m *TagReaderWriter) FindByChildTagID(childID int) ([]*models.Tag, error)
return r0, r1
}
-// FindByGalleryID provides a mock function with given fields: galleryID
-func (_m *TagReaderWriter) FindByGalleryID(galleryID int) ([]*models.Tag, error) {
- ret := _m.Called(galleryID)
+// FindByGalleryID provides a mock function with given fields: ctx, galleryID
+func (_m *TagReaderWriter) FindByGalleryID(ctx context.Context, galleryID int) ([]*models.Tag, error) {
+ ret := _m.Called(ctx, galleryID)
var r0 []*models.Tag
- if rf, ok := ret.Get(0).(func(int) []*models.Tag); ok {
- r0 = rf(galleryID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) []*models.Tag); ok {
+ r0 = rf(ctx, galleryID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Tag)
@@ -213,8 +215,8 @@ func (_m *TagReaderWriter) FindByGalleryID(galleryID int) ([]*models.Tag, error)
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(galleryID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, galleryID)
} else {
r1 = ret.Error(1)
}
@@ -222,13 +224,13 @@ func (_m *TagReaderWriter) FindByGalleryID(galleryID int) ([]*models.Tag, error)
return r0, r1
}
-// FindByImageID provides a mock function with given fields: imageID
-func (_m *TagReaderWriter) FindByImageID(imageID int) ([]*models.Tag, error) {
- ret := _m.Called(imageID)
+// FindByImageID provides a mock function with given fields: ctx, imageID
+func (_m *TagReaderWriter) FindByImageID(ctx context.Context, imageID int) ([]*models.Tag, error) {
+ ret := _m.Called(ctx, imageID)
var r0 []*models.Tag
- if rf, ok := ret.Get(0).(func(int) []*models.Tag); ok {
- r0 = rf(imageID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) []*models.Tag); ok {
+ r0 = rf(ctx, imageID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Tag)
@@ -236,8 +238,8 @@ func (_m *TagReaderWriter) FindByImageID(imageID int) ([]*models.Tag, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(imageID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, imageID)
} else {
r1 = ret.Error(1)
}
@@ -245,13 +247,13 @@ func (_m *TagReaderWriter) FindByImageID(imageID int) ([]*models.Tag, error) {
return r0, r1
}
-// FindByName provides a mock function with given fields: name, nocase
-func (_m *TagReaderWriter) FindByName(name string, nocase bool) (*models.Tag, error) {
- ret := _m.Called(name, nocase)
+// FindByName provides a mock function with given fields: ctx, name, nocase
+func (_m *TagReaderWriter) FindByName(ctx context.Context, name string, nocase bool) (*models.Tag, error) {
+ ret := _m.Called(ctx, name, nocase)
var r0 *models.Tag
- if rf, ok := ret.Get(0).(func(string, bool) *models.Tag); ok {
- r0 = rf(name, nocase)
+ if rf, ok := ret.Get(0).(func(context.Context, string, bool) *models.Tag); ok {
+ r0 = rf(ctx, name, nocase)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.Tag)
@@ -259,8 +261,8 @@ func (_m *TagReaderWriter) FindByName(name string, nocase bool) (*models.Tag, er
}
var r1 error
- if rf, ok := ret.Get(1).(func(string, bool) error); ok {
- r1 = rf(name, nocase)
+ if rf, ok := ret.Get(1).(func(context.Context, string, bool) error); ok {
+ r1 = rf(ctx, name, nocase)
} else {
r1 = ret.Error(1)
}
@@ -268,13 +270,13 @@ func (_m *TagReaderWriter) FindByName(name string, nocase bool) (*models.Tag, er
return r0, r1
}
-// FindByNames provides a mock function with given fields: names, nocase
-func (_m *TagReaderWriter) FindByNames(names []string, nocase bool) ([]*models.Tag, error) {
- ret := _m.Called(names, nocase)
+// FindByNames provides a mock function with given fields: ctx, names, nocase
+func (_m *TagReaderWriter) FindByNames(ctx context.Context, names []string, nocase bool) ([]*models.Tag, error) {
+ ret := _m.Called(ctx, names, nocase)
var r0 []*models.Tag
- if rf, ok := ret.Get(0).(func([]string, bool) []*models.Tag); ok {
- r0 = rf(names, nocase)
+ if rf, ok := ret.Get(0).(func(context.Context, []string, bool) []*models.Tag); ok {
+ r0 = rf(ctx, names, nocase)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Tag)
@@ -282,8 +284,8 @@ func (_m *TagReaderWriter) FindByNames(names []string, nocase bool) ([]*models.T
}
var r1 error
- if rf, ok := ret.Get(1).(func([]string, bool) error); ok {
- r1 = rf(names, nocase)
+ if rf, ok := ret.Get(1).(func(context.Context, []string, bool) error); ok {
+ r1 = rf(ctx, names, nocase)
} else {
r1 = ret.Error(1)
}
@@ -291,13 +293,13 @@ func (_m *TagReaderWriter) FindByNames(names []string, nocase bool) ([]*models.T
return r0, r1
}
-// FindByParentTagID provides a mock function with given fields: parentID
-func (_m *TagReaderWriter) FindByParentTagID(parentID int) ([]*models.Tag, error) {
- ret := _m.Called(parentID)
+// FindByParentTagID provides a mock function with given fields: ctx, parentID
+func (_m *TagReaderWriter) FindByParentTagID(ctx context.Context, parentID int) ([]*models.Tag, error) {
+ ret := _m.Called(ctx, parentID)
var r0 []*models.Tag
- if rf, ok := ret.Get(0).(func(int) []*models.Tag); ok {
- r0 = rf(parentID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) []*models.Tag); ok {
+ r0 = rf(ctx, parentID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Tag)
@@ -305,8 +307,8 @@ func (_m *TagReaderWriter) FindByParentTagID(parentID int) ([]*models.Tag, error
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(parentID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, parentID)
} else {
r1 = ret.Error(1)
}
@@ -314,13 +316,13 @@ func (_m *TagReaderWriter) FindByParentTagID(parentID int) ([]*models.Tag, error
return r0, r1
}
-// FindByPerformerID provides a mock function with given fields: performerID
-func (_m *TagReaderWriter) FindByPerformerID(performerID int) ([]*models.Tag, error) {
- ret := _m.Called(performerID)
+// FindByPerformerID provides a mock function with given fields: ctx, performerID
+func (_m *TagReaderWriter) FindByPerformerID(ctx context.Context, performerID int) ([]*models.Tag, error) {
+ ret := _m.Called(ctx, performerID)
var r0 []*models.Tag
- if rf, ok := ret.Get(0).(func(int) []*models.Tag); ok {
- r0 = rf(performerID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) []*models.Tag); ok {
+ r0 = rf(ctx, performerID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Tag)
@@ -328,8 +330,8 @@ func (_m *TagReaderWriter) FindByPerformerID(performerID int) ([]*models.Tag, er
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(performerID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, performerID)
} else {
r1 = ret.Error(1)
}
@@ -337,13 +339,13 @@ func (_m *TagReaderWriter) FindByPerformerID(performerID int) ([]*models.Tag, er
return r0, r1
}
-// FindBySceneID provides a mock function with given fields: sceneID
-func (_m *TagReaderWriter) FindBySceneID(sceneID int) ([]*models.Tag, error) {
- ret := _m.Called(sceneID)
+// FindBySceneID provides a mock function with given fields: ctx, sceneID
+func (_m *TagReaderWriter) FindBySceneID(ctx context.Context, sceneID int) ([]*models.Tag, error) {
+ ret := _m.Called(ctx, sceneID)
var r0 []*models.Tag
- if rf, ok := ret.Get(0).(func(int) []*models.Tag); ok {
- r0 = rf(sceneID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) []*models.Tag); ok {
+ r0 = rf(ctx, sceneID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Tag)
@@ -351,8 +353,8 @@ func (_m *TagReaderWriter) FindBySceneID(sceneID int) ([]*models.Tag, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(sceneID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, sceneID)
} else {
r1 = ret.Error(1)
}
@@ -360,13 +362,13 @@ func (_m *TagReaderWriter) FindBySceneID(sceneID int) ([]*models.Tag, error) {
return r0, r1
}
-// FindBySceneMarkerID provides a mock function with given fields: sceneMarkerID
-func (_m *TagReaderWriter) FindBySceneMarkerID(sceneMarkerID int) ([]*models.Tag, error) {
- ret := _m.Called(sceneMarkerID)
+// FindBySceneMarkerID provides a mock function with given fields: ctx, sceneMarkerID
+func (_m *TagReaderWriter) FindBySceneMarkerID(ctx context.Context, sceneMarkerID int) ([]*models.Tag, error) {
+ ret := _m.Called(ctx, sceneMarkerID)
var r0 []*models.Tag
- if rf, ok := ret.Get(0).(func(int) []*models.Tag); ok {
- r0 = rf(sceneMarkerID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) []*models.Tag); ok {
+ r0 = rf(ctx, sceneMarkerID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Tag)
@@ -374,8 +376,8 @@ func (_m *TagReaderWriter) FindBySceneMarkerID(sceneMarkerID int) ([]*models.Tag
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(sceneMarkerID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, sceneMarkerID)
} else {
r1 = ret.Error(1)
}
@@ -383,13 +385,13 @@ func (_m *TagReaderWriter) FindBySceneMarkerID(sceneMarkerID int) ([]*models.Tag
return r0, r1
}
-// FindMany provides a mock function with given fields: ids
-func (_m *TagReaderWriter) FindMany(ids []int) ([]*models.Tag, error) {
- ret := _m.Called(ids)
+// FindMany provides a mock function with given fields: ctx, ids
+func (_m *TagReaderWriter) FindMany(ctx context.Context, ids []int) ([]*models.Tag, error) {
+ ret := _m.Called(ctx, ids)
var r0 []*models.Tag
- if rf, ok := ret.Get(0).(func([]int) []*models.Tag); ok {
- r0 = rf(ids)
+ if rf, ok := ret.Get(0).(func(context.Context, []int) []*models.Tag); ok {
+ r0 = rf(ctx, ids)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Tag)
@@ -397,8 +399,8 @@ func (_m *TagReaderWriter) FindMany(ids []int) ([]*models.Tag, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func([]int) error); ok {
- r1 = rf(ids)
+ if rf, ok := ret.Get(1).(func(context.Context, []int) error); ok {
+ r1 = rf(ctx, ids)
} else {
r1 = ret.Error(1)
}
@@ -406,13 +408,13 @@ func (_m *TagReaderWriter) FindMany(ids []int) ([]*models.Tag, error) {
return r0, r1
}
-// GetAliases provides a mock function with given fields: tagID
-func (_m *TagReaderWriter) GetAliases(tagID int) ([]string, error) {
- ret := _m.Called(tagID)
+// GetAliases provides a mock function with given fields: ctx, tagID
+func (_m *TagReaderWriter) GetAliases(ctx context.Context, tagID int) ([]string, error) {
+ ret := _m.Called(ctx, tagID)
var r0 []string
- if rf, ok := ret.Get(0).(func(int) []string); ok {
- r0 = rf(tagID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) []string); ok {
+ r0 = rf(ctx, tagID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]string)
@@ -420,8 +422,8 @@ func (_m *TagReaderWriter) GetAliases(tagID int) ([]string, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(tagID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, tagID)
} else {
r1 = ret.Error(1)
}
@@ -429,13 +431,13 @@ func (_m *TagReaderWriter) GetAliases(tagID int) ([]string, error) {
return r0, r1
}
-// GetImage provides a mock function with given fields: tagID
-func (_m *TagReaderWriter) GetImage(tagID int) ([]byte, error) {
- ret := _m.Called(tagID)
+// GetImage provides a mock function with given fields: ctx, tagID
+func (_m *TagReaderWriter) GetImage(ctx context.Context, tagID int) ([]byte, error) {
+ ret := _m.Called(ctx, tagID)
var r0 []byte
- if rf, ok := ret.Get(0).(func(int) []byte); ok {
- r0 = rf(tagID)
+ if rf, ok := ret.Get(0).(func(context.Context, int) []byte); ok {
+ r0 = rf(ctx, tagID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]byte)
@@ -443,8 +445,8 @@ func (_m *TagReaderWriter) GetImage(tagID int) ([]byte, error) {
}
var r1 error
- if rf, ok := ret.Get(1).(func(int) error); ok {
- r1 = rf(tagID)
+ if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
+ r1 = rf(ctx, tagID)
} else {
r1 = ret.Error(1)
}
@@ -452,13 +454,13 @@ func (_m *TagReaderWriter) GetImage(tagID int) ([]byte, error) {
return r0, r1
}
-// Merge provides a mock function with given fields: source, destination
-func (_m *TagReaderWriter) Merge(source []int, destination int) error {
- ret := _m.Called(source, destination)
+// Merge provides a mock function with given fields: ctx, source, destination
+func (_m *TagReaderWriter) Merge(ctx context.Context, source []int, destination int) error {
+ ret := _m.Called(ctx, source, destination)
var r0 error
- if rf, ok := ret.Get(0).(func([]int, int) error); ok {
- r0 = rf(source, destination)
+ if rf, ok := ret.Get(0).(func(context.Context, []int, int) error); ok {
+ r0 = rf(ctx, source, destination)
} else {
r0 = ret.Error(0)
}
@@ -466,13 +468,13 @@ func (_m *TagReaderWriter) Merge(source []int, destination int) error {
return r0
}
-// Query provides a mock function with given fields: tagFilter, findFilter
-func (_m *TagReaderWriter) Query(tagFilter *models.TagFilterType, findFilter *models.FindFilterType) ([]*models.Tag, int, error) {
- ret := _m.Called(tagFilter, findFilter)
+// Query provides a mock function with given fields: ctx, tagFilter, findFilter
+func (_m *TagReaderWriter) Query(ctx context.Context, tagFilter *models.TagFilterType, findFilter *models.FindFilterType) ([]*models.Tag, int, error) {
+ ret := _m.Called(ctx, tagFilter, findFilter)
var r0 []*models.Tag
- if rf, ok := ret.Get(0).(func(*models.TagFilterType, *models.FindFilterType) []*models.Tag); ok {
- r0 = rf(tagFilter, findFilter)
+ if rf, ok := ret.Get(0).(func(context.Context, *models.TagFilterType, *models.FindFilterType) []*models.Tag); ok {
+ r0 = rf(ctx, tagFilter, findFilter)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Tag)
@@ -480,15 +482,15 @@ func (_m *TagReaderWriter) Query(tagFilter *models.TagFilterType, findFilter *mo
}
var r1 int
- if rf, ok := ret.Get(1).(func(*models.TagFilterType, *models.FindFilterType) int); ok {
- r1 = rf(tagFilter, findFilter)
+ if rf, ok := ret.Get(1).(func(context.Context, *models.TagFilterType, *models.FindFilterType) int); ok {
+ r1 = rf(ctx, tagFilter, findFilter)
} else {
r1 = ret.Get(1).(int)
}
var r2 error
- if rf, ok := ret.Get(2).(func(*models.TagFilterType, *models.FindFilterType) error); ok {
- r2 = rf(tagFilter, findFilter)
+ if rf, ok := ret.Get(2).(func(context.Context, *models.TagFilterType, *models.FindFilterType) error); ok {
+ r2 = rf(ctx, tagFilter, findFilter)
} else {
r2 = ret.Error(2)
}
@@ -496,13 +498,13 @@ func (_m *TagReaderWriter) Query(tagFilter *models.TagFilterType, findFilter *mo
return r0, r1, r2
}
-// QueryForAutoTag provides a mock function with given fields: words
-func (_m *TagReaderWriter) QueryForAutoTag(words []string) ([]*models.Tag, error) {
- ret := _m.Called(words)
+// QueryForAutoTag provides a mock function with given fields: ctx, words
+func (_m *TagReaderWriter) QueryForAutoTag(ctx context.Context, words []string) ([]*models.Tag, error) {
+ ret := _m.Called(ctx, words)
var r0 []*models.Tag
- if rf, ok := ret.Get(0).(func([]string) []*models.Tag); ok {
- r0 = rf(words)
+ if rf, ok := ret.Get(0).(func(context.Context, []string) []*models.Tag); ok {
+ r0 = rf(ctx, words)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Tag)
@@ -510,8 +512,8 @@ func (_m *TagReaderWriter) QueryForAutoTag(words []string) ([]*models.Tag, error
}
var r1 error
- if rf, ok := ret.Get(1).(func([]string) error); ok {
- r1 = rf(words)
+ if rf, ok := ret.Get(1).(func(context.Context, []string) error); ok {
+ r1 = rf(ctx, words)
} else {
r1 = ret.Error(1)
}
@@ -519,13 +521,13 @@ func (_m *TagReaderWriter) QueryForAutoTag(words []string) ([]*models.Tag, error
return r0, r1
}
-// Update provides a mock function with given fields: updateTag
-func (_m *TagReaderWriter) Update(updateTag models.TagPartial) (*models.Tag, error) {
- ret := _m.Called(updateTag)
+// Update provides a mock function with given fields: ctx, updateTag
+func (_m *TagReaderWriter) Update(ctx context.Context, updateTag models.TagPartial) (*models.Tag, error) {
+ ret := _m.Called(ctx, updateTag)
var r0 *models.Tag
- if rf, ok := ret.Get(0).(func(models.TagPartial) *models.Tag); ok {
- r0 = rf(updateTag)
+ if rf, ok := ret.Get(0).(func(context.Context, models.TagPartial) *models.Tag); ok {
+ r0 = rf(ctx, updateTag)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.Tag)
@@ -533,8 +535,8 @@ func (_m *TagReaderWriter) Update(updateTag models.TagPartial) (*models.Tag, err
}
var r1 error
- if rf, ok := ret.Get(1).(func(models.TagPartial) error); ok {
- r1 = rf(updateTag)
+ if rf, ok := ret.Get(1).(func(context.Context, models.TagPartial) error); ok {
+ r1 = rf(ctx, updateTag)
} else {
r1 = ret.Error(1)
}
@@ -542,13 +544,13 @@ func (_m *TagReaderWriter) Update(updateTag models.TagPartial) (*models.Tag, err
return r0, r1
}
-// UpdateAliases provides a mock function with given fields: tagID, aliases
-func (_m *TagReaderWriter) UpdateAliases(tagID int, aliases []string) error {
- ret := _m.Called(tagID, aliases)
+// UpdateAliases provides a mock function with given fields: ctx, tagID, aliases
+func (_m *TagReaderWriter) UpdateAliases(ctx context.Context, tagID int, aliases []string) error {
+ ret := _m.Called(ctx, tagID, aliases)
var r0 error
- if rf, ok := ret.Get(0).(func(int, []string) error); ok {
- r0 = rf(tagID, aliases)
+ if rf, ok := ret.Get(0).(func(context.Context, int, []string) error); ok {
+ r0 = rf(ctx, tagID, aliases)
} else {
r0 = ret.Error(0)
}
@@ -556,13 +558,13 @@ func (_m *TagReaderWriter) UpdateAliases(tagID int, aliases []string) error {
return r0
}
-// UpdateChildTags provides a mock function with given fields: tagID, parentIDs
-func (_m *TagReaderWriter) UpdateChildTags(tagID int, parentIDs []int) error {
- ret := _m.Called(tagID, parentIDs)
+// UpdateChildTags provides a mock function with given fields: ctx, tagID, parentIDs
+func (_m *TagReaderWriter) UpdateChildTags(ctx context.Context, tagID int, parentIDs []int) error {
+ ret := _m.Called(ctx, tagID, parentIDs)
var r0 error
- if rf, ok := ret.Get(0).(func(int, []int) error); ok {
- r0 = rf(tagID, parentIDs)
+ if rf, ok := ret.Get(0).(func(context.Context, int, []int) error); ok {
+ r0 = rf(ctx, tagID, parentIDs)
} else {
r0 = ret.Error(0)
}
@@ -570,13 +572,13 @@ func (_m *TagReaderWriter) UpdateChildTags(tagID int, parentIDs []int) error {
return r0
}
-// UpdateFull provides a mock function with given fields: updatedTag
-func (_m *TagReaderWriter) UpdateFull(updatedTag models.Tag) (*models.Tag, error) {
- ret := _m.Called(updatedTag)
+// UpdateFull provides a mock function with given fields: ctx, updatedTag
+func (_m *TagReaderWriter) UpdateFull(ctx context.Context, updatedTag models.Tag) (*models.Tag, error) {
+ ret := _m.Called(ctx, updatedTag)
var r0 *models.Tag
- if rf, ok := ret.Get(0).(func(models.Tag) *models.Tag); ok {
- r0 = rf(updatedTag)
+ if rf, ok := ret.Get(0).(func(context.Context, models.Tag) *models.Tag); ok {
+ r0 = rf(ctx, updatedTag)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.Tag)
@@ -584,8 +586,8 @@ func (_m *TagReaderWriter) UpdateFull(updatedTag models.Tag) (*models.Tag, error
}
var r1 error
- if rf, ok := ret.Get(1).(func(models.Tag) error); ok {
- r1 = rf(updatedTag)
+ if rf, ok := ret.Get(1).(func(context.Context, models.Tag) error); ok {
+ r1 = rf(ctx, updatedTag)
} else {
r1 = ret.Error(1)
}
@@ -593,13 +595,13 @@ func (_m *TagReaderWriter) UpdateFull(updatedTag models.Tag) (*models.Tag, error
return r0, r1
}
-// UpdateImage provides a mock function with given fields: tagID, image
-func (_m *TagReaderWriter) UpdateImage(tagID int, image []byte) error {
- ret := _m.Called(tagID, image)
+// UpdateImage provides a mock function with given fields: ctx, tagID, image
+func (_m *TagReaderWriter) UpdateImage(ctx context.Context, tagID int, image []byte) error {
+ ret := _m.Called(ctx, tagID, image)
var r0 error
- if rf, ok := ret.Get(0).(func(int, []byte) error); ok {
- r0 = rf(tagID, image)
+ if rf, ok := ret.Get(0).(func(context.Context, int, []byte) error); ok {
+ r0 = rf(ctx, tagID, image)
} else {
r0 = ret.Error(0)
}
@@ -607,13 +609,13 @@ func (_m *TagReaderWriter) UpdateImage(tagID int, image []byte) error {
return r0
}
-// UpdateParentTags provides a mock function with given fields: tagID, parentIDs
-func (_m *TagReaderWriter) UpdateParentTags(tagID int, parentIDs []int) error {
- ret := _m.Called(tagID, parentIDs)
+// UpdateParentTags provides a mock function with given fields: ctx, tagID, parentIDs
+func (_m *TagReaderWriter) UpdateParentTags(ctx context.Context, tagID int, parentIDs []int) error {
+ ret := _m.Called(ctx, tagID, parentIDs)
var r0 error
- if rf, ok := ret.Get(0).(func(int, []int) error); ok {
- r0 = rf(tagID, parentIDs)
+ if rf, ok := ret.Get(0).(func(context.Context, int, []int) error); ok {
+ r0 = rf(ctx, tagID, parentIDs)
} else {
r0 = ret.Error(0)
}
diff --git a/pkg/models/mocks/query.go b/pkg/models/mocks/query.go
index 152335fc2..346bd1e55 100644
--- a/pkg/models/mocks/query.go
+++ b/pkg/models/mocks/query.go
@@ -1,16 +1,20 @@
package mocks
-import "github.com/stashapp/stash/pkg/models"
+import (
+ context "context"
+
+ "github.com/stashapp/stash/pkg/models"
+)
type sceneResolver struct {
scenes []*models.Scene
}
-func (s *sceneResolver) Find(id int) (*models.Scene, error) {
+func (s *sceneResolver) Find(ctx context.Context, id int) (*models.Scene, error) {
panic("not implemented")
}
-func (s *sceneResolver) FindMany(ids []int) ([]*models.Scene, error) {
+func (s *sceneResolver) FindMany(ctx context.Context, ids []int) ([]*models.Scene, error) {
return s.scenes, nil
}
@@ -27,7 +31,7 @@ type imageResolver struct {
images []*models.Image
}
-func (s *imageResolver) FindMany(ids []int) ([]*models.Image, error) {
+func (s *imageResolver) FindMany(ctx context.Context, ids []int) ([]*models.Image, error) {
return s.images, nil
}
diff --git a/pkg/models/mocks/transaction.go b/pkg/models/mocks/transaction.go
index 886fef7d6..ab5c7dba3 100644
--- a/pkg/models/mocks/transaction.go
+++ b/pkg/models/mocks/transaction.go
@@ -1,167 +1,41 @@
package mocks
import (
- "context"
+ context "context"
- models "github.com/stashapp/stash/pkg/models"
+ "github.com/stashapp/stash/pkg/models"
)
-type TransactionManager struct {
- gallery *GalleryReaderWriter
- image *ImageReaderWriter
- movie *MovieReaderWriter
- performer *PerformerReaderWriter
- scene *SceneReaderWriter
- sceneMarker *SceneMarkerReaderWriter
- scrapedItem *ScrapedItemReaderWriter
- studio *StudioReaderWriter
- tag *TagReaderWriter
- savedFilter *SavedFilterReaderWriter
+type TxnManager struct{}
+
+func (*TxnManager) Begin(ctx context.Context) (context.Context, error) {
+ return ctx, nil
}
-func NewTransactionManager() *TransactionManager {
- return &TransactionManager{
- gallery: &GalleryReaderWriter{},
- image: &ImageReaderWriter{},
- movie: &MovieReaderWriter{},
- performer: &PerformerReaderWriter{},
- scene: &SceneReaderWriter{},
- sceneMarker: &SceneMarkerReaderWriter{},
- scrapedItem: &ScrapedItemReaderWriter{},
- studio: &StudioReaderWriter{},
- tag: &TagReaderWriter{},
- savedFilter: &SavedFilterReaderWriter{},
+func (*TxnManager) Commit(ctx context.Context) error {
+ return nil
+}
+
+func (*TxnManager) Rollback(ctx context.Context) error {
+ return nil
+}
+
+func (*TxnManager) Reset() error {
+ return nil
+}
+
+func NewTxnRepository() models.Repository {
+ return models.Repository{
+ TxnManager: &TxnManager{},
+ Gallery: &GalleryReaderWriter{},
+ Image: &ImageReaderWriter{},
+ Movie: &MovieReaderWriter{},
+ Performer: &PerformerReaderWriter{},
+ Scene: &SceneReaderWriter{},
+ SceneMarker: &SceneMarkerReaderWriter{},
+ ScrapedItem: &ScrapedItemReaderWriter{},
+ Studio: &StudioReaderWriter{},
+ Tag: &TagReaderWriter{},
+ SavedFilter: &SavedFilterReaderWriter{},
}
}
-
-func (t *TransactionManager) WithTxn(ctx context.Context, fn func(r models.Repository) error) error {
- return fn(t)
-}
-
-func (t *TransactionManager) GalleryMock() *GalleryReaderWriter {
- return t.gallery
-}
-
-func (t *TransactionManager) ImageMock() *ImageReaderWriter {
- return t.image
-}
-
-func (t *TransactionManager) MovieMock() *MovieReaderWriter {
- return t.movie
-}
-
-func (t *TransactionManager) PerformerMock() *PerformerReaderWriter {
- return t.performer
-}
-
-func (t *TransactionManager) SceneMarkerMock() *SceneMarkerReaderWriter {
- return t.sceneMarker
-}
-
-func (t *TransactionManager) SceneMock() *SceneReaderWriter {
- return t.scene
-}
-
-func (t *TransactionManager) ScrapedItemMock() *ScrapedItemReaderWriter {
- return t.scrapedItem
-}
-
-func (t *TransactionManager) StudioMock() *StudioReaderWriter {
- return t.studio
-}
-
-func (t *TransactionManager) TagMock() *TagReaderWriter {
- return t.tag
-}
-
-func (t *TransactionManager) SavedFilterMock() *SavedFilterReaderWriter {
- return t.savedFilter
-}
-
-func (t *TransactionManager) Gallery() models.GalleryReaderWriter {
- return t.GalleryMock()
-}
-
-func (t *TransactionManager) Image() models.ImageReaderWriter {
- return t.ImageMock()
-}
-
-func (t *TransactionManager) Movie() models.MovieReaderWriter {
- return t.MovieMock()
-}
-
-func (t *TransactionManager) Performer() models.PerformerReaderWriter {
- return t.PerformerMock()
-}
-
-func (t *TransactionManager) SceneMarker() models.SceneMarkerReaderWriter {
- return t.SceneMarkerMock()
-}
-
-func (t *TransactionManager) Scene() models.SceneReaderWriter {
- return t.SceneMock()
-}
-
-func (t *TransactionManager) ScrapedItem() models.ScrapedItemReaderWriter {
- return t.ScrapedItemMock()
-}
-
-func (t *TransactionManager) Studio() models.StudioReaderWriter {
- return t.StudioMock()
-}
-
-func (t *TransactionManager) Tag() models.TagReaderWriter {
- return t.TagMock()
-}
-
-func (t *TransactionManager) SavedFilter() models.SavedFilterReaderWriter {
- return t.SavedFilterMock()
-}
-
-type ReadTransaction struct {
- *TransactionManager
-}
-
-func (t *TransactionManager) WithReadTxn(ctx context.Context, fn func(r models.ReaderRepository) error) error {
- return fn(&ReadTransaction{t})
-}
-
-func (r *ReadTransaction) Gallery() models.GalleryReader {
- return r.GalleryMock()
-}
-
-func (r *ReadTransaction) Image() models.ImageReader {
- return r.ImageMock()
-}
-
-func (r *ReadTransaction) Movie() models.MovieReader {
- return r.MovieMock()
-}
-
-func (r *ReadTransaction) Performer() models.PerformerReader {
- return r.PerformerMock()
-}
-
-func (r *ReadTransaction) SceneMarker() models.SceneMarkerReader {
- return r.SceneMarkerMock()
-}
-
-func (r *ReadTransaction) Scene() models.SceneReader {
- return r.SceneMock()
-}
-
-func (r *ReadTransaction) ScrapedItem() models.ScrapedItemReader {
- return r.ScrapedItemMock()
-}
-
-func (r *ReadTransaction) Studio() models.StudioReader {
- return r.StudioMock()
-}
-
-func (r *ReadTransaction) Tag() models.TagReader {
- return r.TagMock()
-}
-
-func (r *ReadTransaction) SavedFilter() models.SavedFilterReader {
- return r.SavedFilterMock()
-}
diff --git a/pkg/models/movie.go b/pkg/models/movie.go
index 8b68217ce..3fc1890a6 100644
--- a/pkg/models/movie.go
+++ b/pkg/models/movie.go
@@ -1,5 +1,7 @@
package models
+import "context"
+
type MovieFilterType struct {
Name *StringCriterionInput `json:"name"`
Director *StringCriterionInput `json:"director"`
@@ -19,29 +21,29 @@ type MovieFilterType struct {
}
type MovieReader interface {
- Find(id int) (*Movie, error)
- FindMany(ids []int) ([]*Movie, error)
+ Find(ctx context.Context, id int) (*Movie, error)
+ FindMany(ctx context.Context, ids []int) ([]*Movie, error)
// FindBySceneID(sceneID int) ([]*Movie, error)
- FindByName(name string, nocase bool) (*Movie, error)
- FindByNames(names []string, nocase bool) ([]*Movie, error)
- All() ([]*Movie, error)
- Count() (int, error)
- Query(movieFilter *MovieFilterType, findFilter *FindFilterType) ([]*Movie, int, error)
- GetFrontImage(movieID int) ([]byte, error)
- GetBackImage(movieID int) ([]byte, error)
- FindByPerformerID(performerID int) ([]*Movie, error)
- CountByPerformerID(performerID int) (int, error)
- FindByStudioID(studioID int) ([]*Movie, error)
- CountByStudioID(studioID int) (int, error)
+ FindByName(ctx context.Context, name string, nocase bool) (*Movie, error)
+ FindByNames(ctx context.Context, names []string, nocase bool) ([]*Movie, error)
+ All(ctx context.Context) ([]*Movie, error)
+ Count(ctx context.Context) (int, error)
+ Query(ctx context.Context, movieFilter *MovieFilterType, findFilter *FindFilterType) ([]*Movie, int, error)
+ GetFrontImage(ctx context.Context, movieID int) ([]byte, error)
+ GetBackImage(ctx context.Context, movieID int) ([]byte, error)
+ FindByPerformerID(ctx context.Context, performerID int) ([]*Movie, error)
+ CountByPerformerID(ctx context.Context, performerID int) (int, error)
+ FindByStudioID(ctx context.Context, studioID int) ([]*Movie, error)
+ CountByStudioID(ctx context.Context, studioID int) (int, error)
}
type MovieWriter interface {
- Create(newMovie Movie) (*Movie, error)
- Update(updatedMovie MoviePartial) (*Movie, error)
- UpdateFull(updatedMovie Movie) (*Movie, error)
- Destroy(id int) error
- UpdateImages(movieID int, frontImage []byte, backImage []byte) error
- DestroyImages(movieID int) error
+ Create(ctx context.Context, newMovie Movie) (*Movie, error)
+ Update(ctx context.Context, updatedMovie MoviePartial) (*Movie, error)
+ UpdateFull(ctx context.Context, updatedMovie Movie) (*Movie, error)
+ Destroy(ctx context.Context, id int) error
+ UpdateImages(ctx context.Context, movieID int, frontImage []byte, backImage []byte) error
+ DestroyImages(ctx context.Context, movieID int) error
}
type MovieReaderWriter interface {
diff --git a/pkg/models/performer.go b/pkg/models/performer.go
index d50c360a4..1bf3ec918 100644
--- a/pkg/models/performer.go
+++ b/pkg/models/performer.go
@@ -1,6 +1,7 @@
package models
import (
+ "context"
"fmt"
"io"
"strconv"
@@ -125,36 +126,36 @@ type PerformerFilterType struct {
}
type PerformerReader interface {
- Find(id int) (*Performer, error)
- FindMany(ids []int) ([]*Performer, error)
- FindBySceneID(sceneID int) ([]*Performer, error)
- FindNamesBySceneID(sceneID int) ([]*Performer, error)
- FindByImageID(imageID int) ([]*Performer, error)
- FindByGalleryID(galleryID int) ([]*Performer, error)
- FindByNames(names []string, nocase bool) ([]*Performer, error)
- FindByStashID(stashID StashID) ([]*Performer, error)
- FindByStashIDStatus(hasStashID bool, stashboxEndpoint string) ([]*Performer, error)
- CountByTagID(tagID int) (int, error)
- Count() (int, error)
- All() ([]*Performer, error)
+ Find(ctx context.Context, id int) (*Performer, error)
+ FindMany(ctx context.Context, ids []int) ([]*Performer, error)
+ FindBySceneID(ctx context.Context, sceneID int) ([]*Performer, error)
+ FindNamesBySceneID(ctx context.Context, sceneID int) ([]*Performer, error)
+ FindByImageID(ctx context.Context, imageID int) ([]*Performer, error)
+ FindByGalleryID(ctx context.Context, galleryID int) ([]*Performer, error)
+ FindByNames(ctx context.Context, names []string, nocase bool) ([]*Performer, error)
+ FindByStashID(ctx context.Context, stashID StashID) ([]*Performer, error)
+ FindByStashIDStatus(ctx context.Context, hasStashID bool, stashboxEndpoint string) ([]*Performer, error)
+ CountByTagID(ctx context.Context, tagID int) (int, error)
+ Count(ctx context.Context) (int, error)
+ All(ctx context.Context) ([]*Performer, error)
// TODO - this interface is temporary until the filter schema can fully
// support the query needed
- QueryForAutoTag(words []string) ([]*Performer, error)
- Query(performerFilter *PerformerFilterType, findFilter *FindFilterType) ([]*Performer, int, error)
- GetImage(performerID int) ([]byte, error)
- GetStashIDs(performerID int) ([]*StashID, error)
- GetTagIDs(performerID int) ([]int, error)
+ QueryForAutoTag(ctx context.Context, words []string) ([]*Performer, error)
+ Query(ctx context.Context, performerFilter *PerformerFilterType, findFilter *FindFilterType) ([]*Performer, int, error)
+ GetImage(ctx context.Context, performerID int) ([]byte, error)
+ GetStashIDs(ctx context.Context, performerID int) ([]*StashID, error)
+ GetTagIDs(ctx context.Context, performerID int) ([]int, error)
}
type PerformerWriter interface {
- Create(newPerformer Performer) (*Performer, error)
- Update(updatedPerformer PerformerPartial) (*Performer, error)
- UpdateFull(updatedPerformer Performer) (*Performer, error)
- Destroy(id int) error
- UpdateImage(performerID int, image []byte) error
- DestroyImage(performerID int) error
- UpdateStashIDs(performerID int, stashIDs []StashID) error
- UpdateTags(performerID int, tagIDs []int) error
+ Create(ctx context.Context, newPerformer Performer) (*Performer, error)
+ Update(ctx context.Context, updatedPerformer PerformerPartial) (*Performer, error)
+ UpdateFull(ctx context.Context, updatedPerformer Performer) (*Performer, error)
+ Destroy(ctx context.Context, id int) error
+ UpdateImage(ctx context.Context, performerID int, image []byte) error
+ DestroyImage(ctx context.Context, performerID int) error
+ UpdateStashIDs(ctx context.Context, performerID int, stashIDs []StashID) error
+ UpdateTags(ctx context.Context, performerID int, tagIDs []int) error
}
type PerformerReaderWriter interface {
diff --git a/pkg/models/repository.go b/pkg/models/repository.go
index 2686d7c3a..0056ccad3 100644
--- a/pkg/models/repository.go
+++ b/pkg/models/repository.go
@@ -1,27 +1,31 @@
package models
-type Repository interface {
- Gallery() GalleryReaderWriter
- Image() ImageReaderWriter
- Movie() MovieReaderWriter
- Performer() PerformerReaderWriter
- Scene() SceneReaderWriter
- SceneMarker() SceneMarkerReaderWriter
- ScrapedItem() ScrapedItemReaderWriter
- Studio() StudioReaderWriter
- Tag() TagReaderWriter
- SavedFilter() SavedFilterReaderWriter
+import (
+ "context"
+
+ "github.com/stashapp/stash/pkg/txn"
+)
+
+type TxnManager interface {
+ txn.Manager
+ Reset() error
}
-type ReaderRepository interface {
- Gallery() GalleryReader
- Image() ImageReader
- Movie() MovieReader
- Performer() PerformerReader
- Scene() SceneReader
- SceneMarker() SceneMarkerReader
- ScrapedItem() ScrapedItemReader
- Studio() StudioReader
- Tag() TagReader
- SavedFilter() SavedFilterReader
+type Repository struct {
+ TxnManager
+
+ Gallery GalleryReaderWriter
+ Image ImageReaderWriter
+ Movie MovieReaderWriter
+ Performer PerformerReaderWriter
+ Scene SceneReaderWriter
+ SceneMarker SceneMarkerReaderWriter
+ ScrapedItem ScrapedItemReaderWriter
+ Studio StudioReaderWriter
+ Tag TagReaderWriter
+ SavedFilter SavedFilterReaderWriter
+}
+
+func (r *Repository) WithTxn(ctx context.Context, fn txn.TxnFunc) error {
+ return txn.WithTxn(ctx, r, fn)
}
diff --git a/pkg/models/saved_filter.go b/pkg/models/saved_filter.go
index e6cd2f8e0..10dd4af36 100644
--- a/pkg/models/saved_filter.go
+++ b/pkg/models/saved_filter.go
@@ -1,18 +1,20 @@
package models
+import "context"
+
type SavedFilterReader interface {
- All() ([]*SavedFilter, error)
- Find(id int) (*SavedFilter, error)
- FindMany(ids []int, ignoreNotFound bool) ([]*SavedFilter, error)
- FindByMode(mode FilterMode) ([]*SavedFilter, error)
- FindDefault(mode FilterMode) (*SavedFilter, error)
+ All(ctx context.Context) ([]*SavedFilter, error)
+ Find(ctx context.Context, id int) (*SavedFilter, error)
+ FindMany(ctx context.Context, ids []int, ignoreNotFound bool) ([]*SavedFilter, error)
+ FindByMode(ctx context.Context, mode FilterMode) ([]*SavedFilter, error)
+ FindDefault(ctx context.Context, mode FilterMode) (*SavedFilter, error)
}
type SavedFilterWriter interface {
- Create(obj SavedFilter) (*SavedFilter, error)
- Update(obj SavedFilter) (*SavedFilter, error)
- SetDefault(obj SavedFilter) (*SavedFilter, error)
- Destroy(id int) error
+ Create(ctx context.Context, obj SavedFilter) (*SavedFilter, error)
+ Update(ctx context.Context, obj SavedFilter) (*SavedFilter, error)
+ SetDefault(ctx context.Context, obj SavedFilter) (*SavedFilter, error)
+ Destroy(ctx context.Context, id int) error
}
type SavedFilterReaderWriter interface {
diff --git a/pkg/models/scene.go b/pkg/models/scene.go
index 6940abe84..4c6c3aabe 100644
--- a/pkg/models/scene.go
+++ b/pkg/models/scene.go
@@ -1,5 +1,7 @@
package models
+import "context"
+
type PHashDuplicationCriterionInput struct {
Duplicated *bool `json:"duplicated"`
// Currently unimplemented
@@ -102,70 +104,70 @@ func NewSceneQueryResult(finder SceneFinder) *SceneQueryResult {
}
}
-func (r *SceneQueryResult) Resolve() ([]*Scene, error) {
+func (r *SceneQueryResult) Resolve(ctx context.Context) ([]*Scene, error) {
// cache results
if r.scenes == nil && r.resolveErr == nil {
- r.scenes, r.resolveErr = r.finder.FindMany(r.IDs)
+ r.scenes, r.resolveErr = r.finder.FindMany(ctx, r.IDs)
}
return r.scenes, r.resolveErr
}
type SceneFinder interface {
// TODO - rename this to Find and remove existing method
- FindMany(ids []int) ([]*Scene, error)
+ FindMany(ctx context.Context, ids []int) ([]*Scene, error)
}
type SceneReader interface {
SceneFinder
// TODO - remove this in another PR
- Find(id int) (*Scene, error)
- FindByChecksum(checksum string) (*Scene, error)
- FindByOSHash(oshash string) (*Scene, error)
- FindByPath(path string) (*Scene, error)
- FindByPerformerID(performerID int) ([]*Scene, error)
- FindByGalleryID(performerID int) ([]*Scene, error)
- FindDuplicates(distance int) ([][]*Scene, error)
- CountByPerformerID(performerID int) (int, error)
+ Find(ctx context.Context, id int) (*Scene, error)
+ FindByChecksum(ctx context.Context, checksum string) (*Scene, error)
+ FindByOSHash(ctx context.Context, oshash string) (*Scene, error)
+ FindByPath(ctx context.Context, path string) (*Scene, error)
+ FindByPerformerID(ctx context.Context, performerID int) ([]*Scene, error)
+ FindByGalleryID(ctx context.Context, performerID int) ([]*Scene, error)
+ FindDuplicates(ctx context.Context, distance int) ([][]*Scene, error)
+ CountByPerformerID(ctx context.Context, performerID int) (int, error)
// FindByStudioID(studioID int) ([]*Scene, error)
- FindByMovieID(movieID int) ([]*Scene, error)
- CountByMovieID(movieID int) (int, error)
- Count() (int, error)
- Size() (float64, error)
- Duration() (float64, error)
+ FindByMovieID(ctx context.Context, movieID int) ([]*Scene, error)
+ CountByMovieID(ctx context.Context, movieID int) (int, error)
+ Count(ctx context.Context) (int, error)
+ Size(ctx context.Context) (float64, error)
+ Duration(ctx context.Context) (float64, error)
// SizeCount() (string, error)
- CountByStudioID(studioID int) (int, error)
- CountByTagID(tagID int) (int, error)
- CountMissingChecksum() (int, error)
- CountMissingOSHash() (int, error)
- Wall(q *string) ([]*Scene, error)
- All() ([]*Scene, error)
- Query(options SceneQueryOptions) (*SceneQueryResult, error)
- GetCaptions(sceneID int) ([]*SceneCaption, error)
- GetCover(sceneID int) ([]byte, error)
- GetMovies(sceneID int) ([]MoviesScenes, error)
- GetTagIDs(sceneID int) ([]int, error)
- GetGalleryIDs(sceneID int) ([]int, error)
- GetPerformerIDs(sceneID int) ([]int, error)
- GetStashIDs(sceneID int) ([]*StashID, error)
+ CountByStudioID(ctx context.Context, studioID int) (int, error)
+ CountByTagID(ctx context.Context, tagID int) (int, error)
+ CountMissingChecksum(ctx context.Context) (int, error)
+ CountMissingOSHash(ctx context.Context) (int, error)
+ Wall(ctx context.Context, q *string) ([]*Scene, error)
+ All(ctx context.Context) ([]*Scene, error)
+ Query(ctx context.Context, options SceneQueryOptions) (*SceneQueryResult, error)
+ GetCaptions(ctx context.Context, sceneID int) ([]*SceneCaption, error)
+ GetCover(ctx context.Context, sceneID int) ([]byte, error)
+ GetMovies(ctx context.Context, sceneID int) ([]MoviesScenes, error)
+ GetTagIDs(ctx context.Context, sceneID int) ([]int, error)
+ GetGalleryIDs(ctx context.Context, sceneID int) ([]int, error)
+ GetPerformerIDs(ctx context.Context, sceneID int) ([]int, error)
+ GetStashIDs(ctx context.Context, sceneID int) ([]*StashID, error)
}
type SceneWriter interface {
- Create(newScene Scene) (*Scene, error)
- Update(updatedScene ScenePartial) (*Scene, error)
- UpdateFull(updatedScene Scene) (*Scene, error)
- IncrementOCounter(id int) (int, error)
- DecrementOCounter(id int) (int, error)
- ResetOCounter(id int) (int, error)
- UpdateFileModTime(id int, modTime NullSQLiteTimestamp) error
- Destroy(id int) error
- UpdateCaptions(id int, captions []*SceneCaption) error
- UpdateCover(sceneID int, cover []byte) error
- DestroyCover(sceneID int) error
- UpdatePerformers(sceneID int, performerIDs []int) error
- UpdateTags(sceneID int, tagIDs []int) error
- UpdateGalleries(sceneID int, galleryIDs []int) error
- UpdateMovies(sceneID int, movies []MoviesScenes) error
- UpdateStashIDs(sceneID int, stashIDs []StashID) error
+ Create(ctx context.Context, newScene Scene) (*Scene, error)
+ Update(ctx context.Context, updatedScene ScenePartial) (*Scene, error)
+ UpdateFull(ctx context.Context, updatedScene Scene) (*Scene, error)
+ IncrementOCounter(ctx context.Context, id int) (int, error)
+ DecrementOCounter(ctx context.Context, id int) (int, error)
+ ResetOCounter(ctx context.Context, id int) (int, error)
+ UpdateFileModTime(ctx context.Context, id int, modTime NullSQLiteTimestamp) error
+ Destroy(ctx context.Context, id int) error
+ UpdateCaptions(ctx context.Context, id int, captions []*SceneCaption) error
+ UpdateCover(ctx context.Context, sceneID int, cover []byte) error
+ DestroyCover(ctx context.Context, sceneID int) error
+ UpdatePerformers(ctx context.Context, sceneID int, performerIDs []int) error
+ UpdateTags(ctx context.Context, sceneID int, tagIDs []int) error
+ UpdateGalleries(ctx context.Context, sceneID int, galleryIDs []int) error
+ UpdateMovies(ctx context.Context, sceneID int, movies []MoviesScenes) error
+ UpdateStashIDs(ctx context.Context, sceneID int, stashIDs []StashID) error
}
type SceneReaderWriter interface {
diff --git a/pkg/models/scene_marker.go b/pkg/models/scene_marker.go
index ac7501672..dd0b786f6 100644
--- a/pkg/models/scene_marker.go
+++ b/pkg/models/scene_marker.go
@@ -1,5 +1,7 @@
package models
+import "context"
+
type SceneMarkerFilterType struct {
// Filter to only include scene markers with this tag
TagID *string `json:"tag_id"`
@@ -18,21 +20,21 @@ type MarkerStringsResultType struct {
}
type SceneMarkerReader interface {
- Find(id int) (*SceneMarker, error)
- FindMany(ids []int) ([]*SceneMarker, error)
- FindBySceneID(sceneID int) ([]*SceneMarker, error)
- CountByTagID(tagID int) (int, error)
- GetMarkerStrings(q *string, sort *string) ([]*MarkerStringsResultType, error)
- Wall(q *string) ([]*SceneMarker, error)
- Query(sceneMarkerFilter *SceneMarkerFilterType, findFilter *FindFilterType) ([]*SceneMarker, int, error)
- GetTagIDs(imageID int) ([]int, error)
+ Find(ctx context.Context, id int) (*SceneMarker, error)
+ FindMany(ctx context.Context, ids []int) ([]*SceneMarker, error)
+ FindBySceneID(ctx context.Context, sceneID int) ([]*SceneMarker, error)
+ CountByTagID(ctx context.Context, tagID int) (int, error)
+ GetMarkerStrings(ctx context.Context, q *string, sort *string) ([]*MarkerStringsResultType, error)
+ Wall(ctx context.Context, q *string) ([]*SceneMarker, error)
+ Query(ctx context.Context, sceneMarkerFilter *SceneMarkerFilterType, findFilter *FindFilterType) ([]*SceneMarker, int, error)
+ GetTagIDs(ctx context.Context, imageID int) ([]int, error)
}
type SceneMarkerWriter interface {
- Create(newSceneMarker SceneMarker) (*SceneMarker, error)
- Update(updatedSceneMarker SceneMarker) (*SceneMarker, error)
- Destroy(id int) error
- UpdateTags(markerID int, tagIDs []int) error
+ Create(ctx context.Context, newSceneMarker SceneMarker) (*SceneMarker, error)
+ Update(ctx context.Context, updatedSceneMarker SceneMarker) (*SceneMarker, error)
+ Destroy(ctx context.Context, id int) error
+ UpdateTags(ctx context.Context, markerID int, tagIDs []int) error
}
type SceneMarkerReaderWriter interface {
diff --git a/pkg/models/scraped.go b/pkg/models/scraped.go
index f57a8409a..be424147b 100644
--- a/pkg/models/scraped.go
+++ b/pkg/models/scraped.go
@@ -1,15 +1,18 @@
package models
-import "errors"
+import (
+ "context"
+ "errors"
+)
var ErrScraperSource = errors.New("invalid ScraperSource")
type ScrapedItemReader interface {
- All() ([]*ScrapedItem, error)
+ All(ctx context.Context) ([]*ScrapedItem, error)
}
type ScrapedItemWriter interface {
- Create(newObject ScrapedItem) (*ScrapedItem, error)
+ Create(ctx context.Context, newObject ScrapedItem) (*ScrapedItem, error)
}
type ScrapedItemReaderWriter interface {
diff --git a/pkg/models/studio.go b/pkg/models/studio.go
index fb9f7c415..c1f077ce7 100644
--- a/pkg/models/studio.go
+++ b/pkg/models/studio.go
@@ -1,5 +1,7 @@
package models
+import "context"
+
type StudioFilterType struct {
And *StudioFilterType `json:"AND"`
Or *StudioFilterType `json:"OR"`
@@ -29,32 +31,32 @@ type StudioFilterType struct {
}
type StudioReader interface {
- Find(id int) (*Studio, error)
- FindMany(ids []int) ([]*Studio, error)
- FindChildren(id int) ([]*Studio, error)
- FindByName(name string, nocase bool) (*Studio, error)
- FindByStashID(stashID StashID) ([]*Studio, error)
- Count() (int, error)
- All() ([]*Studio, error)
+ Find(ctx context.Context, id int) (*Studio, error)
+ FindMany(ctx context.Context, ids []int) ([]*Studio, error)
+ FindChildren(ctx context.Context, id int) ([]*Studio, error)
+ FindByName(ctx context.Context, name string, nocase bool) (*Studio, error)
+ FindByStashID(ctx context.Context, stashID StashID) ([]*Studio, error)
+ Count(ctx context.Context) (int, error)
+ All(ctx context.Context) ([]*Studio, error)
// TODO - this interface is temporary until the filter schema can fully
// support the query needed
- QueryForAutoTag(words []string) ([]*Studio, error)
- Query(studioFilter *StudioFilterType, findFilter *FindFilterType) ([]*Studio, int, error)
- GetImage(studioID int) ([]byte, error)
- HasImage(studioID int) (bool, error)
- GetStashIDs(studioID int) ([]*StashID, error)
- GetAliases(studioID int) ([]string, error)
+ QueryForAutoTag(ctx context.Context, words []string) ([]*Studio, error)
+ Query(ctx context.Context, studioFilter *StudioFilterType, findFilter *FindFilterType) ([]*Studio, int, error)
+ GetImage(ctx context.Context, studioID int) ([]byte, error)
+ HasImage(ctx context.Context, studioID int) (bool, error)
+ GetStashIDs(ctx context.Context, studioID int) ([]*StashID, error)
+ GetAliases(ctx context.Context, studioID int) ([]string, error)
}
type StudioWriter interface {
- Create(newStudio Studio) (*Studio, error)
- Update(updatedStudio StudioPartial) (*Studio, error)
- UpdateFull(updatedStudio Studio) (*Studio, error)
- Destroy(id int) error
- UpdateImage(studioID int, image []byte) error
- DestroyImage(studioID int) error
- UpdateStashIDs(studioID int, stashIDs []StashID) error
- UpdateAliases(studioID int, aliases []string) error
+ Create(ctx context.Context, newStudio Studio) (*Studio, error)
+ Update(ctx context.Context, updatedStudio StudioPartial) (*Studio, error)
+ UpdateFull(ctx context.Context, updatedStudio Studio) (*Studio, error)
+ Destroy(ctx context.Context, id int) error
+ UpdateImage(ctx context.Context, studioID int, image []byte) error
+ DestroyImage(ctx context.Context, studioID int) error
+ UpdateStashIDs(ctx context.Context, studioID int, stashIDs []StashID) error
+ UpdateAliases(ctx context.Context, studioID int, aliases []string) error
}
type StudioReaderWriter interface {
diff --git a/pkg/models/tag.go b/pkg/models/tag.go
index a7f7518bf..33ff859c6 100644
--- a/pkg/models/tag.go
+++ b/pkg/models/tag.go
@@ -1,5 +1,7 @@
package models
+import "context"
+
type TagFilterType struct {
And *TagFilterType `json:"AND"`
Or *TagFilterType `json:"OR"`
@@ -33,40 +35,40 @@ type TagFilterType struct {
}
type TagReader interface {
- Find(id int) (*Tag, error)
- FindMany(ids []int) ([]*Tag, error)
- FindBySceneID(sceneID int) ([]*Tag, error)
- FindByPerformerID(performerID int) ([]*Tag, error)
- FindBySceneMarkerID(sceneMarkerID int) ([]*Tag, error)
- FindByImageID(imageID int) ([]*Tag, error)
- FindByGalleryID(galleryID int) ([]*Tag, error)
- FindByName(name string, nocase bool) (*Tag, error)
- FindByNames(names []string, nocase bool) ([]*Tag, error)
- FindByParentTagID(parentID int) ([]*Tag, error)
- FindByChildTagID(childID int) ([]*Tag, error)
- Count() (int, error)
- All() ([]*Tag, error)
+ Find(ctx context.Context, id int) (*Tag, error)
+ FindMany(ctx context.Context, ids []int) ([]*Tag, error)
+ FindBySceneID(ctx context.Context, sceneID int) ([]*Tag, error)
+ FindByPerformerID(ctx context.Context, performerID int) ([]*Tag, error)
+ FindBySceneMarkerID(ctx context.Context, sceneMarkerID int) ([]*Tag, error)
+ FindByImageID(ctx context.Context, imageID int) ([]*Tag, error)
+ FindByGalleryID(ctx context.Context, galleryID int) ([]*Tag, error)
+ FindByName(ctx context.Context, name string, nocase bool) (*Tag, error)
+ FindByNames(ctx context.Context, names []string, nocase bool) ([]*Tag, error)
+ FindByParentTagID(ctx context.Context, parentID int) ([]*Tag, error)
+ FindByChildTagID(ctx context.Context, childID int) ([]*Tag, error)
+ Count(ctx context.Context) (int, error)
+ All(ctx context.Context) ([]*Tag, error)
// TODO - this interface is temporary until the filter schema can fully
// support the query needed
- QueryForAutoTag(words []string) ([]*Tag, error)
- Query(tagFilter *TagFilterType, findFilter *FindFilterType) ([]*Tag, int, error)
- GetImage(tagID int) ([]byte, error)
- GetAliases(tagID int) ([]string, error)
- FindAllAncestors(tagID int, excludeIDs []int) ([]*TagPath, error)
- FindAllDescendants(tagID int, excludeIDs []int) ([]*TagPath, error)
+ QueryForAutoTag(ctx context.Context, words []string) ([]*Tag, error)
+ Query(ctx context.Context, tagFilter *TagFilterType, findFilter *FindFilterType) ([]*Tag, int, error)
+ GetImage(ctx context.Context, tagID int) ([]byte, error)
+ GetAliases(ctx context.Context, tagID int) ([]string, error)
+ FindAllAncestors(ctx context.Context, tagID int, excludeIDs []int) ([]*TagPath, error)
+ FindAllDescendants(ctx context.Context, tagID int, excludeIDs []int) ([]*TagPath, error)
}
type TagWriter interface {
- Create(newTag Tag) (*Tag, error)
- Update(updateTag TagPartial) (*Tag, error)
- UpdateFull(updatedTag Tag) (*Tag, error)
- Destroy(id int) error
- UpdateImage(tagID int, image []byte) error
- DestroyImage(tagID int) error
- UpdateAliases(tagID int, aliases []string) error
- Merge(source []int, destination int) error
- UpdateParentTags(tagID int, parentIDs []int) error
- UpdateChildTags(tagID int, parentIDs []int) error
+ Create(ctx context.Context, newTag Tag) (*Tag, error)
+ Update(ctx context.Context, updateTag TagPartial) (*Tag, error)
+ UpdateFull(ctx context.Context, updatedTag Tag) (*Tag, error)
+ Destroy(ctx context.Context, id int) error
+ UpdateImage(ctx context.Context, tagID int, image []byte) error
+ DestroyImage(ctx context.Context, tagID int) error
+ UpdateAliases(ctx context.Context, tagID int, aliases []string) error
+ Merge(ctx context.Context, source []int, destination int) error
+ UpdateParentTags(ctx context.Context, tagID int, parentIDs []int) error
+ UpdateChildTags(ctx context.Context, tagID int, parentIDs []int) error
}
type TagReaderWriter interface {
diff --git a/pkg/models/transaction.go b/pkg/models/transaction.go
deleted file mode 100644
index 291038b0c..000000000
--- a/pkg/models/transaction.go
+++ /dev/null
@@ -1,86 +0,0 @@
-package models
-
-import (
- "context"
-
- "github.com/stashapp/stash/pkg/logger"
-)
-
-type Transaction interface {
- Begin() error
- Rollback() error
- Commit() error
- Repository() Repository
-}
-
-type ReadTransaction interface {
- Begin() error
- Rollback() error
- Commit() error
- Repository() ReaderRepository
-}
-
-type TransactionManager interface {
- WithTxn(ctx context.Context, fn func(r Repository) error) error
- WithReadTxn(ctx context.Context, fn func(r ReaderRepository) error) error
-}
-
-func WithTxn(txn Transaction, fn func(r Repository) error) error {
- err := txn.Begin()
- if err != nil {
- return err
- }
-
- defer func() {
- if p := recover(); p != nil {
- // a panic occurred, rollback and repanic
- if err := txn.Rollback(); err != nil {
- logger.Warnf("error while trying to roll back transaction: %v", err)
- }
- panic(p)
- }
-
- if err != nil {
- // something went wrong, rollback
- if err := txn.Rollback(); err != nil {
- logger.Warnf("error while trying to roll back transaction: %v", err)
- }
- } else {
- // all good, commit
- err = txn.Commit()
- }
- }()
-
- err = fn(txn.Repository())
- return err
-}
-
-func WithROTxn(txn ReadTransaction, fn func(r ReaderRepository) error) error {
- err := txn.Begin()
- if err != nil {
- return err
- }
-
- defer func() {
- if p := recover(); p != nil {
- // a panic occurred, rollback and repanic
- if err := txn.Rollback(); err != nil {
- logger.Warnf("error while trying to roll back RO transaction: %v", err)
- }
- panic(p)
- }
-
- if err != nil {
- // something went wrong, rollback
- if err := txn.Rollback(); err != nil {
- logger.Warnf("error while trying to roll back RO transaction: %v", err)
- }
- } else {
- // all good, commit
- err = txn.Commit()
- }
- }()
-
- err = fn(txn.Repository())
- return err
-}
diff --git a/pkg/movie/export.go b/pkg/movie/export.go
index a70e30290..2af697a49 100644
--- a/pkg/movie/export.go
+++ b/pkg/movie/export.go
@@ -1,16 +1,23 @@
package movie
import (
+ "context"
"fmt"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/models/json"
"github.com/stashapp/stash/pkg/models/jsonschema"
+ "github.com/stashapp/stash/pkg/studio"
"github.com/stashapp/stash/pkg/utils"
)
+type ImageGetter interface {
+ GetFrontImage(ctx context.Context, movieID int) ([]byte, error)
+ GetBackImage(ctx context.Context, movieID int) ([]byte, error)
+}
+
// ToJSON converts a Movie into its JSON equivalent.
-func ToJSON(reader models.MovieReader, studioReader models.StudioReader, movie *models.Movie) (*jsonschema.Movie, error) {
+func ToJSON(ctx context.Context, reader ImageGetter, studioReader studio.Finder, movie *models.Movie) (*jsonschema.Movie, error) {
newMovieJSON := jsonschema.Movie{
CreatedAt: json.JSONTime{Time: movie.CreatedAt.Timestamp},
UpdatedAt: json.JSONTime{Time: movie.UpdatedAt.Timestamp},
@@ -45,7 +52,7 @@ func ToJSON(reader models.MovieReader, studioReader models.StudioReader, movie *
}
if movie.StudioID.Valid {
- studio, err := studioReader.Find(int(movie.StudioID.Int64))
+ studio, err := studioReader.Find(ctx, int(movie.StudioID.Int64))
if err != nil {
return nil, fmt.Errorf("error getting movie studio: %v", err)
}
@@ -55,7 +62,7 @@ func ToJSON(reader models.MovieReader, studioReader models.StudioReader, movie *
}
}
- frontImage, err := reader.GetFrontImage(movie.ID)
+ frontImage, err := reader.GetFrontImage(ctx, movie.ID)
if err != nil {
return nil, fmt.Errorf("error getting movie front image: %v", err)
}
@@ -64,7 +71,7 @@ func ToJSON(reader models.MovieReader, studioReader models.StudioReader, movie *
newMovieJSON.FrontImage = utils.GetBase64StringFromData(frontImage)
}
- backImage, err := reader.GetBackImage(movie.ID)
+ backImage, err := reader.GetBackImage(ctx, movie.ID)
if err != nil {
return nil, fmt.Errorf("error getting movie back image: %v", err)
}
diff --git a/pkg/movie/export_test.go b/pkg/movie/export_test.go
index 11be97b7b..007383902 100644
--- a/pkg/movie/export_test.go
+++ b/pkg/movie/export_test.go
@@ -55,7 +55,7 @@ var (
backImageBytes = []byte("backImageBytes")
)
-var studio models.Studio = models.Studio{
+var movieStudio models.Studio = models.Studio{
Name: models.NullString(studioName),
}
@@ -189,30 +189,30 @@ func TestToJSON(t *testing.T) {
imageErr := errors.New("error getting image")
- mockMovieReader.On("GetFrontImage", movieID).Return(frontImageBytes, nil).Once()
- mockMovieReader.On("GetFrontImage", missingStudioMovieID).Return(frontImageBytes, nil).Once()
- mockMovieReader.On("GetFrontImage", emptyID).Return(nil, nil).Once().Maybe()
- mockMovieReader.On("GetFrontImage", errFrontImageID).Return(nil, imageErr).Once()
- mockMovieReader.On("GetFrontImage", errBackImageID).Return(frontImageBytes, nil).Once()
+ mockMovieReader.On("GetFrontImage", testCtx, movieID).Return(frontImageBytes, nil).Once()
+ mockMovieReader.On("GetFrontImage", testCtx, missingStudioMovieID).Return(frontImageBytes, nil).Once()
+ mockMovieReader.On("GetFrontImage", testCtx, emptyID).Return(nil, nil).Once().Maybe()
+ mockMovieReader.On("GetFrontImage", testCtx, errFrontImageID).Return(nil, imageErr).Once()
+ mockMovieReader.On("GetFrontImage", testCtx, errBackImageID).Return(frontImageBytes, nil).Once()
- mockMovieReader.On("GetBackImage", movieID).Return(backImageBytes, nil).Once()
- mockMovieReader.On("GetBackImage", missingStudioMovieID).Return(backImageBytes, nil).Once()
- mockMovieReader.On("GetBackImage", emptyID).Return(nil, nil).Once()
- mockMovieReader.On("GetBackImage", errBackImageID).Return(nil, imageErr).Once()
- mockMovieReader.On("GetBackImage", errFrontImageID).Return(backImageBytes, nil).Maybe()
- mockMovieReader.On("GetBackImage", errStudioMovieID).Return(backImageBytes, nil).Maybe()
+ mockMovieReader.On("GetBackImage", testCtx, movieID).Return(backImageBytes, nil).Once()
+ mockMovieReader.On("GetBackImage", testCtx, missingStudioMovieID).Return(backImageBytes, nil).Once()
+ mockMovieReader.On("GetBackImage", testCtx, emptyID).Return(nil, nil).Once()
+ mockMovieReader.On("GetBackImage", testCtx, errBackImageID).Return(nil, imageErr).Once()
+ mockMovieReader.On("GetBackImage", testCtx, errFrontImageID).Return(backImageBytes, nil).Maybe()
+ mockMovieReader.On("GetBackImage", testCtx, errStudioMovieID).Return(backImageBytes, nil).Maybe()
mockStudioReader := &mocks.StudioReaderWriter{}
studioErr := errors.New("error getting studio")
- mockStudioReader.On("Find", studioID).Return(&studio, nil)
- mockStudioReader.On("Find", missingStudioID).Return(nil, nil)
- mockStudioReader.On("Find", errStudioID).Return(nil, studioErr)
+ mockStudioReader.On("Find", testCtx, studioID).Return(&movieStudio, nil)
+ mockStudioReader.On("Find", testCtx, missingStudioID).Return(nil, nil)
+ mockStudioReader.On("Find", testCtx, errStudioID).Return(nil, studioErr)
for i, s := range scenarios {
movie := s.movie
- json, err := ToJSON(mockMovieReader, mockStudioReader, &movie)
+ json, err := ToJSON(testCtx, mockMovieReader, mockStudioReader, &movie)
switch {
case !s.err && err != nil:
diff --git a/pkg/movie/import.go b/pkg/movie/import.go
index 6afdef8b9..461df0f84 100644
--- a/pkg/movie/import.go
+++ b/pkg/movie/import.go
@@ -1,18 +1,26 @@
package movie
import (
+ "context"
"database/sql"
"fmt"
"github.com/stashapp/stash/pkg/hash/md5"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/models/jsonschema"
+ "github.com/stashapp/stash/pkg/studio"
"github.com/stashapp/stash/pkg/utils"
)
+type NameFinderCreatorUpdater interface {
+ NameFinderCreator
+ UpdateFull(ctx context.Context, updatedMovie models.Movie) (*models.Movie, error)
+ UpdateImages(ctx context.Context, movieID int, frontImage []byte, backImage []byte) error
+}
+
type Importer struct {
- ReaderWriter models.MovieReaderWriter
- StudioWriter models.StudioReaderWriter
+ ReaderWriter NameFinderCreatorUpdater
+ StudioWriter studio.NameFinderCreator
Input jsonschema.Movie
MissingRefBehaviour models.ImportMissingRefEnum
@@ -21,10 +29,10 @@ type Importer struct {
backImageData []byte
}
-func (i *Importer) PreImport() error {
+func (i *Importer) PreImport(ctx context.Context) error {
i.movie = i.movieJSONToMovie(i.Input)
- if err := i.populateStudio(); err != nil {
+ if err := i.populateStudio(ctx); err != nil {
return err
}
@@ -71,9 +79,9 @@ func (i *Importer) movieJSONToMovie(movieJSON jsonschema.Movie) models.Movie {
return newMovie
}
-func (i *Importer) populateStudio() error {
+func (i *Importer) populateStudio(ctx context.Context) error {
if i.Input.Studio != "" {
- studio, err := i.StudioWriter.FindByName(i.Input.Studio, false)
+ studio, err := i.StudioWriter.FindByName(ctx, i.Input.Studio, false)
if err != nil {
return fmt.Errorf("error finding studio by name: %v", err)
}
@@ -88,7 +96,7 @@ func (i *Importer) populateStudio() error {
}
if i.MissingRefBehaviour == models.ImportMissingRefEnumCreate {
- studioID, err := i.createStudio(i.Input.Studio)
+ studioID, err := i.createStudio(ctx, i.Input.Studio)
if err != nil {
return err
}
@@ -105,10 +113,10 @@ func (i *Importer) populateStudio() error {
return nil
}
-func (i *Importer) createStudio(name string) (int, error) {
+func (i *Importer) createStudio(ctx context.Context, name string) (int, error) {
newStudio := *models.NewStudio(name)
- created, err := i.StudioWriter.Create(newStudio)
+ created, err := i.StudioWriter.Create(ctx, newStudio)
if err != nil {
return 0, err
}
@@ -116,9 +124,9 @@ func (i *Importer) createStudio(name string) (int, error) {
return created.ID, nil
}
-func (i *Importer) PostImport(id int) error {
+func (i *Importer) PostImport(ctx context.Context, id int) error {
if len(i.frontImageData) > 0 {
- if err := i.ReaderWriter.UpdateImages(id, i.frontImageData, i.backImageData); err != nil {
+ if err := i.ReaderWriter.UpdateImages(ctx, id, i.frontImageData, i.backImageData); err != nil {
return fmt.Errorf("error setting movie images: %v", err)
}
}
@@ -130,9 +138,9 @@ func (i *Importer) Name() string {
return i.Input.Name
}
-func (i *Importer) FindExistingID() (*int, error) {
+func (i *Importer) FindExistingID(ctx context.Context) (*int, error) {
const nocase = false
- existing, err := i.ReaderWriter.FindByName(i.Name(), nocase)
+ existing, err := i.ReaderWriter.FindByName(ctx, i.Name(), nocase)
if err != nil {
return nil, err
}
@@ -145,8 +153,8 @@ func (i *Importer) FindExistingID() (*int, error) {
return nil, nil
}
-func (i *Importer) Create() (*int, error) {
- created, err := i.ReaderWriter.Create(i.movie)
+func (i *Importer) Create(ctx context.Context) (*int, error) {
+ created, err := i.ReaderWriter.Create(ctx, i.movie)
if err != nil {
return nil, fmt.Errorf("error creating movie: %v", err)
}
@@ -155,10 +163,10 @@ func (i *Importer) Create() (*int, error) {
return &id, nil
}
-func (i *Importer) Update(id int) error {
+func (i *Importer) Update(ctx context.Context, id int) error {
movie := i.movie
movie.ID = id
- _, err := i.ReaderWriter.UpdateFull(movie)
+ _, err := i.ReaderWriter.UpdateFull(ctx, movie)
if err != nil {
return fmt.Errorf("error updating existing movie: %v", err)
}
diff --git a/pkg/movie/import_test.go b/pkg/movie/import_test.go
index 7aff71d47..26b6d9f27 100644
--- a/pkg/movie/import_test.go
+++ b/pkg/movie/import_test.go
@@ -1,6 +1,7 @@
package movie
import (
+ "context"
"errors"
"testing"
@@ -27,6 +28,8 @@ const (
errImageID = 3
)
+var testCtx = context.Background()
+
func TestImporterName(t *testing.T) {
i := Importer{
Input: jsonschema.Movie{
@@ -45,23 +48,23 @@ func TestImporterPreImport(t *testing.T) {
},
}
- err := i.PreImport()
+ err := i.PreImport(testCtx)
assert.NotNil(t, err)
i.Input.FrontImage = frontImage
i.Input.BackImage = invalidImage
- err = i.PreImport()
+ err = i.PreImport(testCtx)
assert.NotNil(t, err)
i.Input.BackImage = ""
- err = i.PreImport()
+ err = i.PreImport(testCtx)
assert.Nil(t, err)
i.Input.BackImage = backImage
- err = i.PreImport()
+ err = i.PreImport(testCtx)
assert.Nil(t, err)
}
@@ -79,17 +82,17 @@ func TestImporterPreImportWithStudio(t *testing.T) {
},
}
- studioReaderWriter.On("FindByName", existingStudioName, false).Return(&models.Studio{
+ studioReaderWriter.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{
ID: existingStudioID,
}, nil).Once()
- studioReaderWriter.On("FindByName", existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once()
+ studioReaderWriter.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once()
- err := i.PreImport()
+ err := i.PreImport(testCtx)
assert.Nil(t, err)
assert.Equal(t, int64(existingStudioID), i.movie.StudioID.Int64)
i.Input.Studio = existingStudioErr
- err = i.PreImport()
+ err = i.PreImport(testCtx)
assert.NotNil(t, err)
studioReaderWriter.AssertExpectations(t)
@@ -108,20 +111,20 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail,
}
- studioReaderWriter.On("FindByName", missingStudioName, false).Return(nil, nil).Times(3)
- studioReaderWriter.On("Create", mock.AnythingOfType("models.Studio")).Return(&models.Studio{
+ studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3)
+ studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Studio")).Return(&models.Studio{
ID: existingStudioID,
}, nil)
- err := i.PreImport()
+ err := i.PreImport(testCtx)
assert.NotNil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore
- err = i.PreImport()
+ err = i.PreImport(testCtx)
assert.Nil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumCreate
- err = i.PreImport()
+ err = i.PreImport(testCtx)
assert.Nil(t, err)
assert.Equal(t, int64(existingStudioID), i.movie.StudioID.Int64)
@@ -141,10 +144,10 @@ func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
}
- studioReaderWriter.On("FindByName", missingStudioName, false).Return(nil, nil).Once()
- studioReaderWriter.On("Create", mock.AnythingOfType("models.Studio")).Return(nil, errors.New("Create error"))
+ studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once()
+ studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Studio")).Return(nil, errors.New("Create error"))
- err := i.PreImport()
+ err := i.PreImport(testCtx)
assert.NotNil(t, err)
}
@@ -159,13 +162,13 @@ func TestImporterPostImport(t *testing.T) {
updateMovieImageErr := errors.New("UpdateImages error")
- readerWriter.On("UpdateImages", movieID, frontImageBytes, backImageBytes).Return(nil).Once()
- readerWriter.On("UpdateImages", errImageID, frontImageBytes, backImageBytes).Return(updateMovieImageErr).Once()
+ readerWriter.On("UpdateImages", testCtx, movieID, frontImageBytes, backImageBytes).Return(nil).Once()
+ readerWriter.On("UpdateImages", testCtx, errImageID, frontImageBytes, backImageBytes).Return(updateMovieImageErr).Once()
- err := i.PostImport(movieID)
+ err := i.PostImport(testCtx, movieID)
assert.Nil(t, err)
- err = i.PostImport(errImageID)
+ err = i.PostImport(testCtx, errImageID)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
@@ -182,23 +185,23 @@ func TestImporterFindExistingID(t *testing.T) {
}
errFindByName := errors.New("FindByName error")
- readerWriter.On("FindByName", movieName, false).Return(nil, nil).Once()
- readerWriter.On("FindByName", existingMovieName, false).Return(&models.Movie{
+ readerWriter.On("FindByName", testCtx, movieName, false).Return(nil, nil).Once()
+ readerWriter.On("FindByName", testCtx, existingMovieName, false).Return(&models.Movie{
ID: existingMovieID,
}, nil).Once()
- readerWriter.On("FindByName", movieNameErr, false).Return(nil, errFindByName).Once()
+ readerWriter.On("FindByName", testCtx, movieNameErr, false).Return(nil, errFindByName).Once()
- id, err := i.FindExistingID()
+ id, err := i.FindExistingID(testCtx)
assert.Nil(t, id)
assert.Nil(t, err)
i.Input.Name = existingMovieName
- id, err = i.FindExistingID()
+ id, err = i.FindExistingID(testCtx)
assert.Equal(t, existingMovieID, *id)
assert.Nil(t, err)
i.Input.Name = movieNameErr
- id, err = i.FindExistingID()
+ id, err = i.FindExistingID(testCtx)
assert.Nil(t, id)
assert.NotNil(t, err)
@@ -222,17 +225,17 @@ func TestCreate(t *testing.T) {
}
errCreate := errors.New("Create error")
- readerWriter.On("Create", movie).Return(&models.Movie{
+ readerWriter.On("Create", testCtx, movie).Return(&models.Movie{
ID: movieID,
}, nil).Once()
- readerWriter.On("Create", movieErr).Return(nil, errCreate).Once()
+ readerWriter.On("Create", testCtx, movieErr).Return(nil, errCreate).Once()
- id, err := i.Create()
+ id, err := i.Create(testCtx)
assert.Equal(t, movieID, *id)
assert.Nil(t, err)
i.movie = movieErr
- id, err = i.Create()
+ id, err = i.Create(testCtx)
assert.Nil(t, id)
assert.NotNil(t, err)
@@ -259,18 +262,18 @@ func TestUpdate(t *testing.T) {
// id needs to be set for the mock input
movie.ID = movieID
- readerWriter.On("UpdateFull", movie).Return(nil, nil).Once()
+ readerWriter.On("UpdateFull", testCtx, movie).Return(nil, nil).Once()
- err := i.Update(movieID)
+ err := i.Update(testCtx, movieID)
assert.Nil(t, err)
i.movie = movieErr
// need to set id separately
movieErr.ID = errImageID
- readerWriter.On("UpdateFull", movieErr).Return(nil, errUpdate).Once()
+ readerWriter.On("UpdateFull", testCtx, movieErr).Return(nil, errUpdate).Once()
- err = i.Update(errImageID)
+ err = i.Update(testCtx, errImageID)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
diff --git a/pkg/movie/update.go b/pkg/movie/update.go
new file mode 100644
index 000000000..48dc9c123
--- /dev/null
+++ b/pkg/movie/update.go
@@ -0,0 +1,12 @@
+package movie
+
+import (
+ "context"
+
+ "github.com/stashapp/stash/pkg/models"
+)
+
+type NameFinderCreator interface {
+ FindByName(ctx context.Context, name string, nocase bool) (*models.Movie, error)
+ Create(ctx context.Context, newMovie models.Movie) (*models.Movie, error)
+}
diff --git a/pkg/performer/export.go b/pkg/performer/export.go
index 240d0fc28..a15df7e99 100644
--- a/pkg/performer/export.go
+++ b/pkg/performer/export.go
@@ -1,6 +1,7 @@
package performer
import (
+ "context"
"fmt"
"github.com/stashapp/stash/pkg/models"
@@ -9,8 +10,13 @@ import (
"github.com/stashapp/stash/pkg/utils"
)
+type ImageStashIDGetter interface {
+ GetImage(ctx context.Context, performerID int) ([]byte, error)
+ GetStashIDs(ctx context.Context, performerID int) ([]*models.StashID, error)
+}
+
// ToJSON converts a Performer object into its JSON equivalent.
-func ToJSON(reader models.PerformerReader, performer *models.Performer) (*jsonschema.Performer, error) {
+func ToJSON(ctx context.Context, reader ImageStashIDGetter, performer *models.Performer) (*jsonschema.Performer, error) {
newPerformerJSON := jsonschema.Performer{
IgnoreAutoTag: performer.IgnoreAutoTag,
CreatedAt: json.JSONTime{Time: performer.CreatedAt.Timestamp},
@@ -84,7 +90,7 @@ func ToJSON(reader models.PerformerReader, performer *models.Performer) (*jsonsc
newPerformerJSON.Weight = int(performer.Weight.Int64)
}
- image, err := reader.GetImage(performer.ID)
+ image, err := reader.GetImage(ctx, performer.ID)
if err != nil {
return nil, fmt.Errorf("error getting performers image: %v", err)
}
@@ -93,7 +99,7 @@ func ToJSON(reader models.PerformerReader, performer *models.Performer) (*jsonsc
newPerformerJSON.Image = utils.GetBase64StringFromData(image)
}
- stashIDs, _ := reader.GetStashIDs(performer.ID)
+ stashIDs, _ := reader.GetStashIDs(ctx, performer.ID)
var ret []models.StashID
for _, stashID := range stashIDs {
newJoin := models.StashID{
diff --git a/pkg/performer/export_test.go b/pkg/performer/export_test.go
index 7cfdabb7f..e83d0e189 100644
--- a/pkg/performer/export_test.go
+++ b/pkg/performer/export_test.go
@@ -208,16 +208,16 @@ func TestToJSON(t *testing.T) {
imageErr := errors.New("error getting image")
- mockPerformerReader.On("GetImage", performerID).Return(imageBytes, nil).Once()
- mockPerformerReader.On("GetImage", noImageID).Return(nil, nil).Once()
- mockPerformerReader.On("GetImage", errImageID).Return(nil, imageErr).Once()
+ mockPerformerReader.On("GetImage", testCtx, performerID).Return(imageBytes, nil).Once()
+ mockPerformerReader.On("GetImage", testCtx, noImageID).Return(nil, nil).Once()
+ mockPerformerReader.On("GetImage", testCtx, errImageID).Return(nil, imageErr).Once()
- mockPerformerReader.On("GetStashIDs", performerID).Return(stashIDs, nil).Once()
- mockPerformerReader.On("GetStashIDs", noImageID).Return(nil, nil).Once()
+ mockPerformerReader.On("GetStashIDs", testCtx, performerID).Return(stashIDs, nil).Once()
+ mockPerformerReader.On("GetStashIDs", testCtx, noImageID).Return(nil, nil).Once()
for i, s := range scenarios {
tag := s.input
- json, err := ToJSON(mockPerformerReader, &tag)
+ json, err := ToJSON(testCtx, mockPerformerReader, &tag)
switch {
case !s.err && err != nil:
diff --git a/pkg/performer/import.go b/pkg/performer/import.go
index 9e4ec77f7..7c673fb34 100644
--- a/pkg/performer/import.go
+++ b/pkg/performer/import.go
@@ -1,6 +1,7 @@
package performer
import (
+ "context"
"database/sql"
"fmt"
"strings"
@@ -9,12 +10,21 @@ import (
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/models/jsonschema"
"github.com/stashapp/stash/pkg/sliceutil/stringslice"
+ "github.com/stashapp/stash/pkg/tag"
"github.com/stashapp/stash/pkg/utils"
)
+type NameFinderCreatorUpdater interface {
+ NameFinderCreator
+ UpdateFull(ctx context.Context, updatedPerformer models.Performer) (*models.Performer, error)
+ UpdateTags(ctx context.Context, performerID int, tagIDs []int) error
+ UpdateImage(ctx context.Context, performerID int, image []byte) error
+ UpdateStashIDs(ctx context.Context, performerID int, stashIDs []models.StashID) error
+}
+
type Importer struct {
- ReaderWriter models.PerformerReaderWriter
- TagWriter models.TagReaderWriter
+ ReaderWriter NameFinderCreatorUpdater
+ TagWriter tag.NameFinderCreator
Input jsonschema.Performer
MissingRefBehaviour models.ImportMissingRefEnum
@@ -25,10 +35,10 @@ type Importer struct {
tags []*models.Tag
}
-func (i *Importer) PreImport() error {
+func (i *Importer) PreImport(ctx context.Context) error {
i.performer = performerJSONToPerformer(i.Input)
- if err := i.populateTags(); err != nil {
+ if err := i.populateTags(ctx); err != nil {
return err
}
@@ -43,10 +53,10 @@ func (i *Importer) PreImport() error {
return nil
}
-func (i *Importer) populateTags() error {
+func (i *Importer) populateTags(ctx context.Context) error {
if len(i.Input.Tags) > 0 {
- tags, err := importTags(i.TagWriter, i.Input.Tags, i.MissingRefBehaviour)
+ tags, err := importTags(ctx, i.TagWriter, i.Input.Tags, i.MissingRefBehaviour)
if err != nil {
return err
}
@@ -57,8 +67,8 @@ func (i *Importer) populateTags() error {
return nil
}
-func importTags(tagWriter models.TagReaderWriter, names []string, missingRefBehaviour models.ImportMissingRefEnum) ([]*models.Tag, error) {
- tags, err := tagWriter.FindByNames(names, false)
+func importTags(ctx context.Context, tagWriter tag.NameFinderCreator, names []string, missingRefBehaviour models.ImportMissingRefEnum) ([]*models.Tag, error) {
+ tags, err := tagWriter.FindByNames(ctx, names, false)
if err != nil {
return nil, err
}
@@ -78,7 +88,7 @@ func importTags(tagWriter models.TagReaderWriter, names []string, missingRefBeha
}
if missingRefBehaviour == models.ImportMissingRefEnumCreate {
- createdTags, err := createTags(tagWriter, missingTags)
+ createdTags, err := createTags(ctx, tagWriter, missingTags)
if err != nil {
return nil, fmt.Errorf("error creating tags: %v", err)
}
@@ -92,12 +102,12 @@ func importTags(tagWriter models.TagReaderWriter, names []string, missingRefBeha
return tags, nil
}
-func createTags(tagWriter models.TagWriter, names []string) ([]*models.Tag, error) {
+func createTags(ctx context.Context, tagWriter tag.NameFinderCreator, names []string) ([]*models.Tag, error) {
var ret []*models.Tag
for _, name := range names {
newTag := *models.NewTag(name)
- created, err := tagWriter.Create(newTag)
+ created, err := tagWriter.Create(ctx, newTag)
if err != nil {
return nil, err
}
@@ -108,25 +118,25 @@ func createTags(tagWriter models.TagWriter, names []string) ([]*models.Tag, erro
return ret, nil
}
-func (i *Importer) PostImport(id int) error {
+func (i *Importer) PostImport(ctx context.Context, id int) error {
if len(i.tags) > 0 {
var tagIDs []int
for _, t := range i.tags {
tagIDs = append(tagIDs, t.ID)
}
- if err := i.ReaderWriter.UpdateTags(id, tagIDs); err != nil {
+ if err := i.ReaderWriter.UpdateTags(ctx, id, tagIDs); err != nil {
return fmt.Errorf("failed to associate tags: %v", err)
}
}
if len(i.imageData) > 0 {
- if err := i.ReaderWriter.UpdateImage(id, i.imageData); err != nil {
+ if err := i.ReaderWriter.UpdateImage(ctx, id, i.imageData); err != nil {
return fmt.Errorf("error setting performer image: %v", err)
}
}
if len(i.Input.StashIDs) > 0 {
- if err := i.ReaderWriter.UpdateStashIDs(id, i.Input.StashIDs); err != nil {
+ if err := i.ReaderWriter.UpdateStashIDs(ctx, id, i.Input.StashIDs); err != nil {
return fmt.Errorf("error setting stash id: %v", err)
}
}
@@ -138,9 +148,9 @@ func (i *Importer) Name() string {
return i.Input.Name
}
-func (i *Importer) FindExistingID() (*int, error) {
+func (i *Importer) FindExistingID(ctx context.Context) (*int, error) {
const nocase = false
- existing, err := i.ReaderWriter.FindByNames([]string{i.Name()}, nocase)
+ existing, err := i.ReaderWriter.FindByNames(ctx, []string{i.Name()}, nocase)
if err != nil {
return nil, err
}
@@ -153,8 +163,8 @@ func (i *Importer) FindExistingID() (*int, error) {
return nil, nil
}
-func (i *Importer) Create() (*int, error) {
- created, err := i.ReaderWriter.Create(i.performer)
+func (i *Importer) Create(ctx context.Context) (*int, error) {
+ created, err := i.ReaderWriter.Create(ctx, i.performer)
if err != nil {
return nil, fmt.Errorf("error creating performer: %v", err)
}
@@ -163,10 +173,10 @@ func (i *Importer) Create() (*int, error) {
return &id, nil
}
-func (i *Importer) Update(id int) error {
+func (i *Importer) Update(ctx context.Context, id int) error {
performer := i.performer
performer.ID = id
- _, err := i.ReaderWriter.UpdateFull(performer)
+ _, err := i.ReaderWriter.UpdateFull(ctx, performer)
if err != nil {
return fmt.Errorf("error updating existing performer: %v", err)
}
diff --git a/pkg/performer/import_test.go b/pkg/performer/import_test.go
index 30ddbae5e..4f80a67c0 100644
--- a/pkg/performer/import_test.go
+++ b/pkg/performer/import_test.go
@@ -1,6 +1,7 @@
package performer
import (
+ "context"
"errors"
"github.com/stretchr/testify/mock"
@@ -29,6 +30,8 @@ const (
missingTagName = "missingTagName"
)
+var testCtx = context.Background()
+
func TestImporterName(t *testing.T) {
i := Importer{
Input: jsonschema.Performer{
@@ -47,13 +50,13 @@ func TestImporterPreImport(t *testing.T) {
},
}
- err := i.PreImport()
+ err := i.PreImport(testCtx)
assert.NotNil(t, err)
i.Input = *createFullJSONPerformer(performerName, image)
- err = i.PreImport()
+ err = i.PreImport(testCtx)
assert.Nil(t, err)
expectedPerformer := *createFullPerformer(0, performerName)
@@ -74,20 +77,20 @@ func TestImporterPreImportWithTag(t *testing.T) {
},
}
- tagReaderWriter.On("FindByNames", []string{existingTagName}, false).Return([]*models.Tag{
+ tagReaderWriter.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{
{
ID: existingTagID,
Name: existingTagName,
},
}, nil).Once()
- tagReaderWriter.On("FindByNames", []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once()
+ tagReaderWriter.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once()
- err := i.PreImport()
+ err := i.PreImport(testCtx)
assert.Nil(t, err)
assert.Equal(t, existingTagID, i.tags[0].ID)
i.Input.Tags = []string{existingTagErr}
- err = i.PreImport()
+ err = i.PreImport(testCtx)
assert.NotNil(t, err)
tagReaderWriter.AssertExpectations(t)
@@ -106,20 +109,20 @@ func TestImporterPreImportWithMissingTag(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail,
}
- tagReaderWriter.On("FindByNames", []string{missingTagName}, false).Return(nil, nil).Times(3)
- tagReaderWriter.On("Create", mock.AnythingOfType("models.Tag")).Return(&models.Tag{
+ tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3)
+ tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Tag")).Return(&models.Tag{
ID: existingTagID,
}, nil)
- err := i.PreImport()
+ err := i.PreImport(testCtx)
assert.NotNil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore
- err = i.PreImport()
+ err = i.PreImport(testCtx)
assert.Nil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumCreate
- err = i.PreImport()
+ err = i.PreImport(testCtx)
assert.Nil(t, err)
assert.Equal(t, existingTagID, i.tags[0].ID)
@@ -139,10 +142,10 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
}
- tagReaderWriter.On("FindByNames", []string{missingTagName}, false).Return(nil, nil).Once()
- tagReaderWriter.On("Create", mock.AnythingOfType("models.Tag")).Return(nil, errors.New("Create error"))
+ tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once()
+ tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Tag")).Return(nil, errors.New("Create error"))
- err := i.PreImport()
+ err := i.PreImport(testCtx)
assert.NotNil(t, err)
}
@@ -156,13 +159,13 @@ func TestImporterPostImport(t *testing.T) {
updatePerformerImageErr := errors.New("UpdateImage error")
- readerWriter.On("UpdateImage", performerID, imageBytes).Return(nil).Once()
- readerWriter.On("UpdateImage", errImageID, imageBytes).Return(updatePerformerImageErr).Once()
+ readerWriter.On("UpdateImage", testCtx, performerID, imageBytes).Return(nil).Once()
+ readerWriter.On("UpdateImage", testCtx, errImageID, imageBytes).Return(updatePerformerImageErr).Once()
- err := i.PostImport(performerID)
+ err := i.PostImport(testCtx, performerID)
assert.Nil(t, err)
- err = i.PostImport(errImageID)
+ err = i.PostImport(testCtx, errImageID)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
@@ -179,25 +182,25 @@ func TestImporterFindExistingID(t *testing.T) {
}
errFindByNames := errors.New("FindByNames error")
- readerWriter.On("FindByNames", []string{performerName}, false).Return(nil, nil).Once()
- readerWriter.On("FindByNames", []string{existingPerformerName}, false).Return([]*models.Performer{
+ readerWriter.On("FindByNames", testCtx, []string{performerName}, false).Return(nil, nil).Once()
+ readerWriter.On("FindByNames", testCtx, []string{existingPerformerName}, false).Return([]*models.Performer{
{
ID: existingPerformerID,
},
}, nil).Once()
- readerWriter.On("FindByNames", []string{performerNameErr}, false).Return(nil, errFindByNames).Once()
+ readerWriter.On("FindByNames", testCtx, []string{performerNameErr}, false).Return(nil, errFindByNames).Once()
- id, err := i.FindExistingID()
+ id, err := i.FindExistingID(testCtx)
assert.Nil(t, id)
assert.Nil(t, err)
i.Input.Name = existingPerformerName
- id, err = i.FindExistingID()
+ id, err = i.FindExistingID(testCtx)
assert.Equal(t, existingPerformerID, *id)
assert.Nil(t, err)
i.Input.Name = performerNameErr
- id, err = i.FindExistingID()
+ id, err = i.FindExistingID(testCtx)
assert.Nil(t, id)
assert.NotNil(t, err)
@@ -218,13 +221,13 @@ func TestImporterPostImportUpdateTags(t *testing.T) {
updateErr := errors.New("UpdateTags error")
- readerWriter.On("UpdateTags", performerID, []int{existingTagID}).Return(nil).Once()
- readerWriter.On("UpdateTags", errTagsID, mock.AnythingOfType("[]int")).Return(updateErr).Once()
+ readerWriter.On("UpdateTags", testCtx, performerID, []int{existingTagID}).Return(nil).Once()
+ readerWriter.On("UpdateTags", testCtx, errTagsID, mock.AnythingOfType("[]int")).Return(updateErr).Once()
- err := i.PostImport(performerID)
+ err := i.PostImport(testCtx, performerID)
assert.Nil(t, err)
- err = i.PostImport(errTagsID)
+ err = i.PostImport(testCtx, errTagsID)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
@@ -247,17 +250,17 @@ func TestCreate(t *testing.T) {
}
errCreate := errors.New("Create error")
- readerWriter.On("Create", performer).Return(&models.Performer{
+ readerWriter.On("Create", testCtx, performer).Return(&models.Performer{
ID: performerID,
}, nil).Once()
- readerWriter.On("Create", performerErr).Return(nil, errCreate).Once()
+ readerWriter.On("Create", testCtx, performerErr).Return(nil, errCreate).Once()
- id, err := i.Create()
+ id, err := i.Create(testCtx)
assert.Equal(t, performerID, *id)
assert.Nil(t, err)
i.performer = performerErr
- id, err = i.Create()
+ id, err = i.Create(testCtx)
assert.Nil(t, id)
assert.NotNil(t, err)
@@ -284,18 +287,18 @@ func TestUpdate(t *testing.T) {
// id needs to be set for the mock input
performer.ID = performerID
- readerWriter.On("UpdateFull", performer).Return(nil, nil).Once()
+ readerWriter.On("UpdateFull", testCtx, performer).Return(nil, nil).Once()
- err := i.Update(performerID)
+ err := i.Update(testCtx, performerID)
assert.Nil(t, err)
i.performer = performerErr
// need to set id separately
performerErr.ID = errImageID
- readerWriter.On("UpdateFull", performerErr).Return(nil, errUpdate).Once()
+ readerWriter.On("UpdateFull", testCtx, performerErr).Return(nil, errUpdate).Once()
- err = i.Update(errImageID)
+ err = i.Update(testCtx, errImageID)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
diff --git a/pkg/performer/update.go b/pkg/performer/update.go
new file mode 100644
index 000000000..5974a5eab
--- /dev/null
+++ b/pkg/performer/update.go
@@ -0,0 +1,12 @@
+package performer
+
+import (
+ "context"
+
+ "github.com/stashapp/stash/pkg/models"
+)
+
+type NameFinderCreator interface {
+ FindByNames(ctx context.Context, names []string, nocase bool) ([]*models.Performer, error)
+ Create(ctx context.Context, newPerformer models.Performer) (*models.Performer, error)
+}
diff --git a/pkg/scene/delete.go b/pkg/scene/delete.go
index 3a31d6f60..7347d68fd 100644
--- a/pkg/scene/delete.go
+++ b/pkg/scene/delete.go
@@ -1,6 +1,7 @@
package scene
import (
+ "context"
"path/filepath"
"github.com/stashapp/stash/pkg/file"
@@ -114,19 +115,25 @@ func (d *FileDeleter) MarkMarkerFiles(scene *models.Scene, seconds int) error {
return d.Files(files)
}
+type Destroyer interface {
+ Destroy(ctx context.Context, id int) error
+}
+
+type MarkerDestroyer interface {
+ FindBySceneID(ctx context.Context, sceneID int) ([]*models.SceneMarker, error)
+ Destroy(ctx context.Context, id int) error
+}
+
// Destroy deletes a scene and its associated relationships from the
// database.
-func Destroy(scene *models.Scene, repo models.Repository, fileDeleter *FileDeleter, deleteGenerated, deleteFile bool) error {
- qb := repo.Scene()
- mqb := repo.SceneMarker()
-
- markers, err := mqb.FindBySceneID(scene.ID)
+func Destroy(ctx context.Context, scene *models.Scene, qb Destroyer, mqb MarkerDestroyer, fileDeleter *FileDeleter, deleteGenerated, deleteFile bool) error {
+ markers, err := mqb.FindBySceneID(ctx, scene.ID)
if err != nil {
return err
}
for _, m := range markers {
- if err := DestroyMarker(scene, m, mqb, fileDeleter); err != nil {
+ if err := DestroyMarker(ctx, scene, m, mqb, fileDeleter); err != nil {
return err
}
}
@@ -151,7 +158,7 @@ func Destroy(scene *models.Scene, repo models.Repository, fileDeleter *FileDelet
}
}
- if err := qb.Destroy(scene.ID); err != nil {
+ if err := qb.Destroy(ctx, scene.ID); err != nil {
return err
}
@@ -161,8 +168,8 @@ func Destroy(scene *models.Scene, repo models.Repository, fileDeleter *FileDelet
// DestroyMarker deletes the scene marker from the database and returns a
// function that removes the generated files, to be executed after the
// transaction is successfully committed.
-func DestroyMarker(scene *models.Scene, sceneMarker *models.SceneMarker, qb models.SceneMarkerWriter, fileDeleter *FileDeleter) error {
- if err := qb.Destroy(sceneMarker.ID); err != nil {
+func DestroyMarker(ctx context.Context, scene *models.Scene, sceneMarker *models.SceneMarker, qb MarkerDestroyer, fileDeleter *FileDeleter) error {
+ if err := qb.Destroy(ctx, sceneMarker.ID); err != nil {
return err
}
diff --git a/pkg/scene/export.go b/pkg/scene/export.go
index c5bda2c47..57557f11a 100644
--- a/pkg/scene/export.go
+++ b/pkg/scene/export.go
@@ -1,6 +1,7 @@
package scene
import (
+ "context"
"fmt"
"math"
"strconv"
@@ -9,13 +10,38 @@ import (
"github.com/stashapp/stash/pkg/models/json"
"github.com/stashapp/stash/pkg/models/jsonschema"
"github.com/stashapp/stash/pkg/sliceutil/intslice"
+ "github.com/stashapp/stash/pkg/studio"
+ "github.com/stashapp/stash/pkg/tag"
"github.com/stashapp/stash/pkg/utils"
)
+type CoverStashIDGetter interface {
+ GetCover(ctx context.Context, sceneID int) ([]byte, error)
+ GetStashIDs(ctx context.Context, sceneID int) ([]*models.StashID, error)
+}
+
+type MovieGetter interface {
+ GetMovies(ctx context.Context, sceneID int) ([]models.MoviesScenes, error)
+}
+
+type MarkerTagFinder interface {
+ tag.Finder
+ TagFinder
+ FindBySceneMarkerID(ctx context.Context, sceneMarkerID int) ([]*models.Tag, error)
+}
+
+type MarkerFinder interface {
+ FindBySceneID(ctx context.Context, sceneID int) ([]*models.SceneMarker, error)
+}
+
+type TagFinder interface {
+ FindBySceneID(ctx context.Context, sceneID int) ([]*models.Tag, error)
+}
+
// ToBasicJSON converts a scene object into its JSON object equivalent. It
// does not convert the relationships to other objects, with the exception
// of cover image.
-func ToBasicJSON(reader models.SceneReader, scene *models.Scene) (*jsonschema.Scene, error) {
+func ToBasicJSON(ctx context.Context, reader CoverStashIDGetter, scene *models.Scene) (*jsonschema.Scene, error) {
newSceneJSON := jsonschema.Scene{
CreatedAt: json.JSONTime{Time: scene.CreatedAt.Timestamp},
UpdatedAt: json.JSONTime{Time: scene.UpdatedAt.Timestamp},
@@ -58,7 +84,7 @@ func ToBasicJSON(reader models.SceneReader, scene *models.Scene) (*jsonschema.Sc
newSceneJSON.File = getSceneFileJSON(scene)
- cover, err := reader.GetCover(scene.ID)
+ cover, err := reader.GetCover(ctx, scene.ID)
if err != nil {
return nil, fmt.Errorf("error getting scene cover: %v", err)
}
@@ -67,7 +93,7 @@ func ToBasicJSON(reader models.SceneReader, scene *models.Scene) (*jsonschema.Sc
newSceneJSON.Cover = utils.GetBase64StringFromData(cover)
}
- stashIDs, _ := reader.GetStashIDs(scene.ID)
+ stashIDs, _ := reader.GetStashIDs(ctx, scene.ID)
var ret []models.StashID
for _, stashID := range stashIDs {
newJoin := models.StashID{
@@ -130,9 +156,9 @@ func getSceneFileJSON(scene *models.Scene) *jsonschema.SceneFile {
// GetStudioName returns the name of the provided scene's studio. It returns an
// empty string if there is no studio assigned to the scene.
-func GetStudioName(reader models.StudioReader, scene *models.Scene) (string, error) {
+func GetStudioName(ctx context.Context, reader studio.Finder, scene *models.Scene) (string, error) {
if scene.StudioID.Valid {
- studio, err := reader.Find(int(scene.StudioID.Int64))
+ studio, err := reader.Find(ctx, int(scene.StudioID.Int64))
if err != nil {
return "", err
}
@@ -147,8 +173,8 @@ func GetStudioName(reader models.StudioReader, scene *models.Scene) (string, err
// GetTagNames returns a slice of tag names corresponding to the provided
// scene's tags.
-func GetTagNames(reader models.TagReader, scene *models.Scene) ([]string, error) {
- tags, err := reader.FindBySceneID(scene.ID)
+func GetTagNames(ctx context.Context, reader TagFinder, scene *models.Scene) ([]string, error) {
+ tags, err := reader.FindBySceneID(ctx, scene.ID)
if err != nil {
return nil, fmt.Errorf("error getting scene tags: %v", err)
}
@@ -168,10 +194,10 @@ func getTagNames(tags []*models.Tag) []string {
}
// GetDependentTagIDs returns a slice of unique tag IDs that this scene references.
-func GetDependentTagIDs(tags models.TagReader, markerReader models.SceneMarkerReader, scene *models.Scene) ([]int, error) {
+func GetDependentTagIDs(ctx context.Context, tags MarkerTagFinder, markerReader MarkerFinder, scene *models.Scene) ([]int, error) {
var ret []int
- t, err := tags.FindBySceneID(scene.ID)
+ t, err := tags.FindBySceneID(ctx, scene.ID)
if err != nil {
return nil, err
}
@@ -180,14 +206,14 @@ func GetDependentTagIDs(tags models.TagReader, markerReader models.SceneMarkerRe
ret = intslice.IntAppendUnique(ret, tt.ID)
}
- sm, err := markerReader.FindBySceneID(scene.ID)
+ sm, err := markerReader.FindBySceneID(ctx, scene.ID)
if err != nil {
return nil, err
}
for _, smm := range sm {
ret = intslice.IntAppendUnique(ret, smm.PrimaryTagID)
- smmt, err := tags.FindBySceneMarkerID(smm.ID)
+ smmt, err := tags.FindBySceneMarkerID(ctx, smm.ID)
if err != nil {
return nil, fmt.Errorf("invalid tags for scene marker: %v", err)
}
@@ -200,17 +226,21 @@ func GetDependentTagIDs(tags models.TagReader, markerReader models.SceneMarkerRe
return ret, nil
}
+type MovieFinder interface {
+ Find(ctx context.Context, id int) (*models.Movie, error)
+}
+
// GetSceneMoviesJSON returns a slice of SceneMovie JSON representation objects
// corresponding to the provided scene's scene movie relationships.
-func GetSceneMoviesJSON(movieReader models.MovieReader, sceneReader models.SceneReader, scene *models.Scene) ([]jsonschema.SceneMovie, error) {
- sceneMovies, err := sceneReader.GetMovies(scene.ID)
+func GetSceneMoviesJSON(ctx context.Context, movieReader MovieFinder, sceneReader MovieGetter, scene *models.Scene) ([]jsonschema.SceneMovie, error) {
+ sceneMovies, err := sceneReader.GetMovies(ctx, scene.ID)
if err != nil {
return nil, fmt.Errorf("error getting scene movies: %v", err)
}
var results []jsonschema.SceneMovie
for _, sceneMovie := range sceneMovies {
- movie, err := movieReader.Find(sceneMovie.MovieID)
+ movie, err := movieReader.Find(ctx, sceneMovie.MovieID)
if err != nil {
return nil, fmt.Errorf("error getting movie: %v", err)
}
@@ -228,10 +258,10 @@ func GetSceneMoviesJSON(movieReader models.MovieReader, sceneReader models.Scene
}
// GetDependentMovieIDs returns a slice of movie IDs that this scene references.
-func GetDependentMovieIDs(sceneReader models.SceneReader, scene *models.Scene) ([]int, error) {
+func GetDependentMovieIDs(ctx context.Context, sceneReader MovieGetter, scene *models.Scene) ([]int, error) {
var ret []int
- m, err := sceneReader.GetMovies(scene.ID)
+ m, err := sceneReader.GetMovies(ctx, scene.ID)
if err != nil {
return nil, err
}
@@ -245,8 +275,8 @@ func GetDependentMovieIDs(sceneReader models.SceneReader, scene *models.Scene) (
// GetSceneMarkersJSON returns a slice of SceneMarker JSON representation
// objects corresponding to the provided scene's markers.
-func GetSceneMarkersJSON(markerReader models.SceneMarkerReader, tagReader models.TagReader, scene *models.Scene) ([]jsonschema.SceneMarker, error) {
- sceneMarkers, err := markerReader.FindBySceneID(scene.ID)
+func GetSceneMarkersJSON(ctx context.Context, markerReader MarkerFinder, tagReader MarkerTagFinder, scene *models.Scene) ([]jsonschema.SceneMarker, error) {
+ sceneMarkers, err := markerReader.FindBySceneID(ctx, scene.ID)
if err != nil {
return nil, fmt.Errorf("error getting scene markers: %v", err)
}
@@ -254,12 +284,12 @@ func GetSceneMarkersJSON(markerReader models.SceneMarkerReader, tagReader models
var results []jsonschema.SceneMarker
for _, sceneMarker := range sceneMarkers {
- primaryTag, err := tagReader.Find(sceneMarker.PrimaryTagID)
+ primaryTag, err := tagReader.Find(ctx, sceneMarker.PrimaryTagID)
if err != nil {
return nil, fmt.Errorf("invalid primary tag for scene marker: %v", err)
}
- sceneMarkerTags, err := tagReader.FindBySceneMarkerID(sceneMarker.ID)
+ sceneMarkerTags, err := tagReader.FindBySceneMarkerID(ctx, sceneMarker.ID)
if err != nil {
return nil, fmt.Errorf("invalid tags for scene marker: %v", err)
}
diff --git a/pkg/scene/export_test.go b/pkg/scene/export_test.go
index aa8b7fb52..ae6efc725 100644
--- a/pkg/scene/export_test.go
+++ b/pkg/scene/export_test.go
@@ -230,16 +230,16 @@ func TestToJSON(t *testing.T) {
imageErr := errors.New("error getting image")
- mockSceneReader.On("GetCover", sceneID).Return(imageBytes, nil).Once()
- mockSceneReader.On("GetCover", noImageID).Return(nil, nil).Once()
- mockSceneReader.On("GetCover", errImageID).Return(nil, imageErr).Once()
+ mockSceneReader.On("GetCover", testCtx, sceneID).Return(imageBytes, nil).Once()
+ mockSceneReader.On("GetCover", testCtx, noImageID).Return(nil, nil).Once()
+ mockSceneReader.On("GetCover", testCtx, errImageID).Return(nil, imageErr).Once()
- mockSceneReader.On("GetStashIDs", sceneID).Return(stashIDs, nil).Once()
- mockSceneReader.On("GetStashIDs", noImageID).Return(nil, nil).Once()
+ mockSceneReader.On("GetStashIDs", testCtx, sceneID).Return(stashIDs, nil).Once()
+ mockSceneReader.On("GetStashIDs", testCtx, noImageID).Return(nil, nil).Once()
for i, s := range scenarios {
scene := s.input
- json, err := ToBasicJSON(mockSceneReader, &scene)
+ json, err := ToBasicJSON(testCtx, mockSceneReader, &scene)
switch {
case !s.err && err != nil:
@@ -289,15 +289,15 @@ func TestGetStudioName(t *testing.T) {
studioErr := errors.New("error getting image")
- mockStudioReader.On("Find", studioID).Return(&models.Studio{
+ mockStudioReader.On("Find", testCtx, studioID).Return(&models.Studio{
Name: models.NullString(studioName),
}, nil).Once()
- mockStudioReader.On("Find", missingStudioID).Return(nil, nil).Once()
- mockStudioReader.On("Find", errStudioID).Return(nil, studioErr).Once()
+ mockStudioReader.On("Find", testCtx, missingStudioID).Return(nil, nil).Once()
+ mockStudioReader.On("Find", testCtx, errStudioID).Return(nil, studioErr).Once()
for i, s := range getStudioScenarios {
scene := s.input
- json, err := GetStudioName(mockStudioReader, &scene)
+ json, err := GetStudioName(testCtx, mockStudioReader, &scene)
switch {
case !s.err && err != nil:
@@ -352,13 +352,13 @@ func TestGetTagNames(t *testing.T) {
tagErr := errors.New("error getting tag")
- mockTagReader.On("FindBySceneID", sceneID).Return(getTags(names), nil).Once()
- mockTagReader.On("FindBySceneID", noTagsID).Return(nil, nil).Once()
- mockTagReader.On("FindBySceneID", errTagsID).Return(nil, tagErr).Once()
+ mockTagReader.On("FindBySceneID", testCtx, sceneID).Return(getTags(names), nil).Once()
+ mockTagReader.On("FindBySceneID", testCtx, noTagsID).Return(nil, nil).Once()
+ mockTagReader.On("FindBySceneID", testCtx, errTagsID).Return(nil, tagErr).Once()
for i, s := range getTagNamesScenarios {
scene := s.input
- json, err := GetTagNames(mockTagReader, &scene)
+ json, err := GetTagNames(testCtx, mockTagReader, &scene)
switch {
case !s.err && err != nil:
@@ -436,22 +436,22 @@ func TestGetSceneMoviesJSON(t *testing.T) {
joinErr := errors.New("error getting scene movies")
movieErr := errors.New("error getting movie")
- mockSceneReader.On("GetMovies", sceneID).Return(validMovies, nil).Once()
- mockSceneReader.On("GetMovies", noMoviesID).Return(nil, nil).Once()
- mockSceneReader.On("GetMovies", errMoviesID).Return(nil, joinErr).Once()
- mockSceneReader.On("GetMovies", errFindMovieID).Return(invalidMovies, nil).Once()
+ mockSceneReader.On("GetMovies", testCtx, sceneID).Return(validMovies, nil).Once()
+ mockSceneReader.On("GetMovies", testCtx, noMoviesID).Return(nil, nil).Once()
+ mockSceneReader.On("GetMovies", testCtx, errMoviesID).Return(nil, joinErr).Once()
+ mockSceneReader.On("GetMovies", testCtx, errFindMovieID).Return(invalidMovies, nil).Once()
- mockMovieReader.On("Find", validMovie1).Return(&models.Movie{
+ mockMovieReader.On("Find", testCtx, validMovie1).Return(&models.Movie{
Name: models.NullString(movie1Name),
}, nil).Once()
- mockMovieReader.On("Find", validMovie2).Return(&models.Movie{
+ mockMovieReader.On("Find", testCtx, validMovie2).Return(&models.Movie{
Name: models.NullString(movie2Name),
}, nil).Once()
- mockMovieReader.On("Find", invalidMovie).Return(nil, movieErr).Once()
+ mockMovieReader.On("Find", testCtx, invalidMovie).Return(nil, movieErr).Once()
for i, s := range getSceneMoviesJSONScenarios {
scene := s.input
- json, err := GetSceneMoviesJSON(mockMovieReader, mockSceneReader, &scene)
+ json, err := GetSceneMoviesJSON(testCtx, mockMovieReader, mockSceneReader, &scene)
switch {
case !s.err && err != nil:
@@ -603,21 +603,21 @@ func TestGetSceneMarkersJSON(t *testing.T) {
markersErr := errors.New("error getting scene markers")
tagErr := errors.New("error getting tags")
- mockMarkerReader.On("FindBySceneID", sceneID).Return(validMarkers, nil).Once()
- mockMarkerReader.On("FindBySceneID", noMarkersID).Return(nil, nil).Once()
- mockMarkerReader.On("FindBySceneID", errMarkersID).Return(nil, markersErr).Once()
- mockMarkerReader.On("FindBySceneID", errFindPrimaryTagID).Return(invalidMarkers1, nil).Once()
- mockMarkerReader.On("FindBySceneID", errFindByMarkerID).Return(invalidMarkers2, nil).Once()
+ mockMarkerReader.On("FindBySceneID", testCtx, sceneID).Return(validMarkers, nil).Once()
+ mockMarkerReader.On("FindBySceneID", testCtx, noMarkersID).Return(nil, nil).Once()
+ mockMarkerReader.On("FindBySceneID", testCtx, errMarkersID).Return(nil, markersErr).Once()
+ mockMarkerReader.On("FindBySceneID", testCtx, errFindPrimaryTagID).Return(invalidMarkers1, nil).Once()
+ mockMarkerReader.On("FindBySceneID", testCtx, errFindByMarkerID).Return(invalidMarkers2, nil).Once()
- mockTagReader.On("Find", validTagID1).Return(&models.Tag{
+ mockTagReader.On("Find", testCtx, validTagID1).Return(&models.Tag{
Name: validTagName1,
}, nil)
- mockTagReader.On("Find", validTagID2).Return(&models.Tag{
+ mockTagReader.On("Find", testCtx, validTagID2).Return(&models.Tag{
Name: validTagName2,
}, nil)
- mockTagReader.On("Find", invalidTagID).Return(nil, tagErr)
+ mockTagReader.On("Find", testCtx, invalidTagID).Return(nil, tagErr)
- mockTagReader.On("FindBySceneMarkerID", validMarkerID1).Return([]*models.Tag{
+ mockTagReader.On("FindBySceneMarkerID", testCtx, validMarkerID1).Return([]*models.Tag{
{
Name: validTagName1,
},
@@ -625,16 +625,16 @@ func TestGetSceneMarkersJSON(t *testing.T) {
Name: validTagName2,
},
}, nil)
- mockTagReader.On("FindBySceneMarkerID", validMarkerID2).Return([]*models.Tag{
+ mockTagReader.On("FindBySceneMarkerID", testCtx, validMarkerID2).Return([]*models.Tag{
{
Name: validTagName2,
},
}, nil)
- mockTagReader.On("FindBySceneMarkerID", invalidMarkerID2).Return(nil, tagErr).Once()
+ mockTagReader.On("FindBySceneMarkerID", testCtx, invalidMarkerID2).Return(nil, tagErr).Once()
for i, s := range getSceneMarkersJSONScenarios {
scene := s.input
- json, err := GetSceneMarkersJSON(mockMarkerReader, mockTagReader, &scene)
+ json, err := GetSceneMarkersJSON(testCtx, mockMarkerReader, mockTagReader, &scene)
switch {
case !s.err && err != nil:
diff --git a/pkg/scene/import.go b/pkg/scene/import.go
index 103be88fd..d7b59cf8b 100644
--- a/pkg/scene/import.go
+++ b/pkg/scene/import.go
@@ -1,24 +1,37 @@
package scene
import (
+ "context"
"database/sql"
"fmt"
"strconv"
"strings"
+ "github.com/stashapp/stash/pkg/gallery"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/models/jsonschema"
+ "github.com/stashapp/stash/pkg/movie"
+ "github.com/stashapp/stash/pkg/performer"
"github.com/stashapp/stash/pkg/sliceutil/stringslice"
+ "github.com/stashapp/stash/pkg/studio"
+ "github.com/stashapp/stash/pkg/tag"
"github.com/stashapp/stash/pkg/utils"
)
+type FullCreatorUpdater interface {
+ CreatorUpdater
+ Updater
+ UpdateGalleries(ctx context.Context, sceneID int, galleryIDs []int) error
+ UpdateMovies(ctx context.Context, sceneID int, movies []models.MoviesScenes) error
+}
+
type Importer struct {
- ReaderWriter models.SceneReaderWriter
- StudioWriter models.StudioReaderWriter
- GalleryWriter models.GalleryReaderWriter
- PerformerWriter models.PerformerReaderWriter
- MovieWriter models.MovieReaderWriter
- TagWriter models.TagReaderWriter
+ ReaderWriter FullCreatorUpdater
+ StudioWriter studio.NameFinderCreator
+ GalleryWriter gallery.ChecksumsFinder
+ PerformerWriter performer.NameFinderCreator
+ MovieWriter movie.NameFinderCreator
+ TagWriter tag.NameFinderCreator
Input jsonschema.Scene
Path string
MissingRefBehaviour models.ImportMissingRefEnum
@@ -33,26 +46,26 @@ type Importer struct {
coverImageData []byte
}
-func (i *Importer) PreImport() error {
+func (i *Importer) PreImport(ctx context.Context) error {
i.scene = i.sceneJSONToScene(i.Input)
- if err := i.populateStudio(); err != nil {
+ if err := i.populateStudio(ctx); err != nil {
return err
}
- if err := i.populateGalleries(); err != nil {
+ if err := i.populateGalleries(ctx); err != nil {
return err
}
- if err := i.populatePerformers(); err != nil {
+ if err := i.populatePerformers(ctx); err != nil {
return err
}
- if err := i.populateTags(); err != nil {
+ if err := i.populateTags(ctx); err != nil {
return err
}
- if err := i.populateMovies(); err != nil {
+ if err := i.populateMovies(ctx); err != nil {
return err
}
@@ -135,9 +148,9 @@ func (i *Importer) sceneJSONToScene(sceneJSON jsonschema.Scene) models.Scene {
return newScene
}
-func (i *Importer) populateStudio() error {
+func (i *Importer) populateStudio(ctx context.Context) error {
if i.Input.Studio != "" {
- studio, err := i.StudioWriter.FindByName(i.Input.Studio, false)
+ studio, err := i.StudioWriter.FindByName(ctx, i.Input.Studio, false)
if err != nil {
return fmt.Errorf("error finding studio by name: %v", err)
}
@@ -152,7 +165,7 @@ func (i *Importer) populateStudio() error {
}
if i.MissingRefBehaviour == models.ImportMissingRefEnumCreate {
- studioID, err := i.createStudio(i.Input.Studio)
+ studioID, err := i.createStudio(ctx, i.Input.Studio)
if err != nil {
return err
}
@@ -169,10 +182,10 @@ func (i *Importer) populateStudio() error {
return nil
}
-func (i *Importer) createStudio(name string) (int, error) {
+func (i *Importer) createStudio(ctx context.Context, name string) (int, error) {
newStudio := *models.NewStudio(name)
- created, err := i.StudioWriter.Create(newStudio)
+ created, err := i.StudioWriter.Create(ctx, newStudio)
if err != nil {
return 0, err
}
@@ -180,10 +193,10 @@ func (i *Importer) createStudio(name string) (int, error) {
return created.ID, nil
}
-func (i *Importer) populateGalleries() error {
+func (i *Importer) populateGalleries(ctx context.Context) error {
if len(i.Input.Galleries) > 0 {
checksums := i.Input.Galleries
- galleries, err := i.GalleryWriter.FindByChecksums(checksums)
+ galleries, err := i.GalleryWriter.FindByChecksums(ctx, checksums)
if err != nil {
return err
}
@@ -211,10 +224,10 @@ func (i *Importer) populateGalleries() error {
return nil
}
-func (i *Importer) populatePerformers() error {
+func (i *Importer) populatePerformers(ctx context.Context) error {
if len(i.Input.Performers) > 0 {
names := i.Input.Performers
- performers, err := i.PerformerWriter.FindByNames(names, false)
+ performers, err := i.PerformerWriter.FindByNames(ctx, names, false)
if err != nil {
return err
}
@@ -237,7 +250,7 @@ func (i *Importer) populatePerformers() error {
}
if i.MissingRefBehaviour == models.ImportMissingRefEnumCreate {
- createdPerformers, err := i.createPerformers(missingPerformers)
+ createdPerformers, err := i.createPerformers(ctx, missingPerformers)
if err != nil {
return fmt.Errorf("error creating scene performers: %v", err)
}
@@ -254,12 +267,12 @@ func (i *Importer) populatePerformers() error {
return nil
}
-func (i *Importer) createPerformers(names []string) ([]*models.Performer, error) {
+func (i *Importer) createPerformers(ctx context.Context, names []string) ([]*models.Performer, error) {
var ret []*models.Performer
for _, name := range names {
newPerformer := *models.NewPerformer(name)
- created, err := i.PerformerWriter.Create(newPerformer)
+ created, err := i.PerformerWriter.Create(ctx, newPerformer)
if err != nil {
return nil, err
}
@@ -270,10 +283,10 @@ func (i *Importer) createPerformers(names []string) ([]*models.Performer, error)
return ret, nil
}
-func (i *Importer) populateMovies() error {
+func (i *Importer) populateMovies(ctx context.Context) error {
if len(i.Input.Movies) > 0 {
for _, inputMovie := range i.Input.Movies {
- movie, err := i.MovieWriter.FindByName(inputMovie.MovieName, false)
+ movie, err := i.MovieWriter.FindByName(ctx, inputMovie.MovieName, false)
if err != nil {
return fmt.Errorf("error finding scene movie: %v", err)
}
@@ -284,7 +297,7 @@ func (i *Importer) populateMovies() error {
}
if i.MissingRefBehaviour == models.ImportMissingRefEnumCreate {
- movie, err = i.createMovie(inputMovie.MovieName)
+ movie, err = i.createMovie(ctx, inputMovie.MovieName)
if err != nil {
return fmt.Errorf("error creating scene movie: %v", err)
}
@@ -314,10 +327,10 @@ func (i *Importer) populateMovies() error {
return nil
}
-func (i *Importer) createMovie(name string) (*models.Movie, error) {
+func (i *Importer) createMovie(ctx context.Context, name string) (*models.Movie, error) {
newMovie := *models.NewMovie(name)
- created, err := i.MovieWriter.Create(newMovie)
+ created, err := i.MovieWriter.Create(ctx, newMovie)
if err != nil {
return nil, err
}
@@ -325,10 +338,10 @@ func (i *Importer) createMovie(name string) (*models.Movie, error) {
return created, nil
}
-func (i *Importer) populateTags() error {
+func (i *Importer) populateTags(ctx context.Context) error {
if len(i.Input.Tags) > 0 {
- tags, err := importTags(i.TagWriter, i.Input.Tags, i.MissingRefBehaviour)
+ tags, err := importTags(ctx, i.TagWriter, i.Input.Tags, i.MissingRefBehaviour)
if err != nil {
return err
}
@@ -339,9 +352,9 @@ func (i *Importer) populateTags() error {
return nil
}
-func (i *Importer) PostImport(id int) error {
+func (i *Importer) PostImport(ctx context.Context, id int) error {
if len(i.coverImageData) > 0 {
- if err := i.ReaderWriter.UpdateCover(id, i.coverImageData); err != nil {
+ if err := i.ReaderWriter.UpdateCover(ctx, id, i.coverImageData); err != nil {
return fmt.Errorf("error setting scene images: %v", err)
}
}
@@ -352,7 +365,7 @@ func (i *Importer) PostImport(id int) error {
galleryIDs = append(galleryIDs, gallery.ID)
}
- if err := i.ReaderWriter.UpdateGalleries(id, galleryIDs); err != nil {
+ if err := i.ReaderWriter.UpdateGalleries(ctx, id, galleryIDs); err != nil {
return fmt.Errorf("failed to associate galleries: %v", err)
}
}
@@ -363,7 +376,7 @@ func (i *Importer) PostImport(id int) error {
performerIDs = append(performerIDs, performer.ID)
}
- if err := i.ReaderWriter.UpdatePerformers(id, performerIDs); err != nil {
+ if err := i.ReaderWriter.UpdatePerformers(ctx, id, performerIDs); err != nil {
return fmt.Errorf("failed to associate performers: %v", err)
}
}
@@ -372,7 +385,7 @@ func (i *Importer) PostImport(id int) error {
for index := range i.movies {
i.movies[index].SceneID = id
}
- if err := i.ReaderWriter.UpdateMovies(id, i.movies); err != nil {
+ if err := i.ReaderWriter.UpdateMovies(ctx, id, i.movies); err != nil {
return fmt.Errorf("failed to associate movies: %v", err)
}
}
@@ -382,13 +395,13 @@ func (i *Importer) PostImport(id int) error {
for _, t := range i.tags {
tagIDs = append(tagIDs, t.ID)
}
- if err := i.ReaderWriter.UpdateTags(id, tagIDs); err != nil {
+ if err := i.ReaderWriter.UpdateTags(ctx, id, tagIDs); err != nil {
return fmt.Errorf("failed to associate tags: %v", err)
}
}
if len(i.Input.StashIDs) > 0 {
- if err := i.ReaderWriter.UpdateStashIDs(id, i.Input.StashIDs); err != nil {
+ if err := i.ReaderWriter.UpdateStashIDs(ctx, id, i.Input.StashIDs); err != nil {
return fmt.Errorf("error setting stash id: %v", err)
}
}
@@ -400,15 +413,15 @@ func (i *Importer) Name() string {
return i.Path
}
-func (i *Importer) FindExistingID() (*int, error) {
+func (i *Importer) FindExistingID(ctx context.Context) (*int, error) {
var existing *models.Scene
var err error
switch i.FileNamingAlgorithm {
case models.HashAlgorithmMd5:
- existing, err = i.ReaderWriter.FindByChecksum(i.Input.Checksum)
+ existing, err = i.ReaderWriter.FindByChecksum(ctx, i.Input.Checksum)
case models.HashAlgorithmOshash:
- existing, err = i.ReaderWriter.FindByOSHash(i.Input.OSHash)
+ existing, err = i.ReaderWriter.FindByOSHash(ctx, i.Input.OSHash)
default:
panic("unknown file naming algorithm")
}
@@ -425,8 +438,8 @@ func (i *Importer) FindExistingID() (*int, error) {
return nil, nil
}
-func (i *Importer) Create() (*int, error) {
- created, err := i.ReaderWriter.Create(i.scene)
+func (i *Importer) Create(ctx context.Context) (*int, error) {
+ created, err := i.ReaderWriter.Create(ctx, i.scene)
if err != nil {
return nil, fmt.Errorf("error creating scene: %v", err)
}
@@ -436,11 +449,11 @@ func (i *Importer) Create() (*int, error) {
return &id, nil
}
-func (i *Importer) Update(id int) error {
+func (i *Importer) Update(ctx context.Context, id int) error {
scene := i.scene
scene.ID = id
i.ID = id
- _, err := i.ReaderWriter.UpdateFull(scene)
+ _, err := i.ReaderWriter.UpdateFull(ctx, scene)
if err != nil {
return fmt.Errorf("error updating existing scene: %v", err)
}
@@ -448,8 +461,8 @@ func (i *Importer) Update(id int) error {
return nil
}
-func importTags(tagWriter models.TagReaderWriter, names []string, missingRefBehaviour models.ImportMissingRefEnum) ([]*models.Tag, error) {
- tags, err := tagWriter.FindByNames(names, false)
+func importTags(ctx context.Context, tagWriter tag.NameFinderCreator, names []string, missingRefBehaviour models.ImportMissingRefEnum) ([]*models.Tag, error) {
+ tags, err := tagWriter.FindByNames(ctx, names, false)
if err != nil {
return nil, err
}
@@ -469,7 +482,7 @@ func importTags(tagWriter models.TagReaderWriter, names []string, missingRefBeha
}
if missingRefBehaviour == models.ImportMissingRefEnumCreate {
- createdTags, err := createTags(tagWriter, missingTags)
+ createdTags, err := createTags(ctx, tagWriter, missingTags)
if err != nil {
return nil, fmt.Errorf("error creating tags: %v", err)
}
@@ -483,12 +496,12 @@ func importTags(tagWriter models.TagReaderWriter, names []string, missingRefBeha
return tags, nil
}
-func createTags(tagWriter models.TagWriter, names []string) ([]*models.Tag, error) {
+func createTags(ctx context.Context, tagWriter tag.NameFinderCreator, names []string) ([]*models.Tag, error) {
var ret []*models.Tag
for _, name := range names {
newTag := *models.NewTag(name)
- created, err := tagWriter.Create(newTag)
+ created, err := tagWriter.Create(ctx, newTag)
if err != nil {
return nil, err
}
diff --git a/pkg/scene/import_test.go b/pkg/scene/import_test.go
index 499f27299..75dab2200 100644
--- a/pkg/scene/import_test.go
+++ b/pkg/scene/import_test.go
@@ -1,6 +1,7 @@
package scene
import (
+ "context"
"errors"
"testing"
@@ -55,6 +56,8 @@ const (
errOSHash = "errOSHash"
)
+var testCtx = context.Background()
+
func TestImporterName(t *testing.T) {
i := Importer{
Path: path,
@@ -72,17 +75,18 @@ func TestImporterPreImport(t *testing.T) {
},
}
- err := i.PreImport()
+ err := i.PreImport(testCtx)
assert.NotNil(t, err)
i.Input.Cover = imageBase64
- err = i.PreImport()
+ err = i.PreImport(testCtx)
assert.Nil(t, err)
}
func TestImporterPreImportWithStudio(t *testing.T) {
studioReaderWriter := &mocks.StudioReaderWriter{}
+ testCtx := context.Background()
i := Importer{
StudioWriter: studioReaderWriter,
@@ -92,17 +96,17 @@ func TestImporterPreImportWithStudio(t *testing.T) {
},
}
- studioReaderWriter.On("FindByName", existingStudioName, false).Return(&models.Studio{
+ studioReaderWriter.On("FindByName", testCtx, existingStudioName, false).Return(&models.Studio{
ID: existingStudioID,
}, nil).Once()
- studioReaderWriter.On("FindByName", existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once()
+ studioReaderWriter.On("FindByName", testCtx, existingStudioErr, false).Return(nil, errors.New("FindByName error")).Once()
- err := i.PreImport()
+ err := i.PreImport(testCtx)
assert.Nil(t, err)
assert.Equal(t, int64(existingStudioID), i.scene.StudioID.Int64)
i.Input.Studio = existingStudioErr
- err = i.PreImport()
+ err = i.PreImport(testCtx)
assert.NotNil(t, err)
studioReaderWriter.AssertExpectations(t)
@@ -120,20 +124,20 @@ func TestImporterPreImportWithMissingStudio(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail,
}
- studioReaderWriter.On("FindByName", missingStudioName, false).Return(nil, nil).Times(3)
- studioReaderWriter.On("Create", mock.AnythingOfType("models.Studio")).Return(&models.Studio{
+ studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Times(3)
+ studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Studio")).Return(&models.Studio{
ID: existingStudioID,
}, nil)
- err := i.PreImport()
+ err := i.PreImport(testCtx)
assert.NotNil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore
- err = i.PreImport()
+ err = i.PreImport(testCtx)
assert.Nil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumCreate
- err = i.PreImport()
+ err = i.PreImport(testCtx)
assert.Nil(t, err)
assert.Equal(t, int64(existingStudioID), i.scene.StudioID.Int64)
@@ -152,10 +156,10 @@ func TestImporterPreImportWithMissingStudioCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
}
- studioReaderWriter.On("FindByName", missingStudioName, false).Return(nil, nil).Once()
- studioReaderWriter.On("Create", mock.AnythingOfType("models.Studio")).Return(nil, errors.New("Create error"))
+ studioReaderWriter.On("FindByName", testCtx, missingStudioName, false).Return(nil, nil).Once()
+ studioReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Studio")).Return(nil, errors.New("Create error"))
- err := i.PreImport()
+ err := i.PreImport(testCtx)
assert.NotNil(t, err)
}
@@ -173,21 +177,21 @@ func TestImporterPreImportWithGallery(t *testing.T) {
},
}
- galleryReaderWriter.On("FindByChecksums", []string{existingGalleryChecksum}).Return([]*models.Gallery{
+ galleryReaderWriter.On("FindByChecksums", testCtx, []string{existingGalleryChecksum}).Return([]*models.Gallery{
{
ID: existingGalleryID,
Checksum: existingGalleryChecksum,
},
}, nil).Once()
- galleryReaderWriter.On("FindByChecksums", []string{existingGalleryErr}).Return(nil, errors.New("FindByChecksums error")).Once()
+ galleryReaderWriter.On("FindByChecksums", testCtx, []string{existingGalleryErr}).Return(nil, errors.New("FindByChecksums error")).Once()
- err := i.PreImport()
+ err := i.PreImport(testCtx)
assert.Nil(t, err)
assert.Equal(t, existingGalleryID, i.galleries[0].ID)
i.Input.Galleries = []string{existingGalleryErr}
- err = i.PreImport()
+ err = i.PreImport(testCtx)
assert.NotNil(t, err)
galleryReaderWriter.AssertExpectations(t)
@@ -207,17 +211,17 @@ func TestImporterPreImportWithMissingGallery(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail,
}
- galleryReaderWriter.On("FindByChecksums", []string{missingGalleryChecksum}).Return(nil, nil).Times(3)
+ galleryReaderWriter.On("FindByChecksums", testCtx, []string{missingGalleryChecksum}).Return(nil, nil).Times(3)
- err := i.PreImport()
+ err := i.PreImport(testCtx)
assert.NotNil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore
- err = i.PreImport()
+ err = i.PreImport(testCtx)
assert.Nil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumCreate
- err = i.PreImport()
+ err = i.PreImport(testCtx)
assert.Nil(t, err)
galleryReaderWriter.AssertExpectations(t)
@@ -237,20 +241,20 @@ func TestImporterPreImportWithPerformer(t *testing.T) {
},
}
- performerReaderWriter.On("FindByNames", []string{existingPerformerName}, false).Return([]*models.Performer{
+ performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerName}, false).Return([]*models.Performer{
{
ID: existingPerformerID,
Name: models.NullString(existingPerformerName),
},
}, nil).Once()
- performerReaderWriter.On("FindByNames", []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once()
+ performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once()
- err := i.PreImport()
+ err := i.PreImport(testCtx)
assert.Nil(t, err)
assert.Equal(t, existingPerformerID, i.performers[0].ID)
i.Input.Performers = []string{existingPerformerErr}
- err = i.PreImport()
+ err = i.PreImport(testCtx)
assert.NotNil(t, err)
performerReaderWriter.AssertExpectations(t)
@@ -270,20 +274,20 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail,
}
- performerReaderWriter.On("FindByNames", []string{missingPerformerName}, false).Return(nil, nil).Times(3)
- performerReaderWriter.On("Create", mock.AnythingOfType("models.Performer")).Return(&models.Performer{
+ performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3)
+ performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Performer")).Return(&models.Performer{
ID: existingPerformerID,
}, nil)
- err := i.PreImport()
+ err := i.PreImport(testCtx)
assert.NotNil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore
- err = i.PreImport()
+ err = i.PreImport(testCtx)
assert.Nil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumCreate
- err = i.PreImport()
+ err = i.PreImport(testCtx)
assert.Nil(t, err)
assert.Equal(t, existingPerformerID, i.performers[0].ID)
@@ -304,15 +308,16 @@ func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
}
- performerReaderWriter.On("FindByNames", []string{missingPerformerName}, false).Return(nil, nil).Once()
- performerReaderWriter.On("Create", mock.AnythingOfType("models.Performer")).Return(nil, errors.New("Create error"))
+ performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once()
+ performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Performer")).Return(nil, errors.New("Create error"))
- err := i.PreImport()
+ err := i.PreImport(testCtx)
assert.NotNil(t, err)
}
func TestImporterPreImportWithMovie(t *testing.T) {
movieReaderWriter := &mocks.MovieReaderWriter{}
+ testCtx := context.Background()
i := Importer{
MovieWriter: movieReaderWriter,
@@ -328,18 +333,18 @@ func TestImporterPreImportWithMovie(t *testing.T) {
},
}
- movieReaderWriter.On("FindByName", existingMovieName, false).Return(&models.Movie{
+ movieReaderWriter.On("FindByName", testCtx, existingMovieName, false).Return(&models.Movie{
ID: existingMovieID,
Name: models.NullString(existingMovieName),
}, nil).Once()
- movieReaderWriter.On("FindByName", existingMovieErr, false).Return(nil, errors.New("FindByName error")).Once()
+ movieReaderWriter.On("FindByName", testCtx, existingMovieErr, false).Return(nil, errors.New("FindByName error")).Once()
- err := i.PreImport()
+ err := i.PreImport(testCtx)
assert.Nil(t, err)
assert.Equal(t, existingMovieID, i.movies[0].MovieID)
i.Input.Movies[0].MovieName = existingMovieErr
- err = i.PreImport()
+ err = i.PreImport(testCtx)
assert.NotNil(t, err)
movieReaderWriter.AssertExpectations(t)
@@ -347,6 +352,7 @@ func TestImporterPreImportWithMovie(t *testing.T) {
func TestImporterPreImportWithMissingMovie(t *testing.T) {
movieReaderWriter := &mocks.MovieReaderWriter{}
+ testCtx := context.Background()
i := Importer{
Path: path,
@@ -361,20 +367,20 @@ func TestImporterPreImportWithMissingMovie(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail,
}
- movieReaderWriter.On("FindByName", missingMovieName, false).Return(nil, nil).Times(3)
- movieReaderWriter.On("Create", mock.AnythingOfType("models.Movie")).Return(&models.Movie{
+ movieReaderWriter.On("FindByName", testCtx, missingMovieName, false).Return(nil, nil).Times(3)
+ movieReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Movie")).Return(&models.Movie{
ID: existingMovieID,
}, nil)
- err := i.PreImport()
+ err := i.PreImport(testCtx)
assert.NotNil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore
- err = i.PreImport()
+ err = i.PreImport(testCtx)
assert.Nil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumCreate
- err = i.PreImport()
+ err = i.PreImport(testCtx)
assert.Nil(t, err)
assert.Equal(t, existingMovieID, i.movies[0].MovieID)
@@ -397,10 +403,10 @@ func TestImporterPreImportWithMissingMovieCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
}
- movieReaderWriter.On("FindByName", missingMovieName, false).Return(nil, nil).Once()
- movieReaderWriter.On("Create", mock.AnythingOfType("models.Movie")).Return(nil, errors.New("Create error"))
+ movieReaderWriter.On("FindByName", testCtx, missingMovieName, false).Return(nil, nil).Once()
+ movieReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Movie")).Return(nil, errors.New("Create error"))
- err := i.PreImport()
+ err := i.PreImport(testCtx)
assert.NotNil(t, err)
}
@@ -418,20 +424,20 @@ func TestImporterPreImportWithTag(t *testing.T) {
},
}
- tagReaderWriter.On("FindByNames", []string{existingTagName}, false).Return([]*models.Tag{
+ tagReaderWriter.On("FindByNames", testCtx, []string{existingTagName}, false).Return([]*models.Tag{
{
ID: existingTagID,
Name: existingTagName,
},
}, nil).Once()
- tagReaderWriter.On("FindByNames", []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once()
+ tagReaderWriter.On("FindByNames", testCtx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Once()
- err := i.PreImport()
+ err := i.PreImport(testCtx)
assert.Nil(t, err)
assert.Equal(t, existingTagID, i.tags[0].ID)
i.Input.Tags = []string{existingTagErr}
- err = i.PreImport()
+ err = i.PreImport(testCtx)
assert.NotNil(t, err)
tagReaderWriter.AssertExpectations(t)
@@ -451,20 +457,20 @@ func TestImporterPreImportWithMissingTag(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail,
}
- tagReaderWriter.On("FindByNames", []string{missingTagName}, false).Return(nil, nil).Times(3)
- tagReaderWriter.On("Create", mock.AnythingOfType("models.Tag")).Return(&models.Tag{
+ tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3)
+ tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Tag")).Return(&models.Tag{
ID: existingTagID,
}, nil)
- err := i.PreImport()
+ err := i.PreImport(testCtx)
assert.NotNil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore
- err = i.PreImport()
+ err = i.PreImport(testCtx)
assert.Nil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumCreate
- err = i.PreImport()
+ err = i.PreImport(testCtx)
assert.Nil(t, err)
assert.Equal(t, existingTagID, i.tags[0].ID)
@@ -485,10 +491,10 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
}
- tagReaderWriter.On("FindByNames", []string{missingTagName}, false).Return(nil, nil).Once()
- tagReaderWriter.On("Create", mock.AnythingOfType("models.Tag")).Return(nil, errors.New("Create error"))
+ tagReaderWriter.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once()
+ tagReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Tag")).Return(nil, errors.New("Create error"))
- err := i.PreImport()
+ err := i.PreImport(testCtx)
assert.NotNil(t, err)
}
@@ -502,13 +508,13 @@ func TestImporterPostImport(t *testing.T) {
updateSceneImageErr := errors.New("UpdateCover error")
- readerWriter.On("UpdateCover", sceneID, imageBytes).Return(nil).Once()
- readerWriter.On("UpdateCover", errImageID, imageBytes).Return(updateSceneImageErr).Once()
+ readerWriter.On("UpdateCover", testCtx, sceneID, imageBytes).Return(nil).Once()
+ readerWriter.On("UpdateCover", testCtx, errImageID, imageBytes).Return(updateSceneImageErr).Once()
- err := i.PostImport(sceneID)
+ err := i.PostImport(testCtx, sceneID)
assert.Nil(t, err)
- err = i.PostImport(errImageID)
+ err = i.PostImport(testCtx, errImageID)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
@@ -528,13 +534,13 @@ func TestImporterPostImportUpdateGalleries(t *testing.T) {
updateErr := errors.New("UpdateGalleries error")
- sceneReaderWriter.On("UpdateGalleries", sceneID, []int{existingGalleryID}).Return(nil).Once()
- sceneReaderWriter.On("UpdateGalleries", errGalleriesID, mock.AnythingOfType("[]int")).Return(updateErr).Once()
+ sceneReaderWriter.On("UpdateGalleries", testCtx, sceneID, []int{existingGalleryID}).Return(nil).Once()
+ sceneReaderWriter.On("UpdateGalleries", testCtx, errGalleriesID, mock.AnythingOfType("[]int")).Return(updateErr).Once()
- err := i.PostImport(sceneID)
+ err := i.PostImport(testCtx, sceneID)
assert.Nil(t, err)
- err = i.PostImport(errGalleriesID)
+ err = i.PostImport(testCtx, errGalleriesID)
assert.NotNil(t, err)
sceneReaderWriter.AssertExpectations(t)
@@ -554,13 +560,13 @@ func TestImporterPostImportUpdatePerformers(t *testing.T) {
updateErr := errors.New("UpdatePerformers error")
- sceneReaderWriter.On("UpdatePerformers", sceneID, []int{existingPerformerID}).Return(nil).Once()
- sceneReaderWriter.On("UpdatePerformers", errPerformersID, mock.AnythingOfType("[]int")).Return(updateErr).Once()
+ sceneReaderWriter.On("UpdatePerformers", testCtx, sceneID, []int{existingPerformerID}).Return(nil).Once()
+ sceneReaderWriter.On("UpdatePerformers", testCtx, errPerformersID, mock.AnythingOfType("[]int")).Return(updateErr).Once()
- err := i.PostImport(sceneID)
+ err := i.PostImport(testCtx, sceneID)
assert.Nil(t, err)
- err = i.PostImport(errPerformersID)
+ err = i.PostImport(testCtx, errPerformersID)
assert.NotNil(t, err)
sceneReaderWriter.AssertExpectations(t)
@@ -580,18 +586,18 @@ func TestImporterPostImportUpdateMovies(t *testing.T) {
updateErr := errors.New("UpdateMovies error")
- sceneReaderWriter.On("UpdateMovies", sceneID, []models.MoviesScenes{
+ sceneReaderWriter.On("UpdateMovies", testCtx, sceneID, []models.MoviesScenes{
{
MovieID: existingMovieID,
SceneID: sceneID,
},
}).Return(nil).Once()
- sceneReaderWriter.On("UpdateMovies", errMoviesID, mock.AnythingOfType("[]models.MoviesScenes")).Return(updateErr).Once()
+ sceneReaderWriter.On("UpdateMovies", testCtx, errMoviesID, mock.AnythingOfType("[]models.MoviesScenes")).Return(updateErr).Once()
- err := i.PostImport(sceneID)
+ err := i.PostImport(testCtx, sceneID)
assert.Nil(t, err)
- err = i.PostImport(errMoviesID)
+ err = i.PostImport(testCtx, errMoviesID)
assert.NotNil(t, err)
sceneReaderWriter.AssertExpectations(t)
@@ -611,13 +617,13 @@ func TestImporterPostImportUpdateTags(t *testing.T) {
updateErr := errors.New("UpdateTags error")
- sceneReaderWriter.On("UpdateTags", sceneID, []int{existingTagID}).Return(nil).Once()
- sceneReaderWriter.On("UpdateTags", errTagsID, mock.AnythingOfType("[]int")).Return(updateErr).Once()
+ sceneReaderWriter.On("UpdateTags", testCtx, sceneID, []int{existingTagID}).Return(nil).Once()
+ sceneReaderWriter.On("UpdateTags", testCtx, errTagsID, mock.AnythingOfType("[]int")).Return(updateErr).Once()
- err := i.PostImport(sceneID)
+ err := i.PostImport(testCtx, sceneID)
assert.Nil(t, err)
- err = i.PostImport(errTagsID)
+ err = i.PostImport(testCtx, errTagsID)
assert.NotNil(t, err)
sceneReaderWriter.AssertExpectations(t)
@@ -637,44 +643,44 @@ func TestImporterFindExistingID(t *testing.T) {
}
expectedErr := errors.New("FindBy* error")
- readerWriter.On("FindByChecksum", missingChecksum).Return(nil, nil).Once()
- readerWriter.On("FindByChecksum", checksum).Return(&models.Scene{
+ readerWriter.On("FindByChecksum", testCtx, missingChecksum).Return(nil, nil).Once()
+ readerWriter.On("FindByChecksum", testCtx, checksum).Return(&models.Scene{
ID: existingSceneID,
}, nil).Once()
- readerWriter.On("FindByChecksum", errChecksum).Return(nil, expectedErr).Once()
+ readerWriter.On("FindByChecksum", testCtx, errChecksum).Return(nil, expectedErr).Once()
- readerWriter.On("FindByOSHash", missingOSHash).Return(nil, nil).Once()
- readerWriter.On("FindByOSHash", oshash).Return(&models.Scene{
+ readerWriter.On("FindByOSHash", testCtx, missingOSHash).Return(nil, nil).Once()
+ readerWriter.On("FindByOSHash", testCtx, oshash).Return(&models.Scene{
ID: existingSceneID,
}, nil).Once()
- readerWriter.On("FindByOSHash", errOSHash).Return(nil, expectedErr).Once()
+ readerWriter.On("FindByOSHash", testCtx, errOSHash).Return(nil, expectedErr).Once()
- id, err := i.FindExistingID()
+ id, err := i.FindExistingID(testCtx)
assert.Nil(t, id)
assert.Nil(t, err)
i.Input.Checksum = checksum
- id, err = i.FindExistingID()
+ id, err = i.FindExistingID(testCtx)
assert.Equal(t, existingSceneID, *id)
assert.Nil(t, err)
i.Input.Checksum = errChecksum
- id, err = i.FindExistingID()
+ id, err = i.FindExistingID(testCtx)
assert.Nil(t, id)
assert.NotNil(t, err)
i.FileNamingAlgorithm = models.HashAlgorithmOshash
- id, err = i.FindExistingID()
+ id, err = i.FindExistingID(testCtx)
assert.Nil(t, id)
assert.Nil(t, err)
i.Input.OSHash = oshash
- id, err = i.FindExistingID()
+ id, err = i.FindExistingID(testCtx)
assert.Equal(t, existingSceneID, *id)
assert.Nil(t, err)
i.Input.OSHash = errOSHash
- id, err = i.FindExistingID()
+ id, err = i.FindExistingID(testCtx)
assert.Nil(t, id)
assert.NotNil(t, err)
@@ -698,18 +704,18 @@ func TestCreate(t *testing.T) {
}
errCreate := errors.New("Create error")
- readerWriter.On("Create", scene).Return(&models.Scene{
+ readerWriter.On("Create", testCtx, scene).Return(&models.Scene{
ID: sceneID,
}, nil).Once()
- readerWriter.On("Create", sceneErr).Return(nil, errCreate).Once()
+ readerWriter.On("Create", testCtx, sceneErr).Return(nil, errCreate).Once()
- id, err := i.Create()
+ id, err := i.Create(testCtx)
assert.Equal(t, sceneID, *id)
assert.Nil(t, err)
assert.Equal(t, sceneID, i.ID)
i.scene = sceneErr
- id, err = i.Create()
+ id, err = i.Create(testCtx)
assert.Nil(t, id)
assert.NotNil(t, err)
@@ -736,9 +742,9 @@ func TestUpdate(t *testing.T) {
// id needs to be set for the mock input
scene.ID = sceneID
- readerWriter.On("UpdateFull", scene).Return(nil, nil).Once()
+ readerWriter.On("UpdateFull", testCtx, scene).Return(nil, nil).Once()
- err := i.Update(sceneID)
+ err := i.Update(testCtx, sceneID)
assert.Nil(t, err)
assert.Equal(t, sceneID, i.ID)
@@ -746,9 +752,9 @@ func TestUpdate(t *testing.T) {
// need to set id separately
sceneErr.ID = errImageID
- readerWriter.On("UpdateFull", sceneErr).Return(nil, errUpdate).Once()
+ readerWriter.On("UpdateFull", testCtx, sceneErr).Return(nil, errUpdate).Once()
- err = i.Update(errImageID)
+ err = i.Update(testCtx, errImageID)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
diff --git a/pkg/scene/marker_import.go b/pkg/scene/marker_import.go
index 530d025ea..32f6deb65 100644
--- a/pkg/scene/marker_import.go
+++ b/pkg/scene/marker_import.go
@@ -1,18 +1,27 @@
package scene
import (
+ "context"
"database/sql"
"fmt"
"strconv"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/models/jsonschema"
+ "github.com/stashapp/stash/pkg/tag"
)
+type MarkerCreatorUpdater interface {
+ Create(ctx context.Context, newSceneMarker models.SceneMarker) (*models.SceneMarker, error)
+ Update(ctx context.Context, updatedSceneMarker models.SceneMarker) (*models.SceneMarker, error)
+ FindBySceneID(ctx context.Context, sceneID int) ([]*models.SceneMarker, error)
+ UpdateTags(ctx context.Context, markerID int, tagIDs []int) error
+}
+
type MarkerImporter struct {
SceneID int
- ReaderWriter models.SceneMarkerReaderWriter
- TagWriter models.TagReaderWriter
+ ReaderWriter MarkerCreatorUpdater
+ TagWriter tag.NameFinderCreator
Input jsonschema.SceneMarker
MissingRefBehaviour models.ImportMissingRefEnum
@@ -20,7 +29,7 @@ type MarkerImporter struct {
marker models.SceneMarker
}
-func (i *MarkerImporter) PreImport() error {
+func (i *MarkerImporter) PreImport(ctx context.Context) error {
seconds, _ := strconv.ParseFloat(i.Input.Seconds, 64)
i.marker = models.SceneMarker{
Title: i.Input.Title,
@@ -30,21 +39,21 @@ func (i *MarkerImporter) PreImport() error {
UpdatedAt: models.SQLiteTimestamp{Timestamp: i.Input.UpdatedAt.GetTime()},
}
- if err := i.populateTags(); err != nil {
+ if err := i.populateTags(ctx); err != nil {
return err
}
return nil
}
-func (i *MarkerImporter) populateTags() error {
+func (i *MarkerImporter) populateTags(ctx context.Context) error {
// primary tag cannot be ignored
mrb := i.MissingRefBehaviour
if mrb == models.ImportMissingRefEnumIgnore {
mrb = models.ImportMissingRefEnumFail
}
- primaryTag, err := importTags(i.TagWriter, []string{i.Input.PrimaryTag}, mrb)
+ primaryTag, err := importTags(ctx, i.TagWriter, []string{i.Input.PrimaryTag}, mrb)
if err != nil {
return err
}
@@ -52,7 +61,7 @@ func (i *MarkerImporter) populateTags() error {
i.marker.PrimaryTagID = primaryTag[0].ID
if len(i.Input.Tags) > 0 {
- tags, err := importTags(i.TagWriter, i.Input.Tags, i.MissingRefBehaviour)
+ tags, err := importTags(ctx, i.TagWriter, i.Input.Tags, i.MissingRefBehaviour)
if err != nil {
return err
}
@@ -63,13 +72,13 @@ func (i *MarkerImporter) populateTags() error {
return nil
}
-func (i *MarkerImporter) PostImport(id int) error {
+func (i *MarkerImporter) PostImport(ctx context.Context, id int) error {
if len(i.tags) > 0 {
var tagIDs []int
for _, t := range i.tags {
tagIDs = append(tagIDs, t.ID)
}
- if err := i.ReaderWriter.UpdateTags(id, tagIDs); err != nil {
+ if err := i.ReaderWriter.UpdateTags(ctx, id, tagIDs); err != nil {
return fmt.Errorf("failed to associate tags: %v", err)
}
}
@@ -81,8 +90,8 @@ func (i *MarkerImporter) Name() string {
return fmt.Sprintf("%s (%s)", i.Input.Title, i.Input.Seconds)
}
-func (i *MarkerImporter) FindExistingID() (*int, error) {
- existingMarkers, err := i.ReaderWriter.FindBySceneID(i.SceneID)
+func (i *MarkerImporter) FindExistingID(ctx context.Context) (*int, error) {
+ existingMarkers, err := i.ReaderWriter.FindBySceneID(ctx, i.SceneID)
if err != nil {
return nil, err
@@ -98,8 +107,8 @@ func (i *MarkerImporter) FindExistingID() (*int, error) {
return nil, nil
}
-func (i *MarkerImporter) Create() (*int, error) {
- created, err := i.ReaderWriter.Create(i.marker)
+func (i *MarkerImporter) Create(ctx context.Context) (*int, error) {
+ created, err := i.ReaderWriter.Create(ctx, i.marker)
if err != nil {
return nil, fmt.Errorf("error creating marker: %v", err)
}
@@ -108,10 +117,10 @@ func (i *MarkerImporter) Create() (*int, error) {
return &id, nil
}
-func (i *MarkerImporter) Update(id int) error {
+func (i *MarkerImporter) Update(ctx context.Context, id int) error {
marker := i.marker
marker.ID = id
- _, err := i.ReaderWriter.Update(marker)
+ _, err := i.ReaderWriter.Update(ctx, marker)
if err != nil {
return fmt.Errorf("error updating existing marker: %v", err)
}
diff --git a/pkg/scene/marker_import_test.go b/pkg/scene/marker_import_test.go
index 0aa72a08b..f34d6b266 100644
--- a/pkg/scene/marker_import_test.go
+++ b/pkg/scene/marker_import_test.go
@@ -1,6 +1,7 @@
package scene
import (
+ "context"
"errors"
"testing"
@@ -30,6 +31,7 @@ func TestMarkerImporterName(t *testing.T) {
func TestMarkerImporterPreImportWithTag(t *testing.T) {
tagReaderWriter := &mocks.TagReaderWriter{}
+ ctx := context.Background()
i := MarkerImporter{
TagWriter: tagReaderWriter,
@@ -39,32 +41,32 @@ func TestMarkerImporterPreImportWithTag(t *testing.T) {
},
}
- tagReaderWriter.On("FindByNames", []string{existingTagName}, false).Return([]*models.Tag{
+ tagReaderWriter.On("FindByNames", ctx, []string{existingTagName}, false).Return([]*models.Tag{
{
ID: existingTagID,
Name: existingTagName,
},
}, nil).Times(4)
- tagReaderWriter.On("FindByNames", []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Times(2)
+ tagReaderWriter.On("FindByNames", ctx, []string{existingTagErr}, false).Return(nil, errors.New("FindByNames error")).Times(2)
- err := i.PreImport()
+ err := i.PreImport(ctx)
assert.Nil(t, err)
assert.Equal(t, existingTagID, i.marker.PrimaryTagID)
i.Input.PrimaryTag = existingTagErr
- err = i.PreImport()
+ err = i.PreImport(ctx)
assert.NotNil(t, err)
i.Input.PrimaryTag = existingTagName
i.Input.Tags = []string{
existingTagName,
}
- err = i.PreImport()
+ err = i.PreImport(ctx)
assert.Nil(t, err)
assert.Equal(t, existingTagID, i.tags[0].ID)
i.Input.Tags[0] = existingTagErr
- err = i.PreImport()
+ err = i.PreImport(ctx)
assert.NotNil(t, err)
tagReaderWriter.AssertExpectations(t)
@@ -72,6 +74,7 @@ func TestMarkerImporterPreImportWithTag(t *testing.T) {
func TestMarkerImporterPostImportUpdateTags(t *testing.T) {
sceneMarkerReaderWriter := &mocks.SceneMarkerReaderWriter{}
+ ctx := context.Background()
i := MarkerImporter{
ReaderWriter: sceneMarkerReaderWriter,
@@ -84,13 +87,13 @@ func TestMarkerImporterPostImportUpdateTags(t *testing.T) {
updateErr := errors.New("UpdateTags error")
- sceneMarkerReaderWriter.On("UpdateTags", sceneID, []int{existingTagID}).Return(nil).Once()
- sceneMarkerReaderWriter.On("UpdateTags", errTagsID, mock.AnythingOfType("[]int")).Return(updateErr).Once()
+ sceneMarkerReaderWriter.On("UpdateTags", ctx, sceneID, []int{existingTagID}).Return(nil).Once()
+ sceneMarkerReaderWriter.On("UpdateTags", ctx, errTagsID, mock.AnythingOfType("[]int")).Return(updateErr).Once()
- err := i.PostImport(sceneID)
+ err := i.PostImport(ctx, sceneID)
assert.Nil(t, err)
- err = i.PostImport(errTagsID)
+ err = i.PostImport(ctx, errTagsID)
assert.NotNil(t, err)
sceneMarkerReaderWriter.AssertExpectations(t)
@@ -98,6 +101,7 @@ func TestMarkerImporterPostImportUpdateTags(t *testing.T) {
func TestMarkerImporterFindExistingID(t *testing.T) {
readerWriter := &mocks.SceneMarkerReaderWriter{}
+ ctx := context.Background()
i := MarkerImporter{
ReaderWriter: readerWriter,
@@ -108,25 +112,25 @@ func TestMarkerImporterFindExistingID(t *testing.T) {
}
expectedErr := errors.New("FindBy* error")
- readerWriter.On("FindBySceneID", sceneID).Return([]*models.SceneMarker{
+ readerWriter.On("FindBySceneID", ctx, sceneID).Return([]*models.SceneMarker{
{
ID: existingSceneID,
Seconds: secondsFloat,
},
}, nil).Times(2)
- readerWriter.On("FindBySceneID", errSceneID).Return(nil, expectedErr).Once()
+ readerWriter.On("FindBySceneID", ctx, errSceneID).Return(nil, expectedErr).Once()
- id, err := i.FindExistingID()
+ id, err := i.FindExistingID(ctx)
assert.Equal(t, existingSceneID, *id)
assert.Nil(t, err)
i.marker.Seconds++
- id, err = i.FindExistingID()
+ id, err = i.FindExistingID(ctx)
assert.Nil(t, id)
assert.Nil(t, err)
i.SceneID = errSceneID
- id, err = i.FindExistingID()
+ id, err = i.FindExistingID(ctx)
assert.Nil(t, id)
assert.NotNil(t, err)
@@ -135,6 +139,7 @@ func TestMarkerImporterFindExistingID(t *testing.T) {
func TestMarkerImporterCreate(t *testing.T) {
readerWriter := &mocks.SceneMarkerReaderWriter{}
+ ctx := context.Background()
scene := models.SceneMarker{
Title: title,
@@ -150,17 +155,17 @@ func TestMarkerImporterCreate(t *testing.T) {
}
errCreate := errors.New("Create error")
- readerWriter.On("Create", scene).Return(&models.SceneMarker{
+ readerWriter.On("Create", ctx, scene).Return(&models.SceneMarker{
ID: sceneID,
}, nil).Once()
- readerWriter.On("Create", sceneErr).Return(nil, errCreate).Once()
+ readerWriter.On("Create", ctx, sceneErr).Return(nil, errCreate).Once()
- id, err := i.Create()
+ id, err := i.Create(ctx)
assert.Equal(t, sceneID, *id)
assert.Nil(t, err)
i.marker = sceneErr
- id, err = i.Create()
+ id, err = i.Create(ctx)
assert.Nil(t, id)
assert.NotNil(t, err)
@@ -169,6 +174,7 @@ func TestMarkerImporterCreate(t *testing.T) {
func TestMarkerImporterUpdate(t *testing.T) {
readerWriter := &mocks.SceneMarkerReaderWriter{}
+ ctx := context.Background()
scene := models.SceneMarker{
Title: title,
@@ -187,18 +193,18 @@ func TestMarkerImporterUpdate(t *testing.T) {
// id needs to be set for the mock input
scene.ID = sceneID
- readerWriter.On("Update", scene).Return(nil, nil).Once()
+ readerWriter.On("Update", ctx, scene).Return(nil, nil).Once()
- err := i.Update(sceneID)
+ err := i.Update(ctx, sceneID)
assert.Nil(t, err)
i.marker = sceneErr
// need to set id separately
sceneErr.ID = errImageID
- readerWriter.On("Update", sceneErr).Return(nil, errUpdate).Once()
+ readerWriter.On("Update", ctx, sceneErr).Return(nil, errUpdate).Once()
- err = i.Update(errImageID)
+ err = i.Update(ctx, errImageID)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
diff --git a/pkg/scene/query.go b/pkg/scene/query.go
index f560430c3..928270f38 100644
--- a/pkg/scene/query.go
+++ b/pkg/scene/query.go
@@ -11,7 +11,11 @@ import (
)
type Queryer interface {
- Query(options models.SceneQueryOptions) (*models.SceneQueryResult, error)
+ Query(ctx context.Context, options models.SceneQueryOptions) (*models.SceneQueryResult, error)
+}
+
+type IDFinder interface {
+ Find(ctx context.Context, id int) (*models.Scene, error)
}
// QueryOptions returns a SceneQueryOptions populated with the provided filters.
@@ -26,15 +30,15 @@ func QueryOptions(sceneFilter *models.SceneFilterType, findFilter *models.FindFi
}
// QueryWithCount queries for scenes, returning the scene objects and the total count.
-func QueryWithCount(qb Queryer, sceneFilter *models.SceneFilterType, findFilter *models.FindFilterType) ([]*models.Scene, int, error) {
+func QueryWithCount(ctx context.Context, qb Queryer, sceneFilter *models.SceneFilterType, findFilter *models.FindFilterType) ([]*models.Scene, int, error) {
// this was moved from the queryBuilder code
// left here so that calling functions can reference this instead
- result, err := qb.Query(QueryOptions(sceneFilter, findFilter, true))
+ result, err := qb.Query(ctx, QueryOptions(sceneFilter, findFilter, true))
if err != nil {
return nil, 0, err
}
- scenes, err := result.Resolve()
+ scenes, err := result.Resolve(ctx)
if err != nil {
return nil, 0, err
}
@@ -43,13 +47,13 @@ func QueryWithCount(qb Queryer, sceneFilter *models.SceneFilterType, findFilter
}
// Query queries for scenes using the provided filters.
-func Query(qb Queryer, sceneFilter *models.SceneFilterType, findFilter *models.FindFilterType) ([]*models.Scene, error) {
- result, err := qb.Query(QueryOptions(sceneFilter, findFilter, false))
+func Query(ctx context.Context, qb Queryer, sceneFilter *models.SceneFilterType, findFilter *models.FindFilterType) ([]*models.Scene, error) {
+ result, err := qb.Query(ctx, QueryOptions(sceneFilter, findFilter, false))
if err != nil {
return nil, err
}
- scenes, err := result.Resolve()
+ scenes, err := result.Resolve(ctx)
if err != nil {
return nil, err
}
@@ -57,7 +61,7 @@ func Query(qb Queryer, sceneFilter *models.SceneFilterType, findFilter *models.F
return scenes, nil
}
-func BatchProcess(ctx context.Context, reader models.SceneReader, sceneFilter *models.SceneFilterType, findFilter *models.FindFilterType, fn func(scene *models.Scene) error) error {
+func BatchProcess(ctx context.Context, reader Queryer, sceneFilter *models.SceneFilterType, findFilter *models.FindFilterType, fn func(scene *models.Scene) error) error {
const batchSize = 1000
if findFilter == nil {
@@ -74,7 +78,7 @@ func BatchProcess(ctx context.Context, reader models.SceneReader, sceneFilter *m
return nil
}
- scenes, err := Query(reader, sceneFilter, findFilter)
+ scenes, err := Query(ctx, reader, sceneFilter, findFilter)
if err != nil {
return fmt.Errorf("error querying for scenes: %w", err)
}
diff --git a/pkg/scene/scan.go b/pkg/scene/scan.go
index 1f33fa9ff..e5d1ec739 100644
--- a/pkg/scene/scan.go
+++ b/pkg/scene/scan.go
@@ -17,11 +17,23 @@ import (
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/models/paths"
"github.com/stashapp/stash/pkg/plugin"
+ "github.com/stashapp/stash/pkg/txn"
"github.com/stashapp/stash/pkg/utils"
)
const mutexType = "scene"
+type CreatorUpdater interface {
+ FindByChecksum(ctx context.Context, checksum string) (*models.Scene, error)
+ FindByOSHash(ctx context.Context, oshash string) (*models.Scene, error)
+ Create(ctx context.Context, newScene models.Scene) (*models.Scene, error)
+ UpdateFull(ctx context.Context, updatedScene models.Scene) (*models.Scene, error)
+ Update(ctx context.Context, updatedScene models.ScenePartial) (*models.Scene, error)
+
+ GetCaptions(ctx context.Context, sceneID int) ([]*models.SceneCaption, error)
+ UpdateCaptions(ctx context.Context, id int, captions []*models.SceneCaption) error
+}
+
type videoFileCreator interface {
NewVideoFile(path string) (*ffmpeg.VideoFile, error)
}
@@ -34,7 +46,8 @@ type Scanner struct {
FileNamingAlgorithm models.HashAlgorithm
CaseSensitiveFs bool
- TxnManager models.TransactionManager
+ TxnManager txn.Manager
+ CreatorUpdater CreatorUpdater
Paths *paths.Paths
Screenshotter screenshotter
VideoFileCreator videoFileCreator
@@ -105,16 +118,17 @@ func (scanner *Scanner) ScanExisting(ctx context.Context, existing file.FileBase
changed = true
}
- if err := scanner.TxnManager.WithTxn(context.TODO(), func(r models.Repository) error {
- var err error
- sqb := r.Scene()
+ qb := scanner.CreatorUpdater
- captions, er := sqb.GetCaptions(s.ID)
+ if err := txn.WithTxn(ctx, scanner.TxnManager, func(ctx context.Context) error {
+ var err error
+
+ captions, er := qb.GetCaptions(ctx, s.ID)
if er == nil {
if len(captions) > 0 {
clean, altered := CleanCaptions(s.Path, captions)
if altered {
- er = sqb.UpdateCaptions(s.ID, clean)
+ er = qb.UpdateCaptions(ctx, s.ID, clean)
if er == nil {
logger.Debugf("Captions for %s cleaned: %s -> %s", path, captions, clean)
}
@@ -136,20 +150,20 @@ func (scanner *Scanner) ScanExisting(ctx context.Context, existing file.FileBase
scanner.MutexManager.Claim(mutexType, scanned.New.Checksum, done)
}
- if err := scanner.TxnManager.WithTxn(ctx, func(r models.Repository) error {
+ if err := txn.WithTxn(ctx, scanner.TxnManager, func(ctx context.Context) error {
defer close(done)
- qb := r.Scene()
+ qb := scanner.CreatorUpdater
// ensure no clashes of hashes
if scanned.New.Checksum != "" && scanned.Old.Checksum != scanned.New.Checksum {
- dupe, _ := qb.FindByChecksum(s.Checksum.String)
+ dupe, _ := qb.FindByChecksum(ctx, s.Checksum.String)
if dupe != nil {
return fmt.Errorf("MD5 for file %s is the same as that of %s", path, dupe.Path)
}
}
if scanned.New.OSHash != "" && scanned.Old.OSHash != scanned.New.OSHash {
- dupe, _ := qb.FindByOSHash(scanned.New.OSHash)
+ dupe, _ := qb.FindByOSHash(ctx, scanned.New.OSHash)
if dupe != nil {
return fmt.Errorf("OSHash for file %s is the same as that of %s", path, dupe.Path)
}
@@ -158,7 +172,7 @@ func (scanner *Scanner) ScanExisting(ctx context.Context, existing file.FileBase
s.Interactive = interactive
s.UpdatedAt = models.SQLiteTimestamp{Timestamp: time.Now()}
- _, err := qb.UpdateFull(*s)
+ _, err := qb.UpdateFull(ctx, *s)
return err
}); err != nil {
return err
@@ -204,14 +218,14 @@ func (scanner *Scanner) ScanNew(ctx context.Context, file file.SourceFile) (retS
// check for scene by checksum and oshash - MD5 should be
// redundant, but check both
var s *models.Scene
- if err := scanner.TxnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
- qb := r.Scene()
+ if err := txn.WithTxn(ctx, scanner.TxnManager, func(ctx context.Context) error {
+ qb := scanner.CreatorUpdater
if checksum != "" {
- s, _ = qb.FindByChecksum(checksum)
+ s, _ = qb.FindByChecksum(ctx, checksum)
}
if s == nil {
- s, _ = qb.FindByOSHash(oshash)
+ s, _ = qb.FindByOSHash(ctx, oshash)
}
return nil
@@ -246,8 +260,8 @@ func (scanner *Scanner) ScanNew(ctx context.Context, file file.SourceFile) (retS
Path: &path,
Interactive: &interactive,
}
- if err := scanner.TxnManager.WithTxn(ctx, func(r models.Repository) error {
- _, err := r.Scene().Update(scenePartial)
+ if err := txn.WithTxn(ctx, scanner.TxnManager, func(ctx context.Context) error {
+ _, err := scanner.CreatorUpdater.Update(ctx, scenePartial)
return err
}); err != nil {
return nil, err
@@ -297,9 +311,9 @@ func (scanner *Scanner) ScanNew(ctx context.Context, file file.SourceFile) (retS
_ = newScene.Date.Scan(videoFile.CreationTime)
}
- if err := scanner.TxnManager.WithTxn(ctx, func(r models.Repository) error {
+ if err := txn.WithTxn(ctx, scanner.TxnManager, func(ctx context.Context) error {
var err error
- retScene, err = r.Scene().Create(newScene)
+ retScene, err = scanner.CreatorUpdater.Create(ctx, newScene)
return err
}); err != nil {
return nil, err
diff --git a/pkg/scene/update.go b/pkg/scene/update.go
index e1155c368..8486ac5e1 100644
--- a/pkg/scene/update.go
+++ b/pkg/scene/update.go
@@ -1,6 +1,7 @@
package scene
import (
+ "context"
"database/sql"
"errors"
"fmt"
@@ -11,6 +12,33 @@ import (
"github.com/stashapp/stash/pkg/utils"
)
+type Updater interface {
+ PartialUpdater
+ UpdatePerformers(ctx context.Context, sceneID int, performerIDs []int) error
+ UpdateTags(ctx context.Context, sceneID int, tagIDs []int) error
+ UpdateStashIDs(ctx context.Context, sceneID int, stashIDs []models.StashID) error
+ UpdateCover(ctx context.Context, sceneID int, cover []byte) error
+}
+
+type PartialUpdater interface {
+ Update(ctx context.Context, updatedScene models.ScenePartial) (*models.Scene, error)
+}
+
+type PerformerUpdater interface {
+ GetPerformerIDs(ctx context.Context, sceneID int) ([]int, error)
+ UpdatePerformers(ctx context.Context, sceneID int, performerIDs []int) error
+}
+
+type TagUpdater interface {
+ GetTagIDs(ctx context.Context, sceneID int) ([]int, error)
+ UpdateTags(ctx context.Context, sceneID int, tagIDs []int) error
+}
+
+type GalleryUpdater interface {
+ GetGalleryIDs(ctx context.Context, sceneID int) ([]int, error)
+ UpdateGalleries(ctx context.Context, sceneID int, galleryIDs []int) error
+}
+
var ErrEmptyUpdater = errors.New("no fields have been set")
// UpdateSet is used to update a scene and its relationships.
@@ -47,7 +75,7 @@ func (u *UpdateSet) IsEmpty() bool {
// Update updates a scene by updating the fields in the Partial field, then
// updates non-nil relationships. Returns an error if there is no work to
// be done.
-func (u *UpdateSet) Update(qb models.SceneWriter, screenshotSetter ScreenshotSetter) (*models.Scene, error) {
+func (u *UpdateSet) Update(ctx context.Context, qb Updater, screenshotSetter ScreenshotSetter) (*models.Scene, error) {
if u.IsEmpty() {
return nil, ErrEmptyUpdater
}
@@ -58,31 +86,31 @@ func (u *UpdateSet) Update(qb models.SceneWriter, screenshotSetter ScreenshotSet
Timestamp: time.Now(),
}
- ret, err := qb.Update(partial)
+ ret, err := qb.Update(ctx, partial)
if err != nil {
return nil, fmt.Errorf("error updating scene: %w", err)
}
if u.PerformerIDs != nil {
- if err := qb.UpdatePerformers(u.ID, u.PerformerIDs); err != nil {
+ if err := qb.UpdatePerformers(ctx, u.ID, u.PerformerIDs); err != nil {
return nil, fmt.Errorf("error updating scene performers: %w", err)
}
}
if u.TagIDs != nil {
- if err := qb.UpdateTags(u.ID, u.TagIDs); err != nil {
+ if err := qb.UpdateTags(ctx, u.ID, u.TagIDs); err != nil {
return nil, fmt.Errorf("error updating scene tags: %w", err)
}
}
if u.StashIDs != nil {
- if err := qb.UpdateStashIDs(u.ID, u.StashIDs); err != nil {
+ if err := qb.UpdateStashIDs(ctx, u.ID, u.StashIDs); err != nil {
return nil, fmt.Errorf("error updating scene stash_ids: %w", err)
}
}
if u.CoverImage != nil {
- if err := qb.UpdateCover(u.ID, u.CoverImage); err != nil {
+ if err := qb.UpdateCover(ctx, u.ID, u.CoverImage); err != nil {
return nil, fmt.Errorf("error updating scene cover: %w", err)
}
@@ -124,8 +152,8 @@ func (u UpdateSet) UpdateInput() models.SceneUpdateInput {
return ret
}
-func UpdateFormat(qb models.SceneWriter, id int, format string) (*models.Scene, error) {
- return qb.Update(models.ScenePartial{
+func UpdateFormat(ctx context.Context, qb PartialUpdater, id int, format string) (*models.Scene, error) {
+ return qb.Update(ctx, models.ScenePartial{
ID: id,
Format: &sql.NullString{
String: format,
@@ -134,8 +162,8 @@ func UpdateFormat(qb models.SceneWriter, id int, format string) (*models.Scene,
})
}
-func UpdateOSHash(qb models.SceneWriter, id int, oshash string) (*models.Scene, error) {
- return qb.Update(models.ScenePartial{
+func UpdateOSHash(ctx context.Context, qb PartialUpdater, id int, oshash string) (*models.Scene, error) {
+ return qb.Update(ctx, models.ScenePartial{
ID: id,
OSHash: &sql.NullString{
String: oshash,
@@ -144,8 +172,8 @@ func UpdateOSHash(qb models.SceneWriter, id int, oshash string) (*models.Scene,
})
}
-func UpdateChecksum(qb models.SceneWriter, id int, checksum string) (*models.Scene, error) {
- return qb.Update(models.ScenePartial{
+func UpdateChecksum(ctx context.Context, qb PartialUpdater, id int, checksum string) (*models.Scene, error) {
+ return qb.Update(ctx, models.ScenePartial{
ID: id,
Checksum: &sql.NullString{
String: checksum,
@@ -154,15 +182,15 @@ func UpdateChecksum(qb models.SceneWriter, id int, checksum string) (*models.Sce
})
}
-func UpdateFileModTime(qb models.SceneWriter, id int, modTime models.NullSQLiteTimestamp) (*models.Scene, error) {
- return qb.Update(models.ScenePartial{
+func UpdateFileModTime(ctx context.Context, qb PartialUpdater, id int, modTime models.NullSQLiteTimestamp) (*models.Scene, error) {
+ return qb.Update(ctx, models.ScenePartial{
ID: id,
FileModTime: &modTime,
})
}
-func AddPerformer(qb models.SceneReaderWriter, id int, performerID int) (bool, error) {
- performerIDs, err := qb.GetPerformerIDs(id)
+func AddPerformer(ctx context.Context, qb PerformerUpdater, id int, performerID int) (bool, error) {
+ performerIDs, err := qb.GetPerformerIDs(ctx, id)
if err != nil {
return false, err
}
@@ -171,7 +199,7 @@ func AddPerformer(qb models.SceneReaderWriter, id int, performerID int) (bool, e
performerIDs = intslice.IntAppendUnique(performerIDs, performerID)
if len(performerIDs) != oldLen {
- if err := qb.UpdatePerformers(id, performerIDs); err != nil {
+ if err := qb.UpdatePerformers(ctx, id, performerIDs); err != nil {
return false, err
}
@@ -181,8 +209,8 @@ func AddPerformer(qb models.SceneReaderWriter, id int, performerID int) (bool, e
return false, nil
}
-func AddTag(qb models.SceneReaderWriter, id int, tagID int) (bool, error) {
- tagIDs, err := qb.GetTagIDs(id)
+func AddTag(ctx context.Context, qb TagUpdater, id int, tagID int) (bool, error) {
+ tagIDs, err := qb.GetTagIDs(ctx, id)
if err != nil {
return false, err
}
@@ -191,7 +219,7 @@ func AddTag(qb models.SceneReaderWriter, id int, tagID int) (bool, error) {
tagIDs = intslice.IntAppendUnique(tagIDs, tagID)
if len(tagIDs) != oldLen {
- if err := qb.UpdateTags(id, tagIDs); err != nil {
+ if err := qb.UpdateTags(ctx, id, tagIDs); err != nil {
return false, err
}
@@ -201,8 +229,8 @@ func AddTag(qb models.SceneReaderWriter, id int, tagID int) (bool, error) {
return false, nil
}
-func AddGallery(qb models.SceneReaderWriter, id int, galleryID int) (bool, error) {
- galleryIDs, err := qb.GetGalleryIDs(id)
+func AddGallery(ctx context.Context, qb GalleryUpdater, id int, galleryID int) (bool, error) {
+ galleryIDs, err := qb.GetGalleryIDs(ctx, id)
if err != nil {
return false, err
}
@@ -211,7 +239,7 @@ func AddGallery(qb models.SceneReaderWriter, id int, galleryID int) (bool, error
galleryIDs = intslice.IntAppendUnique(galleryIDs, galleryID)
if len(galleryIDs) != oldLen {
- if err := qb.UpdateGalleries(id, galleryIDs); err != nil {
+ if err := qb.UpdateGalleries(ctx, id, galleryIDs); err != nil {
return false, err
}
diff --git a/pkg/scene/update_test.go b/pkg/scene/update_test.go
index 4619fd137..c605f98a2 100644
--- a/pkg/scene/update_test.go
+++ b/pkg/scene/update_test.go
@@ -1,6 +1,7 @@
package scene
import (
+ "context"
"errors"
"strconv"
"testing"
@@ -104,6 +105,8 @@ func TestUpdater_Update(t *testing.T) {
tagID
)
+ ctx := context.Background()
+
performerIDs := []int{performerID}
tagIDs := []int{tagID}
stashID := "stashID"
@@ -123,22 +126,22 @@ func TestUpdater_Update(t *testing.T) {
updateErr := errors.New("error updating")
qb := mocks.SceneReaderWriter{}
- qb.On("Update", mock.MatchedBy(func(s models.ScenePartial) bool {
+ qb.On("Update", ctx, mock.MatchedBy(func(s models.ScenePartial) bool {
return s.ID != badUpdateID
})).Return(validScene, nil)
- qb.On("Update", mock.MatchedBy(func(s models.ScenePartial) bool {
+ qb.On("Update", ctx, mock.MatchedBy(func(s models.ScenePartial) bool {
return s.ID == badUpdateID
})).Return(nil, updateErr)
- qb.On("UpdatePerformers", sceneID, performerIDs).Return(nil).Once()
- qb.On("UpdateTags", sceneID, tagIDs).Return(nil).Once()
- qb.On("UpdateStashIDs", sceneID, stashIDs).Return(nil).Once()
- qb.On("UpdateCover", sceneID, cover).Return(nil).Once()
+ qb.On("UpdatePerformers", ctx, sceneID, performerIDs).Return(nil).Once()
+ qb.On("UpdateTags", ctx, sceneID, tagIDs).Return(nil).Once()
+ qb.On("UpdateStashIDs", ctx, sceneID, stashIDs).Return(nil).Once()
+ qb.On("UpdateCover", ctx, sceneID, cover).Return(nil).Once()
- qb.On("UpdatePerformers", badPerformersID, performerIDs).Return(updateErr).Once()
- qb.On("UpdateTags", badTagsID, tagIDs).Return(updateErr).Once()
- qb.On("UpdateStashIDs", badStashIDsID, stashIDs).Return(updateErr).Once()
- qb.On("UpdateCover", badCoverID, cover).Return(updateErr).Once()
+ qb.On("UpdatePerformers", ctx, badPerformersID, performerIDs).Return(updateErr).Once()
+ qb.On("UpdateTags", ctx, badTagsID, tagIDs).Return(updateErr).Once()
+ qb.On("UpdateStashIDs", ctx, badStashIDsID, stashIDs).Return(updateErr).Once()
+ qb.On("UpdateCover", ctx, badCoverID, cover).Return(updateErr).Once()
tests := []struct {
name string
@@ -232,7 +235,7 @@ func TestUpdater_Update(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- got, err := tt.u.Update(&qb, &mockScreenshotSetter{})
+ got, err := tt.u.Update(ctx, &qb, &mockScreenshotSetter{})
if (err != nil) != tt.wantErr {
t.Errorf("Updater.Update() error = %v, wantErr %v", err, tt.wantErr)
return
diff --git a/pkg/scraper/action.go b/pkg/scraper/action.go
index e674cb2a2..0011441fb 100644
--- a/pkg/scraper/action.go
+++ b/pkg/scraper/action.go
@@ -33,16 +33,16 @@ type scraperActionImpl interface {
scrapeGalleryByGallery(ctx context.Context, gallery *models.Gallery) (*ScrapedGallery, error)
}
-func (c config) getScraper(scraper scraperTypeConfig, client *http.Client, txnManager models.TransactionManager, globalConfig GlobalConfig) scraperActionImpl {
+func (c config) getScraper(scraper scraperTypeConfig, client *http.Client, globalConfig GlobalConfig) scraperActionImpl {
switch scraper.Action {
case scraperActionScript:
return newScriptScraper(scraper, c, globalConfig)
case scraperActionStash:
- return newStashScraper(scraper, client, txnManager, c, globalConfig)
+ return newStashScraper(scraper, client, c, globalConfig)
case scraperActionXPath:
- return newXpathScraper(scraper, client, txnManager, c, globalConfig)
+ return newXpathScraper(scraper, client, c, globalConfig)
case scraperActionJson:
- return newJsonScraper(scraper, client, txnManager, c, globalConfig)
+ return newJsonScraper(scraper, client, c, globalConfig)
}
panic("unknown scraper action: " + scraper.Action)
diff --git a/pkg/scraper/autotag.go b/pkg/scraper/autotag.go
index 52f3ce4af..ba10ace3f 100644
--- a/pkg/scraper/autotag.go
+++ b/pkg/scraper/autotag.go
@@ -8,6 +8,7 @@ import (
"github.com/stashapp/stash/pkg/match"
"github.com/stashapp/stash/pkg/models"
+ "github.com/stashapp/stash/pkg/txn"
)
// autoTagScraperID is the scraper ID for the built-in AutoTag scraper
@@ -17,12 +18,17 @@ const (
)
type autotagScraper struct {
- txnManager models.TransactionManager
+ // repository models.Repository
+ txnManager txn.Manager
+ performerReader match.PerformerAutoTagQueryer
+ studioReader match.StudioAutoTagQueryer
+ tagReader match.TagAutoTagQueryer
+
globalConfig GlobalConfig
}
-func autotagMatchPerformers(path string, performerReader models.PerformerReader, trimExt bool) ([]*models.ScrapedPerformer, error) {
- p, err := match.PathToPerformers(path, performerReader, nil, trimExt)
+func autotagMatchPerformers(ctx context.Context, path string, performerReader match.PerformerAutoTagQueryer, trimExt bool) ([]*models.ScrapedPerformer, error) {
+ p, err := match.PathToPerformers(ctx, path, performerReader, nil, trimExt)
if err != nil {
return nil, fmt.Errorf("error matching performers: %w", err)
}
@@ -45,8 +51,8 @@ func autotagMatchPerformers(path string, performerReader models.PerformerReader,
return ret, nil
}
-func autotagMatchStudio(path string, studioReader models.StudioReader, trimExt bool) (*models.ScrapedStudio, error) {
- studio, err := match.PathToStudio(path, studioReader, nil, trimExt)
+func autotagMatchStudio(ctx context.Context, path string, studioReader match.StudioAutoTagQueryer, trimExt bool) (*models.ScrapedStudio, error) {
+ studio, err := match.PathToStudio(ctx, path, studioReader, nil, trimExt)
if err != nil {
return nil, fmt.Errorf("error matching studios: %w", err)
}
@@ -62,8 +68,8 @@ func autotagMatchStudio(path string, studioReader models.StudioReader, trimExt b
return nil, nil
}
-func autotagMatchTags(path string, tagReader models.TagReader, trimExt bool) ([]*models.ScrapedTag, error) {
- t, err := match.PathToTags(path, tagReader, nil, trimExt)
+func autotagMatchTags(ctx context.Context, path string, tagReader match.TagAutoTagQueryer, trimExt bool) ([]*models.ScrapedTag, error) {
+ t, err := match.PathToTags(ctx, path, tagReader, nil, trimExt)
if err != nil {
return nil, fmt.Errorf("error matching tags: %w", err)
}
@@ -88,18 +94,18 @@ func (s autotagScraper) viaScene(ctx context.Context, _client *http.Client, scen
const trimExt = false
// populate performers, studio and tags based on scene path
- if err := s.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
+ if err := txn.WithTxn(ctx, s.txnManager, func(ctx context.Context) error {
path := scene.Path
- performers, err := autotagMatchPerformers(path, r.Performer(), trimExt)
+ performers, err := autotagMatchPerformers(ctx, path, s.performerReader, trimExt)
if err != nil {
return fmt.Errorf("autotag scraper viaScene: %w", err)
}
- studio, err := autotagMatchStudio(path, r.Studio(), trimExt)
+ studio, err := autotagMatchStudio(ctx, path, s.studioReader, trimExt)
if err != nil {
return fmt.Errorf("autotag scraper viaScene: %w", err)
}
- tags, err := autotagMatchTags(path, r.Tag(), trimExt)
+ tags, err := autotagMatchTags(ctx, path, s.tagReader, trimExt)
if err != nil {
return fmt.Errorf("autotag scraper viaScene: %w", err)
}
@@ -132,18 +138,18 @@ func (s autotagScraper) viaGallery(ctx context.Context, _client *http.Client, ga
var ret *ScrapedGallery
// populate performers, studio and tags based on scene path
- if err := s.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
+ if err := txn.WithTxn(ctx, s.txnManager, func(ctx context.Context) error {
path := gallery.Path.String
- performers, err := autotagMatchPerformers(path, r.Performer(), trimExt)
+ performers, err := autotagMatchPerformers(ctx, path, s.performerReader, trimExt)
if err != nil {
return fmt.Errorf("autotag scraper viaGallery: %w", err)
}
- studio, err := autotagMatchStudio(path, r.Studio(), trimExt)
+ studio, err := autotagMatchStudio(ctx, path, s.studioReader, trimExt)
if err != nil {
return fmt.Errorf("autotag scraper viaGallery: %w", err)
}
- tags, err := autotagMatchTags(path, r.Tag(), trimExt)
+ tags, err := autotagMatchTags(ctx, path, s.tagReader, trimExt)
if err != nil {
return fmt.Errorf("autotag scraper viaGallery: %w", err)
}
@@ -196,10 +202,13 @@ func (s autotagScraper) spec() Scraper {
}
}
-func getAutoTagScraper(txnManager models.TransactionManager, globalConfig GlobalConfig) scraper {
+func getAutoTagScraper(txnManager txn.Manager, repo Repository, globalConfig GlobalConfig) scraper {
base := autotagScraper{
- txnManager: txnManager,
- globalConfig: globalConfig,
+ txnManager: txnManager,
+ performerReader: repo.PerformerFinder,
+ studioReader: repo.StudioFinder,
+ tagReader: repo.TagFinder,
+ globalConfig: globalConfig,
}
return base
diff --git a/pkg/scraper/cache.go b/pkg/scraper/cache.go
index 5357f8b94..decf71cb2 100644
--- a/pkg/scraper/cache.go
+++ b/pkg/scraper/cache.go
@@ -13,7 +13,11 @@ import (
"github.com/stashapp/stash/pkg/fsutil"
"github.com/stashapp/stash/pkg/logger"
+ "github.com/stashapp/stash/pkg/match"
"github.com/stashapp/stash/pkg/models"
+ "github.com/stashapp/stash/pkg/scene"
+ "github.com/stashapp/stash/pkg/tag"
+ "github.com/stashapp/stash/pkg/txn"
)
const (
@@ -47,12 +51,42 @@ func isCDPPathWS(c GlobalConfig) bool {
return strings.HasPrefix(c.GetScraperCDPPath(), "ws://")
}
+type PerformerFinder interface {
+ match.PerformerAutoTagQueryer
+ match.PerformerFinder
+}
+
+type StudioFinder interface {
+ match.StudioAutoTagQueryer
+ match.StudioFinder
+}
+
+type TagFinder interface {
+ match.TagAutoTagQueryer
+ tag.Queryer
+}
+
+type GalleryFinder interface {
+ Find(ctx context.Context, id int) (*models.Gallery, error)
+}
+
+type Repository struct {
+ SceneFinder scene.IDFinder
+ GalleryFinder GalleryFinder
+ TagFinder TagFinder
+ PerformerFinder PerformerFinder
+ MovieFinder match.MovieNamesFinder
+ StudioFinder StudioFinder
+}
+
// Cache stores the database of scrapers
type Cache struct {
client *http.Client
scrapers map[string]scraper // Scraper ID -> Scraper
globalConfig GlobalConfig
- txnManager models.TransactionManager
+ txnManager txn.Manager
+
+ repository Repository
}
// newClient creates a scraper-local http client we use throughout the scraper subsystem.
@@ -81,30 +115,33 @@ func newClient(gc GlobalConfig) *http.Client {
//
// Scraper configurations are loaded from yml files in the provided scrapers
// directory and any subdirectories.
-func NewCache(globalConfig GlobalConfig, txnManager models.TransactionManager) (*Cache, error) {
+func NewCache(globalConfig GlobalConfig, txnManager txn.Manager, repo Repository) (*Cache, error) {
// HTTP Client setup
client := newClient(globalConfig)
- scrapers, err := loadScrapers(globalConfig, txnManager)
+ ret := &Cache{
+ client: client,
+ globalConfig: globalConfig,
+ txnManager: txnManager,
+ repository: repo,
+ }
+
+ var err error
+ ret.scrapers, err = ret.loadScrapers()
if err != nil {
return nil, err
}
- return &Cache{
- client: client,
- globalConfig: globalConfig,
- scrapers: scrapers,
- txnManager: txnManager,
- }, nil
+ return ret, nil
}
-func loadScrapers(globalConfig GlobalConfig, txnManager models.TransactionManager) (map[string]scraper, error) {
- path := globalConfig.GetScrapersPath()
+func (c *Cache) loadScrapers() (map[string]scraper, error) {
+ path := c.globalConfig.GetScrapersPath()
scrapers := make(map[string]scraper)
// Add built-in scrapers
- freeOnes := getFreeonesScraper(txnManager, globalConfig)
- autoTag := getAutoTagScraper(txnManager, globalConfig)
+ freeOnes := getFreeonesScraper(c.globalConfig)
+ autoTag := getAutoTagScraper(c.txnManager, c.repository, c.globalConfig)
scrapers[freeOnes.spec().ID] = freeOnes
scrapers[autoTag.spec().ID] = autoTag
@@ -113,11 +150,11 @@ func loadScrapers(globalConfig GlobalConfig, txnManager models.TransactionManage
scraperFiles := []string{}
err := fsutil.SymWalk(path, func(fp string, f os.FileInfo, err error) error {
if filepath.Ext(fp) == ".yml" {
- c, err := loadConfigFromYAMLFile(fp)
+ conf, err := loadConfigFromYAMLFile(fp)
if err != nil {
logger.Errorf("Error loading scraper %s: %v", fp, err)
} else {
- scraper := newGroupScraper(*c, txnManager, globalConfig)
+ scraper := newGroupScraper(*conf, c.globalConfig)
scrapers[scraper.spec().ID] = scraper
}
scraperFiles = append(scraperFiles, fp)
@@ -137,7 +174,7 @@ func loadScrapers(globalConfig GlobalConfig, txnManager models.TransactionManage
// In the event of an error during loading, the cache will be left empty.
func (c *Cache) ReloadScrapers() error {
c.scrapers = nil
- scrapers, err := loadScrapers(c.globalConfig, c.txnManager)
+ scrapers, err := c.loadScrapers()
if err != nil {
return err
}
@@ -269,7 +306,7 @@ func (c Cache) ScrapeID(ctx context.Context, scraperID string, id int, ty Scrape
return nil, fmt.Errorf("%w: cannot use scraper %s as a scene scraper", ErrNotSupported, scraperID)
}
- scene, err := getScene(ctx, id, c.txnManager)
+ scene, err := c.getScene(ctx, id)
if err != nil {
return nil, fmt.Errorf("scraper %s: unable to load scene id %v: %w", scraperID, id, err)
}
@@ -290,7 +327,7 @@ func (c Cache) ScrapeID(ctx context.Context, scraperID string, id int, ty Scrape
return nil, fmt.Errorf("%w: cannot use scraper %s as a gallery scraper", ErrNotSupported, scraperID)
}
- gallery, err := getGallery(ctx, id, c.txnManager)
+ gallery, err := c.getGallery(ctx, id)
if err != nil {
return nil, fmt.Errorf("scraper %s: unable to load gallery id %v: %w", scraperID, id, err)
}
@@ -309,3 +346,27 @@ func (c Cache) ScrapeID(ctx context.Context, scraperID string, id int, ty Scrape
return c.postScrape(ctx, ret)
}
+
+func (c Cache) getScene(ctx context.Context, sceneID int) (*models.Scene, error) {
+ var ret *models.Scene
+ if err := txn.WithTxn(ctx, c.txnManager, func(ctx context.Context) error {
+ var err error
+ ret, err = c.repository.SceneFinder.Find(ctx, sceneID)
+ return err
+ }); err != nil {
+ return nil, err
+ }
+ return ret, nil
+}
+
+func (c Cache) getGallery(ctx context.Context, galleryID int) (*models.Gallery, error) {
+ var ret *models.Gallery
+ if err := txn.WithTxn(ctx, c.txnManager, func(ctx context.Context) error {
+ var err error
+ ret, err = c.repository.GalleryFinder.Find(ctx, galleryID)
+ return err
+ }); err != nil {
+ return nil, err
+ }
+ return ret, nil
+}
diff --git a/pkg/scraper/freeones.go b/pkg/scraper/freeones.go
index 7b6c81649..9a8eb4859 100644
--- a/pkg/scraper/freeones.go
+++ b/pkg/scraper/freeones.go
@@ -4,7 +4,6 @@ import (
"strings"
"github.com/stashapp/stash/pkg/logger"
- "github.com/stashapp/stash/pkg/models"
)
// FreeonesScraperID is the scraper ID for the built-in Freeones scraper
@@ -123,7 +122,7 @@ xPathScrapers:
# Last updated April 13, 2021
`
-func getFreeonesScraper(txnManager models.TransactionManager, globalConfig GlobalConfig) scraper {
+func getFreeonesScraper(globalConfig GlobalConfig) scraper {
yml := freeonesScraperConfig
c, err := loadConfigFromYAML(FreeonesScraperID, strings.NewReader(yml))
@@ -131,5 +130,5 @@ func getFreeonesScraper(txnManager models.TransactionManager, globalConfig Globa
logger.Fatalf("Error loading builtin freeones scraper: %s", err.Error())
}
- return newGroupScraper(*c, txnManager, globalConfig)
+ return newGroupScraper(*c, globalConfig)
}
diff --git a/pkg/scraper/group.go b/pkg/scraper/group.go
index b423883c4..bbf0a680a 100644
--- a/pkg/scraper/group.go
+++ b/pkg/scraper/group.go
@@ -11,14 +11,12 @@ import (
type group struct {
config config
- txnManager models.TransactionManager
globalConf GlobalConfig
}
-func newGroupScraper(c config, txnManager models.TransactionManager, globalConfig GlobalConfig) scraper {
+func newGroupScraper(c config, globalConfig GlobalConfig) scraper {
return group{
config: c,
- txnManager: txnManager,
globalConf: globalConfig,
}
}
@@ -55,7 +53,7 @@ func (g group) viaFragment(ctx context.Context, client *http.Client, input Input
return nil, ErrNotSupported
}
- s := g.config.getScraper(*stc, client, g.txnManager, g.globalConf)
+ s := g.config.getScraper(*stc, client, g.globalConf)
return s.scrapeByFragment(ctx, input)
}
@@ -64,7 +62,7 @@ func (g group) viaScene(ctx context.Context, client *http.Client, scene *models.
return nil, ErrNotSupported
}
- s := g.config.getScraper(*g.config.SceneByFragment, client, g.txnManager, g.globalConf)
+ s := g.config.getScraper(*g.config.SceneByFragment, client, g.globalConf)
return s.scrapeSceneByScene(ctx, scene)
}
@@ -73,7 +71,7 @@ func (g group) viaGallery(ctx context.Context, client *http.Client, gallery *mod
return nil, ErrNotSupported
}
- s := g.config.getScraper(*g.config.GalleryByFragment, client, g.txnManager, g.globalConf)
+ s := g.config.getScraper(*g.config.GalleryByFragment, client, g.globalConf)
return s.scrapeGalleryByGallery(ctx, gallery)
}
@@ -96,7 +94,7 @@ func (g group) viaURL(ctx context.Context, client *http.Client, url string, ty S
candidates := loadUrlCandidates(g.config, ty)
for _, scraper := range candidates {
if scraper.matchesURL(url) {
- s := g.config.getScraper(scraper.scraperTypeConfig, client, g.txnManager, g.globalConf)
+ s := g.config.getScraper(scraper.scraperTypeConfig, client, g.globalConf)
ret, err := s.scrapeByURL(ctx, url, ty)
if err != nil {
return nil, err
@@ -118,14 +116,14 @@ func (g group) viaName(ctx context.Context, client *http.Client, name string, ty
break
}
- s := g.config.getScraper(*g.config.PerformerByName, client, g.txnManager, g.globalConf)
+ s := g.config.getScraper(*g.config.PerformerByName, client, g.globalConf)
return s.scrapeByName(ctx, name, ty)
case ScrapeContentTypeScene:
if g.config.SceneByName == nil {
break
}
- s := g.config.getScraper(*g.config.SceneByName, client, g.txnManager, g.globalConf)
+ s := g.config.getScraper(*g.config.SceneByName, client, g.globalConf)
return s.scrapeByName(ctx, name, ty)
}
diff --git a/pkg/scraper/json.go b/pkg/scraper/json.go
index dbcba38ef..1d6358a92 100644
--- a/pkg/scraper/json.go
+++ b/pkg/scraper/json.go
@@ -19,16 +19,14 @@ type jsonScraper struct {
config config
globalConfig GlobalConfig
client *http.Client
- txnManager models.TransactionManager
}
-func newJsonScraper(scraper scraperTypeConfig, client *http.Client, txnManager models.TransactionManager, config config, globalConfig GlobalConfig) *jsonScraper {
+func newJsonScraper(scraper scraperTypeConfig, client *http.Client, config config, globalConfig GlobalConfig) *jsonScraper {
return &jsonScraper{
scraper: scraper,
config: config,
client: client,
globalConfig: globalConfig,
- txnManager: txnManager,
}
}
diff --git a/pkg/scraper/postprocessing.go b/pkg/scraper/postprocessing.go
index ded3bc816..6351ebccb 100644
--- a/pkg/scraper/postprocessing.go
+++ b/pkg/scraper/postprocessing.go
@@ -6,6 +6,8 @@ import (
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/match"
"github.com/stashapp/stash/pkg/models"
+ "github.com/stashapp/stash/pkg/tag"
+ "github.com/stashapp/stash/pkg/txn"
)
// postScrape handles post-processing of scraped content. If the content
@@ -45,10 +47,10 @@ func (c Cache) postScrape(ctx context.Context, content ScrapedContent) (ScrapedC
}
func (c Cache) postScrapePerformer(ctx context.Context, p models.ScrapedPerformer) (ScrapedContent, error) {
- if err := c.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
- tqb := r.Tag()
+ if err := txn.WithTxn(ctx, c.txnManager, func(ctx context.Context) error {
+ tqb := c.repository.TagFinder
- tags, err := postProcessTags(tqb, p.Tags)
+ tags, err := postProcessTags(ctx, tqb, p.Tags)
if err != nil {
return err
}
@@ -69,8 +71,8 @@ func (c Cache) postScrapePerformer(ctx context.Context, p models.ScrapedPerforme
func (c Cache) postScrapeMovie(ctx context.Context, m models.ScrapedMovie) (ScrapedContent, error) {
if m.Studio != nil {
- if err := c.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
- return match.ScrapedStudio(r.Studio(), m.Studio, nil)
+ if err := txn.WithTxn(ctx, c.txnManager, func(ctx context.Context) error {
+ return match.ScrapedStudio(ctx, c.repository.StudioFinder, m.Studio, nil)
}); err != nil {
return nil, err
}
@@ -88,10 +90,10 @@ func (c Cache) postScrapeMovie(ctx context.Context, m models.ScrapedMovie) (Scra
}
func (c Cache) postScrapeScenePerformer(ctx context.Context, p models.ScrapedPerformer) error {
- if err := c.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
- tqb := r.Tag()
+ if err := txn.WithTxn(ctx, c.txnManager, func(ctx context.Context) error {
+ tqb := c.repository.TagFinder
- tags, err := postProcessTags(tqb, p.Tags)
+ tags, err := postProcessTags(ctx, tqb, p.Tags)
if err != nil {
return err
}
@@ -106,11 +108,11 @@ func (c Cache) postScrapeScenePerformer(ctx context.Context, p models.ScrapedPer
}
func (c Cache) postScrapeScene(ctx context.Context, scene ScrapedScene) (ScrapedContent, error) {
- if err := c.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
- pqb := r.Performer()
- mqb := r.Movie()
- tqb := r.Tag()
- sqb := r.Studio()
+ if err := txn.WithTxn(ctx, c.txnManager, func(ctx context.Context) error {
+ pqb := c.repository.PerformerFinder
+ mqb := c.repository.MovieFinder
+ tqb := c.repository.TagFinder
+ sqb := c.repository.StudioFinder
for _, p := range scene.Performers {
if p == nil {
@@ -121,26 +123,26 @@ func (c Cache) postScrapeScene(ctx context.Context, scene ScrapedScene) (Scraped
return err
}
- if err := match.ScrapedPerformer(pqb, p, nil); err != nil {
+ if err := match.ScrapedPerformer(ctx, pqb, p, nil); err != nil {
return err
}
}
for _, p := range scene.Movies {
- err := match.ScrapedMovie(mqb, p)
+ err := match.ScrapedMovie(ctx, mqb, p)
if err != nil {
return err
}
}
- tags, err := postProcessTags(tqb, scene.Tags)
+ tags, err := postProcessTags(ctx, tqb, scene.Tags)
if err != nil {
return err
}
scene.Tags = tags
if scene.Studio != nil {
- err := match.ScrapedStudio(sqb, scene.Studio, nil)
+ err := match.ScrapedStudio(ctx, sqb, scene.Studio, nil)
if err != nil {
return err
}
@@ -160,26 +162,26 @@ func (c Cache) postScrapeScene(ctx context.Context, scene ScrapedScene) (Scraped
}
func (c Cache) postScrapeGallery(ctx context.Context, g ScrapedGallery) (ScrapedContent, error) {
- if err := c.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
- pqb := r.Performer()
- tqb := r.Tag()
- sqb := r.Studio()
+ if err := txn.WithTxn(ctx, c.txnManager, func(ctx context.Context) error {
+ pqb := c.repository.PerformerFinder
+ tqb := c.repository.TagFinder
+ sqb := c.repository.StudioFinder
for _, p := range g.Performers {
- err := match.ScrapedPerformer(pqb, p, nil)
+ err := match.ScrapedPerformer(ctx, pqb, p, nil)
if err != nil {
return err
}
}
- tags, err := postProcessTags(tqb, g.Tags)
+ tags, err := postProcessTags(ctx, tqb, g.Tags)
if err != nil {
return err
}
g.Tags = tags
if g.Studio != nil {
- err := match.ScrapedStudio(sqb, g.Studio, nil)
+ err := match.ScrapedStudio(ctx, sqb, g.Studio, nil)
if err != nil {
return err
}
@@ -193,11 +195,11 @@ func (c Cache) postScrapeGallery(ctx context.Context, g ScrapedGallery) (Scraped
return g, nil
}
-func postProcessTags(tqb models.TagReader, scrapedTags []*models.ScrapedTag) ([]*models.ScrapedTag, error) {
+func postProcessTags(ctx context.Context, tqb tag.Queryer, scrapedTags []*models.ScrapedTag) ([]*models.ScrapedTag, error) {
var ret []*models.ScrapedTag
for _, t := range scrapedTags {
- err := match.ScrapedTag(tqb, t)
+ err := match.ScrapedTag(ctx, tqb, t)
if err != nil {
return nil, err
}
diff --git a/pkg/scraper/stash.go b/pkg/scraper/stash.go
index b6e0e7696..7095ab711 100644
--- a/pkg/scraper/stash.go
+++ b/pkg/scraper/stash.go
@@ -18,16 +18,14 @@ type stashScraper struct {
config config
globalConfig GlobalConfig
client *http.Client
- txnManager models.TransactionManager
}
-func newStashScraper(scraper scraperTypeConfig, client *http.Client, txnManager models.TransactionManager, config config, globalConfig GlobalConfig) *stashScraper {
+func newStashScraper(scraper scraperTypeConfig, client *http.Client, config config, globalConfig GlobalConfig) *stashScraper {
return &stashScraper{
scraper: scraper,
config: config,
client: client,
globalConfig: globalConfig,
- txnManager: txnManager,
}
}
@@ -308,18 +306,6 @@ func (s *stashScraper) scrapeByURL(_ context.Context, _ string, _ ScrapeContentT
return nil, ErrNotSupported
}
-func getScene(ctx context.Context, sceneID int, txnManager models.TransactionManager) (*models.Scene, error) {
- var ret *models.Scene
- if err := txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
- var err error
- ret, err = r.Scene().Find(sceneID)
- return err
- }); err != nil {
- return nil, err
- }
- return ret, nil
-}
-
func sceneToUpdateInput(scene *models.Scene) models.SceneUpdateInput {
toStringPtr := func(s sql.NullString) *string {
if s.Valid {
@@ -346,18 +332,6 @@ func sceneToUpdateInput(scene *models.Scene) models.SceneUpdateInput {
}
}
-func getGallery(ctx context.Context, galleryID int, txnManager models.TransactionManager) (*models.Gallery, error) {
- var ret *models.Gallery
- if err := txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
- var err error
- ret, err = r.Gallery().Find(galleryID)
- return err
- }); err != nil {
- return nil, err
- }
- return ret, nil
-}
-
func galleryToUpdateInput(gallery *models.Gallery) models.GalleryUpdateInput {
toStringPtr := func(s sql.NullString) *string {
if s.Valid {
diff --git a/pkg/scraper/stashbox/stash_box.go b/pkg/scraper/stashbox/stash_box.go
index 923cd4482..7e2017aef 100644
--- a/pkg/scraper/stashbox/stash_box.go
+++ b/pkg/scraper/stashbox/stash_box.go
@@ -24,18 +24,52 @@ import (
"github.com/stashapp/stash/pkg/scraper"
"github.com/stashapp/stash/pkg/scraper/stashbox/graphql"
"github.com/stashapp/stash/pkg/sliceutil/stringslice"
+ "github.com/stashapp/stash/pkg/studio"
+ "github.com/stashapp/stash/pkg/tag"
+ "github.com/stashapp/stash/pkg/txn"
"github.com/stashapp/stash/pkg/utils"
)
+type SceneReader interface {
+ Find(ctx context.Context, id int) (*models.Scene, error)
+ GetStashIDs(ctx context.Context, sceneID int) ([]*models.StashID, error)
+}
+
+type PerformerReader interface {
+ match.PerformerFinder
+ Find(ctx context.Context, id int) (*models.Performer, error)
+ FindBySceneID(ctx context.Context, sceneID int) ([]*models.Performer, error)
+ GetStashIDs(ctx context.Context, performerID int) ([]*models.StashID, error)
+ GetImage(ctx context.Context, performerID int) ([]byte, error)
+}
+
+type StudioReader interface {
+ match.StudioFinder
+ studio.Finder
+ GetStashIDs(ctx context.Context, studioID int) ([]*models.StashID, error)
+}
+type TagFinder interface {
+ tag.Queryer
+ FindBySceneID(ctx context.Context, sceneID int) ([]*models.Tag, error)
+}
+
+type Repository struct {
+ Scene SceneReader
+ Performer PerformerReader
+ Tag TagFinder
+ Studio StudioReader
+}
+
// Client represents the client interface to a stash-box server instance.
type Client struct {
client *graphql.Client
- txnManager models.TransactionManager
+ txnManager txn.Manager
+ repository Repository
box models.StashBox
}
// NewClient returns a new instance of a stash-box client.
-func NewClient(box models.StashBox, txnManager models.TransactionManager) *Client {
+func NewClient(box models.StashBox, txnManager txn.Manager, repo Repository) *Client {
authHeader := func(req *http.Request) {
req.Header.Set("ApiKey", box.APIKey)
}
@@ -47,6 +81,7 @@ func NewClient(box models.StashBox, txnManager models.TransactionManager) *Clien
return &Client{
client: client,
txnManager: txnManager,
+ repository: repo,
box: box,
}
}
@@ -92,11 +127,11 @@ func (c Client) FindStashBoxSceneByFingerprints(ctx context.Context, sceneID int
func (c Client) FindStashBoxScenesByFingerprints(ctx context.Context, ids []int) ([][]*scraper.ScrapedScene, error) {
var fingerprints [][]*graphql.FingerprintQueryInput
- if err := c.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
- qb := r.Scene()
+ if err := txn.WithTxn(ctx, c.txnManager, func(ctx context.Context) error {
+ qb := c.repository.Scene
for _, sceneID := range ids {
- scene, err := qb.Find(sceneID)
+ scene, err := qb.Find(ctx, sceneID)
if err != nil {
return err
}
@@ -177,11 +212,11 @@ func (c Client) SubmitStashBoxFingerprints(ctx context.Context, sceneIDs []strin
var fingerprints []graphql.FingerprintSubmission
- if err := c.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
- qb := r.Scene()
+ if err := txn.WithTxn(ctx, c.txnManager, func(ctx context.Context) error {
+ qb := c.repository.Scene
for _, sceneID := range ids {
- scene, err := qb.Find(sceneID)
+ scene, err := qb.Find(ctx, sceneID)
if err != nil {
return err
}
@@ -190,7 +225,7 @@ func (c Client) SubmitStashBoxFingerprints(ctx context.Context, sceneIDs []strin
continue
}
- stashIDs, err := qb.GetStashIDs(sceneID)
+ stashIDs, err := qb.GetStashIDs(ctx, sceneID)
if err != nil {
return err
}
@@ -307,11 +342,11 @@ func (c Client) FindStashBoxPerformersByNames(ctx context.Context, performerIDs
var performers []*models.Performer
- if err := c.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
- qb := r.Performer()
+ if err := txn.WithTxn(ctx, c.txnManager, func(ctx context.Context) error {
+ qb := c.repository.Performer
for _, performerID := range ids {
- performer, err := qb.Find(performerID)
+ performer, err := qb.Find(ctx, performerID)
if err != nil {
return err
}
@@ -341,11 +376,11 @@ func (c Client) FindStashBoxPerformersByPerformerNames(ctx context.Context, perf
var performers []*models.Performer
- if err := c.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
- qb := r.Performer()
+ if err := txn.WithTxn(ctx, c.txnManager, func(ctx context.Context) error {
+ qb := c.repository.Performer
for _, performerID := range ids {
- performer, err := qb.Find(performerID)
+ performer, err := qb.Find(ctx, performerID)
if err != nil {
return err
}
@@ -622,9 +657,9 @@ func (c Client) sceneFragmentToScrapedScene(ctx context.Context, s *graphql.Scen
ss.Image = getFirstImage(ctx, c.getHTTPClient(), s.Images)
}
- if err := c.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
- pqb := r.Performer()
- tqb := r.Tag()
+ if err := txn.WithTxn(ctx, c.txnManager, func(ctx context.Context) error {
+ pqb := c.repository.Performer
+ tqb := c.repository.Tag
if s.Studio != nil {
studioID := s.Studio.ID
@@ -634,7 +669,7 @@ func (c Client) sceneFragmentToScrapedScene(ctx context.Context, s *graphql.Scen
RemoteSiteID: &studioID,
}
- err := match.ScrapedStudio(r.Studio(), ss.Studio, &c.box.Endpoint)
+ err := match.ScrapedStudio(ctx, c.repository.Studio, ss.Studio, &c.box.Endpoint)
if err != nil {
return err
}
@@ -643,7 +678,7 @@ func (c Client) sceneFragmentToScrapedScene(ctx context.Context, s *graphql.Scen
for _, p := range s.Performers {
sp := performerFragmentToScrapedScenePerformer(p.Performer)
- err := match.ScrapedPerformer(pqb, sp, &c.box.Endpoint)
+ err := match.ScrapedPerformer(ctx, pqb, sp, &c.box.Endpoint)
if err != nil {
return err
}
@@ -656,7 +691,7 @@ func (c Client) sceneFragmentToScrapedScene(ctx context.Context, s *graphql.Scen
Name: t.Name,
}
- err := match.ScrapedTag(tqb, st)
+ err := match.ScrapedTag(ctx, tqb, st)
if err != nil {
return err
}
@@ -705,12 +740,13 @@ func (c Client) GetUser(ctx context.Context) (*graphql.Me, error) {
func (c Client) SubmitSceneDraft(ctx context.Context, sceneID int, endpoint string, imagePath string) (*string, error) {
draft := graphql.SceneDraftInput{}
var image *os.File
- if err := c.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
- qb := r.Scene()
- pqb := r.Performer()
- sqb := r.Studio()
+ if err := txn.WithTxn(ctx, c.txnManager, func(ctx context.Context) error {
+ r := c.repository
+ qb := r.Scene
+ pqb := r.Performer
+ sqb := r.Studio
- scene, err := qb.Find(sceneID)
+ scene, err := qb.Find(ctx, sceneID)
if err != nil {
return err
}
@@ -730,7 +766,7 @@ func (c Client) SubmitSceneDraft(ctx context.Context, sceneID int, endpoint stri
}
if scene.StudioID.Valid {
- studio, err := sqb.Find(int(scene.StudioID.Int64))
+ studio, err := sqb.Find(ctx, int(scene.StudioID.Int64))
if err != nil {
return err
}
@@ -738,7 +774,7 @@ func (c Client) SubmitSceneDraft(ctx context.Context, sceneID int, endpoint stri
Name: studio.Name.String,
}
- stashIDs, err := sqb.GetStashIDs(studio.ID)
+ stashIDs, err := sqb.GetStashIDs(ctx, studio.ID)
if err != nil {
return err
}
@@ -780,7 +816,7 @@ func (c Client) SubmitSceneDraft(ctx context.Context, sceneID int, endpoint stri
}
draft.Fingerprints = fingerprints
- scenePerformers, err := pqb.FindBySceneID(sceneID)
+ scenePerformers, err := pqb.FindBySceneID(ctx, sceneID)
if err != nil {
return err
}
@@ -791,7 +827,7 @@ func (c Client) SubmitSceneDraft(ctx context.Context, sceneID int, endpoint stri
Name: p.Name.String,
}
- stashIDs, err := pqb.GetStashIDs(p.ID)
+ stashIDs, err := pqb.GetStashIDs(ctx, p.ID)
if err != nil {
return err
}
@@ -808,7 +844,7 @@ func (c Client) SubmitSceneDraft(ctx context.Context, sceneID int, endpoint stri
draft.Performers = performers
var tags []*graphql.DraftEntityInput
- sceneTags, err := r.Tag().FindBySceneID(scene.ID)
+ sceneTags, err := r.Tag.FindBySceneID(ctx, scene.ID)
if err != nil {
return err
}
@@ -825,7 +861,7 @@ func (c Client) SubmitSceneDraft(ctx context.Context, sceneID int, endpoint stri
}
}
- stashIDs, err := qb.GetStashIDs(sceneID)
+ stashIDs, err := qb.GetStashIDs(ctx, sceneID)
if err != nil {
return err
}
@@ -862,9 +898,9 @@ func (c Client) SubmitSceneDraft(ctx context.Context, sceneID int, endpoint stri
func (c Client) SubmitPerformerDraft(ctx context.Context, performer *models.Performer, endpoint string) (*string, error) {
draft := graphql.PerformerDraftInput{}
var image io.Reader
- if err := c.txnManager.WithReadTxn(ctx, func(r models.ReaderRepository) error {
- pqb := r.Performer()
- img, _ := pqb.GetImage(performer.ID)
+ if err := txn.WithTxn(ctx, c.txnManager, func(ctx context.Context) error {
+ pqb := c.repository.Performer
+ img, _ := pqb.GetImage(ctx, performer.ID)
if img != nil {
image = bytes.NewReader(img)
}
@@ -923,7 +959,7 @@ func (c Client) SubmitPerformerDraft(ctx context.Context, performer *models.Perf
draft.Urls = urls
}
- stashIDs, err := pqb.GetStashIDs(performer.ID)
+ stashIDs, err := pqb.GetStashIDs(ctx, performer.ID)
if err != nil {
return err
}
diff --git a/pkg/scraper/xpath.go b/pkg/scraper/xpath.go
index 092968b5e..29a4b0a19 100644
--- a/pkg/scraper/xpath.go
+++ b/pkg/scraper/xpath.go
@@ -23,16 +23,14 @@ type xpathScraper struct {
config config
globalConfig GlobalConfig
client *http.Client
- txnManager models.TransactionManager
}
-func newXpathScraper(scraper scraperTypeConfig, client *http.Client, txnManager models.TransactionManager, config config, globalConfig GlobalConfig) *xpathScraper {
+func newXpathScraper(scraper scraperTypeConfig, client *http.Client, config config, globalConfig GlobalConfig) *xpathScraper {
return &xpathScraper{
scraper: scraper,
config: config,
globalConfig: globalConfig,
client: client,
- txnManager: txnManager,
}
}
diff --git a/pkg/scraper/xpath_test.go b/pkg/scraper/xpath_test.go
index 9aef91a23..7120f8574 100644
--- a/pkg/scraper/xpath_test.go
+++ b/pkg/scraper/xpath_test.go
@@ -885,7 +885,7 @@ xPathScrapers:
client := &http.Client{}
ctx := context.Background()
- s := newGroupScraper(*c, nil, globalConfig)
+ s := newGroupScraper(*c, globalConfig)
us, ok := s.(urlScraper)
if !ok {
t.Error("couldn't convert scraper into url scraper")
diff --git a/pkg/database/custom_migrations.go b/pkg/sqlite/custom_migrations.go
similarity index 81%
rename from pkg/database/custom_migrations.go
rename to pkg/sqlite/custom_migrations.go
index 340ffba55..768317707 100644
--- a/pkg/database/custom_migrations.go
+++ b/pkg/sqlite/custom_migrations.go
@@ -1,27 +1,33 @@
-package database
+package sqlite
import (
+ "context"
"database/sql"
"errors"
"fmt"
"strings"
- "github.com/jmoiron/sqlx"
"github.com/stashapp/stash/pkg/logger"
+ "github.com/stashapp/stash/pkg/txn"
)
-func runCustomMigrations() error {
- if err := createImagesChecksumIndex(); err != nil {
+func (db *Database) runCustomMigrations() error {
+ if err := db.createImagesChecksumIndex(); err != nil {
return err
}
return nil
}
-func createImagesChecksumIndex() error {
- return WithTxn(func(tx *sqlx.Tx) error {
+func (db *Database) createImagesChecksumIndex() error {
+ return txn.WithTxn(context.Background(), db, func(ctx context.Context) error {
+ tx, err := getTx(ctx)
+ if err != nil {
+ return err
+ }
+
row := tx.QueryRow("SELECT 1 AS found FROM sqlite_master WHERE type = 'index' AND name = 'images_checksum_unique'")
- err := row.Err()
+ err = row.Err()
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return err
}
diff --git a/pkg/database/database.go b/pkg/sqlite/database.go
similarity index 54%
rename from pkg/database/database.go
rename to pkg/sqlite/database.go
index 3fa260716..5897e844a 100644
--- a/pkg/database/database.go
+++ b/pkg/sqlite/database.go
@@ -1,4 +1,4 @@
-package database
+package sqlite
import (
"database/sql"
@@ -20,99 +20,135 @@ import (
"github.com/stashapp/stash/pkg/logger"
)
-var DB *sqlx.DB
-var WriteMu sync.Mutex
-var dbPath string
var appSchemaVersion uint = 31
-var databaseSchemaVersion uint
//go:embed migrations/*.sql
var migrationsBox embed.FS
var (
- // ErrMigrationNeeded indicates that a database migration is needed
- // before the database can be initialized
- ErrMigrationNeeded = errors.New("database migration required")
-
// ErrDatabaseNotInitialized indicates that the database is not
// initialized, usually due to an incomplete configuration.
ErrDatabaseNotInitialized = errors.New("database not initialized")
)
-const sqlite3Driver = "sqlite3ex"
-
-// Ready returns an error if the database is not ready to begin transactions.
-func Ready() error {
- if DB == nil {
- return ErrDatabaseNotInitialized
- }
-
- return nil
+// ErrMigrationNeeded indicates that a database migration is needed
+// before the database can be initialized
+type MigrationNeededError struct {
+ CurrentSchemaVersion uint
+ RequiredSchemaVersion uint
}
+func (e *MigrationNeededError) Error() string {
+ return fmt.Sprintf("database schema version %d does not match required schema version %d", e.CurrentSchemaVersion, e.RequiredSchemaVersion)
+}
+
+type MismatchedSchemaVersionError struct {
+ CurrentSchemaVersion uint
+ RequiredSchemaVersion uint
+}
+
+func (e *MismatchedSchemaVersionError) Error() string {
+ return fmt.Sprintf("schema version %d is incompatible with required schema version %d", e.CurrentSchemaVersion, e.RequiredSchemaVersion)
+}
+
+const sqlite3Driver = "sqlite3ex"
+
func init() {
// register custom driver with regexp function
registerCustomDriver()
}
-// Initialize initializes the database. If the database is new, then it
+type Database struct {
+ db *sqlx.DB
+ dbPath string
+
+ schemaVersion uint
+
+ writeMu sync.Mutex
+}
+
+// Ready returns an error if the database is not ready to begin transactions.
+func (db *Database) Ready() error {
+ if db.db == nil {
+ return ErrDatabaseNotInitialized
+ }
+
+ return nil
+}
+
+// Open initializes the database. If the database is new, then it
// performs a full migration to the latest schema version. Otherwise, any
// necessary migrations must be run separately using RunMigrations.
// Returns true if the database is new.
-func Initialize(databasePath string) error {
- dbPath = databasePath
+func (db *Database) Open(dbPath string) error {
+ db.writeMu.Lock()
+ defer db.writeMu.Unlock()
- if err := getDatabaseSchemaVersion(); err != nil {
- return fmt.Errorf("error getting database schema version: %v", err)
+ db.dbPath = dbPath
+
+ databaseSchemaVersion, err := db.getDatabaseSchemaVersion()
+ if err != nil {
+ return fmt.Errorf("getting database schema version: %w", err)
}
+ db.schemaVersion = databaseSchemaVersion
+
if databaseSchemaVersion == 0 {
// new database, just run the migrations
- if err := RunMigrations(); err != nil {
+ if err := db.RunMigrations(); err != nil {
return fmt.Errorf("error running initial schema migrations: %v", err)
}
- // RunMigrations calls Initialise. Just return
- return nil
} else {
if databaseSchemaVersion > appSchemaVersion {
- panic(fmt.Sprintf("Database schema version %d is incompatible with required schema version %d", databaseSchemaVersion, appSchemaVersion))
+ return &MismatchedSchemaVersionError{
+ CurrentSchemaVersion: databaseSchemaVersion,
+ RequiredSchemaVersion: appSchemaVersion,
+ }
}
// if migration is needed, then don't open the connection
- if NeedsMigration() {
- logger.Warnf("Database schema version %d does not match required schema version %d.", databaseSchemaVersion, appSchemaVersion)
- return nil
+ if db.needsMigration() {
+ return &MigrationNeededError{
+ CurrentSchemaVersion: databaseSchemaVersion,
+ RequiredSchemaVersion: appSchemaVersion,
+ }
}
}
- const disableForeignKeys = false
- DB = open(databasePath, disableForeignKeys)
+ // RunMigrations may have opened a connection already
+ if db.db == nil {
+ const disableForeignKeys = false
+ db.db, err = db.open(disableForeignKeys)
+ if err != nil {
+ return err
+ }
+ }
- if err := runCustomMigrations(); err != nil {
+ if err := db.runCustomMigrations(); err != nil {
return err
}
return nil
}
-func Close() error {
- WriteMu.Lock()
- defer WriteMu.Unlock()
+func (db *Database) Close() error {
+ db.writeMu.Lock()
+ defer db.writeMu.Unlock()
- if DB != nil {
- if err := DB.Close(); err != nil {
+ if db.db != nil {
+ if err := db.db.Close(); err != nil {
return err
}
- DB = nil
+ db.db = nil
}
return nil
}
-func open(databasePath string, disableForeignKeys bool) *sqlx.DB {
+func (db *Database) open(disableForeignKeys bool) (*sqlx.DB, error) {
// https://github.com/mattn/go-sqlite3
- url := "file:" + databasePath + "?_journal=WAL&_sync=NORMAL"
+ url := "file:" + db.dbPath + "?_journal=WAL&_sync=NORMAL"
if !disableForeignKeys {
url += "&_fk=true"
}
@@ -122,14 +158,15 @@ func open(databasePath string, disableForeignKeys bool) *sqlx.DB {
conn.SetMaxIdleConns(4)
conn.SetConnMaxLifetime(30 * time.Second)
if err != nil {
- logger.Fatalf("db.Open(): %q\n", err)
+ return nil, fmt.Errorf("db.Open(): %w", err)
}
- return conn
+ return conn, nil
}
-func Reset(databasePath string) error {
- err := DB.Close()
+func (db *Database) Reset() error {
+ databasePath := db.dbPath
+ err := db.Close()
if err != nil {
return errors.New("Error closing database: " + err.Error())
@@ -151,7 +188,7 @@ func Reset(databasePath string) error {
}
}
- if err := Initialize(databasePath); err != nil {
+ if err := db.Open(databasePath); err != nil {
return fmt.Errorf("[reset DB] unable to initialize: %w", err)
}
@@ -160,18 +197,19 @@ func Reset(databasePath string) error {
// Backup the database. If db is nil, then uses the existing database
// connection.
-func Backup(db *sqlx.DB, backupPath string) error {
- if db == nil {
+func (db *Database) Backup(backupPath string) error {
+ thisDB := db.db
+ if thisDB == nil {
var err error
- db, err = sqlx.Connect(sqlite3Driver, "file:"+dbPath+"?_fk=true")
+ thisDB, err = sqlx.Connect(sqlite3Driver, "file:"+db.dbPath+"?_fk=true")
if err != nil {
- return fmt.Errorf("open database %s failed: %v", dbPath, err)
+ return fmt.Errorf("open database %s failed: %v", db.dbPath, err)
}
- defer db.Close()
+ defer thisDB.Close()
}
logger.Infof("Backing up database into: %s", backupPath)
- _, err := db.Exec(`VACUUM INTO "` + backupPath + `"`)
+ _, err := thisDB.Exec(`VACUUM INTO "` + backupPath + `"`)
if err != nil {
return fmt.Errorf("vacuum failed: %v", err)
}
@@ -179,40 +217,43 @@ func Backup(db *sqlx.DB, backupPath string) error {
return nil
}
-func RestoreFromBackup(backupPath string) error {
- logger.Infof("Restoring backup database %s into %s", backupPath, dbPath)
- return os.Rename(backupPath, dbPath)
+func (db *Database) RestoreFromBackup(backupPath string) error {
+ logger.Infof("Restoring backup database %s into %s", backupPath, db.dbPath)
+ return os.Rename(backupPath, db.dbPath)
}
// Migrate the database
-func NeedsMigration() bool {
- return databaseSchemaVersion != appSchemaVersion
+func (db *Database) needsMigration() bool {
+ return db.schemaVersion != appSchemaVersion
}
-func AppSchemaVersion() uint {
+func (db *Database) AppSchemaVersion() uint {
return appSchemaVersion
}
-func DatabasePath() string {
- return dbPath
+func (db *Database) DatabasePath() string {
+ return db.dbPath
}
-func DatabaseBackupPath() string {
- return fmt.Sprintf("%s.%d.%s", dbPath, databaseSchemaVersion, time.Now().Format("20060102_150405"))
+func (db *Database) DatabaseBackupPath() string {
+ return fmt.Sprintf("%s.%d.%s", db.dbPath, db.schemaVersion, time.Now().Format("20060102_150405"))
}
-func Version() uint {
- return databaseSchemaVersion
+func (db *Database) Version() uint {
+ return db.schemaVersion
}
-func getMigrate() (*migrate.Migrate, error) {
+func (db *Database) getMigrate() (*migrate.Migrate, error) {
migrations, err := iofs.New(migrationsBox, "migrations")
if err != nil {
panic(err.Error())
}
const disableForeignKeys = true
- conn := open(dbPath, disableForeignKeys)
+ conn, err := db.open(disableForeignKeys)
+ if err != nil {
+ return nil, err
+ }
driver, err := sqlite3mig.WithInstance(conn.DB, &sqlite3mig.Config{})
if err != nil {
@@ -223,31 +264,31 @@ func getMigrate() (*migrate.Migrate, error) {
return migrate.NewWithInstance(
"iofs",
migrations,
- dbPath,
+ db.dbPath,
driver,
)
}
-func getDatabaseSchemaVersion() error {
- m, err := getMigrate()
+func (db *Database) getDatabaseSchemaVersion() (uint, error) {
+ m, err := db.getMigrate()
if err != nil {
- return err
- }
-
- databaseSchemaVersion, _, _ = m.Version()
- m.Close()
- return nil
-}
-
-// Migrate the database
-func RunMigrations() error {
- m, err := getMigrate()
- if err != nil {
- panic(err.Error())
+ return 0, err
}
defer m.Close()
- databaseSchemaVersion, _, _ = m.Version()
+ ret, _, _ := m.Version()
+ return ret, nil
+}
+
+// Migrate the database
+func (db *Database) RunMigrations() error {
+ m, err := db.getMigrate()
+ if err != nil {
+ return err
+ }
+ defer m.Close()
+
+ databaseSchemaVersion, _, _ := m.Version()
stepNumber := appSchemaVersion - databaseSchemaVersion
if stepNumber != 0 {
logger.Infof("Migrating database from version %d to %d", databaseSchemaVersion, appSchemaVersion)
@@ -258,14 +299,19 @@ func RunMigrations() error {
}
}
+ // update the schema version
+ db.schemaVersion, _, _ = m.Version()
+
// re-initialise the database
- if err = Initialize(dbPath); err != nil {
- logger.Warnf("Error re-initializing the database: %v", err)
+ const disableForeignKeys = false
+ db.db, err = db.open(disableForeignKeys)
+ if err != nil {
+ return fmt.Errorf("re-initializing the database: %w", err)
}
// run a vacuum on the database
logger.Info("Performing vacuum on database")
- _, err = DB.Exec("VACUUM")
+ _, err = db.db.Exec("VACUUM")
if err != nil {
logger.Warnf("error while performing post-migration vacuum: %v", err)
}
diff --git a/pkg/sqlite/filter.go b/pkg/sqlite/filter.go
index 79af98efd..af8578dbe 100644
--- a/pkg/sqlite/filter.go
+++ b/pkg/sqlite/filter.go
@@ -1,6 +1,7 @@
package sqlite
import (
+ "context"
"errors"
"fmt"
"regexp"
@@ -26,13 +27,13 @@ func makeClause(sql string, args ...interface{}) sqlClause {
}
type criterionHandler interface {
- handle(f *filterBuilder)
+ handle(ctx context.Context, f *filterBuilder)
}
-type criterionHandlerFunc func(f *filterBuilder)
+type criterionHandlerFunc func(ctx context.Context, f *filterBuilder)
-func (h criterionHandlerFunc) handle(f *filterBuilder) {
- h(f)
+func (h criterionHandlerFunc) handle(ctx context.Context, f *filterBuilder) {
+ h(ctx, f)
}
type join struct {
@@ -331,8 +332,8 @@ func (f *filterBuilder) getError() error {
// handleCriterion calls the handle function on the provided criterionHandler,
// providing itself.
-func (f *filterBuilder) handleCriterion(handler criterionHandler) {
- handler.handle(f)
+func (f *filterBuilder) handleCriterion(ctx context.Context, handler criterionHandler) {
+ handler.handle(ctx, f)
}
func (f *filterBuilder) setError(e error) {
@@ -361,7 +362,7 @@ func (f *filterBuilder) andClauses(input []sqlClause) (string, []interface{}) {
}
func stringCriterionHandler(c *models.StringCriterionInput, column string) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
if c != nil {
if modifier := c.Modifier; c.Modifier.IsValid() {
switch modifier {
@@ -400,7 +401,7 @@ func stringCriterionHandler(c *models.StringCriterionInput, column string) crite
}
func intCriterionHandler(c *models.IntCriterionInput, column string) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
if c != nil {
clause, args := getIntCriterionWhereClause(column, *c)
f.addWhere(clause, args...)
@@ -409,7 +410,7 @@ func intCriterionHandler(c *models.IntCriterionInput, column string) criterionHa
}
func boolCriterionHandler(c *bool, column string) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
if c != nil {
var v string
if *c {
@@ -441,7 +442,7 @@ type joinedMultiCriterionHandlerBuilder struct {
}
func (m *joinedMultiCriterionHandlerBuilder) handler(criterion *models.MultiCriterionInput) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
if criterion != nil {
joinAlias := m.joinAs
if joinAlias == "" {
@@ -511,7 +512,7 @@ type multiCriterionHandlerBuilder struct {
}
func (m *multiCriterionHandlerBuilder) handler(criterion *models.MultiCriterionInput) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
if criterion != nil {
if criterion.Modifier == models.CriterionModifierIsNull || criterion.Modifier == models.CriterionModifierNotNull {
var notClause string
@@ -556,7 +557,7 @@ type countCriterionHandlerBuilder struct {
}
func (m *countCriterionHandlerBuilder) handler(criterion *models.IntCriterionInput) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
if criterion != nil {
clause, args := getCountCriterionClause(m.primaryTable, m.joinTable, m.primaryFK, *criterion)
@@ -576,11 +577,11 @@ type stringListCriterionHandlerBuilder struct {
}
func (m *stringListCriterionHandlerBuilder) handler(criterion *models.StringCriterionInput) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
if criterion != nil && len(criterion.Value) > 0 {
m.addJoinTable(f)
- stringCriterionHandler(criterion, m.joinTable+"."+m.stringColumn)(f)
+ stringCriterionHandler(criterion, m.joinTable+"."+m.stringColumn)(ctx, f)
}
}
}
@@ -597,7 +598,7 @@ type hierarchicalMultiCriterionHandlerBuilder struct {
relationsTable string
}
-func getHierarchicalValues(tx dbi, values []string, table, relationsTable, parentFK string, depth *int) string {
+func getHierarchicalValues(ctx context.Context, tx dbi, values []string, table, relationsTable, parentFK string, depth *int) string {
var args []interface{}
depthVal := 0
@@ -670,7 +671,7 @@ WHERE id in {inBinding}
query := fmt.Sprintf("WITH RECURSIVE %s SELECT 'VALUES' || GROUP_CONCAT('(' || root_id || ', ' || item_id || ')') AS val FROM items", withClause)
var valuesClause string
- err := tx.Get(&valuesClause, query, args...)
+ err := tx.Get(ctx, &valuesClause, query, args...)
if err != nil {
logger.Error(err)
// return record which never matches so we don't have to handle error here
@@ -693,7 +694,7 @@ func addHierarchicalConditionClauses(f *filterBuilder, criterion *models.Hierarc
}
func (m *hierarchicalMultiCriterionHandlerBuilder) handler(criterion *models.HierarchicalMultiCriterionInput) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
if criterion != nil {
if criterion.Modifier == models.CriterionModifierIsNull || criterion.Modifier == models.CriterionModifierNotNull {
var notClause string
@@ -713,7 +714,7 @@ func (m *hierarchicalMultiCriterionHandlerBuilder) handler(criterion *models.Hie
return
}
- valuesClause := getHierarchicalValues(m.tx, criterion.Value, m.foreignTable, m.relationsTable, m.parentFK, criterion.Depth)
+ valuesClause := getHierarchicalValues(ctx, m.tx, criterion.Value, m.foreignTable, m.relationsTable, m.parentFK, criterion.Depth)
f.addLeftJoin("(SELECT column1 AS root_id, column2 AS item_id FROM ("+valuesClause+"))", m.derivedTable, fmt.Sprintf("%s.item_id = %s.%s", m.derivedTable, m.primaryTable, m.foreignFK))
@@ -738,7 +739,7 @@ type joinedHierarchicalMultiCriterionHandlerBuilder struct {
}
func (m *joinedHierarchicalMultiCriterionHandlerBuilder) handler(criterion *models.HierarchicalMultiCriterionInput) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
if criterion != nil {
joinAlias := m.joinAs
@@ -762,7 +763,7 @@ func (m *joinedHierarchicalMultiCriterionHandlerBuilder) handler(criterion *mode
return
}
- valuesClause := getHierarchicalValues(m.tx, criterion.Value, m.foreignTable, m.relationsTable, m.parentFK, criterion.Depth)
+ valuesClause := getHierarchicalValues(ctx, m.tx, criterion.Value, m.foreignTable, m.relationsTable, m.parentFK, criterion.Depth)
joinTable := utils.StrFormat(`(
SELECT j.*, d.column1 AS root_id, d.column2 AS item_id FROM {joinTable} AS j
diff --git a/pkg/sqlite/filter_internal_test.go b/pkg/sqlite/filter_internal_test.go
index e9f173de0..f416b661c 100644
--- a/pkg/sqlite/filter_internal_test.go
+++ b/pkg/sqlite/filter_internal_test.go
@@ -1,6 +1,7 @@
package sqlite
import (
+ "context"
"errors"
"fmt"
"testing"
@@ -9,6 +10,8 @@ import (
"github.com/stretchr/testify/assert"
)
+var testCtx = context.Background()
+
func TestJoinsAddJoin(t *testing.T) {
var joins joins
@@ -462,7 +465,7 @@ func TestStringCriterionHandlerIncludes(t *testing.T) {
const quotedValue = `"two words"`
f := &filterBuilder{}
- f.handleCriterion(stringCriterionHandler(&models.StringCriterionInput{
+ f.handleCriterion(testCtx, stringCriterionHandler(&models.StringCriterionInput{
Modifier: models.CriterionModifierIncludes,
Value: value1,
}, column))
@@ -474,7 +477,7 @@ func TestStringCriterionHandlerIncludes(t *testing.T) {
assert.Equal("%words%", f.whereClauses[0].args[1])
f = &filterBuilder{}
- f.handleCriterion(stringCriterionHandler(&models.StringCriterionInput{
+ f.handleCriterion(testCtx, stringCriterionHandler(&models.StringCriterionInput{
Modifier: models.CriterionModifierIncludes,
Value: quotedValue,
}, column))
@@ -493,7 +496,7 @@ func TestStringCriterionHandlerExcludes(t *testing.T) {
const quotedValue = `"two words"`
f := &filterBuilder{}
- f.handleCriterion(stringCriterionHandler(&models.StringCriterionInput{
+ f.handleCriterion(testCtx, stringCriterionHandler(&models.StringCriterionInput{
Modifier: models.CriterionModifierExcludes,
Value: value1,
}, column))
@@ -505,7 +508,7 @@ func TestStringCriterionHandlerExcludes(t *testing.T) {
assert.Equal("%words%", f.whereClauses[0].args[1])
f = &filterBuilder{}
- f.handleCriterion(stringCriterionHandler(&models.StringCriterionInput{
+ f.handleCriterion(testCtx, stringCriterionHandler(&models.StringCriterionInput{
Modifier: models.CriterionModifierExcludes,
Value: quotedValue,
}, column))
@@ -523,7 +526,7 @@ func TestStringCriterionHandlerEquals(t *testing.T) {
const value1 = "two words"
f := &filterBuilder{}
- f.handleCriterion(stringCriterionHandler(&models.StringCriterionInput{
+ f.handleCriterion(testCtx, stringCriterionHandler(&models.StringCriterionInput{
Modifier: models.CriterionModifierEquals,
Value: value1,
}, column))
@@ -541,7 +544,7 @@ func TestStringCriterionHandlerNotEquals(t *testing.T) {
const value1 = "two words"
f := &filterBuilder{}
- f.handleCriterion(stringCriterionHandler(&models.StringCriterionInput{
+ f.handleCriterion(testCtx, stringCriterionHandler(&models.StringCriterionInput{
Modifier: models.CriterionModifierNotEquals,
Value: value1,
}, column))
@@ -560,7 +563,7 @@ func TestStringCriterionHandlerMatchesRegex(t *testing.T) {
const invalidValue = "*two words"
f := &filterBuilder{}
- f.handleCriterion(stringCriterionHandler(&models.StringCriterionInput{
+ f.handleCriterion(testCtx, stringCriterionHandler(&models.StringCriterionInput{
Modifier: models.CriterionModifierMatchesRegex,
Value: validValue,
}, column))
@@ -572,7 +575,7 @@ func TestStringCriterionHandlerMatchesRegex(t *testing.T) {
// ensure invalid regex sets error state
f = &filterBuilder{}
- f.handleCriterion(stringCriterionHandler(&models.StringCriterionInput{
+ f.handleCriterion(testCtx, stringCriterionHandler(&models.StringCriterionInput{
Modifier: models.CriterionModifierMatchesRegex,
Value: invalidValue,
}, column))
@@ -588,7 +591,7 @@ func TestStringCriterionHandlerNotMatchesRegex(t *testing.T) {
const invalidValue = "*two words"
f := &filterBuilder{}
- f.handleCriterion(stringCriterionHandler(&models.StringCriterionInput{
+ f.handleCriterion(testCtx, stringCriterionHandler(&models.StringCriterionInput{
Modifier: models.CriterionModifierNotMatchesRegex,
Value: validValue,
}, column))
@@ -600,7 +603,7 @@ func TestStringCriterionHandlerNotMatchesRegex(t *testing.T) {
// ensure invalid regex sets error state
f = &filterBuilder{}
- f.handleCriterion(stringCriterionHandler(&models.StringCriterionInput{
+ f.handleCriterion(testCtx, stringCriterionHandler(&models.StringCriterionInput{
Modifier: models.CriterionModifierNotMatchesRegex,
Value: invalidValue,
}, column))
@@ -614,7 +617,7 @@ func TestStringCriterionHandlerIsNull(t *testing.T) {
const column = "column"
f := &filterBuilder{}
- f.handleCriterion(stringCriterionHandler(&models.StringCriterionInput{
+ f.handleCriterion(testCtx, stringCriterionHandler(&models.StringCriterionInput{
Modifier: models.CriterionModifierIsNull,
}, column))
@@ -629,7 +632,7 @@ func TestStringCriterionHandlerNotNull(t *testing.T) {
const column = "column"
f := &filterBuilder{}
- f.handleCriterion(stringCriterionHandler(&models.StringCriterionInput{
+ f.handleCriterion(testCtx, stringCriterionHandler(&models.StringCriterionInput{
Modifier: models.CriterionModifierNotNull,
}, column))
diff --git a/pkg/database/functions.go b/pkg/sqlite/functions.go
similarity index 96%
rename from pkg/database/functions.go
rename to pkg/sqlite/functions.go
index 2971f1e22..29e93aa22 100644
--- a/pkg/database/functions.go
+++ b/pkg/sqlite/functions.go
@@ -1,4 +1,4 @@
-package database
+package sqlite
import (
"strconv"
diff --git a/pkg/sqlite/gallery.go b/pkg/sqlite/gallery.go
index b7f9276ac..bb94fa1f0 100644
--- a/pkg/sqlite/gallery.go
+++ b/pkg/sqlite/gallery.go
@@ -1,6 +1,7 @@
package sqlite
import (
+ "context"
"database/sql"
"errors"
"fmt"
@@ -20,62 +21,59 @@ type galleryQueryBuilder struct {
repository
}
-func NewGalleryReaderWriter(tx dbi) *galleryQueryBuilder {
- return &galleryQueryBuilder{
- repository{
- tx: tx,
- tableName: galleryTable,
- idColumn: idColumn,
- },
- }
+var GalleryReaderWriter = &galleryQueryBuilder{
+ repository{
+ tableName: galleryTable,
+ idColumn: idColumn,
+ },
}
-func (qb *galleryQueryBuilder) Create(newObject models.Gallery) (*models.Gallery, error) {
+func (qb *galleryQueryBuilder) Create(ctx context.Context, newObject models.Gallery) (*models.Gallery, error) {
var ret models.Gallery
- if err := qb.insertObject(newObject, &ret); err != nil {
+ if err := qb.insertObject(ctx, newObject, &ret); err != nil {
return nil, err
}
return &ret, nil
}
-func (qb *galleryQueryBuilder) Update(updatedObject models.Gallery) (*models.Gallery, error) {
+func (qb *galleryQueryBuilder) Update(ctx context.Context, updatedObject models.Gallery) (*models.Gallery, error) {
const partial = false
- if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil {
+ if err := qb.update(ctx, updatedObject.ID, updatedObject, partial); err != nil {
return nil, err
}
- return qb.Find(updatedObject.ID)
+ return qb.Find(ctx, updatedObject.ID)
}
-func (qb *galleryQueryBuilder) UpdatePartial(updatedObject models.GalleryPartial) (*models.Gallery, error) {
+func (qb *galleryQueryBuilder) UpdatePartial(ctx context.Context, updatedObject models.GalleryPartial) (*models.Gallery, error) {
const partial = true
- if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil {
+ if err := qb.update(ctx, updatedObject.ID, updatedObject, partial); err != nil {
return nil, err
}
- return qb.Find(updatedObject.ID)
+ return qb.Find(ctx, updatedObject.ID)
}
-func (qb *galleryQueryBuilder) UpdateChecksum(id int, checksum string) error {
- return qb.updateMap(id, map[string]interface{}{
+func (qb *galleryQueryBuilder) UpdateChecksum(ctx context.Context, id int, checksum string) error {
+ return qb.updateMap(ctx, id, map[string]interface{}{
"checksum": checksum,
})
}
-func (qb *galleryQueryBuilder) UpdateFileModTime(id int, modTime models.NullSQLiteTimestamp) error {
- return qb.updateMap(id, map[string]interface{}{
+func (qb *galleryQueryBuilder) UpdateFileModTime(ctx context.Context, id int, modTime models.NullSQLiteTimestamp) error {
+ return qb.updateMap(ctx, id, map[string]interface{}{
"file_mod_time": modTime,
})
}
-func (qb *galleryQueryBuilder) Destroy(id int) error {
- return qb.destroyExisting([]int{id})
+func (qb *galleryQueryBuilder) Destroy(ctx context.Context, id int) error {
+ return qb.destroyExisting(ctx, []int{id})
}
-func (qb *galleryQueryBuilder) Find(id int) (*models.Gallery, error) {
+func (qb *galleryQueryBuilder) Find(ctx context.Context, id int) (*models.Gallery, error) {
var ret models.Gallery
- if err := qb.get(id, &ret); err != nil {
+ if err := qb.getByID(ctx, id, &ret); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
@@ -84,10 +82,10 @@ func (qb *galleryQueryBuilder) Find(id int) (*models.Gallery, error) {
return &ret, nil
}
-func (qb *galleryQueryBuilder) FindMany(ids []int) ([]*models.Gallery, error) {
+func (qb *galleryQueryBuilder) FindMany(ctx context.Context, ids []int) ([]*models.Gallery, error) {
var galleries []*models.Gallery
for _, id := range ids {
- gallery, err := qb.Find(id)
+ gallery, err := qb.Find(ctx, id)
if err != nil {
return nil, err
}
@@ -102,61 +100,61 @@ func (qb *galleryQueryBuilder) FindMany(ids []int) ([]*models.Gallery, error) {
return galleries, nil
}
-func (qb *galleryQueryBuilder) FindByChecksum(checksum string) (*models.Gallery, error) {
+func (qb *galleryQueryBuilder) FindByChecksum(ctx context.Context, checksum string) (*models.Gallery, error) {
query := "SELECT * FROM galleries WHERE checksum = ? LIMIT 1"
args := []interface{}{checksum}
- return qb.queryGallery(query, args)
+ return qb.queryGallery(ctx, query, args)
}
-func (qb *galleryQueryBuilder) FindByChecksums(checksums []string) ([]*models.Gallery, error) {
+func (qb *galleryQueryBuilder) FindByChecksums(ctx context.Context, checksums []string) ([]*models.Gallery, error) {
query := "SELECT * FROM galleries WHERE checksum IN " + getInBinding(len(checksums))
var args []interface{}
for _, checksum := range checksums {
args = append(args, checksum)
}
- return qb.queryGalleries(query, args)
+ return qb.queryGalleries(ctx, query, args)
}
-func (qb *galleryQueryBuilder) FindByPath(path string) (*models.Gallery, error) {
+func (qb *galleryQueryBuilder) FindByPath(ctx context.Context, path string) (*models.Gallery, error) {
query := "SELECT * FROM galleries WHERE path = ? LIMIT 1"
args := []interface{}{path}
- return qb.queryGallery(query, args)
+ return qb.queryGallery(ctx, query, args)
}
-func (qb *galleryQueryBuilder) FindBySceneID(sceneID int) ([]*models.Gallery, error) {
+func (qb *galleryQueryBuilder) FindBySceneID(ctx context.Context, sceneID int) ([]*models.Gallery, error) {
query := selectAll(galleryTable) + `
LEFT JOIN scenes_galleries as scenes_join on scenes_join.gallery_id = galleries.id
WHERE scenes_join.scene_id = ?
GROUP BY galleries.id
`
args := []interface{}{sceneID}
- return qb.queryGalleries(query, args)
+ return qb.queryGalleries(ctx, query, args)
}
-func (qb *galleryQueryBuilder) FindByImageID(imageID int) ([]*models.Gallery, error) {
+func (qb *galleryQueryBuilder) FindByImageID(ctx context.Context, imageID int) ([]*models.Gallery, error) {
query := selectAll(galleryTable) + `
INNER JOIN galleries_images as images_join on images_join.gallery_id = galleries.id
WHERE images_join.image_id = ?
GROUP BY galleries.id
`
args := []interface{}{imageID}
- return qb.queryGalleries(query, args)
+ return qb.queryGalleries(ctx, query, args)
}
-func (qb *galleryQueryBuilder) CountByImageID(imageID int) (int, error) {
+func (qb *galleryQueryBuilder) CountByImageID(ctx context.Context, imageID int) (int, error) {
query := `SELECT image_id FROM galleries_images
WHERE image_id = ?
GROUP BY gallery_id`
args := []interface{}{imageID}
- return qb.runCountQuery(qb.buildCountQuery(query), args)
+ return qb.runCountQuery(ctx, qb.buildCountQuery(query), args)
}
-func (qb *galleryQueryBuilder) Count() (int, error) {
- return qb.runCountQuery(qb.buildCountQuery("SELECT galleries.id FROM galleries"), nil)
+func (qb *galleryQueryBuilder) Count(ctx context.Context) (int, error) {
+ return qb.runCountQuery(ctx, qb.buildCountQuery("SELECT galleries.id FROM galleries"), nil)
}
-func (qb *galleryQueryBuilder) All() ([]*models.Gallery, error) {
- return qb.queryGalleries(selectAll("galleries")+qb.getGallerySort(nil), nil)
+func (qb *galleryQueryBuilder) All(ctx context.Context) ([]*models.Gallery, error) {
+ return qb.queryGalleries(ctx, selectAll("galleries")+qb.getGallerySort(nil), nil)
}
func (qb *galleryQueryBuilder) validateFilter(galleryFilter *models.GalleryFilterType) error {
@@ -190,43 +188,43 @@ func (qb *galleryQueryBuilder) validateFilter(galleryFilter *models.GalleryFilte
return nil
}
-func (qb *galleryQueryBuilder) makeFilter(galleryFilter *models.GalleryFilterType) *filterBuilder {
+func (qb *galleryQueryBuilder) makeFilter(ctx context.Context, galleryFilter *models.GalleryFilterType) *filterBuilder {
query := &filterBuilder{}
if galleryFilter.And != nil {
- query.and(qb.makeFilter(galleryFilter.And))
+ query.and(qb.makeFilter(ctx, galleryFilter.And))
}
if galleryFilter.Or != nil {
- query.or(qb.makeFilter(galleryFilter.Or))
+ query.or(qb.makeFilter(ctx, galleryFilter.Or))
}
if galleryFilter.Not != nil {
- query.not(qb.makeFilter(galleryFilter.Not))
+ query.not(qb.makeFilter(ctx, galleryFilter.Not))
}
- query.handleCriterion(stringCriterionHandler(galleryFilter.Title, "galleries.title"))
- query.handleCriterion(stringCriterionHandler(galleryFilter.Details, "galleries.details"))
- query.handleCriterion(stringCriterionHandler(galleryFilter.Checksum, "galleries.checksum"))
- query.handleCriterion(boolCriterionHandler(galleryFilter.IsZip, "galleries.zip"))
- query.handleCriterion(stringCriterionHandler(galleryFilter.Path, "galleries.path"))
- query.handleCriterion(intCriterionHandler(galleryFilter.Rating, "galleries.rating"))
- query.handleCriterion(stringCriterionHandler(galleryFilter.URL, "galleries.url"))
- query.handleCriterion(boolCriterionHandler(galleryFilter.Organized, "galleries.organized"))
- query.handleCriterion(galleryIsMissingCriterionHandler(qb, galleryFilter.IsMissing))
- query.handleCriterion(galleryTagsCriterionHandler(qb, galleryFilter.Tags))
- query.handleCriterion(galleryTagCountCriterionHandler(qb, galleryFilter.TagCount))
- query.handleCriterion(galleryPerformersCriterionHandler(qb, galleryFilter.Performers))
- query.handleCriterion(galleryPerformerCountCriterionHandler(qb, galleryFilter.PerformerCount))
- query.handleCriterion(galleryStudioCriterionHandler(qb, galleryFilter.Studios))
- query.handleCriterion(galleryPerformerTagsCriterionHandler(qb, galleryFilter.PerformerTags))
- query.handleCriterion(galleryAverageResolutionCriterionHandler(qb, galleryFilter.AverageResolution))
- query.handleCriterion(galleryImageCountCriterionHandler(qb, galleryFilter.ImageCount))
- query.handleCriterion(galleryPerformerFavoriteCriterionHandler(galleryFilter.PerformerFavorite))
- query.handleCriterion(galleryPerformerAgeCriterionHandler(galleryFilter.PerformerAge))
+ query.handleCriterion(ctx, stringCriterionHandler(galleryFilter.Title, "galleries.title"))
+ query.handleCriterion(ctx, stringCriterionHandler(galleryFilter.Details, "galleries.details"))
+ query.handleCriterion(ctx, stringCriterionHandler(galleryFilter.Checksum, "galleries.checksum"))
+ query.handleCriterion(ctx, boolCriterionHandler(galleryFilter.IsZip, "galleries.zip"))
+ query.handleCriterion(ctx, stringCriterionHandler(galleryFilter.Path, "galleries.path"))
+ query.handleCriterion(ctx, intCriterionHandler(galleryFilter.Rating, "galleries.rating"))
+ query.handleCriterion(ctx, stringCriterionHandler(galleryFilter.URL, "galleries.url"))
+ query.handleCriterion(ctx, boolCriterionHandler(galleryFilter.Organized, "galleries.organized"))
+ query.handleCriterion(ctx, galleryIsMissingCriterionHandler(qb, galleryFilter.IsMissing))
+ query.handleCriterion(ctx, galleryTagsCriterionHandler(qb, galleryFilter.Tags))
+ query.handleCriterion(ctx, galleryTagCountCriterionHandler(qb, galleryFilter.TagCount))
+ query.handleCriterion(ctx, galleryPerformersCriterionHandler(qb, galleryFilter.Performers))
+ query.handleCriterion(ctx, galleryPerformerCountCriterionHandler(qb, galleryFilter.PerformerCount))
+ query.handleCriterion(ctx, galleryStudioCriterionHandler(qb, galleryFilter.Studios))
+ query.handleCriterion(ctx, galleryPerformerTagsCriterionHandler(qb, galleryFilter.PerformerTags))
+ query.handleCriterion(ctx, galleryAverageResolutionCriterionHandler(qb, galleryFilter.AverageResolution))
+ query.handleCriterion(ctx, galleryImageCountCriterionHandler(qb, galleryFilter.ImageCount))
+ query.handleCriterion(ctx, galleryPerformerFavoriteCriterionHandler(galleryFilter.PerformerFavorite))
+ query.handleCriterion(ctx, galleryPerformerAgeCriterionHandler(galleryFilter.PerformerAge))
return query
}
-func (qb *galleryQueryBuilder) makeQuery(galleryFilter *models.GalleryFilterType, findFilter *models.FindFilterType) (*queryBuilder, error) {
+func (qb *galleryQueryBuilder) makeQuery(ctx context.Context, galleryFilter *models.GalleryFilterType, findFilter *models.FindFilterType) (*queryBuilder, error) {
if galleryFilter == nil {
galleryFilter = &models.GalleryFilterType{}
}
@@ -245,7 +243,7 @@ func (qb *galleryQueryBuilder) makeQuery(galleryFilter *models.GalleryFilterType
if err := qb.validateFilter(galleryFilter); err != nil {
return nil, err
}
- filter := qb.makeFilter(galleryFilter)
+ filter := qb.makeFilter(ctx, galleryFilter)
query.addFilter(filter)
@@ -254,20 +252,20 @@ func (qb *galleryQueryBuilder) makeQuery(galleryFilter *models.GalleryFilterType
return &query, nil
}
-func (qb *galleryQueryBuilder) Query(galleryFilter *models.GalleryFilterType, findFilter *models.FindFilterType) ([]*models.Gallery, int, error) {
- query, err := qb.makeQuery(galleryFilter, findFilter)
+func (qb *galleryQueryBuilder) Query(ctx context.Context, galleryFilter *models.GalleryFilterType, findFilter *models.FindFilterType) ([]*models.Gallery, int, error) {
+ query, err := qb.makeQuery(ctx, galleryFilter, findFilter)
if err != nil {
return nil, 0, err
}
- idsResult, countResult, err := query.executeFind()
+ idsResult, countResult, err := query.executeFind(ctx)
if err != nil {
return nil, 0, err
}
var galleries []*models.Gallery
for _, id := range idsResult {
- gallery, err := qb.Find(id)
+ gallery, err := qb.Find(ctx, id)
if err != nil {
return nil, 0, err
}
@@ -278,17 +276,17 @@ func (qb *galleryQueryBuilder) Query(galleryFilter *models.GalleryFilterType, fi
return galleries, countResult, nil
}
-func (qb *galleryQueryBuilder) QueryCount(galleryFilter *models.GalleryFilterType, findFilter *models.FindFilterType) (int, error) {
- query, err := qb.makeQuery(galleryFilter, findFilter)
+func (qb *galleryQueryBuilder) QueryCount(ctx context.Context, galleryFilter *models.GalleryFilterType, findFilter *models.FindFilterType) (int, error) {
+ query, err := qb.makeQuery(ctx, galleryFilter, findFilter)
if err != nil {
return 0, err
}
- return query.executeCount()
+ return query.executeCount(ctx)
}
func galleryIsMissingCriterionHandler(qb *galleryQueryBuilder, isMissing *string) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
if isMissing != nil && *isMissing != "" {
switch *isMissing {
case "scenes":
@@ -389,7 +387,7 @@ func galleryStudioCriterionHandler(qb *galleryQueryBuilder, studios *models.Hier
}
func galleryPerformerTagsCriterionHandler(qb *galleryQueryBuilder, tags *models.HierarchicalMultiCriterionInput) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
if tags != nil {
if tags.Modifier == models.CriterionModifierIsNull || tags.Modifier == models.CriterionModifierNotNull {
var notClause string
@@ -408,7 +406,7 @@ func galleryPerformerTagsCriterionHandler(qb *galleryQueryBuilder, tags *models.
return
}
- valuesClause := getHierarchicalValues(qb.tx, tags.Value, tagTable, "tags_relations", "", tags.Depth)
+ valuesClause := getHierarchicalValues(ctx, qb.tx, tags.Value, tagTable, "tags_relations", "", tags.Depth)
f.addWith(`performer_tags AS (
SELECT pg.gallery_id, t.column1 AS root_tag_id FROM performers_galleries pg
@@ -424,7 +422,7 @@ INNER JOIN (` + valuesClause + `) t ON t.column2 = pt.tag_id
}
func galleryPerformerFavoriteCriterionHandler(performerfavorite *bool) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
if performerfavorite != nil {
f.addLeftJoin("performers_galleries", "", "galleries.id = performers_galleries.gallery_id")
@@ -444,7 +442,7 @@ GROUP BY performers_galleries.gallery_id HAVING SUM(performers.favorite) = 0)`,
}
func galleryPerformerAgeCriterionHandler(performerAge *models.IntCriterionInput) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
if performerAge != nil {
f.addInnerJoin("performers_galleries", "", "galleries.id = performers_galleries.gallery_id")
f.addInnerJoin("performers", "", "performers_galleries.performer_id = performers.id")
@@ -461,7 +459,7 @@ func galleryPerformerAgeCriterionHandler(performerAge *models.IntCriterionInput)
}
func galleryAverageResolutionCriterionHandler(qb *galleryQueryBuilder, resolution *models.ResolutionCriterionInput) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
if resolution != nil && resolution.Value.IsValid() {
qb.imagesRepository().join(f, "images_join", "galleries.id")
f.addLeftJoin("images", "", "images_join.image_id = images.id")
@@ -505,17 +503,17 @@ func (qb *galleryQueryBuilder) getGallerySort(findFilter *models.FindFilterType)
}
}
-func (qb *galleryQueryBuilder) queryGallery(query string, args []interface{}) (*models.Gallery, error) {
- results, err := qb.queryGalleries(query, args)
+func (qb *galleryQueryBuilder) queryGallery(ctx context.Context, query string, args []interface{}) (*models.Gallery, error) {
+ results, err := qb.queryGalleries(ctx, query, args)
if err != nil || len(results) < 1 {
return nil, err
}
return results[0], nil
}
-func (qb *galleryQueryBuilder) queryGalleries(query string, args []interface{}) ([]*models.Gallery, error) {
+func (qb *galleryQueryBuilder) queryGalleries(ctx context.Context, query string, args []interface{}) ([]*models.Gallery, error) {
var ret models.Galleries
- if err := qb.query(query, args, &ret); err != nil {
+ if err := qb.query(ctx, query, args, &ret); err != nil {
return nil, err
}
@@ -533,13 +531,13 @@ func (qb *galleryQueryBuilder) performersRepository() *joinRepository {
}
}
-func (qb *galleryQueryBuilder) GetPerformerIDs(galleryID int) ([]int, error) {
- return qb.performersRepository().getIDs(galleryID)
+func (qb *galleryQueryBuilder) GetPerformerIDs(ctx context.Context, galleryID int) ([]int, error) {
+ return qb.performersRepository().getIDs(ctx, galleryID)
}
-func (qb *galleryQueryBuilder) UpdatePerformers(galleryID int, performerIDs []int) error {
+func (qb *galleryQueryBuilder) UpdatePerformers(ctx context.Context, galleryID int, performerIDs []int) error {
// Delete the existing joins and then create new ones
- return qb.performersRepository().replace(galleryID, performerIDs)
+ return qb.performersRepository().replace(ctx, galleryID, performerIDs)
}
func (qb *galleryQueryBuilder) tagsRepository() *joinRepository {
@@ -553,13 +551,13 @@ func (qb *galleryQueryBuilder) tagsRepository() *joinRepository {
}
}
-func (qb *galleryQueryBuilder) GetTagIDs(galleryID int) ([]int, error) {
- return qb.tagsRepository().getIDs(galleryID)
+func (qb *galleryQueryBuilder) GetTagIDs(ctx context.Context, galleryID int) ([]int, error) {
+ return qb.tagsRepository().getIDs(ctx, galleryID)
}
-func (qb *galleryQueryBuilder) UpdateTags(galleryID int, tagIDs []int) error {
+func (qb *galleryQueryBuilder) UpdateTags(ctx context.Context, galleryID int, tagIDs []int) error {
// Delete the existing joins and then create new ones
- return qb.tagsRepository().replace(galleryID, tagIDs)
+ return qb.tagsRepository().replace(ctx, galleryID, tagIDs)
}
func (qb *galleryQueryBuilder) imagesRepository() *joinRepository {
@@ -573,13 +571,13 @@ func (qb *galleryQueryBuilder) imagesRepository() *joinRepository {
}
}
-func (qb *galleryQueryBuilder) GetImageIDs(galleryID int) ([]int, error) {
- return qb.imagesRepository().getIDs(galleryID)
+func (qb *galleryQueryBuilder) GetImageIDs(ctx context.Context, galleryID int) ([]int, error) {
+ return qb.imagesRepository().getIDs(ctx, galleryID)
}
-func (qb *galleryQueryBuilder) UpdateImages(galleryID int, imageIDs []int) error {
+func (qb *galleryQueryBuilder) UpdateImages(ctx context.Context, galleryID int, imageIDs []int) error {
// Delete the existing joins and then create new ones
- return qb.imagesRepository().replace(galleryID, imageIDs)
+ return qb.imagesRepository().replace(ctx, galleryID, imageIDs)
}
func (qb *galleryQueryBuilder) scenesRepository() *joinRepository {
@@ -593,11 +591,11 @@ func (qb *galleryQueryBuilder) scenesRepository() *joinRepository {
}
}
-func (qb *galleryQueryBuilder) GetSceneIDs(galleryID int) ([]int, error) {
- return qb.scenesRepository().getIDs(galleryID)
+func (qb *galleryQueryBuilder) GetSceneIDs(ctx context.Context, galleryID int) ([]int, error) {
+ return qb.scenesRepository().getIDs(ctx, galleryID)
}
-func (qb *galleryQueryBuilder) UpdateScenes(galleryID int, sceneIDs []int) error {
+func (qb *galleryQueryBuilder) UpdateScenes(ctx context.Context, galleryID int, sceneIDs []int) error {
// Delete the existing joins and then create new ones
- return qb.scenesRepository().replace(galleryID, sceneIDs)
+ return qb.scenesRepository().replace(ctx, galleryID, sceneIDs)
}
diff --git a/pkg/sqlite/gallery_test.go b/pkg/sqlite/gallery_test.go
index f9aa9ef5e..ae2cbe21b 100644
--- a/pkg/sqlite/gallery_test.go
+++ b/pkg/sqlite/gallery_test.go
@@ -4,6 +4,7 @@
package sqlite_test
import (
+ "context"
"math"
"strconv"
"testing"
@@ -11,14 +12,15 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stashapp/stash/pkg/models"
+ "github.com/stashapp/stash/pkg/sqlite"
)
func TestGalleryFind(t *testing.T) {
- withTxn(func(r models.Repository) error {
- gqb := r.Gallery()
+ withTxn(func(ctx context.Context) error {
+ gqb := sqlite.GalleryReaderWriter
const galleryIdx = 0
- gallery, err := gqb.Find(galleryIDs[galleryIdx])
+ gallery, err := gqb.Find(ctx, galleryIDs[galleryIdx])
if err != nil {
t.Errorf("Error finding gallery: %s", err.Error())
@@ -26,7 +28,7 @@ func TestGalleryFind(t *testing.T) {
assert.Equal(t, getGalleryStringValue(galleryIdx, "Path"), gallery.Path.String)
- gallery, err = gqb.Find(0)
+ gallery, err = gqb.Find(ctx, 0)
if err != nil {
t.Errorf("Error finding gallery: %s", err.Error())
@@ -39,12 +41,12 @@ func TestGalleryFind(t *testing.T) {
}
func TestGalleryFindByChecksum(t *testing.T) {
- withTxn(func(r models.Repository) error {
- gqb := r.Gallery()
+ withTxn(func(ctx context.Context) error {
+ gqb := sqlite.GalleryReaderWriter
const galleryIdx = 0
galleryChecksum := getGalleryStringValue(galleryIdx, "Checksum")
- gallery, err := gqb.FindByChecksum(galleryChecksum)
+ gallery, err := gqb.FindByChecksum(ctx, galleryChecksum)
if err != nil {
t.Errorf("Error finding gallery: %s", err.Error())
@@ -53,7 +55,7 @@ func TestGalleryFindByChecksum(t *testing.T) {
assert.Equal(t, getGalleryStringValue(galleryIdx, "Path"), gallery.Path.String)
galleryChecksum = "not exist"
- gallery, err = gqb.FindByChecksum(galleryChecksum)
+ gallery, err = gqb.FindByChecksum(ctx, galleryChecksum)
if err != nil {
t.Errorf("Error finding gallery: %s", err.Error())
@@ -66,12 +68,12 @@ func TestGalleryFindByChecksum(t *testing.T) {
}
func TestGalleryFindByPath(t *testing.T) {
- withTxn(func(r models.Repository) error {
- gqb := r.Gallery()
+ withTxn(func(ctx context.Context) error {
+ gqb := sqlite.GalleryReaderWriter
const galleryIdx = 0
galleryPath := getGalleryStringValue(galleryIdx, "Path")
- gallery, err := gqb.FindByPath(galleryPath)
+ gallery, err := gqb.FindByPath(ctx, galleryPath)
if err != nil {
t.Errorf("Error finding gallery: %s", err.Error())
@@ -80,7 +82,7 @@ func TestGalleryFindByPath(t *testing.T) {
assert.Equal(t, galleryPath, gallery.Path.String)
galleryPath = "not exist"
- gallery, err = gqb.FindByPath(galleryPath)
+ gallery, err = gqb.FindByPath(ctx, galleryPath)
if err != nil {
t.Errorf("Error finding gallery: %s", err.Error())
@@ -93,11 +95,11 @@ func TestGalleryFindByPath(t *testing.T) {
}
func TestGalleryFindBySceneID(t *testing.T) {
- withTxn(func(r models.Repository) error {
- gqb := r.Gallery()
+ withTxn(func(ctx context.Context) error {
+ gqb := sqlite.GalleryReaderWriter
sceneID := sceneIDs[sceneIdxWithGallery]
- galleries, err := gqb.FindBySceneID(sceneID)
+ galleries, err := gqb.FindBySceneID(ctx, sceneID)
if err != nil {
t.Errorf("Error finding gallery: %s", err.Error())
@@ -105,7 +107,7 @@ func TestGalleryFindBySceneID(t *testing.T) {
assert.Equal(t, getGalleryStringValue(galleryIdxWithScene, "Path"), galleries[0].Path.String)
- galleries, err = gqb.FindBySceneID(0)
+ galleries, err = gqb.FindBySceneID(ctx, 0)
if err != nil {
t.Errorf("Error finding gallery: %s", err.Error())
@@ -118,24 +120,24 @@ func TestGalleryFindBySceneID(t *testing.T) {
}
func TestGalleryQueryQ(t *testing.T) {
- withTxn(func(r models.Repository) error {
+ withTxn(func(ctx context.Context) error {
const galleryIdx = 0
q := getGalleryStringValue(galleryIdx, pathField)
- sqb := r.Gallery()
+ sqb := sqlite.GalleryReaderWriter
- galleryQueryQ(t, sqb, q, galleryIdx)
+ galleryQueryQ(ctx, t, sqb, q, galleryIdx)
return nil
})
}
-func galleryQueryQ(t *testing.T, qb models.GalleryReader, q string, expectedGalleryIdx int) {
+func galleryQueryQ(ctx context.Context, t *testing.T, qb models.GalleryReader, q string, expectedGalleryIdx int) {
filter := models.FindFilterType{
Q: &q,
}
- galleries, _, err := qb.Query(nil, &filter)
+ galleries, _, err := qb.Query(ctx, nil, &filter)
if err != nil {
t.Errorf("Error querying gallery: %s", err.Error())
}
@@ -146,7 +148,7 @@ func galleryQueryQ(t *testing.T, qb models.GalleryReader, q string, expectedGall
// no Q should return all results
filter.Q = nil
- galleries, _, err = qb.Query(nil, &filter)
+ galleries, _, err = qb.Query(ctx, nil, &filter)
if err != nil {
t.Errorf("Error querying gallery: %s", err.Error())
}
@@ -155,7 +157,7 @@ func galleryQueryQ(t *testing.T, qb models.GalleryReader, q string, expectedGall
}
func TestGalleryQueryPath(t *testing.T) {
- withTxn(func(r models.Repository) error {
+ withTxn(func(ctx context.Context) error {
const galleryIdx = 1
galleryPath := getGalleryStringValue(galleryIdx, "Path")
@@ -164,28 +166,28 @@ func TestGalleryQueryPath(t *testing.T) {
Modifier: models.CriterionModifierEquals,
}
- verifyGalleriesPath(t, r.Gallery(), pathCriterion)
+ verifyGalleriesPath(ctx, t, sqlite.GalleryReaderWriter, pathCriterion)
pathCriterion.Modifier = models.CriterionModifierNotEquals
- verifyGalleriesPath(t, r.Gallery(), pathCriterion)
+ verifyGalleriesPath(ctx, t, sqlite.GalleryReaderWriter, pathCriterion)
pathCriterion.Modifier = models.CriterionModifierMatchesRegex
pathCriterion.Value = "gallery.*1_Path"
- verifyGalleriesPath(t, r.Gallery(), pathCriterion)
+ verifyGalleriesPath(ctx, t, sqlite.GalleryReaderWriter, pathCriterion)
pathCriterion.Modifier = models.CriterionModifierNotMatchesRegex
- verifyGalleriesPath(t, r.Gallery(), pathCriterion)
+ verifyGalleriesPath(ctx, t, sqlite.GalleryReaderWriter, pathCriterion)
return nil
})
}
-func verifyGalleriesPath(t *testing.T, sqb models.GalleryReader, pathCriterion models.StringCriterionInput) {
+func verifyGalleriesPath(ctx context.Context, t *testing.T, sqb models.GalleryReader, pathCriterion models.StringCriterionInput) {
galleryFilter := models.GalleryFilterType{
Path: &pathCriterion,
}
- galleries, _, err := sqb.Query(&galleryFilter, nil)
+ galleries, _, err := sqb.Query(ctx, &galleryFilter, nil)
if err != nil {
t.Errorf("Error querying gallery: %s", err.Error())
}
@@ -215,10 +217,10 @@ func TestGalleryQueryPathOr(t *testing.T) {
},
}
- withTxn(func(r models.Repository) error {
- sqb := r.Gallery()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.GalleryReaderWriter
- galleries := queryGallery(t, sqb, &galleryFilter, nil)
+ galleries := queryGallery(ctx, t, sqb, &galleryFilter, nil)
assert.Len(t, galleries, 2)
assert.Equal(t, gallery1Path, galleries[0].Path.String)
@@ -246,10 +248,10 @@ func TestGalleryQueryPathAndRating(t *testing.T) {
},
}
- withTxn(func(r models.Repository) error {
- sqb := r.Gallery()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.GalleryReaderWriter
- galleries := queryGallery(t, sqb, &galleryFilter, nil)
+ galleries := queryGallery(ctx, t, sqb, &galleryFilter, nil)
assert.Len(t, galleries, 1)
assert.Equal(t, galleryPath, galleries[0].Path.String)
@@ -281,10 +283,10 @@ func TestGalleryQueryPathNotRating(t *testing.T) {
},
}
- withTxn(func(r models.Repository) error {
- sqb := r.Gallery()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.GalleryReaderWriter
- galleries := queryGallery(t, sqb, &galleryFilter, nil)
+ galleries := queryGallery(ctx, t, sqb, &galleryFilter, nil)
for _, gallery := range galleries {
verifyNullString(t, gallery.Path, pathCriterion)
@@ -312,20 +314,20 @@ func TestGalleryIllegalQuery(t *testing.T) {
Or: &subFilter,
}
- withTxn(func(r models.Repository) error {
- sqb := r.Gallery()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.GalleryReaderWriter
- _, _, err := sqb.Query(galleryFilter, nil)
+ _, _, err := sqb.Query(ctx, galleryFilter, nil)
assert.NotNil(err)
galleryFilter.Or = nil
galleryFilter.Not = &subFilter
- _, _, err = sqb.Query(galleryFilter, nil)
+ _, _, err = sqb.Query(ctx, galleryFilter, nil)
assert.NotNil(err)
galleryFilter.And = nil
galleryFilter.Or = &subFilter
- _, _, err = sqb.Query(galleryFilter, nil)
+ _, _, err = sqb.Query(ctx, galleryFilter, nil)
assert.NotNil(err)
return nil
@@ -371,11 +373,11 @@ func TestGalleryQueryURL(t *testing.T) {
}
func verifyGalleryQuery(t *testing.T, filter models.GalleryFilterType, verifyFn func(s *models.Gallery)) {
- withTxn(func(r models.Repository) error {
+ withTxn(func(ctx context.Context) error {
t.Helper()
- sqb := r.Gallery()
+ sqb := sqlite.GalleryReaderWriter
- galleries := queryGallery(t, sqb, &filter, nil)
+ galleries := queryGallery(ctx, t, sqb, &filter, nil)
// assume it should find at least one
assert.Greater(t, len(galleries), 0)
@@ -414,13 +416,13 @@ func TestGalleryQueryRating(t *testing.T) {
}
func verifyGalleriesRating(t *testing.T, ratingCriterion models.IntCriterionInput) {
- withTxn(func(r models.Repository) error {
- sqb := r.Gallery()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.GalleryReaderWriter
galleryFilter := models.GalleryFilterType{
Rating: &ratingCriterion,
}
- galleries, _, err := sqb.Query(&galleryFilter, nil)
+ galleries, _, err := sqb.Query(ctx, &galleryFilter, nil)
if err != nil {
t.Errorf("Error querying gallery: %s", err.Error())
}
@@ -434,8 +436,8 @@ func verifyGalleriesRating(t *testing.T, ratingCriterion models.IntCriterionInpu
}
func TestGalleryQueryIsMissingScene(t *testing.T) {
- withTxn(func(r models.Repository) error {
- qb := r.Gallery()
+ withTxn(func(ctx context.Context) error {
+ qb := sqlite.GalleryReaderWriter
isMissing := "scenes"
galleryFilter := models.GalleryFilterType{
IsMissing: &isMissing,
@@ -446,7 +448,7 @@ func TestGalleryQueryIsMissingScene(t *testing.T) {
Q: &q,
}
- galleries, _, err := qb.Query(&galleryFilter, &findFilter)
+ galleries, _, err := qb.Query(ctx, &galleryFilter, &findFilter)
if err != nil {
t.Errorf("Error querying gallery: %s", err.Error())
}
@@ -454,7 +456,7 @@ func TestGalleryQueryIsMissingScene(t *testing.T) {
assert.Len(t, galleries, 0)
findFilter.Q = nil
- galleries, _, err = qb.Query(&galleryFilter, &findFilter)
+ galleries, _, err = qb.Query(ctx, &galleryFilter, &findFilter)
if err != nil {
t.Errorf("Error querying gallery: %s", err.Error())
}
@@ -468,8 +470,8 @@ func TestGalleryQueryIsMissingScene(t *testing.T) {
})
}
-func queryGallery(t *testing.T, sqb models.GalleryReader, galleryFilter *models.GalleryFilterType, findFilter *models.FindFilterType) []*models.Gallery {
- galleries, _, err := sqb.Query(galleryFilter, findFilter)
+func queryGallery(ctx context.Context, t *testing.T, sqb models.GalleryReader, galleryFilter *models.GalleryFilterType, findFilter *models.FindFilterType) []*models.Gallery {
+ galleries, _, err := sqb.Query(ctx, galleryFilter, findFilter)
if err != nil {
t.Errorf("Error querying gallery: %s", err.Error())
}
@@ -478,8 +480,8 @@ func queryGallery(t *testing.T, sqb models.GalleryReader, galleryFilter *models.
}
func TestGalleryQueryIsMissingStudio(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Gallery()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.GalleryReaderWriter
isMissing := "studio"
galleryFilter := models.GalleryFilterType{
IsMissing: &isMissing,
@@ -490,12 +492,12 @@ func TestGalleryQueryIsMissingStudio(t *testing.T) {
Q: &q,
}
- galleries := queryGallery(t, sqb, &galleryFilter, &findFilter)
+ galleries := queryGallery(ctx, t, sqb, &galleryFilter, &findFilter)
assert.Len(t, galleries, 0)
findFilter.Q = nil
- galleries = queryGallery(t, sqb, &galleryFilter, &findFilter)
+ galleries = queryGallery(ctx, t, sqb, &galleryFilter, &findFilter)
// ensure non of the ids equal the one with studio
for _, gallery := range galleries {
@@ -507,8 +509,8 @@ func TestGalleryQueryIsMissingStudio(t *testing.T) {
}
func TestGalleryQueryIsMissingPerformers(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Gallery()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.GalleryReaderWriter
isMissing := "performers"
galleryFilter := models.GalleryFilterType{
IsMissing: &isMissing,
@@ -519,12 +521,12 @@ func TestGalleryQueryIsMissingPerformers(t *testing.T) {
Q: &q,
}
- galleries := queryGallery(t, sqb, &galleryFilter, &findFilter)
+ galleries := queryGallery(ctx, t, sqb, &galleryFilter, &findFilter)
assert.Len(t, galleries, 0)
findFilter.Q = nil
- galleries = queryGallery(t, sqb, &galleryFilter, &findFilter)
+ galleries = queryGallery(ctx, t, sqb, &galleryFilter, &findFilter)
assert.True(t, len(galleries) > 0)
@@ -538,8 +540,8 @@ func TestGalleryQueryIsMissingPerformers(t *testing.T) {
}
func TestGalleryQueryIsMissingTags(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Gallery()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.GalleryReaderWriter
isMissing := "tags"
galleryFilter := models.GalleryFilterType{
IsMissing: &isMissing,
@@ -550,12 +552,12 @@ func TestGalleryQueryIsMissingTags(t *testing.T) {
Q: &q,
}
- galleries := queryGallery(t, sqb, &galleryFilter, &findFilter)
+ galleries := queryGallery(ctx, t, sqb, &galleryFilter, &findFilter)
assert.Len(t, galleries, 0)
findFilter.Q = nil
- galleries = queryGallery(t, sqb, &galleryFilter, &findFilter)
+ galleries = queryGallery(ctx, t, sqb, &galleryFilter, &findFilter)
assert.True(t, len(galleries) > 0)
@@ -564,14 +566,14 @@ func TestGalleryQueryIsMissingTags(t *testing.T) {
}
func TestGalleryQueryIsMissingDate(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Gallery()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.GalleryReaderWriter
isMissing := "date"
galleryFilter := models.GalleryFilterType{
IsMissing: &isMissing,
}
- galleries := queryGallery(t, sqb, &galleryFilter, nil)
+ galleries := queryGallery(ctx, t, sqb, &galleryFilter, nil)
// three in four scenes have no date
assert.Len(t, galleries, int(math.Ceil(float64(totalGalleries)/4*3)))
@@ -586,8 +588,8 @@ func TestGalleryQueryIsMissingDate(t *testing.T) {
}
func TestGalleryQueryPerformers(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Gallery()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.GalleryReaderWriter
performerCriterion := models.MultiCriterionInput{
Value: []string{
strconv.Itoa(performerIDs[performerIdxWithGallery]),
@@ -600,7 +602,7 @@ func TestGalleryQueryPerformers(t *testing.T) {
Performers: &performerCriterion,
}
- galleries := queryGallery(t, sqb, &galleryFilter, nil)
+ galleries := queryGallery(ctx, t, sqb, &galleryFilter, nil)
assert.Len(t, galleries, 2)
@@ -617,7 +619,7 @@ func TestGalleryQueryPerformers(t *testing.T) {
Modifier: models.CriterionModifierIncludesAll,
}
- galleries = queryGallery(t, sqb, &galleryFilter, nil)
+ galleries = queryGallery(ctx, t, sqb, &galleryFilter, nil)
assert.Len(t, galleries, 1)
assert.Equal(t, galleryIDs[galleryIdxWithTwoPerformers], galleries[0].ID)
@@ -634,7 +636,7 @@ func TestGalleryQueryPerformers(t *testing.T) {
Q: &q,
}
- galleries = queryGallery(t, sqb, &galleryFilter, &findFilter)
+ galleries = queryGallery(ctx, t, sqb, &galleryFilter, &findFilter)
assert.Len(t, galleries, 0)
return nil
@@ -642,8 +644,8 @@ func TestGalleryQueryPerformers(t *testing.T) {
}
func TestGalleryQueryTags(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Gallery()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.GalleryReaderWriter
tagCriterion := models.HierarchicalMultiCriterionInput{
Value: []string{
strconv.Itoa(tagIDs[tagIdxWithGallery]),
@@ -656,7 +658,7 @@ func TestGalleryQueryTags(t *testing.T) {
Tags: &tagCriterion,
}
- galleries := queryGallery(t, sqb, &galleryFilter, nil)
+ galleries := queryGallery(ctx, t, sqb, &galleryFilter, nil)
assert.Len(t, galleries, 2)
// ensure ids are correct
@@ -672,7 +674,7 @@ func TestGalleryQueryTags(t *testing.T) {
Modifier: models.CriterionModifierIncludesAll,
}
- galleries = queryGallery(t, sqb, &galleryFilter, nil)
+ galleries = queryGallery(ctx, t, sqb, &galleryFilter, nil)
assert.Len(t, galleries, 1)
assert.Equal(t, galleryIDs[galleryIdxWithTwoTags], galleries[0].ID)
@@ -689,7 +691,7 @@ func TestGalleryQueryTags(t *testing.T) {
Q: &q,
}
- galleries = queryGallery(t, sqb, &galleryFilter, &findFilter)
+ galleries = queryGallery(ctx, t, sqb, &galleryFilter, &findFilter)
assert.Len(t, galleries, 0)
return nil
@@ -697,8 +699,8 @@ func TestGalleryQueryTags(t *testing.T) {
}
func TestGalleryQueryStudio(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Gallery()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.GalleryReaderWriter
studioCriterion := models.HierarchicalMultiCriterionInput{
Value: []string{
strconv.Itoa(studioIDs[studioIdxWithGallery]),
@@ -710,7 +712,7 @@ func TestGalleryQueryStudio(t *testing.T) {
Studios: &studioCriterion,
}
- galleries := queryGallery(t, sqb, &galleryFilter, nil)
+ galleries := queryGallery(ctx, t, sqb, &galleryFilter, nil)
assert.Len(t, galleries, 1)
@@ -729,7 +731,7 @@ func TestGalleryQueryStudio(t *testing.T) {
Q: &q,
}
- galleries = queryGallery(t, sqb, &galleryFilter, &findFilter)
+ galleries = queryGallery(ctx, t, sqb, &galleryFilter, &findFilter)
assert.Len(t, galleries, 0)
return nil
@@ -737,8 +739,8 @@ func TestGalleryQueryStudio(t *testing.T) {
}
func TestGalleryQueryStudioDepth(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Gallery()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.GalleryReaderWriter
depth := 2
studioCriterion := models.HierarchicalMultiCriterionInput{
Value: []string{
@@ -752,16 +754,16 @@ func TestGalleryQueryStudioDepth(t *testing.T) {
Studios: &studioCriterion,
}
- galleries := queryGallery(t, sqb, &galleryFilter, nil)
+ galleries := queryGallery(ctx, t, sqb, &galleryFilter, nil)
assert.Len(t, galleries, 1)
depth = 1
- galleries = queryGallery(t, sqb, &galleryFilter, nil)
+ galleries = queryGallery(ctx, t, sqb, &galleryFilter, nil)
assert.Len(t, galleries, 0)
studioCriterion.Value = []string{strconv.Itoa(studioIDs[studioIdxWithParentAndChild])}
- galleries = queryGallery(t, sqb, &galleryFilter, nil)
+ galleries = queryGallery(ctx, t, sqb, &galleryFilter, nil)
assert.Len(t, galleries, 1)
// ensure id is correct
@@ -782,15 +784,15 @@ func TestGalleryQueryStudioDepth(t *testing.T) {
Q: &q,
}
- galleries = queryGallery(t, sqb, &galleryFilter, &findFilter)
+ galleries = queryGallery(ctx, t, sqb, &galleryFilter, &findFilter)
assert.Len(t, galleries, 0)
depth = 1
- galleries = queryGallery(t, sqb, &galleryFilter, &findFilter)
+ galleries = queryGallery(ctx, t, sqb, &galleryFilter, &findFilter)
assert.Len(t, galleries, 1)
studioCriterion.Value = []string{strconv.Itoa(studioIDs[studioIdxWithParentAndChild])}
- galleries = queryGallery(t, sqb, &galleryFilter, &findFilter)
+ galleries = queryGallery(ctx, t, sqb, &galleryFilter, &findFilter)
assert.Len(t, galleries, 0)
return nil
@@ -798,8 +800,8 @@ func TestGalleryQueryStudioDepth(t *testing.T) {
}
func TestGalleryQueryPerformerTags(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Gallery()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.GalleryReaderWriter
tagCriterion := models.HierarchicalMultiCriterionInput{
Value: []string{
strconv.Itoa(tagIDs[tagIdxWithPerformer]),
@@ -812,7 +814,7 @@ func TestGalleryQueryPerformerTags(t *testing.T) {
PerformerTags: &tagCriterion,
}
- galleries := queryGallery(t, sqb, &galleryFilter, nil)
+ galleries := queryGallery(ctx, t, sqb, &galleryFilter, nil)
assert.Len(t, galleries, 2)
// ensure ids are correct
@@ -828,7 +830,7 @@ func TestGalleryQueryPerformerTags(t *testing.T) {
Modifier: models.CriterionModifierIncludesAll,
}
- galleries = queryGallery(t, sqb, &galleryFilter, nil)
+ galleries = queryGallery(ctx, t, sqb, &galleryFilter, nil)
assert.Len(t, galleries, 1)
assert.Equal(t, galleryIDs[galleryIdxWithPerformerTwoTags], galleries[0].ID)
@@ -845,7 +847,7 @@ func TestGalleryQueryPerformerTags(t *testing.T) {
Q: &q,
}
- galleries = queryGallery(t, sqb, &galleryFilter, &findFilter)
+ galleries = queryGallery(ctx, t, sqb, &galleryFilter, &findFilter)
assert.Len(t, galleries, 0)
tagCriterion = models.HierarchicalMultiCriterionInput{
@@ -853,22 +855,22 @@ func TestGalleryQueryPerformerTags(t *testing.T) {
}
q = getGalleryStringValue(galleryIdx1WithImage, titleField)
- galleries = queryGallery(t, sqb, &galleryFilter, &findFilter)
+ galleries = queryGallery(ctx, t, sqb, &galleryFilter, &findFilter)
assert.Len(t, galleries, 1)
assert.Equal(t, galleryIDs[galleryIdx1WithImage], galleries[0].ID)
q = getGalleryStringValue(galleryIdxWithPerformerTag, titleField)
- galleries = queryGallery(t, sqb, &galleryFilter, &findFilter)
+ galleries = queryGallery(ctx, t, sqb, &galleryFilter, &findFilter)
assert.Len(t, galleries, 0)
tagCriterion.Modifier = models.CriterionModifierNotNull
- galleries = queryGallery(t, sqb, &galleryFilter, &findFilter)
+ galleries = queryGallery(ctx, t, sqb, &galleryFilter, &findFilter)
assert.Len(t, galleries, 1)
assert.Equal(t, galleryIDs[galleryIdxWithPerformerTag], galleries[0].ID)
q = getGalleryStringValue(galleryIdx1WithImage, titleField)
- galleries = queryGallery(t, sqb, &galleryFilter, &findFilter)
+ galleries = queryGallery(ctx, t, sqb, &galleryFilter, &findFilter)
assert.Len(t, galleries, 0)
return nil
@@ -895,17 +897,17 @@ func TestGalleryQueryTagCount(t *testing.T) {
}
func verifyGalleriesTagCount(t *testing.T, tagCountCriterion models.IntCriterionInput) {
- withTxn(func(r models.Repository) error {
- sqb := r.Gallery()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.GalleryReaderWriter
galleryFilter := models.GalleryFilterType{
TagCount: &tagCountCriterion,
}
- galleries := queryGallery(t, sqb, &galleryFilter, nil)
+ galleries := queryGallery(ctx, t, sqb, &galleryFilter, nil)
assert.Greater(t, len(galleries), 0)
for _, gallery := range galleries {
- ids, err := sqb.GetTagIDs(gallery.ID)
+ ids, err := sqb.GetTagIDs(ctx, gallery.ID)
if err != nil {
return err
}
@@ -936,17 +938,17 @@ func TestGalleryQueryPerformerCount(t *testing.T) {
}
func verifyGalleriesPerformerCount(t *testing.T, performerCountCriterion models.IntCriterionInput) {
- withTxn(func(r models.Repository) error {
- sqb := r.Gallery()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.GalleryReaderWriter
galleryFilter := models.GalleryFilterType{
PerformerCount: &performerCountCriterion,
}
- galleries := queryGallery(t, sqb, &galleryFilter, nil)
+ galleries := queryGallery(ctx, t, sqb, &galleryFilter, nil)
assert.Greater(t, len(galleries), 0)
for _, gallery := range galleries {
- ids, err := sqb.GetPerformerIDs(gallery.ID)
+ ids, err := sqb.GetPerformerIDs(ctx, gallery.ID)
if err != nil {
return err
}
@@ -958,8 +960,8 @@ func verifyGalleriesPerformerCount(t *testing.T, performerCountCriterion models.
}
func TestGalleryQueryAverageResolution(t *testing.T) {
- withTxn(func(r models.Repository) error {
- qb := r.Gallery()
+ withTxn(func(ctx context.Context) error {
+ qb := sqlite.GalleryReaderWriter
resolution := models.ResolutionEnumLow
galleryFilter := models.GalleryFilterType{
AverageResolution: &models.ResolutionCriterionInput{
@@ -969,7 +971,7 @@ func TestGalleryQueryAverageResolution(t *testing.T) {
}
// not verifying average - just ensure we get at least one
- galleries := queryGallery(t, qb, &galleryFilter, nil)
+ galleries := queryGallery(ctx, t, qb, &galleryFilter, nil)
assert.Greater(t, len(galleries), 0)
return nil
@@ -996,19 +998,19 @@ func TestGalleryQueryImageCount(t *testing.T) {
}
func verifyGalleriesImageCount(t *testing.T, imageCountCriterion models.IntCriterionInput) {
- withTxn(func(r models.Repository) error {
- sqb := r.Gallery()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.GalleryReaderWriter
galleryFilter := models.GalleryFilterType{
ImageCount: &imageCountCriterion,
}
- galleries := queryGallery(t, sqb, &galleryFilter, nil)
+ galleries := queryGallery(ctx, t, sqb, &galleryFilter, nil)
assert.Greater(t, len(galleries), -1)
for _, gallery := range galleries {
pp := 0
- result, err := r.Image().Query(models.ImageQueryOptions{
+ result, err := sqlite.ImageReaderWriter.Query(ctx, models.ImageQueryOptions{
QueryOptions: models.QueryOptions{
FindFilter: &models.FindFilterType{
PerPage: &pp,
diff --git a/pkg/sqlite/image.go b/pkg/sqlite/image.go
index d2b3adb8f..3238595d7 100644
--- a/pkg/sqlite/image.go
+++ b/pkg/sqlite/image.go
@@ -1,6 +1,7 @@
package sqlite
import (
+ "context"
"database/sql"
"errors"
"fmt"
@@ -29,45 +30,42 @@ type imageQueryBuilder struct {
repository
}
-func NewImageReaderWriter(tx dbi) *imageQueryBuilder {
- return &imageQueryBuilder{
- repository{
- tx: tx,
- tableName: imageTable,
- idColumn: idColumn,
- },
- }
+var ImageReaderWriter = &imageQueryBuilder{
+ repository{
+ tableName: imageTable,
+ idColumn: idColumn,
+ },
}
-func (qb *imageQueryBuilder) Create(newObject models.Image) (*models.Image, error) {
+func (qb *imageQueryBuilder) Create(ctx context.Context, newObject models.Image) (*models.Image, error) {
var ret models.Image
- if err := qb.insertObject(newObject, &ret); err != nil {
+ if err := qb.insertObject(ctx, newObject, &ret); err != nil {
return nil, err
}
return &ret, nil
}
-func (qb *imageQueryBuilder) Update(updatedObject models.ImagePartial) (*models.Image, error) {
+func (qb *imageQueryBuilder) Update(ctx context.Context, updatedObject models.ImagePartial) (*models.Image, error) {
const partial = true
- if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil {
+ if err := qb.update(ctx, updatedObject.ID, updatedObject, partial); err != nil {
return nil, err
}
- return qb.find(updatedObject.ID)
+ return qb.find(ctx, updatedObject.ID)
}
-func (qb *imageQueryBuilder) UpdateFull(updatedObject models.Image) (*models.Image, error) {
+func (qb *imageQueryBuilder) UpdateFull(ctx context.Context, updatedObject models.Image) (*models.Image, error) {
const partial = false
- if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil {
+ if err := qb.update(ctx, updatedObject.ID, updatedObject, partial); err != nil {
return nil, err
}
- return qb.find(updatedObject.ID)
+ return qb.find(ctx, updatedObject.ID)
}
-func (qb *imageQueryBuilder) IncrementOCounter(id int) (int, error) {
- _, err := qb.tx.Exec(
+func (qb *imageQueryBuilder) IncrementOCounter(ctx context.Context, id int) (int, error) {
+ _, err := qb.tx.Exec(ctx,
`UPDATE `+imageTable+` SET o_counter = o_counter + 1 WHERE `+imageTable+`.id = ?`,
id,
)
@@ -75,7 +73,7 @@ func (qb *imageQueryBuilder) IncrementOCounter(id int) (int, error) {
return 0, err
}
- image, err := qb.find(id)
+ image, err := qb.find(ctx, id)
if err != nil {
return 0, err
}
@@ -83,8 +81,8 @@ func (qb *imageQueryBuilder) IncrementOCounter(id int) (int, error) {
return image.OCounter, nil
}
-func (qb *imageQueryBuilder) DecrementOCounter(id int) (int, error) {
- _, err := qb.tx.Exec(
+func (qb *imageQueryBuilder) DecrementOCounter(ctx context.Context, id int) (int, error) {
+ _, err := qb.tx.Exec(ctx,
`UPDATE `+imageTable+` SET o_counter = o_counter - 1 WHERE `+imageTable+`.id = ? and `+imageTable+`.o_counter > 0`,
id,
)
@@ -92,7 +90,7 @@ func (qb *imageQueryBuilder) DecrementOCounter(id int) (int, error) {
return 0, err
}
- image, err := qb.find(id)
+ image, err := qb.find(ctx, id)
if err != nil {
return 0, err
}
@@ -100,8 +98,8 @@ func (qb *imageQueryBuilder) DecrementOCounter(id int) (int, error) {
return image.OCounter, nil
}
-func (qb *imageQueryBuilder) ResetOCounter(id int) (int, error) {
- _, err := qb.tx.Exec(
+func (qb *imageQueryBuilder) ResetOCounter(ctx context.Context, id int) (int, error) {
+ _, err := qb.tx.Exec(ctx,
`UPDATE `+imageTable+` SET o_counter = 0 WHERE `+imageTable+`.id = ?`,
id,
)
@@ -109,7 +107,7 @@ func (qb *imageQueryBuilder) ResetOCounter(id int) (int, error) {
return 0, err
}
- image, err := qb.find(id)
+ image, err := qb.find(ctx, id)
if err != nil {
return 0, err
}
@@ -117,18 +115,18 @@ func (qb *imageQueryBuilder) ResetOCounter(id int) (int, error) {
return image.OCounter, nil
}
-func (qb *imageQueryBuilder) Destroy(id int) error {
- return qb.destroyExisting([]int{id})
+func (qb *imageQueryBuilder) Destroy(ctx context.Context, id int) error {
+ return qb.destroyExisting(ctx, []int{id})
}
-func (qb *imageQueryBuilder) Find(id int) (*models.Image, error) {
- return qb.find(id)
+func (qb *imageQueryBuilder) Find(ctx context.Context, id int) (*models.Image, error) {
+ return qb.find(ctx, id)
}
-func (qb *imageQueryBuilder) FindMany(ids []int) ([]*models.Image, error) {
+func (qb *imageQueryBuilder) FindMany(ctx context.Context, ids []int) ([]*models.Image, error) {
var images []*models.Image
for _, id := range ids {
- image, err := qb.Find(id)
+ image, err := qb.Find(ctx, id)
if err != nil {
return nil, err
}
@@ -143,9 +141,9 @@ func (qb *imageQueryBuilder) FindMany(ids []int) ([]*models.Image, error) {
return images, nil
}
-func (qb *imageQueryBuilder) find(id int) (*models.Image, error) {
+func (qb *imageQueryBuilder) find(ctx context.Context, id int) (*models.Image, error) {
var ret models.Image
- if err := qb.get(id, &ret); err != nil {
+ if err := qb.getByID(ctx, id, &ret); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
@@ -154,43 +152,43 @@ func (qb *imageQueryBuilder) find(id int) (*models.Image, error) {
return &ret, nil
}
-func (qb *imageQueryBuilder) FindByChecksum(checksum string) (*models.Image, error) {
+func (qb *imageQueryBuilder) FindByChecksum(ctx context.Context, checksum string) (*models.Image, error) {
query := "SELECT * FROM images WHERE checksum = ? LIMIT 1"
args := []interface{}{checksum}
- return qb.queryImage(query, args)
+ return qb.queryImage(ctx, query, args)
}
-func (qb *imageQueryBuilder) FindByPath(path string) (*models.Image, error) {
+func (qb *imageQueryBuilder) FindByPath(ctx context.Context, path string) (*models.Image, error) {
query := selectAll(imageTable) + "WHERE path = ? LIMIT 1"
args := []interface{}{path}
- return qb.queryImage(query, args)
+ return qb.queryImage(ctx, query, args)
}
-func (qb *imageQueryBuilder) FindByGalleryID(galleryID int) ([]*models.Image, error) {
+func (qb *imageQueryBuilder) FindByGalleryID(ctx context.Context, galleryID int) ([]*models.Image, error) {
args := []interface{}{galleryID}
sort := "path"
sortDir := models.SortDirectionEnumAsc
- return qb.queryImages(imagesForGalleryQuery+qb.getImageSort(&models.FindFilterType{
+ return qb.queryImages(ctx, imagesForGalleryQuery+qb.getImageSort(&models.FindFilterType{
Sort: &sort,
Direction: &sortDir,
}), args)
}
-func (qb *imageQueryBuilder) CountByGalleryID(galleryID int) (int, error) {
+func (qb *imageQueryBuilder) CountByGalleryID(ctx context.Context, galleryID int) (int, error) {
args := []interface{}{galleryID}
- return qb.runCountQuery(qb.buildCountQuery(countImagesForGalleryQuery), args)
+ return qb.runCountQuery(ctx, qb.buildCountQuery(countImagesForGalleryQuery), args)
}
-func (qb *imageQueryBuilder) Count() (int, error) {
- return qb.runCountQuery(qb.buildCountQuery("SELECT images.id FROM images"), nil)
+func (qb *imageQueryBuilder) Count(ctx context.Context) (int, error) {
+ return qb.runCountQuery(ctx, qb.buildCountQuery("SELECT images.id FROM images"), nil)
}
-func (qb *imageQueryBuilder) Size() (float64, error) {
- return qb.runSumQuery("SELECT SUM(cast(size as double)) as sum FROM images", nil)
+func (qb *imageQueryBuilder) Size(ctx context.Context) (float64, error) {
+ return qb.runSumQuery(ctx, "SELECT SUM(cast(size as double)) as sum FROM images", nil)
}
-func (qb *imageQueryBuilder) All() ([]*models.Image, error) {
- return qb.queryImages(selectAll(imageTable)+qb.getImageSort(nil), nil)
+func (qb *imageQueryBuilder) All(ctx context.Context) ([]*models.Image, error) {
+ return qb.queryImages(ctx, selectAll(imageTable)+qb.getImageSort(nil), nil)
}
func (qb *imageQueryBuilder) validateFilter(imageFilter *models.ImageFilterType) error {
@@ -224,41 +222,41 @@ func (qb *imageQueryBuilder) validateFilter(imageFilter *models.ImageFilterType)
return nil
}
-func (qb *imageQueryBuilder) makeFilter(imageFilter *models.ImageFilterType) *filterBuilder {
+func (qb *imageQueryBuilder) makeFilter(ctx context.Context, imageFilter *models.ImageFilterType) *filterBuilder {
query := &filterBuilder{}
if imageFilter.And != nil {
- query.and(qb.makeFilter(imageFilter.And))
+ query.and(qb.makeFilter(ctx, imageFilter.And))
}
if imageFilter.Or != nil {
- query.or(qb.makeFilter(imageFilter.Or))
+ query.or(qb.makeFilter(ctx, imageFilter.Or))
}
if imageFilter.Not != nil {
- query.not(qb.makeFilter(imageFilter.Not))
+ query.not(qb.makeFilter(ctx, imageFilter.Not))
}
- query.handleCriterion(stringCriterionHandler(imageFilter.Checksum, "images.checksum"))
- query.handleCriterion(stringCriterionHandler(imageFilter.Title, "images.title"))
- query.handleCriterion(stringCriterionHandler(imageFilter.Path, "images.path"))
- query.handleCriterion(intCriterionHandler(imageFilter.Rating, "images.rating"))
- query.handleCriterion(intCriterionHandler(imageFilter.OCounter, "images.o_counter"))
- query.handleCriterion(boolCriterionHandler(imageFilter.Organized, "images.organized"))
- query.handleCriterion(resolutionCriterionHandler(imageFilter.Resolution, "images.height", "images.width"))
- query.handleCriterion(imageIsMissingCriterionHandler(qb, imageFilter.IsMissing))
+ query.handleCriterion(ctx, stringCriterionHandler(imageFilter.Checksum, "images.checksum"))
+ query.handleCriterion(ctx, stringCriterionHandler(imageFilter.Title, "images.title"))
+ query.handleCriterion(ctx, stringCriterionHandler(imageFilter.Path, "images.path"))
+ query.handleCriterion(ctx, intCriterionHandler(imageFilter.Rating, "images.rating"))
+ query.handleCriterion(ctx, intCriterionHandler(imageFilter.OCounter, "images.o_counter"))
+ query.handleCriterion(ctx, boolCriterionHandler(imageFilter.Organized, "images.organized"))
+ query.handleCriterion(ctx, resolutionCriterionHandler(imageFilter.Resolution, "images.height", "images.width"))
+ query.handleCriterion(ctx, imageIsMissingCriterionHandler(qb, imageFilter.IsMissing))
- query.handleCriterion(imageTagsCriterionHandler(qb, imageFilter.Tags))
- query.handleCriterion(imageTagCountCriterionHandler(qb, imageFilter.TagCount))
- query.handleCriterion(imageGalleriesCriterionHandler(qb, imageFilter.Galleries))
- query.handleCriterion(imagePerformersCriterionHandler(qb, imageFilter.Performers))
- query.handleCriterion(imagePerformerCountCriterionHandler(qb, imageFilter.PerformerCount))
- query.handleCriterion(imageStudioCriterionHandler(qb, imageFilter.Studios))
- query.handleCriterion(imagePerformerTagsCriterionHandler(qb, imageFilter.PerformerTags))
- query.handleCriterion(imagePerformerFavoriteCriterionHandler(imageFilter.PerformerFavorite))
+ query.handleCriterion(ctx, imageTagsCriterionHandler(qb, imageFilter.Tags))
+ query.handleCriterion(ctx, imageTagCountCriterionHandler(qb, imageFilter.TagCount))
+ query.handleCriterion(ctx, imageGalleriesCriterionHandler(qb, imageFilter.Galleries))
+ query.handleCriterion(ctx, imagePerformersCriterionHandler(qb, imageFilter.Performers))
+ query.handleCriterion(ctx, imagePerformerCountCriterionHandler(qb, imageFilter.PerformerCount))
+ query.handleCriterion(ctx, imageStudioCriterionHandler(qb, imageFilter.Studios))
+ query.handleCriterion(ctx, imagePerformerTagsCriterionHandler(qb, imageFilter.PerformerTags))
+ query.handleCriterion(ctx, imagePerformerFavoriteCriterionHandler(imageFilter.PerformerFavorite))
return query
}
-func (qb *imageQueryBuilder) makeQuery(imageFilter *models.ImageFilterType, findFilter *models.FindFilterType) (*queryBuilder, error) {
+func (qb *imageQueryBuilder) makeQuery(ctx context.Context, imageFilter *models.ImageFilterType, findFilter *models.FindFilterType) (*queryBuilder, error) {
if imageFilter == nil {
imageFilter = &models.ImageFilterType{}
}
@@ -277,7 +275,7 @@ func (qb *imageQueryBuilder) makeQuery(imageFilter *models.ImageFilterType, find
if err := qb.validateFilter(imageFilter); err != nil {
return nil, err
}
- filter := qb.makeFilter(imageFilter)
+ filter := qb.makeFilter(ctx, imageFilter)
query.addFilter(filter)
@@ -286,18 +284,18 @@ func (qb *imageQueryBuilder) makeQuery(imageFilter *models.ImageFilterType, find
return &query, nil
}
-func (qb *imageQueryBuilder) Query(options models.ImageQueryOptions) (*models.ImageQueryResult, error) {
- query, err := qb.makeQuery(options.ImageFilter, options.FindFilter)
+func (qb *imageQueryBuilder) Query(ctx context.Context, options models.ImageQueryOptions) (*models.ImageQueryResult, error) {
+ query, err := qb.makeQuery(ctx, options.ImageFilter, options.FindFilter)
if err != nil {
return nil, err
}
- result, err := qb.queryGroupedFields(options, *query)
+ result, err := qb.queryGroupedFields(ctx, options, *query)
if err != nil {
return nil, fmt.Errorf("error querying aggregate fields: %w", err)
}
- idsResult, err := query.findIDs()
+ idsResult, err := query.findIDs(ctx)
if err != nil {
return nil, fmt.Errorf("error finding IDs: %w", err)
}
@@ -306,7 +304,7 @@ func (qb *imageQueryBuilder) Query(options models.ImageQueryOptions) (*models.Im
return result, nil
}
-func (qb *imageQueryBuilder) queryGroupedFields(options models.ImageQueryOptions, query queryBuilder) (*models.ImageQueryResult, error) {
+func (qb *imageQueryBuilder) queryGroupedFields(ctx context.Context, options models.ImageQueryOptions, query queryBuilder) (*models.ImageQueryResult, error) {
if !options.Count && !options.Megapixels && !options.TotalSize {
// nothing to do - return empty result
return models.NewImageQueryResult(qb), nil
@@ -336,7 +334,7 @@ func (qb *imageQueryBuilder) queryGroupedFields(options models.ImageQueryOptions
Megapixels float64
Size float64
}{}
- if err := qb.repository.queryStruct(aggregateQuery.toSQL(includeSortPagination), query.args, &out); err != nil {
+ if err := qb.repository.queryStruct(ctx, aggregateQuery.toSQL(includeSortPagination), query.args, &out); err != nil {
return nil, err
}
@@ -347,17 +345,17 @@ func (qb *imageQueryBuilder) queryGroupedFields(options models.ImageQueryOptions
return ret, nil
}
-func (qb *imageQueryBuilder) QueryCount(imageFilter *models.ImageFilterType, findFilter *models.FindFilterType) (int, error) {
- query, err := qb.makeQuery(imageFilter, findFilter)
+func (qb *imageQueryBuilder) QueryCount(ctx context.Context, imageFilter *models.ImageFilterType, findFilter *models.FindFilterType) (int, error) {
+ query, err := qb.makeQuery(ctx, imageFilter, findFilter)
if err != nil {
return 0, err
}
- return query.executeCount()
+ return query.executeCount(ctx)
}
func imageIsMissingCriterionHandler(qb *imageQueryBuilder, isMissing *string) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
if isMissing != nil && *isMissing != "" {
switch *isMissing {
case "studio":
@@ -453,7 +451,7 @@ func imagePerformerCountCriterionHandler(qb *imageQueryBuilder, performerCount *
}
func imagePerformerFavoriteCriterionHandler(performerfavorite *bool) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
if performerfavorite != nil {
f.addLeftJoin("performers_images", "", "images.id = performers_images.image_id")
@@ -487,7 +485,7 @@ func imageStudioCriterionHandler(qb *imageQueryBuilder, studios *models.Hierarch
}
func imagePerformerTagsCriterionHandler(qb *imageQueryBuilder, tags *models.HierarchicalMultiCriterionInput) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
if tags != nil {
if tags.Modifier == models.CriterionModifierIsNull || tags.Modifier == models.CriterionModifierNotNull {
var notClause string
@@ -506,7 +504,7 @@ func imagePerformerTagsCriterionHandler(qb *imageQueryBuilder, tags *models.Hier
return
}
- valuesClause := getHierarchicalValues(qb.tx, tags.Value, tagTable, "tags_relations", "", tags.Depth)
+ valuesClause := getHierarchicalValues(ctx, qb.tx, tags.Value, tagTable, "tags_relations", "", tags.Depth)
f.addWith(`performer_tags AS (
SELECT pi.image_id, t.column1 AS root_tag_id FROM performers_images pi
@@ -538,17 +536,17 @@ func (qb *imageQueryBuilder) getImageSort(findFilter *models.FindFilterType) str
}
}
-func (qb *imageQueryBuilder) queryImage(query string, args []interface{}) (*models.Image, error) {
- results, err := qb.queryImages(query, args)
+func (qb *imageQueryBuilder) queryImage(ctx context.Context, query string, args []interface{}) (*models.Image, error) {
+ results, err := qb.queryImages(ctx, query, args)
if err != nil || len(results) < 1 {
return nil, err
}
return results[0], nil
}
-func (qb *imageQueryBuilder) queryImages(query string, args []interface{}) ([]*models.Image, error) {
+func (qb *imageQueryBuilder) queryImages(ctx context.Context, query string, args []interface{}) ([]*models.Image, error) {
var ret models.Images
- if err := qb.query(query, args, &ret); err != nil {
+ if err := qb.query(ctx, query, args, &ret); err != nil {
return nil, err
}
@@ -566,13 +564,13 @@ func (qb *imageQueryBuilder) galleriesRepository() *joinRepository {
}
}
-func (qb *imageQueryBuilder) GetGalleryIDs(imageID int) ([]int, error) {
- return qb.galleriesRepository().getIDs(imageID)
+func (qb *imageQueryBuilder) GetGalleryIDs(ctx context.Context, imageID int) ([]int, error) {
+ return qb.galleriesRepository().getIDs(ctx, imageID)
}
-func (qb *imageQueryBuilder) UpdateGalleries(imageID int, galleryIDs []int) error {
+func (qb *imageQueryBuilder) UpdateGalleries(ctx context.Context, imageID int, galleryIDs []int) error {
// Delete the existing joins and then create new ones
- return qb.galleriesRepository().replace(imageID, galleryIDs)
+ return qb.galleriesRepository().replace(ctx, imageID, galleryIDs)
}
func (qb *imageQueryBuilder) performersRepository() *joinRepository {
@@ -586,13 +584,13 @@ func (qb *imageQueryBuilder) performersRepository() *joinRepository {
}
}
-func (qb *imageQueryBuilder) GetPerformerIDs(imageID int) ([]int, error) {
- return qb.performersRepository().getIDs(imageID)
+func (qb *imageQueryBuilder) GetPerformerIDs(ctx context.Context, imageID int) ([]int, error) {
+ return qb.performersRepository().getIDs(ctx, imageID)
}
-func (qb *imageQueryBuilder) UpdatePerformers(imageID int, performerIDs []int) error {
+func (qb *imageQueryBuilder) UpdatePerformers(ctx context.Context, imageID int, performerIDs []int) error {
// Delete the existing joins and then create new ones
- return qb.performersRepository().replace(imageID, performerIDs)
+ return qb.performersRepository().replace(ctx, imageID, performerIDs)
}
func (qb *imageQueryBuilder) tagsRepository() *joinRepository {
@@ -606,11 +604,11 @@ func (qb *imageQueryBuilder) tagsRepository() *joinRepository {
}
}
-func (qb *imageQueryBuilder) GetTagIDs(imageID int) ([]int, error) {
- return qb.tagsRepository().getIDs(imageID)
+func (qb *imageQueryBuilder) GetTagIDs(ctx context.Context, imageID int) ([]int, error) {
+ return qb.tagsRepository().getIDs(ctx, imageID)
}
-func (qb *imageQueryBuilder) UpdateTags(imageID int, tagIDs []int) error {
+func (qb *imageQueryBuilder) UpdateTags(ctx context.Context, imageID int, tagIDs []int) error {
// Delete the existing joins and then create new ones
- return qb.tagsRepository().replace(imageID, tagIDs)
+ return qb.tagsRepository().replace(ctx, imageID, tagIDs)
}
diff --git a/pkg/sqlite/image_test.go b/pkg/sqlite/image_test.go
index 552db2cdf..3c131ed56 100644
--- a/pkg/sqlite/image_test.go
+++ b/pkg/sqlite/image_test.go
@@ -4,6 +4,7 @@
package sqlite_test
import (
+ "context"
"database/sql"
"strconv"
"testing"
@@ -11,16 +12,17 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stashapp/stash/pkg/models"
+ "github.com/stashapp/stash/pkg/sqlite"
)
func TestImageFind(t *testing.T) {
- withTxn(func(r models.Repository) error {
+ withTxn(func(ctx context.Context) error {
// assume that the first image is imageWithGalleryPath
- sqb := r.Image()
+ sqb := sqlite.ImageReaderWriter
const imageIdx = 0
imageID := imageIDs[imageIdx]
- image, err := sqb.Find(imageID)
+ image, err := sqb.Find(ctx, imageID)
if err != nil {
t.Errorf("Error finding image: %s", err.Error())
@@ -29,7 +31,7 @@ func TestImageFind(t *testing.T) {
assert.Equal(t, getImageStringValue(imageIdx, "Path"), image.Path)
imageID = 0
- image, err = sqb.Find(imageID)
+ image, err = sqb.Find(ctx, imageID)
if err != nil {
t.Errorf("Error finding image: %s", err.Error())
@@ -42,12 +44,12 @@ func TestImageFind(t *testing.T) {
}
func TestImageFindByPath(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Image()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.ImageReaderWriter
const imageIdx = 1
imagePath := getImageStringValue(imageIdx, "Path")
- image, err := sqb.FindByPath(imagePath)
+ image, err := sqb.FindByPath(ctx, imagePath)
if err != nil {
t.Errorf("Error finding image: %s", err.Error())
@@ -57,7 +59,7 @@ func TestImageFindByPath(t *testing.T) {
assert.Equal(t, imagePath, image.Path)
imagePath = "not exist"
- image, err = sqb.FindByPath(imagePath)
+ image, err = sqb.FindByPath(ctx, imagePath)
if err != nil {
t.Errorf("Error finding image: %s", err.Error())
@@ -70,10 +72,10 @@ func TestImageFindByPath(t *testing.T) {
}
func TestImageFindByGalleryID(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Image()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.ImageReaderWriter
- images, err := sqb.FindByGalleryID(galleryIDs[galleryIdxWithTwoImages])
+ images, err := sqb.FindByGalleryID(ctx, galleryIDs[galleryIdxWithTwoImages])
if err != nil {
t.Errorf("Error finding images: %s", err.Error())
@@ -83,7 +85,7 @@ func TestImageFindByGalleryID(t *testing.T) {
assert.Equal(t, imageIDs[imageIdx1WithGallery], images[0].ID)
assert.Equal(t, imageIDs[imageIdx2WithGallery], images[1].ID)
- images, err = sqb.FindByGalleryID(galleryIDs[galleryIdxWithScene])
+ images, err = sqb.FindByGalleryID(ctx, galleryIDs[galleryIdxWithScene])
if err != nil {
t.Errorf("Error finding images: %s", err.Error())
@@ -96,21 +98,21 @@ func TestImageFindByGalleryID(t *testing.T) {
}
func TestImageQueryQ(t *testing.T) {
- withTxn(func(r models.Repository) error {
+ withTxn(func(ctx context.Context) error {
const imageIdx = 2
q := getImageStringValue(imageIdx, titleField)
- sqb := r.Image()
+ sqb := sqlite.ImageReaderWriter
- imageQueryQ(t, sqb, q, imageIdx)
+ imageQueryQ(ctx, t, sqb, q, imageIdx)
return nil
})
}
-func queryImagesWithCount(sqb models.ImageReader, imageFilter *models.ImageFilterType, findFilter *models.FindFilterType) ([]*models.Image, int, error) {
- result, err := sqb.Query(models.ImageQueryOptions{
+func queryImagesWithCount(ctx context.Context, sqb models.ImageReader, imageFilter *models.ImageFilterType, findFilter *models.FindFilterType) ([]*models.Image, int, error) {
+ result, err := sqb.Query(ctx, models.ImageQueryOptions{
QueryOptions: models.QueryOptions{
FindFilter: findFilter,
Count: true,
@@ -121,7 +123,7 @@ func queryImagesWithCount(sqb models.ImageReader, imageFilter *models.ImageFilte
return nil, 0, err
}
- images, err := result.Resolve()
+ images, err := result.Resolve(ctx)
if err != nil {
return nil, 0, err
}
@@ -129,17 +131,17 @@ func queryImagesWithCount(sqb models.ImageReader, imageFilter *models.ImageFilte
return images, result.Count, nil
}
-func imageQueryQ(t *testing.T, sqb models.ImageReader, q string, expectedImageIdx int) {
+func imageQueryQ(ctx context.Context, t *testing.T, sqb models.ImageReader, q string, expectedImageIdx int) {
filter := models.FindFilterType{
Q: &q,
}
- images := queryImages(t, sqb, nil, &filter)
+ images := queryImages(ctx, t, sqb, nil, &filter)
assert.Len(t, images, 1)
image := images[0]
assert.Equal(t, imageIDs[expectedImageIdx], image.ID)
- count, err := sqb.QueryCount(nil, &filter)
+ count, err := sqb.QueryCount(ctx, nil, &filter)
if err != nil {
t.Errorf("Error querying image: %s", err.Error())
}
@@ -147,7 +149,7 @@ func imageQueryQ(t *testing.T, sqb models.ImageReader, q string, expectedImageId
// no Q should return all results
filter.Q = nil
- images = queryImages(t, sqb, nil, &filter)
+ images = queryImages(ctx, t, sqb, nil, &filter)
assert.Len(t, images, totalImages)
}
@@ -175,13 +177,13 @@ func TestImageQueryPath(t *testing.T) {
}
func verifyImagePath(t *testing.T, pathCriterion models.StringCriterionInput, expected int) {
- withTxn(func(r models.Repository) error {
- sqb := r.Image()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.ImageReaderWriter
imageFilter := models.ImageFilterType{
Path: &pathCriterion,
}
- images := queryImages(t, sqb, &imageFilter, nil)
+ images := queryImages(ctx, t, sqb, &imageFilter, nil)
assert.Equal(t, expected, len(images), "number of returned images")
@@ -213,10 +215,10 @@ func TestImageQueryPathOr(t *testing.T) {
},
}
- withTxn(func(r models.Repository) error {
- sqb := r.Image()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.ImageReaderWriter
- images := queryImages(t, sqb, &imageFilter, nil)
+ images := queryImages(ctx, t, sqb, &imageFilter, nil)
assert.Len(t, images, 2)
assert.Equal(t, image1Path, images[0].Path)
@@ -244,10 +246,10 @@ func TestImageQueryPathAndRating(t *testing.T) {
},
}
- withTxn(func(r models.Repository) error {
- sqb := r.Image()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.ImageReaderWriter
- images := queryImages(t, sqb, &imageFilter, nil)
+ images := queryImages(ctx, t, sqb, &imageFilter, nil)
assert.Len(t, images, 1)
assert.Equal(t, imagePath, images[0].Path)
@@ -279,10 +281,10 @@ func TestImageQueryPathNotRating(t *testing.T) {
},
}
- withTxn(func(r models.Repository) error {
- sqb := r.Image()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.ImageReaderWriter
- images := queryImages(t, sqb, &imageFilter, nil)
+ images := queryImages(ctx, t, sqb, &imageFilter, nil)
for _, image := range images {
verifyString(t, image.Path, pathCriterion)
@@ -310,20 +312,20 @@ func TestImageIllegalQuery(t *testing.T) {
Or: &subFilter,
}
- withTxn(func(r models.Repository) error {
- sqb := r.Image()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.ImageReaderWriter
- _, _, err := queryImagesWithCount(sqb, imageFilter, nil)
+ _, _, err := queryImagesWithCount(ctx, sqb, imageFilter, nil)
assert.NotNil(err)
imageFilter.Or = nil
imageFilter.Not = &subFilter
- _, _, err = queryImagesWithCount(sqb, imageFilter, nil)
+ _, _, err = queryImagesWithCount(ctx, sqb, imageFilter, nil)
assert.NotNil(err)
imageFilter.And = nil
imageFilter.Or = &subFilter
- _, _, err = queryImagesWithCount(sqb, imageFilter, nil)
+ _, _, err = queryImagesWithCount(ctx, sqb, imageFilter, nil)
assert.NotNil(err)
return nil
@@ -356,13 +358,13 @@ func TestImageQueryRating(t *testing.T) {
}
func verifyImagesRating(t *testing.T, ratingCriterion models.IntCriterionInput) {
- withTxn(func(r models.Repository) error {
- sqb := r.Image()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.ImageReaderWriter
imageFilter := models.ImageFilterType{
Rating: &ratingCriterion,
}
- images, _, err := queryImagesWithCount(sqb, &imageFilter, nil)
+ images, _, err := queryImagesWithCount(ctx, sqb, &imageFilter, nil)
if err != nil {
t.Errorf("Error querying image: %s", err.Error())
}
@@ -395,13 +397,13 @@ func TestImageQueryOCounter(t *testing.T) {
}
func verifyImagesOCounter(t *testing.T, oCounterCriterion models.IntCriterionInput) {
- withTxn(func(r models.Repository) error {
- sqb := r.Image()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.ImageReaderWriter
imageFilter := models.ImageFilterType{
OCounter: &oCounterCriterion,
}
- images, _, err := queryImagesWithCount(sqb, &imageFilter, nil)
+ images, _, err := queryImagesWithCount(ctx, sqb, &imageFilter, nil)
if err != nil {
t.Errorf("Error querying image: %s", err.Error())
}
@@ -424,8 +426,8 @@ func TestImageQueryResolution(t *testing.T) {
}
func verifyImagesResolution(t *testing.T, resolution models.ResolutionEnum) {
- withTxn(func(r models.Repository) error {
- sqb := r.Image()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.ImageReaderWriter
imageFilter := models.ImageFilterType{
Resolution: &models.ResolutionCriterionInput{
Value: resolution,
@@ -433,7 +435,7 @@ func verifyImagesResolution(t *testing.T, resolution models.ResolutionEnum) {
},
}
- images, _, err := queryImagesWithCount(sqb, &imageFilter, nil)
+ images, _, err := queryImagesWithCount(ctx, sqb, &imageFilter, nil)
if err != nil {
t.Errorf("Error querying image: %s", err.Error())
}
@@ -465,8 +467,8 @@ func verifyImageResolution(t *testing.T, height sql.NullInt64, resolution models
}
func TestImageQueryIsMissingGalleries(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Image()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.ImageReaderWriter
isMissing := "galleries"
imageFilter := models.ImageFilterType{
IsMissing: &isMissing,
@@ -477,7 +479,7 @@ func TestImageQueryIsMissingGalleries(t *testing.T) {
Q: &q,
}
- images, _, err := queryImagesWithCount(sqb, &imageFilter, &findFilter)
+ images, _, err := queryImagesWithCount(ctx, sqb, &imageFilter, &findFilter)
if err != nil {
t.Errorf("Error querying image: %s", err.Error())
}
@@ -485,7 +487,7 @@ func TestImageQueryIsMissingGalleries(t *testing.T) {
assert.Len(t, images, 0)
findFilter.Q = nil
- images, _, err = queryImagesWithCount(sqb, &imageFilter, &findFilter)
+ images, _, err = queryImagesWithCount(ctx, sqb, &imageFilter, &findFilter)
if err != nil {
t.Errorf("Error querying image: %s", err.Error())
}
@@ -502,8 +504,8 @@ func TestImageQueryIsMissingGalleries(t *testing.T) {
}
func TestImageQueryIsMissingStudio(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Image()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.ImageReaderWriter
isMissing := "studio"
imageFilter := models.ImageFilterType{
IsMissing: &isMissing,
@@ -514,7 +516,7 @@ func TestImageQueryIsMissingStudio(t *testing.T) {
Q: &q,
}
- images, _, err := queryImagesWithCount(sqb, &imageFilter, &findFilter)
+ images, _, err := queryImagesWithCount(ctx, sqb, &imageFilter, &findFilter)
if err != nil {
t.Errorf("Error querying image: %s", err.Error())
}
@@ -522,7 +524,7 @@ func TestImageQueryIsMissingStudio(t *testing.T) {
assert.Len(t, images, 0)
findFilter.Q = nil
- images, _, err = queryImagesWithCount(sqb, &imageFilter, &findFilter)
+ images, _, err = queryImagesWithCount(ctx, sqb, &imageFilter, &findFilter)
if err != nil {
t.Errorf("Error querying image: %s", err.Error())
}
@@ -537,8 +539,8 @@ func TestImageQueryIsMissingStudio(t *testing.T) {
}
func TestImageQueryIsMissingPerformers(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Image()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.ImageReaderWriter
isMissing := "performers"
imageFilter := models.ImageFilterType{
IsMissing: &isMissing,
@@ -549,7 +551,7 @@ func TestImageQueryIsMissingPerformers(t *testing.T) {
Q: &q,
}
- images, _, err := queryImagesWithCount(sqb, &imageFilter, &findFilter)
+ images, _, err := queryImagesWithCount(ctx, sqb, &imageFilter, &findFilter)
if err != nil {
t.Errorf("Error querying image: %s", err.Error())
}
@@ -557,7 +559,7 @@ func TestImageQueryIsMissingPerformers(t *testing.T) {
assert.Len(t, images, 0)
findFilter.Q = nil
- images, _, err = queryImagesWithCount(sqb, &imageFilter, &findFilter)
+ images, _, err = queryImagesWithCount(ctx, sqb, &imageFilter, &findFilter)
if err != nil {
t.Errorf("Error querying image: %s", err.Error())
}
@@ -574,8 +576,8 @@ func TestImageQueryIsMissingPerformers(t *testing.T) {
}
func TestImageQueryIsMissingTags(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Image()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.ImageReaderWriter
isMissing := "tags"
imageFilter := models.ImageFilterType{
IsMissing: &isMissing,
@@ -586,7 +588,7 @@ func TestImageQueryIsMissingTags(t *testing.T) {
Q: &q,
}
- images, _, err := queryImagesWithCount(sqb, &imageFilter, &findFilter)
+ images, _, err := queryImagesWithCount(ctx, sqb, &imageFilter, &findFilter)
if err != nil {
t.Errorf("Error querying image: %s", err.Error())
}
@@ -594,7 +596,7 @@ func TestImageQueryIsMissingTags(t *testing.T) {
assert.Len(t, images, 0)
findFilter.Q = nil
- images, _, err = queryImagesWithCount(sqb, &imageFilter, &findFilter)
+ images, _, err = queryImagesWithCount(ctx, sqb, &imageFilter, &findFilter)
if err != nil {
t.Errorf("Error querying image: %s", err.Error())
}
@@ -606,14 +608,14 @@ func TestImageQueryIsMissingTags(t *testing.T) {
}
func TestImageQueryIsMissingRating(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Image()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.ImageReaderWriter
isMissing := "rating"
imageFilter := models.ImageFilterType{
IsMissing: &isMissing,
}
- images, _, err := queryImagesWithCount(sqb, &imageFilter, nil)
+ images, _, err := queryImagesWithCount(ctx, sqb, &imageFilter, nil)
if err != nil {
t.Errorf("Error querying image: %s", err.Error())
}
@@ -630,8 +632,8 @@ func TestImageQueryIsMissingRating(t *testing.T) {
}
func TestImageQueryGallery(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Image()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.ImageReaderWriter
galleryCriterion := models.MultiCriterionInput{
Value: []string{
strconv.Itoa(galleryIDs[galleryIdxWithImage]),
@@ -643,7 +645,7 @@ func TestImageQueryGallery(t *testing.T) {
Galleries: &galleryCriterion,
}
- images := queryImages(t, sqb, &imageFilter, nil)
+ images := queryImages(ctx, t, sqb, &imageFilter, nil)
assert.Len(t, images, 1)
// ensure ids are correct
@@ -659,7 +661,7 @@ func TestImageQueryGallery(t *testing.T) {
Modifier: models.CriterionModifierIncludesAll,
}
- images = queryImages(t, sqb, &imageFilter, nil)
+ images = queryImages(ctx, t, sqb, &imageFilter, nil)
assert.Len(t, images, 1)
assert.Equal(t, imageIDs[imageIdxWithTwoGalleries], images[0].ID)
@@ -676,11 +678,11 @@ func TestImageQueryGallery(t *testing.T) {
Q: &q,
}
- images = queryImages(t, sqb, &imageFilter, &findFilter)
+ images = queryImages(ctx, t, sqb, &imageFilter, &findFilter)
assert.Len(t, images, 0)
q = getImageStringValue(imageIdxWithPerformer, titleField)
- images = queryImages(t, sqb, &imageFilter, &findFilter)
+ images = queryImages(ctx, t, sqb, &imageFilter, &findFilter)
assert.Len(t, images, 1)
return nil
@@ -688,8 +690,8 @@ func TestImageQueryGallery(t *testing.T) {
}
func TestImageQueryPerformers(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Image()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.ImageReaderWriter
performerCriterion := models.MultiCriterionInput{
Value: []string{
strconv.Itoa(performerIDs[performerIdxWithImage]),
@@ -702,7 +704,7 @@ func TestImageQueryPerformers(t *testing.T) {
Performers: &performerCriterion,
}
- images := queryImages(t, sqb, &imageFilter, nil)
+ images := queryImages(ctx, t, sqb, &imageFilter, nil)
assert.Len(t, images, 2)
// ensure ids are correct
@@ -718,7 +720,7 @@ func TestImageQueryPerformers(t *testing.T) {
Modifier: models.CriterionModifierIncludesAll,
}
- images = queryImages(t, sqb, &imageFilter, nil)
+ images = queryImages(ctx, t, sqb, &imageFilter, nil)
assert.Len(t, images, 1)
assert.Equal(t, imageIDs[imageIdxWithTwoPerformers], images[0].ID)
@@ -734,7 +736,7 @@ func TestImageQueryPerformers(t *testing.T) {
Q: &q,
}
- images = queryImages(t, sqb, &imageFilter, &findFilter)
+ images = queryImages(ctx, t, sqb, &imageFilter, &findFilter)
assert.Len(t, images, 0)
performerCriterion = models.MultiCriterionInput{
@@ -742,22 +744,22 @@ func TestImageQueryPerformers(t *testing.T) {
}
q = getImageStringValue(imageIdxWithGallery, titleField)
- images = queryImages(t, sqb, &imageFilter, &findFilter)
+ images = queryImages(ctx, t, sqb, &imageFilter, &findFilter)
assert.Len(t, images, 1)
assert.Equal(t, imageIDs[imageIdxWithGallery], images[0].ID)
q = getImageStringValue(imageIdxWithPerformerTag, titleField)
- images = queryImages(t, sqb, &imageFilter, &findFilter)
+ images = queryImages(ctx, t, sqb, &imageFilter, &findFilter)
assert.Len(t, images, 0)
performerCriterion.Modifier = models.CriterionModifierNotNull
- images = queryImages(t, sqb, &imageFilter, &findFilter)
+ images = queryImages(ctx, t, sqb, &imageFilter, &findFilter)
assert.Len(t, images, 1)
assert.Equal(t, imageIDs[imageIdxWithPerformerTag], images[0].ID)
q = getImageStringValue(imageIdxWithGallery, titleField)
- images = queryImages(t, sqb, &imageFilter, &findFilter)
+ images = queryImages(ctx, t, sqb, &imageFilter, &findFilter)
assert.Len(t, images, 0)
return nil
@@ -765,8 +767,8 @@ func TestImageQueryPerformers(t *testing.T) {
}
func TestImageQueryTags(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Image()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.ImageReaderWriter
tagCriterion := models.HierarchicalMultiCriterionInput{
Value: []string{
strconv.Itoa(tagIDs[tagIdxWithImage]),
@@ -779,7 +781,7 @@ func TestImageQueryTags(t *testing.T) {
Tags: &tagCriterion,
}
- images := queryImages(t, sqb, &imageFilter, nil)
+ images := queryImages(ctx, t, sqb, &imageFilter, nil)
assert.Len(t, images, 2)
// ensure ids are correct
@@ -795,7 +797,7 @@ func TestImageQueryTags(t *testing.T) {
Modifier: models.CriterionModifierIncludesAll,
}
- images = queryImages(t, sqb, &imageFilter, nil)
+ images = queryImages(ctx, t, sqb, &imageFilter, nil)
assert.Len(t, images, 1)
assert.Equal(t, imageIDs[imageIdxWithTwoTags], images[0].ID)
@@ -811,7 +813,7 @@ func TestImageQueryTags(t *testing.T) {
Q: &q,
}
- images = queryImages(t, sqb, &imageFilter, &findFilter)
+ images = queryImages(ctx, t, sqb, &imageFilter, &findFilter)
assert.Len(t, images, 0)
tagCriterion = models.HierarchicalMultiCriterionInput{
@@ -819,22 +821,22 @@ func TestImageQueryTags(t *testing.T) {
}
q = getImageStringValue(imageIdxWithGallery, titleField)
- images = queryImages(t, sqb, &imageFilter, &findFilter)
+ images = queryImages(ctx, t, sqb, &imageFilter, &findFilter)
assert.Len(t, images, 1)
assert.Equal(t, imageIDs[imageIdxWithGallery], images[0].ID)
q = getImageStringValue(imageIdxWithTag, titleField)
- images = queryImages(t, sqb, &imageFilter, &findFilter)
+ images = queryImages(ctx, t, sqb, &imageFilter, &findFilter)
assert.Len(t, images, 0)
tagCriterion.Modifier = models.CriterionModifierNotNull
- images = queryImages(t, sqb, &imageFilter, &findFilter)
+ images = queryImages(ctx, t, sqb, &imageFilter, &findFilter)
assert.Len(t, images, 1)
assert.Equal(t, imageIDs[imageIdxWithTag], images[0].ID)
q = getImageStringValue(imageIdxWithGallery, titleField)
- images = queryImages(t, sqb, &imageFilter, &findFilter)
+ images = queryImages(ctx, t, sqb, &imageFilter, &findFilter)
assert.Len(t, images, 0)
return nil
@@ -842,8 +844,8 @@ func TestImageQueryTags(t *testing.T) {
}
func TestImageQueryStudio(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Image()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.ImageReaderWriter
studioCriterion := models.HierarchicalMultiCriterionInput{
Value: []string{
strconv.Itoa(studioIDs[studioIdxWithImage]),
@@ -855,7 +857,7 @@ func TestImageQueryStudio(t *testing.T) {
Studios: &studioCriterion,
}
- images, _, err := queryImagesWithCount(sqb, &imageFilter, nil)
+ images, _, err := queryImagesWithCount(ctx, sqb, &imageFilter, nil)
if err != nil {
t.Errorf("Error querying image: %s", err.Error())
}
@@ -877,7 +879,7 @@ func TestImageQueryStudio(t *testing.T) {
Q: &q,
}
- images, _, err = queryImagesWithCount(sqb, &imageFilter, &findFilter)
+ images, _, err = queryImagesWithCount(ctx, sqb, &imageFilter, &findFilter)
if err != nil {
t.Errorf("Error querying image: %s", err.Error())
}
@@ -888,8 +890,8 @@ func TestImageQueryStudio(t *testing.T) {
}
func TestImageQueryStudioDepth(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Image()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.ImageReaderWriter
depth := 2
studioCriterion := models.HierarchicalMultiCriterionInput{
Value: []string{
@@ -903,16 +905,16 @@ func TestImageQueryStudioDepth(t *testing.T) {
Studios: &studioCriterion,
}
- images := queryImages(t, sqb, &imageFilter, nil)
+ images := queryImages(ctx, t, sqb, &imageFilter, nil)
assert.Len(t, images, 1)
depth = 1
- images = queryImages(t, sqb, &imageFilter, nil)
+ images = queryImages(ctx, t, sqb, &imageFilter, nil)
assert.Len(t, images, 0)
studioCriterion.Value = []string{strconv.Itoa(studioIDs[studioIdxWithParentAndChild])}
- images = queryImages(t, sqb, &imageFilter, nil)
+ images = queryImages(ctx, t, sqb, &imageFilter, nil)
assert.Len(t, images, 1)
// ensure id is correct
@@ -933,23 +935,23 @@ func TestImageQueryStudioDepth(t *testing.T) {
Q: &q,
}
- images = queryImages(t, sqb, &imageFilter, &findFilter)
+ images = queryImages(ctx, t, sqb, &imageFilter, &findFilter)
assert.Len(t, images, 0)
depth = 1
- images = queryImages(t, sqb, &imageFilter, &findFilter)
+ images = queryImages(ctx, t, sqb, &imageFilter, &findFilter)
assert.Len(t, images, 1)
studioCriterion.Value = []string{strconv.Itoa(studioIDs[studioIdxWithParentAndChild])}
- images = queryImages(t, sqb, &imageFilter, &findFilter)
+ images = queryImages(ctx, t, sqb, &imageFilter, &findFilter)
assert.Len(t, images, 0)
return nil
})
}
-func queryImages(t *testing.T, sqb models.ImageReader, imageFilter *models.ImageFilterType, findFilter *models.FindFilterType) []*models.Image {
- images, _, err := queryImagesWithCount(sqb, imageFilter, findFilter)
+func queryImages(ctx context.Context, t *testing.T, sqb models.ImageReader, imageFilter *models.ImageFilterType, findFilter *models.FindFilterType) []*models.Image {
+ images, _, err := queryImagesWithCount(ctx, sqb, imageFilter, findFilter)
if err != nil {
t.Errorf("Error querying images: %s", err.Error())
}
@@ -958,8 +960,8 @@ func queryImages(t *testing.T, sqb models.ImageReader, imageFilter *models.Image
}
func TestImageQueryPerformerTags(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Image()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.ImageReaderWriter
tagCriterion := models.HierarchicalMultiCriterionInput{
Value: []string{
strconv.Itoa(tagIDs[tagIdxWithPerformer]),
@@ -972,7 +974,7 @@ func TestImageQueryPerformerTags(t *testing.T) {
PerformerTags: &tagCriterion,
}
- images := queryImages(t, sqb, &imageFilter, nil)
+ images := queryImages(ctx, t, sqb, &imageFilter, nil)
assert.Len(t, images, 2)
// ensure ids are correct
@@ -988,7 +990,7 @@ func TestImageQueryPerformerTags(t *testing.T) {
Modifier: models.CriterionModifierIncludesAll,
}
- images = queryImages(t, sqb, &imageFilter, nil)
+ images = queryImages(ctx, t, sqb, &imageFilter, nil)
assert.Len(t, images, 1)
assert.Equal(t, imageIDs[imageIdxWithPerformerTwoTags], images[0].ID)
@@ -1005,7 +1007,7 @@ func TestImageQueryPerformerTags(t *testing.T) {
Q: &q,
}
- images = queryImages(t, sqb, &imageFilter, &findFilter)
+ images = queryImages(ctx, t, sqb, &imageFilter, &findFilter)
assert.Len(t, images, 0)
tagCriterion = models.HierarchicalMultiCriterionInput{
@@ -1013,22 +1015,22 @@ func TestImageQueryPerformerTags(t *testing.T) {
}
q = getImageStringValue(imageIdxWithGallery, titleField)
- images = queryImages(t, sqb, &imageFilter, &findFilter)
+ images = queryImages(ctx, t, sqb, &imageFilter, &findFilter)
assert.Len(t, images, 1)
assert.Equal(t, imageIDs[imageIdxWithGallery], images[0].ID)
q = getImageStringValue(imageIdxWithPerformerTag, titleField)
- images = queryImages(t, sqb, &imageFilter, &findFilter)
+ images = queryImages(ctx, t, sqb, &imageFilter, &findFilter)
assert.Len(t, images, 0)
tagCriterion.Modifier = models.CriterionModifierNotNull
- images = queryImages(t, sqb, &imageFilter, &findFilter)
+ images = queryImages(ctx, t, sqb, &imageFilter, &findFilter)
assert.Len(t, images, 1)
assert.Equal(t, imageIDs[imageIdxWithPerformerTag], images[0].ID)
q = getImageStringValue(imageIdxWithGallery, titleField)
- images = queryImages(t, sqb, &imageFilter, &findFilter)
+ images = queryImages(ctx, t, sqb, &imageFilter, &findFilter)
assert.Len(t, images, 0)
return nil
@@ -1055,17 +1057,17 @@ func TestImageQueryTagCount(t *testing.T) {
}
func verifyImagesTagCount(t *testing.T, tagCountCriterion models.IntCriterionInput) {
- withTxn(func(r models.Repository) error {
- sqb := r.Image()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.ImageReaderWriter
imageFilter := models.ImageFilterType{
TagCount: &tagCountCriterion,
}
- images := queryImages(t, sqb, &imageFilter, nil)
+ images := queryImages(ctx, t, sqb, &imageFilter, nil)
assert.Greater(t, len(images), 0)
for _, image := range images {
- ids, err := sqb.GetTagIDs(image.ID)
+ ids, err := sqb.GetTagIDs(ctx, image.ID)
if err != nil {
return err
}
@@ -1096,17 +1098,17 @@ func TestImageQueryPerformerCount(t *testing.T) {
}
func verifyImagesPerformerCount(t *testing.T, performerCountCriterion models.IntCriterionInput) {
- withTxn(func(r models.Repository) error {
- sqb := r.Image()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.ImageReaderWriter
imageFilter := models.ImageFilterType{
PerformerCount: &performerCountCriterion,
}
- images := queryImages(t, sqb, &imageFilter, nil)
+ images := queryImages(ctx, t, sqb, &imageFilter, nil)
assert.Greater(t, len(images), 0)
for _, image := range images {
- ids, err := sqb.GetPerformerIDs(image.ID)
+ ids, err := sqb.GetPerformerIDs(ctx, image.ID)
if err != nil {
return err
}
@@ -1118,7 +1120,7 @@ func verifyImagesPerformerCount(t *testing.T, performerCountCriterion models.Int
}
func TestImageQuerySorting(t *testing.T) {
- withTxn(func(r models.Repository) error {
+ withTxn(func(ctx context.Context) error {
sort := titleField
direction := models.SortDirectionEnumAsc
findFilter := models.FindFilterType{
@@ -1126,8 +1128,8 @@ func TestImageQuerySorting(t *testing.T) {
Direction: &direction,
}
- sqb := r.Image()
- images, _, err := queryImagesWithCount(sqb, nil, &findFilter)
+ sqb := sqlite.ImageReaderWriter
+ images, _, err := queryImagesWithCount(ctx, sqb, nil, &findFilter)
if err != nil {
t.Errorf("Error querying image: %s", err.Error())
}
@@ -1142,7 +1144,7 @@ func TestImageQuerySorting(t *testing.T) {
// sort in descending order
direction = models.SortDirectionEnumDesc
- images, _, err = queryImagesWithCount(sqb, nil, &findFilter)
+ images, _, err = queryImagesWithCount(ctx, sqb, nil, &findFilter)
if err != nil {
t.Errorf("Error querying image: %s", err.Error())
}
@@ -1157,14 +1159,14 @@ func TestImageQuerySorting(t *testing.T) {
}
func TestImageQueryPagination(t *testing.T) {
- withTxn(func(r models.Repository) error {
+ withTxn(func(ctx context.Context) error {
perPage := 1
findFilter := models.FindFilterType{
PerPage: &perPage,
}
- sqb := r.Image()
- images, _, err := queryImagesWithCount(sqb, nil, &findFilter)
+ sqb := sqlite.ImageReaderWriter
+ images, _, err := queryImagesWithCount(ctx, sqb, nil, &findFilter)
if err != nil {
t.Errorf("Error querying image: %s", err.Error())
}
@@ -1175,7 +1177,7 @@ func TestImageQueryPagination(t *testing.T) {
page := 2
findFilter.Page = &page
- images, _, err = queryImagesWithCount(sqb, nil, &findFilter)
+ images, _, err = queryImagesWithCount(ctx, sqb, nil, &findFilter)
if err != nil {
t.Errorf("Error querying image: %s", err.Error())
}
@@ -1187,7 +1189,7 @@ func TestImageQueryPagination(t *testing.T) {
perPage = 2
page = 1
- images, _, err = queryImagesWithCount(sqb, nil, &findFilter)
+ images, _, err = queryImagesWithCount(ctx, sqb, nil, &findFilter)
if err != nil {
t.Errorf("Error querying image: %s", err.Error())
}
diff --git a/pkg/database/migrations/10_image_tables.up.sql b/pkg/sqlite/migrations/10_image_tables.up.sql
similarity index 100%
rename from pkg/database/migrations/10_image_tables.up.sql
rename to pkg/sqlite/migrations/10_image_tables.up.sql
diff --git a/pkg/database/migrations/11_tag_image.up.sql b/pkg/sqlite/migrations/11_tag_image.up.sql
similarity index 100%
rename from pkg/database/migrations/11_tag_image.up.sql
rename to pkg/sqlite/migrations/11_tag_image.up.sql
diff --git a/pkg/database/migrations/12_oshash.up.sql b/pkg/sqlite/migrations/12_oshash.up.sql
similarity index 100%
rename from pkg/database/migrations/12_oshash.up.sql
rename to pkg/sqlite/migrations/12_oshash.up.sql
diff --git a/pkg/database/migrations/13_images.up.sql b/pkg/sqlite/migrations/13_images.up.sql
similarity index 100%
rename from pkg/database/migrations/13_images.up.sql
rename to pkg/sqlite/migrations/13_images.up.sql
diff --git a/pkg/database/migrations/14_stash_box_ids.up.sql b/pkg/sqlite/migrations/14_stash_box_ids.up.sql
similarity index 100%
rename from pkg/database/migrations/14_stash_box_ids.up.sql
rename to pkg/sqlite/migrations/14_stash_box_ids.up.sql
diff --git a/pkg/database/migrations/15_file_mod_time.up.sql b/pkg/sqlite/migrations/15_file_mod_time.up.sql
similarity index 100%
rename from pkg/database/migrations/15_file_mod_time.up.sql
rename to pkg/sqlite/migrations/15_file_mod_time.up.sql
diff --git a/pkg/database/migrations/16_organized_flag.up.sql b/pkg/sqlite/migrations/16_organized_flag.up.sql
similarity index 100%
rename from pkg/database/migrations/16_organized_flag.up.sql
rename to pkg/sqlite/migrations/16_organized_flag.up.sql
diff --git a/pkg/database/migrations/17_reset_scene_size.up.sql b/pkg/sqlite/migrations/17_reset_scene_size.up.sql
similarity index 100%
rename from pkg/database/migrations/17_reset_scene_size.up.sql
rename to pkg/sqlite/migrations/17_reset_scene_size.up.sql
diff --git a/pkg/database/migrations/18_scene_galleries.up.sql b/pkg/sqlite/migrations/18_scene_galleries.up.sql
similarity index 100%
rename from pkg/database/migrations/18_scene_galleries.up.sql
rename to pkg/sqlite/migrations/18_scene_galleries.up.sql
diff --git a/pkg/database/migrations/19_performer_tags.up.sql b/pkg/sqlite/migrations/19_performer_tags.up.sql
similarity index 100%
rename from pkg/database/migrations/19_performer_tags.up.sql
rename to pkg/sqlite/migrations/19_performer_tags.up.sql
diff --git a/pkg/database/migrations/1_initial.down.sql b/pkg/sqlite/migrations/1_initial.down.sql
similarity index 100%
rename from pkg/database/migrations/1_initial.down.sql
rename to pkg/sqlite/migrations/1_initial.down.sql
diff --git a/pkg/database/migrations/1_initial.up.sql b/pkg/sqlite/migrations/1_initial.up.sql
similarity index 100%
rename from pkg/database/migrations/1_initial.up.sql
rename to pkg/sqlite/migrations/1_initial.up.sql
diff --git a/pkg/database/migrations/20_phash.up.sql b/pkg/sqlite/migrations/20_phash.up.sql
similarity index 100%
rename from pkg/database/migrations/20_phash.up.sql
rename to pkg/sqlite/migrations/20_phash.up.sql
diff --git a/pkg/database/migrations/21_performers_studios_details.up.sql b/pkg/sqlite/migrations/21_performers_studios_details.up.sql
similarity index 100%
rename from pkg/database/migrations/21_performers_studios_details.up.sql
rename to pkg/sqlite/migrations/21_performers_studios_details.up.sql
diff --git a/pkg/database/migrations/22_performers_studios_rating.up.sql b/pkg/sqlite/migrations/22_performers_studios_rating.up.sql
similarity index 100%
rename from pkg/database/migrations/22_performers_studios_rating.up.sql
rename to pkg/sqlite/migrations/22_performers_studios_rating.up.sql
diff --git a/pkg/database/migrations/23_scenes_interactive.up.sql b/pkg/sqlite/migrations/23_scenes_interactive.up.sql
similarity index 100%
rename from pkg/database/migrations/23_scenes_interactive.up.sql
rename to pkg/sqlite/migrations/23_scenes_interactive.up.sql
diff --git a/pkg/database/migrations/24_tag_aliases.up.sql b/pkg/sqlite/migrations/24_tag_aliases.up.sql
similarity index 100%
rename from pkg/database/migrations/24_tag_aliases.up.sql
rename to pkg/sqlite/migrations/24_tag_aliases.up.sql
diff --git a/pkg/database/migrations/25_saved_filters.up.sql b/pkg/sqlite/migrations/25_saved_filters.up.sql
similarity index 100%
rename from pkg/database/migrations/25_saved_filters.up.sql
rename to pkg/sqlite/migrations/25_saved_filters.up.sql
diff --git a/pkg/database/migrations/26_tag_hierarchy.up.sql b/pkg/sqlite/migrations/26_tag_hierarchy.up.sql
similarity index 100%
rename from pkg/database/migrations/26_tag_hierarchy.up.sql
rename to pkg/sqlite/migrations/26_tag_hierarchy.up.sql
diff --git a/pkg/database/migrations/27_studio_aliases.up.sql b/pkg/sqlite/migrations/27_studio_aliases.up.sql
similarity index 100%
rename from pkg/database/migrations/27_studio_aliases.up.sql
rename to pkg/sqlite/migrations/27_studio_aliases.up.sql
diff --git a/pkg/database/migrations/28_images_indexes.up.sql b/pkg/sqlite/migrations/28_images_indexes.up.sql
similarity index 100%
rename from pkg/database/migrations/28_images_indexes.up.sql
rename to pkg/sqlite/migrations/28_images_indexes.up.sql
diff --git a/pkg/database/migrations/29_interactive_speed.up.sql b/pkg/sqlite/migrations/29_interactive_speed.up.sql
similarity index 100%
rename from pkg/database/migrations/29_interactive_speed.up.sql
rename to pkg/sqlite/migrations/29_interactive_speed.up.sql
diff --git a/pkg/database/migrations/2_cover_image.up.sql b/pkg/sqlite/migrations/2_cover_image.up.sql
similarity index 100%
rename from pkg/database/migrations/2_cover_image.up.sql
rename to pkg/sqlite/migrations/2_cover_image.up.sql
diff --git a/pkg/database/migrations/30_ignore_autotag.up..sql b/pkg/sqlite/migrations/30_ignore_autotag.up..sql
similarity index 100%
rename from pkg/database/migrations/30_ignore_autotag.up..sql
rename to pkg/sqlite/migrations/30_ignore_autotag.up..sql
diff --git a/pkg/database/migrations/31_scenes_captions.up.sql b/pkg/sqlite/migrations/31_scenes_captions.up.sql
similarity index 100%
rename from pkg/database/migrations/31_scenes_captions.up.sql
rename to pkg/sqlite/migrations/31_scenes_captions.up.sql
diff --git a/pkg/database/migrations/3_o_counter.up.sql b/pkg/sqlite/migrations/3_o_counter.up.sql
similarity index 100%
rename from pkg/database/migrations/3_o_counter.up.sql
rename to pkg/sqlite/migrations/3_o_counter.up.sql
diff --git a/pkg/database/migrations/4_movie.up.sql b/pkg/sqlite/migrations/4_movie.up.sql
similarity index 100%
rename from pkg/database/migrations/4_movie.up.sql
rename to pkg/sqlite/migrations/4_movie.up.sql
diff --git a/pkg/database/migrations/5_performer_gender.down.sql b/pkg/sqlite/migrations/5_performer_gender.down.sql
similarity index 100%
rename from pkg/database/migrations/5_performer_gender.down.sql
rename to pkg/sqlite/migrations/5_performer_gender.down.sql
diff --git a/pkg/database/migrations/5_performer_gender.up.sql b/pkg/sqlite/migrations/5_performer_gender.up.sql
similarity index 100%
rename from pkg/database/migrations/5_performer_gender.up.sql
rename to pkg/sqlite/migrations/5_performer_gender.up.sql
diff --git a/pkg/database/migrations/6_scenes_format.up.sql b/pkg/sqlite/migrations/6_scenes_format.up.sql
similarity index 100%
rename from pkg/database/migrations/6_scenes_format.up.sql
rename to pkg/sqlite/migrations/6_scenes_format.up.sql
diff --git a/pkg/database/migrations/7_performer_optimization.up.sql b/pkg/sqlite/migrations/7_performer_optimization.up.sql
similarity index 100%
rename from pkg/database/migrations/7_performer_optimization.up.sql
rename to pkg/sqlite/migrations/7_performer_optimization.up.sql
diff --git a/pkg/database/migrations/8_movie_fix.up.sql b/pkg/sqlite/migrations/8_movie_fix.up.sql
similarity index 100%
rename from pkg/database/migrations/8_movie_fix.up.sql
rename to pkg/sqlite/migrations/8_movie_fix.up.sql
diff --git a/pkg/database/migrations/9_studios_parent_studio.up.sql b/pkg/sqlite/migrations/9_studios_parent_studio.up.sql
similarity index 100%
rename from pkg/database/migrations/9_studios_parent_studio.up.sql
rename to pkg/sqlite/migrations/9_studios_parent_studio.up.sql
diff --git a/pkg/sqlite/movies.go b/pkg/sqlite/movies.go
index eac02ae54..c52556e15 100644
--- a/pkg/sqlite/movies.go
+++ b/pkg/sqlite/movies.go
@@ -1,6 +1,7 @@
package sqlite
import (
+ "context"
"database/sql"
"errors"
"fmt"
@@ -15,50 +16,47 @@ type movieQueryBuilder struct {
repository
}
-func NewMovieReaderWriter(tx dbi) *movieQueryBuilder {
- return &movieQueryBuilder{
- repository{
- tx: tx,
- tableName: movieTable,
- idColumn: idColumn,
- },
- }
+var MovieReaderWriter = &movieQueryBuilder{
+ repository{
+ tableName: movieTable,
+ idColumn: idColumn,
+ },
}
-func (qb *movieQueryBuilder) Create(newObject models.Movie) (*models.Movie, error) {
+func (qb *movieQueryBuilder) Create(ctx context.Context, newObject models.Movie) (*models.Movie, error) {
var ret models.Movie
- if err := qb.insertObject(newObject, &ret); err != nil {
+ if err := qb.insertObject(ctx, newObject, &ret); err != nil {
return nil, err
}
return &ret, nil
}
-func (qb *movieQueryBuilder) Update(updatedObject models.MoviePartial) (*models.Movie, error) {
+func (qb *movieQueryBuilder) Update(ctx context.Context, updatedObject models.MoviePartial) (*models.Movie, error) {
const partial = true
- if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil {
+ if err := qb.update(ctx, updatedObject.ID, updatedObject, partial); err != nil {
return nil, err
}
- return qb.Find(updatedObject.ID)
+ return qb.Find(ctx, updatedObject.ID)
}
-func (qb *movieQueryBuilder) UpdateFull(updatedObject models.Movie) (*models.Movie, error) {
+func (qb *movieQueryBuilder) UpdateFull(ctx context.Context, updatedObject models.Movie) (*models.Movie, error) {
const partial = false
- if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil {
+ if err := qb.update(ctx, updatedObject.ID, updatedObject, partial); err != nil {
return nil, err
}
- return qb.Find(updatedObject.ID)
+ return qb.Find(ctx, updatedObject.ID)
}
-func (qb *movieQueryBuilder) Destroy(id int) error {
- return qb.destroyExisting([]int{id})
+func (qb *movieQueryBuilder) Destroy(ctx context.Context, id int) error {
+ return qb.destroyExisting(ctx, []int{id})
}
-func (qb *movieQueryBuilder) Find(id int) (*models.Movie, error) {
+func (qb *movieQueryBuilder) Find(ctx context.Context, id int) (*models.Movie, error) {
var ret models.Movie
- if err := qb.get(id, &ret); err != nil {
+ if err := qb.getByID(ctx, id, &ret); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
@@ -67,10 +65,10 @@ func (qb *movieQueryBuilder) Find(id int) (*models.Movie, error) {
return &ret, nil
}
-func (qb *movieQueryBuilder) FindMany(ids []int) ([]*models.Movie, error) {
+func (qb *movieQueryBuilder) FindMany(ctx context.Context, ids []int) ([]*models.Movie, error) {
var movies []*models.Movie
for _, id := range ids {
- movie, err := qb.Find(id)
+ movie, err := qb.Find(ctx, id)
if err != nil {
return nil, err
}
@@ -85,17 +83,17 @@ func (qb *movieQueryBuilder) FindMany(ids []int) ([]*models.Movie, error) {
return movies, nil
}
-func (qb *movieQueryBuilder) FindByName(name string, nocase bool) (*models.Movie, error) {
+func (qb *movieQueryBuilder) FindByName(ctx context.Context, name string, nocase bool) (*models.Movie, error) {
query := "SELECT * FROM movies WHERE name = ?"
if nocase {
query += " COLLATE NOCASE"
}
query += " LIMIT 1"
args := []interface{}{name}
- return qb.queryMovie(query, args)
+ return qb.queryMovie(ctx, query, args)
}
-func (qb *movieQueryBuilder) FindByNames(names []string, nocase bool) ([]*models.Movie, error) {
+func (qb *movieQueryBuilder) FindByNames(ctx context.Context, names []string, nocase bool) ([]*models.Movie, error) {
query := "SELECT * FROM movies WHERE name"
if nocase {
query += " COLLATE NOCASE"
@@ -105,34 +103,34 @@ func (qb *movieQueryBuilder) FindByNames(names []string, nocase bool) ([]*models
for _, name := range names {
args = append(args, name)
}
- return qb.queryMovies(query, args)
+ return qb.queryMovies(ctx, query, args)
}
-func (qb *movieQueryBuilder) Count() (int, error) {
- return qb.runCountQuery(qb.buildCountQuery("SELECT movies.id FROM movies"), nil)
+func (qb *movieQueryBuilder) Count(ctx context.Context) (int, error) {
+ return qb.runCountQuery(ctx, qb.buildCountQuery("SELECT movies.id FROM movies"), nil)
}
-func (qb *movieQueryBuilder) All() ([]*models.Movie, error) {
- return qb.queryMovies(selectAll("movies")+qb.getMovieSort(nil), nil)
+func (qb *movieQueryBuilder) All(ctx context.Context) ([]*models.Movie, error) {
+ return qb.queryMovies(ctx, selectAll("movies")+qb.getMovieSort(nil), nil)
}
-func (qb *movieQueryBuilder) makeFilter(movieFilter *models.MovieFilterType) *filterBuilder {
+func (qb *movieQueryBuilder) makeFilter(ctx context.Context, movieFilter *models.MovieFilterType) *filterBuilder {
query := &filterBuilder{}
- query.handleCriterion(stringCriterionHandler(movieFilter.Name, "movies.name"))
- query.handleCriterion(stringCriterionHandler(movieFilter.Director, "movies.director"))
- query.handleCriterion(stringCriterionHandler(movieFilter.Synopsis, "movies.synopsis"))
- query.handleCriterion(intCriterionHandler(movieFilter.Rating, "movies.rating"))
- query.handleCriterion(durationCriterionHandler(movieFilter.Duration, "movies.duration"))
- query.handleCriterion(movieIsMissingCriterionHandler(qb, movieFilter.IsMissing))
- query.handleCriterion(stringCriterionHandler(movieFilter.URL, "movies.url"))
- query.handleCriterion(movieStudioCriterionHandler(qb, movieFilter.Studios))
- query.handleCriterion(moviePerformersCriterionHandler(qb, movieFilter.Performers))
+ query.handleCriterion(ctx, stringCriterionHandler(movieFilter.Name, "movies.name"))
+ query.handleCriterion(ctx, stringCriterionHandler(movieFilter.Director, "movies.director"))
+ query.handleCriterion(ctx, stringCriterionHandler(movieFilter.Synopsis, "movies.synopsis"))
+ query.handleCriterion(ctx, intCriterionHandler(movieFilter.Rating, "movies.rating"))
+ query.handleCriterion(ctx, durationCriterionHandler(movieFilter.Duration, "movies.duration"))
+ query.handleCriterion(ctx, movieIsMissingCriterionHandler(qb, movieFilter.IsMissing))
+ query.handleCriterion(ctx, stringCriterionHandler(movieFilter.URL, "movies.url"))
+ query.handleCriterion(ctx, movieStudioCriterionHandler(qb, movieFilter.Studios))
+ query.handleCriterion(ctx, moviePerformersCriterionHandler(qb, movieFilter.Performers))
return query
}
-func (qb *movieQueryBuilder) Query(movieFilter *models.MovieFilterType, findFilter *models.FindFilterType) ([]*models.Movie, int, error) {
+func (qb *movieQueryBuilder) Query(ctx context.Context, movieFilter *models.MovieFilterType, findFilter *models.FindFilterType) ([]*models.Movie, int, error) {
if findFilter == nil {
findFilter = &models.FindFilterType{}
}
@@ -148,19 +146,19 @@ func (qb *movieQueryBuilder) Query(movieFilter *models.MovieFilterType, findFilt
query.parseQueryString(searchColumns, *q)
}
- filter := qb.makeFilter(movieFilter)
+ filter := qb.makeFilter(ctx, movieFilter)
query.addFilter(filter)
query.sortAndPagination = qb.getMovieSort(findFilter) + getPagination(findFilter)
- idsResult, countResult, err := query.executeFind()
+ idsResult, countResult, err := query.executeFind(ctx)
if err != nil {
return nil, 0, err
}
var movies []*models.Movie
for _, id := range idsResult {
- movie, err := qb.Find(id)
+ movie, err := qb.Find(ctx, id)
if err != nil {
return nil, 0, err
}
@@ -172,7 +170,7 @@ func (qb *movieQueryBuilder) Query(movieFilter *models.MovieFilterType, findFilt
}
func movieIsMissingCriterionHandler(qb *movieQueryBuilder, isMissing *string) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
if isMissing != nil && *isMissing != "" {
switch *isMissing {
case "front_image":
@@ -206,7 +204,7 @@ func movieStudioCriterionHandler(qb *movieQueryBuilder, studios *models.Hierarch
}
func moviePerformersCriterionHandler(qb *movieQueryBuilder, performers *models.MultiCriterionInput) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
if performers != nil {
if performers.Modifier == models.CriterionModifierIsNull || performers.Modifier == models.CriterionModifierNotNull {
var notClause string
@@ -273,30 +271,30 @@ func (qb *movieQueryBuilder) getMovieSort(findFilter *models.FindFilterType) str
}
}
-func (qb *movieQueryBuilder) queryMovie(query string, args []interface{}) (*models.Movie, error) {
- results, err := qb.queryMovies(query, args)
+func (qb *movieQueryBuilder) queryMovie(ctx context.Context, query string, args []interface{}) (*models.Movie, error) {
+ results, err := qb.queryMovies(ctx, query, args)
if err != nil || len(results) < 1 {
return nil, err
}
return results[0], nil
}
-func (qb *movieQueryBuilder) queryMovies(query string, args []interface{}) ([]*models.Movie, error) {
+func (qb *movieQueryBuilder) queryMovies(ctx context.Context, query string, args []interface{}) ([]*models.Movie, error) {
var ret models.Movies
- if err := qb.query(query, args, &ret); err != nil {
+ if err := qb.query(ctx, query, args, &ret); err != nil {
return nil, err
}
return []*models.Movie(ret), nil
}
-func (qb *movieQueryBuilder) UpdateImages(movieID int, frontImage []byte, backImage []byte) error {
+func (qb *movieQueryBuilder) UpdateImages(ctx context.Context, movieID int, frontImage []byte, backImage []byte) error {
// Delete the existing cover and then create new
- if err := qb.DestroyImages(movieID); err != nil {
+ if err := qb.DestroyImages(ctx, movieID); err != nil {
return err
}
- _, err := qb.tx.Exec(
+ _, err := qb.tx.Exec(ctx,
`INSERT INTO movies_images (movie_id, front_image, back_image) VALUES (?, ?, ?)`,
movieID,
frontImage,
@@ -306,26 +304,26 @@ func (qb *movieQueryBuilder) UpdateImages(movieID int, frontImage []byte, backIm
return err
}
-func (qb *movieQueryBuilder) DestroyImages(movieID int) error {
+func (qb *movieQueryBuilder) DestroyImages(ctx context.Context, movieID int) error {
// Delete the existing joins
- _, err := qb.tx.Exec("DELETE FROM movies_images WHERE movie_id = ?", movieID)
+ _, err := qb.tx.Exec(ctx, "DELETE FROM movies_images WHERE movie_id = ?", movieID)
if err != nil {
return err
}
return err
}
-func (qb *movieQueryBuilder) GetFrontImage(movieID int) ([]byte, error) {
+func (qb *movieQueryBuilder) GetFrontImage(ctx context.Context, movieID int) ([]byte, error) {
query := `SELECT front_image from movies_images WHERE movie_id = ?`
- return getImage(qb.tx, query, movieID)
+ return getImage(ctx, qb.tx, query, movieID)
}
-func (qb *movieQueryBuilder) GetBackImage(movieID int) ([]byte, error) {
+func (qb *movieQueryBuilder) GetBackImage(ctx context.Context, movieID int) ([]byte, error) {
query := `SELECT back_image from movies_images WHERE movie_id = ?`
- return getImage(qb.tx, query, movieID)
+ return getImage(ctx, qb.tx, query, movieID)
}
-func (qb *movieQueryBuilder) FindByPerformerID(performerID int) ([]*models.Movie, error) {
+func (qb *movieQueryBuilder) FindByPerformerID(ctx context.Context, performerID int) ([]*models.Movie, error) {
query := `SELECT DISTINCT movies.*
FROM movies
INNER JOIN movies_scenes ON movies.id = movies_scenes.movie_id
@@ -333,33 +331,33 @@ INNER JOIN performers_scenes ON performers_scenes.scene_id = movies_scenes.scene
WHERE performers_scenes.performer_id = ?
`
args := []interface{}{performerID}
- return qb.queryMovies(query, args)
+ return qb.queryMovies(ctx, query, args)
}
-func (qb *movieQueryBuilder) CountByPerformerID(performerID int) (int, error) {
+func (qb *movieQueryBuilder) CountByPerformerID(ctx context.Context, performerID int) (int, error) {
query := `SELECT COUNT(DISTINCT movies_scenes.movie_id) AS count
FROM movies_scenes
INNER JOIN performers_scenes ON performers_scenes.scene_id = movies_scenes.scene_id
WHERE performers_scenes.performer_id = ?
`
args := []interface{}{performerID}
- return qb.runCountQuery(query, args)
+ return qb.runCountQuery(ctx, query, args)
}
-func (qb *movieQueryBuilder) FindByStudioID(studioID int) ([]*models.Movie, error) {
+func (qb *movieQueryBuilder) FindByStudioID(ctx context.Context, studioID int) ([]*models.Movie, error) {
query := `SELECT movies.*
FROM movies
WHERE movies.studio_id = ?
`
args := []interface{}{studioID}
- return qb.queryMovies(query, args)
+ return qb.queryMovies(ctx, query, args)
}
-func (qb *movieQueryBuilder) CountByStudioID(studioID int) (int, error) {
+func (qb *movieQueryBuilder) CountByStudioID(ctx context.Context, studioID int) (int, error) {
query := `SELECT COUNT(1) AS count
FROM movies
WHERE movies.studio_id = ?
`
args := []interface{}{studioID}
- return qb.runCountQuery(query, args)
+ return qb.runCountQuery(ctx, query, args)
}
diff --git a/pkg/sqlite/movies_test.go b/pkg/sqlite/movies_test.go
index 75c6cc5bf..eff0cf50b 100644
--- a/pkg/sqlite/movies_test.go
+++ b/pkg/sqlite/movies_test.go
@@ -4,6 +4,7 @@
package sqlite_test
import (
+ "context"
"database/sql"
"fmt"
"strconv"
@@ -14,15 +15,16 @@ import (
"github.com/stashapp/stash/pkg/hash/md5"
"github.com/stashapp/stash/pkg/models"
+ "github.com/stashapp/stash/pkg/sqlite"
)
func TestMovieFindByName(t *testing.T) {
- withTxn(func(r models.Repository) error {
- mqb := r.Movie()
+ withTxn(func(ctx context.Context) error {
+ mqb := sqlite.MovieReaderWriter
name := movieNames[movieIdxWithScene] // find a movie by name
- movie, err := mqb.FindByName(name, false)
+ movie, err := mqb.FindByName(ctx, name, false)
if err != nil {
t.Errorf("Error finding movies: %s", err.Error())
@@ -32,7 +34,7 @@ func TestMovieFindByName(t *testing.T) {
name = movieNames[movieIdxWithDupName] // find a movie by name nocase
- movie, err = mqb.FindByName(name, true)
+ movie, err = mqb.FindByName(ctx, name, true)
if err != nil {
t.Errorf("Error finding movies: %s", err.Error())
@@ -48,21 +50,21 @@ func TestMovieFindByName(t *testing.T) {
}
func TestMovieFindByNames(t *testing.T) {
- withTxn(func(r models.Repository) error {
+ withTxn(func(ctx context.Context) error {
var names []string
- mqb := r.Movie()
+ mqb := sqlite.MovieReaderWriter
names = append(names, movieNames[movieIdxWithScene]) // find movies by names
- movies, err := mqb.FindByNames(names, false)
+ movies, err := mqb.FindByNames(ctx, names, false)
if err != nil {
t.Errorf("Error finding movies: %s", err.Error())
}
assert.Len(t, movies, 1)
assert.Equal(t, movieNames[movieIdxWithScene], movies[0].Name.String)
- movies, err = mqb.FindByNames(names, true) // find movies by names nocase
+ movies, err = mqb.FindByNames(ctx, names, true) // find movies by names nocase
if err != nil {
t.Errorf("Error finding movies: %s", err.Error())
}
@@ -75,8 +77,8 @@ func TestMovieFindByNames(t *testing.T) {
}
func TestMovieQueryStudio(t *testing.T) {
- withTxn(func(r models.Repository) error {
- mqb := r.Movie()
+ withTxn(func(ctx context.Context) error {
+ mqb := sqlite.MovieReaderWriter
studioCriterion := models.HierarchicalMultiCriterionInput{
Value: []string{
strconv.Itoa(studioIDs[studioIdxWithMovie]),
@@ -88,7 +90,7 @@ func TestMovieQueryStudio(t *testing.T) {
Studios: &studioCriterion,
}
- movies, _, err := mqb.Query(&movieFilter, nil)
+ movies, _, err := mqb.Query(ctx, &movieFilter, nil)
if err != nil {
t.Errorf("Error querying movie: %s", err.Error())
}
@@ -110,7 +112,7 @@ func TestMovieQueryStudio(t *testing.T) {
Q: &q,
}
- movies, _, err = mqb.Query(&movieFilter, &findFilter)
+ movies, _, err = mqb.Query(ctx, &movieFilter, &findFilter)
if err != nil {
t.Errorf("Error querying movie: %s", err.Error())
}
@@ -159,11 +161,11 @@ func TestMovieQueryURL(t *testing.T) {
}
func verifyMovieQuery(t *testing.T, filter models.MovieFilterType, verifyFn func(s *models.Movie)) {
- withTxn(func(r models.Repository) error {
+ withTxn(func(ctx context.Context) error {
t.Helper()
- sqb := r.Movie()
+ sqb := sqlite.MovieReaderWriter
- movies := queryMovie(t, sqb, &filter, nil)
+ movies := queryMovie(ctx, t, sqb, &filter, nil)
// assume it should find at least one
assert.Greater(t, len(movies), 0)
@@ -176,8 +178,8 @@ func verifyMovieQuery(t *testing.T, filter models.MovieFilterType, verifyFn func
})
}
-func queryMovie(t *testing.T, sqb models.MovieReader, movieFilter *models.MovieFilterType, findFilter *models.FindFilterType) []*models.Movie {
- movies, _, err := sqb.Query(movieFilter, findFilter)
+func queryMovie(ctx context.Context, t *testing.T, sqb models.MovieReader, movieFilter *models.MovieFilterType, findFilter *models.FindFilterType) []*models.Movie {
+ movies, _, err := sqb.Query(ctx, movieFilter, findFilter)
if err != nil {
t.Errorf("Error querying movie: %s", err.Error())
}
@@ -193,9 +195,9 @@ func TestMovieQuerySorting(t *testing.T) {
Direction: &direction,
}
- withTxn(func(r models.Repository) error {
- sqb := r.Movie()
- movies := queryMovie(t, sqb, nil, &findFilter)
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.MovieReaderWriter
+ movies := queryMovie(ctx, t, sqb, nil, &findFilter)
// scenes should be in same order as indexes
firstMovie := movies[0]
@@ -205,7 +207,7 @@ func TestMovieQuerySorting(t *testing.T) {
// sort in descending order
direction = models.SortDirectionEnumAsc
- movies = queryMovie(t, sqb, nil, &findFilter)
+ movies = queryMovie(ctx, t, sqb, nil, &findFilter)
lastMovie := movies[len(movies)-1]
assert.Equal(t, movieIDs[movieIdxWithScene], lastMovie.ID)
@@ -215,8 +217,8 @@ func TestMovieQuerySorting(t *testing.T) {
}
func TestMovieUpdateMovieImages(t *testing.T) {
- if err := withTxn(func(r models.Repository) error {
- mqb := r.Movie()
+ if err := withTxn(func(ctx context.Context) error {
+ mqb := sqlite.MovieReaderWriter
// create movie to test against
const name = "TestMovieUpdateMovieImages"
@@ -224,26 +226,26 @@ func TestMovieUpdateMovieImages(t *testing.T) {
Name: sql.NullString{String: name, Valid: true},
Checksum: md5.FromString(name),
}
- created, err := mqb.Create(movie)
+ created, err := mqb.Create(ctx, movie)
if err != nil {
return fmt.Errorf("Error creating movie: %s", err.Error())
}
frontImage := []byte("frontImage")
backImage := []byte("backImage")
- err = mqb.UpdateImages(created.ID, frontImage, backImage)
+ err = mqb.UpdateImages(ctx, created.ID, frontImage, backImage)
if err != nil {
return fmt.Errorf("Error updating movie images: %s", err.Error())
}
// ensure images are set
- storedFront, err := mqb.GetFrontImage(created.ID)
+ storedFront, err := mqb.GetFrontImage(ctx, created.ID)
if err != nil {
return fmt.Errorf("Error getting front image: %s", err.Error())
}
assert.Equal(t, storedFront, frontImage)
- storedBack, err := mqb.GetBackImage(created.ID)
+ storedBack, err := mqb.GetBackImage(ctx, created.ID)
if err != nil {
return fmt.Errorf("Error getting back image: %s", err.Error())
}
@@ -251,26 +253,26 @@ func TestMovieUpdateMovieImages(t *testing.T) {
// set front image only
newImage := []byte("newImage")
- err = mqb.UpdateImages(created.ID, newImage, nil)
+ err = mqb.UpdateImages(ctx, created.ID, newImage, nil)
if err != nil {
return fmt.Errorf("Error updating movie images: %s", err.Error())
}
- storedFront, err = mqb.GetFrontImage(created.ID)
+ storedFront, err = mqb.GetFrontImage(ctx, created.ID)
if err != nil {
return fmt.Errorf("Error getting front image: %s", err.Error())
}
assert.Equal(t, storedFront, newImage)
// back image should be nil
- storedBack, err = mqb.GetBackImage(created.ID)
+ storedBack, err = mqb.GetBackImage(ctx, created.ID)
if err != nil {
return fmt.Errorf("Error getting back image: %s", err.Error())
}
assert.Nil(t, nil)
// set back image only
- err = mqb.UpdateImages(created.ID, nil, newImage)
+ err = mqb.UpdateImages(ctx, created.ID, nil, newImage)
if err == nil {
return fmt.Errorf("Expected error setting nil front image")
}
@@ -282,8 +284,8 @@ func TestMovieUpdateMovieImages(t *testing.T) {
}
func TestMovieDestroyMovieImages(t *testing.T) {
- if err := withTxn(func(r models.Repository) error {
- mqb := r.Movie()
+ if err := withTxn(func(ctx context.Context) error {
+ mqb := sqlite.MovieReaderWriter
// create movie to test against
const name = "TestMovieDestroyMovieImages"
@@ -291,32 +293,32 @@ func TestMovieDestroyMovieImages(t *testing.T) {
Name: sql.NullString{String: name, Valid: true},
Checksum: md5.FromString(name),
}
- created, err := mqb.Create(movie)
+ created, err := mqb.Create(ctx, movie)
if err != nil {
return fmt.Errorf("Error creating movie: %s", err.Error())
}
frontImage := []byte("frontImage")
backImage := []byte("backImage")
- err = mqb.UpdateImages(created.ID, frontImage, backImage)
+ err = mqb.UpdateImages(ctx, created.ID, frontImage, backImage)
if err != nil {
return fmt.Errorf("Error updating movie images: %s", err.Error())
}
- err = mqb.DestroyImages(created.ID)
+ err = mqb.DestroyImages(ctx, created.ID)
if err != nil {
return fmt.Errorf("Error destroying movie images: %s", err.Error())
}
// front image should be nil
- storedFront, err := mqb.GetFrontImage(created.ID)
+ storedFront, err := mqb.GetFrontImage(ctx, created.ID)
if err != nil {
return fmt.Errorf("Error getting front image: %s", err.Error())
}
assert.Nil(t, storedFront)
// back image should be nil
- storedBack, err := mqb.GetBackImage(created.ID)
+ storedBack, err := mqb.GetBackImage(ctx, created.ID)
if err != nil {
return fmt.Errorf("Error getting back image: %s", err.Error())
}
diff --git a/pkg/sqlite/performer.go b/pkg/sqlite/performer.go
index 142be42ff..d81170f0c 100644
--- a/pkg/sqlite/performer.go
+++ b/pkg/sqlite/performer.go
@@ -1,6 +1,7 @@
package sqlite
import (
+ "context"
"database/sql"
"errors"
"fmt"
@@ -25,66 +26,63 @@ type performerQueryBuilder struct {
repository
}
-func NewPerformerReaderWriter(tx dbi) *performerQueryBuilder {
- return &performerQueryBuilder{
- repository{
- tx: tx,
- tableName: performerTable,
- idColumn: idColumn,
- },
- }
+var PerformerReaderWriter = &performerQueryBuilder{
+ repository{
+ tableName: performerTable,
+ idColumn: idColumn,
+ },
}
-func (qb *performerQueryBuilder) Create(newObject models.Performer) (*models.Performer, error) {
+func (qb *performerQueryBuilder) Create(ctx context.Context, newObject models.Performer) (*models.Performer, error) {
var ret models.Performer
- if err := qb.insertObject(newObject, &ret); err != nil {
+ if err := qb.insertObject(ctx, newObject, &ret); err != nil {
return nil, err
}
return &ret, nil
}
-func (qb *performerQueryBuilder) Update(updatedObject models.PerformerPartial) (*models.Performer, error) {
+func (qb *performerQueryBuilder) Update(ctx context.Context, updatedObject models.PerformerPartial) (*models.Performer, error) {
const partial = true
- if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil {
+ if err := qb.update(ctx, updatedObject.ID, updatedObject, partial); err != nil {
return nil, err
}
var ret models.Performer
- if err := qb.get(updatedObject.ID, &ret); err != nil {
+ if err := qb.getByID(ctx, updatedObject.ID, &ret); err != nil {
return nil, err
}
return &ret, nil
}
-func (qb *performerQueryBuilder) UpdateFull(updatedObject models.Performer) (*models.Performer, error) {
+func (qb *performerQueryBuilder) UpdateFull(ctx context.Context, updatedObject models.Performer) (*models.Performer, error) {
const partial = false
- if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil {
+ if err := qb.update(ctx, updatedObject.ID, updatedObject, partial); err != nil {
return nil, err
}
var ret models.Performer
- if err := qb.get(updatedObject.ID, &ret); err != nil {
+ if err := qb.getByID(ctx, updatedObject.ID, &ret); err != nil {
return nil, err
}
return &ret, nil
}
-func (qb *performerQueryBuilder) Destroy(id int) error {
+func (qb *performerQueryBuilder) Destroy(ctx context.Context, id int) error {
// TODO - add on delete cascade to performers_scenes
- _, err := qb.tx.Exec("DELETE FROM performers_scenes WHERE performer_id = ?", id)
+ _, err := qb.tx.Exec(ctx, "DELETE FROM performers_scenes WHERE performer_id = ?", id)
if err != nil {
return err
}
- return qb.destroyExisting([]int{id})
+ return qb.destroyExisting(ctx, []int{id})
}
-func (qb *performerQueryBuilder) Find(id int) (*models.Performer, error) {
+func (qb *performerQueryBuilder) Find(ctx context.Context, id int) (*models.Performer, error) {
var ret models.Performer
- if err := qb.get(id, &ret); err != nil {
+ if err := qb.getByID(ctx, id, &ret); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
@@ -93,10 +91,10 @@ func (qb *performerQueryBuilder) Find(id int) (*models.Performer, error) {
return &ret, nil
}
-func (qb *performerQueryBuilder) FindMany(ids []int) ([]*models.Performer, error) {
+func (qb *performerQueryBuilder) FindMany(ctx context.Context, ids []int) ([]*models.Performer, error) {
var performers []*models.Performer
for _, id := range ids {
- performer, err := qb.Find(id)
+ performer, err := qb.Find(ctx, id)
if err != nil {
return nil, err
}
@@ -111,44 +109,44 @@ func (qb *performerQueryBuilder) FindMany(ids []int) ([]*models.Performer, error
return performers, nil
}
-func (qb *performerQueryBuilder) FindBySceneID(sceneID int) ([]*models.Performer, error) {
+func (qb *performerQueryBuilder) FindBySceneID(ctx context.Context, sceneID int) ([]*models.Performer, error) {
query := selectAll("performers") + `
LEFT JOIN performers_scenes as scenes_join on scenes_join.performer_id = performers.id
WHERE scenes_join.scene_id = ?
`
args := []interface{}{sceneID}
- return qb.queryPerformers(query, args)
+ return qb.queryPerformers(ctx, query, args)
}
-func (qb *performerQueryBuilder) FindByImageID(imageID int) ([]*models.Performer, error) {
+func (qb *performerQueryBuilder) FindByImageID(ctx context.Context, imageID int) ([]*models.Performer, error) {
query := selectAll("performers") + `
LEFT JOIN performers_images as images_join on images_join.performer_id = performers.id
WHERE images_join.image_id = ?
`
args := []interface{}{imageID}
- return qb.queryPerformers(query, args)
+ return qb.queryPerformers(ctx, query, args)
}
-func (qb *performerQueryBuilder) FindByGalleryID(galleryID int) ([]*models.Performer, error) {
+func (qb *performerQueryBuilder) FindByGalleryID(ctx context.Context, galleryID int) ([]*models.Performer, error) {
query := selectAll("performers") + `
LEFT JOIN performers_galleries as galleries_join on galleries_join.performer_id = performers.id
WHERE galleries_join.gallery_id = ?
`
args := []interface{}{galleryID}
- return qb.queryPerformers(query, args)
+ return qb.queryPerformers(ctx, query, args)
}
-func (qb *performerQueryBuilder) FindNamesBySceneID(sceneID int) ([]*models.Performer, error) {
+func (qb *performerQueryBuilder) FindNamesBySceneID(ctx context.Context, sceneID int) ([]*models.Performer, error) {
query := `
SELECT performers.name FROM performers
LEFT JOIN performers_scenes as scenes_join on scenes_join.performer_id = performers.id
WHERE scenes_join.scene_id = ?
`
args := []interface{}{sceneID}
- return qb.queryPerformers(query, args)
+ return qb.queryPerformers(ctx, query, args)
}
-func (qb *performerQueryBuilder) FindByNames(names []string, nocase bool) ([]*models.Performer, error) {
+func (qb *performerQueryBuilder) FindByNames(ctx context.Context, names []string, nocase bool) ([]*models.Performer, error) {
query := "SELECT * FROM performers WHERE name"
if nocase {
query += " COLLATE NOCASE"
@@ -159,23 +157,23 @@ func (qb *performerQueryBuilder) FindByNames(names []string, nocase bool) ([]*mo
for _, name := range names {
args = append(args, name)
}
- return qb.queryPerformers(query, args)
+ return qb.queryPerformers(ctx, query, args)
}
-func (qb *performerQueryBuilder) CountByTagID(tagID int) (int, error) {
+func (qb *performerQueryBuilder) CountByTagID(ctx context.Context, tagID int) (int, error) {
args := []interface{}{tagID}
- return qb.runCountQuery(qb.buildCountQuery(countPerformersForTagQuery), args)
+ return qb.runCountQuery(ctx, qb.buildCountQuery(countPerformersForTagQuery), args)
}
-func (qb *performerQueryBuilder) Count() (int, error) {
- return qb.runCountQuery(qb.buildCountQuery("SELECT performers.id FROM performers"), nil)
+func (qb *performerQueryBuilder) Count(ctx context.Context) (int, error) {
+ return qb.runCountQuery(ctx, qb.buildCountQuery("SELECT performers.id FROM performers"), nil)
}
-func (qb *performerQueryBuilder) All() ([]*models.Performer, error) {
- return qb.queryPerformers(selectAll("performers")+qb.getPerformerSort(nil), nil)
+func (qb *performerQueryBuilder) All(ctx context.Context) ([]*models.Performer, error) {
+ return qb.queryPerformers(ctx, selectAll("performers")+qb.getPerformerSort(nil), nil)
}
-func (qb *performerQueryBuilder) QueryForAutoTag(words []string) ([]*models.Performer, error) {
+func (qb *performerQueryBuilder) QueryForAutoTag(ctx context.Context, words []string) ([]*models.Performer, error) {
// TODO - Query needs to be changed to support queries of this type, and
// this method should be removed
query := selectAll(performerTable)
@@ -196,7 +194,7 @@ func (qb *performerQueryBuilder) QueryForAutoTag(words []string) ([]*models.Perf
"ignore_auto_tag = 0",
whereOr,
}, " AND ")
- return qb.queryPerformers(query+" WHERE "+where, args)
+ return qb.queryPerformers(ctx, query+" WHERE "+where, args)
}
func (qb *performerQueryBuilder) validateFilter(filter *models.PerformerFilterType) error {
@@ -230,74 +228,74 @@ func (qb *performerQueryBuilder) validateFilter(filter *models.PerformerFilterTy
return nil
}
-func (qb *performerQueryBuilder) makeFilter(filter *models.PerformerFilterType) *filterBuilder {
+func (qb *performerQueryBuilder) makeFilter(ctx context.Context, filter *models.PerformerFilterType) *filterBuilder {
query := &filterBuilder{}
if filter.And != nil {
- query.and(qb.makeFilter(filter.And))
+ query.and(qb.makeFilter(ctx, filter.And))
}
if filter.Or != nil {
- query.or(qb.makeFilter(filter.Or))
+ query.or(qb.makeFilter(ctx, filter.Or))
}
if filter.Not != nil {
- query.not(qb.makeFilter(filter.Not))
+ query.not(qb.makeFilter(ctx, filter.Not))
}
const tableName = performerTable
- query.handleCriterion(stringCriterionHandler(filter.Name, tableName+".name"))
- query.handleCriterion(stringCriterionHandler(filter.Details, tableName+".details"))
+ query.handleCriterion(ctx, stringCriterionHandler(filter.Name, tableName+".name"))
+ query.handleCriterion(ctx, stringCriterionHandler(filter.Details, tableName+".details"))
- query.handleCriterion(boolCriterionHandler(filter.FilterFavorites, tableName+".favorite"))
- query.handleCriterion(boolCriterionHandler(filter.IgnoreAutoTag, tableName+".ignore_auto_tag"))
+ query.handleCriterion(ctx, boolCriterionHandler(filter.FilterFavorites, tableName+".favorite"))
+ query.handleCriterion(ctx, boolCriterionHandler(filter.IgnoreAutoTag, tableName+".ignore_auto_tag"))
- query.handleCriterion(yearFilterCriterionHandler(filter.BirthYear, tableName+".birthdate"))
- query.handleCriterion(yearFilterCriterionHandler(filter.DeathYear, tableName+".death_date"))
+ query.handleCriterion(ctx, yearFilterCriterionHandler(filter.BirthYear, tableName+".birthdate"))
+ query.handleCriterion(ctx, yearFilterCriterionHandler(filter.DeathYear, tableName+".death_date"))
- query.handleCriterion(performerAgeFilterCriterionHandler(filter.Age))
+ query.handleCriterion(ctx, performerAgeFilterCriterionHandler(filter.Age))
- query.handleCriterion(criterionHandlerFunc(func(f *filterBuilder) {
+ query.handleCriterion(ctx, criterionHandlerFunc(func(ctx context.Context, f *filterBuilder) {
if gender := filter.Gender; gender != nil {
f.addWhere(tableName+".gender = ?", gender.Value.String())
}
}))
- query.handleCriterion(performerIsMissingCriterionHandler(qb, filter.IsMissing))
- query.handleCriterion(stringCriterionHandler(filter.Ethnicity, tableName+".ethnicity"))
- query.handleCriterion(stringCriterionHandler(filter.Country, tableName+".country"))
- query.handleCriterion(stringCriterionHandler(filter.EyeColor, tableName+".eye_color"))
- query.handleCriterion(stringCriterionHandler(filter.Height, tableName+".height"))
- query.handleCriterion(stringCriterionHandler(filter.Measurements, tableName+".measurements"))
- query.handleCriterion(stringCriterionHandler(filter.FakeTits, tableName+".fake_tits"))
- query.handleCriterion(stringCriterionHandler(filter.CareerLength, tableName+".career_length"))
- query.handleCriterion(stringCriterionHandler(filter.Tattoos, tableName+".tattoos"))
- query.handleCriterion(stringCriterionHandler(filter.Piercings, tableName+".piercings"))
- query.handleCriterion(intCriterionHandler(filter.Rating, tableName+".rating"))
- query.handleCriterion(stringCriterionHandler(filter.HairColor, tableName+".hair_color"))
- query.handleCriterion(stringCriterionHandler(filter.URL, tableName+".url"))
- query.handleCriterion(intCriterionHandler(filter.Weight, tableName+".weight"))
- query.handleCriterion(criterionHandlerFunc(func(f *filterBuilder) {
+ query.handleCriterion(ctx, performerIsMissingCriterionHandler(qb, filter.IsMissing))
+ query.handleCriterion(ctx, stringCriterionHandler(filter.Ethnicity, tableName+".ethnicity"))
+ query.handleCriterion(ctx, stringCriterionHandler(filter.Country, tableName+".country"))
+ query.handleCriterion(ctx, stringCriterionHandler(filter.EyeColor, tableName+".eye_color"))
+ query.handleCriterion(ctx, stringCriterionHandler(filter.Height, tableName+".height"))
+ query.handleCriterion(ctx, stringCriterionHandler(filter.Measurements, tableName+".measurements"))
+ query.handleCriterion(ctx, stringCriterionHandler(filter.FakeTits, tableName+".fake_tits"))
+ query.handleCriterion(ctx, stringCriterionHandler(filter.CareerLength, tableName+".career_length"))
+ query.handleCriterion(ctx, stringCriterionHandler(filter.Tattoos, tableName+".tattoos"))
+ query.handleCriterion(ctx, stringCriterionHandler(filter.Piercings, tableName+".piercings"))
+ query.handleCriterion(ctx, intCriterionHandler(filter.Rating, tableName+".rating"))
+ query.handleCriterion(ctx, stringCriterionHandler(filter.HairColor, tableName+".hair_color"))
+ query.handleCriterion(ctx, stringCriterionHandler(filter.URL, tableName+".url"))
+ query.handleCriterion(ctx, intCriterionHandler(filter.Weight, tableName+".weight"))
+ query.handleCriterion(ctx, criterionHandlerFunc(func(ctx context.Context, f *filterBuilder) {
if filter.StashID != nil {
qb.stashIDRepository().join(f, "performer_stash_ids", "performers.id")
- stringCriterionHandler(filter.StashID, "performer_stash_ids.stash_id")(f)
+ stringCriterionHandler(filter.StashID, "performer_stash_ids.stash_id")(ctx, f)
}
}))
// TODO - need better handling of aliases
- query.handleCriterion(stringCriterionHandler(filter.Aliases, tableName+".aliases"))
+ query.handleCriterion(ctx, stringCriterionHandler(filter.Aliases, tableName+".aliases"))
- query.handleCriterion(performerTagsCriterionHandler(qb, filter.Tags))
+ query.handleCriterion(ctx, performerTagsCriterionHandler(qb, filter.Tags))
- query.handleCriterion(performerStudiosCriterionHandler(qb, filter.Studios))
+ query.handleCriterion(ctx, performerStudiosCriterionHandler(qb, filter.Studios))
- query.handleCriterion(performerTagCountCriterionHandler(qb, filter.TagCount))
- query.handleCriterion(performerSceneCountCriterionHandler(qb, filter.SceneCount))
- query.handleCriterion(performerImageCountCriterionHandler(qb, filter.ImageCount))
- query.handleCriterion(performerGalleryCountCriterionHandler(qb, filter.GalleryCount))
+ query.handleCriterion(ctx, performerTagCountCriterionHandler(qb, filter.TagCount))
+ query.handleCriterion(ctx, performerSceneCountCriterionHandler(qb, filter.SceneCount))
+ query.handleCriterion(ctx, performerImageCountCriterionHandler(qb, filter.ImageCount))
+ query.handleCriterion(ctx, performerGalleryCountCriterionHandler(qb, filter.GalleryCount))
return query
}
-func (qb *performerQueryBuilder) Query(performerFilter *models.PerformerFilterType, findFilter *models.FindFilterType) ([]*models.Performer, int, error) {
+func (qb *performerQueryBuilder) Query(ctx context.Context, performerFilter *models.PerformerFilterType, findFilter *models.FindFilterType) ([]*models.Performer, int, error) {
if performerFilter == nil {
performerFilter = &models.PerformerFilterType{}
}
@@ -316,19 +314,19 @@ func (qb *performerQueryBuilder) Query(performerFilter *models.PerformerFilterTy
if err := qb.validateFilter(performerFilter); err != nil {
return nil, 0, err
}
- filter := qb.makeFilter(performerFilter)
+ filter := qb.makeFilter(ctx, performerFilter)
query.addFilter(filter)
query.sortAndPagination = qb.getPerformerSort(findFilter) + getPagination(findFilter)
- idsResult, countResult, err := query.executeFind()
+ idsResult, countResult, err := query.executeFind(ctx)
if err != nil {
return nil, 0, err
}
var performers []*models.Performer
for _, id := range idsResult {
- performer, err := qb.Find(id)
+ performer, err := qb.Find(ctx, id)
if err != nil {
return nil, 0, err
}
@@ -339,7 +337,7 @@ func (qb *performerQueryBuilder) Query(performerFilter *models.PerformerFilterTy
}
func performerIsMissingCriterionHandler(qb *performerQueryBuilder, isMissing *string) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
if isMissing != nil && *isMissing != "" {
switch *isMissing {
case "scenes": // Deprecated: use `scene_count == 0` filter instead
@@ -359,7 +357,7 @@ func performerIsMissingCriterionHandler(qb *performerQueryBuilder, isMissing *st
}
func yearFilterCriterionHandler(year *models.IntCriterionInput, col string) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
if year != nil && year.Modifier.IsValid() {
clause, args := getIntCriterionWhereClause("cast(strftime('%Y', "+col+") as int)", *year)
f.addWhere(clause, args...)
@@ -368,7 +366,7 @@ func yearFilterCriterionHandler(year *models.IntCriterionInput, col string) crit
}
func performerAgeFilterCriterionHandler(age *models.IntCriterionInput) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
if age != nil && age.Modifier.IsValid() {
clause, args := getIntCriterionWhereClause(
"cast(IFNULL(strftime('%Y.%m%d', performers.death_date), strftime('%Y.%m%d', 'now')) - strftime('%Y.%m%d', performers.birthdate) as int)",
@@ -437,7 +435,7 @@ func performerGalleryCountCriterionHandler(qb *performerQueryBuilder, count *mod
}
func performerStudiosCriterionHandler(qb *performerQueryBuilder, studios *models.HierarchicalMultiCriterionInput) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
if studios != nil {
formatMaps := []utils.StrFormatMap{
{
@@ -493,7 +491,7 @@ func performerStudiosCriterionHandler(qb *performerQueryBuilder, studios *models
}
const derivedPerformerStudioTable = "performer_studio"
- valuesClause := getHierarchicalValues(qb.tx, studios.Value, studioTable, "", "parent_id", studios.Depth)
+ valuesClause := getHierarchicalValues(ctx, qb.tx, studios.Value, studioTable, "", "parent_id", studios.Depth)
f.addWith("studio(root_id, item_id) AS (" + valuesClause + ")")
templStr := `SELECT performer_id FROM {primaryTable}
@@ -540,9 +538,9 @@ func (qb *performerQueryBuilder) getPerformerSort(findFilter *models.FindFilterT
return getSort(sort, direction, "performers")
}
-func (qb *performerQueryBuilder) queryPerformers(query string, args []interface{}) ([]*models.Performer, error) {
+func (qb *performerQueryBuilder) queryPerformers(ctx context.Context, query string, args []interface{}) ([]*models.Performer, error) {
var ret models.Performers
- if err := qb.query(query, args, &ret); err != nil {
+ if err := qb.query(ctx, query, args, &ret); err != nil {
return nil, err
}
@@ -560,13 +558,13 @@ func (qb *performerQueryBuilder) tagsRepository() *joinRepository {
}
}
-func (qb *performerQueryBuilder) GetTagIDs(id int) ([]int, error) {
- return qb.tagsRepository().getIDs(id)
+func (qb *performerQueryBuilder) GetTagIDs(ctx context.Context, id int) ([]int, error) {
+ return qb.tagsRepository().getIDs(ctx, id)
}
-func (qb *performerQueryBuilder) UpdateTags(id int, tagIDs []int) error {
+func (qb *performerQueryBuilder) UpdateTags(ctx context.Context, id int, tagIDs []int) error {
// Delete the existing joins and then create new ones
- return qb.tagsRepository().replace(id, tagIDs)
+ return qb.tagsRepository().replace(ctx, id, tagIDs)
}
func (qb *performerQueryBuilder) imageRepository() *imageRepository {
@@ -580,16 +578,16 @@ func (qb *performerQueryBuilder) imageRepository() *imageRepository {
}
}
-func (qb *performerQueryBuilder) GetImage(performerID int) ([]byte, error) {
- return qb.imageRepository().get(performerID)
+func (qb *performerQueryBuilder) GetImage(ctx context.Context, performerID int) ([]byte, error) {
+ return qb.imageRepository().get(ctx, performerID)
}
-func (qb *performerQueryBuilder) UpdateImage(performerID int, image []byte) error {
- return qb.imageRepository().replace(performerID, image)
+func (qb *performerQueryBuilder) UpdateImage(ctx context.Context, performerID int, image []byte) error {
+ return qb.imageRepository().replace(ctx, performerID, image)
}
-func (qb *performerQueryBuilder) DestroyImage(performerID int) error {
- return qb.imageRepository().destroy([]int{performerID})
+func (qb *performerQueryBuilder) DestroyImage(ctx context.Context, performerID int) error {
+ return qb.imageRepository().destroy(ctx, []int{performerID})
}
func (qb *performerQueryBuilder) stashIDRepository() *stashIDRepository {
@@ -602,25 +600,25 @@ func (qb *performerQueryBuilder) stashIDRepository() *stashIDRepository {
}
}
-func (qb *performerQueryBuilder) GetStashIDs(performerID int) ([]*models.StashID, error) {
- return qb.stashIDRepository().get(performerID)
+func (qb *performerQueryBuilder) GetStashIDs(ctx context.Context, performerID int) ([]*models.StashID, error) {
+ return qb.stashIDRepository().get(ctx, performerID)
}
-func (qb *performerQueryBuilder) UpdateStashIDs(performerID int, stashIDs []models.StashID) error {
- return qb.stashIDRepository().replace(performerID, stashIDs)
+func (qb *performerQueryBuilder) UpdateStashIDs(ctx context.Context, performerID int, stashIDs []models.StashID) error {
+ return qb.stashIDRepository().replace(ctx, performerID, stashIDs)
}
-func (qb *performerQueryBuilder) FindByStashID(stashID models.StashID) ([]*models.Performer, error) {
+func (qb *performerQueryBuilder) FindByStashID(ctx context.Context, stashID models.StashID) ([]*models.Performer, error) {
query := selectAll("performers") + `
LEFT JOIN performer_stash_ids on performer_stash_ids.performer_id = performers.id
WHERE performer_stash_ids.stash_id = ?
AND performer_stash_ids.endpoint = ?
`
args := []interface{}{stashID.StashID, stashID.Endpoint}
- return qb.queryPerformers(query, args)
+ return qb.queryPerformers(ctx, query, args)
}
-func (qb *performerQueryBuilder) FindByStashIDStatus(hasStashID bool, stashboxEndpoint string) ([]*models.Performer, error) {
+func (qb *performerQueryBuilder) FindByStashIDStatus(ctx context.Context, hasStashID bool, stashboxEndpoint string) ([]*models.Performer, error) {
query := selectAll("performers") + `
LEFT JOIN performer_stash_ids on performer_stash_ids.performer_id = performers.id
`
@@ -637,5 +635,5 @@ func (qb *performerQueryBuilder) FindByStashIDStatus(hasStashID bool, stashboxEn
}
args := []interface{}{stashboxEndpoint}
- return qb.queryPerformers(query, args)
+ return qb.queryPerformers(ctx, query, args)
}
diff --git a/pkg/sqlite/performer_test.go b/pkg/sqlite/performer_test.go
index a6839f573..7be6eb4fd 100644
--- a/pkg/sqlite/performer_test.go
+++ b/pkg/sqlite/performer_test.go
@@ -4,6 +4,7 @@
package sqlite_test
import (
+ "context"
"database/sql"
"fmt"
"math"
@@ -16,14 +17,15 @@ import (
"github.com/stashapp/stash/pkg/hash/md5"
"github.com/stashapp/stash/pkg/models"
+ "github.com/stashapp/stash/pkg/sqlite"
)
func TestPerformerFindBySceneID(t *testing.T) {
- withTxn(func(r models.Repository) error {
- pqb := r.Performer()
+ withTxn(func(ctx context.Context) error {
+ pqb := sqlite.PerformerReaderWriter
sceneID := sceneIDs[sceneIdxWithPerformer]
- performers, err := pqb.FindBySceneID(sceneID)
+ performers, err := pqb.FindBySceneID(ctx, sceneID)
if err != nil {
t.Errorf("Error finding performer: %s", err.Error())
@@ -34,7 +36,7 @@ func TestPerformerFindBySceneID(t *testing.T) {
assert.Equal(t, getPerformerStringValue(performerIdxWithScene, "Name"), performer.Name.String)
- performers, err = pqb.FindBySceneID(0)
+ performers, err = pqb.FindBySceneID(ctx, 0)
if err != nil {
t.Errorf("Error finding performer: %s", err.Error())
@@ -55,21 +57,21 @@ func TestPerformerFindByNames(t *testing.T) {
return ret
}
- withTxn(func(r models.Repository) error {
+ withTxn(func(ctx context.Context) error {
var names []string
- pqb := r.Performer()
+ pqb := sqlite.PerformerReaderWriter
names = append(names, performerNames[performerIdxWithScene]) // find performers by names
- performers, err := pqb.FindByNames(names, false)
+ performers, err := pqb.FindByNames(ctx, names, false)
if err != nil {
t.Errorf("Error finding performers: %s", err.Error())
}
assert.Len(t, performers, 1)
assert.Equal(t, performerNames[performerIdxWithScene], performers[0].Name.String)
- performers, err = pqb.FindByNames(names, true) // find performers by names nocase
+ performers, err = pqb.FindByNames(ctx, names, true) // find performers by names nocase
if err != nil {
t.Errorf("Error finding performers: %s", err.Error())
}
@@ -79,14 +81,14 @@ func TestPerformerFindByNames(t *testing.T) {
names = append(names, performerNames[performerIdx1WithScene]) // find performers by names ( 2 names )
- performers, err = pqb.FindByNames(names, false)
+ performers, err = pqb.FindByNames(ctx, names, false)
if err != nil {
t.Errorf("Error finding performers: %s", err.Error())
}
retNames := getNames(performers)
assert.Equal(t, names, retNames)
- performers, err = pqb.FindByNames(names, true) // find performers by names ( 2 names nocase)
+ performers, err = pqb.FindByNames(ctx, names, true) // find performers by names ( 2 names nocase)
if err != nil {
t.Errorf("Error finding performers: %s", err.Error())
}
@@ -122,10 +124,10 @@ func TestPerformerQueryEthnicityOr(t *testing.T) {
},
}
- withTxn(func(r models.Repository) error {
- sqb := r.Performer()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.PerformerReaderWriter
- performers := queryPerformers(t, sqb, &performerFilter, nil)
+ performers := queryPerformers(ctx, t, sqb, &performerFilter, nil)
assert.Len(t, performers, 2)
assert.Equal(t, performer1Eth, performers[0].Ethnicity.String)
@@ -153,10 +155,10 @@ func TestPerformerQueryEthnicityAndRating(t *testing.T) {
},
}
- withTxn(func(r models.Repository) error {
- sqb := r.Performer()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.PerformerReaderWriter
- performers := queryPerformers(t, sqb, &performerFilter, nil)
+ performers := queryPerformers(ctx, t, sqb, &performerFilter, nil)
assert.Len(t, performers, 1)
assert.Equal(t, performerEth, performers[0].Ethnicity.String)
@@ -188,10 +190,10 @@ func TestPerformerQueryEthnicityNotRating(t *testing.T) {
},
}
- withTxn(func(r models.Repository) error {
- sqb := r.Performer()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.PerformerReaderWriter
- performers := queryPerformers(t, sqb, &performerFilter, nil)
+ performers := queryPerformers(ctx, t, sqb, &performerFilter, nil)
for _, performer := range performers {
verifyString(t, performer.Ethnicity.String, ethCriterion)
@@ -219,20 +221,20 @@ func TestPerformerIllegalQuery(t *testing.T) {
Or: &subFilter,
}
- withTxn(func(r models.Repository) error {
- sqb := r.Performer()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.PerformerReaderWriter
- _, _, err := sqb.Query(performerFilter, nil)
+ _, _, err := sqb.Query(ctx, performerFilter, nil)
assert.NotNil(err)
performerFilter.Or = nil
performerFilter.Not = &subFilter
- _, _, err = sqb.Query(performerFilter, nil)
+ _, _, err = sqb.Query(ctx, performerFilter, nil)
assert.NotNil(err)
performerFilter.And = nil
performerFilter.Or = &subFilter
- _, _, err = sqb.Query(performerFilter, nil)
+ _, _, err = sqb.Query(ctx, performerFilter, nil)
assert.NotNil(err)
return nil
@@ -240,15 +242,15 @@ func TestPerformerIllegalQuery(t *testing.T) {
}
func TestPerformerQueryIgnoreAutoTag(t *testing.T) {
- withTxn(func(r models.Repository) error {
+ withTxn(func(ctx context.Context) error {
ignoreAutoTag := true
performerFilter := models.PerformerFilterType{
IgnoreAutoTag: &ignoreAutoTag,
}
- sqb := r.Performer()
+ sqb := sqlite.PerformerReaderWriter
- performers := queryPerformers(t, sqb, &performerFilter, nil)
+ performers := queryPerformers(ctx, t, sqb, &performerFilter, nil)
assert.Len(t, performers, int(math.Ceil(float64(totalPerformers)/5)))
for _, p := range performers {
@@ -260,12 +262,12 @@ func TestPerformerQueryIgnoreAutoTag(t *testing.T) {
}
func TestPerformerQueryForAutoTag(t *testing.T) {
- withTxn(func(r models.Repository) error {
- tqb := r.Performer()
+ withTxn(func(ctx context.Context) error {
+ tqb := sqlite.PerformerReaderWriter
name := performerNames[performerIdx1WithScene] // find a performer by name
- performers, err := tqb.QueryForAutoTag([]string{name})
+ performers, err := tqb.QueryForAutoTag(ctx, []string{name})
if err != nil {
t.Errorf("Error finding performers: %s", err.Error())
@@ -280,8 +282,8 @@ func TestPerformerQueryForAutoTag(t *testing.T) {
}
func TestPerformerUpdatePerformerImage(t *testing.T) {
- if err := withTxn(func(r models.Repository) error {
- qb := r.Performer()
+ if err := withTxn(func(ctx context.Context) error {
+ qb := sqlite.PerformerReaderWriter
// create performer to test against
const name = "TestPerformerUpdatePerformerImage"
@@ -290,26 +292,26 @@ func TestPerformerUpdatePerformerImage(t *testing.T) {
Checksum: md5.FromString(name),
Favorite: sql.NullBool{Bool: false, Valid: true},
}
- created, err := qb.Create(performer)
+ created, err := qb.Create(ctx, performer)
if err != nil {
return fmt.Errorf("Error creating performer: %s", err.Error())
}
image := []byte("image")
- err = qb.UpdateImage(created.ID, image)
+ err = qb.UpdateImage(ctx, created.ID, image)
if err != nil {
return fmt.Errorf("Error updating performer image: %s", err.Error())
}
// ensure image set
- storedImage, err := qb.GetImage(created.ID)
+ storedImage, err := qb.GetImage(ctx, created.ID)
if err != nil {
return fmt.Errorf("Error getting image: %s", err.Error())
}
assert.Equal(t, storedImage, image)
// set nil image
- err = qb.UpdateImage(created.ID, nil)
+ err = qb.UpdateImage(ctx, created.ID, nil)
if err == nil {
return fmt.Errorf("Expected error setting nil image")
}
@@ -321,8 +323,8 @@ func TestPerformerUpdatePerformerImage(t *testing.T) {
}
func TestPerformerDestroyPerformerImage(t *testing.T) {
- if err := withTxn(func(r models.Repository) error {
- qb := r.Performer()
+ if err := withTxn(func(ctx context.Context) error {
+ qb := sqlite.PerformerReaderWriter
// create performer to test against
const name = "TestPerformerDestroyPerformerImage"
@@ -331,24 +333,24 @@ func TestPerformerDestroyPerformerImage(t *testing.T) {
Checksum: md5.FromString(name),
Favorite: sql.NullBool{Bool: false, Valid: true},
}
- created, err := qb.Create(performer)
+ created, err := qb.Create(ctx, performer)
if err != nil {
return fmt.Errorf("Error creating performer: %s", err.Error())
}
image := []byte("image")
- err = qb.UpdateImage(created.ID, image)
+ err = qb.UpdateImage(ctx, created.ID, image)
if err != nil {
return fmt.Errorf("Error updating performer image: %s", err.Error())
}
- err = qb.DestroyImage(created.ID)
+ err = qb.DestroyImage(ctx, created.ID)
if err != nil {
return fmt.Errorf("Error destroying performer image: %s", err.Error())
}
// image should be nil
- storedImage, err := qb.GetImage(created.ID)
+ storedImage, err := qb.GetImage(ctx, created.ID)
if err != nil {
return fmt.Errorf("Error getting image: %s", err.Error())
}
@@ -380,13 +382,13 @@ func TestPerformerQueryAge(t *testing.T) {
}
func verifyPerformerAge(t *testing.T, ageCriterion models.IntCriterionInput) {
- withTxn(func(r models.Repository) error {
- qb := r.Performer()
+ withTxn(func(ctx context.Context) error {
+ qb := sqlite.PerformerReaderWriter
performerFilter := models.PerformerFilterType{
Age: &ageCriterion,
}
- performers, _, err := qb.Query(&performerFilter, nil)
+ performers, _, err := qb.Query(ctx, &performerFilter, nil)
if err != nil {
t.Errorf("Error querying performer: %s", err.Error())
}
@@ -433,13 +435,13 @@ func TestPerformerQueryCareerLength(t *testing.T) {
}
func verifyPerformerCareerLength(t *testing.T, criterion models.StringCriterionInput) {
- withTxn(func(r models.Repository) error {
- qb := r.Performer()
+ withTxn(func(ctx context.Context) error {
+ qb := sqlite.PerformerReaderWriter
performerFilter := models.PerformerFilterType{
CareerLength: &criterion,
}
- performers, _, err := qb.Query(&performerFilter, nil)
+ performers, _, err := qb.Query(ctx, &performerFilter, nil)
if err != nil {
t.Errorf("Error querying performer: %s", err.Error())
}
@@ -492,11 +494,11 @@ func TestPerformerQueryURL(t *testing.T) {
}
func verifyPerformerQuery(t *testing.T, filter models.PerformerFilterType, verifyFn func(s *models.Performer)) {
- withTxn(func(r models.Repository) error {
+ withTxn(func(ctx context.Context) error {
t.Helper()
- sqb := r.Performer()
+ sqb := sqlite.PerformerReaderWriter
- performers := queryPerformers(t, sqb, &filter, nil)
+ performers := queryPerformers(ctx, t, sqb, &filter, nil)
// assume it should find at least one
assert.Greater(t, len(performers), 0)
@@ -509,8 +511,8 @@ func verifyPerformerQuery(t *testing.T, filter models.PerformerFilterType, verif
})
}
-func queryPerformers(t *testing.T, qb models.PerformerReader, performerFilter *models.PerformerFilterType, findFilter *models.FindFilterType) []*models.Performer {
- performers, _, err := qb.Query(performerFilter, findFilter)
+func queryPerformers(ctx context.Context, t *testing.T, qb models.PerformerReader, performerFilter *models.PerformerFilterType, findFilter *models.FindFilterType) []*models.Performer {
+ performers, _, err := qb.Query(ctx, performerFilter, findFilter)
if err != nil {
t.Errorf("Error querying performers: %s", err.Error())
}
@@ -519,8 +521,8 @@ func queryPerformers(t *testing.T, qb models.PerformerReader, performerFilter *m
}
func TestPerformerQueryTags(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Performer()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.PerformerReaderWriter
tagCriterion := models.HierarchicalMultiCriterionInput{
Value: []string{
strconv.Itoa(tagIDs[tagIdxWithPerformer]),
@@ -534,7 +536,7 @@ func TestPerformerQueryTags(t *testing.T) {
}
// ensure ids are correct
- performers := queryPerformers(t, sqb, &performerFilter, nil)
+ performers := queryPerformers(ctx, t, sqb, &performerFilter, nil)
assert.Len(t, performers, 2)
for _, performer := range performers {
assert.True(t, performer.ID == performerIDs[performerIdxWithTag] || performer.ID == performerIDs[performerIdxWithTwoTags])
@@ -548,7 +550,7 @@ func TestPerformerQueryTags(t *testing.T) {
Modifier: models.CriterionModifierIncludesAll,
}
- performers = queryPerformers(t, sqb, &performerFilter, nil)
+ performers = queryPerformers(ctx, t, sqb, &performerFilter, nil)
assert.Len(t, performers, 1)
assert.Equal(t, sceneIDs[performerIdxWithTwoTags], performers[0].ID)
@@ -565,7 +567,7 @@ func TestPerformerQueryTags(t *testing.T) {
Q: &q,
}
- performers = queryPerformers(t, sqb, &performerFilter, &findFilter)
+ performers = queryPerformers(ctx, t, sqb, &performerFilter, &findFilter)
assert.Len(t, performers, 0)
return nil
@@ -592,17 +594,17 @@ func TestPerformerQueryTagCount(t *testing.T) {
}
func verifyPerformersTagCount(t *testing.T, tagCountCriterion models.IntCriterionInput) {
- withTxn(func(r models.Repository) error {
- sqb := r.Performer()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.PerformerReaderWriter
performerFilter := models.PerformerFilterType{
TagCount: &tagCountCriterion,
}
- performers := queryPerformers(t, sqb, &performerFilter, nil)
+ performers := queryPerformers(ctx, t, sqb, &performerFilter, nil)
assert.Greater(t, len(performers), 0)
for _, performer := range performers {
- ids, err := sqb.GetTagIDs(performer.ID)
+ ids, err := sqb.GetTagIDs(ctx, performer.ID)
if err != nil {
return err
}
@@ -633,17 +635,17 @@ func TestPerformerQuerySceneCount(t *testing.T) {
}
func verifyPerformersSceneCount(t *testing.T, sceneCountCriterion models.IntCriterionInput) {
- withTxn(func(r models.Repository) error {
- sqb := r.Performer()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.PerformerReaderWriter
performerFilter := models.PerformerFilterType{
SceneCount: &sceneCountCriterion,
}
- performers := queryPerformers(t, sqb, &performerFilter, nil)
+ performers := queryPerformers(ctx, t, sqb, &performerFilter, nil)
assert.Greater(t, len(performers), 0)
for _, performer := range performers {
- ids, err := r.Scene().FindByPerformerID(performer.ID)
+ ids, err := sqlite.SceneReaderWriter.FindByPerformerID(ctx, performer.ID)
if err != nil {
return err
}
@@ -674,19 +676,19 @@ func TestPerformerQueryImageCount(t *testing.T) {
}
func verifyPerformersImageCount(t *testing.T, imageCountCriterion models.IntCriterionInput) {
- withTxn(func(r models.Repository) error {
- sqb := r.Performer()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.PerformerReaderWriter
performerFilter := models.PerformerFilterType{
ImageCount: &imageCountCriterion,
}
- performers := queryPerformers(t, sqb, &performerFilter, nil)
+ performers := queryPerformers(ctx, t, sqb, &performerFilter, nil)
assert.Greater(t, len(performers), 0)
for _, performer := range performers {
pp := 0
- result, err := r.Image().Query(models.ImageQueryOptions{
+ result, err := sqlite.ImageReaderWriter.Query(ctx, models.ImageQueryOptions{
QueryOptions: models.QueryOptions{
FindFilter: &models.FindFilterType{
PerPage: &pp,
@@ -730,19 +732,19 @@ func TestPerformerQueryGalleryCount(t *testing.T) {
}
func verifyPerformersGalleryCount(t *testing.T, galleryCountCriterion models.IntCriterionInput) {
- withTxn(func(r models.Repository) error {
- sqb := r.Performer()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.PerformerReaderWriter
performerFilter := models.PerformerFilterType{
GalleryCount: &galleryCountCriterion,
}
- performers := queryPerformers(t, sqb, &performerFilter, nil)
+ performers := queryPerformers(ctx, t, sqb, &performerFilter, nil)
assert.Greater(t, len(performers), 0)
for _, performer := range performers {
pp := 0
- _, count, err := r.Gallery().Query(&models.GalleryFilterType{
+ _, count, err := sqlite.GalleryReaderWriter.Query(ctx, &models.GalleryFilterType{
Performers: &models.MultiCriterionInput{
Value: []string{strconv.Itoa(performer.ID)},
Modifier: models.CriterionModifierIncludes,
@@ -761,7 +763,7 @@ func verifyPerformersGalleryCount(t *testing.T, galleryCountCriterion models.Int
}
func TestPerformerQueryStudio(t *testing.T) {
- withTxn(func(r models.Repository) error {
+ withTxn(func(ctx context.Context) error {
testCases := []struct {
studioIndex int
performerIndex int
@@ -771,7 +773,7 @@ func TestPerformerQueryStudio(t *testing.T) {
{studioIndex: studioIdxWithGalleryPerformer, performerIndex: performerIdxWithGalleryStudio},
}
- sqb := r.Performer()
+ sqb := sqlite.PerformerReaderWriter
for _, tc := range testCases {
studioCriterion := models.HierarchicalMultiCriterionInput{
@@ -785,7 +787,7 @@ func TestPerformerQueryStudio(t *testing.T) {
Studios: &studioCriterion,
}
- performers := queryPerformers(t, sqb, &performerFilter, nil)
+ performers := queryPerformers(ctx, t, sqb, &performerFilter, nil)
assert.Len(t, performers, 1)
@@ -804,7 +806,7 @@ func TestPerformerQueryStudio(t *testing.T) {
Q: &q,
}
- performers = queryPerformers(t, sqb, &performerFilter, &findFilter)
+ performers = queryPerformers(ctx, t, sqb, &performerFilter, &findFilter)
assert.Len(t, performers, 0)
}
@@ -819,21 +821,21 @@ func TestPerformerQueryStudio(t *testing.T) {
Q: &q,
}
- performers := queryPerformers(t, sqb, performerFilter, findFilter)
+ performers := queryPerformers(ctx, t, sqb, performerFilter, findFilter)
assert.Len(t, performers, 1)
assert.Equal(t, imageIDs[performerIdx1WithImage], performers[0].ID)
q = getPerformerStringValue(performerIdxWithSceneStudio, "Name")
- performers = queryPerformers(t, sqb, performerFilter, findFilter)
+ performers = queryPerformers(ctx, t, sqb, performerFilter, findFilter)
assert.Len(t, performers, 0)
performerFilter.Studios.Modifier = models.CriterionModifierNotNull
- performers = queryPerformers(t, sqb, performerFilter, findFilter)
+ performers = queryPerformers(ctx, t, sqb, performerFilter, findFilter)
assert.Len(t, performers, 1)
assert.Equal(t, imageIDs[performerIdxWithSceneStudio], performers[0].ID)
q = getPerformerStringValue(performerIdx1WithImage, "Name")
- performers = queryPerformers(t, sqb, performerFilter, findFilter)
+ performers = queryPerformers(ctx, t, sqb, performerFilter, findFilter)
assert.Len(t, performers, 0)
return nil
@@ -841,8 +843,8 @@ func TestPerformerQueryStudio(t *testing.T) {
}
func TestPerformerStashIDs(t *testing.T) {
- if err := withTxn(func(r models.Repository) error {
- qb := r.Performer()
+ if err := withTxn(func(ctx context.Context) error {
+ qb := sqlite.PerformerReaderWriter
// create performer to test against
const name = "TestStashIDs"
@@ -851,12 +853,12 @@ func TestPerformerStashIDs(t *testing.T) {
Checksum: md5.FromString(name),
Favorite: sql.NullBool{Bool: false, Valid: true},
}
- created, err := qb.Create(performer)
+ created, err := qb.Create(ctx, performer)
if err != nil {
return fmt.Errorf("Error creating performer: %s", err.Error())
}
- testStashIDReaderWriter(t, qb, created.ID)
+ testStashIDReaderWriter(ctx, t, qb, created.ID)
return nil
}); err != nil {
t.Error(err.Error())
@@ -888,13 +890,13 @@ func TestPerformerQueryRating(t *testing.T) {
}
func verifyPerformersRating(t *testing.T, ratingCriterion models.IntCriterionInput) {
- withTxn(func(r models.Repository) error {
- sqb := r.Performer()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.PerformerReaderWriter
performerFilter := models.PerformerFilterType{
Rating: &ratingCriterion,
}
- performers := queryPerformers(t, sqb, &performerFilter, nil)
+ performers := queryPerformers(ctx, t, sqb, &performerFilter, nil)
for _, performer := range performers {
verifyInt64(t, performer.Rating, ratingCriterion)
@@ -905,14 +907,14 @@ func verifyPerformersRating(t *testing.T, ratingCriterion models.IntCriterionInp
}
func TestPerformerQueryIsMissingRating(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Performer()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.PerformerReaderWriter
isMissing := "rating"
performerFilter := models.PerformerFilterType{
IsMissing: &isMissing,
}
- performers := queryPerformers(t, sqb, &performerFilter, nil)
+ performers := queryPerformers(ctx, t, sqb, &performerFilter, nil)
assert.True(t, len(performers) > 0)
@@ -925,14 +927,14 @@ func TestPerformerQueryIsMissingRating(t *testing.T) {
}
func TestPerformerQueryIsMissingImage(t *testing.T) {
- withTxn(func(r models.Repository) error {
+ withTxn(func(ctx context.Context) error {
isMissing := "image"
performerFilter := &models.PerformerFilterType{
IsMissing: &isMissing,
}
// ensure query does not error
- performers, _, err := r.Performer().Query(performerFilter, nil)
+ performers, _, err := sqlite.PerformerReaderWriter.Query(ctx, performerFilter, nil)
if err != nil {
t.Errorf("Error querying performers: %s", err.Error())
}
@@ -940,7 +942,7 @@ func TestPerformerQueryIsMissingImage(t *testing.T) {
assert.True(t, len(performers) > 0)
for _, performer := range performers {
- img, err := r.Performer().GetImage(performer.ID)
+ img, err := sqlite.PerformerReaderWriter.GetImage(ctx, performer.ID)
if err != nil {
t.Errorf("error getting performer image: %s", err.Error())
}
@@ -959,9 +961,9 @@ func TestPerformerQuerySortScenesCount(t *testing.T) {
Direction: &direction,
}
- withTxn(func(r models.Repository) error {
+ withTxn(func(ctx context.Context) error {
// just ensure it queries without error
- performers, _, err := r.Performer().Query(nil, findFilter)
+ performers, _, err := sqlite.PerformerReaderWriter.Query(ctx, nil, findFilter)
if err != nil {
t.Errorf("Error querying performers: %s", err.Error())
}
@@ -976,7 +978,7 @@ func TestPerformerQuerySortScenesCount(t *testing.T) {
// sort in ascending order
direction = models.SortDirectionEnumAsc
- performers, _, err = r.Performer().Query(nil, findFilter)
+ performers, _, err = sqlite.PerformerReaderWriter.Query(ctx, nil, findFilter)
if err != nil {
t.Errorf("Error querying performers: %s", err.Error())
}
diff --git a/pkg/sqlite/query.go b/pkg/sqlite/query.go
index 27ce213b5..25cb2b2b1 100644
--- a/pkg/sqlite/query.go
+++ b/pkg/sqlite/query.go
@@ -1,6 +1,7 @@
package sqlite
import (
+ "context"
"fmt"
"strings"
@@ -54,24 +55,24 @@ func (qb queryBuilder) toSQL(includeSortPagination bool) string {
return body
}
-func (qb queryBuilder) findIDs() ([]int, error) {
+func (qb queryBuilder) findIDs(ctx context.Context) ([]int, error) {
const includeSortPagination = true
sql := qb.toSQL(includeSortPagination)
logger.Tracef("SQL: %s, args: %v", sql, qb.args)
- return qb.repository.runIdsQuery(sql, qb.args)
+ return qb.repository.runIdsQuery(ctx, sql, qb.args)
}
-func (qb queryBuilder) executeFind() ([]int, int, error) {
+func (qb queryBuilder) executeFind(ctx context.Context) ([]int, int, error) {
if qb.err != nil {
return nil, 0, qb.err
}
body := qb.body()
- return qb.repository.executeFindQuery(body, qb.args, qb.sortAndPagination, qb.whereClauses, qb.havingClauses, qb.withClauses, qb.recursiveWith)
+ return qb.repository.executeFindQuery(ctx, body, qb.args, qb.sortAndPagination, qb.whereClauses, qb.havingClauses, qb.withClauses, qb.recursiveWith)
}
-func (qb queryBuilder) executeCount() (int, error) {
+func (qb queryBuilder) executeCount(ctx context.Context) (int, error) {
if qb.err != nil {
return 0, qb.err
}
@@ -89,7 +90,7 @@ func (qb queryBuilder) executeCount() (int, error) {
body = qb.repository.buildQueryBody(body, qb.whereClauses, qb.havingClauses)
countQuery := withClause + qb.repository.buildCountQuery(body)
- return qb.repository.runCountQuery(countQuery, qb.args)
+ return qb.repository.runCountQuery(ctx, countQuery, qb.args)
}
func (qb *queryBuilder) addWhere(clauses ...string) {
diff --git a/pkg/database/regex.go b/pkg/sqlite/regex.go
similarity index 98%
rename from pkg/database/regex.go
rename to pkg/sqlite/regex.go
index dc7b5feb5..bbf713ae3 100644
--- a/pkg/database/regex.go
+++ b/pkg/sqlite/regex.go
@@ -1,4 +1,4 @@
-package database
+package sqlite
import (
"regexp"
diff --git a/pkg/sqlite/repository.go b/pkg/sqlite/repository.go
index f195d9c7e..5bc1fe782 100644
--- a/pkg/sqlite/repository.go
+++ b/pkg/sqlite/repository.go
@@ -1,6 +1,7 @@
package sqlite
import (
+ "context"
"database/sql"
"errors"
"fmt"
@@ -26,23 +27,23 @@ type repository struct {
idColumn string
}
-func (r *repository) get(id int, dest interface{}) error {
+func (r *repository) getByID(ctx context.Context, id int, dest interface{}) error {
stmt := fmt.Sprintf("SELECT * FROM %s WHERE %s = ? LIMIT 1", r.tableName, r.idColumn)
- return r.tx.Get(dest, stmt, id)
+ return r.tx.Get(ctx, dest, stmt, id)
}
-func (r *repository) getAll(id int, f func(rows *sqlx.Rows) error) error {
+func (r *repository) getAll(ctx context.Context, id int, f func(rows *sqlx.Rows) error) error {
stmt := fmt.Sprintf("SELECT * FROM %s WHERE %s = ?", r.tableName, r.idColumn)
- return r.queryFunc(stmt, []interface{}{id}, false, f)
+ return r.queryFunc(ctx, stmt, []interface{}{id}, false, f)
}
-func (r *repository) insert(obj interface{}) (sql.Result, error) {
+func (r *repository) insert(ctx context.Context, obj interface{}) (sql.Result, error) {
stmt := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", r.tableName, listKeys(obj, false), listKeys(obj, true))
- return r.tx.NamedExec(stmt, obj)
+ return r.tx.NamedExec(ctx, stmt, obj)
}
-func (r *repository) insertObject(obj interface{}, out interface{}) error {
- result, err := r.insert(obj)
+func (r *repository) insertObject(ctx context.Context, obj interface{}, out interface{}) error {
+ result, err := r.insert(ctx, obj)
if err != nil {
return err
}
@@ -50,11 +51,11 @@ func (r *repository) insertObject(obj interface{}, out interface{}) error {
if err != nil {
return err
}
- return r.get(int(id), out)
+ return r.getByID(ctx, int(id), out)
}
-func (r *repository) update(id int, obj interface{}, partial bool) error {
- exists, err := r.exists(id)
+func (r *repository) update(ctx context.Context, id int, obj interface{}, partial bool) error {
+ exists, err := r.exists(ctx, id)
if err != nil {
return err
}
@@ -64,13 +65,13 @@ func (r *repository) update(id int, obj interface{}, partial bool) error {
}
stmt := fmt.Sprintf("UPDATE %s SET %s WHERE %s.%s = :id", r.tableName, updateSet(obj, partial), r.tableName, r.idColumn)
- _, err = r.tx.NamedExec(stmt, obj)
+ _, err = r.tx.NamedExec(ctx, stmt, obj)
return err
}
-func (r *repository) updateMap(id int, m map[string]interface{}) error {
- exists, err := r.exists(id)
+func (r *repository) updateMap(ctx context.Context, id int, m map[string]interface{}) error {
+ exists, err := r.exists(ctx, id)
if err != nil {
return err
}
@@ -80,14 +81,14 @@ func (r *repository) updateMap(id int, m map[string]interface{}) error {
}
stmt := fmt.Sprintf("UPDATE %s SET %s WHERE %s.%s = :id", r.tableName, updateSetMap(m), r.tableName, r.idColumn)
- _, err = r.tx.NamedExec(stmt, m)
+ _, err = r.tx.NamedExec(ctx, stmt, m)
return err
}
-func (r *repository) destroyExisting(ids []int) error {
+func (r *repository) destroyExisting(ctx context.Context, ids []int) error {
for _, id := range ids {
- exists, err := r.exists(id)
+ exists, err := r.exists(ctx, id)
if err != nil {
return err
}
@@ -97,13 +98,13 @@ func (r *repository) destroyExisting(ids []int) error {
}
}
- return r.destroy(ids)
+ return r.destroy(ctx, ids)
}
-func (r *repository) destroy(ids []int) error {
+func (r *repository) destroy(ctx context.Context, ids []int) error {
for _, id := range ids {
stmt := fmt.Sprintf("DELETE FROM %s WHERE %s = ?", r.tableName, r.idColumn)
- if _, err := r.tx.Exec(stmt, id); err != nil {
+ if _, err := r.tx.Exec(ctx, stmt, id); err != nil {
return err
}
}
@@ -111,11 +112,11 @@ func (r *repository) destroy(ids []int) error {
return nil
}
-func (r *repository) exists(id int) (bool, error) {
+func (r *repository) exists(ctx context.Context, id int) (bool, error) {
stmt := fmt.Sprintf("SELECT %s FROM %s WHERE %s = ? LIMIT 1", r.idColumn, r.tableName, r.idColumn)
stmt = r.buildCountQuery(stmt)
- c, err := r.runCountQuery(stmt, []interface{}{id})
+ c, err := r.runCountQuery(ctx, stmt, []interface{}{id})
if err != nil {
return false, err
}
@@ -127,25 +128,25 @@ func (r *repository) buildCountQuery(query string) string {
return "SELECT COUNT(*) as count FROM (" + query + ") as temp"
}
-func (r *repository) runCountQuery(query string, args []interface{}) (int, error) {
+func (r *repository) runCountQuery(ctx context.Context, query string, args []interface{}) (int, error) {
result := struct {
Int int `db:"count"`
}{0}
// Perform query and fetch result
- if err := r.tx.Get(&result, query, args...); err != nil && !errors.Is(err, sql.ErrNoRows) {
+ if err := r.tx.Get(ctx, &result, query, args...); err != nil && !errors.Is(err, sql.ErrNoRows) {
return 0, err
}
return result.Int, nil
}
-func (r *repository) runIdsQuery(query string, args []interface{}) ([]int, error) {
+func (r *repository) runIdsQuery(ctx context.Context, query string, args []interface{}) ([]int, error) {
var result []struct {
Int int `db:"id"`
}
- if err := r.tx.Select(&result, query, args...); err != nil && !errors.Is(err, sql.ErrNoRows) {
+ if err := r.tx.Select(ctx, &result, query, args...); err != nil && !errors.Is(err, sql.ErrNoRows) {
return []int{}, err
}
@@ -156,24 +157,24 @@ func (r *repository) runIdsQuery(query string, args []interface{}) ([]int, error
return vsm, nil
}
-func (r *repository) runSumQuery(query string, args []interface{}) (float64, error) {
+func (r *repository) runSumQuery(ctx context.Context, query string, args []interface{}) (float64, error) {
// Perform query and fetch result
result := struct {
Float64 float64 `db:"sum"`
}{0}
// Perform query and fetch result
- if err := r.tx.Get(&result, query, args...); err != nil && !errors.Is(err, sql.ErrNoRows) {
+ if err := r.tx.Get(ctx, &result, query, args...); err != nil && !errors.Is(err, sql.ErrNoRows) {
return 0, err
}
return result.Float64, nil
}
-func (r *repository) queryFunc(query string, args []interface{}, single bool, f func(rows *sqlx.Rows) error) error {
+func (r *repository) queryFunc(ctx context.Context, query string, args []interface{}, single bool, f func(rows *sqlx.Rows) error) error {
logger.Tracef("SQL: %s, args: %v", query, args)
- rows, err := r.tx.Queryx(query, args...)
+ rows, err := r.tx.Queryx(ctx, query, args...)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return err
@@ -196,8 +197,8 @@ func (r *repository) queryFunc(query string, args []interface{}, single bool, f
return nil
}
-func (r *repository) query(query string, args []interface{}, out objectList) error {
- return r.queryFunc(query, args, false, func(rows *sqlx.Rows) error {
+func (r *repository) query(ctx context.Context, query string, args []interface{}, out objectList) error {
+ return r.queryFunc(ctx, query, args, false, func(rows *sqlx.Rows) error {
object := out.New()
if err := rows.StructScan(object); err != nil {
return err
@@ -207,8 +208,8 @@ func (r *repository) query(query string, args []interface{}, out objectList) err
})
}
-func (r *repository) queryStruct(query string, args []interface{}, out interface{}) error {
- return r.queryFunc(query, args, true, func(rows *sqlx.Rows) error {
+func (r *repository) queryStruct(ctx context.Context, query string, args []interface{}, out interface{}) error {
+ return r.queryFunc(ctx, query, args, true, func(rows *sqlx.Rows) error {
if err := rows.StructScan(out); err != nil {
return err
}
@@ -216,8 +217,8 @@ func (r *repository) queryStruct(query string, args []interface{}, out interface
})
}
-func (r *repository) querySimple(query string, args []interface{}, out interface{}) error {
- rows, err := r.tx.Queryx(query, args...)
+func (r *repository) querySimple(ctx context.Context, query string, args []interface{}, out interface{}) error {
+ rows, err := r.tx.Queryx(ctx, query, args...)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return err
@@ -249,7 +250,7 @@ func (r *repository) buildQueryBody(body string, whereClauses []string, havingCl
return body
}
-func (r *repository) executeFindQuery(body string, args []interface{}, sortAndPagination string, whereClauses []string, havingClauses []string, withClauses []string, recursiveWith bool) ([]int, int, error) {
+func (r *repository) executeFindQuery(ctx context.Context, body string, args []interface{}, sortAndPagination string, whereClauses []string, havingClauses []string, withClauses []string, recursiveWith bool) ([]int, int, error) {
body = r.buildQueryBody(body, whereClauses, havingClauses)
withClause := ""
@@ -272,8 +273,8 @@ func (r *repository) executeFindQuery(body string, args []interface{}, sortAndPa
var idsResult []int
var idsErr error
- countResult, countErr = r.runCountQuery(countQuery, args)
- idsResult, idsErr = r.runIdsQuery(idsQuery, args)
+ countResult, countErr = r.runCountQuery(ctx, countQuery, args)
+ idsResult, idsErr = r.runIdsQuery(ctx, idsQuery, args)
if countErr != nil {
return nil, 0, fmt.Errorf("error executing count query with SQL: %s, args: %v, error: %s", countQuery, args, countErr.Error())
@@ -318,23 +319,23 @@ type joinRepository struct {
fkColumn string
}
-func (r *joinRepository) getIDs(id int) ([]int, error) {
+func (r *joinRepository) getIDs(ctx context.Context, id int) ([]int, error) {
query := fmt.Sprintf(`SELECT %s as id from %s WHERE %s = ?`, r.fkColumn, r.tableName, r.idColumn)
- return r.runIdsQuery(query, []interface{}{id})
+ return r.runIdsQuery(ctx, query, []interface{}{id})
}
-func (r *joinRepository) insert(id, foreignID int) (sql.Result, error) {
+func (r *joinRepository) insert(ctx context.Context, id, foreignID int) (sql.Result, error) {
stmt := fmt.Sprintf("INSERT INTO %s (%s, %s) VALUES (?, ?)", r.tableName, r.idColumn, r.fkColumn)
- return r.tx.Exec(stmt, id, foreignID)
+ return r.tx.Exec(ctx, stmt, id, foreignID)
}
-func (r *joinRepository) replace(id int, foreignIDs []int) error {
- if err := r.destroy([]int{id}); err != nil {
+func (r *joinRepository) replace(ctx context.Context, id int, foreignIDs []int) error {
+ if err := r.destroy(ctx, []int{id}); err != nil {
return err
}
for _, fk := range foreignIDs {
- if _, err := r.insert(id, fk); err != nil {
+ if _, err := r.insert(ctx, id, fk); err != nil {
return err
}
}
@@ -347,20 +348,20 @@ type imageRepository struct {
imageColumn string
}
-func (r *imageRepository) get(id int) ([]byte, error) {
+func (r *imageRepository) get(ctx context.Context, id int) ([]byte, error) {
query := fmt.Sprintf("SELECT %s from %s WHERE %s = ?", r.imageColumn, r.tableName, r.idColumn)
var ret []byte
- err := r.querySimple(query, []interface{}{id}, &ret)
+ err := r.querySimple(ctx, query, []interface{}{id}, &ret)
return ret, err
}
-func (r *imageRepository) replace(id int, image []byte) error {
- if err := r.destroy([]int{id}); err != nil {
+func (r *imageRepository) replace(ctx context.Context, id int, image []byte) error {
+ if err := r.destroy(ctx, []int{id}); err != nil {
return err
}
stmt := fmt.Sprintf("INSERT INTO %s (%s, %s) VALUES (?, ?)", r.tableName, r.idColumn, r.imageColumn)
- _, err := r.tx.Exec(stmt, id, image)
+ _, err := r.tx.Exec(ctx, stmt, id, image)
return err
}
@@ -369,10 +370,10 @@ type captionRepository struct {
repository
}
-func (r *captionRepository) get(id int) ([]*models.SceneCaption, error) {
+func (r *captionRepository) get(ctx context.Context, id int) ([]*models.SceneCaption, error) {
query := fmt.Sprintf("SELECT %s, %s, %s from %s WHERE %s = ?", sceneCaptionCodeColumn, sceneCaptionFilenameColumn, sceneCaptionTypeColumn, r.tableName, r.idColumn)
var ret []*models.SceneCaption
- err := r.queryFunc(query, []interface{}{id}, false, func(rows *sqlx.Rows) error {
+ err := r.queryFunc(ctx, query, []interface{}{id}, false, func(rows *sqlx.Rows) error {
var captionCode string
var captionFilename string
var captionType string
@@ -392,18 +393,18 @@ func (r *captionRepository) get(id int) ([]*models.SceneCaption, error) {
return ret, err
}
-func (r *captionRepository) insert(id int, caption *models.SceneCaption) (sql.Result, error) {
+func (r *captionRepository) insert(ctx context.Context, id int, caption *models.SceneCaption) (sql.Result, error) {
stmt := fmt.Sprintf("INSERT INTO %s (%s, %s, %s, %s) VALUES (?, ?, ?, ?)", r.tableName, r.idColumn, sceneCaptionCodeColumn, sceneCaptionFilenameColumn, sceneCaptionTypeColumn)
- return r.tx.Exec(stmt, id, caption.LanguageCode, caption.Filename, caption.CaptionType)
+ return r.tx.Exec(ctx, stmt, id, caption.LanguageCode, caption.Filename, caption.CaptionType)
}
-func (r *captionRepository) replace(id int, captions []*models.SceneCaption) error {
- if err := r.destroy([]int{id}); err != nil {
+func (r *captionRepository) replace(ctx context.Context, id int, captions []*models.SceneCaption) error {
+ if err := r.destroy(ctx, []int{id}); err != nil {
return err
}
for _, caption := range captions {
- if _, err := r.insert(id, caption); err != nil {
+ if _, err := r.insert(ctx, id, caption); err != nil {
return err
}
}
@@ -416,10 +417,10 @@ type stringRepository struct {
stringColumn string
}
-func (r *stringRepository) get(id int) ([]string, error) {
+func (r *stringRepository) get(ctx context.Context, id int) ([]string, error) {
query := fmt.Sprintf("SELECT %s from %s WHERE %s = ?", r.stringColumn, r.tableName, r.idColumn)
var ret []string
- err := r.queryFunc(query, []interface{}{id}, false, func(rows *sqlx.Rows) error {
+ err := r.queryFunc(ctx, query, []interface{}{id}, false, func(rows *sqlx.Rows) error {
var out string
if err := rows.Scan(&out); err != nil {
return err
@@ -431,18 +432,18 @@ func (r *stringRepository) get(id int) ([]string, error) {
return ret, err
}
-func (r *stringRepository) insert(id int, s string) (sql.Result, error) {
+func (r *stringRepository) insert(ctx context.Context, id int, s string) (sql.Result, error) {
stmt := fmt.Sprintf("INSERT INTO %s (%s, %s) VALUES (?, ?)", r.tableName, r.idColumn, r.stringColumn)
- return r.tx.Exec(stmt, id, s)
+ return r.tx.Exec(ctx, stmt, id, s)
}
-func (r *stringRepository) replace(id int, newStrings []string) error {
- if err := r.destroy([]int{id}); err != nil {
+func (r *stringRepository) replace(ctx context.Context, id int, newStrings []string) error {
+ if err := r.destroy(ctx, []int{id}); err != nil {
return err
}
for _, s := range newStrings {
- if _, err := r.insert(id, s); err != nil {
+ if _, err := r.insert(ctx, id, s); err != nil {
return err
}
}
@@ -464,21 +465,21 @@ func (s *stashIDs) New() interface{} {
return &models.StashID{}
}
-func (r *stashIDRepository) get(id int) ([]*models.StashID, error) {
+func (r *stashIDRepository) get(ctx context.Context, id int) ([]*models.StashID, error) {
query := fmt.Sprintf("SELECT stash_id, endpoint from %s WHERE %s = ?", r.tableName, r.idColumn)
var ret stashIDs
- err := r.query(query, []interface{}{id}, &ret)
+ err := r.query(ctx, query, []interface{}{id}, &ret)
return []*models.StashID(ret), err
}
-func (r *stashIDRepository) replace(id int, newIDs []models.StashID) error {
- if err := r.destroy([]int{id}); err != nil {
+func (r *stashIDRepository) replace(ctx context.Context, id int, newIDs []models.StashID) error {
+ if err := r.destroy(ctx, []int{id}); err != nil {
return err
}
query := fmt.Sprintf("INSERT INTO %s (%s, endpoint, stash_id) VALUES (?, ?, ?)", r.tableName, r.idColumn)
for _, stashID := range newIDs {
- _, err := r.tx.Exec(query, id, stashID.Endpoint, stashID.StashID)
+ _, err := r.tx.Exec(ctx, query, id, stashID.Endpoint, stashID.StashID)
if err != nil {
return err
}
diff --git a/pkg/sqlite/saved_filter.go b/pkg/sqlite/saved_filter.go
index 6c507bee3..54fddcbc8 100644
--- a/pkg/sqlite/saved_filter.go
+++ b/pkg/sqlite/saved_filter.go
@@ -1,6 +1,7 @@
package sqlite
import (
+ "context"
"database/sql"
"errors"
"fmt"
@@ -15,42 +16,39 @@ type savedFilterQueryBuilder struct {
repository
}
-func NewSavedFilterReaderWriter(tx dbi) *savedFilterQueryBuilder {
- return &savedFilterQueryBuilder{
- repository{
- tx: tx,
- tableName: savedFilterTable,
- idColumn: idColumn,
- },
- }
+var SavedFilterReaderWriter = &savedFilterQueryBuilder{
+ repository{
+ tableName: savedFilterTable,
+ idColumn: idColumn,
+ },
}
-func (qb *savedFilterQueryBuilder) Create(newObject models.SavedFilter) (*models.SavedFilter, error) {
+func (qb *savedFilterQueryBuilder) Create(ctx context.Context, newObject models.SavedFilter) (*models.SavedFilter, error) {
var ret models.SavedFilter
- if err := qb.insertObject(newObject, &ret); err != nil {
+ if err := qb.insertObject(ctx, newObject, &ret); err != nil {
return nil, err
}
return &ret, nil
}
-func (qb *savedFilterQueryBuilder) Update(updatedObject models.SavedFilter) (*models.SavedFilter, error) {
+func (qb *savedFilterQueryBuilder) Update(ctx context.Context, updatedObject models.SavedFilter) (*models.SavedFilter, error) {
const partial = false
- if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil {
+ if err := qb.update(ctx, updatedObject.ID, updatedObject, partial); err != nil {
return nil, err
}
var ret models.SavedFilter
- if err := qb.get(updatedObject.ID, &ret); err != nil {
+ if err := qb.getByID(ctx, updatedObject.ID, &ret); err != nil {
return nil, err
}
return &ret, nil
}
-func (qb *savedFilterQueryBuilder) SetDefault(obj models.SavedFilter) (*models.SavedFilter, error) {
+func (qb *savedFilterQueryBuilder) SetDefault(ctx context.Context, obj models.SavedFilter) (*models.SavedFilter, error) {
// find the existing default
- existing, err := qb.FindDefault(obj.Mode)
+ existing, err := qb.FindDefault(ctx, obj.Mode)
if err != nil {
return nil, err
@@ -60,19 +58,19 @@ func (qb *savedFilterQueryBuilder) SetDefault(obj models.SavedFilter) (*models.S
if existing != nil {
obj.ID = existing.ID
- return qb.Update(obj)
+ return qb.Update(ctx, obj)
}
- return qb.Create(obj)
+ return qb.Create(ctx, obj)
}
-func (qb *savedFilterQueryBuilder) Destroy(id int) error {
- return qb.destroyExisting([]int{id})
+func (qb *savedFilterQueryBuilder) Destroy(ctx context.Context, id int) error {
+ return qb.destroyExisting(ctx, []int{id})
}
-func (qb *savedFilterQueryBuilder) Find(id int) (*models.SavedFilter, error) {
+func (qb *savedFilterQueryBuilder) Find(ctx context.Context, id int) (*models.SavedFilter, error) {
var ret models.SavedFilter
- if err := qb.get(id, &ret); err != nil {
+ if err := qb.getByID(ctx, id, &ret); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
@@ -81,10 +79,10 @@ func (qb *savedFilterQueryBuilder) Find(id int) (*models.SavedFilter, error) {
return &ret, nil
}
-func (qb *savedFilterQueryBuilder) FindMany(ids []int, ignoreNotFound bool) ([]*models.SavedFilter, error) {
+func (qb *savedFilterQueryBuilder) FindMany(ctx context.Context, ids []int, ignoreNotFound bool) ([]*models.SavedFilter, error) {
var filters []*models.SavedFilter
for _, id := range ids {
- filter, err := qb.Find(id)
+ filter, err := qb.Find(ctx, id)
if err != nil {
return nil, err
}
@@ -99,24 +97,24 @@ func (qb *savedFilterQueryBuilder) FindMany(ids []int, ignoreNotFound bool) ([]*
return filters, nil
}
-func (qb *savedFilterQueryBuilder) FindByMode(mode models.FilterMode) ([]*models.SavedFilter, error) {
+func (qb *savedFilterQueryBuilder) FindByMode(ctx context.Context, mode models.FilterMode) ([]*models.SavedFilter, error) {
// exclude empty-named filters - these are the internal default filters
query := fmt.Sprintf(`SELECT * FROM %s WHERE mode = ? AND name != ?`, savedFilterTable)
var ret models.SavedFilters
- if err := qb.query(query, []interface{}{mode, savedFilterDefaultName}, &ret); err != nil {
+ if err := qb.query(ctx, query, []interface{}{mode, savedFilterDefaultName}, &ret); err != nil {
return nil, err
}
return []*models.SavedFilter(ret), nil
}
-func (qb *savedFilterQueryBuilder) FindDefault(mode models.FilterMode) (*models.SavedFilter, error) {
+func (qb *savedFilterQueryBuilder) FindDefault(ctx context.Context, mode models.FilterMode) (*models.SavedFilter, error) {
query := fmt.Sprintf(`SELECT * FROM %s WHERE mode = ? AND name = ?`, savedFilterTable)
var ret models.SavedFilters
- if err := qb.query(query, []interface{}{mode, savedFilterDefaultName}, &ret); err != nil {
+ if err := qb.query(ctx, query, []interface{}{mode, savedFilterDefaultName}, &ret); err != nil {
return nil, err
}
@@ -127,9 +125,9 @@ func (qb *savedFilterQueryBuilder) FindDefault(mode models.FilterMode) (*models.
return nil, nil
}
-func (qb *savedFilterQueryBuilder) All() ([]*models.SavedFilter, error) {
+func (qb *savedFilterQueryBuilder) All(ctx context.Context) ([]*models.SavedFilter, error) {
var ret models.SavedFilters
- if err := qb.query(selectAll(savedFilterTable), nil, &ret); err != nil {
+ if err := qb.query(ctx, selectAll(savedFilterTable), nil, &ret); err != nil {
return nil, err
}
diff --git a/pkg/sqlite/saved_filter_test.go b/pkg/sqlite/saved_filter_test.go
index 5ec049290..c22b374fb 100644
--- a/pkg/sqlite/saved_filter_test.go
+++ b/pkg/sqlite/saved_filter_test.go
@@ -4,15 +4,17 @@
package sqlite_test
import (
+ "context"
"testing"
"github.com/stashapp/stash/pkg/models"
+ "github.com/stashapp/stash/pkg/sqlite"
"github.com/stretchr/testify/assert"
)
func TestSavedFilterFind(t *testing.T) {
- withTxn(func(r models.Repository) error {
- savedFilter, err := r.SavedFilter().Find(savedFilterIDs[savedFilterIdxImage])
+ withTxn(func(ctx context.Context) error {
+ savedFilter, err := sqlite.SavedFilterReaderWriter.Find(ctx, savedFilterIDs[savedFilterIdxImage])
if err != nil {
t.Errorf("Error finding saved filter: %s", err.Error())
@@ -25,8 +27,8 @@ func TestSavedFilterFind(t *testing.T) {
}
func TestSavedFilterFindByMode(t *testing.T) {
- withTxn(func(r models.Repository) error {
- savedFilters, err := r.SavedFilter().FindByMode(models.FilterModeScenes)
+ withTxn(func(ctx context.Context) error {
+ savedFilters, err := sqlite.SavedFilterReaderWriter.FindByMode(ctx, models.FilterModeScenes)
if err != nil {
t.Errorf("Error finding saved filters: %s", err.Error())
@@ -45,8 +47,8 @@ func TestSavedFilterDestroy(t *testing.T) {
var id int
// create the saved filter to destroy
- withTxn(func(r models.Repository) error {
- created, err := r.SavedFilter().Create(models.SavedFilter{
+ withTxn(func(ctx context.Context) error {
+ created, err := sqlite.SavedFilterReaderWriter.Create(ctx, models.SavedFilter{
Name: filterName,
Mode: models.FilterModeScenes,
Filter: testFilter,
@@ -59,15 +61,15 @@ func TestSavedFilterDestroy(t *testing.T) {
return err
})
- withTxn(func(r models.Repository) error {
- qb := r.SavedFilter()
+ withTxn(func(ctx context.Context) error {
+ qb := sqlite.SavedFilterReaderWriter
- return qb.Destroy(id)
+ return qb.Destroy(ctx, id)
})
// now try to find it
- withTxn(func(r models.Repository) error {
- found, err := r.SavedFilter().Find(id)
+ withTxn(func(ctx context.Context) error {
+ found, err := sqlite.SavedFilterReaderWriter.Find(ctx, id)
if err == nil {
assert.Nil(t, found)
}
@@ -77,8 +79,8 @@ func TestSavedFilterDestroy(t *testing.T) {
}
func TestSavedFilterFindDefault(t *testing.T) {
- withTxn(func(r models.Repository) error {
- def, err := r.SavedFilter().FindDefault(models.FilterModeScenes)
+ withTxn(func(ctx context.Context) error {
+ def, err := sqlite.SavedFilterReaderWriter.FindDefault(ctx, models.FilterModeScenes)
if err == nil {
assert.Equal(t, savedFilterIDs[savedFilterIdxDefaultScene], def.ID)
}
@@ -90,8 +92,8 @@ func TestSavedFilterFindDefault(t *testing.T) {
func TestSavedFilterSetDefault(t *testing.T) {
const newFilter = "foo"
- withTxn(func(r models.Repository) error {
- _, err := r.SavedFilter().SetDefault(models.SavedFilter{
+ withTxn(func(ctx context.Context) error {
+ _, err := sqlite.SavedFilterReaderWriter.SetDefault(ctx, models.SavedFilter{
Mode: models.FilterModeMovies,
Filter: newFilter,
})
@@ -100,8 +102,8 @@ func TestSavedFilterSetDefault(t *testing.T) {
})
var defID int
- withTxn(func(r models.Repository) error {
- def, err := r.SavedFilter().FindDefault(models.FilterModeMovies)
+ withTxn(func(ctx context.Context) error {
+ def, err := sqlite.SavedFilterReaderWriter.FindDefault(ctx, models.FilterModeMovies)
if err == nil {
defID = def.ID
assert.Equal(t, newFilter, def.Filter)
@@ -111,8 +113,8 @@ func TestSavedFilterSetDefault(t *testing.T) {
})
// destroy it again
- withTxn(func(r models.Repository) error {
- return r.SavedFilter().Destroy(defID)
+ withTxn(func(ctx context.Context) error {
+ return sqlite.SavedFilterReaderWriter.Destroy(ctx, defID)
})
}
diff --git a/pkg/sqlite/scene.go b/pkg/sqlite/scene.go
index cb5085dfd..921e2f4c3 100644
--- a/pkg/sqlite/scene.go
+++ b/pkg/sqlite/scene.go
@@ -1,6 +1,7 @@
package sqlite
import (
+ "context"
"database/sql"
"errors"
"fmt"
@@ -89,45 +90,42 @@ type sceneQueryBuilder struct {
repository
}
-func NewSceneReaderWriter(tx dbi) *sceneQueryBuilder {
- return &sceneQueryBuilder{
- repository{
- tx: tx,
- tableName: sceneTable,
- idColumn: idColumn,
- },
- }
+var SceneReaderWriter = &sceneQueryBuilder{
+ repository{
+ tableName: sceneTable,
+ idColumn: idColumn,
+ },
}
-func (qb *sceneQueryBuilder) Create(newObject models.Scene) (*models.Scene, error) {
+func (qb *sceneQueryBuilder) Create(ctx context.Context, newObject models.Scene) (*models.Scene, error) {
var ret models.Scene
- if err := qb.insertObject(newObject, &ret); err != nil {
+ if err := qb.insertObject(ctx, newObject, &ret); err != nil {
return nil, err
}
return &ret, nil
}
-func (qb *sceneQueryBuilder) Update(updatedObject models.ScenePartial) (*models.Scene, error) {
+func (qb *sceneQueryBuilder) Update(ctx context.Context, updatedObject models.ScenePartial) (*models.Scene, error) {
const partial = true
- if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil {
+ if err := qb.update(ctx, updatedObject.ID, updatedObject, partial); err != nil {
return nil, err
}
- return qb.find(updatedObject.ID)
+ return qb.find(ctx, updatedObject.ID)
}
-func (qb *sceneQueryBuilder) UpdateFull(updatedObject models.Scene) (*models.Scene, error) {
+func (qb *sceneQueryBuilder) UpdateFull(ctx context.Context, updatedObject models.Scene) (*models.Scene, error) {
const partial = false
- if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil {
+ if err := qb.update(ctx, updatedObject.ID, updatedObject, partial); err != nil {
return nil, err
}
- return qb.find(updatedObject.ID)
+ return qb.find(ctx, updatedObject.ID)
}
-func (qb *sceneQueryBuilder) UpdateFileModTime(id int, modTime models.NullSQLiteTimestamp) error {
- return qb.updateMap(id, map[string]interface{}{
+func (qb *sceneQueryBuilder) UpdateFileModTime(ctx context.Context, id int, modTime models.NullSQLiteTimestamp) error {
+ return qb.updateMap(ctx, id, map[string]interface{}{
"file_mod_time": modTime,
})
}
@@ -142,17 +140,17 @@ func (qb *sceneQueryBuilder) captionRepository() *captionRepository {
}
}
-func (qb *sceneQueryBuilder) GetCaptions(sceneID int) ([]*models.SceneCaption, error) {
- return qb.captionRepository().get(sceneID)
+func (qb *sceneQueryBuilder) GetCaptions(ctx context.Context, sceneID int) ([]*models.SceneCaption, error) {
+ return qb.captionRepository().get(ctx, sceneID)
}
-func (qb *sceneQueryBuilder) UpdateCaptions(sceneID int, captions []*models.SceneCaption) error {
- return qb.captionRepository().replace(sceneID, captions)
+func (qb *sceneQueryBuilder) UpdateCaptions(ctx context.Context, sceneID int, captions []*models.SceneCaption) error {
+ return qb.captionRepository().replace(ctx, sceneID, captions)
}
-func (qb *sceneQueryBuilder) IncrementOCounter(id int) (int, error) {
- _, err := qb.tx.Exec(
+func (qb *sceneQueryBuilder) IncrementOCounter(ctx context.Context, id int) (int, error) {
+ _, err := qb.tx.Exec(ctx,
`UPDATE scenes SET o_counter = o_counter + 1 WHERE scenes.id = ?`,
id,
)
@@ -160,7 +158,7 @@ func (qb *sceneQueryBuilder) IncrementOCounter(id int) (int, error) {
return 0, err
}
- scene, err := qb.find(id)
+ scene, err := qb.find(ctx, id)
if err != nil {
return 0, err
}
@@ -168,8 +166,8 @@ func (qb *sceneQueryBuilder) IncrementOCounter(id int) (int, error) {
return scene.OCounter, nil
}
-func (qb *sceneQueryBuilder) DecrementOCounter(id int) (int, error) {
- _, err := qb.tx.Exec(
+func (qb *sceneQueryBuilder) DecrementOCounter(ctx context.Context, id int) (int, error) {
+ _, err := qb.tx.Exec(ctx,
`UPDATE scenes SET o_counter = o_counter - 1 WHERE scenes.id = ? and scenes.o_counter > 0`,
id,
)
@@ -177,7 +175,7 @@ func (qb *sceneQueryBuilder) DecrementOCounter(id int) (int, error) {
return 0, err
}
- scene, err := qb.find(id)
+ scene, err := qb.find(ctx, id)
if err != nil {
return 0, err
}
@@ -185,8 +183,8 @@ func (qb *sceneQueryBuilder) DecrementOCounter(id int) (int, error) {
return scene.OCounter, nil
}
-func (qb *sceneQueryBuilder) ResetOCounter(id int) (int, error) {
- _, err := qb.tx.Exec(
+func (qb *sceneQueryBuilder) ResetOCounter(ctx context.Context, id int) (int, error) {
+ _, err := qb.tx.Exec(ctx,
`UPDATE scenes SET o_counter = 0 WHERE scenes.id = ?`,
id,
)
@@ -194,7 +192,7 @@ func (qb *sceneQueryBuilder) ResetOCounter(id int) (int, error) {
return 0, err
}
- scene, err := qb.find(id)
+ scene, err := qb.find(ctx, id)
if err != nil {
return 0, err
}
@@ -202,27 +200,27 @@ func (qb *sceneQueryBuilder) ResetOCounter(id int) (int, error) {
return scene.OCounter, nil
}
-func (qb *sceneQueryBuilder) Destroy(id int) error {
+func (qb *sceneQueryBuilder) Destroy(ctx context.Context, id int) error {
// delete all related table rows
// TODO - this should be handled by a delete cascade
- if err := qb.performersRepository().destroy([]int{id}); err != nil {
+ if err := qb.performersRepository().destroy(ctx, []int{id}); err != nil {
return err
}
// scene markers should be handled prior to calling destroy
// galleries should be handled prior to calling destroy
- return qb.destroyExisting([]int{id})
+ return qb.destroyExisting(ctx, []int{id})
}
-func (qb *sceneQueryBuilder) Find(id int) (*models.Scene, error) {
- return qb.find(id)
+func (qb *sceneQueryBuilder) Find(ctx context.Context, id int) (*models.Scene, error) {
+ return qb.find(ctx, id)
}
-func (qb *sceneQueryBuilder) FindMany(ids []int) ([]*models.Scene, error) {
+func (qb *sceneQueryBuilder) FindMany(ctx context.Context, ids []int) ([]*models.Scene, error) {
var scenes []*models.Scene
for _, id := range ids {
- scene, err := qb.Find(id)
+ scene, err := qb.Find(ctx, id)
if err != nil {
return nil, err
}
@@ -237,9 +235,9 @@ func (qb *sceneQueryBuilder) FindMany(ids []int) ([]*models.Scene, error) {
return scenes, nil
}
-func (qb *sceneQueryBuilder) find(id int) (*models.Scene, error) {
+func (qb *sceneQueryBuilder) find(ctx context.Context, id int) (*models.Scene, error) {
var ret models.Scene
- if err := qb.get(id, &ret); err != nil {
+ if err := qb.getByID(ctx, id, &ret); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
@@ -248,92 +246,92 @@ func (qb *sceneQueryBuilder) find(id int) (*models.Scene, error) {
return &ret, nil
}
-func (qb *sceneQueryBuilder) FindByChecksum(checksum string) (*models.Scene, error) {
+func (qb *sceneQueryBuilder) FindByChecksum(ctx context.Context, checksum string) (*models.Scene, error) {
query := "SELECT * FROM scenes WHERE checksum = ? LIMIT 1"
args := []interface{}{checksum}
- return qb.queryScene(query, args)
+ return qb.queryScene(ctx, query, args)
}
-func (qb *sceneQueryBuilder) FindByOSHash(oshash string) (*models.Scene, error) {
+func (qb *sceneQueryBuilder) FindByOSHash(ctx context.Context, oshash string) (*models.Scene, error) {
query := "SELECT * FROM scenes WHERE oshash = ? LIMIT 1"
args := []interface{}{oshash}
- return qb.queryScene(query, args)
+ return qb.queryScene(ctx, query, args)
}
-func (qb *sceneQueryBuilder) FindByPath(path string) (*models.Scene, error) {
+func (qb *sceneQueryBuilder) FindByPath(ctx context.Context, path string) (*models.Scene, error) {
query := selectAll(sceneTable) + "WHERE path = ? LIMIT 1"
args := []interface{}{path}
- return qb.queryScene(query, args)
+ return qb.queryScene(ctx, query, args)
}
-func (qb *sceneQueryBuilder) FindByPerformerID(performerID int) ([]*models.Scene, error) {
+func (qb *sceneQueryBuilder) FindByPerformerID(ctx context.Context, performerID int) ([]*models.Scene, error) {
args := []interface{}{performerID}
- return qb.queryScenes(scenesForPerformerQuery, args)
+ return qb.queryScenes(ctx, scenesForPerformerQuery, args)
}
-func (qb *sceneQueryBuilder) FindByGalleryID(galleryID int) ([]*models.Scene, error) {
+func (qb *sceneQueryBuilder) FindByGalleryID(ctx context.Context, galleryID int) ([]*models.Scene, error) {
args := []interface{}{galleryID}
- return qb.queryScenes(scenesForGalleryQuery, args)
+ return qb.queryScenes(ctx, scenesForGalleryQuery, args)
}
-func (qb *sceneQueryBuilder) CountByPerformerID(performerID int) (int, error) {
+func (qb *sceneQueryBuilder) CountByPerformerID(ctx context.Context, performerID int) (int, error) {
args := []interface{}{performerID}
- return qb.runCountQuery(qb.buildCountQuery(countScenesForPerformerQuery), args)
+ return qb.runCountQuery(ctx, qb.buildCountQuery(countScenesForPerformerQuery), args)
}
-func (qb *sceneQueryBuilder) FindByMovieID(movieID int) ([]*models.Scene, error) {
+func (qb *sceneQueryBuilder) FindByMovieID(ctx context.Context, movieID int) ([]*models.Scene, error) {
args := []interface{}{movieID}
- return qb.queryScenes(scenesForMovieQuery, args)
+ return qb.queryScenes(ctx, scenesForMovieQuery, args)
}
-func (qb *sceneQueryBuilder) CountByMovieID(movieID int) (int, error) {
+func (qb *sceneQueryBuilder) CountByMovieID(ctx context.Context, movieID int) (int, error) {
args := []interface{}{movieID}
- return qb.runCountQuery(qb.buildCountQuery(scenesForMovieQuery), args)
+ return qb.runCountQuery(ctx, qb.buildCountQuery(scenesForMovieQuery), args)
}
-func (qb *sceneQueryBuilder) Count() (int, error) {
- return qb.runCountQuery(qb.buildCountQuery("SELECT scenes.id FROM scenes"), nil)
+func (qb *sceneQueryBuilder) Count(ctx context.Context) (int, error) {
+ return qb.runCountQuery(ctx, qb.buildCountQuery("SELECT scenes.id FROM scenes"), nil)
}
-func (qb *sceneQueryBuilder) Size() (float64, error) {
- return qb.runSumQuery("SELECT SUM(cast(size as double)) as sum FROM scenes", nil)
+func (qb *sceneQueryBuilder) Size(ctx context.Context) (float64, error) {
+ return qb.runSumQuery(ctx, "SELECT SUM(cast(size as double)) as sum FROM scenes", nil)
}
-func (qb *sceneQueryBuilder) Duration() (float64, error) {
- return qb.runSumQuery("SELECT SUM(cast(duration as double)) as sum FROM scenes", nil)
+func (qb *sceneQueryBuilder) Duration(ctx context.Context) (float64, error) {
+ return qb.runSumQuery(ctx, "SELECT SUM(cast(duration as double)) as sum FROM scenes", nil)
}
-func (qb *sceneQueryBuilder) CountByStudioID(studioID int) (int, error) {
+func (qb *sceneQueryBuilder) CountByStudioID(ctx context.Context, studioID int) (int, error) {
args := []interface{}{studioID}
- return qb.runCountQuery(qb.buildCountQuery(scenesForStudioQuery), args)
+ return qb.runCountQuery(ctx, qb.buildCountQuery(scenesForStudioQuery), args)
}
-func (qb *sceneQueryBuilder) CountByTagID(tagID int) (int, error) {
+func (qb *sceneQueryBuilder) CountByTagID(ctx context.Context, tagID int) (int, error) {
args := []interface{}{tagID}
- return qb.runCountQuery(qb.buildCountQuery(countScenesForTagQuery), args)
+ return qb.runCountQuery(ctx, qb.buildCountQuery(countScenesForTagQuery), args)
}
// CountMissingChecksum returns the number of scenes missing a checksum value.
-func (qb *sceneQueryBuilder) CountMissingChecksum() (int, error) {
- return qb.runCountQuery(qb.buildCountQuery(countScenesForMissingChecksumQuery), []interface{}{})
+func (qb *sceneQueryBuilder) CountMissingChecksum(ctx context.Context) (int, error) {
+ return qb.runCountQuery(ctx, qb.buildCountQuery(countScenesForMissingChecksumQuery), []interface{}{})
}
// CountMissingOSHash returns the number of scenes missing an oshash value.
-func (qb *sceneQueryBuilder) CountMissingOSHash() (int, error) {
- return qb.runCountQuery(qb.buildCountQuery(countScenesForMissingOSHashQuery), []interface{}{})
+func (qb *sceneQueryBuilder) CountMissingOSHash(ctx context.Context) (int, error) {
+ return qb.runCountQuery(ctx, qb.buildCountQuery(countScenesForMissingOSHashQuery), []interface{}{})
}
-func (qb *sceneQueryBuilder) Wall(q *string) ([]*models.Scene, error) {
+func (qb *sceneQueryBuilder) Wall(ctx context.Context, q *string) ([]*models.Scene, error) {
s := ""
if q != nil {
s = *q
}
query := selectAll(sceneTable) + "WHERE scenes.details LIKE '%" + s + "%' ORDER BY RANDOM() LIMIT 80"
- return qb.queryScenes(query, nil)
+ return qb.queryScenes(ctx, query, nil)
}
-func (qb *sceneQueryBuilder) All() ([]*models.Scene, error) {
- return qb.queryScenes(selectAll(sceneTable)+qb.getDefaultSceneSort(), nil)
+func (qb *sceneQueryBuilder) All(ctx context.Context) ([]*models.Scene, error) {
+ return qb.queryScenes(ctx, selectAll(sceneTable)+qb.getDefaultSceneSort(), nil)
}
func illegalFilterCombination(type1, type2 string) error {
@@ -371,61 +369,61 @@ func (qb *sceneQueryBuilder) validateFilter(sceneFilter *models.SceneFilterType)
return nil
}
-func (qb *sceneQueryBuilder) makeFilter(sceneFilter *models.SceneFilterType) *filterBuilder {
+func (qb *sceneQueryBuilder) makeFilter(ctx context.Context, sceneFilter *models.SceneFilterType) *filterBuilder {
query := &filterBuilder{}
if sceneFilter.And != nil {
- query.and(qb.makeFilter(sceneFilter.And))
+ query.and(qb.makeFilter(ctx, sceneFilter.And))
}
if sceneFilter.Or != nil {
- query.or(qb.makeFilter(sceneFilter.Or))
+ query.or(qb.makeFilter(ctx, sceneFilter.Or))
}
if sceneFilter.Not != nil {
- query.not(qb.makeFilter(sceneFilter.Not))
+ query.not(qb.makeFilter(ctx, sceneFilter.Not))
}
- query.handleCriterion(stringCriterionHandler(sceneFilter.Path, "scenes.path"))
- query.handleCriterion(stringCriterionHandler(sceneFilter.Title, "scenes.title"))
- query.handleCriterion(stringCriterionHandler(sceneFilter.Details, "scenes.details"))
- query.handleCriterion(stringCriterionHandler(sceneFilter.Oshash, "scenes.oshash"))
- query.handleCriterion(stringCriterionHandler(sceneFilter.Checksum, "scenes.checksum"))
- query.handleCriterion(phashCriterionHandler(sceneFilter.Phash))
- query.handleCriterion(intCriterionHandler(sceneFilter.Rating, "scenes.rating"))
- query.handleCriterion(intCriterionHandler(sceneFilter.OCounter, "scenes.o_counter"))
- query.handleCriterion(boolCriterionHandler(sceneFilter.Organized, "scenes.organized"))
- query.handleCriterion(durationCriterionHandler(sceneFilter.Duration, "scenes.duration"))
- query.handleCriterion(resolutionCriterionHandler(sceneFilter.Resolution, "scenes.height", "scenes.width"))
- query.handleCriterion(hasMarkersCriterionHandler(sceneFilter.HasMarkers))
- query.handleCriterion(sceneIsMissingCriterionHandler(qb, sceneFilter.IsMissing))
- query.handleCriterion(stringCriterionHandler(sceneFilter.URL, "scenes.url"))
+ query.handleCriterion(ctx, stringCriterionHandler(sceneFilter.Path, "scenes.path"))
+ query.handleCriterion(ctx, stringCriterionHandler(sceneFilter.Title, "scenes.title"))
+ query.handleCriterion(ctx, stringCriterionHandler(sceneFilter.Details, "scenes.details"))
+ query.handleCriterion(ctx, stringCriterionHandler(sceneFilter.Oshash, "scenes.oshash"))
+ query.handleCriterion(ctx, stringCriterionHandler(sceneFilter.Checksum, "scenes.checksum"))
+ query.handleCriterion(ctx, phashCriterionHandler(sceneFilter.Phash))
+ query.handleCriterion(ctx, intCriterionHandler(sceneFilter.Rating, "scenes.rating"))
+ query.handleCriterion(ctx, intCriterionHandler(sceneFilter.OCounter, "scenes.o_counter"))
+ query.handleCriterion(ctx, boolCriterionHandler(sceneFilter.Organized, "scenes.organized"))
+ query.handleCriterion(ctx, durationCriterionHandler(sceneFilter.Duration, "scenes.duration"))
+ query.handleCriterion(ctx, resolutionCriterionHandler(sceneFilter.Resolution, "scenes.height", "scenes.width"))
+ query.handleCriterion(ctx, hasMarkersCriterionHandler(sceneFilter.HasMarkers))
+ query.handleCriterion(ctx, sceneIsMissingCriterionHandler(qb, sceneFilter.IsMissing))
+ query.handleCriterion(ctx, stringCriterionHandler(sceneFilter.URL, "scenes.url"))
- query.handleCriterion(criterionHandlerFunc(func(f *filterBuilder) {
+ query.handleCriterion(ctx, criterionHandlerFunc(func(ctx context.Context, f *filterBuilder) {
if sceneFilter.StashID != nil {
qb.stashIDRepository().join(f, "scene_stash_ids", "scenes.id")
- stringCriterionHandler(sceneFilter.StashID, "scene_stash_ids.stash_id")(f)
+ stringCriterionHandler(sceneFilter.StashID, "scene_stash_ids.stash_id")(ctx, f)
}
}))
- query.handleCriterion(boolCriterionHandler(sceneFilter.Interactive, "scenes.interactive"))
- query.handleCriterion(intCriterionHandler(sceneFilter.InteractiveSpeed, "scenes.interactive_speed"))
+ query.handleCriterion(ctx, boolCriterionHandler(sceneFilter.Interactive, "scenes.interactive"))
+ query.handleCriterion(ctx, intCriterionHandler(sceneFilter.InteractiveSpeed, "scenes.interactive_speed"))
- query.handleCriterion(sceneCaptionCriterionHandler(qb, sceneFilter.Captions))
+ query.handleCriterion(ctx, sceneCaptionCriterionHandler(qb, sceneFilter.Captions))
- query.handleCriterion(sceneTagsCriterionHandler(qb, sceneFilter.Tags))
- query.handleCriterion(sceneTagCountCriterionHandler(qb, sceneFilter.TagCount))
- query.handleCriterion(scenePerformersCriterionHandler(qb, sceneFilter.Performers))
- query.handleCriterion(scenePerformerCountCriterionHandler(qb, sceneFilter.PerformerCount))
- query.handleCriterion(sceneStudioCriterionHandler(qb, sceneFilter.Studios))
- query.handleCriterion(sceneMoviesCriterionHandler(qb, sceneFilter.Movies))
- query.handleCriterion(scenePerformerTagsCriterionHandler(qb, sceneFilter.PerformerTags))
- query.handleCriterion(scenePerformerFavoriteCriterionHandler(sceneFilter.PerformerFavorite))
- query.handleCriterion(scenePerformerAgeCriterionHandler(sceneFilter.PerformerAge))
- query.handleCriterion(scenePhashDuplicatedCriterionHandler(sceneFilter.Duplicated))
+ query.handleCriterion(ctx, sceneTagsCriterionHandler(qb, sceneFilter.Tags))
+ query.handleCriterion(ctx, sceneTagCountCriterionHandler(qb, sceneFilter.TagCount))
+ query.handleCriterion(ctx, scenePerformersCriterionHandler(qb, sceneFilter.Performers))
+ query.handleCriterion(ctx, scenePerformerCountCriterionHandler(qb, sceneFilter.PerformerCount))
+ query.handleCriterion(ctx, sceneStudioCriterionHandler(qb, sceneFilter.Studios))
+ query.handleCriterion(ctx, sceneMoviesCriterionHandler(qb, sceneFilter.Movies))
+ query.handleCriterion(ctx, scenePerformerTagsCriterionHandler(qb, sceneFilter.PerformerTags))
+ query.handleCriterion(ctx, scenePerformerFavoriteCriterionHandler(sceneFilter.PerformerFavorite))
+ query.handleCriterion(ctx, scenePerformerAgeCriterionHandler(sceneFilter.PerformerAge))
+ query.handleCriterion(ctx, scenePhashDuplicatedCriterionHandler(sceneFilter.Duplicated))
return query
}
-func (qb *sceneQueryBuilder) Query(options models.SceneQueryOptions) (*models.SceneQueryResult, error) {
+func (qb *sceneQueryBuilder) Query(ctx context.Context, options models.SceneQueryOptions) (*models.SceneQueryResult, error) {
sceneFilter := options.SceneFilter
findFilter := options.FindFilter
@@ -448,19 +446,19 @@ func (qb *sceneQueryBuilder) Query(options models.SceneQueryOptions) (*models.Sc
if err := qb.validateFilter(sceneFilter); err != nil {
return nil, err
}
- filter := qb.makeFilter(sceneFilter)
+ filter := qb.makeFilter(ctx, sceneFilter)
query.addFilter(filter)
qb.setSceneSort(&query, findFilter)
query.sortAndPagination += getPagination(findFilter)
- result, err := qb.queryGroupedFields(options, query)
+ result, err := qb.queryGroupedFields(ctx, options, query)
if err != nil {
return nil, fmt.Errorf("error querying aggregate fields: %w", err)
}
- idsResult, err := query.findIDs()
+ idsResult, err := query.findIDs(ctx)
if err != nil {
return nil, fmt.Errorf("error finding IDs: %w", err)
}
@@ -469,7 +467,7 @@ func (qb *sceneQueryBuilder) Query(options models.SceneQueryOptions) (*models.Sc
return result, nil
}
-func (qb *sceneQueryBuilder) queryGroupedFields(options models.SceneQueryOptions, query queryBuilder) (*models.SceneQueryResult, error) {
+func (qb *sceneQueryBuilder) queryGroupedFields(ctx context.Context, options models.SceneQueryOptions, query queryBuilder) (*models.SceneQueryResult, error) {
if !options.Count && !options.TotalDuration && !options.TotalSize {
// nothing to do - return empty result
return models.NewSceneQueryResult(qb), nil
@@ -499,7 +497,7 @@ func (qb *sceneQueryBuilder) queryGroupedFields(options models.SceneQueryOptions
Duration float64
Size float64
}{}
- if err := qb.repository.queryStruct(aggregateQuery.toSQL(includeSortPagination), query.args, &out); err != nil {
+ if err := qb.repository.queryStruct(ctx, aggregateQuery.toSQL(includeSortPagination), query.args, &out); err != nil {
return nil, err
}
@@ -511,7 +509,7 @@ func (qb *sceneQueryBuilder) queryGroupedFields(options models.SceneQueryOptions
}
func phashCriterionHandler(phashFilter *models.StringCriterionInput) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
if phashFilter != nil {
// convert value to int from hex
// ignore errors
@@ -534,7 +532,7 @@ func phashCriterionHandler(phashFilter *models.StringCriterionInput) criterionHa
}
func scenePhashDuplicatedCriterionHandler(duplicatedFilter *models.PHashDuplicationCriterionInput) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
// TODO: Wishlist item: Implement Distance matching
if duplicatedFilter != nil {
var v string
@@ -549,7 +547,7 @@ func scenePhashDuplicatedCriterionHandler(duplicatedFilter *models.PHashDuplicat
}
func durationCriterionHandler(durationFilter *models.IntCriterionInput, column string) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
if durationFilter != nil {
clause, args := getIntCriterionWhereClause("cast("+column+" as int)", *durationFilter)
f.addWhere(clause, args...)
@@ -558,7 +556,7 @@ func durationCriterionHandler(durationFilter *models.IntCriterionInput, column s
}
func resolutionCriterionHandler(resolution *models.ResolutionCriterionInput, heightColumn string, widthColumn string) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
if resolution != nil && resolution.Value.IsValid() {
min := resolution.Value.GetMinResolution()
max := resolution.Value.GetMaxResolution()
@@ -580,7 +578,7 @@ func resolutionCriterionHandler(resolution *models.ResolutionCriterionInput, hei
}
func hasMarkersCriterionHandler(hasMarkers *string) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
if hasMarkers != nil {
f.addLeftJoin("scene_markers", "", "scene_markers.scene_id = scenes.id")
if *hasMarkers == "true" {
@@ -593,7 +591,7 @@ func hasMarkersCriterionHandler(hasMarkers *string) criterionHandlerFunc {
}
func sceneIsMissingCriterionHandler(qb *sceneQueryBuilder, isMissing *string) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
if isMissing != nil && *isMissing != "" {
switch *isMissing {
case "galleries":
@@ -699,7 +697,7 @@ func scenePerformerCountCriterionHandler(qb *sceneQueryBuilder, performerCount *
}
func scenePerformerFavoriteCriterionHandler(performerfavorite *bool) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
if performerfavorite != nil {
f.addLeftJoin("performers_scenes", "", "scenes.id = performers_scenes.scene_id")
@@ -719,7 +717,7 @@ GROUP BY performers_scenes.scene_id HAVING SUM(performers.favorite) = 0)`, "nofa
}
func scenePerformerAgeCriterionHandler(performerAge *models.IntCriterionInput) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
if performerAge != nil {
f.addInnerJoin("performers_scenes", "", "scenes.id = performers_scenes.scene_id")
f.addInnerJoin("performers", "", "performers_scenes.performer_id = performers.id")
@@ -759,7 +757,7 @@ func sceneMoviesCriterionHandler(qb *sceneQueryBuilder, movies *models.MultiCrit
}
func scenePerformerTagsCriterionHandler(qb *sceneQueryBuilder, tags *models.HierarchicalMultiCriterionInput) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
if tags != nil {
if tags.Modifier == models.CriterionModifierIsNull || tags.Modifier == models.CriterionModifierNotNull {
var notClause string
@@ -778,7 +776,7 @@ func scenePerformerTagsCriterionHandler(qb *sceneQueryBuilder, tags *models.Hier
return
}
- valuesClause := getHierarchicalValues(qb.tx, tags.Value, tagTable, "tags_relations", "", tags.Depth)
+ valuesClause := getHierarchicalValues(ctx, qb.tx, tags.Value, tagTable, "tags_relations", "", tags.Depth)
f.addWith(`performer_tags AS (
SELECT ps.scene_id, t.column1 AS root_tag_id FROM performers_scenes ps
@@ -816,17 +814,17 @@ func (qb *sceneQueryBuilder) setSceneSort(query *queryBuilder, findFilter *model
}
}
-func (qb *sceneQueryBuilder) queryScene(query string, args []interface{}) (*models.Scene, error) {
- results, err := qb.queryScenes(query, args)
+func (qb *sceneQueryBuilder) queryScene(ctx context.Context, query string, args []interface{}) (*models.Scene, error) {
+ results, err := qb.queryScenes(ctx, query, args)
if err != nil || len(results) < 1 {
return nil, err
}
return results[0], nil
}
-func (qb *sceneQueryBuilder) queryScenes(query string, args []interface{}) ([]*models.Scene, error) {
+func (qb *sceneQueryBuilder) queryScenes(ctx context.Context, query string, args []interface{}) ([]*models.Scene, error) {
var ret models.Scenes
- if err := qb.query(query, args, &ret); err != nil {
+ if err := qb.query(ctx, query, args, &ret); err != nil {
return nil, err
}
@@ -844,16 +842,16 @@ func (qb *sceneQueryBuilder) imageRepository() *imageRepository {
}
}
-func (qb *sceneQueryBuilder) GetCover(sceneID int) ([]byte, error) {
- return qb.imageRepository().get(sceneID)
+func (qb *sceneQueryBuilder) GetCover(ctx context.Context, sceneID int) ([]byte, error) {
+ return qb.imageRepository().get(ctx, sceneID)
}
-func (qb *sceneQueryBuilder) UpdateCover(sceneID int, image []byte) error {
- return qb.imageRepository().replace(sceneID, image)
+func (qb *sceneQueryBuilder) UpdateCover(ctx context.Context, sceneID int, image []byte) error {
+ return qb.imageRepository().replace(ctx, sceneID, image)
}
-func (qb *sceneQueryBuilder) DestroyCover(sceneID int) error {
- return qb.imageRepository().destroy([]int{sceneID})
+func (qb *sceneQueryBuilder) DestroyCover(ctx context.Context, sceneID int) error {
+ return qb.imageRepository().destroy(ctx, []int{sceneID})
}
func (qb *sceneQueryBuilder) moviesRepository() *repository {
@@ -864,8 +862,8 @@ func (qb *sceneQueryBuilder) moviesRepository() *repository {
}
}
-func (qb *sceneQueryBuilder) GetMovies(id int) (ret []models.MoviesScenes, err error) {
- if err := qb.moviesRepository().getAll(id, func(rows *sqlx.Rows) error {
+func (qb *sceneQueryBuilder) GetMovies(ctx context.Context, id int) (ret []models.MoviesScenes, err error) {
+ if err := qb.moviesRepository().getAll(ctx, id, func(rows *sqlx.Rows) error {
var ms models.MoviesScenes
if err := rows.StructScan(&ms); err != nil {
return err
@@ -880,16 +878,16 @@ func (qb *sceneQueryBuilder) GetMovies(id int) (ret []models.MoviesScenes, err e
return ret, nil
}
-func (qb *sceneQueryBuilder) UpdateMovies(sceneID int, movies []models.MoviesScenes) error {
+func (qb *sceneQueryBuilder) UpdateMovies(ctx context.Context, sceneID int, movies []models.MoviesScenes) error {
// destroy existing joins
r := qb.moviesRepository()
- if err := r.destroy([]int{sceneID}); err != nil {
+ if err := r.destroy(ctx, []int{sceneID}); err != nil {
return err
}
for _, m := range movies {
m.SceneID = sceneID
- if _, err := r.insert(m); err != nil {
+ if _, err := r.insert(ctx, m); err != nil {
return err
}
}
@@ -908,13 +906,13 @@ func (qb *sceneQueryBuilder) performersRepository() *joinRepository {
}
}
-func (qb *sceneQueryBuilder) GetPerformerIDs(id int) ([]int, error) {
- return qb.performersRepository().getIDs(id)
+func (qb *sceneQueryBuilder) GetPerformerIDs(ctx context.Context, id int) ([]int, error) {
+ return qb.performersRepository().getIDs(ctx, id)
}
-func (qb *sceneQueryBuilder) UpdatePerformers(id int, performerIDs []int) error {
+func (qb *sceneQueryBuilder) UpdatePerformers(ctx context.Context, id int, performerIDs []int) error {
// Delete the existing joins and then create new ones
- return qb.performersRepository().replace(id, performerIDs)
+ return qb.performersRepository().replace(ctx, id, performerIDs)
}
func (qb *sceneQueryBuilder) tagsRepository() *joinRepository {
@@ -928,13 +926,13 @@ func (qb *sceneQueryBuilder) tagsRepository() *joinRepository {
}
}
-func (qb *sceneQueryBuilder) GetTagIDs(id int) ([]int, error) {
- return qb.tagsRepository().getIDs(id)
+func (qb *sceneQueryBuilder) GetTagIDs(ctx context.Context, id int) ([]int, error) {
+ return qb.tagsRepository().getIDs(ctx, id)
}
-func (qb *sceneQueryBuilder) UpdateTags(id int, tagIDs []int) error {
+func (qb *sceneQueryBuilder) UpdateTags(ctx context.Context, id int, tagIDs []int) error {
// Delete the existing joins and then create new ones
- return qb.tagsRepository().replace(id, tagIDs)
+ return qb.tagsRepository().replace(ctx, id, tagIDs)
}
func (qb *sceneQueryBuilder) galleriesRepository() *joinRepository {
@@ -948,13 +946,13 @@ func (qb *sceneQueryBuilder) galleriesRepository() *joinRepository {
}
}
-func (qb *sceneQueryBuilder) GetGalleryIDs(id int) ([]int, error) {
- return qb.galleriesRepository().getIDs(id)
+func (qb *sceneQueryBuilder) GetGalleryIDs(ctx context.Context, id int) ([]int, error) {
+ return qb.galleriesRepository().getIDs(ctx, id)
}
-func (qb *sceneQueryBuilder) UpdateGalleries(id int, galleryIDs []int) error {
+func (qb *sceneQueryBuilder) UpdateGalleries(ctx context.Context, id int, galleryIDs []int) error {
// Delete the existing joins and then create new ones
- return qb.galleriesRepository().replace(id, galleryIDs)
+ return qb.galleriesRepository().replace(ctx, id, galleryIDs)
}
func (qb *sceneQueryBuilder) stashIDRepository() *stashIDRepository {
@@ -967,19 +965,19 @@ func (qb *sceneQueryBuilder) stashIDRepository() *stashIDRepository {
}
}
-func (qb *sceneQueryBuilder) GetStashIDs(sceneID int) ([]*models.StashID, error) {
- return qb.stashIDRepository().get(sceneID)
+func (qb *sceneQueryBuilder) GetStashIDs(ctx context.Context, sceneID int) ([]*models.StashID, error) {
+ return qb.stashIDRepository().get(ctx, sceneID)
}
-func (qb *sceneQueryBuilder) UpdateStashIDs(sceneID int, stashIDs []models.StashID) error {
- return qb.stashIDRepository().replace(sceneID, stashIDs)
+func (qb *sceneQueryBuilder) UpdateStashIDs(ctx context.Context, sceneID int, stashIDs []models.StashID) error {
+ return qb.stashIDRepository().replace(ctx, sceneID, stashIDs)
}
-func (qb *sceneQueryBuilder) FindDuplicates(distance int) ([][]*models.Scene, error) {
+func (qb *sceneQueryBuilder) FindDuplicates(ctx context.Context, distance int) ([][]*models.Scene, error) {
var dupeIds [][]int
if distance == 0 {
var ids []string
- if err := qb.tx.Select(&ids, findExactDuplicateQuery); err != nil {
+ if err := qb.tx.Select(ctx, &ids, findExactDuplicateQuery); err != nil {
return nil, err
}
@@ -996,7 +994,7 @@ func (qb *sceneQueryBuilder) FindDuplicates(distance int) ([][]*models.Scene, er
} else {
var hashes []*utils.Phash
- if err := qb.queryFunc(findAllPhashesQuery, nil, false, func(rows *sqlx.Rows) error {
+ if err := qb.queryFunc(ctx, findAllPhashesQuery, nil, false, func(rows *sqlx.Rows) error {
phash := utils.Phash{
Bucket: -1,
}
@@ -1015,7 +1013,7 @@ func (qb *sceneQueryBuilder) FindDuplicates(distance int) ([][]*models.Scene, er
var duplicates [][]*models.Scene
for _, sceneIds := range dupeIds {
- if scenes, err := qb.FindMany(sceneIds); err == nil {
+ if scenes, err := qb.FindMany(ctx, sceneIds); err == nil {
duplicates = append(duplicates, scenes)
}
}
diff --git a/pkg/sqlite/scene_marker.go b/pkg/sqlite/scene_marker.go
index c79c1dc16..c8091d2a5 100644
--- a/pkg/sqlite/scene_marker.go
+++ b/pkg/sqlite/scene_marker.go
@@ -1,11 +1,11 @@
package sqlite
import (
+ "context"
"database/sql"
"errors"
"fmt"
- "github.com/stashapp/stash/pkg/database"
"github.com/stashapp/stash/pkg/models"
)
@@ -22,57 +22,54 @@ type sceneMarkerQueryBuilder struct {
repository
}
-func NewSceneMarkerReaderWriter(tx dbi) *sceneMarkerQueryBuilder {
- return &sceneMarkerQueryBuilder{
- repository{
- tx: tx,
- tableName: sceneMarkerTable,
- idColumn: idColumn,
- },
- }
+var SceneMarkerReaderWriter = &sceneMarkerQueryBuilder{
+ repository{
+ tableName: sceneMarkerTable,
+ idColumn: idColumn,
+ },
}
-func (qb *sceneMarkerQueryBuilder) Create(newObject models.SceneMarker) (*models.SceneMarker, error) {
+func (qb *sceneMarkerQueryBuilder) Create(ctx context.Context, newObject models.SceneMarker) (*models.SceneMarker, error) {
var ret models.SceneMarker
- if err := qb.insertObject(newObject, &ret); err != nil {
+ if err := qb.insertObject(ctx, newObject, &ret); err != nil {
return nil, err
}
return &ret, nil
}
-func (qb *sceneMarkerQueryBuilder) Update(updatedObject models.SceneMarker) (*models.SceneMarker, error) {
+func (qb *sceneMarkerQueryBuilder) Update(ctx context.Context, updatedObject models.SceneMarker) (*models.SceneMarker, error) {
const partial = false
- if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil {
+ if err := qb.update(ctx, updatedObject.ID, updatedObject, partial); err != nil {
return nil, err
}
var ret models.SceneMarker
- if err := qb.get(updatedObject.ID, &ret); err != nil {
+ if err := qb.getByID(ctx, updatedObject.ID, &ret); err != nil {
return nil, err
}
return &ret, nil
}
-func (qb *sceneMarkerQueryBuilder) Destroy(id int) error {
- return qb.destroyExisting([]int{id})
+func (qb *sceneMarkerQueryBuilder) Destroy(ctx context.Context, id int) error {
+ return qb.destroyExisting(ctx, []int{id})
}
-func (qb *sceneMarkerQueryBuilder) Find(id int) (*models.SceneMarker, error) {
+func (qb *sceneMarkerQueryBuilder) Find(ctx context.Context, id int) (*models.SceneMarker, error) {
query := "SELECT * FROM scene_markers WHERE id = ? LIMIT 1"
args := []interface{}{id}
- results, err := qb.querySceneMarkers(query, args)
+ results, err := qb.querySceneMarkers(ctx, query, args)
if err != nil || len(results) < 1 {
return nil, err
}
return results[0], nil
}
-func (qb *sceneMarkerQueryBuilder) FindMany(ids []int) ([]*models.SceneMarker, error) {
+func (qb *sceneMarkerQueryBuilder) FindMany(ctx context.Context, ids []int) ([]*models.SceneMarker, error) {
var markers []*models.SceneMarker
for _, id := range ids {
- marker, err := qb.Find(id)
+ marker, err := qb.Find(ctx, id)
if err != nil {
return nil, err
}
@@ -87,7 +84,7 @@ func (qb *sceneMarkerQueryBuilder) FindMany(ids []int) ([]*models.SceneMarker, e
return markers, nil
}
-func (qb *sceneMarkerQueryBuilder) FindBySceneID(sceneID int) ([]*models.SceneMarker, error) {
+func (qb *sceneMarkerQueryBuilder) FindBySceneID(ctx context.Context, sceneID int) ([]*models.SceneMarker, error) {
query := `
SELECT scene_markers.* FROM scene_markers
WHERE scene_markers.scene_id = ?
@@ -95,15 +92,15 @@ func (qb *sceneMarkerQueryBuilder) FindBySceneID(sceneID int) ([]*models.SceneMa
ORDER BY scene_markers.seconds ASC
`
args := []interface{}{sceneID}
- return qb.querySceneMarkers(query, args)
+ return qb.querySceneMarkers(ctx, query, args)
}
-func (qb *sceneMarkerQueryBuilder) CountByTagID(tagID int) (int, error) {
+func (qb *sceneMarkerQueryBuilder) CountByTagID(ctx context.Context, tagID int) (int, error) {
args := []interface{}{tagID, tagID}
- return qb.runCountQuery(qb.buildCountQuery(countSceneMarkersForTagQuery), args)
+ return qb.runCountQuery(ctx, qb.buildCountQuery(countSceneMarkersForTagQuery), args)
}
-func (qb *sceneMarkerQueryBuilder) GetMarkerStrings(q *string, sort *string) ([]*models.MarkerStringsResultType, error) {
+func (qb *sceneMarkerQueryBuilder) GetMarkerStrings(ctx context.Context, q *string, sort *string) ([]*models.MarkerStringsResultType, error) {
query := "SELECT count(*) as `count`, scene_markers.id as id, scene_markers.title as title FROM scene_markers"
if q != nil {
query += " WHERE title LIKE '%" + *q + "%'"
@@ -115,30 +112,30 @@ func (qb *sceneMarkerQueryBuilder) GetMarkerStrings(q *string, sort *string) ([]
query += " ORDER BY title ASC"
}
var args []interface{}
- return qb.queryMarkerStringsResultType(query, args)
+ return qb.queryMarkerStringsResultType(ctx, query, args)
}
-func (qb *sceneMarkerQueryBuilder) Wall(q *string) ([]*models.SceneMarker, error) {
+func (qb *sceneMarkerQueryBuilder) Wall(ctx context.Context, q *string) ([]*models.SceneMarker, error) {
s := ""
if q != nil {
s = *q
}
query := "SELECT scene_markers.* FROM scene_markers WHERE scene_markers.title LIKE '%" + s + "%' ORDER BY RANDOM() LIMIT 80"
- return qb.querySceneMarkers(query, nil)
+ return qb.querySceneMarkers(ctx, query, nil)
}
-func (qb *sceneMarkerQueryBuilder) makeFilter(sceneMarkerFilter *models.SceneMarkerFilterType) *filterBuilder {
+func (qb *sceneMarkerQueryBuilder) makeFilter(ctx context.Context, sceneMarkerFilter *models.SceneMarkerFilterType) *filterBuilder {
query := &filterBuilder{}
- query.handleCriterion(sceneMarkerTagIDCriterionHandler(qb, sceneMarkerFilter.TagID))
- query.handleCriterion(sceneMarkerTagsCriterionHandler(qb, sceneMarkerFilter.Tags))
- query.handleCriterion(sceneMarkerSceneTagsCriterionHandler(qb, sceneMarkerFilter.SceneTags))
- query.handleCriterion(sceneMarkerPerformersCriterionHandler(qb, sceneMarkerFilter.Performers))
+ query.handleCriterion(ctx, sceneMarkerTagIDCriterionHandler(qb, sceneMarkerFilter.TagID))
+ query.handleCriterion(ctx, sceneMarkerTagsCriterionHandler(qb, sceneMarkerFilter.Tags))
+ query.handleCriterion(ctx, sceneMarkerSceneTagsCriterionHandler(qb, sceneMarkerFilter.SceneTags))
+ query.handleCriterion(ctx, sceneMarkerPerformersCriterionHandler(qb, sceneMarkerFilter.Performers))
return query
}
-func (qb *sceneMarkerQueryBuilder) Query(sceneMarkerFilter *models.SceneMarkerFilterType, findFilter *models.FindFilterType) ([]*models.SceneMarker, int, error) {
+func (qb *sceneMarkerQueryBuilder) Query(ctx context.Context, sceneMarkerFilter *models.SceneMarkerFilterType, findFilter *models.FindFilterType) ([]*models.SceneMarker, int, error) {
if sceneMarkerFilter == nil {
sceneMarkerFilter = &models.SceneMarkerFilterType{}
}
@@ -154,19 +151,19 @@ func (qb *sceneMarkerQueryBuilder) Query(sceneMarkerFilter *models.SceneMarkerFi
query.parseQueryString(searchColumns, *q)
}
- filter := qb.makeFilter(sceneMarkerFilter)
+ filter := qb.makeFilter(ctx, sceneMarkerFilter)
query.addFilter(filter)
query.sortAndPagination = qb.getSceneMarkerSort(&query, findFilter) + getPagination(findFilter)
- idsResult, countResult, err := query.executeFind()
+ idsResult, countResult, err := query.executeFind(ctx)
if err != nil {
return nil, 0, err
}
var sceneMarkers []*models.SceneMarker
for _, id := range idsResult {
- sceneMarker, err := qb.Find(id)
+ sceneMarker, err := qb.Find(ctx, id)
if err != nil {
return nil, 0, err
}
@@ -178,7 +175,7 @@ func (qb *sceneMarkerQueryBuilder) Query(sceneMarkerFilter *models.SceneMarkerFi
}
func sceneMarkerTagIDCriterionHandler(qb *sceneMarkerQueryBuilder, tagID *string) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
if tagID != nil {
f.addLeftJoin("scene_markers_tags", "", "scene_markers_tags.scene_marker_id = scene_markers.id")
@@ -188,7 +185,7 @@ func sceneMarkerTagIDCriterionHandler(qb *sceneMarkerQueryBuilder, tagID *string
}
func sceneMarkerTagsCriterionHandler(qb *sceneMarkerQueryBuilder, tags *models.HierarchicalMultiCriterionInput) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
if tags != nil {
if tags.Modifier == models.CriterionModifierIsNull || tags.Modifier == models.CriterionModifierNotNull {
var notClause string
@@ -205,7 +202,7 @@ func sceneMarkerTagsCriterionHandler(qb *sceneMarkerQueryBuilder, tags *models.H
if len(tags.Value) == 0 {
return
}
- valuesClause := getHierarchicalValues(qb.tx, tags.Value, tagTable, "tags_relations", "", tags.Depth)
+ valuesClause := getHierarchicalValues(ctx, qb.tx, tags.Value, tagTable, "tags_relations", "", tags.Depth)
f.addWith(`marker_tags AS (
SELECT mt.scene_marker_id, t.column1 AS root_tag_id FROM scene_markers_tags mt
@@ -223,7 +220,7 @@ INNER JOIN (` + valuesClause + `) t ON t.column2 = m.primary_tag_id
}
func sceneMarkerSceneTagsCriterionHandler(qb *sceneMarkerQueryBuilder, tags *models.HierarchicalMultiCriterionInput) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
if tags != nil {
if tags.Modifier == models.CriterionModifierIsNull || tags.Modifier == models.CriterionModifierNotNull {
var notClause string
@@ -241,7 +238,7 @@ func sceneMarkerSceneTagsCriterionHandler(qb *sceneMarkerQueryBuilder, tags *mod
return
}
- valuesClause := getHierarchicalValues(qb.tx, tags.Value, tagTable, "tags_relations", "", tags.Depth)
+ valuesClause := getHierarchicalValues(ctx, qb.tx, tags.Value, tagTable, "tags_relations", "", tags.Depth)
f.addWith(`scene_tags AS (
SELECT st.scene_id, t.column1 AS root_tag_id FROM scenes_tags st
@@ -269,10 +266,10 @@ func sceneMarkerPerformersCriterionHandler(qb *sceneMarkerQueryBuilder, performe
}
handler := h.handler(performers)
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
// Make sure scenes is included, otherwise excludes filter fails
f.addLeftJoin(sceneTable, "", "scenes.id = scene_markers.scene_id")
- handler(f)
+ handler(ctx, f)
}
}
@@ -289,17 +286,17 @@ func (qb *sceneMarkerQueryBuilder) getSceneMarkerSort(query *queryBuilder, findF
return getSort(sort, direction, tableName)
}
-func (qb *sceneMarkerQueryBuilder) querySceneMarkers(query string, args []interface{}) ([]*models.SceneMarker, error) {
+func (qb *sceneMarkerQueryBuilder) querySceneMarkers(ctx context.Context, query string, args []interface{}) ([]*models.SceneMarker, error) {
var ret models.SceneMarkers
- if err := qb.query(query, args, &ret); err != nil {
+ if err := qb.query(ctx, query, args, &ret); err != nil {
return nil, err
}
return []*models.SceneMarker(ret), nil
}
-func (qb *sceneMarkerQueryBuilder) queryMarkerStringsResultType(query string, args []interface{}) ([]*models.MarkerStringsResultType, error) {
- rows, err := database.DB.Queryx(query, args...)
+func (qb *sceneMarkerQueryBuilder) queryMarkerStringsResultType(ctx context.Context, query string, args []interface{}) ([]*models.MarkerStringsResultType, error) {
+ rows, err := qb.tx.Queryx(ctx, query, args...)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return nil, err
}
@@ -332,11 +329,11 @@ func (qb *sceneMarkerQueryBuilder) tagsRepository() *joinRepository {
}
}
-func (qb *sceneMarkerQueryBuilder) GetTagIDs(id int) ([]int, error) {
- return qb.tagsRepository().getIDs(id)
+func (qb *sceneMarkerQueryBuilder) GetTagIDs(ctx context.Context, id int) ([]int, error) {
+ return qb.tagsRepository().getIDs(ctx, id)
}
-func (qb *sceneMarkerQueryBuilder) UpdateTags(id int, tagIDs []int) error {
+func (qb *sceneMarkerQueryBuilder) UpdateTags(ctx context.Context, id int, tagIDs []int) error {
// Delete the existing joins and then create new ones
- return qb.tagsRepository().replace(id, tagIDs)
+ return qb.tagsRepository().replace(ctx, id, tagIDs)
}
diff --git a/pkg/sqlite/scene_marker_test.go b/pkg/sqlite/scene_marker_test.go
index 2fa0d7501..c0d29162e 100644
--- a/pkg/sqlite/scene_marker_test.go
+++ b/pkg/sqlite/scene_marker_test.go
@@ -4,18 +4,20 @@
package sqlite_test
import (
+ "context"
"testing"
"github.com/stashapp/stash/pkg/models"
+ "github.com/stashapp/stash/pkg/sqlite"
"github.com/stretchr/testify/assert"
)
func TestMarkerFindBySceneID(t *testing.T) {
- withTxn(func(r models.Repository) error {
- mqb := r.SceneMarker()
+ withTxn(func(ctx context.Context) error {
+ mqb := sqlite.SceneMarkerReaderWriter
sceneID := sceneIDs[sceneIdxWithMarkers]
- markers, err := mqb.FindBySceneID(sceneID)
+ markers, err := mqb.FindBySceneID(ctx, sceneID)
if err != nil {
t.Errorf("Error finding markers: %s", err.Error())
@@ -26,7 +28,7 @@ func TestMarkerFindBySceneID(t *testing.T) {
assert.Equal(t, sceneIDs[sceneIdxWithMarkers], int(marker.SceneID.Int64))
}
- markers, err = mqb.FindBySceneID(0)
+ markers, err = mqb.FindBySceneID(ctx, 0)
if err != nil {
t.Errorf("Error finding marker: %s", err.Error())
@@ -39,10 +41,10 @@ func TestMarkerFindBySceneID(t *testing.T) {
}
func TestMarkerCountByTagID(t *testing.T) {
- withTxn(func(r models.Repository) error {
- mqb := r.SceneMarker()
+ withTxn(func(ctx context.Context) error {
+ mqb := sqlite.SceneMarkerReaderWriter
- markerCount, err := mqb.CountByTagID(tagIDs[tagIdxWithPrimaryMarkers])
+ markerCount, err := mqb.CountByTagID(ctx, tagIDs[tagIdxWithPrimaryMarkers])
if err != nil {
t.Errorf("error calling CountByTagID: %s", err.Error())
@@ -50,7 +52,7 @@ func TestMarkerCountByTagID(t *testing.T) {
assert.Equal(t, 3, markerCount)
- markerCount, err = mqb.CountByTagID(tagIDs[tagIdxWithMarkers])
+ markerCount, err = mqb.CountByTagID(ctx, tagIDs[tagIdxWithMarkers])
if err != nil {
t.Errorf("error calling CountByTagID: %s", err.Error())
@@ -58,7 +60,7 @@ func TestMarkerCountByTagID(t *testing.T) {
assert.Equal(t, 1, markerCount)
- markerCount, err = mqb.CountByTagID(0)
+ markerCount, err = mqb.CountByTagID(ctx, 0)
if err != nil {
t.Errorf("error calling CountByTagID: %s", err.Error())
@@ -71,9 +73,9 @@ func TestMarkerCountByTagID(t *testing.T) {
}
func TestMarkerQuerySortBySceneUpdated(t *testing.T) {
- withTxn(func(r models.Repository) error {
+ withTxn(func(ctx context.Context) error {
sort := "scenes_updated_at"
- _, _, err := r.SceneMarker().Query(nil, &models.FindFilterType{
+ _, _, err := sqlite.SceneMarkerReaderWriter.Query(ctx, nil, &models.FindFilterType{
Sort: &sort,
})
@@ -92,9 +94,9 @@ func TestMarkerQueryTags(t *testing.T) {
findFilter *models.FindFilterType
}
- withTxn(func(r models.Repository) error {
+ withTxn(func(ctx context.Context) error {
testTags := func(m *models.SceneMarker, markerFilter *models.SceneMarkerFilterType) {
- tagIDs, err := r.SceneMarker().GetTagIDs(m.ID)
+ tagIDs, err := sqlite.SceneMarkerReaderWriter.GetTagIDs(ctx, m.ID)
if err != nil {
t.Errorf("error getting marker tag ids: %v", err)
}
@@ -129,7 +131,7 @@ func TestMarkerQueryTags(t *testing.T) {
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
- markers := queryMarkers(t, r.SceneMarker(), tc.markerFilter, tc.findFilter)
+ markers := queryMarkers(ctx, t, sqlite.SceneMarkerReaderWriter, tc.markerFilter, tc.findFilter)
assert.Greater(t, len(markers), 0)
for _, m := range markers {
testTags(m, tc.markerFilter)
@@ -148,9 +150,9 @@ func TestMarkerQuerySceneTags(t *testing.T) {
findFilter *models.FindFilterType
}
- withTxn(func(r models.Repository) error {
+ withTxn(func(ctx context.Context) error {
testTags := func(m *models.SceneMarker, markerFilter *models.SceneMarkerFilterType) {
- tagIDs, err := r.Scene().GetTagIDs(int(m.SceneID.Int64))
+ tagIDs, err := sqlite.SceneReaderWriter.GetTagIDs(ctx, int(m.SceneID.Int64))
if err != nil {
t.Errorf("error getting marker tag ids: %v", err)
}
@@ -185,7 +187,7 @@ func TestMarkerQuerySceneTags(t *testing.T) {
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
- markers := queryMarkers(t, r.SceneMarker(), tc.markerFilter, tc.findFilter)
+ markers := queryMarkers(ctx, t, sqlite.SceneMarkerReaderWriter, tc.markerFilter, tc.findFilter)
assert.Greater(t, len(markers), 0)
for _, m := range markers {
testTags(m, tc.markerFilter)
@@ -197,9 +199,9 @@ func TestMarkerQuerySceneTags(t *testing.T) {
})
}
-func queryMarkers(t *testing.T, sqb models.SceneMarkerReader, markerFilter *models.SceneMarkerFilterType, findFilter *models.FindFilterType) []*models.SceneMarker {
+func queryMarkers(ctx context.Context, t *testing.T, sqb models.SceneMarkerReader, markerFilter *models.SceneMarkerFilterType, findFilter *models.FindFilterType) []*models.SceneMarker {
t.Helper()
- result, _, err := sqb.Query(markerFilter, findFilter)
+ result, _, err := sqb.Query(ctx, markerFilter, findFilter)
if err != nil {
t.Errorf("Error querying markers: %v", err)
}
diff --git a/pkg/sqlite/scene_test.go b/pkg/sqlite/scene_test.go
index dc70d7637..da88a0bdc 100644
--- a/pkg/sqlite/scene_test.go
+++ b/pkg/sqlite/scene_test.go
@@ -4,6 +4,7 @@
package sqlite_test
import (
+ "context"
"database/sql"
"fmt"
"math"
@@ -15,16 +16,17 @@ import (
"github.com/stashapp/stash/pkg/hash/md5"
"github.com/stashapp/stash/pkg/models"
+ "github.com/stashapp/stash/pkg/sqlite"
)
func TestSceneFind(t *testing.T) {
- withTxn(func(r models.Repository) error {
+ withTxn(func(ctx context.Context) error {
// assume that the first scene is sceneWithGalleryPath
- sqb := r.Scene()
+ sqb := sqlite.SceneReaderWriter
const sceneIdx = 0
sceneID := sceneIDs[sceneIdx]
- scene, err := sqb.Find(sceneID)
+ scene, err := sqb.Find(ctx, sceneID)
if err != nil {
t.Errorf("Error finding scene: %s", err.Error())
@@ -33,7 +35,7 @@ func TestSceneFind(t *testing.T) {
assert.Equal(t, getSceneStringValue(sceneIdx, "Path"), scene.Path)
sceneID = 0
- scene, err = sqb.Find(sceneID)
+ scene, err = sqb.Find(ctx, sceneID)
if err != nil {
t.Errorf("Error finding scene: %s", err.Error())
@@ -46,12 +48,12 @@ func TestSceneFind(t *testing.T) {
}
func TestSceneFindByPath(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Scene()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.SceneReaderWriter
const sceneIdx = 1
scenePath := getSceneStringValue(sceneIdx, "Path")
- scene, err := sqb.FindByPath(scenePath)
+ scene, err := sqb.FindByPath(ctx, scenePath)
if err != nil {
t.Errorf("Error finding scene: %s", err.Error())
@@ -61,7 +63,7 @@ func TestSceneFindByPath(t *testing.T) {
assert.Equal(t, scenePath, scene.Path)
scenePath = "not exist"
- scene, err = sqb.FindByPath(scenePath)
+ scene, err = sqb.FindByPath(ctx, scenePath)
if err != nil {
t.Errorf("Error finding scene: %s", err.Error())
@@ -74,9 +76,9 @@ func TestSceneFindByPath(t *testing.T) {
}
func TestSceneCountByPerformerID(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Scene()
- count, err := sqb.CountByPerformerID(performerIDs[performerIdxWithScene])
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.SceneReaderWriter
+ count, err := sqb.CountByPerformerID(ctx, performerIDs[performerIdxWithScene])
if err != nil {
t.Errorf("Error counting scenes: %s", err.Error())
@@ -84,7 +86,7 @@ func TestSceneCountByPerformerID(t *testing.T) {
assert.Equal(t, 1, count)
- count, err = sqb.CountByPerformerID(0)
+ count, err = sqb.CountByPerformerID(ctx, 0)
if err != nil {
t.Errorf("Error counting scenes: %s", err.Error())
@@ -97,12 +99,12 @@ func TestSceneCountByPerformerID(t *testing.T) {
}
func TestSceneWall(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Scene()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.SceneReaderWriter
const sceneIdx = 2
wallQuery := getSceneStringValue(sceneIdx, "Details")
- scenes, err := sqb.Wall(&wallQuery)
+ scenes, err := sqb.Wall(ctx, &wallQuery)
if err != nil {
t.Errorf("Error finding scenes: %s", err.Error())
@@ -114,7 +116,7 @@ func TestSceneWall(t *testing.T) {
assert.Equal(t, getSceneStringValue(sceneIdx, "Path"), scene.Path)
wallQuery = "not exist"
- scenes, err = sqb.Wall(&wallQuery)
+ scenes, err = sqb.Wall(ctx, &wallQuery)
if err != nil {
t.Errorf("Error finding scene: %s", err.Error())
@@ -131,18 +133,18 @@ func TestSceneQueryQ(t *testing.T) {
q := getSceneStringValue(sceneIdx, titleField)
- withTxn(func(r models.Repository) error {
- sqb := r.Scene()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.SceneReaderWriter
- sceneQueryQ(t, sqb, q, sceneIdx)
+ sceneQueryQ(ctx, t, sqb, q, sceneIdx)
return nil
})
}
-func queryScene(t *testing.T, sqb models.SceneReader, sceneFilter *models.SceneFilterType, findFilter *models.FindFilterType) []*models.Scene {
+func queryScene(ctx context.Context, t *testing.T, sqb models.SceneReader, sceneFilter *models.SceneFilterType, findFilter *models.FindFilterType) []*models.Scene {
t.Helper()
- result, err := sqb.Query(models.SceneQueryOptions{
+ result, err := sqb.Query(ctx, models.SceneQueryOptions{
QueryOptions: models.QueryOptions{
FindFilter: findFilter,
},
@@ -152,7 +154,7 @@ func queryScene(t *testing.T, sqb models.SceneReader, sceneFilter *models.SceneF
t.Errorf("Error querying scene: %v", err)
}
- scenes, err := result.Resolve()
+ scenes, err := result.Resolve(ctx)
if err != nil {
t.Errorf("Error resolving scenes: %v", err)
}
@@ -160,11 +162,11 @@ func queryScene(t *testing.T, sqb models.SceneReader, sceneFilter *models.SceneF
return scenes
}
-func sceneQueryQ(t *testing.T, sqb models.SceneReader, q string, expectedSceneIdx int) {
+func sceneQueryQ(ctx context.Context, t *testing.T, sqb models.SceneReader, q string, expectedSceneIdx int) {
filter := models.FindFilterType{
Q: &q,
}
- scenes := queryScene(t, sqb, nil, &filter)
+ scenes := queryScene(ctx, t, sqb, nil, &filter)
assert.Len(t, scenes, 1)
scene := scenes[0]
@@ -172,7 +174,7 @@ func sceneQueryQ(t *testing.T, sqb models.SceneReader, q string, expectedSceneId
// no Q should return all results
filter.Q = nil
- scenes = queryScene(t, sqb, nil, &filter)
+ scenes = queryScene(ctx, t, sqb, nil, &filter)
assert.Len(t, scenes, totalScenes)
}
@@ -257,10 +259,10 @@ func TestSceneQueryPathOr(t *testing.T) {
},
}
- withTxn(func(r models.Repository) error {
- sqb := r.Scene()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.SceneReaderWriter
- scenes := queryScene(t, sqb, &sceneFilter, nil)
+ scenes := queryScene(ctx, t, sqb, &sceneFilter, nil)
assert.Len(t, scenes, 2)
assert.Equal(t, scene1Path, scenes[0].Path)
@@ -288,10 +290,10 @@ func TestSceneQueryPathAndRating(t *testing.T) {
},
}
- withTxn(func(r models.Repository) error {
- sqb := r.Scene()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.SceneReaderWriter
- scenes := queryScene(t, sqb, &sceneFilter, nil)
+ scenes := queryScene(ctx, t, sqb, &sceneFilter, nil)
assert.Len(t, scenes, 1)
assert.Equal(t, scenePath, scenes[0].Path)
@@ -323,10 +325,10 @@ func TestSceneQueryPathNotRating(t *testing.T) {
},
}
- withTxn(func(r models.Repository) error {
- sqb := r.Scene()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.SceneReaderWriter
- scenes := queryScene(t, sqb, &sceneFilter, nil)
+ scenes := queryScene(ctx, t, sqb, &sceneFilter, nil)
for _, scene := range scenes {
verifyString(t, scene.Path, pathCriterion)
@@ -354,24 +356,24 @@ func TestSceneIllegalQuery(t *testing.T) {
Or: &subFilter,
}
- withTxn(func(r models.Repository) error {
- sqb := r.Scene()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.SceneReaderWriter
queryOptions := models.SceneQueryOptions{
SceneFilter: sceneFilter,
}
- _, err := sqb.Query(queryOptions)
+ _, err := sqb.Query(ctx, queryOptions)
assert.NotNil(err)
sceneFilter.Or = nil
sceneFilter.Not = &subFilter
- _, err = sqb.Query(queryOptions)
+ _, err = sqb.Query(ctx, queryOptions)
assert.NotNil(err)
sceneFilter.And = nil
sceneFilter.Or = &subFilter
- _, err = sqb.Query(queryOptions)
+ _, err = sqb.Query(ctx, queryOptions)
assert.NotNil(err)
return nil
@@ -379,11 +381,11 @@ func TestSceneIllegalQuery(t *testing.T) {
}
func verifySceneQuery(t *testing.T, filter models.SceneFilterType, verifyFn func(s *models.Scene)) {
- withTxn(func(r models.Repository) error {
+ withTxn(func(ctx context.Context) error {
t.Helper()
- sqb := r.Scene()
+ sqb := sqlite.SceneReaderWriter
- scenes := queryScene(t, sqb, &filter, nil)
+ scenes := queryScene(ctx, t, sqb, &filter, nil)
// assume it should find at least one
assert.Greater(t, len(scenes), 0)
@@ -397,13 +399,13 @@ func verifySceneQuery(t *testing.T, filter models.SceneFilterType, verifyFn func
}
func verifyScenesPath(t *testing.T, pathCriterion models.StringCriterionInput) {
- withTxn(func(r models.Repository) error {
- sqb := r.Scene()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.SceneReaderWriter
sceneFilter := models.SceneFilterType{
Path: &pathCriterion,
}
- scenes := queryScene(t, sqb, &sceneFilter, nil)
+ scenes := queryScene(ctx, t, sqb, &sceneFilter, nil)
for _, scene := range scenes {
verifyString(t, scene.Path, pathCriterion)
@@ -489,13 +491,13 @@ func TestSceneQueryRating(t *testing.T) {
}
func verifyScenesRating(t *testing.T, ratingCriterion models.IntCriterionInput) {
- withTxn(func(r models.Repository) error {
- sqb := r.Scene()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.SceneReaderWriter
sceneFilter := models.SceneFilterType{
Rating: &ratingCriterion,
}
- scenes := queryScene(t, sqb, &sceneFilter, nil)
+ scenes := queryScene(ctx, t, sqb, &sceneFilter, nil)
for _, scene := range scenes {
verifyInt64(t, scene.Rating, ratingCriterion)
@@ -548,13 +550,13 @@ func TestSceneQueryOCounter(t *testing.T) {
}
func verifyScenesOCounter(t *testing.T, oCounterCriterion models.IntCriterionInput) {
- withTxn(func(r models.Repository) error {
- sqb := r.Scene()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.SceneReaderWriter
sceneFilter := models.SceneFilterType{
OCounter: &oCounterCriterion,
}
- scenes := queryScene(t, sqb, &sceneFilter, nil)
+ scenes := queryScene(ctx, t, sqb, &sceneFilter, nil)
for _, scene := range scenes {
verifyInt(t, scene.OCounter, oCounterCriterion)
@@ -607,13 +609,13 @@ func TestSceneQueryDuration(t *testing.T) {
}
func verifyScenesDuration(t *testing.T, durationCriterion models.IntCriterionInput) {
- withTxn(func(r models.Repository) error {
- sqb := r.Scene()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.SceneReaderWriter
sceneFilter := models.SceneFilterType{
Duration: &durationCriterion,
}
- scenes := queryScene(t, sqb, &sceneFilter, nil)
+ scenes := queryScene(ctx, t, sqb, &sceneFilter, nil)
for _, scene := range scenes {
if durationCriterion.Modifier == models.CriterionModifierEquals {
@@ -661,8 +663,8 @@ func TestSceneQueryResolution(t *testing.T) {
}
func verifyScenesResolution(t *testing.T, resolution models.ResolutionEnum) {
- withTxn(func(r models.Repository) error {
- sqb := r.Scene()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.SceneReaderWriter
sceneFilter := models.SceneFilterType{
Resolution: &models.ResolutionCriterionInput{
Value: resolution,
@@ -670,7 +672,7 @@ func verifyScenesResolution(t *testing.T, resolution models.ResolutionEnum) {
},
}
- scenes := queryScene(t, sqb, &sceneFilter, nil)
+ scenes := queryScene(ctx, t, sqb, &sceneFilter, nil)
for _, scene := range scenes {
verifySceneResolution(t, scene.Height, resolution)
@@ -706,20 +708,20 @@ func TestAllResolutionsHaveResolutionRange(t *testing.T) {
}
func TestSceneQueryResolutionModifiers(t *testing.T) {
- if err := withRollbackTxn(func(r models.Repository) error {
- qb := r.Scene()
- sceneNoResolution, _ := createScene(qb, 0, 0)
- firstScene540P, _ := createScene(qb, 960, 540)
- secondScene540P, _ := createScene(qb, 1280, 719)
- firstScene720P, _ := createScene(qb, 1280, 720)
- secondScene720P, _ := createScene(qb, 1280, 721)
- thirdScene720P, _ := createScene(qb, 1920, 1079)
- scene1080P, _ := createScene(qb, 1920, 1080)
+ if err := withRollbackTxn(func(ctx context.Context) error {
+ qb := sqlite.SceneReaderWriter
+ sceneNoResolution, _ := createScene(ctx, qb, 0, 0)
+ firstScene540P, _ := createScene(ctx, qb, 960, 540)
+ secondScene540P, _ := createScene(ctx, qb, 1280, 719)
+ firstScene720P, _ := createScene(ctx, qb, 1280, 720)
+ secondScene720P, _ := createScene(ctx, qb, 1280, 721)
+ thirdScene720P, _ := createScene(ctx, qb, 1920, 1079)
+ scene1080P, _ := createScene(ctx, qb, 1920, 1080)
- scenesEqualTo720P := queryScenes(t, qb, models.ResolutionEnumStandardHd, models.CriterionModifierEquals)
- scenesNotEqualTo720P := queryScenes(t, qb, models.ResolutionEnumStandardHd, models.CriterionModifierNotEquals)
- scenesGreaterThan720P := queryScenes(t, qb, models.ResolutionEnumStandardHd, models.CriterionModifierGreaterThan)
- scenesLessThan720P := queryScenes(t, qb, models.ResolutionEnumStandardHd, models.CriterionModifierLessThan)
+ scenesEqualTo720P := queryScenes(ctx, t, qb, models.ResolutionEnumStandardHd, models.CriterionModifierEquals)
+ scenesNotEqualTo720P := queryScenes(ctx, t, qb, models.ResolutionEnumStandardHd, models.CriterionModifierNotEquals)
+ scenesGreaterThan720P := queryScenes(ctx, t, qb, models.ResolutionEnumStandardHd, models.CriterionModifierGreaterThan)
+ scenesLessThan720P := queryScenes(ctx, t, qb, models.ResolutionEnumStandardHd, models.CriterionModifierLessThan)
assert.Subset(t, scenesEqualTo720P, []*models.Scene{firstScene720P, secondScene720P, thirdScene720P})
assert.NotSubset(t, scenesEqualTo720P, []*models.Scene{sceneNoResolution, firstScene540P, secondScene540P, scene1080P})
@@ -739,7 +741,7 @@ func TestSceneQueryResolutionModifiers(t *testing.T) {
}
}
-func queryScenes(t *testing.T, queryBuilder models.SceneReaderWriter, resolution models.ResolutionEnum, modifier models.CriterionModifier) []*models.Scene {
+func queryScenes(ctx context.Context, t *testing.T, queryBuilder models.SceneReaderWriter, resolution models.ResolutionEnum, modifier models.CriterionModifier) []*models.Scene {
sceneFilter := models.SceneFilterType{
Resolution: &models.ResolutionCriterionInput{
Value: resolution,
@@ -747,10 +749,10 @@ func queryScenes(t *testing.T, queryBuilder models.SceneReaderWriter, resolution
},
}
- return queryScene(t, queryBuilder, &sceneFilter, nil)
+ return queryScene(ctx, t, queryBuilder, &sceneFilter, nil)
}
-func createScene(queryBuilder models.SceneReaderWriter, width int64, height int64) (*models.Scene, error) {
+func createScene(ctx context.Context, queryBuilder models.SceneReaderWriter, width int64, height int64) (*models.Scene, error) {
name := fmt.Sprintf("TestSceneQueryResolutionModifiers %d %d", width, height)
scene := models.Scene{
Path: name,
@@ -765,12 +767,12 @@ func createScene(queryBuilder models.SceneReaderWriter, width int64, height int6
Checksum: sql.NullString{String: md5.FromString(name), Valid: true},
}
- return queryBuilder.Create(scene)
+ return queryBuilder.Create(ctx, scene)
}
func TestSceneQueryHasMarkers(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Scene()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.SceneReaderWriter
hasMarkers := "true"
sceneFilter := models.SceneFilterType{
HasMarkers: &hasMarkers,
@@ -781,17 +783,17 @@ func TestSceneQueryHasMarkers(t *testing.T) {
Q: &q,
}
- scenes := queryScene(t, sqb, &sceneFilter, &findFilter)
+ scenes := queryScene(ctx, t, sqb, &sceneFilter, &findFilter)
assert.Len(t, scenes, 1)
assert.Equal(t, sceneIDs[sceneIdxWithMarkers], scenes[0].ID)
hasMarkers = "false"
- scenes = queryScene(t, sqb, &sceneFilter, &findFilter)
+ scenes = queryScene(ctx, t, sqb, &sceneFilter, &findFilter)
assert.Len(t, scenes, 0)
findFilter.Q = nil
- scenes = queryScene(t, sqb, &sceneFilter, &findFilter)
+ scenes = queryScene(ctx, t, sqb, &sceneFilter, &findFilter)
assert.NotEqual(t, 0, len(scenes))
@@ -805,8 +807,8 @@ func TestSceneQueryHasMarkers(t *testing.T) {
}
func TestSceneQueryIsMissingGallery(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Scene()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.SceneReaderWriter
isMissing := "galleries"
sceneFilter := models.SceneFilterType{
IsMissing: &isMissing,
@@ -817,12 +819,12 @@ func TestSceneQueryIsMissingGallery(t *testing.T) {
Q: &q,
}
- scenes := queryScene(t, sqb, &sceneFilter, &findFilter)
+ scenes := queryScene(ctx, t, sqb, &sceneFilter, &findFilter)
assert.Len(t, scenes, 0)
findFilter.Q = nil
- scenes = queryScene(t, sqb, &sceneFilter, &findFilter)
+ scenes = queryScene(ctx, t, sqb, &sceneFilter, &findFilter)
// ensure non of the ids equal the one with gallery
for _, scene := range scenes {
@@ -834,8 +836,8 @@ func TestSceneQueryIsMissingGallery(t *testing.T) {
}
func TestSceneQueryIsMissingStudio(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Scene()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.SceneReaderWriter
isMissing := "studio"
sceneFilter := models.SceneFilterType{
IsMissing: &isMissing,
@@ -846,12 +848,12 @@ func TestSceneQueryIsMissingStudio(t *testing.T) {
Q: &q,
}
- scenes := queryScene(t, sqb, &sceneFilter, &findFilter)
+ scenes := queryScene(ctx, t, sqb, &sceneFilter, &findFilter)
assert.Len(t, scenes, 0)
findFilter.Q = nil
- scenes = queryScene(t, sqb, &sceneFilter, &findFilter)
+ scenes = queryScene(ctx, t, sqb, &sceneFilter, &findFilter)
// ensure non of the ids equal the one with studio
for _, scene := range scenes {
@@ -863,8 +865,8 @@ func TestSceneQueryIsMissingStudio(t *testing.T) {
}
func TestSceneQueryIsMissingMovies(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Scene()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.SceneReaderWriter
isMissing := "movie"
sceneFilter := models.SceneFilterType{
IsMissing: &isMissing,
@@ -875,12 +877,12 @@ func TestSceneQueryIsMissingMovies(t *testing.T) {
Q: &q,
}
- scenes := queryScene(t, sqb, &sceneFilter, &findFilter)
+ scenes := queryScene(ctx, t, sqb, &sceneFilter, &findFilter)
assert.Len(t, scenes, 0)
findFilter.Q = nil
- scenes = queryScene(t, sqb, &sceneFilter, &findFilter)
+ scenes = queryScene(ctx, t, sqb, &sceneFilter, &findFilter)
// ensure non of the ids equal the one with movies
for _, scene := range scenes {
@@ -892,8 +894,8 @@ func TestSceneQueryIsMissingMovies(t *testing.T) {
}
func TestSceneQueryIsMissingPerformers(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Scene()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.SceneReaderWriter
isMissing := "performers"
sceneFilter := models.SceneFilterType{
IsMissing: &isMissing,
@@ -904,12 +906,12 @@ func TestSceneQueryIsMissingPerformers(t *testing.T) {
Q: &q,
}
- scenes := queryScene(t, sqb, &sceneFilter, &findFilter)
+ scenes := queryScene(ctx, t, sqb, &sceneFilter, &findFilter)
assert.Len(t, scenes, 0)
findFilter.Q = nil
- scenes = queryScene(t, sqb, &sceneFilter, &findFilter)
+ scenes = queryScene(ctx, t, sqb, &sceneFilter, &findFilter)
assert.True(t, len(scenes) > 0)
@@ -923,14 +925,14 @@ func TestSceneQueryIsMissingPerformers(t *testing.T) {
}
func TestSceneQueryIsMissingDate(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Scene()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.SceneReaderWriter
isMissing := "date"
sceneFilter := models.SceneFilterType{
IsMissing: &isMissing,
}
- scenes := queryScene(t, sqb, &sceneFilter, nil)
+ scenes := queryScene(ctx, t, sqb, &sceneFilter, nil)
// three in four scenes have no date
assert.Len(t, scenes, int(math.Ceil(float64(totalScenes)/4*3)))
@@ -945,8 +947,8 @@ func TestSceneQueryIsMissingDate(t *testing.T) {
}
func TestSceneQueryIsMissingTags(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Scene()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.SceneReaderWriter
isMissing := "tags"
sceneFilter := models.SceneFilterType{
IsMissing: &isMissing,
@@ -957,12 +959,12 @@ func TestSceneQueryIsMissingTags(t *testing.T) {
Q: &q,
}
- scenes := queryScene(t, sqb, &sceneFilter, &findFilter)
+ scenes := queryScene(ctx, t, sqb, &sceneFilter, &findFilter)
assert.Len(t, scenes, 0)
findFilter.Q = nil
- scenes = queryScene(t, sqb, &sceneFilter, &findFilter)
+ scenes = queryScene(ctx, t, sqb, &sceneFilter, &findFilter)
assert.True(t, len(scenes) > 0)
@@ -971,14 +973,14 @@ func TestSceneQueryIsMissingTags(t *testing.T) {
}
func TestSceneQueryIsMissingRating(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Scene()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.SceneReaderWriter
isMissing := "rating"
sceneFilter := models.SceneFilterType{
IsMissing: &isMissing,
}
- scenes := queryScene(t, sqb, &sceneFilter, nil)
+ scenes := queryScene(ctx, t, sqb, &sceneFilter, nil)
assert.True(t, len(scenes) > 0)
@@ -992,8 +994,8 @@ func TestSceneQueryIsMissingRating(t *testing.T) {
}
func TestSceneQueryPerformers(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Scene()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.SceneReaderWriter
performerCriterion := models.MultiCriterionInput{
Value: []string{
strconv.Itoa(performerIDs[performerIdxWithScene]),
@@ -1006,7 +1008,7 @@ func TestSceneQueryPerformers(t *testing.T) {
Performers: &performerCriterion,
}
- scenes := queryScene(t, sqb, &sceneFilter, nil)
+ scenes := queryScene(ctx, t, sqb, &sceneFilter, nil)
assert.Len(t, scenes, 2)
@@ -1023,7 +1025,7 @@ func TestSceneQueryPerformers(t *testing.T) {
Modifier: models.CriterionModifierIncludesAll,
}
- scenes = queryScene(t, sqb, &sceneFilter, nil)
+ scenes = queryScene(ctx, t, sqb, &sceneFilter, nil)
assert.Len(t, scenes, 1)
assert.Equal(t, sceneIDs[sceneIdxWithTwoPerformers], scenes[0].ID)
@@ -1040,7 +1042,7 @@ func TestSceneQueryPerformers(t *testing.T) {
Q: &q,
}
- scenes = queryScene(t, sqb, &sceneFilter, &findFilter)
+ scenes = queryScene(ctx, t, sqb, &sceneFilter, &findFilter)
assert.Len(t, scenes, 0)
return nil
@@ -1048,8 +1050,8 @@ func TestSceneQueryPerformers(t *testing.T) {
}
func TestSceneQueryTags(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Scene()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.SceneReaderWriter
tagCriterion := models.HierarchicalMultiCriterionInput{
Value: []string{
strconv.Itoa(tagIDs[tagIdxWithScene]),
@@ -1062,7 +1064,7 @@ func TestSceneQueryTags(t *testing.T) {
Tags: &tagCriterion,
}
- scenes := queryScene(t, sqb, &sceneFilter, nil)
+ scenes := queryScene(ctx, t, sqb, &sceneFilter, nil)
assert.Len(t, scenes, 2)
// ensure ids are correct
@@ -1078,7 +1080,7 @@ func TestSceneQueryTags(t *testing.T) {
Modifier: models.CriterionModifierIncludesAll,
}
- scenes = queryScene(t, sqb, &sceneFilter, nil)
+ scenes = queryScene(ctx, t, sqb, &sceneFilter, nil)
assert.Len(t, scenes, 1)
assert.Equal(t, sceneIDs[sceneIdxWithTwoTags], scenes[0].ID)
@@ -1095,7 +1097,7 @@ func TestSceneQueryTags(t *testing.T) {
Q: &q,
}
- scenes = queryScene(t, sqb, &sceneFilter, &findFilter)
+ scenes = queryScene(ctx, t, sqb, &sceneFilter, &findFilter)
assert.Len(t, scenes, 0)
return nil
@@ -1103,8 +1105,8 @@ func TestSceneQueryTags(t *testing.T) {
}
func TestSceneQueryPerformerTags(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Scene()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.SceneReaderWriter
tagCriterion := models.HierarchicalMultiCriterionInput{
Value: []string{
strconv.Itoa(tagIDs[tagIdxWithPerformer]),
@@ -1117,7 +1119,7 @@ func TestSceneQueryPerformerTags(t *testing.T) {
PerformerTags: &tagCriterion,
}
- scenes := queryScene(t, sqb, &sceneFilter, nil)
+ scenes := queryScene(ctx, t, sqb, &sceneFilter, nil)
assert.Len(t, scenes, 2)
// ensure ids are correct
@@ -1133,7 +1135,7 @@ func TestSceneQueryPerformerTags(t *testing.T) {
Modifier: models.CriterionModifierIncludesAll,
}
- scenes = queryScene(t, sqb, &sceneFilter, nil)
+ scenes = queryScene(ctx, t, sqb, &sceneFilter, nil)
assert.Len(t, scenes, 1)
assert.Equal(t, sceneIDs[sceneIdxWithPerformerTwoTags], scenes[0].ID)
@@ -1150,7 +1152,7 @@ func TestSceneQueryPerformerTags(t *testing.T) {
Q: &q,
}
- scenes = queryScene(t, sqb, &sceneFilter, &findFilter)
+ scenes = queryScene(ctx, t, sqb, &sceneFilter, &findFilter)
assert.Len(t, scenes, 0)
tagCriterion = models.HierarchicalMultiCriterionInput{
@@ -1158,22 +1160,22 @@ func TestSceneQueryPerformerTags(t *testing.T) {
}
q = getSceneStringValue(sceneIdx1WithPerformer, titleField)
- scenes = queryScene(t, sqb, &sceneFilter, &findFilter)
+ scenes = queryScene(ctx, t, sqb, &sceneFilter, &findFilter)
assert.Len(t, scenes, 1)
assert.Equal(t, sceneIDs[sceneIdx1WithPerformer], scenes[0].ID)
q = getSceneStringValue(sceneIdxWithPerformerTag, titleField)
- scenes = queryScene(t, sqb, &sceneFilter, &findFilter)
+ scenes = queryScene(ctx, t, sqb, &sceneFilter, &findFilter)
assert.Len(t, scenes, 0)
tagCriterion.Modifier = models.CriterionModifierNotNull
- scenes = queryScene(t, sqb, &sceneFilter, &findFilter)
+ scenes = queryScene(ctx, t, sqb, &sceneFilter, &findFilter)
assert.Len(t, scenes, 1)
assert.Equal(t, sceneIDs[sceneIdxWithPerformerTag], scenes[0].ID)
q = getSceneStringValue(sceneIdx1WithPerformer, titleField)
- scenes = queryScene(t, sqb, &sceneFilter, &findFilter)
+ scenes = queryScene(ctx, t, sqb, &sceneFilter, &findFilter)
assert.Len(t, scenes, 0)
return nil
@@ -1181,8 +1183,8 @@ func TestSceneQueryPerformerTags(t *testing.T) {
}
func TestSceneQueryStudio(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Scene()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.SceneReaderWriter
studioCriterion := models.HierarchicalMultiCriterionInput{
Value: []string{
strconv.Itoa(studioIDs[studioIdxWithScene]),
@@ -1194,7 +1196,7 @@ func TestSceneQueryStudio(t *testing.T) {
Studios: &studioCriterion,
}
- scenes := queryScene(t, sqb, &sceneFilter, nil)
+ scenes := queryScene(ctx, t, sqb, &sceneFilter, nil)
assert.Len(t, scenes, 1)
@@ -1213,7 +1215,7 @@ func TestSceneQueryStudio(t *testing.T) {
Q: &q,
}
- scenes = queryScene(t, sqb, &sceneFilter, &findFilter)
+ scenes = queryScene(ctx, t, sqb, &sceneFilter, &findFilter)
assert.Len(t, scenes, 0)
return nil
@@ -1221,8 +1223,8 @@ func TestSceneQueryStudio(t *testing.T) {
}
func TestSceneQueryStudioDepth(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Scene()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.SceneReaderWriter
depth := 2
studioCriterion := models.HierarchicalMultiCriterionInput{
Value: []string{
@@ -1236,16 +1238,16 @@ func TestSceneQueryStudioDepth(t *testing.T) {
Studios: &studioCriterion,
}
- scenes := queryScene(t, sqb, &sceneFilter, nil)
+ scenes := queryScene(ctx, t, sqb, &sceneFilter, nil)
assert.Len(t, scenes, 1)
depth = 1
- scenes = queryScene(t, sqb, &sceneFilter, nil)
+ scenes = queryScene(ctx, t, sqb, &sceneFilter, nil)
assert.Len(t, scenes, 0)
studioCriterion.Value = []string{strconv.Itoa(studioIDs[studioIdxWithParentAndChild])}
- scenes = queryScene(t, sqb, &sceneFilter, nil)
+ scenes = queryScene(ctx, t, sqb, &sceneFilter, nil)
assert.Len(t, scenes, 1)
// ensure id is correct
@@ -1265,15 +1267,15 @@ func TestSceneQueryStudioDepth(t *testing.T) {
Q: &q,
}
- scenes = queryScene(t, sqb, &sceneFilter, &findFilter)
+ scenes = queryScene(ctx, t, sqb, &sceneFilter, &findFilter)
assert.Len(t, scenes, 0)
depth = 1
- scenes = queryScene(t, sqb, &sceneFilter, &findFilter)
+ scenes = queryScene(ctx, t, sqb, &sceneFilter, &findFilter)
assert.Len(t, scenes, 1)
studioCriterion.Value = []string{strconv.Itoa(studioIDs[studioIdxWithParentAndChild])}
- scenes = queryScene(t, sqb, &sceneFilter, &findFilter)
+ scenes = queryScene(ctx, t, sqb, &sceneFilter, &findFilter)
assert.Len(t, scenes, 0)
return nil
@@ -1281,8 +1283,8 @@ func TestSceneQueryStudioDepth(t *testing.T) {
}
func TestSceneQueryMovies(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Scene()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.SceneReaderWriter
movieCriterion := models.MultiCriterionInput{
Value: []string{
strconv.Itoa(movieIDs[movieIdxWithScene]),
@@ -1294,7 +1296,7 @@ func TestSceneQueryMovies(t *testing.T) {
Movies: &movieCriterion,
}
- scenes := queryScene(t, sqb, &sceneFilter, nil)
+ scenes := queryScene(ctx, t, sqb, &sceneFilter, nil)
assert.Len(t, scenes, 1)
@@ -1313,7 +1315,7 @@ func TestSceneQueryMovies(t *testing.T) {
Q: &q,
}
- scenes = queryScene(t, sqb, &sceneFilter, &findFilter)
+ scenes = queryScene(ctx, t, sqb, &sceneFilter, &findFilter)
assert.Len(t, scenes, 0)
return nil
@@ -1328,9 +1330,9 @@ func TestSceneQuerySorting(t *testing.T) {
Direction: &direction,
}
- withTxn(func(r models.Repository) error {
- sqb := r.Scene()
- scenes := queryScene(t, sqb, nil, &findFilter)
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.SceneReaderWriter
+ scenes := queryScene(ctx, t, sqb, nil, &findFilter)
// scenes should be in same order as indexes
firstScene := scenes[0]
@@ -1342,7 +1344,7 @@ func TestSceneQuerySorting(t *testing.T) {
// sort in descending order
direction = models.SortDirectionEnumDesc
- scenes = queryScene(t, sqb, nil, &findFilter)
+ scenes = queryScene(ctx, t, sqb, nil, &findFilter)
firstScene = scenes[0]
lastScene = scenes[len(scenes)-1]
@@ -1359,9 +1361,9 @@ func TestSceneQueryPagination(t *testing.T) {
PerPage: &perPage,
}
- withTxn(func(r models.Repository) error {
- sqb := r.Scene()
- scenes := queryScene(t, sqb, nil, &findFilter)
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.SceneReaderWriter
+ scenes := queryScene(ctx, t, sqb, nil, &findFilter)
assert.Len(t, scenes, 1)
@@ -1369,7 +1371,7 @@ func TestSceneQueryPagination(t *testing.T) {
page := 2
findFilter.Page = &page
- scenes = queryScene(t, sqb, nil, &findFilter)
+ scenes = queryScene(ctx, t, sqb, nil, &findFilter)
assert.Len(t, scenes, 1)
secondID := scenes[0].ID
@@ -1378,7 +1380,7 @@ func TestSceneQueryPagination(t *testing.T) {
perPage = 2
page = 1
- scenes = queryScene(t, sqb, nil, &findFilter)
+ scenes = queryScene(ctx, t, sqb, nil, &findFilter)
assert.Len(t, scenes, 2)
assert.Equal(t, firstID, scenes[0].ID)
assert.Equal(t, secondID, scenes[1].ID)
@@ -1407,17 +1409,17 @@ func TestSceneQueryTagCount(t *testing.T) {
}
func verifyScenesTagCount(t *testing.T, tagCountCriterion models.IntCriterionInput) {
- withTxn(func(r models.Repository) error {
- sqb := r.Scene()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.SceneReaderWriter
sceneFilter := models.SceneFilterType{
TagCount: &tagCountCriterion,
}
- scenes := queryScene(t, sqb, &sceneFilter, nil)
+ scenes := queryScene(ctx, t, sqb, &sceneFilter, nil)
assert.Greater(t, len(scenes), 0)
for _, scene := range scenes {
- ids, err := sqb.GetTagIDs(scene.ID)
+ ids, err := sqb.GetTagIDs(ctx, scene.ID)
if err != nil {
return err
}
@@ -1448,17 +1450,17 @@ func TestSceneQueryPerformerCount(t *testing.T) {
}
func verifyScenesPerformerCount(t *testing.T, performerCountCriterion models.IntCriterionInput) {
- withTxn(func(r models.Repository) error {
- sqb := r.Scene()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.SceneReaderWriter
sceneFilter := models.SceneFilterType{
PerformerCount: &performerCountCriterion,
}
- scenes := queryScene(t, sqb, &sceneFilter, nil)
+ scenes := queryScene(ctx, t, sqb, &sceneFilter, nil)
assert.Greater(t, len(scenes), 0)
for _, scene := range scenes {
- ids, err := sqb.GetPerformerIDs(scene.ID)
+ ids, err := sqb.GetPerformerIDs(ctx, scene.ID)
if err != nil {
return err
}
@@ -1470,10 +1472,10 @@ func verifyScenesPerformerCount(t *testing.T, performerCountCriterion models.Int
}
func TestSceneCountByTagID(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Scene()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.SceneReaderWriter
- sceneCount, err := sqb.CountByTagID(tagIDs[tagIdxWithScene])
+ sceneCount, err := sqb.CountByTagID(ctx, tagIDs[tagIdxWithScene])
if err != nil {
t.Errorf("error calling CountByTagID: %s", err.Error())
@@ -1481,7 +1483,7 @@ func TestSceneCountByTagID(t *testing.T) {
assert.Equal(t, 1, sceneCount)
- sceneCount, err = sqb.CountByTagID(0)
+ sceneCount, err = sqb.CountByTagID(ctx, 0)
if err != nil {
t.Errorf("error calling CountByTagID: %s", err.Error())
@@ -1494,10 +1496,10 @@ func TestSceneCountByTagID(t *testing.T) {
}
func TestSceneCountByMovieID(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Scene()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.SceneReaderWriter
- sceneCount, err := sqb.CountByMovieID(movieIDs[movieIdxWithScene])
+ sceneCount, err := sqb.CountByMovieID(ctx, movieIDs[movieIdxWithScene])
if err != nil {
t.Errorf("error calling CountByMovieID: %s", err.Error())
@@ -1505,7 +1507,7 @@ func TestSceneCountByMovieID(t *testing.T) {
assert.Equal(t, 1, sceneCount)
- sceneCount, err = sqb.CountByMovieID(0)
+ sceneCount, err = sqb.CountByMovieID(ctx, 0)
if err != nil {
t.Errorf("error calling CountByMovieID: %s", err.Error())
@@ -1518,10 +1520,10 @@ func TestSceneCountByMovieID(t *testing.T) {
}
func TestSceneCountByStudioID(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Scene()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.SceneReaderWriter
- sceneCount, err := sqb.CountByStudioID(studioIDs[studioIdxWithScene])
+ sceneCount, err := sqb.CountByStudioID(ctx, studioIDs[studioIdxWithScene])
if err != nil {
t.Errorf("error calling CountByStudioID: %s", err.Error())
@@ -1529,7 +1531,7 @@ func TestSceneCountByStudioID(t *testing.T) {
assert.Equal(t, 1, sceneCount)
- sceneCount, err = sqb.CountByStudioID(0)
+ sceneCount, err = sqb.CountByStudioID(ctx, 0)
if err != nil {
t.Errorf("error calling CountByStudioID: %s", err.Error())
@@ -1542,10 +1544,10 @@ func TestSceneCountByStudioID(t *testing.T) {
}
func TestFindByMovieID(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Scene()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.SceneReaderWriter
- scenes, err := sqb.FindByMovieID(movieIDs[movieIdxWithScene])
+ scenes, err := sqb.FindByMovieID(ctx, movieIDs[movieIdxWithScene])
if err != nil {
t.Errorf("error calling FindByMovieID: %s", err.Error())
@@ -1554,7 +1556,7 @@ func TestFindByMovieID(t *testing.T) {
assert.Len(t, scenes, 1)
assert.Equal(t, sceneIDs[sceneIdxWithMovie], scenes[0].ID)
- scenes, err = sqb.FindByMovieID(0)
+ scenes, err = sqb.FindByMovieID(ctx, 0)
if err != nil {
t.Errorf("error calling FindByMovieID: %s", err.Error())
@@ -1567,10 +1569,10 @@ func TestFindByMovieID(t *testing.T) {
}
func TestFindByPerformerID(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Scene()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.SceneReaderWriter
- scenes, err := sqb.FindByPerformerID(performerIDs[performerIdxWithScene])
+ scenes, err := sqb.FindByPerformerID(ctx, performerIDs[performerIdxWithScene])
if err != nil {
t.Errorf("error calling FindByPerformerID: %s", err.Error())
@@ -1579,7 +1581,7 @@ func TestFindByPerformerID(t *testing.T) {
assert.Len(t, scenes, 1)
assert.Equal(t, sceneIDs[sceneIdxWithPerformer], scenes[0].ID)
- scenes, err = sqb.FindByPerformerID(0)
+ scenes, err = sqb.FindByPerformerID(ctx, 0)
if err != nil {
t.Errorf("error calling FindByPerformerID: %s", err.Error())
@@ -1592,8 +1594,8 @@ func TestFindByPerformerID(t *testing.T) {
}
func TestSceneUpdateSceneCover(t *testing.T) {
- if err := withTxn(func(r models.Repository) error {
- qb := r.Scene()
+ if err := withTxn(func(ctx context.Context) error {
+ qb := sqlite.SceneReaderWriter
// create performer to test against
const name = "TestSceneUpdateSceneCover"
@@ -1601,26 +1603,26 @@ func TestSceneUpdateSceneCover(t *testing.T) {
Path: name,
Checksum: sql.NullString{String: md5.FromString(name), Valid: true},
}
- created, err := qb.Create(scene)
+ created, err := qb.Create(ctx, scene)
if err != nil {
return fmt.Errorf("Error creating scene: %s", err.Error())
}
image := []byte("image")
- err = qb.UpdateCover(created.ID, image)
+ err = qb.UpdateCover(ctx, created.ID, image)
if err != nil {
return fmt.Errorf("Error updating scene cover: %s", err.Error())
}
// ensure image set
- storedImage, err := qb.GetCover(created.ID)
+ storedImage, err := qb.GetCover(ctx, created.ID)
if err != nil {
return fmt.Errorf("Error getting image: %s", err.Error())
}
assert.Equal(t, storedImage, image)
// set nil image
- err = qb.UpdateCover(created.ID, nil)
+ err = qb.UpdateCover(ctx, created.ID, nil)
if err == nil {
return fmt.Errorf("Expected error setting nil image")
}
@@ -1632,8 +1634,8 @@ func TestSceneUpdateSceneCover(t *testing.T) {
}
func TestSceneDestroySceneCover(t *testing.T) {
- if err := withTxn(func(r models.Repository) error {
- qb := r.Scene()
+ if err := withTxn(func(ctx context.Context) error {
+ qb := sqlite.SceneReaderWriter
// create performer to test against
const name = "TestSceneDestroySceneCover"
@@ -1641,24 +1643,24 @@ func TestSceneDestroySceneCover(t *testing.T) {
Path: name,
Checksum: sql.NullString{String: md5.FromString(name), Valid: true},
}
- created, err := qb.Create(scene)
+ created, err := qb.Create(ctx, scene)
if err != nil {
return fmt.Errorf("Error creating scene: %s", err.Error())
}
image := []byte("image")
- err = qb.UpdateCover(created.ID, image)
+ err = qb.UpdateCover(ctx, created.ID, image)
if err != nil {
return fmt.Errorf("Error updating scene image: %s", err.Error())
}
- err = qb.DestroyCover(created.ID)
+ err = qb.DestroyCover(ctx, created.ID)
if err != nil {
return fmt.Errorf("Error destroying scene cover: %s", err.Error())
}
// image should be nil
- storedImage, err := qb.GetCover(created.ID)
+ storedImage, err := qb.GetCover(ctx, created.ID)
if err != nil {
return fmt.Errorf("Error getting image: %s", err.Error())
}
@@ -1671,8 +1673,8 @@ func TestSceneDestroySceneCover(t *testing.T) {
}
func TestSceneStashIDs(t *testing.T) {
- if err := withTxn(func(r models.Repository) error {
- qb := r.Scene()
+ if err := withTxn(func(ctx context.Context) error {
+ qb := sqlite.SceneReaderWriter
// create scene to test against
const name = "TestSceneStashIDs"
@@ -1680,12 +1682,12 @@ func TestSceneStashIDs(t *testing.T) {
Path: name,
Checksum: sql.NullString{String: md5.FromString(name), Valid: true},
}
- created, err := qb.Create(scene)
+ created, err := qb.Create(ctx, scene)
if err != nil {
return fmt.Errorf("Error creating scene: %s", err.Error())
}
- testStashIDReaderWriter(t, qb, created.ID)
+ testStashIDReaderWriter(ctx, t, qb, created.ID)
return nil
}); err != nil {
t.Error(err.Error())
@@ -1693,8 +1695,8 @@ func TestSceneStashIDs(t *testing.T) {
}
func TestSceneQueryQTrim(t *testing.T) {
- if err := withTxn(func(r models.Repository) error {
- qb := r.Scene()
+ if err := withTxn(func(ctx context.Context) error {
+ qb := sqlite.SceneReaderWriter
expectedID := sceneIDs[sceneIdxWithSpacedName]
@@ -1717,7 +1719,7 @@ func TestSceneQueryQTrim(t *testing.T) {
f := models.FindFilterType{
Q: &tst.query,
}
- scenes := queryScene(t, qb, nil, &f)
+ scenes := queryScene(ctx, t, qb, nil, &f)
assert.Len(t, scenes, tst.count)
if len(scenes) > 0 {
@@ -1726,7 +1728,7 @@ func TestSceneQueryQTrim(t *testing.T) {
}
findFilter := models.FindFilterType{}
- scenes := queryScene(t, qb, nil, &findFilter)
+ scenes := queryScene(ctx, t, qb, nil, &findFilter)
assert.NotEqual(t, 0, len(scenes))
return nil
diff --git a/pkg/sqlite/scraped_item.go b/pkg/sqlite/scraped_item.go
index 1eafc98a5..1b8216dab 100644
--- a/pkg/sqlite/scraped_item.go
+++ b/pkg/sqlite/scraped_item.go
@@ -1,6 +1,7 @@
package sqlite
import (
+ "context"
"database/sql"
"errors"
@@ -13,41 +14,38 @@ type scrapedItemQueryBuilder struct {
repository
}
-func NewScrapedItemReaderWriter(tx dbi) *scrapedItemQueryBuilder {
- return &scrapedItemQueryBuilder{
- repository{
- tx: tx,
- tableName: scrapedItemTable,
- idColumn: idColumn,
- },
- }
+var ScrapedItemReaderWriter = &scrapedItemQueryBuilder{
+ repository{
+ tableName: scrapedItemTable,
+ idColumn: idColumn,
+ },
}
-func (qb *scrapedItemQueryBuilder) Create(newObject models.ScrapedItem) (*models.ScrapedItem, error) {
+func (qb *scrapedItemQueryBuilder) Create(ctx context.Context, newObject models.ScrapedItem) (*models.ScrapedItem, error) {
var ret models.ScrapedItem
- if err := qb.insertObject(newObject, &ret); err != nil {
+ if err := qb.insertObject(ctx, newObject, &ret); err != nil {
return nil, err
}
return &ret, nil
}
-func (qb *scrapedItemQueryBuilder) Update(updatedObject models.ScrapedItem) (*models.ScrapedItem, error) {
+func (qb *scrapedItemQueryBuilder) Update(ctx context.Context, updatedObject models.ScrapedItem) (*models.ScrapedItem, error) {
const partial = false
- if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil {
+ if err := qb.update(ctx, updatedObject.ID, updatedObject, partial); err != nil {
return nil, err
}
- return qb.find(updatedObject.ID)
+ return qb.find(ctx, updatedObject.ID)
}
-func (qb *scrapedItemQueryBuilder) Find(id int) (*models.ScrapedItem, error) {
- return qb.find(id)
+func (qb *scrapedItemQueryBuilder) Find(ctx context.Context, id int) (*models.ScrapedItem, error) {
+ return qb.find(ctx, id)
}
-func (qb *scrapedItemQueryBuilder) find(id int) (*models.ScrapedItem, error) {
+func (qb *scrapedItemQueryBuilder) find(ctx context.Context, id int) (*models.ScrapedItem, error) {
var ret models.ScrapedItem
- if err := qb.get(id, &ret); err != nil {
+ if err := qb.getByID(ctx, id, &ret); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
@@ -56,8 +54,8 @@ func (qb *scrapedItemQueryBuilder) find(id int) (*models.ScrapedItem, error) {
return &ret, nil
}
-func (qb *scrapedItemQueryBuilder) All() ([]*models.ScrapedItem, error) {
- return qb.queryScrapedItems(selectAll("scraped_items")+qb.getScrapedItemsSort(nil), nil)
+func (qb *scrapedItemQueryBuilder) All(ctx context.Context) ([]*models.ScrapedItem, error) {
+ return qb.queryScrapedItems(ctx, selectAll("scraped_items")+qb.getScrapedItemsSort(nil), nil)
}
func (qb *scrapedItemQueryBuilder) getScrapedItemsSort(findFilter *models.FindFilterType) string {
@@ -73,9 +71,9 @@ func (qb *scrapedItemQueryBuilder) getScrapedItemsSort(findFilter *models.FindFi
return getSort(sort, direction, "scraped_items")
}
-func (qb *scrapedItemQueryBuilder) queryScrapedItems(query string, args []interface{}) ([]*models.ScrapedItem, error) {
+func (qb *scrapedItemQueryBuilder) queryScrapedItems(ctx context.Context, query string, args []interface{}) ([]*models.ScrapedItem, error) {
var ret models.ScrapedItems
- if err := qb.query(query, args, &ret); err != nil {
+ if err := qb.query(ctx, query, args, &ret); err != nil {
return nil, err
}
diff --git a/pkg/sqlite/setup_test.go b/pkg/sqlite/setup_test.go
index e07d8aebe..5cb123118 100644
--- a/pkg/sqlite/setup_test.go
+++ b/pkg/sqlite/setup_test.go
@@ -13,13 +13,13 @@ import (
"testing"
"time"
- "github.com/stashapp/stash/pkg/database"
"github.com/stashapp/stash/pkg/gallery"
"github.com/stashapp/stash/pkg/hash/md5"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/scene"
"github.com/stashapp/stash/pkg/sliceutil/intslice"
"github.com/stashapp/stash/pkg/sqlite"
+ "github.com/stashapp/stash/pkg/txn"
)
const (
@@ -386,20 +386,21 @@ var (
}
)
+var db *sqlite.Database
+
func TestMain(m *testing.M) {
ret := runTests(m)
os.Exit(ret)
}
-func withTxn(f func(r models.Repository) error) error {
- t := sqlite.NewTransactionManager()
- return t.WithTxn(context.TODO(), f)
+func withTxn(f func(ctx context.Context) error) error {
+ return txn.WithTxn(context.Background(), db, f)
}
-func withRollbackTxn(f func(r models.Repository) error) error {
+func withRollbackTxn(f func(ctx context.Context) error) error {
var ret error
- withTxn(func(repo models.Repository) error {
- ret = f(repo)
+ withTxn(func(ctx context.Context) error {
+ ret = f(ctx)
return errors.New("fake error for rollback")
})
@@ -407,7 +408,7 @@ func withRollbackTxn(f func(r models.Repository) error) error {
}
func testTeardown(databaseFile string) {
- err := database.DB.Close()
+ err := db.Close()
if err != nil {
panic(err)
@@ -428,7 +429,9 @@ func runTests(m *testing.M) int {
f.Close()
databaseFile := f.Name()
- if err := database.Initialize(databaseFile); err != nil {
+ db = &sqlite.Database{}
+
+ if err := db.Open(databaseFile); err != nil {
panic(fmt.Sprintf("Could not initialize database: %s", err.Error()))
}
@@ -445,109 +448,109 @@ func runTests(m *testing.M) int {
}
func populateDB() error {
- if err := withTxn(func(r models.Repository) error {
- if err := createScenes(r.Scene(), totalScenes); err != nil {
+ if err := withTxn(func(ctx context.Context) error {
+ if err := createScenes(ctx, sqlite.SceneReaderWriter, totalScenes); err != nil {
return fmt.Errorf("error creating scenes: %s", err.Error())
}
- if err := createImages(r.Image(), totalImages); err != nil {
+ if err := createImages(ctx, sqlite.ImageReaderWriter, totalImages); err != nil {
return fmt.Errorf("error creating images: %s", err.Error())
}
- if err := createGalleries(r.Gallery(), totalGalleries); err != nil {
+ if err := createGalleries(ctx, sqlite.GalleryReaderWriter, totalGalleries); err != nil {
return fmt.Errorf("error creating galleries: %s", err.Error())
}
- if err := createMovies(r.Movie(), moviesNameCase, moviesNameNoCase); err != nil {
+ if err := createMovies(ctx, sqlite.MovieReaderWriter, moviesNameCase, moviesNameNoCase); err != nil {
return fmt.Errorf("error creating movies: %s", err.Error())
}
- if err := createPerformers(r.Performer(), performersNameCase, performersNameNoCase); err != nil {
+ if err := createPerformers(ctx, sqlite.PerformerReaderWriter, performersNameCase, performersNameNoCase); err != nil {
return fmt.Errorf("error creating performers: %s", err.Error())
}
- if err := createTags(r.Tag(), tagsNameCase, tagsNameNoCase); err != nil {
+ if err := createTags(ctx, sqlite.TagReaderWriter, tagsNameCase, tagsNameNoCase); err != nil {
return fmt.Errorf("error creating tags: %s", err.Error())
}
- if err := addTagImage(r.Tag(), tagIdxWithCoverImage); err != nil {
+ if err := addTagImage(ctx, sqlite.TagReaderWriter, tagIdxWithCoverImage); err != nil {
return fmt.Errorf("error adding tag image: %s", err.Error())
}
- if err := createStudios(r.Studio(), studiosNameCase, studiosNameNoCase); err != nil {
+ if err := createStudios(ctx, sqlite.StudioReaderWriter, studiosNameCase, studiosNameNoCase); err != nil {
return fmt.Errorf("error creating studios: %s", err.Error())
}
- if err := createSavedFilters(r.SavedFilter(), totalSavedFilters); err != nil {
+ if err := createSavedFilters(ctx, sqlite.SavedFilterReaderWriter, totalSavedFilters); err != nil {
return fmt.Errorf("error creating saved filters: %s", err.Error())
}
- if err := linkPerformerTags(r.Performer()); err != nil {
+ if err := linkPerformerTags(ctx, sqlite.PerformerReaderWriter); err != nil {
return fmt.Errorf("error linking performer tags: %s", err.Error())
}
- if err := linkSceneGalleries(r.Scene()); err != nil {
+ if err := linkSceneGalleries(ctx, sqlite.SceneReaderWriter); err != nil {
return fmt.Errorf("error linking scenes to galleries: %s", err.Error())
}
- if err := linkSceneMovies(r.Scene()); err != nil {
+ if err := linkSceneMovies(ctx, sqlite.SceneReaderWriter); err != nil {
return fmt.Errorf("error linking scenes to movies: %s", err.Error())
}
- if err := linkScenePerformers(r.Scene()); err != nil {
+ if err := linkScenePerformers(ctx, sqlite.SceneReaderWriter); err != nil {
return fmt.Errorf("error linking scene performers: %s", err.Error())
}
- if err := linkSceneTags(r.Scene()); err != nil {
+ if err := linkSceneTags(ctx, sqlite.SceneReaderWriter); err != nil {
return fmt.Errorf("error linking scene tags: %s", err.Error())
}
- if err := linkSceneStudios(r.Scene()); err != nil {
+ if err := linkSceneStudios(ctx, sqlite.SceneReaderWriter); err != nil {
return fmt.Errorf("error linking scene studios: %s", err.Error())
}
- if err := linkImageGalleries(r.Gallery()); err != nil {
+ if err := linkImageGalleries(ctx, sqlite.GalleryReaderWriter); err != nil {
return fmt.Errorf("error linking gallery images: %s", err.Error())
}
- if err := linkImagePerformers(r.Image()); err != nil {
+ if err := linkImagePerformers(ctx, sqlite.ImageReaderWriter); err != nil {
return fmt.Errorf("error linking image performers: %s", err.Error())
}
- if err := linkImageTags(r.Image()); err != nil {
+ if err := linkImageTags(ctx, sqlite.ImageReaderWriter); err != nil {
return fmt.Errorf("error linking image tags: %s", err.Error())
}
- if err := linkImageStudios(r.Image()); err != nil {
+ if err := linkImageStudios(ctx, sqlite.ImageReaderWriter); err != nil {
return fmt.Errorf("error linking image studio: %s", err.Error())
}
- if err := linkMovieStudios(r.Movie()); err != nil {
+ if err := linkMovieStudios(ctx, sqlite.MovieReaderWriter); err != nil {
return fmt.Errorf("error linking movie studios: %s", err.Error())
}
- if err := linkStudiosParent(r.Studio()); err != nil {
+ if err := linkStudiosParent(ctx, sqlite.StudioReaderWriter); err != nil {
return fmt.Errorf("error linking studios parent: %s", err.Error())
}
- if err := linkGalleryPerformers(r.Gallery()); err != nil {
+ if err := linkGalleryPerformers(ctx, sqlite.GalleryReaderWriter); err != nil {
return fmt.Errorf("error linking gallery performers: %s", err.Error())
}
- if err := linkGalleryTags(r.Gallery()); err != nil {
+ if err := linkGalleryTags(ctx, sqlite.GalleryReaderWriter); err != nil {
return fmt.Errorf("error linking gallery tags: %s", err.Error())
}
- if err := linkGalleryStudios(r.Gallery()); err != nil {
+ if err := linkGalleryStudios(ctx, sqlite.GalleryReaderWriter); err != nil {
return fmt.Errorf("error linking gallery studios: %s", err.Error())
}
- if err := linkTagsParent(r.Tag()); err != nil {
+ if err := linkTagsParent(ctx, sqlite.TagReaderWriter); err != nil {
return fmt.Errorf("error linking tags parent: %s", err.Error())
}
for _, ms := range markerSpecs {
- if err := createMarker(r.SceneMarker(), ms); err != nil {
+ if err := createMarker(ctx, sqlite.SceneMarkerReaderWriter, ms); err != nil {
return fmt.Errorf("error creating scene marker: %s", err.Error())
}
}
@@ -642,7 +645,7 @@ func getObjectDate(index int) models.SQLiteDate {
}
}
-func createScenes(sqb models.SceneReaderWriter, n int) error {
+func createScenes(ctx context.Context, sqb models.SceneReaderWriter, n int) error {
for i := 0; i < n; i++ {
scene := models.Scene{
Path: getSceneStringValue(i, pathField),
@@ -658,7 +661,7 @@ func createScenes(sqb models.SceneReaderWriter, n int) error {
Date: getObjectDate(i),
}
- created, err := sqb.Create(scene)
+ created, err := sqb.Create(ctx, scene)
if err != nil {
return fmt.Errorf("Error creating scene %v+: %s", scene, err.Error())
@@ -683,7 +686,7 @@ func getImagePath(index int) string {
return getImageStringValue(index, pathField)
}
-func createImages(qb models.ImageReaderWriter, n int) error {
+func createImages(ctx context.Context, qb models.ImageReaderWriter, n int) error {
for i := 0; i < n; i++ {
image := models.Image{
Path: getImagePath(i),
@@ -695,7 +698,7 @@ func createImages(qb models.ImageReaderWriter, n int) error {
Width: getWidth(i),
}
- created, err := qb.Create(image)
+ created, err := qb.Create(ctx, image)
if err != nil {
return fmt.Errorf("Error creating image %v+: %s", image, err.Error())
@@ -715,7 +718,7 @@ func getGalleryNullStringValue(index int, field string) sql.NullString {
return getPrefixedNullStringValue("gallery", index, field)
}
-func createGalleries(gqb models.GalleryReaderWriter, n int) error {
+func createGalleries(ctx context.Context, gqb models.GalleryReaderWriter, n int) error {
for i := 0; i < n; i++ {
gallery := models.Gallery{
Path: models.NullString(getGalleryStringValue(i, pathField)),
@@ -726,7 +729,7 @@ func createGalleries(gqb models.GalleryReaderWriter, n int) error {
Date: getObjectDate(i),
}
- created, err := gqb.Create(gallery)
+ created, err := gqb.Create(ctx, gallery)
if err != nil {
return fmt.Errorf("Error creating gallery %v+: %s", gallery, err.Error())
@@ -747,7 +750,7 @@ func getMovieNullStringValue(index int, field string) sql.NullString {
}
// createMoviees creates n movies with plain Name and o movies with camel cased NaMe included
-func createMovies(mqb models.MovieReaderWriter, n int, o int) error {
+func createMovies(ctx context.Context, mqb models.MovieReaderWriter, n int, o int) error {
const namePlain = "Name"
const nameNoCase = "NaMe"
@@ -768,7 +771,7 @@ func createMovies(mqb models.MovieReaderWriter, n int, o int) error {
Checksum: md5.FromString(name),
}
- created, err := mqb.Create(movie)
+ created, err := mqb.Create(ctx, movie)
if err != nil {
return fmt.Errorf("Error creating movie [%d] %v+: %s", i, movie, err.Error())
@@ -828,7 +831,7 @@ func getIgnoreAutoTag(index int) bool {
}
// createPerformers creates n performers with plain Name and o performers with camel cased NaMe included
-func createPerformers(pqb models.PerformerReaderWriter, n int, o int) error {
+func createPerformers(ctx context.Context, pqb models.PerformerReaderWriter, n int, o int) error {
const namePlain = "Name"
const nameNoCase = "NaMe"
@@ -864,7 +867,7 @@ func createPerformers(pqb models.PerformerReaderWriter, n int, o int) error {
performer.CareerLength = models.NullString(*careerLength)
}
- created, err := pqb.Create(performer)
+ created, err := pqb.Create(ctx, performer)
if err != nil {
return fmt.Errorf("Error creating performer %v+: %s", performer, err.Error())
@@ -942,7 +945,7 @@ func getTagChildCount(id int) int {
}
//createTags creates n tags with plain Name and o tags with camel cased NaMe included
-func createTags(tqb models.TagReaderWriter, n int, o int) error {
+func createTags(ctx context.Context, tqb models.TagReaderWriter, n int, o int) error {
const namePlain = "Name"
const nameNoCase = "NaMe"
@@ -962,7 +965,7 @@ func createTags(tqb models.TagReaderWriter, n int, o int) error {
IgnoreAutoTag: getIgnoreAutoTag(i),
}
- created, err := tqb.Create(tag)
+ created, err := tqb.Create(ctx, tag)
if err != nil {
return fmt.Errorf("Error creating tag %v+: %s", tag, err.Error())
@@ -970,7 +973,7 @@ func createTags(tqb models.TagReaderWriter, n int, o int) error {
// add alias
alias := getTagStringValue(i, "Alias")
- if err := tqb.UpdateAliases(created.ID, []string{alias}); err != nil {
+ if err := tqb.UpdateAliases(ctx, created.ID, []string{alias}); err != nil {
return fmt.Errorf("error setting tag alias: %s", err.Error())
}
@@ -989,7 +992,7 @@ func getStudioNullStringValue(index int, field string) sql.NullString {
return getPrefixedNullStringValue("studio", index, field)
}
-func createStudio(sqb models.StudioReaderWriter, name string, parentID *int64) (*models.Studio, error) {
+func createStudio(ctx context.Context, sqb models.StudioReaderWriter, name string, parentID *int64) (*models.Studio, error) {
studio := models.Studio{
Name: sql.NullString{String: name, Valid: true},
Checksum: md5.FromString(name),
@@ -999,11 +1002,11 @@ func createStudio(sqb models.StudioReaderWriter, name string, parentID *int64) (
studio.ParentID = sql.NullInt64{Int64: *parentID, Valid: true}
}
- return createStudioFromModel(sqb, studio)
+ return createStudioFromModel(ctx, sqb, studio)
}
-func createStudioFromModel(sqb models.StudioReaderWriter, studio models.Studio) (*models.Studio, error) {
- created, err := sqb.Create(studio)
+func createStudioFromModel(ctx context.Context, sqb models.StudioReaderWriter, studio models.Studio) (*models.Studio, error) {
+ created, err := sqb.Create(ctx, studio)
if err != nil {
return nil, fmt.Errorf("Error creating studio %v+: %s", studio, err.Error())
@@ -1013,7 +1016,7 @@ func createStudioFromModel(sqb models.StudioReaderWriter, studio models.Studio)
}
// createStudios creates n studios with plain Name and o studios with camel cased NaMe included
-func createStudios(sqb models.StudioReaderWriter, n int, o int) error {
+func createStudios(ctx context.Context, sqb models.StudioReaderWriter, n int, o int) error {
const namePlain = "Name"
const nameNoCase = "NaMe"
@@ -1034,7 +1037,7 @@ func createStudios(sqb models.StudioReaderWriter, n int, o int) error {
URL: getStudioNullStringValue(index, urlField),
IgnoreAutoTag: getIgnoreAutoTag(i),
}
- created, err := createStudioFromModel(sqb, studio)
+ created, err := createStudioFromModel(ctx, sqb, studio)
if err != nil {
return err
@@ -1042,7 +1045,7 @@ func createStudios(sqb models.StudioReaderWriter, n int, o int) error {
// add alias
alias := getStudioStringValue(i, "Alias")
- if err := sqb.UpdateAliases(created.ID, []string{alias}); err != nil {
+ if err := sqb.UpdateAliases(ctx, created.ID, []string{alias}); err != nil {
return fmt.Errorf("error setting studio alias: %s", err.Error())
}
@@ -1053,13 +1056,13 @@ func createStudios(sqb models.StudioReaderWriter, n int, o int) error {
return nil
}
-func createMarker(mqb models.SceneMarkerReaderWriter, markerSpec markerSpec) error {
+func createMarker(ctx context.Context, mqb models.SceneMarkerReaderWriter, markerSpec markerSpec) error {
marker := models.SceneMarker{
SceneID: sql.NullInt64{Int64: int64(sceneIDs[markerSpec.sceneIdx]), Valid: true},
PrimaryTagID: tagIDs[markerSpec.primaryTagIdx],
}
- created, err := mqb.Create(marker)
+ created, err := mqb.Create(ctx, marker)
if err != nil {
return fmt.Errorf("error creating marker %v+: %w", marker, err)
@@ -1074,7 +1077,7 @@ func createMarker(mqb models.SceneMarkerReaderWriter, markerSpec markerSpec) err
newTagIDs = append(newTagIDs, tagIDs[tagIdx])
}
- if err := mqb.UpdateTags(created.ID, newTagIDs); err != nil {
+ if err := mqb.UpdateTags(ctx, created.ID, newTagIDs); err != nil {
return fmt.Errorf("error creating marker/tag join: %w", err)
}
}
@@ -1107,7 +1110,7 @@ func getSavedFilterName(index int) string {
return getPrefixedStringValue("savedFilter", index, "Name")
}
-func createSavedFilters(qb models.SavedFilterReaderWriter, n int) error {
+func createSavedFilters(ctx context.Context, qb models.SavedFilterReaderWriter, n int) error {
for i := 0; i < n; i++ {
savedFilter := models.SavedFilter{
Mode: getSavedFilterMode(i),
@@ -1115,7 +1118,7 @@ func createSavedFilters(qb models.SavedFilterReaderWriter, n int) error {
Filter: getPrefixedStringValue("savedFilter", i, "Filter"),
}
- created, err := qb.Create(savedFilter)
+ created, err := qb.Create(ctx, savedFilter)
if err != nil {
return fmt.Errorf("Error creating saved filter %v+: %s", savedFilter, err.Error())
@@ -1137,25 +1140,25 @@ func doLinks(links [][2]int, fn func(idx1, idx2 int) error) error {
return nil
}
-func linkPerformerTags(qb models.PerformerReaderWriter) error {
+func linkPerformerTags(ctx context.Context, qb models.PerformerReaderWriter) error {
return doLinks(performerTagLinks, func(performerIndex, tagIndex int) error {
performerID := performerIDs[performerIndex]
tagID := tagIDs[tagIndex]
- tagIDs, err := qb.GetTagIDs(performerID)
+ tagIDs, err := qb.GetTagIDs(ctx, performerID)
if err != nil {
return err
}
tagIDs = intslice.IntAppendUnique(tagIDs, tagID)
- return qb.UpdateTags(performerID, tagIDs)
+ return qb.UpdateTags(ctx, performerID, tagIDs)
})
}
-func linkSceneMovies(qb models.SceneReaderWriter) error {
+func linkSceneMovies(ctx context.Context, qb models.SceneReaderWriter) error {
return doLinks(sceneMovieLinks, func(sceneIndex, movieIndex int) error {
sceneID := sceneIDs[sceneIndex]
- movies, err := qb.GetMovies(sceneID)
+ movies, err := qb.GetMovies(ctx, sceneID)
if err != nil {
return err
}
@@ -1164,157 +1167,157 @@ func linkSceneMovies(qb models.SceneReaderWriter) error {
MovieID: movieIDs[movieIndex],
SceneID: sceneID,
})
- return qb.UpdateMovies(sceneID, movies)
+ return qb.UpdateMovies(ctx, sceneID, movies)
})
}
-func linkScenePerformers(qb models.SceneReaderWriter) error {
+func linkScenePerformers(ctx context.Context, qb models.SceneReaderWriter) error {
return doLinks(scenePerformerLinks, func(sceneIndex, performerIndex int) error {
- _, err := scene.AddPerformer(qb, sceneIDs[sceneIndex], performerIDs[performerIndex])
+ _, err := scene.AddPerformer(ctx, qb, sceneIDs[sceneIndex], performerIDs[performerIndex])
return err
})
}
-func linkSceneGalleries(qb models.SceneReaderWriter) error {
+func linkSceneGalleries(ctx context.Context, qb models.SceneReaderWriter) error {
return doLinks(sceneGalleryLinks, func(sceneIndex, galleryIndex int) error {
- _, err := scene.AddGallery(qb, sceneIDs[sceneIndex], galleryIDs[galleryIndex])
+ _, err := scene.AddGallery(ctx, qb, sceneIDs[sceneIndex], galleryIDs[galleryIndex])
return err
})
}
-func linkSceneTags(qb models.SceneReaderWriter) error {
+func linkSceneTags(ctx context.Context, qb models.SceneReaderWriter) error {
return doLinks(sceneTagLinks, func(sceneIndex, tagIndex int) error {
- _, err := scene.AddTag(qb, sceneIDs[sceneIndex], tagIDs[tagIndex])
+ _, err := scene.AddTag(ctx, qb, sceneIDs[sceneIndex], tagIDs[tagIndex])
return err
})
}
-func linkSceneStudios(sqb models.SceneWriter) error {
+func linkSceneStudios(ctx context.Context, sqb models.SceneWriter) error {
return doLinks(sceneStudioLinks, func(sceneIndex, studioIndex int) error {
scene := models.ScenePartial{
ID: sceneIDs[sceneIndex],
StudioID: &sql.NullInt64{Int64: int64(studioIDs[studioIndex]), Valid: true},
}
- _, err := sqb.Update(scene)
+ _, err := sqb.Update(ctx, scene)
return err
})
}
-func linkImageGalleries(gqb models.GalleryReaderWriter) error {
+func linkImageGalleries(ctx context.Context, gqb models.GalleryReaderWriter) error {
return doLinks(imageGalleryLinks, func(imageIndex, galleryIndex int) error {
- return gallery.AddImage(gqb, galleryIDs[galleryIndex], imageIDs[imageIndex])
+ return gallery.AddImage(ctx, gqb, galleryIDs[galleryIndex], imageIDs[imageIndex])
})
}
-func linkImageTags(iqb models.ImageReaderWriter) error {
+func linkImageTags(ctx context.Context, iqb models.ImageReaderWriter) error {
return doLinks(imageTagLinks, func(imageIndex, tagIndex int) error {
imageID := imageIDs[imageIndex]
- tags, err := iqb.GetTagIDs(imageID)
+ tags, err := iqb.GetTagIDs(ctx, imageID)
if err != nil {
return err
}
tags = append(tags, tagIDs[tagIndex])
- return iqb.UpdateTags(imageID, tags)
+ return iqb.UpdateTags(ctx, imageID, tags)
})
}
-func linkImageStudios(qb models.ImageWriter) error {
+func linkImageStudios(ctx context.Context, qb models.ImageWriter) error {
return doLinks(imageStudioLinks, func(imageIndex, studioIndex int) error {
image := models.ImagePartial{
ID: imageIDs[imageIndex],
StudioID: &sql.NullInt64{Int64: int64(studioIDs[studioIndex]), Valid: true},
}
- _, err := qb.Update(image)
+ _, err := qb.Update(ctx, image)
return err
})
}
-func linkImagePerformers(qb models.ImageReaderWriter) error {
+func linkImagePerformers(ctx context.Context, qb models.ImageReaderWriter) error {
return doLinks(imagePerformerLinks, func(imageIndex, performerIndex int) error {
imageID := imageIDs[imageIndex]
- performers, err := qb.GetPerformerIDs(imageID)
+ performers, err := qb.GetPerformerIDs(ctx, imageID)
if err != nil {
return err
}
performers = append(performers, performerIDs[performerIndex])
- return qb.UpdatePerformers(imageID, performers)
+ return qb.UpdatePerformers(ctx, imageID, performers)
})
}
-func linkGalleryPerformers(qb models.GalleryReaderWriter) error {
+func linkGalleryPerformers(ctx context.Context, qb models.GalleryReaderWriter) error {
return doLinks(galleryPerformerLinks, func(galleryIndex, performerIndex int) error {
galleryID := galleryIDs[galleryIndex]
- performers, err := qb.GetPerformerIDs(galleryID)
+ performers, err := qb.GetPerformerIDs(ctx, galleryID)
if err != nil {
return err
}
performers = append(performers, performerIDs[performerIndex])
- return qb.UpdatePerformers(galleryID, performers)
+ return qb.UpdatePerformers(ctx, galleryID, performers)
})
}
-func linkGalleryStudios(qb models.GalleryReaderWriter) error {
+func linkGalleryStudios(ctx context.Context, qb models.GalleryReaderWriter) error {
return doLinks(galleryStudioLinks, func(galleryIndex, studioIndex int) error {
gallery := models.GalleryPartial{
ID: galleryIDs[galleryIndex],
StudioID: &sql.NullInt64{Int64: int64(studioIDs[studioIndex]), Valid: true},
}
- _, err := qb.UpdatePartial(gallery)
+ _, err := qb.UpdatePartial(ctx, gallery)
return err
})
}
-func linkGalleryTags(qb models.GalleryReaderWriter) error {
+func linkGalleryTags(ctx context.Context, qb models.GalleryReaderWriter) error {
return doLinks(galleryTagLinks, func(galleryIndex, tagIndex int) error {
galleryID := galleryIDs[galleryIndex]
- tags, err := qb.GetTagIDs(galleryID)
+ tags, err := qb.GetTagIDs(ctx, galleryID)
if err != nil {
return err
}
tags = append(tags, tagIDs[tagIndex])
- return qb.UpdateTags(galleryID, tags)
+ return qb.UpdateTags(ctx, galleryID, tags)
})
}
-func linkMovieStudios(mqb models.MovieWriter) error {
+func linkMovieStudios(ctx context.Context, mqb models.MovieWriter) error {
return doLinks(movieStudioLinks, func(movieIndex, studioIndex int) error {
movie := models.MoviePartial{
ID: movieIDs[movieIndex],
StudioID: &sql.NullInt64{Int64: int64(studioIDs[studioIndex]), Valid: true},
}
- _, err := mqb.Update(movie)
+ _, err := mqb.Update(ctx, movie)
return err
})
}
-func linkStudiosParent(qb models.StudioWriter) error {
+func linkStudiosParent(ctx context.Context, qb models.StudioWriter) error {
return doLinks(studioParentLinks, func(parentIndex, childIndex int) error {
studio := models.StudioPartial{
ID: studioIDs[childIndex],
ParentID: &sql.NullInt64{Int64: int64(studioIDs[parentIndex]), Valid: true},
}
- _, err := qb.Update(studio)
+ _, err := qb.Update(ctx, studio)
return err
})
}
-func linkTagsParent(qb models.TagReaderWriter) error {
+func linkTagsParent(ctx context.Context, qb models.TagReaderWriter) error {
return doLinks(tagParentLinks, func(parentIndex, childIndex int) error {
tagID := tagIDs[childIndex]
- parentTags, err := qb.FindByChildTagID(tagID)
+ parentTags, err := qb.FindByChildTagID(ctx, tagID)
if err != nil {
return err
}
@@ -1326,10 +1329,10 @@ func linkTagsParent(qb models.TagReaderWriter) error {
parentIDs = append(parentIDs, tagIDs[parentIndex])
- return qb.UpdateParentTags(tagID, parentIDs)
+ return qb.UpdateParentTags(ctx, tagID, parentIDs)
})
}
-func addTagImage(qb models.TagWriter, tagIndex int) error {
- return qb.UpdateImage(tagIDs[tagIndex], models.DefaultTagImage)
+func addTagImage(ctx context.Context, qb models.TagWriter, tagIndex int) error {
+ return qb.UpdateImage(ctx, tagIDs[tagIndex], models.DefaultTagImage)
}
diff --git a/pkg/sqlite/sql.go b/pkg/sqlite/sql.go
index a83612b0b..353e1a130 100644
--- a/pkg/sqlite/sql.go
+++ b/pkg/sqlite/sql.go
@@ -1,6 +1,7 @@
package sqlite
import (
+ "context"
"database/sql"
"errors"
"fmt"
@@ -225,8 +226,8 @@ func getCountCriterionClause(primaryTable, joinTable, primaryFK string, criterio
return getIntCriterionWhereClause(lhs, criterion)
}
-func getImage(tx dbi, query string, args ...interface{}) ([]byte, error) {
- rows, err := tx.Queryx(query, args...)
+func getImage(ctx context.Context, tx dbi, query string, args ...interface{}) ([]byte, error) {
+ rows, err := tx.Queryx(ctx, query, args...)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return nil, err
diff --git a/pkg/sqlite/stash_id_test.go b/pkg/sqlite/stash_id_test.go
index 0f57bef19..44d08efbc 100644
--- a/pkg/sqlite/stash_id_test.go
+++ b/pkg/sqlite/stash_id_test.go
@@ -4,6 +4,7 @@
package sqlite_test
import (
+ "context"
"testing"
"github.com/stashapp/stash/pkg/models"
@@ -11,16 +12,16 @@ import (
)
type stashIDReaderWriter interface {
- GetStashIDs(performerID int) ([]*models.StashID, error)
- UpdateStashIDs(performerID int, stashIDs []models.StashID) error
+ GetStashIDs(ctx context.Context, performerID int) ([]*models.StashID, error)
+ UpdateStashIDs(ctx context.Context, performerID int, stashIDs []models.StashID) error
}
-func testStashIDReaderWriter(t *testing.T, r stashIDReaderWriter, id int) {
+func testStashIDReaderWriter(ctx context.Context, t *testing.T, r stashIDReaderWriter, id int) {
// ensure no stash IDs to begin with
- testNoStashIDs(t, r, id)
+ testNoStashIDs(ctx, t, r, id)
// ensure GetStashIDs with non-existing also returns none
- testNoStashIDs(t, r, -1)
+ testNoStashIDs(ctx, t, r, -1)
// add stash ids
const stashIDStr = "stashID"
@@ -31,28 +32,28 @@ func testStashIDReaderWriter(t *testing.T, r stashIDReaderWriter, id int) {
}
// update stash ids and ensure was updated
- if err := r.UpdateStashIDs(id, []models.StashID{stashID}); err != nil {
+ if err := r.UpdateStashIDs(ctx, id, []models.StashID{stashID}); err != nil {
t.Error(err.Error())
}
- testStashIDs(t, r, id, []*models.StashID{&stashID})
+ testStashIDs(ctx, t, r, id, []*models.StashID{&stashID})
// update non-existing id - should return error
- if err := r.UpdateStashIDs(-1, []models.StashID{stashID}); err == nil {
+ if err := r.UpdateStashIDs(ctx, -1, []models.StashID{stashID}); err == nil {
t.Error("expected error when updating non-existing id")
}
// remove stash ids and ensure was updated
- if err := r.UpdateStashIDs(id, []models.StashID{}); err != nil {
+ if err := r.UpdateStashIDs(ctx, id, []models.StashID{}); err != nil {
t.Error(err.Error())
}
- testNoStashIDs(t, r, id)
+ testNoStashIDs(ctx, t, r, id)
}
-func testNoStashIDs(t *testing.T, r stashIDReaderWriter, id int) {
+func testNoStashIDs(ctx context.Context, t *testing.T, r stashIDReaderWriter, id int) {
t.Helper()
- stashIDs, err := r.GetStashIDs(id)
+ stashIDs, err := r.GetStashIDs(ctx, id)
if err != nil {
t.Error(err.Error())
return
@@ -61,9 +62,9 @@ func testNoStashIDs(t *testing.T, r stashIDReaderWriter, id int) {
assert.Len(t, stashIDs, 0)
}
-func testStashIDs(t *testing.T, r stashIDReaderWriter, id int, expected []*models.StashID) {
+func testStashIDs(ctx context.Context, t *testing.T, r stashIDReaderWriter, id int, expected []*models.StashID) {
t.Helper()
- stashIDs, err := r.GetStashIDs(id)
+ stashIDs, err := r.GetStashIDs(ctx, id)
if err != nil {
t.Error(err.Error())
return
diff --git a/pkg/sqlite/studio.go b/pkg/sqlite/studio.go
index cc810fe0c..966257a1c 100644
--- a/pkg/sqlite/studio.go
+++ b/pkg/sqlite/studio.go
@@ -1,6 +1,7 @@
package sqlite
import (
+ "context"
"database/sql"
"errors"
"fmt"
@@ -18,57 +19,54 @@ type studioQueryBuilder struct {
repository
}
-func NewStudioReaderWriter(tx dbi) *studioQueryBuilder {
- return &studioQueryBuilder{
- repository{
- tx: tx,
- tableName: studioTable,
- idColumn: idColumn,
- },
- }
+var StudioReaderWriter = &studioQueryBuilder{
+ repository{
+ tableName: studioTable,
+ idColumn: idColumn,
+ },
}
-func (qb *studioQueryBuilder) Create(newObject models.Studio) (*models.Studio, error) {
+func (qb *studioQueryBuilder) Create(ctx context.Context, newObject models.Studio) (*models.Studio, error) {
var ret models.Studio
- if err := qb.insertObject(newObject, &ret); err != nil {
+ if err := qb.insertObject(ctx, newObject, &ret); err != nil {
return nil, err
}
return &ret, nil
}
-func (qb *studioQueryBuilder) Update(updatedObject models.StudioPartial) (*models.Studio, error) {
+func (qb *studioQueryBuilder) Update(ctx context.Context, updatedObject models.StudioPartial) (*models.Studio, error) {
const partial = true
- if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil {
+ if err := qb.update(ctx, updatedObject.ID, updatedObject, partial); err != nil {
return nil, err
}
- return qb.Find(updatedObject.ID)
+ return qb.Find(ctx, updatedObject.ID)
}
-func (qb *studioQueryBuilder) UpdateFull(updatedObject models.Studio) (*models.Studio, error) {
+func (qb *studioQueryBuilder) UpdateFull(ctx context.Context, updatedObject models.Studio) (*models.Studio, error) {
const partial = false
- if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil {
+ if err := qb.update(ctx, updatedObject.ID, updatedObject, partial); err != nil {
return nil, err
}
- return qb.Find(updatedObject.ID)
+ return qb.Find(ctx, updatedObject.ID)
}
-func (qb *studioQueryBuilder) Destroy(id int) error {
+func (qb *studioQueryBuilder) Destroy(ctx context.Context, id int) error {
// TODO - set null on foreign key in scraped items
// remove studio from scraped items
- _, err := qb.tx.Exec("UPDATE scraped_items SET studio_id = null WHERE studio_id = ?", id)
+ _, err := qb.tx.Exec(ctx, "UPDATE scraped_items SET studio_id = null WHERE studio_id = ?", id)
if err != nil {
return err
}
- return qb.destroyExisting([]int{id})
+ return qb.destroyExisting(ctx, []int{id})
}
-func (qb *studioQueryBuilder) Find(id int) (*models.Studio, error) {
+func (qb *studioQueryBuilder) Find(ctx context.Context, id int) (*models.Studio, error) {
var ret models.Studio
- if err := qb.get(id, &ret); err != nil {
+ if err := qb.getByID(ctx, id, &ret); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
@@ -77,10 +75,10 @@ func (qb *studioQueryBuilder) Find(id int) (*models.Studio, error) {
return &ret, nil
}
-func (qb *studioQueryBuilder) FindMany(ids []int) ([]*models.Studio, error) {
+func (qb *studioQueryBuilder) FindMany(ctx context.Context, ids []int) ([]*models.Studio, error) {
var studios []*models.Studio
for _, id := range ids {
- studio, err := qb.Find(id)
+ studio, err := qb.Find(ctx, id)
if err != nil {
return nil, err
}
@@ -95,47 +93,47 @@ func (qb *studioQueryBuilder) FindMany(ids []int) ([]*models.Studio, error) {
return studios, nil
}
-func (qb *studioQueryBuilder) FindChildren(id int) ([]*models.Studio, error) {
+func (qb *studioQueryBuilder) FindChildren(ctx context.Context, id int) ([]*models.Studio, error) {
query := "SELECT studios.* FROM studios WHERE studios.parent_id = ?"
args := []interface{}{id}
- return qb.queryStudios(query, args)
+ return qb.queryStudios(ctx, query, args)
}
-func (qb *studioQueryBuilder) FindBySceneID(sceneID int) (*models.Studio, error) {
+func (qb *studioQueryBuilder) FindBySceneID(ctx context.Context, sceneID int) (*models.Studio, error) {
query := "SELECT studios.* FROM studios JOIN scenes ON studios.id = scenes.studio_id WHERE scenes.id = ? LIMIT 1"
args := []interface{}{sceneID}
- return qb.queryStudio(query, args)
+ return qb.queryStudio(ctx, query, args)
}
-func (qb *studioQueryBuilder) FindByName(name string, nocase bool) (*models.Studio, error) {
+func (qb *studioQueryBuilder) FindByName(ctx context.Context, name string, nocase bool) (*models.Studio, error) {
query := "SELECT * FROM studios WHERE name = ?"
if nocase {
query += " COLLATE NOCASE"
}
query += " LIMIT 1"
args := []interface{}{name}
- return qb.queryStudio(query, args)
+ return qb.queryStudio(ctx, query, args)
}
-func (qb *studioQueryBuilder) FindByStashID(stashID models.StashID) ([]*models.Studio, error) {
+func (qb *studioQueryBuilder) FindByStashID(ctx context.Context, stashID models.StashID) ([]*models.Studio, error) {
query := selectAll("studios") + `
LEFT JOIN studio_stash_ids on studio_stash_ids.studio_id = studios.id
WHERE studio_stash_ids.stash_id = ?
AND studio_stash_ids.endpoint = ?
`
args := []interface{}{stashID.StashID, stashID.Endpoint}
- return qb.queryStudios(query, args)
+ return qb.queryStudios(ctx, query, args)
}
-func (qb *studioQueryBuilder) Count() (int, error) {
- return qb.runCountQuery(qb.buildCountQuery("SELECT studios.id FROM studios"), nil)
+func (qb *studioQueryBuilder) Count(ctx context.Context) (int, error) {
+ return qb.runCountQuery(ctx, qb.buildCountQuery("SELECT studios.id FROM studios"), nil)
}
-func (qb *studioQueryBuilder) All() ([]*models.Studio, error) {
- return qb.queryStudios(selectAll("studios")+qb.getStudioSort(nil), nil)
+func (qb *studioQueryBuilder) All(ctx context.Context) ([]*models.Studio, error) {
+ return qb.queryStudios(ctx, selectAll("studios")+qb.getStudioSort(nil), nil)
}
-func (qb *studioQueryBuilder) QueryForAutoTag(words []string) ([]*models.Studio, error) {
+func (qb *studioQueryBuilder) QueryForAutoTag(ctx context.Context, words []string) ([]*models.Studio, error) {
// TODO - Query needs to be changed to support queries of this type, and
// this method should be removed
query := selectAll(studioTable)
@@ -159,7 +157,7 @@ func (qb *studioQueryBuilder) QueryForAutoTag(words []string) ([]*models.Studio,
"studios.ignore_auto_tag = 0",
whereOr,
}, " AND ")
- return qb.queryStudios(query+" WHERE "+where, args)
+ return qb.queryStudios(ctx, query+" WHERE "+where, args)
}
func (qb *studioQueryBuilder) validateFilter(filter *models.StudioFilterType) error {
@@ -193,43 +191,43 @@ func (qb *studioQueryBuilder) validateFilter(filter *models.StudioFilterType) er
return nil
}
-func (qb *studioQueryBuilder) makeFilter(studioFilter *models.StudioFilterType) *filterBuilder {
+func (qb *studioQueryBuilder) makeFilter(ctx context.Context, studioFilter *models.StudioFilterType) *filterBuilder {
query := &filterBuilder{}
if studioFilter.And != nil {
- query.and(qb.makeFilter(studioFilter.And))
+ query.and(qb.makeFilter(ctx, studioFilter.And))
}
if studioFilter.Or != nil {
- query.or(qb.makeFilter(studioFilter.Or))
+ query.or(qb.makeFilter(ctx, studioFilter.Or))
}
if studioFilter.Not != nil {
- query.not(qb.makeFilter(studioFilter.Not))
+ query.not(qb.makeFilter(ctx, studioFilter.Not))
}
- query.handleCriterion(stringCriterionHandler(studioFilter.Name, studioTable+".name"))
- query.handleCriterion(stringCriterionHandler(studioFilter.Details, studioTable+".details"))
- query.handleCriterion(stringCriterionHandler(studioFilter.URL, studioTable+".url"))
- query.handleCriterion(intCriterionHandler(studioFilter.Rating, studioTable+".rating"))
- query.handleCriterion(boolCriterionHandler(studioFilter.IgnoreAutoTag, studioTable+".ignore_auto_tag"))
+ query.handleCriterion(ctx, stringCriterionHandler(studioFilter.Name, studioTable+".name"))
+ query.handleCriterion(ctx, stringCriterionHandler(studioFilter.Details, studioTable+".details"))
+ query.handleCriterion(ctx, stringCriterionHandler(studioFilter.URL, studioTable+".url"))
+ query.handleCriterion(ctx, intCriterionHandler(studioFilter.Rating, studioTable+".rating"))
+ query.handleCriterion(ctx, boolCriterionHandler(studioFilter.IgnoreAutoTag, studioTable+".ignore_auto_tag"))
- query.handleCriterion(criterionHandlerFunc(func(f *filterBuilder) {
+ query.handleCriterion(ctx, criterionHandlerFunc(func(ctx context.Context, f *filterBuilder) {
if studioFilter.StashID != nil {
qb.stashIDRepository().join(f, "studio_stash_ids", "studios.id")
- stringCriterionHandler(studioFilter.StashID, "studio_stash_ids.stash_id")(f)
+ stringCriterionHandler(studioFilter.StashID, "studio_stash_ids.stash_id")(ctx, f)
}
}))
- query.handleCriterion(studioIsMissingCriterionHandler(qb, studioFilter.IsMissing))
- query.handleCriterion(studioSceneCountCriterionHandler(qb, studioFilter.SceneCount))
- query.handleCriterion(studioImageCountCriterionHandler(qb, studioFilter.ImageCount))
- query.handleCriterion(studioGalleryCountCriterionHandler(qb, studioFilter.GalleryCount))
- query.handleCriterion(studioParentCriterionHandler(qb, studioFilter.Parents))
- query.handleCriterion(studioAliasCriterionHandler(qb, studioFilter.Aliases))
+ query.handleCriterion(ctx, studioIsMissingCriterionHandler(qb, studioFilter.IsMissing))
+ query.handleCriterion(ctx, studioSceneCountCriterionHandler(qb, studioFilter.SceneCount))
+ query.handleCriterion(ctx, studioImageCountCriterionHandler(qb, studioFilter.ImageCount))
+ query.handleCriterion(ctx, studioGalleryCountCriterionHandler(qb, studioFilter.GalleryCount))
+ query.handleCriterion(ctx, studioParentCriterionHandler(qb, studioFilter.Parents))
+ query.handleCriterion(ctx, studioAliasCriterionHandler(qb, studioFilter.Aliases))
return query
}
-func (qb *studioQueryBuilder) Query(studioFilter *models.StudioFilterType, findFilter *models.FindFilterType) ([]*models.Studio, int, error) {
+func (qb *studioQueryBuilder) Query(ctx context.Context, studioFilter *models.StudioFilterType, findFilter *models.FindFilterType) ([]*models.Studio, int, error) {
if studioFilter == nil {
studioFilter = &models.StudioFilterType{}
}
@@ -250,19 +248,19 @@ func (qb *studioQueryBuilder) Query(studioFilter *models.StudioFilterType, findF
if err := qb.validateFilter(studioFilter); err != nil {
return nil, 0, err
}
- filter := qb.makeFilter(studioFilter)
+ filter := qb.makeFilter(ctx, studioFilter)
query.addFilter(filter)
query.sortAndPagination = qb.getStudioSort(findFilter) + getPagination(findFilter)
- idsResult, countResult, err := query.executeFind()
+ idsResult, countResult, err := query.executeFind(ctx)
if err != nil {
return nil, 0, err
}
var studios []*models.Studio
for _, id := range idsResult {
- studio, err := qb.Find(id)
+ studio, err := qb.Find(ctx, id)
if err != nil {
return nil, 0, err
}
@@ -274,7 +272,7 @@ func (qb *studioQueryBuilder) Query(studioFilter *models.StudioFilterType, findF
}
func studioIsMissingCriterionHandler(qb *studioQueryBuilder, isMissing *string) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
if isMissing != nil && *isMissing != "" {
switch *isMissing {
case "image":
@@ -291,7 +289,7 @@ func studioIsMissingCriterionHandler(qb *studioQueryBuilder, isMissing *string)
}
func studioSceneCountCriterionHandler(qb *studioQueryBuilder, sceneCount *models.IntCriterionInput) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
if sceneCount != nil {
f.addLeftJoin("scenes", "", "scenes.studio_id = studios.id")
clause, args := getIntCriterionWhereClause("count(distinct scenes.id)", *sceneCount)
@@ -302,7 +300,7 @@ func studioSceneCountCriterionHandler(qb *studioQueryBuilder, sceneCount *models
}
func studioImageCountCriterionHandler(qb *studioQueryBuilder, imageCount *models.IntCriterionInput) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
if imageCount != nil {
f.addLeftJoin("images", "", "images.studio_id = studios.id")
clause, args := getIntCriterionWhereClause("count(distinct images.id)", *imageCount)
@@ -313,7 +311,7 @@ func studioImageCountCriterionHandler(qb *studioQueryBuilder, imageCount *models
}
func studioGalleryCountCriterionHandler(qb *studioQueryBuilder, galleryCount *models.IntCriterionInput) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
if galleryCount != nil {
f.addLeftJoin("galleries", "", "galleries.studio_id = studios.id")
clause, args := getIntCriterionWhereClause("count(distinct galleries.id)", *galleryCount)
@@ -373,17 +371,17 @@ func (qb *studioQueryBuilder) getStudioSort(findFilter *models.FindFilterType) s
}
}
-func (qb *studioQueryBuilder) queryStudio(query string, args []interface{}) (*models.Studio, error) {
- results, err := qb.queryStudios(query, args)
+func (qb *studioQueryBuilder) queryStudio(ctx context.Context, query string, args []interface{}) (*models.Studio, error) {
+ results, err := qb.queryStudios(ctx, query, args)
if err != nil || len(results) < 1 {
return nil, err
}
return results[0], nil
}
-func (qb *studioQueryBuilder) queryStudios(query string, args []interface{}) ([]*models.Studio, error) {
+func (qb *studioQueryBuilder) queryStudios(ctx context.Context, query string, args []interface{}) ([]*models.Studio, error) {
var ret models.Studios
- if err := qb.query(query, args, &ret); err != nil {
+ if err := qb.query(ctx, query, args, &ret); err != nil {
return nil, err
}
@@ -401,20 +399,20 @@ func (qb *studioQueryBuilder) imageRepository() *imageRepository {
}
}
-func (qb *studioQueryBuilder) GetImage(studioID int) ([]byte, error) {
- return qb.imageRepository().get(studioID)
+func (qb *studioQueryBuilder) GetImage(ctx context.Context, studioID int) ([]byte, error) {
+ return qb.imageRepository().get(ctx, studioID)
}
-func (qb *studioQueryBuilder) HasImage(studioID int) (bool, error) {
- return qb.imageRepository().exists(studioID)
+func (qb *studioQueryBuilder) HasImage(ctx context.Context, studioID int) (bool, error) {
+ return qb.imageRepository().exists(ctx, studioID)
}
-func (qb *studioQueryBuilder) UpdateImage(studioID int, image []byte) error {
- return qb.imageRepository().replace(studioID, image)
+func (qb *studioQueryBuilder) UpdateImage(ctx context.Context, studioID int, image []byte) error {
+ return qb.imageRepository().replace(ctx, studioID, image)
}
-func (qb *studioQueryBuilder) DestroyImage(studioID int) error {
- return qb.imageRepository().destroy([]int{studioID})
+func (qb *studioQueryBuilder) DestroyImage(ctx context.Context, studioID int) error {
+ return qb.imageRepository().destroy(ctx, []int{studioID})
}
func (qb *studioQueryBuilder) stashIDRepository() *stashIDRepository {
@@ -427,12 +425,12 @@ func (qb *studioQueryBuilder) stashIDRepository() *stashIDRepository {
}
}
-func (qb *studioQueryBuilder) GetStashIDs(studioID int) ([]*models.StashID, error) {
- return qb.stashIDRepository().get(studioID)
+func (qb *studioQueryBuilder) GetStashIDs(ctx context.Context, studioID int) ([]*models.StashID, error) {
+ return qb.stashIDRepository().get(ctx, studioID)
}
-func (qb *studioQueryBuilder) UpdateStashIDs(studioID int, stashIDs []models.StashID) error {
- return qb.stashIDRepository().replace(studioID, stashIDs)
+func (qb *studioQueryBuilder) UpdateStashIDs(ctx context.Context, studioID int, stashIDs []models.StashID) error {
+ return qb.stashIDRepository().replace(ctx, studioID, stashIDs)
}
func (qb *studioQueryBuilder) aliasRepository() *stringRepository {
@@ -446,10 +444,10 @@ func (qb *studioQueryBuilder) aliasRepository() *stringRepository {
}
}
-func (qb *studioQueryBuilder) GetAliases(studioID int) ([]string, error) {
- return qb.aliasRepository().get(studioID)
+func (qb *studioQueryBuilder) GetAliases(ctx context.Context, studioID int) ([]string, error) {
+ return qb.aliasRepository().get(ctx, studioID)
}
-func (qb *studioQueryBuilder) UpdateAliases(studioID int, aliases []string) error {
- return qb.aliasRepository().replace(studioID, aliases)
+func (qb *studioQueryBuilder) UpdateAliases(ctx context.Context, studioID int, aliases []string) error {
+ return qb.aliasRepository().replace(ctx, studioID, aliases)
}
diff --git a/pkg/sqlite/studio_test.go b/pkg/sqlite/studio_test.go
index 08e6a30da..28b162328 100644
--- a/pkg/sqlite/studio_test.go
+++ b/pkg/sqlite/studio_test.go
@@ -4,6 +4,7 @@
package sqlite_test
import (
+ "context"
"database/sql"
"errors"
"fmt"
@@ -13,16 +14,17 @@ import (
"testing"
"github.com/stashapp/stash/pkg/models"
+ "github.com/stashapp/stash/pkg/sqlite"
"github.com/stretchr/testify/assert"
)
func TestStudioFindByName(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Studio()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.StudioReaderWriter
name := studioNames[studioIdxWithScene] // find a studio by name
- studio, err := sqb.FindByName(name, false)
+ studio, err := sqb.FindByName(ctx, name, false)
if err != nil {
t.Errorf("Error finding studios: %s", err.Error())
@@ -32,7 +34,7 @@ func TestStudioFindByName(t *testing.T) {
name = studioNames[studioIdxWithDupName] // find a studio by name nocase
- studio, err = sqb.FindByName(name, true)
+ studio, err = sqb.FindByName(ctx, name, true)
if err != nil {
t.Errorf("Error finding studios: %s", err.Error())
@@ -67,10 +69,10 @@ func TestStudioQueryNameOr(t *testing.T) {
},
}
- withTxn(func(r models.Repository) error {
- sqb := r.Studio()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.StudioReaderWriter
- studios := queryStudio(t, sqb, &studioFilter, nil)
+ studios := queryStudio(ctx, t, sqb, &studioFilter, nil)
assert.Len(t, studios, 2)
assert.Equal(t, studio1Name, studios[0].Name.String)
@@ -98,10 +100,10 @@ func TestStudioQueryNameAndUrl(t *testing.T) {
},
}
- withTxn(func(r models.Repository) error {
- sqb := r.Studio()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.StudioReaderWriter
- studios := queryStudio(t, sqb, &studioFilter, nil)
+ studios := queryStudio(ctx, t, sqb, &studioFilter, nil)
assert.Len(t, studios, 1)
assert.Equal(t, studioName, studios[0].Name.String)
@@ -133,10 +135,10 @@ func TestStudioQueryNameNotUrl(t *testing.T) {
},
}
- withTxn(func(r models.Repository) error {
- sqb := r.Studio()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.StudioReaderWriter
- studios := queryStudio(t, sqb, &studioFilter, nil)
+ studios := queryStudio(ctx, t, sqb, &studioFilter, nil)
for _, studio := range studios {
verifyString(t, studio.Name.String, nameCriterion)
@@ -164,20 +166,20 @@ func TestStudioIllegalQuery(t *testing.T) {
Or: &subFilter,
}
- withTxn(func(r models.Repository) error {
- sqb := r.Studio()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.StudioReaderWriter
- _, _, err := sqb.Query(studioFilter, nil)
+ _, _, err := sqb.Query(ctx, studioFilter, nil)
assert.NotNil(err)
studioFilter.Or = nil
studioFilter.Not = &subFilter
- _, _, err = sqb.Query(studioFilter, nil)
+ _, _, err = sqb.Query(ctx, studioFilter, nil)
assert.NotNil(err)
studioFilter.And = nil
studioFilter.Or = &subFilter
- _, _, err = sqb.Query(studioFilter, nil)
+ _, _, err = sqb.Query(ctx, studioFilter, nil)
assert.NotNil(err)
return nil
@@ -185,15 +187,15 @@ func TestStudioIllegalQuery(t *testing.T) {
}
func TestStudioQueryIgnoreAutoTag(t *testing.T) {
- withTxn(func(r models.Repository) error {
+ withTxn(func(ctx context.Context) error {
ignoreAutoTag := true
studioFilter := models.StudioFilterType{
IgnoreAutoTag: &ignoreAutoTag,
}
- sqb := r.Studio()
+ sqb := sqlite.StudioReaderWriter
- studios := queryStudio(t, sqb, &studioFilter, nil)
+ studios := queryStudio(ctx, t, sqb, &studioFilter, nil)
assert.Len(t, studios, int(math.Ceil(float64(totalStudios)/5)))
for _, s := range studios {
@@ -205,12 +207,12 @@ func TestStudioQueryIgnoreAutoTag(t *testing.T) {
}
func TestStudioQueryForAutoTag(t *testing.T) {
- withTxn(func(r models.Repository) error {
- tqb := r.Studio()
+ withTxn(func(ctx context.Context) error {
+ tqb := sqlite.StudioReaderWriter
name := studioNames[studioIdxWithMovie] // find a studio by name
- studios, err := tqb.QueryForAutoTag([]string{name})
+ studios, err := tqb.QueryForAutoTag(ctx, []string{name})
if err != nil {
t.Errorf("Error finding studios: %s", err.Error())
@@ -221,7 +223,7 @@ func TestStudioQueryForAutoTag(t *testing.T) {
// find by alias
name = getStudioStringValue(studioIdxWithMovie, "Alias")
- studios, err = tqb.QueryForAutoTag([]string{name})
+ studios, err = tqb.QueryForAutoTag(ctx, []string{name})
if err != nil {
t.Errorf("Error finding studios: %s", err.Error())
@@ -235,8 +237,8 @@ func TestStudioQueryForAutoTag(t *testing.T) {
}
func TestStudioQueryParent(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Studio()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.StudioReaderWriter
studioCriterion := models.MultiCriterionInput{
Value: []string{
strconv.Itoa(studioIDs[studioIdxWithChildStudio]),
@@ -248,7 +250,7 @@ func TestStudioQueryParent(t *testing.T) {
Parents: &studioCriterion,
}
- studios, _, err := sqb.Query(&studioFilter, nil)
+ studios, _, err := sqb.Query(ctx, &studioFilter, nil)
if err != nil {
t.Errorf("Error querying studio: %s", err.Error())
}
@@ -270,7 +272,7 @@ func TestStudioQueryParent(t *testing.T) {
Q: &q,
}
- studios, _, err = sqb.Query(&studioFilter, &findFilter)
+ studios, _, err = sqb.Query(ctx, &studioFilter, &findFilter)
if err != nil {
t.Errorf("Error querying studio: %s", err.Error())
}
@@ -285,28 +287,28 @@ func TestStudioDestroyParent(t *testing.T) {
const childName = "child"
// create parent and child studios
- if err := withTxn(func(r models.Repository) error {
- createdParent, err := createStudio(r.Studio(), parentName, nil)
+ if err := withTxn(func(ctx context.Context) error {
+ createdParent, err := createStudio(ctx, sqlite.StudioReaderWriter, parentName, nil)
if err != nil {
return fmt.Errorf("Error creating parent studio: %s", err.Error())
}
parentID := int64(createdParent.ID)
- createdChild, err := createStudio(r.Studio(), childName, &parentID)
+ createdChild, err := createStudio(ctx, sqlite.StudioReaderWriter, childName, &parentID)
if err != nil {
return fmt.Errorf("Error creating child studio: %s", err.Error())
}
- sqb := r.Studio()
+ sqb := sqlite.StudioReaderWriter
// destroy the parent
- err = sqb.Destroy(createdParent.ID)
+ err = sqb.Destroy(ctx, createdParent.ID)
if err != nil {
return fmt.Errorf("Error destroying parent studio: %s", err.Error())
}
// destroy the child
- err = sqb.Destroy(createdChild.ID)
+ err = sqb.Destroy(ctx, createdChild.ID)
if err != nil {
return fmt.Errorf("Error destroying child studio: %s", err.Error())
}
@@ -318,10 +320,10 @@ func TestStudioDestroyParent(t *testing.T) {
}
func TestStudioFindChildren(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Studio()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.StudioReaderWriter
- studios, err := sqb.FindChildren(studioIDs[studioIdxWithChildStudio])
+ studios, err := sqb.FindChildren(ctx, studioIDs[studioIdxWithChildStudio])
if err != nil {
t.Errorf("error calling FindChildren: %s", err.Error())
@@ -330,7 +332,7 @@ func TestStudioFindChildren(t *testing.T) {
assert.Len(t, studios, 1)
assert.Equal(t, studioIDs[studioIdxWithParentStudio], studios[0].ID)
- studios, err = sqb.FindChildren(0)
+ studios, err = sqb.FindChildren(ctx, 0)
if err != nil {
t.Errorf("error calling FindChildren: %s", err.Error())
@@ -347,19 +349,19 @@ func TestStudioUpdateClearParent(t *testing.T) {
const childName = "clearParent_child"
// create parent and child studios
- if err := withTxn(func(r models.Repository) error {
- createdParent, err := createStudio(r.Studio(), parentName, nil)
+ if err := withTxn(func(ctx context.Context) error {
+ createdParent, err := createStudio(ctx, sqlite.StudioReaderWriter, parentName, nil)
if err != nil {
return fmt.Errorf("Error creating parent studio: %s", err.Error())
}
parentID := int64(createdParent.ID)
- createdChild, err := createStudio(r.Studio(), childName, &parentID)
+ createdChild, err := createStudio(ctx, sqlite.StudioReaderWriter, childName, &parentID)
if err != nil {
return fmt.Errorf("Error creating child studio: %s", err.Error())
}
- sqb := r.Studio()
+ sqb := sqlite.StudioReaderWriter
// clear the parent id from the child
updatePartial := models.StudioPartial{
@@ -367,7 +369,7 @@ func TestStudioUpdateClearParent(t *testing.T) {
ParentID: &sql.NullInt64{Valid: false},
}
- updatedStudio, err := sqb.Update(updatePartial)
+ updatedStudio, err := sqb.Update(ctx, updatePartial)
if err != nil {
return fmt.Errorf("Error updated studio: %s", err.Error())
@@ -384,31 +386,31 @@ func TestStudioUpdateClearParent(t *testing.T) {
}
func TestStudioUpdateStudioImage(t *testing.T) {
- if err := withTxn(func(r models.Repository) error {
- qb := r.Studio()
+ if err := withTxn(func(ctx context.Context) error {
+ qb := sqlite.StudioReaderWriter
// create performer to test against
const name = "TestStudioUpdateStudioImage"
- created, err := createStudio(r.Studio(), name, nil)
+ created, err := createStudio(ctx, sqlite.StudioReaderWriter, name, nil)
if err != nil {
return fmt.Errorf("Error creating studio: %s", err.Error())
}
image := []byte("image")
- err = qb.UpdateImage(created.ID, image)
+ err = qb.UpdateImage(ctx, created.ID, image)
if err != nil {
return fmt.Errorf("Error updating studio image: %s", err.Error())
}
// ensure image set
- storedImage, err := qb.GetImage(created.ID)
+ storedImage, err := qb.GetImage(ctx, created.ID)
if err != nil {
return fmt.Errorf("Error getting image: %s", err.Error())
}
assert.Equal(t, storedImage, image)
// set nil image
- err = qb.UpdateImage(created.ID, nil)
+ err = qb.UpdateImage(ctx, created.ID, nil)
if err == nil {
return fmt.Errorf("Expected error setting nil image")
}
@@ -420,29 +422,29 @@ func TestStudioUpdateStudioImage(t *testing.T) {
}
func TestStudioDestroyStudioImage(t *testing.T) {
- if err := withTxn(func(r models.Repository) error {
- qb := r.Studio()
+ if err := withTxn(func(ctx context.Context) error {
+ qb := sqlite.StudioReaderWriter
// create performer to test against
const name = "TestStudioDestroyStudioImage"
- created, err := createStudio(r.Studio(), name, nil)
+ created, err := createStudio(ctx, sqlite.StudioReaderWriter, name, nil)
if err != nil {
return fmt.Errorf("Error creating studio: %s", err.Error())
}
image := []byte("image")
- err = qb.UpdateImage(created.ID, image)
+ err = qb.UpdateImage(ctx, created.ID, image)
if err != nil {
return fmt.Errorf("Error updating studio image: %s", err.Error())
}
- err = qb.DestroyImage(created.ID)
+ err = qb.DestroyImage(ctx, created.ID)
if err != nil {
return fmt.Errorf("Error destroying studio image: %s", err.Error())
}
// image should be nil
- storedImage, err := qb.GetImage(created.ID)
+ storedImage, err := qb.GetImage(ctx, created.ID)
if err != nil {
return fmt.Errorf("Error getting image: %s", err.Error())
}
@@ -474,17 +476,17 @@ func TestStudioQuerySceneCount(t *testing.T) {
}
func verifyStudiosSceneCount(t *testing.T, sceneCountCriterion models.IntCriterionInput) {
- withTxn(func(r models.Repository) error {
- sqb := r.Studio()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.StudioReaderWriter
studioFilter := models.StudioFilterType{
SceneCount: &sceneCountCriterion,
}
- studios := queryStudio(t, sqb, &studioFilter, nil)
+ studios := queryStudio(ctx, t, sqb, &studioFilter, nil)
assert.Greater(t, len(studios), 0)
for _, studio := range studios {
- sceneCount, err := r.Scene().CountByStudioID(studio.ID)
+ sceneCount, err := sqlite.SceneReaderWriter.CountByStudioID(ctx, studio.ID)
if err != nil {
return err
}
@@ -515,19 +517,19 @@ func TestStudioQueryImageCount(t *testing.T) {
}
func verifyStudiosImageCount(t *testing.T, imageCountCriterion models.IntCriterionInput) {
- withTxn(func(r models.Repository) error {
- sqb := r.Studio()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.StudioReaderWriter
studioFilter := models.StudioFilterType{
ImageCount: &imageCountCriterion,
}
- studios := queryStudio(t, sqb, &studioFilter, nil)
+ studios := queryStudio(ctx, t, sqb, &studioFilter, nil)
assert.Greater(t, len(studios), 0)
for _, studio := range studios {
pp := 0
- result, err := r.Image().Query(models.ImageQueryOptions{
+ result, err := sqlite.ImageReaderWriter.Query(ctx, models.ImageQueryOptions{
QueryOptions: models.QueryOptions{
FindFilter: &models.FindFilterType{
PerPage: &pp,
@@ -571,19 +573,19 @@ func TestStudioQueryGalleryCount(t *testing.T) {
}
func verifyStudiosGalleryCount(t *testing.T, galleryCountCriterion models.IntCriterionInput) {
- withTxn(func(r models.Repository) error {
- sqb := r.Studio()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.StudioReaderWriter
studioFilter := models.StudioFilterType{
GalleryCount: &galleryCountCriterion,
}
- studios := queryStudio(t, sqb, &studioFilter, nil)
+ studios := queryStudio(ctx, t, sqb, &studioFilter, nil)
assert.Greater(t, len(studios), 0)
for _, studio := range studios {
pp := 0
- _, count, err := r.Gallery().Query(&models.GalleryFilterType{
+ _, count, err := sqlite.GalleryReaderWriter.Query(ctx, &models.GalleryFilterType{
Studios: &models.HierarchicalMultiCriterionInput{
Value: []string{strconv.Itoa(studio.ID)},
Modifier: models.CriterionModifierIncludes,
@@ -602,17 +604,17 @@ func verifyStudiosGalleryCount(t *testing.T, galleryCountCriterion models.IntCri
}
func TestStudioStashIDs(t *testing.T) {
- if err := withTxn(func(r models.Repository) error {
- qb := r.Studio()
+ if err := withTxn(func(ctx context.Context) error {
+ qb := sqlite.StudioReaderWriter
// create studio to test against
const name = "TestStudioStashIDs"
- created, err := createStudio(r.Studio(), name, nil)
+ created, err := createStudio(ctx, sqlite.StudioReaderWriter, name, nil)
if err != nil {
return fmt.Errorf("Error creating studio: %s", err.Error())
}
- testStashIDReaderWriter(t, qb, created.ID)
+ testStashIDReaderWriter(ctx, t, qb, created.ID)
return nil
}); err != nil {
t.Error(err.Error())
@@ -632,7 +634,7 @@ func TestStudioQueryURL(t *testing.T) {
URL: &urlCriterion,
}
- verifyFn := func(g *models.Studio, r models.Repository) {
+ verifyFn := func(ctx context.Context, g *models.Studio) {
t.Helper()
verifyNullString(t, g.URL, urlCriterion)
}
@@ -682,18 +684,18 @@ func TestStudioQueryRating(t *testing.T) {
verifyStudiosRating(t, ratingCriterion)
}
-func verifyStudioQuery(t *testing.T, filter models.StudioFilterType, verifyFn func(s *models.Studio, r models.Repository)) {
- withTxn(func(r models.Repository) error {
+func verifyStudioQuery(t *testing.T, filter models.StudioFilterType, verifyFn func(ctx context.Context, s *models.Studio)) {
+ withTxn(func(ctx context.Context) error {
t.Helper()
- sqb := r.Studio()
+ sqb := sqlite.StudioReaderWriter
- studios := queryStudio(t, sqb, &filter, nil)
+ studios := queryStudio(ctx, t, sqb, &filter, nil)
// assume it should find at least one
assert.Greater(t, len(studios), 0)
for _, studio := range studios {
- verifyFn(studio, r)
+ verifyFn(ctx, studio)
}
return nil
@@ -701,13 +703,13 @@ func verifyStudioQuery(t *testing.T, filter models.StudioFilterType, verifyFn fu
}
func verifyStudiosRating(t *testing.T, ratingCriterion models.IntCriterionInput) {
- withTxn(func(r models.Repository) error {
- sqb := r.Studio()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.StudioReaderWriter
studioFilter := models.StudioFilterType{
Rating: &ratingCriterion,
}
- studios, _, err := sqb.Query(&studioFilter, nil)
+ studios, _, err := sqb.Query(ctx, &studioFilter, nil)
if err != nil {
t.Errorf("Error querying studio: %s", err.Error())
@@ -722,14 +724,14 @@ func verifyStudiosRating(t *testing.T, ratingCriterion models.IntCriterionInput)
}
func TestStudioQueryIsMissingRating(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Studio()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.StudioReaderWriter
isMissing := "rating"
studioFilter := models.StudioFilterType{
IsMissing: &isMissing,
}
- studios, _, err := sqb.Query(&studioFilter, nil)
+ studios, _, err := sqb.Query(ctx, &studioFilter, nil)
if err != nil {
t.Errorf("Error querying studio: %s", err.Error())
@@ -745,8 +747,8 @@ func TestStudioQueryIsMissingRating(t *testing.T) {
})
}
-func queryStudio(t *testing.T, sqb models.StudioReader, studioFilter *models.StudioFilterType, findFilter *models.FindFilterType) []*models.Studio {
- studios, _, err := sqb.Query(studioFilter, findFilter)
+func queryStudio(ctx context.Context, t *testing.T, sqb models.StudioReader, studioFilter *models.StudioFilterType, findFilter *models.FindFilterType) []*models.Studio {
+ studios, _, err := sqb.Query(ctx, studioFilter, findFilter)
if err != nil {
t.Errorf("Error querying studio: %s", err.Error())
}
@@ -767,7 +769,7 @@ func TestStudioQueryName(t *testing.T) {
Name: nameCriterion,
}
- verifyFn := func(studio *models.Studio, r models.Repository) {
+ verifyFn := func(ctx context.Context, studio *models.Studio) {
verifyNullString(t, studio.Name, *nameCriterion)
}
@@ -797,8 +799,8 @@ func TestStudioQueryAlias(t *testing.T) {
Aliases: aliasCriterion,
}
- verifyFn := func(studio *models.Studio, r models.Repository) {
- aliases, err := r.Studio().GetAliases(studio.ID)
+ verifyFn := func(ctx context.Context, studio *models.Studio) {
+ aliases, err := sqlite.StudioReaderWriter.GetAliases(ctx, studio.ID)
if err != nil {
t.Errorf("Error querying studios: %s", err.Error())
}
@@ -825,24 +827,24 @@ func TestStudioQueryAlias(t *testing.T) {
}
func TestStudioUpdateAlias(t *testing.T) {
- if err := withTxn(func(r models.Repository) error {
- qb := r.Studio()
+ if err := withTxn(func(ctx context.Context) error {
+ qb := sqlite.StudioReaderWriter
// create studio to test against
const name = "TestStudioUpdateAlias"
- created, err := createStudio(qb, name, nil)
+ created, err := createStudio(ctx, qb, name, nil)
if err != nil {
return fmt.Errorf("Error creating studio: %s", err.Error())
}
aliases := []string{"alias1", "alias2"}
- err = qb.UpdateAliases(created.ID, aliases)
+ err = qb.UpdateAliases(ctx, created.ID, aliases)
if err != nil {
return fmt.Errorf("Error updating studio aliases: %s", err.Error())
}
// ensure aliases set
- storedAliases, err := qb.GetAliases(created.ID)
+ storedAliases, err := qb.GetAliases(ctx, created.ID)
if err != nil {
return fmt.Errorf("Error getting aliases: %s", err.Error())
}
@@ -922,11 +924,11 @@ func TestStudioQueryFast(t *testing.T) {
}
- withTxn(func(r models.Repository) error {
- sqb := r.Studio()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.StudioReaderWriter
for _, f := range filters {
for _, ff := range findFilters {
- _, _, err := sqb.Query(&f, &ff)
+ _, _, err := sqb.Query(ctx, &f, &ff)
if err != nil {
t.Errorf("Error querying studio: %s", err.Error())
}
diff --git a/pkg/sqlite/tag.go b/pkg/sqlite/tag.go
index 9513a269b..02e853fdc 100644
--- a/pkg/sqlite/tag.go
+++ b/pkg/sqlite/tag.go
@@ -1,6 +1,7 @@
package sqlite
import (
+ "context"
"database/sql"
"errors"
"fmt"
@@ -18,53 +19,50 @@ type tagQueryBuilder struct {
repository
}
-func NewTagReaderWriter(tx dbi) *tagQueryBuilder {
- return &tagQueryBuilder{
- repository{
- tx: tx,
- tableName: tagTable,
- idColumn: idColumn,
- },
- }
+var TagReaderWriter = &tagQueryBuilder{
+ repository{
+ tableName: tagTable,
+ idColumn: idColumn,
+ },
}
-func (qb *tagQueryBuilder) Create(newObject models.Tag) (*models.Tag, error) {
+func (qb *tagQueryBuilder) Create(ctx context.Context, newObject models.Tag) (*models.Tag, error) {
var ret models.Tag
- if err := qb.insertObject(newObject, &ret); err != nil {
+ if err := qb.insertObject(ctx, newObject, &ret); err != nil {
return nil, err
}
return &ret, nil
}
-func (qb *tagQueryBuilder) Update(updatedObject models.TagPartial) (*models.Tag, error) {
+func (qb *tagQueryBuilder) Update(ctx context.Context, updatedObject models.TagPartial) (*models.Tag, error) {
const partial = true
- if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil {
+ if err := qb.update(ctx, updatedObject.ID, updatedObject, partial); err != nil {
return nil, err
}
- return qb.Find(updatedObject.ID)
+ return qb.Find(ctx, updatedObject.ID)
}
-func (qb *tagQueryBuilder) UpdateFull(updatedObject models.Tag) (*models.Tag, error) {
+func (qb *tagQueryBuilder) UpdateFull(ctx context.Context, updatedObject models.Tag) (*models.Tag, error) {
const partial = false
- if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil {
+ if err := qb.update(ctx, updatedObject.ID, updatedObject, partial); err != nil {
return nil, err
}
- return qb.Find(updatedObject.ID)
+ return qb.Find(ctx, updatedObject.ID)
}
-func (qb *tagQueryBuilder) Destroy(id int) error {
+func (qb *tagQueryBuilder) Destroy(ctx context.Context, id int) error {
// TODO - add delete cascade to foreign key
// delete tag from scenes and markers first
- _, err := qb.tx.Exec("DELETE FROM scenes_tags WHERE tag_id = ?", id)
+ _, err := qb.tx.Exec(ctx, "DELETE FROM scenes_tags WHERE tag_id = ?", id)
if err != nil {
return err
}
// TODO - add delete cascade to foreign key
- _, err = qb.tx.Exec("DELETE FROM scene_markers_tags WHERE tag_id = ?", id)
+ _, err = qb.tx.Exec(ctx, "DELETE FROM scene_markers_tags WHERE tag_id = ?", id)
if err != nil {
return err
}
@@ -72,7 +70,7 @@ func (qb *tagQueryBuilder) Destroy(id int) error {
// cannot unset primary_tag_id in scene_markers because it is not nullable
countQuery := "SELECT COUNT(*) as count FROM scene_markers where primary_tag_id = ?"
args := []interface{}{id}
- primaryMarkers, err := qb.runCountQuery(countQuery, args)
+ primaryMarkers, err := qb.runCountQuery(ctx, countQuery, args)
if err != nil {
return err
}
@@ -81,12 +79,12 @@ func (qb *tagQueryBuilder) Destroy(id int) error {
return errors.New("cannot delete tag used as a primary tag in scene markers")
}
- return qb.destroyExisting([]int{id})
+ return qb.destroyExisting(ctx, []int{id})
}
-func (qb *tagQueryBuilder) Find(id int) (*models.Tag, error) {
+func (qb *tagQueryBuilder) Find(ctx context.Context, id int) (*models.Tag, error) {
var ret models.Tag
- if err := qb.get(id, &ret); err != nil {
+ if err := qb.getByID(ctx, id, &ret); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
@@ -95,10 +93,10 @@ func (qb *tagQueryBuilder) Find(id int) (*models.Tag, error) {
return &ret, nil
}
-func (qb *tagQueryBuilder) FindMany(ids []int) ([]*models.Tag, error) {
+func (qb *tagQueryBuilder) FindMany(ctx context.Context, ids []int) ([]*models.Tag, error) {
var tags []*models.Tag
for _, id := range ids {
- tag, err := qb.Find(id)
+ tag, err := qb.Find(ctx, id)
if err != nil {
return nil, err
}
@@ -113,7 +111,7 @@ func (qb *tagQueryBuilder) FindMany(ids []int) ([]*models.Tag, error) {
return tags, nil
}
-func (qb *tagQueryBuilder) FindBySceneID(sceneID int) ([]*models.Tag, error) {
+func (qb *tagQueryBuilder) FindBySceneID(ctx context.Context, sceneID int) ([]*models.Tag, error) {
query := `
SELECT tags.* FROM tags
LEFT JOIN scenes_tags as scenes_join on scenes_join.tag_id = tags.id
@@ -122,10 +120,10 @@ func (qb *tagQueryBuilder) FindBySceneID(sceneID int) ([]*models.Tag, error) {
`
query += qb.getDefaultTagSort()
args := []interface{}{sceneID}
- return qb.queryTags(query, args)
+ return qb.queryTags(ctx, query, args)
}
-func (qb *tagQueryBuilder) FindByPerformerID(performerID int) ([]*models.Tag, error) {
+func (qb *tagQueryBuilder) FindByPerformerID(ctx context.Context, performerID int) ([]*models.Tag, error) {
query := `
SELECT tags.* FROM tags
LEFT JOIN performers_tags as performers_join on performers_join.tag_id = tags.id
@@ -134,10 +132,10 @@ func (qb *tagQueryBuilder) FindByPerformerID(performerID int) ([]*models.Tag, er
`
query += qb.getDefaultTagSort()
args := []interface{}{performerID}
- return qb.queryTags(query, args)
+ return qb.queryTags(ctx, query, args)
}
-func (qb *tagQueryBuilder) FindByImageID(imageID int) ([]*models.Tag, error) {
+func (qb *tagQueryBuilder) FindByImageID(ctx context.Context, imageID int) ([]*models.Tag, error) {
query := `
SELECT tags.* FROM tags
LEFT JOIN images_tags as images_join on images_join.tag_id = tags.id
@@ -146,10 +144,10 @@ func (qb *tagQueryBuilder) FindByImageID(imageID int) ([]*models.Tag, error) {
`
query += qb.getDefaultTagSort()
args := []interface{}{imageID}
- return qb.queryTags(query, args)
+ return qb.queryTags(ctx, query, args)
}
-func (qb *tagQueryBuilder) FindByGalleryID(galleryID int) ([]*models.Tag, error) {
+func (qb *tagQueryBuilder) FindByGalleryID(ctx context.Context, galleryID int) ([]*models.Tag, error) {
query := `
SELECT tags.* FROM tags
LEFT JOIN galleries_tags as galleries_join on galleries_join.tag_id = tags.id
@@ -158,10 +156,10 @@ func (qb *tagQueryBuilder) FindByGalleryID(galleryID int) ([]*models.Tag, error)
`
query += qb.getDefaultTagSort()
args := []interface{}{galleryID}
- return qb.queryTags(query, args)
+ return qb.queryTags(ctx, query, args)
}
-func (qb *tagQueryBuilder) FindBySceneMarkerID(sceneMarkerID int) ([]*models.Tag, error) {
+func (qb *tagQueryBuilder) FindBySceneMarkerID(ctx context.Context, sceneMarkerID int) ([]*models.Tag, error) {
query := `
SELECT tags.* FROM tags
LEFT JOIN scene_markers_tags as scene_markers_join on scene_markers_join.tag_id = tags.id
@@ -170,20 +168,20 @@ func (qb *tagQueryBuilder) FindBySceneMarkerID(sceneMarkerID int) ([]*models.Tag
`
query += qb.getDefaultTagSort()
args := []interface{}{sceneMarkerID}
- return qb.queryTags(query, args)
+ return qb.queryTags(ctx, query, args)
}
-func (qb *tagQueryBuilder) FindByName(name string, nocase bool) (*models.Tag, error) {
+func (qb *tagQueryBuilder) FindByName(ctx context.Context, name string, nocase bool) (*models.Tag, error) {
query := "SELECT * FROM tags WHERE name = ?"
if nocase {
query += " COLLATE NOCASE"
}
query += " LIMIT 1"
args := []interface{}{name}
- return qb.queryTag(query, args)
+ return qb.queryTag(ctx, query, args)
}
-func (qb *tagQueryBuilder) FindByNames(names []string, nocase bool) ([]*models.Tag, error) {
+func (qb *tagQueryBuilder) FindByNames(ctx context.Context, names []string, nocase bool) ([]*models.Tag, error) {
query := "SELECT * FROM tags WHERE name"
if nocase {
query += " COLLATE NOCASE"
@@ -193,10 +191,10 @@ func (qb *tagQueryBuilder) FindByNames(names []string, nocase bool) ([]*models.T
for _, name := range names {
args = append(args, name)
}
- return qb.queryTags(query, args)
+ return qb.queryTags(ctx, query, args)
}
-func (qb *tagQueryBuilder) FindByParentTagID(parentID int) ([]*models.Tag, error) {
+func (qb *tagQueryBuilder) FindByParentTagID(ctx context.Context, parentID int) ([]*models.Tag, error) {
query := `
SELECT tags.* FROM tags
INNER JOIN tags_relations ON tags_relations.child_id = tags.id
@@ -204,10 +202,10 @@ func (qb *tagQueryBuilder) FindByParentTagID(parentID int) ([]*models.Tag, error
`
query += qb.getDefaultTagSort()
args := []interface{}{parentID}
- return qb.queryTags(query, args)
+ return qb.queryTags(ctx, query, args)
}
-func (qb *tagQueryBuilder) FindByChildTagID(parentID int) ([]*models.Tag, error) {
+func (qb *tagQueryBuilder) FindByChildTagID(ctx context.Context, parentID int) ([]*models.Tag, error) {
query := `
SELECT tags.* FROM tags
INNER JOIN tags_relations ON tags_relations.parent_id = tags.id
@@ -215,18 +213,18 @@ func (qb *tagQueryBuilder) FindByChildTagID(parentID int) ([]*models.Tag, error)
`
query += qb.getDefaultTagSort()
args := []interface{}{parentID}
- return qb.queryTags(query, args)
+ return qb.queryTags(ctx, query, args)
}
-func (qb *tagQueryBuilder) Count() (int, error) {
- return qb.runCountQuery(qb.buildCountQuery("SELECT tags.id FROM tags"), nil)
+func (qb *tagQueryBuilder) Count(ctx context.Context) (int, error) {
+ return qb.runCountQuery(ctx, qb.buildCountQuery("SELECT tags.id FROM tags"), nil)
}
-func (qb *tagQueryBuilder) All() ([]*models.Tag, error) {
- return qb.queryTags(selectAll("tags")+qb.getDefaultTagSort(), nil)
+func (qb *tagQueryBuilder) All(ctx context.Context) ([]*models.Tag, error) {
+ return qb.queryTags(ctx, selectAll("tags")+qb.getDefaultTagSort(), nil)
}
-func (qb *tagQueryBuilder) QueryForAutoTag(words []string) ([]*models.Tag, error) {
+func (qb *tagQueryBuilder) QueryForAutoTag(ctx context.Context, words []string) ([]*models.Tag, error) {
// TODO - Query needs to be changed to support queries of this type, and
// this method should be removed
query := selectAll(tagTable)
@@ -250,7 +248,7 @@ func (qb *tagQueryBuilder) QueryForAutoTag(words []string) ([]*models.Tag, error
"tags.ignore_auto_tag = 0",
whereOr,
}, " AND ")
- return qb.queryTags(query+" WHERE "+where, args)
+ return qb.queryTags(ctx, query+" WHERE "+where, args)
}
func (qb *tagQueryBuilder) validateFilter(tagFilter *models.TagFilterType) error {
@@ -284,38 +282,38 @@ func (qb *tagQueryBuilder) validateFilter(tagFilter *models.TagFilterType) error
return nil
}
-func (qb *tagQueryBuilder) makeFilter(tagFilter *models.TagFilterType) *filterBuilder {
+func (qb *tagQueryBuilder) makeFilter(ctx context.Context, tagFilter *models.TagFilterType) *filterBuilder {
query := &filterBuilder{}
if tagFilter.And != nil {
- query.and(qb.makeFilter(tagFilter.And))
+ query.and(qb.makeFilter(ctx, tagFilter.And))
}
if tagFilter.Or != nil {
- query.or(qb.makeFilter(tagFilter.Or))
+ query.or(qb.makeFilter(ctx, tagFilter.Or))
}
if tagFilter.Not != nil {
- query.not(qb.makeFilter(tagFilter.Not))
+ query.not(qb.makeFilter(ctx, tagFilter.Not))
}
- query.handleCriterion(stringCriterionHandler(tagFilter.Name, tagTable+".name"))
- query.handleCriterion(tagAliasCriterionHandler(qb, tagFilter.Aliases))
- query.handleCriterion(boolCriterionHandler(tagFilter.IgnoreAutoTag, tagTable+".ignore_auto_tag"))
+ query.handleCriterion(ctx, stringCriterionHandler(tagFilter.Name, tagTable+".name"))
+ query.handleCriterion(ctx, tagAliasCriterionHandler(qb, tagFilter.Aliases))
+ query.handleCriterion(ctx, boolCriterionHandler(tagFilter.IgnoreAutoTag, tagTable+".ignore_auto_tag"))
- query.handleCriterion(tagIsMissingCriterionHandler(qb, tagFilter.IsMissing))
- query.handleCriterion(tagSceneCountCriterionHandler(qb, tagFilter.SceneCount))
- query.handleCriterion(tagImageCountCriterionHandler(qb, tagFilter.ImageCount))
- query.handleCriterion(tagGalleryCountCriterionHandler(qb, tagFilter.GalleryCount))
- query.handleCriterion(tagPerformerCountCriterionHandler(qb, tagFilter.PerformerCount))
- query.handleCriterion(tagMarkerCountCriterionHandler(qb, tagFilter.MarkerCount))
- query.handleCriterion(tagParentsCriterionHandler(qb, tagFilter.Parents))
- query.handleCriterion(tagChildrenCriterionHandler(qb, tagFilter.Children))
- query.handleCriterion(tagParentCountCriterionHandler(qb, tagFilter.ParentCount))
- query.handleCriterion(tagChildCountCriterionHandler(qb, tagFilter.ChildCount))
+ query.handleCriterion(ctx, tagIsMissingCriterionHandler(qb, tagFilter.IsMissing))
+ query.handleCriterion(ctx, tagSceneCountCriterionHandler(qb, tagFilter.SceneCount))
+ query.handleCriterion(ctx, tagImageCountCriterionHandler(qb, tagFilter.ImageCount))
+ query.handleCriterion(ctx, tagGalleryCountCriterionHandler(qb, tagFilter.GalleryCount))
+ query.handleCriterion(ctx, tagPerformerCountCriterionHandler(qb, tagFilter.PerformerCount))
+ query.handleCriterion(ctx, tagMarkerCountCriterionHandler(qb, tagFilter.MarkerCount))
+ query.handleCriterion(ctx, tagParentsCriterionHandler(qb, tagFilter.Parents))
+ query.handleCriterion(ctx, tagChildrenCriterionHandler(qb, tagFilter.Children))
+ query.handleCriterion(ctx, tagParentCountCriterionHandler(qb, tagFilter.ParentCount))
+ query.handleCriterion(ctx, tagChildCountCriterionHandler(qb, tagFilter.ChildCount))
return query
}
-func (qb *tagQueryBuilder) Query(tagFilter *models.TagFilterType, findFilter *models.FindFilterType) ([]*models.Tag, int, error) {
+func (qb *tagQueryBuilder) Query(ctx context.Context, tagFilter *models.TagFilterType, findFilter *models.FindFilterType) ([]*models.Tag, int, error) {
if tagFilter == nil {
tagFilter = &models.TagFilterType{}
}
@@ -335,19 +333,19 @@ func (qb *tagQueryBuilder) Query(tagFilter *models.TagFilterType, findFilter *mo
if err := qb.validateFilter(tagFilter); err != nil {
return nil, 0, err
}
- filter := qb.makeFilter(tagFilter)
+ filter := qb.makeFilter(ctx, tagFilter)
query.addFilter(filter)
query.sortAndPagination = qb.getTagSort(&query, findFilter) + getPagination(findFilter)
- idsResult, countResult, err := query.executeFind()
+ idsResult, countResult, err := query.executeFind(ctx)
if err != nil {
return nil, 0, err
}
var tags []*models.Tag
for _, id := range idsResult {
- tag, err := qb.Find(id)
+ tag, err := qb.Find(ctx, id)
if err != nil {
return nil, 0, err
}
@@ -370,7 +368,7 @@ func tagAliasCriterionHandler(qb *tagQueryBuilder, alias *models.StringCriterion
}
func tagIsMissingCriterionHandler(qb *tagQueryBuilder, isMissing *string) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
if isMissing != nil && *isMissing != "" {
switch *isMissing {
case "image":
@@ -384,7 +382,7 @@ func tagIsMissingCriterionHandler(qb *tagQueryBuilder, isMissing *string) criter
}
func tagSceneCountCriterionHandler(qb *tagQueryBuilder, sceneCount *models.IntCriterionInput) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
if sceneCount != nil {
f.addLeftJoin("scenes_tags", "", "scenes_tags.tag_id = tags.id")
clause, args := getIntCriterionWhereClause("count(distinct scenes_tags.scene_id)", *sceneCount)
@@ -395,7 +393,7 @@ func tagSceneCountCriterionHandler(qb *tagQueryBuilder, sceneCount *models.IntCr
}
func tagImageCountCriterionHandler(qb *tagQueryBuilder, imageCount *models.IntCriterionInput) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
if imageCount != nil {
f.addLeftJoin("images_tags", "", "images_tags.tag_id = tags.id")
clause, args := getIntCriterionWhereClause("count(distinct images_tags.image_id)", *imageCount)
@@ -406,7 +404,7 @@ func tagImageCountCriterionHandler(qb *tagQueryBuilder, imageCount *models.IntCr
}
func tagGalleryCountCriterionHandler(qb *tagQueryBuilder, galleryCount *models.IntCriterionInput) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
if galleryCount != nil {
f.addLeftJoin("galleries_tags", "", "galleries_tags.tag_id = tags.id")
clause, args := getIntCriterionWhereClause("count(distinct galleries_tags.gallery_id)", *galleryCount)
@@ -417,7 +415,7 @@ func tagGalleryCountCriterionHandler(qb *tagQueryBuilder, galleryCount *models.I
}
func tagPerformerCountCriterionHandler(qb *tagQueryBuilder, performerCount *models.IntCriterionInput) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
if performerCount != nil {
f.addLeftJoin("performers_tags", "", "performers_tags.tag_id = tags.id")
clause, args := getIntCriterionWhereClause("count(distinct performers_tags.performer_id)", *performerCount)
@@ -428,7 +426,7 @@ func tagPerformerCountCriterionHandler(qb *tagQueryBuilder, performerCount *mode
}
func tagMarkerCountCriterionHandler(qb *tagQueryBuilder, markerCount *models.IntCriterionInput) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
if markerCount != nil {
f.addLeftJoin("scene_markers_tags", "", "scene_markers_tags.tag_id = tags.id")
f.addLeftJoin("scene_markers", "", "scene_markers_tags.scene_marker_id = scene_markers.id OR scene_markers.primary_tag_id = tags.id")
@@ -440,7 +438,7 @@ func tagMarkerCountCriterionHandler(qb *tagQueryBuilder, markerCount *models.Int
}
func tagParentsCriterionHandler(qb *tagQueryBuilder, tags *models.HierarchicalMultiCriterionInput) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
if tags != nil {
if tags.Modifier == models.CriterionModifierIsNull || tags.Modifier == models.CriterionModifierNotNull {
var notClause string
@@ -489,7 +487,7 @@ func tagParentsCriterionHandler(qb *tagQueryBuilder, tags *models.HierarchicalMu
}
func tagChildrenCriterionHandler(qb *tagQueryBuilder, tags *models.HierarchicalMultiCriterionInput) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
if tags != nil {
if tags.Modifier == models.CriterionModifierIsNull || tags.Modifier == models.CriterionModifierNotNull {
var notClause string
@@ -538,7 +536,7 @@ func tagChildrenCriterionHandler(qb *tagQueryBuilder, tags *models.HierarchicalM
}
func tagParentCountCriterionHandler(qb *tagQueryBuilder, parentCount *models.IntCriterionInput) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
if parentCount != nil {
f.addLeftJoin("tags_relations", "parents_count", "parents_count.child_id = tags.id")
clause, args := getIntCriterionWhereClause("count(distinct parents_count.parent_id)", *parentCount)
@@ -549,7 +547,7 @@ func tagParentCountCriterionHandler(qb *tagQueryBuilder, parentCount *models.Int
}
func tagChildCountCriterionHandler(qb *tagQueryBuilder, childCount *models.IntCriterionInput) criterionHandlerFunc {
- return func(f *filterBuilder) {
+ return func(ctx context.Context, f *filterBuilder) {
if childCount != nil {
f.addLeftJoin("tags_relations", "children_count", "children_count.parent_id = tags.id")
clause, args := getIntCriterionWhereClause("count(distinct children_count.child_id)", *childCount)
@@ -592,17 +590,17 @@ func (qb *tagQueryBuilder) getTagSort(query *queryBuilder, findFilter *models.Fi
return getSort(sort, direction, "tags")
}
-func (qb *tagQueryBuilder) queryTag(query string, args []interface{}) (*models.Tag, error) {
- results, err := qb.queryTags(query, args)
+func (qb *tagQueryBuilder) queryTag(ctx context.Context, query string, args []interface{}) (*models.Tag, error) {
+ results, err := qb.queryTags(ctx, query, args)
if err != nil || len(results) < 1 {
return nil, err
}
return results[0], nil
}
-func (qb *tagQueryBuilder) queryTags(query string, args []interface{}) ([]*models.Tag, error) {
+func (qb *tagQueryBuilder) queryTags(ctx context.Context, query string, args []interface{}) ([]*models.Tag, error) {
var ret models.Tags
- if err := qb.query(query, args, &ret); err != nil {
+ if err := qb.query(ctx, query, args, &ret); err != nil {
return nil, err
}
@@ -620,20 +618,20 @@ func (qb *tagQueryBuilder) imageRepository() *imageRepository {
}
}
-func (qb *tagQueryBuilder) GetImage(tagID int) ([]byte, error) {
- return qb.imageRepository().get(tagID)
+func (qb *tagQueryBuilder) GetImage(ctx context.Context, tagID int) ([]byte, error) {
+ return qb.imageRepository().get(ctx, tagID)
}
-func (qb *tagQueryBuilder) HasImage(tagID int) (bool, error) {
- return qb.imageRepository().exists(tagID)
+func (qb *tagQueryBuilder) HasImage(ctx context.Context, tagID int) (bool, error) {
+ return qb.imageRepository().exists(ctx, tagID)
}
-func (qb *tagQueryBuilder) UpdateImage(tagID int, image []byte) error {
- return qb.imageRepository().replace(tagID, image)
+func (qb *tagQueryBuilder) UpdateImage(ctx context.Context, tagID int, image []byte) error {
+ return qb.imageRepository().replace(ctx, tagID, image)
}
-func (qb *tagQueryBuilder) DestroyImage(tagID int) error {
- return qb.imageRepository().destroy([]int{tagID})
+func (qb *tagQueryBuilder) DestroyImage(ctx context.Context, tagID int) error {
+ return qb.imageRepository().destroy(ctx, []int{tagID})
}
func (qb *tagQueryBuilder) aliasRepository() *stringRepository {
@@ -647,15 +645,15 @@ func (qb *tagQueryBuilder) aliasRepository() *stringRepository {
}
}
-func (qb *tagQueryBuilder) GetAliases(tagID int) ([]string, error) {
- return qb.aliasRepository().get(tagID)
+func (qb *tagQueryBuilder) GetAliases(ctx context.Context, tagID int) ([]string, error) {
+ return qb.aliasRepository().get(ctx, tagID)
}
-func (qb *tagQueryBuilder) UpdateAliases(tagID int, aliases []string) error {
- return qb.aliasRepository().replace(tagID, aliases)
+func (qb *tagQueryBuilder) UpdateAliases(ctx context.Context, tagID int, aliases []string) error {
+ return qb.aliasRepository().replace(ctx, tagID, aliases)
}
-func (qb *tagQueryBuilder) Merge(source []int, destination int) error {
+func (qb *tagQueryBuilder) Merge(ctx context.Context, source []int, destination int) error {
if len(source) == 0 {
return nil
}
@@ -680,7 +678,7 @@ func (qb *tagQueryBuilder) Merge(source []int, destination int) error {
args = append(args, destination)
for table, idColumn := range tagTables {
- _, err := qb.tx.Exec(`UPDATE `+table+`
+ _, err := qb.tx.Exec(ctx, `UPDATE `+table+`
SET tag_id = ?
WHERE tag_id IN `+inBinding+`
AND NOT EXISTS(SELECT 1 FROM `+table+` o WHERE o.`+idColumn+` = `+table+`.`+idColumn+` AND o.tag_id = ?)`,
@@ -691,23 +689,23 @@ AND NOT EXISTS(SELECT 1 FROM `+table+` o WHERE o.`+idColumn+` = `+table+`.`+idCo
}
}
- _, err := qb.tx.Exec("UPDATE "+sceneMarkerTable+" SET primary_tag_id = ? WHERE primary_tag_id IN "+inBinding, args...)
+ _, err := qb.tx.Exec(ctx, "UPDATE "+sceneMarkerTable+" SET primary_tag_id = ? WHERE primary_tag_id IN "+inBinding, args...)
if err != nil {
return err
}
- _, err = qb.tx.Exec("INSERT INTO "+tagAliasesTable+" (tag_id, alias) SELECT ?, name FROM "+tagTable+" WHERE id IN "+inBinding, args...)
+ _, err = qb.tx.Exec(ctx, "INSERT INTO "+tagAliasesTable+" (tag_id, alias) SELECT ?, name FROM "+tagTable+" WHERE id IN "+inBinding, args...)
if err != nil {
return err
}
- _, err = qb.tx.Exec("UPDATE "+tagAliasesTable+" SET tag_id = ? WHERE tag_id IN "+inBinding, args...)
+ _, err = qb.tx.Exec(ctx, "UPDATE "+tagAliasesTable+" SET tag_id = ? WHERE tag_id IN "+inBinding, args...)
if err != nil {
return err
}
for _, id := range source {
- err = qb.Destroy(id)
+ err = qb.Destroy(ctx, id)
if err != nil {
return err
}
@@ -716,9 +714,9 @@ AND NOT EXISTS(SELECT 1 FROM `+table+` o WHERE o.`+idColumn+` = `+table+`.`+idCo
return nil
}
-func (qb *tagQueryBuilder) UpdateParentTags(tagID int, parentIDs []int) error {
+func (qb *tagQueryBuilder) UpdateParentTags(ctx context.Context, tagID int, parentIDs []int) error {
tx := qb.tx
- if _, err := tx.Exec("DELETE FROM tags_relations WHERE child_id = ?", tagID); err != nil {
+ if _, err := tx.Exec(ctx, "DELETE FROM tags_relations WHERE child_id = ?", tagID); err != nil {
return err
}
@@ -731,7 +729,7 @@ func (qb *tagQueryBuilder) UpdateParentTags(tagID int, parentIDs []int) error {
}
query := "INSERT INTO tags_relations (parent_id, child_id) VALUES " + strings.Join(values, ", ")
- if _, err := tx.Exec(query, args...); err != nil {
+ if _, err := tx.Exec(ctx, query, args...); err != nil {
return err
}
}
@@ -739,9 +737,9 @@ func (qb *tagQueryBuilder) UpdateParentTags(tagID int, parentIDs []int) error {
return nil
}
-func (qb *tagQueryBuilder) UpdateChildTags(tagID int, childIDs []int) error {
+func (qb *tagQueryBuilder) UpdateChildTags(ctx context.Context, tagID int, childIDs []int) error {
tx := qb.tx
- if _, err := tx.Exec("DELETE FROM tags_relations WHERE parent_id = ?", tagID); err != nil {
+ if _, err := tx.Exec(ctx, "DELETE FROM tags_relations WHERE parent_id = ?", tagID); err != nil {
return err
}
@@ -754,7 +752,7 @@ func (qb *tagQueryBuilder) UpdateChildTags(tagID int, childIDs []int) error {
}
query := "INSERT INTO tags_relations (parent_id, child_id) VALUES " + strings.Join(values, ", ")
- if _, err := tx.Exec(query, args...); err != nil {
+ if _, err := tx.Exec(ctx, query, args...); err != nil {
return err
}
}
@@ -764,7 +762,7 @@ func (qb *tagQueryBuilder) UpdateChildTags(tagID int, childIDs []int) error {
// FindAllAncestors returns a slice of TagPath objects, representing all
// ancestors of the tag with the provided id.
-func (qb *tagQueryBuilder) FindAllAncestors(tagID int, excludeIDs []int) ([]*models.TagPath, error) {
+func (qb *tagQueryBuilder) FindAllAncestors(ctx context.Context, tagID int, excludeIDs []int) ([]*models.TagPath, error) {
inBinding := getInBinding(len(excludeIDs) + 1)
query := `WITH RECURSIVE
@@ -783,7 +781,7 @@ SELECT t.*, p.path FROM tags t INNER JOIN parents p ON t.id = p.parent_id
}
args := []interface{}{tagID}
args = append(args, append(append(excludeArgs, excludeArgs...), excludeArgs...)...)
- if err := qb.query(query, args, &ret); err != nil {
+ if err := qb.query(ctx, query, args, &ret); err != nil {
return nil, err
}
@@ -792,7 +790,7 @@ SELECT t.*, p.path FROM tags t INNER JOIN parents p ON t.id = p.parent_id
// FindAllDescendants returns a slice of TagPath objects, representing all
// descendants of the tag with the provided id.
-func (qb *tagQueryBuilder) FindAllDescendants(tagID int, excludeIDs []int) ([]*models.TagPath, error) {
+func (qb *tagQueryBuilder) FindAllDescendants(ctx context.Context, tagID int, excludeIDs []int) ([]*models.TagPath, error) {
inBinding := getInBinding(len(excludeIDs) + 1)
query := `WITH RECURSIVE
@@ -811,7 +809,7 @@ SELECT t.*, c.path FROM tags t INNER JOIN children c ON t.id = c.child_id
}
args := []interface{}{tagID}
args = append(args, append(append(excludeArgs, excludeArgs...), excludeArgs...)...)
- if err := qb.query(query, args, &ret); err != nil {
+ if err := qb.query(ctx, query, args, &ret); err != nil {
return nil, err
}
diff --git a/pkg/sqlite/tag_test.go b/pkg/sqlite/tag_test.go
index a5ed8d966..0bc0fb5c0 100644
--- a/pkg/sqlite/tag_test.go
+++ b/pkg/sqlite/tag_test.go
@@ -4,6 +4,7 @@
package sqlite_test
import (
+ "context"
"database/sql"
"fmt"
"math"
@@ -12,16 +13,17 @@ import (
"testing"
"github.com/stashapp/stash/pkg/models"
+ "github.com/stashapp/stash/pkg/sqlite"
"github.com/stretchr/testify/assert"
)
func TestMarkerFindBySceneMarkerID(t *testing.T) {
- withTxn(func(r models.Repository) error {
- tqb := r.Tag()
+ withTxn(func(ctx context.Context) error {
+ tqb := sqlite.TagReaderWriter
markerID := markerIDs[markerIdxWithTag]
- tags, err := tqb.FindBySceneMarkerID(markerID)
+ tags, err := tqb.FindBySceneMarkerID(ctx, markerID)
if err != nil {
t.Errorf("Error finding tags: %s", err.Error())
@@ -30,7 +32,7 @@ func TestMarkerFindBySceneMarkerID(t *testing.T) {
assert.Len(t, tags, 1)
assert.Equal(t, tagIDs[tagIdxWithMarkers], tags[0].ID)
- tags, err = tqb.FindBySceneMarkerID(0)
+ tags, err = tqb.FindBySceneMarkerID(ctx, 0)
if err != nil {
t.Errorf("Error finding tags: %s", err.Error())
@@ -43,12 +45,12 @@ func TestMarkerFindBySceneMarkerID(t *testing.T) {
}
func TestTagFindByName(t *testing.T) {
- withTxn(func(r models.Repository) error {
- tqb := r.Tag()
+ withTxn(func(ctx context.Context) error {
+ tqb := sqlite.TagReaderWriter
name := tagNames[tagIdxWithScene] // find a tag by name
- tag, err := tqb.FindByName(name, false)
+ tag, err := tqb.FindByName(ctx, name, false)
if err != nil {
t.Errorf("Error finding tags: %s", err.Error())
@@ -58,7 +60,7 @@ func TestTagFindByName(t *testing.T) {
name = tagNames[tagIdxWithDupName] // find a tag by name nocase
- tag, err = tqb.FindByName(name, true)
+ tag, err = tqb.FindByName(ctx, name, true)
if err != nil {
t.Errorf("Error finding tags: %s", err.Error())
@@ -74,15 +76,15 @@ func TestTagFindByName(t *testing.T) {
}
func TestTagQueryIgnoreAutoTag(t *testing.T) {
- withTxn(func(r models.Repository) error {
+ withTxn(func(ctx context.Context) error {
ignoreAutoTag := true
tagFilter := models.TagFilterType{
IgnoreAutoTag: &ignoreAutoTag,
}
- sqb := r.Tag()
+ sqb := sqlite.TagReaderWriter
- tags := queryTags(t, sqb, &tagFilter, nil)
+ tags := queryTags(ctx, t, sqb, &tagFilter, nil)
assert.Len(t, tags, int(math.Ceil(float64(totalTags)/5)))
for _, s := range tags {
@@ -94,12 +96,12 @@ func TestTagQueryIgnoreAutoTag(t *testing.T) {
}
func TestTagQueryForAutoTag(t *testing.T) {
- withTxn(func(r models.Repository) error {
- tqb := r.Tag()
+ withTxn(func(ctx context.Context) error {
+ tqb := sqlite.TagReaderWriter
name := tagNames[tagIdx1WithScene] // find a tag by name
- tags, err := tqb.QueryForAutoTag([]string{name})
+ tags, err := tqb.QueryForAutoTag(ctx, []string{name})
if err != nil {
t.Errorf("Error finding tags: %s", err.Error())
@@ -112,7 +114,7 @@ func TestTagQueryForAutoTag(t *testing.T) {
// find by alias
name = getTagStringValue(tagIdx1WithScene, "Alias")
- tags, err = tqb.QueryForAutoTag([]string{name})
+ tags, err = tqb.QueryForAutoTag(ctx, []string{name})
if err != nil {
t.Errorf("Error finding tags: %s", err.Error())
@@ -128,19 +130,19 @@ func TestTagQueryForAutoTag(t *testing.T) {
func TestTagFindByNames(t *testing.T) {
var names []string
- withTxn(func(r models.Repository) error {
- tqb := r.Tag()
+ withTxn(func(ctx context.Context) error {
+ tqb := sqlite.TagReaderWriter
names = append(names, tagNames[tagIdxWithScene]) // find tags by names
- tags, err := tqb.FindByNames(names, false)
+ tags, err := tqb.FindByNames(ctx, names, false)
if err != nil {
t.Errorf("Error finding tags: %s", err.Error())
}
assert.Len(t, tags, 1)
assert.Equal(t, tagNames[tagIdxWithScene], tags[0].Name)
- tags, err = tqb.FindByNames(names, true) // find tags by names nocase
+ tags, err = tqb.FindByNames(ctx, names, true) // find tags by names nocase
if err != nil {
t.Errorf("Error finding tags: %s", err.Error())
}
@@ -150,7 +152,7 @@ func TestTagFindByNames(t *testing.T) {
names = append(names, tagNames[tagIdx1WithScene]) // find tags by names ( 2 names )
- tags, err = tqb.FindByNames(names, false)
+ tags, err = tqb.FindByNames(ctx, names, false)
if err != nil {
t.Errorf("Error finding tags: %s", err.Error())
}
@@ -158,7 +160,7 @@ func TestTagFindByNames(t *testing.T) {
assert.Equal(t, tagNames[tagIdxWithScene], tags[0].Name)
assert.Equal(t, tagNames[tagIdx1WithScene], tags[1].Name)
- tags, err = tqb.FindByNames(names, true) // find tags by names ( 2 names nocase)
+ tags, err = tqb.FindByNames(ctx, names, true) // find tags by names ( 2 names nocase)
if err != nil {
t.Errorf("Error finding tags: %s", err.Error())
}
@@ -173,8 +175,8 @@ func TestTagFindByNames(t *testing.T) {
}
func TestTagQuerySort(t *testing.T) {
- withTxn(func(r models.Repository) error {
- sqb := r.Tag()
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.TagReaderWriter
sortBy := "scenes_count"
dir := models.SortDirectionEnumDesc
@@ -183,24 +185,24 @@ func TestTagQuerySort(t *testing.T) {
Direction: &dir,
}
- tags := queryTags(t, sqb, nil, findFilter)
+ tags := queryTags(ctx, t, sqb, nil, findFilter)
assert := assert.New(t)
assert.Equal(tagIDs[tagIdxWithScene], tags[0].ID)
sortBy = "scene_markers_count"
- tags = queryTags(t, sqb, nil, findFilter)
+ tags = queryTags(ctx, t, sqb, nil, findFilter)
assert.Equal(tagIDs[tagIdxWithMarkers], tags[0].ID)
sortBy = "images_count"
- tags = queryTags(t, sqb, nil, findFilter)
+ tags = queryTags(ctx, t, sqb, nil, findFilter)
assert.Equal(tagIDs[tagIdxWithImage], tags[0].ID)
sortBy = "galleries_count"
- tags = queryTags(t, sqb, nil, findFilter)
+ tags = queryTags(ctx, t, sqb, nil, findFilter)
assert.Equal(tagIDs[tagIdxWithGallery], tags[0].ID)
sortBy = "performers_count"
- tags = queryTags(t, sqb, nil, findFilter)
+ tags = queryTags(ctx, t, sqb, nil, findFilter)
assert.Equal(tagIDs[tagIdxWithPerformer], tags[0].ID)
return nil
@@ -220,7 +222,7 @@ func TestTagQueryName(t *testing.T) {
Name: nameCriterion,
}
- verifyFn := func(tag *models.Tag, r models.Repository) {
+ verifyFn := func(ctx context.Context, tag *models.Tag) {
verifyString(t, tag.Name, *nameCriterion)
}
@@ -250,8 +252,8 @@ func TestTagQueryAlias(t *testing.T) {
Aliases: aliasCriterion,
}
- verifyFn := func(tag *models.Tag, r models.Repository) {
- aliases, err := r.Tag().GetAliases(tag.ID)
+ verifyFn := func(ctx context.Context, tag *models.Tag) {
+ aliases, err := sqlite.TagReaderWriter.GetAliases(ctx, tag.ID)
if err != nil {
t.Errorf("Error querying tags: %s", err.Error())
}
@@ -277,23 +279,23 @@ func TestTagQueryAlias(t *testing.T) {
verifyTagQuery(t, tagFilter, nil, verifyFn)
}
-func verifyTagQuery(t *testing.T, tagFilter *models.TagFilterType, findFilter *models.FindFilterType, verifyFn func(t *models.Tag, r models.Repository)) {
- withTxn(func(r models.Repository) error {
- sqb := r.Tag()
+func verifyTagQuery(t *testing.T, tagFilter *models.TagFilterType, findFilter *models.FindFilterType, verifyFn func(ctx context.Context, t *models.Tag)) {
+ withTxn(func(ctx context.Context) error {
+ sqb := sqlite.TagReaderWriter
- tags := queryTags(t, sqb, tagFilter, findFilter)
+ tags := queryTags(ctx, t, sqb, tagFilter, findFilter)
for _, tag := range tags {
- verifyFn(tag, r)
+ verifyFn(ctx, tag)
}
return nil
})
}
-func queryTags(t *testing.T, qb models.TagReader, tagFilter *models.TagFilterType, findFilter *models.FindFilterType) []*models.Tag {
+func queryTags(ctx context.Context, t *testing.T, qb models.TagReader, tagFilter *models.TagFilterType, findFilter *models.FindFilterType) []*models.Tag {
t.Helper()
- tags, _, err := qb.Query(tagFilter, findFilter)
+ tags, _, err := qb.Query(ctx, tagFilter, findFilter)
if err != nil {
t.Errorf("Error querying tags: %s", err.Error())
}
@@ -302,8 +304,8 @@ func queryTags(t *testing.T, qb models.TagReader, tagFilter *models.TagFilterTyp
}
func TestTagQueryIsMissingImage(t *testing.T) {
- withTxn(func(r models.Repository) error {
- qb := r.Tag()
+ withTxn(func(ctx context.Context) error {
+ qb := sqlite.TagReaderWriter
isMissing := "image"
tagFilter := models.TagFilterType{
IsMissing: &isMissing,
@@ -314,7 +316,7 @@ func TestTagQueryIsMissingImage(t *testing.T) {
Q: &q,
}
- tags, _, err := qb.Query(&tagFilter, &findFilter)
+ tags, _, err := qb.Query(ctx, &tagFilter, &findFilter)
if err != nil {
t.Errorf("Error querying tag: %s", err.Error())
}
@@ -322,7 +324,7 @@ func TestTagQueryIsMissingImage(t *testing.T) {
assert.Len(t, tags, 0)
findFilter.Q = nil
- tags, _, err = qb.Query(&tagFilter, &findFilter)
+ tags, _, err = qb.Query(ctx, &tagFilter, &findFilter)
if err != nil {
t.Errorf("Error querying tag: %s", err.Error())
}
@@ -356,13 +358,13 @@ func TestTagQuerySceneCount(t *testing.T) {
}
func verifyTagSceneCount(t *testing.T, sceneCountCriterion models.IntCriterionInput) {
- withTxn(func(r models.Repository) error {
- qb := r.Tag()
+ withTxn(func(ctx context.Context) error {
+ qb := sqlite.TagReaderWriter
tagFilter := models.TagFilterType{
SceneCount: &sceneCountCriterion,
}
- tags, _, err := qb.Query(&tagFilter, nil)
+ tags, _, err := qb.Query(ctx, &tagFilter, nil)
if err != nil {
t.Errorf("Error querying tag: %s", err.Error())
}
@@ -398,13 +400,13 @@ func TestTagQueryMarkerCount(t *testing.T) {
}
func verifyTagMarkerCount(t *testing.T, markerCountCriterion models.IntCriterionInput) {
- withTxn(func(r models.Repository) error {
- qb := r.Tag()
+ withTxn(func(ctx context.Context) error {
+ qb := sqlite.TagReaderWriter
tagFilter := models.TagFilterType{
MarkerCount: &markerCountCriterion,
}
- tags, _, err := qb.Query(&tagFilter, nil)
+ tags, _, err := qb.Query(ctx, &tagFilter, nil)
if err != nil {
t.Errorf("Error querying tag: %s", err.Error())
}
@@ -440,13 +442,13 @@ func TestTagQueryImageCount(t *testing.T) {
}
func verifyTagImageCount(t *testing.T, imageCountCriterion models.IntCriterionInput) {
- withTxn(func(r models.Repository) error {
- qb := r.Tag()
+ withTxn(func(ctx context.Context) error {
+ qb := sqlite.TagReaderWriter
tagFilter := models.TagFilterType{
ImageCount: &imageCountCriterion,
}
- tags, _, err := qb.Query(&tagFilter, nil)
+ tags, _, err := qb.Query(ctx, &tagFilter, nil)
if err != nil {
t.Errorf("Error querying tag: %s", err.Error())
}
@@ -482,13 +484,13 @@ func TestTagQueryGalleryCount(t *testing.T) {
}
func verifyTagGalleryCount(t *testing.T, imageCountCriterion models.IntCriterionInput) {
- withTxn(func(r models.Repository) error {
- qb := r.Tag()
+ withTxn(func(ctx context.Context) error {
+ qb := sqlite.TagReaderWriter
tagFilter := models.TagFilterType{
GalleryCount: &imageCountCriterion,
}
- tags, _, err := qb.Query(&tagFilter, nil)
+ tags, _, err := qb.Query(ctx, &tagFilter, nil)
if err != nil {
t.Errorf("Error querying tag: %s", err.Error())
}
@@ -524,13 +526,13 @@ func TestTagQueryPerformerCount(t *testing.T) {
}
func verifyTagPerformerCount(t *testing.T, imageCountCriterion models.IntCriterionInput) {
- withTxn(func(r models.Repository) error {
- qb := r.Tag()
+ withTxn(func(ctx context.Context) error {
+ qb := sqlite.TagReaderWriter
tagFilter := models.TagFilterType{
PerformerCount: &imageCountCriterion,
}
- tags, _, err := qb.Query(&tagFilter, nil)
+ tags, _, err := qb.Query(ctx, &tagFilter, nil)
if err != nil {
t.Errorf("Error querying tag: %s", err.Error())
}
@@ -566,13 +568,13 @@ func TestTagQueryParentCount(t *testing.T) {
}
func verifyTagParentCount(t *testing.T, sceneCountCriterion models.IntCriterionInput) {
- withTxn(func(r models.Repository) error {
- qb := r.Tag()
+ withTxn(func(ctx context.Context) error {
+ qb := sqlite.TagReaderWriter
tagFilter := models.TagFilterType{
ParentCount: &sceneCountCriterion,
}
- tags := queryTags(t, qb, &tagFilter, nil)
+ tags := queryTags(ctx, t, qb, &tagFilter, nil)
if len(tags) == 0 {
t.Error("Expected at least one tag")
@@ -609,13 +611,13 @@ func TestTagQueryChildCount(t *testing.T) {
}
func verifyTagChildCount(t *testing.T, sceneCountCriterion models.IntCriterionInput) {
- withTxn(func(r models.Repository) error {
- qb := r.Tag()
+ withTxn(func(ctx context.Context) error {
+ qb := sqlite.TagReaderWriter
tagFilter := models.TagFilterType{
ChildCount: &sceneCountCriterion,
}
- tags := queryTags(t, qb, &tagFilter, nil)
+ tags := queryTags(ctx, t, qb, &tagFilter, nil)
if len(tags) == 0 {
t.Error("Expected at least one tag")
@@ -633,9 +635,9 @@ func verifyTagChildCount(t *testing.T, sceneCountCriterion models.IntCriterionIn
}
func TestTagQueryParent(t *testing.T) {
- withTxn(func(r models.Repository) error {
+ withTxn(func(ctx context.Context) error {
const nameField = "Name"
- sqb := r.Tag()
+ sqb := sqlite.TagReaderWriter
tagCriterion := models.HierarchicalMultiCriterionInput{
Value: []string{
strconv.Itoa(tagIDs[tagIdxWithChildTag]),
@@ -647,7 +649,7 @@ func TestTagQueryParent(t *testing.T) {
Parents: &tagCriterion,
}
- tags := queryTags(t, sqb, &tagFilter, nil)
+ tags := queryTags(ctx, t, sqb, &tagFilter, nil)
assert.Len(t, tags, 1)
@@ -661,7 +663,7 @@ func TestTagQueryParent(t *testing.T) {
Q: &q,
}
- tags = queryTags(t, sqb, &tagFilter, &findFilter)
+ tags = queryTags(ctx, t, sqb, &tagFilter, &findFilter)
assert.Len(t, tags, 0)
depth := -1
@@ -674,12 +676,12 @@ func TestTagQueryParent(t *testing.T) {
Depth: &depth,
}
- tags = queryTags(t, sqb, &tagFilter, nil)
+ tags = queryTags(ctx, t, sqb, &tagFilter, nil)
assert.Len(t, tags, 2)
depth = 1
- tags = queryTags(t, sqb, &tagFilter, nil)
+ tags = queryTags(ctx, t, sqb, &tagFilter, nil)
assert.Len(t, tags, 2)
tagCriterion = models.HierarchicalMultiCriterionInput{
@@ -687,22 +689,22 @@ func TestTagQueryParent(t *testing.T) {
}
q = getTagStringValue(tagIdxWithGallery, nameField)
- tags = queryTags(t, sqb, &tagFilter, &findFilter)
+ tags = queryTags(ctx, t, sqb, &tagFilter, &findFilter)
assert.Len(t, tags, 1)
assert.Equal(t, tagIDs[tagIdxWithGallery], tags[0].ID)
q = getTagStringValue(tagIdxWithParentTag, nameField)
- tags = queryTags(t, sqb, &tagFilter, &findFilter)
+ tags = queryTags(ctx, t, sqb, &tagFilter, &findFilter)
assert.Len(t, tags, 0)
tagCriterion.Modifier = models.CriterionModifierNotNull
- tags = queryTags(t, sqb, &tagFilter, &findFilter)
+ tags = queryTags(ctx, t, sqb, &tagFilter, &findFilter)
assert.Len(t, tags, 1)
assert.Equal(t, tagIDs[tagIdxWithParentTag], tags[0].ID)
q = getTagStringValue(tagIdxWithGallery, nameField)
- tags = queryTags(t, sqb, &tagFilter, &findFilter)
+ tags = queryTags(ctx, t, sqb, &tagFilter, &findFilter)
assert.Len(t, tags, 0)
return nil
@@ -710,10 +712,10 @@ func TestTagQueryParent(t *testing.T) {
}
func TestTagQueryChild(t *testing.T) {
- withTxn(func(r models.Repository) error {
+ withTxn(func(ctx context.Context) error {
const nameField = "Name"
- sqb := r.Tag()
+ sqb := sqlite.TagReaderWriter
tagCriterion := models.HierarchicalMultiCriterionInput{
Value: []string{
strconv.Itoa(tagIDs[tagIdxWithParentTag]),
@@ -725,7 +727,7 @@ func TestTagQueryChild(t *testing.T) {
Children: &tagCriterion,
}
- tags := queryTags(t, sqb, &tagFilter, nil)
+ tags := queryTags(ctx, t, sqb, &tagFilter, nil)
assert.Len(t, tags, 1)
@@ -739,7 +741,7 @@ func TestTagQueryChild(t *testing.T) {
Q: &q,
}
- tags = queryTags(t, sqb, &tagFilter, &findFilter)
+ tags = queryTags(ctx, t, sqb, &tagFilter, &findFilter)
assert.Len(t, tags, 0)
depth := -1
@@ -752,12 +754,12 @@ func TestTagQueryChild(t *testing.T) {
Depth: &depth,
}
- tags = queryTags(t, sqb, &tagFilter, nil)
+ tags = queryTags(ctx, t, sqb, &tagFilter, nil)
assert.Len(t, tags, 2)
depth = 1
- tags = queryTags(t, sqb, &tagFilter, nil)
+ tags = queryTags(ctx, t, sqb, &tagFilter, nil)
assert.Len(t, tags, 2)
tagCriterion = models.HierarchicalMultiCriterionInput{
@@ -765,22 +767,22 @@ func TestTagQueryChild(t *testing.T) {
}
q = getTagStringValue(tagIdxWithGallery, nameField)
- tags = queryTags(t, sqb, &tagFilter, &findFilter)
+ tags = queryTags(ctx, t, sqb, &tagFilter, &findFilter)
assert.Len(t, tags, 1)
assert.Equal(t, tagIDs[tagIdxWithGallery], tags[0].ID)
q = getTagStringValue(tagIdxWithChildTag, nameField)
- tags = queryTags(t, sqb, &tagFilter, &findFilter)
+ tags = queryTags(ctx, t, sqb, &tagFilter, &findFilter)
assert.Len(t, tags, 0)
tagCriterion.Modifier = models.CriterionModifierNotNull
- tags = queryTags(t, sqb, &tagFilter, &findFilter)
+ tags = queryTags(ctx, t, sqb, &tagFilter, &findFilter)
assert.Len(t, tags, 1)
assert.Equal(t, tagIDs[tagIdxWithChildTag], tags[0].ID)
q = getTagStringValue(tagIdxWithGallery, nameField)
- tags = queryTags(t, sqb, &tagFilter, &findFilter)
+ tags = queryTags(ctx, t, sqb, &tagFilter, &findFilter)
assert.Len(t, tags, 0)
return nil
@@ -788,34 +790,34 @@ func TestTagQueryChild(t *testing.T) {
}
func TestTagUpdateTagImage(t *testing.T) {
- if err := withTxn(func(r models.Repository) error {
- qb := r.Tag()
+ if err := withTxn(func(ctx context.Context) error {
+ qb := sqlite.TagReaderWriter
// create tag to test against
const name = "TestTagUpdateTagImage"
tag := models.Tag{
Name: name,
}
- created, err := qb.Create(tag)
+ created, err := qb.Create(ctx, tag)
if err != nil {
return fmt.Errorf("Error creating tag: %s", err.Error())
}
image := []byte("image")
- err = qb.UpdateImage(created.ID, image)
+ err = qb.UpdateImage(ctx, created.ID, image)
if err != nil {
return fmt.Errorf("Error updating studio image: %s", err.Error())
}
// ensure image set
- storedImage, err := qb.GetImage(created.ID)
+ storedImage, err := qb.GetImage(ctx, created.ID)
if err != nil {
return fmt.Errorf("Error getting image: %s", err.Error())
}
assert.Equal(t, storedImage, image)
// set nil image
- err = qb.UpdateImage(created.ID, nil)
+ err = qb.UpdateImage(ctx, created.ID, nil)
if err == nil {
return fmt.Errorf("Expected error setting nil image")
}
@@ -827,32 +829,32 @@ func TestTagUpdateTagImage(t *testing.T) {
}
func TestTagDestroyTagImage(t *testing.T) {
- if err := withTxn(func(r models.Repository) error {
- qb := r.Tag()
+ if err := withTxn(func(ctx context.Context) error {
+ qb := sqlite.TagReaderWriter
// create performer to test against
const name = "TestTagDestroyTagImage"
tag := models.Tag{
Name: name,
}
- created, err := qb.Create(tag)
+ created, err := qb.Create(ctx, tag)
if err != nil {
return fmt.Errorf("Error creating tag: %s", err.Error())
}
image := []byte("image")
- err = qb.UpdateImage(created.ID, image)
+ err = qb.UpdateImage(ctx, created.ID, image)
if err != nil {
return fmt.Errorf("Error updating studio image: %s", err.Error())
}
- err = qb.DestroyImage(created.ID)
+ err = qb.DestroyImage(ctx, created.ID)
if err != nil {
return fmt.Errorf("Error destroying studio image: %s", err.Error())
}
// image should be nil
- storedImage, err := qb.GetImage(created.ID)
+ storedImage, err := qb.GetImage(ctx, created.ID)
if err != nil {
return fmt.Errorf("Error getting image: %s", err.Error())
}
@@ -865,27 +867,27 @@ func TestTagDestroyTagImage(t *testing.T) {
}
func TestTagUpdateAlias(t *testing.T) {
- if err := withTxn(func(r models.Repository) error {
- qb := r.Tag()
+ if err := withTxn(func(ctx context.Context) error {
+ qb := sqlite.TagReaderWriter
// create tag to test against
const name = "TestTagUpdateAlias"
tag := models.Tag{
Name: name,
}
- created, err := qb.Create(tag)
+ created, err := qb.Create(ctx, tag)
if err != nil {
return fmt.Errorf("Error creating tag: %s", err.Error())
}
aliases := []string{"alias1", "alias2"}
- err = qb.UpdateAliases(created.ID, aliases)
+ err = qb.UpdateAliases(ctx, created.ID, aliases)
if err != nil {
return fmt.Errorf("Error updating tag aliases: %s", err.Error())
}
// ensure aliases set
- storedAliases, err := qb.GetAliases(created.ID)
+ storedAliases, err := qb.GetAliases(ctx, created.ID)
if err != nil {
return fmt.Errorf("Error getting aliases: %s", err.Error())
}
@@ -901,11 +903,11 @@ func TestTagMerge(t *testing.T) {
assert := assert.New(t)
// merge tests - perform these in a transaction that we'll rollback
- if err := withRollbackTxn(func(r models.Repository) error {
- qb := r.Tag()
+ if err := withRollbackTxn(func(ctx context.Context) error {
+ qb := sqlite.TagReaderWriter
// try merging into same tag
- err := qb.Merge([]int{tagIDs[tagIdx1WithScene]}, tagIDs[tagIdx1WithScene])
+ err := qb.Merge(ctx, []int{tagIDs[tagIdx1WithScene]}, tagIDs[tagIdx1WithScene])
assert.NotNil(err)
// merge everything into tagIdxWithScene
@@ -931,13 +933,13 @@ func TestTagMerge(t *testing.T) {
}
destID := tagIDs[tagIdxWithScene]
- if err = qb.Merge(srcIDs, destID); err != nil {
+ if err = qb.Merge(ctx, srcIDs, destID); err != nil {
return err
}
// ensure other tags are deleted
for _, tagId := range srcIDs {
- t, err := qb.Find(tagId)
+ t, err := qb.Find(ctx, tagId)
if err != nil {
return err
}
@@ -946,7 +948,7 @@ func TestTagMerge(t *testing.T) {
}
// ensure aliases are set on the destination
- destAliases, err := qb.GetAliases(destID)
+ destAliases, err := qb.GetAliases(ctx, destID)
if err != nil {
return err
}
@@ -955,7 +957,7 @@ func TestTagMerge(t *testing.T) {
}
// ensure scene points to new tag
- sceneTagIDs, err := r.Scene().GetTagIDs(sceneIDs[sceneIdxWithTwoTags])
+ sceneTagIDs, err := sqlite.SceneReaderWriter.GetTagIDs(ctx, sceneIDs[sceneIdxWithTwoTags])
if err != nil {
return err
}
@@ -963,14 +965,14 @@ func TestTagMerge(t *testing.T) {
assert.Contains(sceneTagIDs, destID)
// ensure marker points to new tag
- marker, err := r.SceneMarker().Find(markerIDs[markerIdxWithTag])
+ marker, err := sqlite.SceneMarkerReaderWriter.Find(ctx, markerIDs[markerIdxWithTag])
if err != nil {
return err
}
assert.Equal(destID, marker.PrimaryTagID)
- markerTagIDs, err := r.SceneMarker().GetTagIDs(marker.ID)
+ markerTagIDs, err := sqlite.SceneMarkerReaderWriter.GetTagIDs(ctx, marker.ID)
if err != nil {
return err
}
@@ -978,7 +980,7 @@ func TestTagMerge(t *testing.T) {
assert.Contains(markerTagIDs, destID)
// ensure image points to new tag
- imageTagIDs, err := r.Image().GetTagIDs(imageIDs[imageIdxWithTwoTags])
+ imageTagIDs, err := sqlite.ImageReaderWriter.GetTagIDs(ctx, imageIDs[imageIdxWithTwoTags])
if err != nil {
return err
}
@@ -986,7 +988,7 @@ func TestTagMerge(t *testing.T) {
assert.Contains(imageTagIDs, destID)
// ensure gallery points to new tag
- galleryTagIDs, err := r.Gallery().GetTagIDs(galleryIDs[galleryIdxWithTwoTags])
+ galleryTagIDs, err := sqlite.GalleryReaderWriter.GetTagIDs(ctx, galleryIDs[galleryIdxWithTwoTags])
if err != nil {
return err
}
@@ -994,7 +996,7 @@ func TestTagMerge(t *testing.T) {
assert.Contains(galleryTagIDs, destID)
// ensure performer points to new tag
- performerTagIDs, err := r.Gallery().GetTagIDs(performerIDs[performerIdxWithTwoTags])
+ performerTagIDs, err := sqlite.GalleryReaderWriter.GetTagIDs(ctx, performerIDs[performerIdxWithTwoTags])
if err != nil {
return err
}
diff --git a/pkg/sqlite/transaction.go b/pkg/sqlite/transaction.go
index 50486d01e..23f9b27b3 100644
--- a/pkg/sqlite/transaction.go
+++ b/pkg/sqlite/transaction.go
@@ -2,209 +2,67 @@ package sqlite
import (
"context"
- "database/sql"
- "errors"
"fmt"
"github.com/jmoiron/sqlx"
- "github.com/stashapp/stash/pkg/database"
"github.com/stashapp/stash/pkg/models"
)
-type dbi interface {
- Get(dest interface{}, query string, args ...interface{}) error
- Select(dest interface{}, query string, args ...interface{}) error
- Queryx(query string, args ...interface{}) (*sqlx.Rows, error)
- NamedExec(query string, arg interface{}) (sql.Result, error)
- Exec(query string, args ...interface{}) (sql.Result, error)
-}
+type key int
-type transaction struct {
- Ctx context.Context
- tx *sqlx.Tx
-}
+const (
+ txnKey key = iota + 1
+)
-func (t *transaction) Begin() error {
- if t.tx != nil {
- return errors.New("transaction already begun")
+func (db *Database) Begin(ctx context.Context) (context.Context, error) {
+ if tx, _ := getTx(ctx); tx != nil {
+ return nil, fmt.Errorf("already in transaction")
}
- if err := database.Ready(); err != nil {
+ tx, err := db.db.BeginTxx(ctx, nil)
+ if err != nil {
+ return nil, fmt.Errorf("beginning transaction: %w", err)
+ }
+
+ return context.WithValue(ctx, txnKey, tx), nil
+}
+
+func (db *Database) Commit(ctx context.Context) error {
+ tx, err := getTx(ctx)
+ if err != nil {
return err
}
+ return tx.Commit()
+}
- var err error
- t.tx, err = database.DB.BeginTxx(t.Ctx, nil)
+func (db *Database) Rollback(ctx context.Context) error {
+ tx, err := getTx(ctx)
if err != nil {
- return fmt.Errorf("error starting transaction: %v", err)
- }
-
- return nil
-}
-
-func (t *transaction) Rollback() error {
- if t.tx == nil {
- return errors.New("not in transaction")
- }
-
- err := t.tx.Rollback()
- if err != nil {
- return fmt.Errorf("error rolling back transaction: %v", err)
- }
- t.tx = nil
-
- return nil
-}
-
-func (t *transaction) Commit() error {
- if t.tx == nil {
- return errors.New("not in transaction")
- }
-
- err := t.tx.Commit()
- if err != nil {
- return fmt.Errorf("error committing transaction: %v", err)
- }
- t.tx = nil
-
- return nil
-}
-
-func (t *transaction) Repository() models.Repository {
- return t
-}
-
-func (t *transaction) ensureTx() {
- if t.tx == nil {
- panic("tx is nil")
- }
-}
-
-func (t *transaction) Gallery() models.GalleryReaderWriter {
- t.ensureTx()
- return NewGalleryReaderWriter(t.tx)
-}
-
-func (t *transaction) Image() models.ImageReaderWriter {
- t.ensureTx()
- return NewImageReaderWriter(t.tx)
-}
-
-func (t *transaction) Movie() models.MovieReaderWriter {
- t.ensureTx()
- return NewMovieReaderWriter(t.tx)
-}
-
-func (t *transaction) Performer() models.PerformerReaderWriter {
- t.ensureTx()
- return NewPerformerReaderWriter(t.tx)
-}
-
-func (t *transaction) SceneMarker() models.SceneMarkerReaderWriter {
- t.ensureTx()
- return NewSceneMarkerReaderWriter(t.tx)
-}
-
-func (t *transaction) Scene() models.SceneReaderWriter {
- t.ensureTx()
- return NewSceneReaderWriter(t.tx)
-}
-
-func (t *transaction) ScrapedItem() models.ScrapedItemReaderWriter {
- t.ensureTx()
- return NewScrapedItemReaderWriter(t.tx)
-}
-
-func (t *transaction) Studio() models.StudioReaderWriter {
- t.ensureTx()
- return NewStudioReaderWriter(t.tx)
-}
-
-func (t *transaction) Tag() models.TagReaderWriter {
- t.ensureTx()
- return NewTagReaderWriter(t.tx)
-}
-
-func (t *transaction) SavedFilter() models.SavedFilterReaderWriter {
- t.ensureTx()
- return NewSavedFilterReaderWriter(t.tx)
-}
-
-type ReadTransaction struct{}
-
-func (t *ReadTransaction) Begin() error {
- if err := database.Ready(); err != nil {
return err
}
-
- return nil
+ return tx.Rollback()
}
-func (t *ReadTransaction) Rollback() error {
- return nil
+func getTx(ctx context.Context) (*sqlx.Tx, error) {
+ tx, ok := ctx.Value(txnKey).(*sqlx.Tx)
+ if !ok || tx == nil {
+ return nil, fmt.Errorf("not in transaction")
+ }
+ return tx, nil
}
-func (t *ReadTransaction) Commit() error {
- return nil
-}
-
-func (t *ReadTransaction) Repository() models.ReaderRepository {
- return t
-}
-
-func (t *ReadTransaction) Gallery() models.GalleryReader {
- return NewGalleryReaderWriter(database.DB)
-}
-
-func (t *ReadTransaction) Image() models.ImageReader {
- return NewImageReaderWriter(database.DB)
-}
-
-func (t *ReadTransaction) Movie() models.MovieReader {
- return NewMovieReaderWriter(database.DB)
-}
-
-func (t *ReadTransaction) Performer() models.PerformerReader {
- return NewPerformerReaderWriter(database.DB)
-}
-
-func (t *ReadTransaction) SceneMarker() models.SceneMarkerReader {
- return NewSceneMarkerReaderWriter(database.DB)
-}
-
-func (t *ReadTransaction) Scene() models.SceneReader {
- return NewSceneReaderWriter(database.DB)
-}
-
-func (t *ReadTransaction) ScrapedItem() models.ScrapedItemReader {
- return NewScrapedItemReaderWriter(database.DB)
-}
-
-func (t *ReadTransaction) Studio() models.StudioReader {
- return NewStudioReaderWriter(database.DB)
-}
-
-func (t *ReadTransaction) Tag() models.TagReader {
- return NewTagReaderWriter(database.DB)
-}
-
-func (t *ReadTransaction) SavedFilter() models.SavedFilterReader {
- return NewSavedFilterReaderWriter(database.DB)
-}
-
-type TransactionManager struct {
-}
-
-func NewTransactionManager() *TransactionManager {
- return &TransactionManager{}
-}
-
-func (t *TransactionManager) WithTxn(ctx context.Context, fn func(r models.Repository) error) error {
- database.WriteMu.Lock()
- defer database.WriteMu.Unlock()
- return models.WithTxn(&transaction{Ctx: ctx}, fn)
-}
-
-func (t *TransactionManager) WithReadTxn(ctx context.Context, fn func(r models.ReaderRepository) error) error {
- return models.WithROTxn(&ReadTransaction{}, fn)
+func (db *Database) TxnRepository() models.Repository {
+ return models.Repository{
+ TxnManager: db,
+ Gallery: GalleryReaderWriter,
+ Image: ImageReaderWriter,
+ Movie: MovieReaderWriter,
+ Performer: PerformerReaderWriter,
+ Scene: SceneReaderWriter,
+ SceneMarker: SceneMarkerReaderWriter,
+ ScrapedItem: ScrapedItemReaderWriter,
+ Studio: StudioReaderWriter,
+ Tag: TagReaderWriter,
+ SavedFilter: SavedFilterReaderWriter,
+ }
}
diff --git a/pkg/sqlite/tx.go b/pkg/sqlite/tx.go
new file mode 100644
index 000000000..f12b1dd4a
--- /dev/null
+++ b/pkg/sqlite/tx.go
@@ -0,0 +1,55 @@
+package sqlite
+
+import (
+ "context"
+ "database/sql"
+
+ "github.com/jmoiron/sqlx"
+)
+
+type dbi struct{}
+
+func (*dbi) Get(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
+ tx, err := getTx(ctx)
+ if err != nil {
+ return err
+ }
+
+ return tx.Get(dest, query, args...)
+}
+
+func (*dbi) Select(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
+ tx, err := getTx(ctx)
+ if err != nil {
+ return err
+ }
+
+ return tx.Select(dest, query, args...)
+}
+
+func (*dbi) Queryx(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) {
+ tx, err := getTx(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ return tx.Queryx(query, args...)
+}
+
+func (*dbi) NamedExec(ctx context.Context, query string, arg interface{}) (sql.Result, error) {
+ tx, err := getTx(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ return tx.NamedExec(query, arg)
+}
+
+func (*dbi) Exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
+ tx, err := getTx(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ return tx.Exec(query, args...)
+}
diff --git a/pkg/studio/export.go b/pkg/studio/export.go
index 951a60417..21272ecc4 100644
--- a/pkg/studio/export.go
+++ b/pkg/studio/export.go
@@ -1,6 +1,7 @@
package studio
import (
+ "context"
"fmt"
"github.com/stashapp/stash/pkg/models"
@@ -9,8 +10,15 @@ import (
"github.com/stashapp/stash/pkg/utils"
)
+type FinderImageStashIDGetter interface {
+ Finder
+ GetAliases(ctx context.Context, studioID int) ([]string, error)
+ GetImage(ctx context.Context, studioID int) ([]byte, error)
+ GetStashIDs(ctx context.Context, studioID int) ([]*models.StashID, error)
+}
+
// ToJSON converts a Studio object into its JSON equivalent.
-func ToJSON(reader models.StudioReader, studio *models.Studio) (*jsonschema.Studio, error) {
+func ToJSON(ctx context.Context, reader FinderImageStashIDGetter, studio *models.Studio) (*jsonschema.Studio, error) {
newStudioJSON := jsonschema.Studio{
IgnoreAutoTag: studio.IgnoreAutoTag,
CreatedAt: json.JSONTime{Time: studio.CreatedAt.Timestamp},
@@ -30,7 +38,7 @@ func ToJSON(reader models.StudioReader, studio *models.Studio) (*jsonschema.Stud
}
if studio.ParentID.Valid {
- parent, err := reader.Find(int(studio.ParentID.Int64))
+ parent, err := reader.Find(ctx, int(studio.ParentID.Int64))
if err != nil {
return nil, fmt.Errorf("error getting parent studio: %v", err)
}
@@ -44,14 +52,14 @@ func ToJSON(reader models.StudioReader, studio *models.Studio) (*jsonschema.Stud
newStudioJSON.Rating = int(studio.Rating.Int64)
}
- aliases, err := reader.GetAliases(studio.ID)
+ aliases, err := reader.GetAliases(ctx, studio.ID)
if err != nil {
return nil, fmt.Errorf("error getting studio aliases: %v", err)
}
newStudioJSON.Aliases = aliases
- image, err := reader.GetImage(studio.ID)
+ image, err := reader.GetImage(ctx, studio.ID)
if err != nil {
return nil, fmt.Errorf("error getting studio image: %v", err)
}
@@ -60,7 +68,7 @@ func ToJSON(reader models.StudioReader, studio *models.Studio) (*jsonschema.Stud
newStudioJSON.Image = utils.GetBase64StringFromData(image)
}
- stashIDs, _ := reader.GetStashIDs(studio.ID)
+ stashIDs, _ := reader.GetStashIDs(ctx, studio.ID)
var ret []models.StashID
for _, stashID := range stashIDs {
newJoin := models.StashID{
diff --git a/pkg/studio/export_test.go b/pkg/studio/export_test.go
index a1f261254..b15fbc018 100644
--- a/pkg/studio/export_test.go
+++ b/pkg/studio/export_test.go
@@ -1,6 +1,7 @@
package studio
import (
+ "context"
"errors"
"github.com/stashapp/stash/pkg/models"
@@ -169,39 +170,40 @@ func initTestTable() {
func TestToJSON(t *testing.T) {
initTestTable()
+ ctx := context.Background()
mockStudioReader := &mocks.StudioReaderWriter{}
imageErr := errors.New("error getting image")
- mockStudioReader.On("GetImage", studioID).Return(imageBytes, nil).Once()
- mockStudioReader.On("GetImage", noImageID).Return(nil, nil).Once()
- mockStudioReader.On("GetImage", errImageID).Return(nil, imageErr).Once()
- mockStudioReader.On("GetImage", missingParentStudioID).Return(imageBytes, nil).Maybe()
- mockStudioReader.On("GetImage", errStudioID).Return(imageBytes, nil).Maybe()
- mockStudioReader.On("GetImage", errAliasID).Return(imageBytes, nil).Maybe()
+ mockStudioReader.On("GetImage", ctx, studioID).Return(imageBytes, nil).Once()
+ mockStudioReader.On("GetImage", ctx, noImageID).Return(nil, nil).Once()
+ mockStudioReader.On("GetImage", ctx, errImageID).Return(nil, imageErr).Once()
+ mockStudioReader.On("GetImage", ctx, missingParentStudioID).Return(imageBytes, nil).Maybe()
+ mockStudioReader.On("GetImage", ctx, errStudioID).Return(imageBytes, nil).Maybe()
+ mockStudioReader.On("GetImage", ctx, errAliasID).Return(imageBytes, nil).Maybe()
parentStudioErr := errors.New("error getting parent studio")
- mockStudioReader.On("Find", parentStudioID).Return(&parentStudio, nil)
- mockStudioReader.On("Find", missingStudioID).Return(nil, nil)
- mockStudioReader.On("Find", errParentStudioID).Return(nil, parentStudioErr)
+ mockStudioReader.On("Find", ctx, parentStudioID).Return(&parentStudio, nil)
+ mockStudioReader.On("Find", ctx, missingStudioID).Return(nil, nil)
+ mockStudioReader.On("Find", ctx, errParentStudioID).Return(nil, parentStudioErr)
aliasErr := errors.New("error getting aliases")
- mockStudioReader.On("GetAliases", studioID).Return([]string{"alias"}, nil).Once()
- mockStudioReader.On("GetAliases", noImageID).Return(nil, nil).Once()
- mockStudioReader.On("GetAliases", errImageID).Return(nil, nil).Once()
- mockStudioReader.On("GetAliases", missingParentStudioID).Return(nil, nil).Once()
- mockStudioReader.On("GetAliases", errAliasID).Return(nil, aliasErr).Once()
+ mockStudioReader.On("GetAliases", ctx, studioID).Return([]string{"alias"}, nil).Once()
+ mockStudioReader.On("GetAliases", ctx, noImageID).Return(nil, nil).Once()
+ mockStudioReader.On("GetAliases", ctx, errImageID).Return(nil, nil).Once()
+ mockStudioReader.On("GetAliases", ctx, missingParentStudioID).Return(nil, nil).Once()
+ mockStudioReader.On("GetAliases", ctx, errAliasID).Return(nil, aliasErr).Once()
- mockStudioReader.On("GetStashIDs", studioID).Return(stashIDs, nil).Once()
- mockStudioReader.On("GetStashIDs", noImageID).Return(nil, nil).Once()
- mockStudioReader.On("GetStashIDs", missingParentStudioID).Return(stashIDs, nil).Once()
+ mockStudioReader.On("GetStashIDs", ctx, studioID).Return(stashIDs, nil).Once()
+ mockStudioReader.On("GetStashIDs", ctx, noImageID).Return(nil, nil).Once()
+ mockStudioReader.On("GetStashIDs", ctx, missingParentStudioID).Return(stashIDs, nil).Once()
for i, s := range scenarios {
studio := s.input
- json, err := ToJSON(mockStudioReader, &studio)
+ json, err := ToJSON(ctx, mockStudioReader, &studio)
switch {
case !s.err && err != nil:
diff --git a/pkg/studio/import.go b/pkg/studio/import.go
index a44481982..627d81272 100644
--- a/pkg/studio/import.go
+++ b/pkg/studio/import.go
@@ -1,6 +1,7 @@
package studio
import (
+ "context"
"database/sql"
"errors"
"fmt"
@@ -11,10 +12,19 @@ import (
"github.com/stashapp/stash/pkg/utils"
)
+type NameFinderCreatorUpdater interface {
+ FindByName(ctx context.Context, name string, nocase bool) (*models.Studio, error)
+ Create(ctx context.Context, newStudio models.Studio) (*models.Studio, error)
+ UpdateFull(ctx context.Context, updatedStudio models.Studio) (*models.Studio, error)
+ UpdateImage(ctx context.Context, studioID int, image []byte) error
+ UpdateAliases(ctx context.Context, studioID int, aliases []string) error
+ UpdateStashIDs(ctx context.Context, studioID int, stashIDs []models.StashID) error
+}
+
var ErrParentStudioNotExist = errors.New("parent studio does not exist")
type Importer struct {
- ReaderWriter models.StudioReaderWriter
+ ReaderWriter NameFinderCreatorUpdater
Input jsonschema.Studio
MissingRefBehaviour models.ImportMissingRefEnum
@@ -22,7 +32,7 @@ type Importer struct {
imageData []byte
}
-func (i *Importer) PreImport() error {
+func (i *Importer) PreImport(ctx context.Context) error {
checksum := md5.FromString(i.Input.Name)
i.studio = models.Studio{
@@ -36,7 +46,7 @@ func (i *Importer) PreImport() error {
Rating: sql.NullInt64{Int64: int64(i.Input.Rating), Valid: true},
}
- if err := i.populateParentStudio(); err != nil {
+ if err := i.populateParentStudio(ctx); err != nil {
return err
}
@@ -51,9 +61,9 @@ func (i *Importer) PreImport() error {
return nil
}
-func (i *Importer) populateParentStudio() error {
+func (i *Importer) populateParentStudio(ctx context.Context) error {
if i.Input.ParentStudio != "" {
- studio, err := i.ReaderWriter.FindByName(i.Input.ParentStudio, false)
+ studio, err := i.ReaderWriter.FindByName(ctx, i.Input.ParentStudio, false)
if err != nil {
return fmt.Errorf("error finding studio by name: %v", err)
}
@@ -68,7 +78,7 @@ func (i *Importer) populateParentStudio() error {
}
if i.MissingRefBehaviour == models.ImportMissingRefEnumCreate {
- parentID, err := i.createParentStudio(i.Input.ParentStudio)
+ parentID, err := i.createParentStudio(ctx, i.Input.ParentStudio)
if err != nil {
return err
}
@@ -85,10 +95,10 @@ func (i *Importer) populateParentStudio() error {
return nil
}
-func (i *Importer) createParentStudio(name string) (int, error) {
+func (i *Importer) createParentStudio(ctx context.Context, name string) (int, error) {
newStudio := *models.NewStudio(name)
- created, err := i.ReaderWriter.Create(newStudio)
+ created, err := i.ReaderWriter.Create(ctx, newStudio)
if err != nil {
return 0, err
}
@@ -96,20 +106,20 @@ func (i *Importer) createParentStudio(name string) (int, error) {
return created.ID, nil
}
-func (i *Importer) PostImport(id int) error {
+func (i *Importer) PostImport(ctx context.Context, id int) error {
if len(i.imageData) > 0 {
- if err := i.ReaderWriter.UpdateImage(id, i.imageData); err != nil {
+ if err := i.ReaderWriter.UpdateImage(ctx, id, i.imageData); err != nil {
return fmt.Errorf("error setting studio image: %v", err)
}
}
if len(i.Input.StashIDs) > 0 {
- if err := i.ReaderWriter.UpdateStashIDs(id, i.Input.StashIDs); err != nil {
+ if err := i.ReaderWriter.UpdateStashIDs(ctx, id, i.Input.StashIDs); err != nil {
return fmt.Errorf("error setting stash id: %v", err)
}
}
- if err := i.ReaderWriter.UpdateAliases(id, i.Input.Aliases); err != nil {
+ if err := i.ReaderWriter.UpdateAliases(ctx, id, i.Input.Aliases); err != nil {
return fmt.Errorf("error setting tag aliases: %v", err)
}
@@ -120,9 +130,9 @@ func (i *Importer) Name() string {
return i.Input.Name
}
-func (i *Importer) FindExistingID() (*int, error) {
+func (i *Importer) FindExistingID(ctx context.Context) (*int, error) {
const nocase = false
- existing, err := i.ReaderWriter.FindByName(i.Name(), nocase)
+ existing, err := i.ReaderWriter.FindByName(ctx, i.Name(), nocase)
if err != nil {
return nil, err
}
@@ -135,8 +145,8 @@ func (i *Importer) FindExistingID() (*int, error) {
return nil, nil
}
-func (i *Importer) Create() (*int, error) {
- created, err := i.ReaderWriter.Create(i.studio)
+func (i *Importer) Create(ctx context.Context) (*int, error) {
+ created, err := i.ReaderWriter.Create(ctx, i.studio)
if err != nil {
return nil, fmt.Errorf("error creating studio: %v", err)
}
@@ -145,10 +155,10 @@ func (i *Importer) Create() (*int, error) {
return &id, nil
}
-func (i *Importer) Update(id int) error {
+func (i *Importer) Update(ctx context.Context, id int) error {
studio := i.studio
studio.ID = id
- _, err := i.ReaderWriter.UpdateFull(studio)
+ _, err := i.ReaderWriter.UpdateFull(ctx, studio)
if err != nil {
return fmt.Errorf("error updating existing studio: %v", err)
}
diff --git a/pkg/studio/import_test.go b/pkg/studio/import_test.go
index 87b22519b..fc2ae402b 100644
--- a/pkg/studio/import_test.go
+++ b/pkg/studio/import_test.go
@@ -1,6 +1,7 @@
package studio
import (
+ "context"
"errors"
"testing"
@@ -43,21 +44,22 @@ func TestImporterPreImport(t *testing.T) {
IgnoreAutoTag: autoTagIgnored,
},
}
+ ctx := context.Background()
- err := i.PreImport()
+ err := i.PreImport(ctx)
assert.NotNil(t, err)
i.Input.Image = image
- err = i.PreImport()
+ err = i.PreImport(ctx)
assert.Nil(t, err)
i.Input = *createFullJSONStudio(studioName, image, []string{"alias"})
i.Input.ParentStudio = ""
- err = i.PreImport()
+ err = i.PreImport(ctx)
assert.Nil(t, err)
expectedStudio := createFullStudio(0, 0)
@@ -68,6 +70,7 @@ func TestImporterPreImport(t *testing.T) {
func TestImporterPreImportWithParent(t *testing.T) {
readerWriter := &mocks.StudioReaderWriter{}
+ ctx := context.Background()
i := Importer{
ReaderWriter: readerWriter,
@@ -78,17 +81,17 @@ func TestImporterPreImportWithParent(t *testing.T) {
},
}
- readerWriter.On("FindByName", existingParentStudioName, false).Return(&models.Studio{
+ readerWriter.On("FindByName", ctx, existingParentStudioName, false).Return(&models.Studio{
ID: existingStudioID,
}, nil).Once()
- readerWriter.On("FindByName", existingParentStudioErr, false).Return(nil, errors.New("FindByName error")).Once()
+ readerWriter.On("FindByName", ctx, existingParentStudioErr, false).Return(nil, errors.New("FindByName error")).Once()
- err := i.PreImport()
+ err := i.PreImport(ctx)
assert.Nil(t, err)
assert.Equal(t, int64(existingStudioID), i.studio.ParentID.Int64)
i.Input.ParentStudio = existingParentStudioErr
- err = i.PreImport()
+ err = i.PreImport(ctx)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
@@ -96,6 +99,7 @@ func TestImporterPreImportWithParent(t *testing.T) {
func TestImporterPreImportWithMissingParent(t *testing.T) {
readerWriter := &mocks.StudioReaderWriter{}
+ ctx := context.Background()
i := Importer{
ReaderWriter: readerWriter,
@@ -107,20 +111,20 @@ func TestImporterPreImportWithMissingParent(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumFail,
}
- readerWriter.On("FindByName", missingParentStudioName, false).Return(nil, nil).Times(3)
- readerWriter.On("Create", mock.AnythingOfType("models.Studio")).Return(&models.Studio{
+ readerWriter.On("FindByName", ctx, missingParentStudioName, false).Return(nil, nil).Times(3)
+ readerWriter.On("Create", ctx, mock.AnythingOfType("models.Studio")).Return(&models.Studio{
ID: existingStudioID,
}, nil)
- err := i.PreImport()
+ err := i.PreImport(ctx)
assert.NotNil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore
- err = i.PreImport()
+ err = i.PreImport(ctx)
assert.Nil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumCreate
- err = i.PreImport()
+ err = i.PreImport(ctx)
assert.Nil(t, err)
assert.Equal(t, int64(existingStudioID), i.studio.ParentID.Int64)
@@ -129,6 +133,7 @@ func TestImporterPreImportWithMissingParent(t *testing.T) {
func TestImporterPreImportWithMissingParentCreateErr(t *testing.T) {
readerWriter := &mocks.StudioReaderWriter{}
+ ctx := context.Background()
i := Importer{
ReaderWriter: readerWriter,
@@ -140,15 +145,16 @@ func TestImporterPreImportWithMissingParentCreateErr(t *testing.T) {
MissingRefBehaviour: models.ImportMissingRefEnumCreate,
}
- readerWriter.On("FindByName", missingParentStudioName, false).Return(nil, nil).Once()
- readerWriter.On("Create", mock.AnythingOfType("models.Studio")).Return(nil, errors.New("Create error"))
+ readerWriter.On("FindByName", ctx, missingParentStudioName, false).Return(nil, nil).Once()
+ readerWriter.On("Create", ctx, mock.AnythingOfType("models.Studio")).Return(nil, errors.New("Create error"))
- err := i.PreImport()
+ err := i.PreImport(ctx)
assert.NotNil(t, err)
}
func TestImporterPostImport(t *testing.T) {
readerWriter := &mocks.StudioReaderWriter{}
+ ctx := context.Background()
i := Importer{
ReaderWriter: readerWriter,
@@ -161,21 +167,21 @@ func TestImporterPostImport(t *testing.T) {
updateStudioImageErr := errors.New("UpdateImage error")
updateTagAliasErr := errors.New("UpdateAlias error")
- readerWriter.On("UpdateImage", studioID, imageBytes).Return(nil).Once()
- readerWriter.On("UpdateImage", errImageID, imageBytes).Return(updateStudioImageErr).Once()
- readerWriter.On("UpdateImage", errAliasID, imageBytes).Return(nil).Once()
+ readerWriter.On("UpdateImage", ctx, studioID, imageBytes).Return(nil).Once()
+ readerWriter.On("UpdateImage", ctx, errImageID, imageBytes).Return(updateStudioImageErr).Once()
+ readerWriter.On("UpdateImage", ctx, errAliasID, imageBytes).Return(nil).Once()
- readerWriter.On("UpdateAliases", studioID, i.Input.Aliases).Return(nil).Once()
- readerWriter.On("UpdateAliases", errImageID, i.Input.Aliases).Return(nil).Maybe()
- readerWriter.On("UpdateAliases", errAliasID, i.Input.Aliases).Return(updateTagAliasErr).Once()
+ readerWriter.On("UpdateAliases", ctx, studioID, i.Input.Aliases).Return(nil).Once()
+ readerWriter.On("UpdateAliases", ctx, errImageID, i.Input.Aliases).Return(nil).Maybe()
+ readerWriter.On("UpdateAliases", ctx, errAliasID, i.Input.Aliases).Return(updateTagAliasErr).Once()
- err := i.PostImport(studioID)
+ err := i.PostImport(ctx, studioID)
assert.Nil(t, err)
- err = i.PostImport(errImageID)
+ err = i.PostImport(ctx, errImageID)
assert.NotNil(t, err)
- err = i.PostImport(errAliasID)
+ err = i.PostImport(ctx, errAliasID)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
@@ -183,6 +189,7 @@ func TestImporterPostImport(t *testing.T) {
func TestImporterFindExistingID(t *testing.T) {
readerWriter := &mocks.StudioReaderWriter{}
+ ctx := context.Background()
i := Importer{
ReaderWriter: readerWriter,
@@ -192,23 +199,23 @@ func TestImporterFindExistingID(t *testing.T) {
}
errFindByName := errors.New("FindByName error")
- readerWriter.On("FindByName", studioName, false).Return(nil, nil).Once()
- readerWriter.On("FindByName", existingStudioName, false).Return(&models.Studio{
+ readerWriter.On("FindByName", ctx, studioName, false).Return(nil, nil).Once()
+ readerWriter.On("FindByName", ctx, existingStudioName, false).Return(&models.Studio{
ID: existingStudioID,
}, nil).Once()
- readerWriter.On("FindByName", studioNameErr, false).Return(nil, errFindByName).Once()
+ readerWriter.On("FindByName", ctx, studioNameErr, false).Return(nil, errFindByName).Once()
- id, err := i.FindExistingID()
+ id, err := i.FindExistingID(ctx)
assert.Nil(t, id)
assert.Nil(t, err)
i.Input.Name = existingStudioName
- id, err = i.FindExistingID()
+ id, err = i.FindExistingID(ctx)
assert.Equal(t, existingStudioID, *id)
assert.Nil(t, err)
i.Input.Name = studioNameErr
- id, err = i.FindExistingID()
+ id, err = i.FindExistingID(ctx)
assert.Nil(t, id)
assert.NotNil(t, err)
@@ -217,6 +224,7 @@ func TestImporterFindExistingID(t *testing.T) {
func TestCreate(t *testing.T) {
readerWriter := &mocks.StudioReaderWriter{}
+ ctx := context.Background()
studio := models.Studio{
Name: models.NullString(studioName),
@@ -232,17 +240,17 @@ func TestCreate(t *testing.T) {
}
errCreate := errors.New("Create error")
- readerWriter.On("Create", studio).Return(&models.Studio{
+ readerWriter.On("Create", ctx, studio).Return(&models.Studio{
ID: studioID,
}, nil).Once()
- readerWriter.On("Create", studioErr).Return(nil, errCreate).Once()
+ readerWriter.On("Create", ctx, studioErr).Return(nil, errCreate).Once()
- id, err := i.Create()
+ id, err := i.Create(ctx)
assert.Equal(t, studioID, *id)
assert.Nil(t, err)
i.studio = studioErr
- id, err = i.Create()
+ id, err = i.Create(ctx)
assert.Nil(t, id)
assert.NotNil(t, err)
@@ -251,6 +259,7 @@ func TestCreate(t *testing.T) {
func TestUpdate(t *testing.T) {
readerWriter := &mocks.StudioReaderWriter{}
+ ctx := context.Background()
studio := models.Studio{
Name: models.NullString(studioName),
@@ -269,18 +278,18 @@ func TestUpdate(t *testing.T) {
// id needs to be set for the mock input
studio.ID = studioID
- readerWriter.On("UpdateFull", studio).Return(nil, nil).Once()
+ readerWriter.On("UpdateFull", ctx, studio).Return(nil, nil).Once()
- err := i.Update(studioID)
+ err := i.Update(ctx, studioID)
assert.Nil(t, err)
i.studio = studioErr
// need to set id separately
studioErr.ID = errImageID
- readerWriter.On("UpdateFull", studioErr).Return(nil, errUpdate).Once()
+ readerWriter.On("UpdateFull", ctx, studioErr).Return(nil, errUpdate).Once()
- err = i.Update(errImageID)
+ err = i.Update(ctx, errImageID)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
diff --git a/pkg/studio/query.go b/pkg/studio/query.go
index 5b2f68896..dee499a1b 100644
--- a/pkg/studio/query.go
+++ b/pkg/studio/query.go
@@ -1,8 +1,20 @@
package studio
-import "github.com/stashapp/stash/pkg/models"
+import (
+ "context"
-func ByName(qb models.StudioReader, name string) (*models.Studio, error) {
+ "github.com/stashapp/stash/pkg/models"
+)
+
+type Finder interface {
+ Find(ctx context.Context, id int) (*models.Studio, error)
+}
+
+type Queryer interface {
+ Query(ctx context.Context, studioFilter *models.StudioFilterType, findFilter *models.FindFilterType) ([]*models.Studio, int, error)
+}
+
+func ByName(ctx context.Context, qb Queryer, name string) (*models.Studio, error) {
f := &models.StudioFilterType{
Name: &models.StringCriterionInput{
Value: name,
@@ -11,7 +23,7 @@ func ByName(qb models.StudioReader, name string) (*models.Studio, error) {
}
pp := 1
- ret, count, err := qb.Query(f, &models.FindFilterType{
+ ret, count, err := qb.Query(ctx, f, &models.FindFilterType{
PerPage: &pp,
})
@@ -26,7 +38,7 @@ func ByName(qb models.StudioReader, name string) (*models.Studio, error) {
return nil, nil
}
-func ByAlias(qb models.StudioReader, alias string) (*models.Studio, error) {
+func ByAlias(ctx context.Context, qb Queryer, alias string) (*models.Studio, error) {
f := &models.StudioFilterType{
Aliases: &models.StringCriterionInput{
Value: alias,
@@ -35,7 +47,7 @@ func ByAlias(qb models.StudioReader, alias string) (*models.Studio, error) {
}
pp := 1
- ret, count, err := qb.Query(f, &models.FindFilterType{
+ ret, count, err := qb.Query(ctx, f, &models.FindFilterType{
PerPage: &pp,
})
diff --git a/pkg/studio/update.go b/pkg/studio/update.go
index 35a655a73..addae5c94 100644
--- a/pkg/studio/update.go
+++ b/pkg/studio/update.go
@@ -1,11 +1,17 @@
package studio
import (
+ "context"
"fmt"
"github.com/stashapp/stash/pkg/models"
)
+type NameFinderCreator interface {
+ FindByName(ctx context.Context, name string, nocase bool) (*models.Studio, error)
+ Create(ctx context.Context, newStudio models.Studio) (*models.Studio, error)
+}
+
type NameExistsError struct {
Name string
}
@@ -25,9 +31,9 @@ func (e *NameUsedByAliasError) Error() string {
// EnsureStudioNameUnique returns an error if the studio name provided
// is used as a name or alias of another existing tag.
-func EnsureStudioNameUnique(id int, name string, qb models.StudioReader) error {
+func EnsureStudioNameUnique(ctx context.Context, id int, name string, qb Queryer) error {
// ensure name is unique
- sameNameStudio, err := ByName(qb, name)
+ sameNameStudio, err := ByName(ctx, qb, name)
if err != nil {
return err
}
@@ -39,7 +45,7 @@ func EnsureStudioNameUnique(id int, name string, qb models.StudioReader) error {
}
// query by alias
- sameNameStudio, err = ByAlias(qb, name)
+ sameNameStudio, err = ByAlias(ctx, qb, name)
if err != nil {
return err
}
@@ -54,9 +60,9 @@ func EnsureStudioNameUnique(id int, name string, qb models.StudioReader) error {
return nil
}
-func EnsureAliasesUnique(id int, aliases []string, qb models.StudioReader) error {
+func EnsureAliasesUnique(ctx context.Context, id int, aliases []string, qb Queryer) error {
for _, a := range aliases {
- if err := EnsureStudioNameUnique(id, a, qb); err != nil {
+ if err := EnsureStudioNameUnique(ctx, id, a, qb); err != nil {
return err
}
}
diff --git a/pkg/tag/export.go b/pkg/tag/export.go
index e70392379..20c1b4adc 100644
--- a/pkg/tag/export.go
+++ b/pkg/tag/export.go
@@ -1,6 +1,7 @@
package tag
import (
+ "context"
"fmt"
"github.com/stashapp/stash/pkg/models"
@@ -9,8 +10,14 @@ import (
"github.com/stashapp/stash/pkg/utils"
)
+type FinderAliasImageGetter interface {
+ GetAliases(ctx context.Context, studioID int) ([]string, error)
+ GetImage(ctx context.Context, tagID int) ([]byte, error)
+ FindByChildTagID(ctx context.Context, childID int) ([]*models.Tag, error)
+}
+
// ToJSON converts a Tag object into its JSON equivalent.
-func ToJSON(reader models.TagReader, tag *models.Tag) (*jsonschema.Tag, error) {
+func ToJSON(ctx context.Context, reader FinderAliasImageGetter, tag *models.Tag) (*jsonschema.Tag, error) {
newTagJSON := jsonschema.Tag{
Name: tag.Name,
IgnoreAutoTag: tag.IgnoreAutoTag,
@@ -18,14 +25,14 @@ func ToJSON(reader models.TagReader, tag *models.Tag) (*jsonschema.Tag, error) {
UpdatedAt: json.JSONTime{Time: tag.UpdatedAt.Timestamp},
}
- aliases, err := reader.GetAliases(tag.ID)
+ aliases, err := reader.GetAliases(ctx, tag.ID)
if err != nil {
return nil, fmt.Errorf("error getting tag aliases: %v", err)
}
newTagJSON.Aliases = aliases
- image, err := reader.GetImage(tag.ID)
+ image, err := reader.GetImage(ctx, tag.ID)
if err != nil {
return nil, fmt.Errorf("error getting tag image: %v", err)
}
@@ -34,7 +41,7 @@ func ToJSON(reader models.TagReader, tag *models.Tag) (*jsonschema.Tag, error) {
newTagJSON.Image = utils.GetBase64StringFromData(image)
}
- parents, err := reader.FindByChildTagID(tag.ID)
+ parents, err := reader.FindByChildTagID(ctx, tag.ID)
if err != nil {
return nil, fmt.Errorf("error getting parents: %v", err)
}
diff --git a/pkg/tag/export_test.go b/pkg/tag/export_test.go
index 930c0fdb1..255c940dd 100644
--- a/pkg/tag/export_test.go
+++ b/pkg/tag/export_test.go
@@ -1,6 +1,7 @@
package tag
import (
+ "context"
"errors"
"github.com/stashapp/stash/pkg/models"
@@ -106,33 +107,34 @@ func initTestTable() {
func TestToJSON(t *testing.T) {
initTestTable()
+ ctx := context.Background()
mockTagReader := &mocks.TagReaderWriter{}
imageErr := errors.New("error getting image")
aliasErr := errors.New("error getting aliases")
parentsErr := errors.New("error getting parents")
- mockTagReader.On("GetAliases", tagID).Return([]string{"alias"}, nil).Once()
- mockTagReader.On("GetAliases", noImageID).Return(nil, nil).Once()
- mockTagReader.On("GetAliases", errImageID).Return(nil, nil).Once()
- mockTagReader.On("GetAliases", errAliasID).Return(nil, aliasErr).Once()
- mockTagReader.On("GetAliases", withParentsID).Return(nil, nil).Once()
- mockTagReader.On("GetAliases", errParentsID).Return(nil, nil).Once()
+ mockTagReader.On("GetAliases", ctx, tagID).Return([]string{"alias"}, nil).Once()
+ mockTagReader.On("GetAliases", ctx, noImageID).Return(nil, nil).Once()
+ mockTagReader.On("GetAliases", ctx, errImageID).Return(nil, nil).Once()
+ mockTagReader.On("GetAliases", ctx, errAliasID).Return(nil, aliasErr).Once()
+ mockTagReader.On("GetAliases", ctx, withParentsID).Return(nil, nil).Once()
+ mockTagReader.On("GetAliases", ctx, errParentsID).Return(nil, nil).Once()
- mockTagReader.On("GetImage", tagID).Return(imageBytes, nil).Once()
- mockTagReader.On("GetImage", noImageID).Return(nil, nil).Once()
- mockTagReader.On("GetImage", errImageID).Return(nil, imageErr).Once()
- mockTagReader.On("GetImage", withParentsID).Return(imageBytes, nil).Once()
- mockTagReader.On("GetImage", errParentsID).Return(nil, nil).Once()
+ mockTagReader.On("GetImage", ctx, tagID).Return(imageBytes, nil).Once()
+ mockTagReader.On("GetImage", ctx, noImageID).Return(nil, nil).Once()
+ mockTagReader.On("GetImage", ctx, errImageID).Return(nil, imageErr).Once()
+ mockTagReader.On("GetImage", ctx, withParentsID).Return(imageBytes, nil).Once()
+ mockTagReader.On("GetImage", ctx, errParentsID).Return(nil, nil).Once()
- mockTagReader.On("FindByChildTagID", tagID).Return(nil, nil).Once()
- mockTagReader.On("FindByChildTagID", noImageID).Return(nil, nil).Once()
- mockTagReader.On("FindByChildTagID", withParentsID).Return([]*models.Tag{{Name: "parent"}}, nil).Once()
- mockTagReader.On("FindByChildTagID", errParentsID).Return(nil, parentsErr).Once()
+ mockTagReader.On("FindByChildTagID", ctx, tagID).Return(nil, nil).Once()
+ mockTagReader.On("FindByChildTagID", ctx, noImageID).Return(nil, nil).Once()
+ mockTagReader.On("FindByChildTagID", ctx, withParentsID).Return([]*models.Tag{{Name: "parent"}}, nil).Once()
+ mockTagReader.On("FindByChildTagID", ctx, errParentsID).Return(nil, parentsErr).Once()
for i, s := range scenarios {
tag := s.tag
- json, err := ToJSON(mockTagReader, &tag)
+ json, err := ToJSON(ctx, mockTagReader, &tag)
switch {
case !s.err && err != nil:
diff --git a/pkg/tag/import.go b/pkg/tag/import.go
index 66028946c..937ea2359 100644
--- a/pkg/tag/import.go
+++ b/pkg/tag/import.go
@@ -1,6 +1,7 @@
package tag
import (
+ "context"
"fmt"
"github.com/stashapp/stash/pkg/models"
@@ -8,6 +9,15 @@ import (
"github.com/stashapp/stash/pkg/utils"
)
+type NameFinderCreatorUpdater interface {
+ FindByName(ctx context.Context, name string, nocase bool) (*models.Tag, error)
+ Create(ctx context.Context, newTag models.Tag) (*models.Tag, error)
+ UpdateFull(ctx context.Context, updatedTag models.Tag) (*models.Tag, error)
+ UpdateImage(ctx context.Context, tagID int, image []byte) error
+ UpdateAliases(ctx context.Context, tagID int, aliases []string) error
+ UpdateParentTags(ctx context.Context, tagID int, parentIDs []int) error
+}
+
type ParentTagNotExistError struct {
missingParent string
}
@@ -21,7 +31,7 @@ func (e ParentTagNotExistError) MissingParent() string {
}
type Importer struct {
- ReaderWriter models.TagReaderWriter
+ ReaderWriter NameFinderCreatorUpdater
Input jsonschema.Tag
MissingRefBehaviour models.ImportMissingRefEnum
@@ -29,7 +39,7 @@ type Importer struct {
imageData []byte
}
-func (i *Importer) PreImport() error {
+func (i *Importer) PreImport(ctx context.Context) error {
i.tag = models.Tag{
Name: i.Input.Name,
IgnoreAutoTag: i.Input.IgnoreAutoTag,
@@ -48,23 +58,23 @@ func (i *Importer) PreImport() error {
return nil
}
-func (i *Importer) PostImport(id int) error {
+func (i *Importer) PostImport(ctx context.Context, id int) error {
if len(i.imageData) > 0 {
- if err := i.ReaderWriter.UpdateImage(id, i.imageData); err != nil {
+ if err := i.ReaderWriter.UpdateImage(ctx, id, i.imageData); err != nil {
return fmt.Errorf("error setting tag image: %v", err)
}
}
- if err := i.ReaderWriter.UpdateAliases(id, i.Input.Aliases); err != nil {
+ if err := i.ReaderWriter.UpdateAliases(ctx, id, i.Input.Aliases); err != nil {
return fmt.Errorf("error setting tag aliases: %v", err)
}
- parents, err := i.getParents()
+ parents, err := i.getParents(ctx)
if err != nil {
return err
}
- if err := i.ReaderWriter.UpdateParentTags(id, parents); err != nil {
+ if err := i.ReaderWriter.UpdateParentTags(ctx, id, parents); err != nil {
return fmt.Errorf("error setting parents: %v", err)
}
@@ -75,9 +85,9 @@ func (i *Importer) Name() string {
return i.Input.Name
}
-func (i *Importer) FindExistingID() (*int, error) {
+func (i *Importer) FindExistingID(ctx context.Context) (*int, error) {
const nocase = false
- existing, err := i.ReaderWriter.FindByName(i.Name(), nocase)
+ existing, err := i.ReaderWriter.FindByName(ctx, i.Name(), nocase)
if err != nil {
return nil, err
}
@@ -90,8 +100,8 @@ func (i *Importer) FindExistingID() (*int, error) {
return nil, nil
}
-func (i *Importer) Create() (*int, error) {
- created, err := i.ReaderWriter.Create(i.tag)
+func (i *Importer) Create(ctx context.Context) (*int, error) {
+ created, err := i.ReaderWriter.Create(ctx, i.tag)
if err != nil {
return nil, fmt.Errorf("error creating tag: %v", err)
}
@@ -100,10 +110,10 @@ func (i *Importer) Create() (*int, error) {
return &id, nil
}
-func (i *Importer) Update(id int) error {
+func (i *Importer) Update(ctx context.Context, id int) error {
tag := i.tag
tag.ID = id
- _, err := i.ReaderWriter.UpdateFull(tag)
+ _, err := i.ReaderWriter.UpdateFull(ctx, tag)
if err != nil {
return fmt.Errorf("error updating existing tag: %v", err)
}
@@ -111,10 +121,10 @@ func (i *Importer) Update(id int) error {
return nil
}
-func (i *Importer) getParents() ([]int, error) {
+func (i *Importer) getParents(ctx context.Context) ([]int, error) {
var parents []int
for _, parent := range i.Input.Parents {
- tag, err := i.ReaderWriter.FindByName(parent, false)
+ tag, err := i.ReaderWriter.FindByName(ctx, parent, false)
if err != nil {
return nil, fmt.Errorf("error finding parent by name: %v", err)
}
@@ -129,7 +139,7 @@ func (i *Importer) getParents() ([]int, error) {
}
if i.MissingRefBehaviour == models.ImportMissingRefEnumCreate {
- parentID, err := i.createParent(parent)
+ parentID, err := i.createParent(ctx, parent)
if err != nil {
return nil, err
}
@@ -143,10 +153,10 @@ func (i *Importer) getParents() ([]int, error) {
return parents, nil
}
-func (i *Importer) createParent(name string) (int, error) {
+func (i *Importer) createParent(ctx context.Context, name string) (int, error) {
newTag := *models.NewTag(name)
- created, err := i.ReaderWriter.Create(newTag)
+ created, err := i.ReaderWriter.Create(ctx, newTag)
if err != nil {
return 0, err
}
diff --git a/pkg/tag/import_test.go b/pkg/tag/import_test.go
index fb6f3c58f..e4fb3ce8d 100644
--- a/pkg/tag/import_test.go
+++ b/pkg/tag/import_test.go
@@ -1,6 +1,7 @@
package tag
import (
+ "context"
"errors"
"testing"
@@ -23,6 +24,8 @@ const (
existingTagID = 100
)
+var testCtx = context.Background()
+
func TestImporterName(t *testing.T) {
i := Importer{
Input: jsonschema.Tag{
@@ -42,13 +45,13 @@ func TestImporterPreImport(t *testing.T) {
},
}
- err := i.PreImport()
+ err := i.PreImport(testCtx)
assert.NotNil(t, err)
i.Input.Image = image
- err = i.PreImport()
+ err = i.PreImport(testCtx)
assert.Nil(t, err)
}
@@ -68,38 +71,38 @@ func TestImporterPostImport(t *testing.T) {
updateTagAliasErr := errors.New("UpdateAlias error")
updateTagParentsErr := errors.New("UpdateParentTags error")
- readerWriter.On("UpdateAliases", tagID, i.Input.Aliases).Return(nil).Once()
- readerWriter.On("UpdateAliases", errAliasID, i.Input.Aliases).Return(updateTagAliasErr).Once()
- readerWriter.On("UpdateAliases", withParentsID, i.Input.Aliases).Return(nil).Once()
- readerWriter.On("UpdateAliases", errParentsID, i.Input.Aliases).Return(nil).Once()
+ readerWriter.On("UpdateAliases", testCtx, tagID, i.Input.Aliases).Return(nil).Once()
+ readerWriter.On("UpdateAliases", testCtx, errAliasID, i.Input.Aliases).Return(updateTagAliasErr).Once()
+ readerWriter.On("UpdateAliases", testCtx, withParentsID, i.Input.Aliases).Return(nil).Once()
+ readerWriter.On("UpdateAliases", testCtx, errParentsID, i.Input.Aliases).Return(nil).Once()
- readerWriter.On("UpdateImage", tagID, imageBytes).Return(nil).Once()
- readerWriter.On("UpdateImage", errAliasID, imageBytes).Return(nil).Once()
- readerWriter.On("UpdateImage", errImageID, imageBytes).Return(updateTagImageErr).Once()
- readerWriter.On("UpdateImage", withParentsID, imageBytes).Return(nil).Once()
- readerWriter.On("UpdateImage", errParentsID, imageBytes).Return(nil).Once()
+ readerWriter.On("UpdateImage", testCtx, tagID, imageBytes).Return(nil).Once()
+ readerWriter.On("UpdateImage", testCtx, errAliasID, imageBytes).Return(nil).Once()
+ readerWriter.On("UpdateImage", testCtx, errImageID, imageBytes).Return(updateTagImageErr).Once()
+ readerWriter.On("UpdateImage", testCtx, withParentsID, imageBytes).Return(nil).Once()
+ readerWriter.On("UpdateImage", testCtx, errParentsID, imageBytes).Return(nil).Once()
var parentTags []int
- readerWriter.On("UpdateParentTags", tagID, parentTags).Return(nil).Once()
- readerWriter.On("UpdateParentTags", withParentsID, []int{100}).Return(nil).Once()
- readerWriter.On("UpdateParentTags", errParentsID, []int{100}).Return(updateTagParentsErr).Once()
+ readerWriter.On("UpdateParentTags", testCtx, tagID, parentTags).Return(nil).Once()
+ readerWriter.On("UpdateParentTags", testCtx, withParentsID, []int{100}).Return(nil).Once()
+ readerWriter.On("UpdateParentTags", testCtx, errParentsID, []int{100}).Return(updateTagParentsErr).Once()
- readerWriter.On("FindByName", "Parent", false).Return(&models.Tag{ID: 100}, nil)
+ readerWriter.On("FindByName", testCtx, "Parent", false).Return(&models.Tag{ID: 100}, nil)
- err := i.PostImport(tagID)
+ err := i.PostImport(testCtx, tagID)
assert.Nil(t, err)
- err = i.PostImport(errImageID)
+ err = i.PostImport(testCtx, errImageID)
assert.NotNil(t, err)
- err = i.PostImport(errAliasID)
+ err = i.PostImport(testCtx, errAliasID)
assert.NotNil(t, err)
i.Input.Parents = []string{"Parent"}
- err = i.PostImport(withParentsID)
+ err = i.PostImport(testCtx, withParentsID)
assert.Nil(t, err)
- err = i.PostImport(errParentsID)
+ err = i.PostImport(testCtx, errParentsID)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
@@ -129,70 +132,70 @@ func TestImporterPostImportParentMissing(t *testing.T) {
var emptyParents []int
- readerWriter.On("UpdateImage", mock.Anything, mock.Anything).Return(nil)
- readerWriter.On("UpdateAliases", mock.Anything, mock.Anything).Return(nil)
+ readerWriter.On("UpdateImage", testCtx, mock.Anything, mock.Anything).Return(nil)
+ readerWriter.On("UpdateAliases", testCtx, mock.Anything, mock.Anything).Return(nil)
- readerWriter.On("FindByName", "Create", false).Return(nil, nil).Once()
- readerWriter.On("FindByName", "CreateError", false).Return(nil, nil).Once()
- readerWriter.On("FindByName", "CreateFindError", false).Return(nil, findError).Once()
- readerWriter.On("FindByName", "CreateFound", false).Return(&models.Tag{ID: 101}, nil).Once()
- readerWriter.On("FindByName", "Fail", false).Return(nil, nil).Once()
- readerWriter.On("FindByName", "FailFindError", false).Return(nil, findError)
- readerWriter.On("FindByName", "FailFound", false).Return(&models.Tag{ID: 102}, nil).Once()
- readerWriter.On("FindByName", "Ignore", false).Return(nil, nil).Once()
- readerWriter.On("FindByName", "IgnoreFindError", false).Return(nil, findError)
- readerWriter.On("FindByName", "IgnoreFound", false).Return(&models.Tag{ID: 103}, nil).Once()
+ readerWriter.On("FindByName", testCtx, "Create", false).Return(nil, nil).Once()
+ readerWriter.On("FindByName", testCtx, "CreateError", false).Return(nil, nil).Once()
+ readerWriter.On("FindByName", testCtx, "CreateFindError", false).Return(nil, findError).Once()
+ readerWriter.On("FindByName", testCtx, "CreateFound", false).Return(&models.Tag{ID: 101}, nil).Once()
+ readerWriter.On("FindByName", testCtx, "Fail", false).Return(nil, nil).Once()
+ readerWriter.On("FindByName", testCtx, "FailFindError", false).Return(nil, findError)
+ readerWriter.On("FindByName", testCtx, "FailFound", false).Return(&models.Tag{ID: 102}, nil).Once()
+ readerWriter.On("FindByName", testCtx, "Ignore", false).Return(nil, nil).Once()
+ readerWriter.On("FindByName", testCtx, "IgnoreFindError", false).Return(nil, findError)
+ readerWriter.On("FindByName", testCtx, "IgnoreFound", false).Return(&models.Tag{ID: 103}, nil).Once()
- readerWriter.On("UpdateParentTags", createID, []int{100}).Return(nil).Once()
- readerWriter.On("UpdateParentTags", createFoundID, []int{101}).Return(nil).Once()
- readerWriter.On("UpdateParentTags", failFoundID, []int{102}).Return(nil).Once()
- readerWriter.On("UpdateParentTags", ignoreID, emptyParents).Return(nil).Once()
- readerWriter.On("UpdateParentTags", ignoreFoundID, []int{103}).Return(nil).Once()
+ readerWriter.On("UpdateParentTags", testCtx, createID, []int{100}).Return(nil).Once()
+ readerWriter.On("UpdateParentTags", testCtx, createFoundID, []int{101}).Return(nil).Once()
+ readerWriter.On("UpdateParentTags", testCtx, failFoundID, []int{102}).Return(nil).Once()
+ readerWriter.On("UpdateParentTags", testCtx, ignoreID, emptyParents).Return(nil).Once()
+ readerWriter.On("UpdateParentTags", testCtx, ignoreFoundID, []int{103}).Return(nil).Once()
- readerWriter.On("Create", mock.MatchedBy(func(t models.Tag) bool { return t.Name == "Create" })).Return(&models.Tag{ID: 100}, nil).Once()
- readerWriter.On("Create", mock.MatchedBy(func(t models.Tag) bool { return t.Name == "CreateError" })).Return(nil, errors.New("failed creating parent")).Once()
+ readerWriter.On("Create", testCtx, mock.MatchedBy(func(t models.Tag) bool { return t.Name == "Create" })).Return(&models.Tag{ID: 100}, nil).Once()
+ readerWriter.On("Create", testCtx, mock.MatchedBy(func(t models.Tag) bool { return t.Name == "CreateError" })).Return(nil, errors.New("failed creating parent")).Once()
i.MissingRefBehaviour = models.ImportMissingRefEnumCreate
i.Input.Parents = []string{"Create"}
- err := i.PostImport(createID)
+ err := i.PostImport(testCtx, createID)
assert.Nil(t, err)
i.Input.Parents = []string{"CreateError"}
- err = i.PostImport(createErrorID)
+ err = i.PostImport(testCtx, createErrorID)
assert.NotNil(t, err)
i.Input.Parents = []string{"CreateFindError"}
- err = i.PostImport(createFindErrorID)
+ err = i.PostImport(testCtx, createFindErrorID)
assert.NotNil(t, err)
i.Input.Parents = []string{"CreateFound"}
- err = i.PostImport(createFoundID)
+ err = i.PostImport(testCtx, createFoundID)
assert.Nil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumFail
i.Input.Parents = []string{"Fail"}
- err = i.PostImport(failID)
+ err = i.PostImport(testCtx, failID)
assert.NotNil(t, err)
i.Input.Parents = []string{"FailFindError"}
- err = i.PostImport(failFindErrorID)
+ err = i.PostImport(testCtx, failFindErrorID)
assert.NotNil(t, err)
i.Input.Parents = []string{"FailFound"}
- err = i.PostImport(failFoundID)
+ err = i.PostImport(testCtx, failFoundID)
assert.Nil(t, err)
i.MissingRefBehaviour = models.ImportMissingRefEnumIgnore
i.Input.Parents = []string{"Ignore"}
- err = i.PostImport(ignoreID)
+ err = i.PostImport(testCtx, ignoreID)
assert.Nil(t, err)
i.Input.Parents = []string{"IgnoreFindError"}
- err = i.PostImport(ignoreFindErrorID)
+ err = i.PostImport(testCtx, ignoreFindErrorID)
assert.NotNil(t, err)
i.Input.Parents = []string{"IgnoreFound"}
- err = i.PostImport(ignoreFoundID)
+ err = i.PostImport(testCtx, ignoreFoundID)
assert.Nil(t, err)
readerWriter.AssertExpectations(t)
@@ -209,23 +212,23 @@ func TestImporterFindExistingID(t *testing.T) {
}
errFindByName := errors.New("FindByName error")
- readerWriter.On("FindByName", tagName, false).Return(nil, nil).Once()
- readerWriter.On("FindByName", existingTagName, false).Return(&models.Tag{
+ readerWriter.On("FindByName", testCtx, tagName, false).Return(nil, nil).Once()
+ readerWriter.On("FindByName", testCtx, existingTagName, false).Return(&models.Tag{
ID: existingTagID,
}, nil).Once()
- readerWriter.On("FindByName", tagNameErr, false).Return(nil, errFindByName).Once()
+ readerWriter.On("FindByName", testCtx, tagNameErr, false).Return(nil, errFindByName).Once()
- id, err := i.FindExistingID()
+ id, err := i.FindExistingID(testCtx)
assert.Nil(t, id)
assert.Nil(t, err)
i.Input.Name = existingTagName
- id, err = i.FindExistingID()
+ id, err = i.FindExistingID(testCtx)
assert.Equal(t, existingTagID, *id)
assert.Nil(t, err)
i.Input.Name = tagNameErr
- id, err = i.FindExistingID()
+ id, err = i.FindExistingID(testCtx)
assert.Nil(t, id)
assert.NotNil(t, err)
@@ -249,17 +252,17 @@ func TestCreate(t *testing.T) {
}
errCreate := errors.New("Create error")
- readerWriter.On("Create", tag).Return(&models.Tag{
+ readerWriter.On("Create", testCtx, tag).Return(&models.Tag{
ID: tagID,
}, nil).Once()
- readerWriter.On("Create", tagErr).Return(nil, errCreate).Once()
+ readerWriter.On("Create", testCtx, tagErr).Return(nil, errCreate).Once()
- id, err := i.Create()
+ id, err := i.Create(testCtx)
assert.Equal(t, tagID, *id)
assert.Nil(t, err)
i.tag = tagErr
- id, err = i.Create()
+ id, err = i.Create(testCtx)
assert.Nil(t, id)
assert.NotNil(t, err)
@@ -286,18 +289,18 @@ func TestUpdate(t *testing.T) {
// id needs to be set for the mock input
tag.ID = tagID
- readerWriter.On("UpdateFull", tag).Return(nil, nil).Once()
+ readerWriter.On("UpdateFull", testCtx, tag).Return(nil, nil).Once()
- err := i.Update(tagID)
+ err := i.Update(testCtx, tagID)
assert.Nil(t, err)
i.tag = tagErr
// need to set id separately
tagErr.ID = errImageID
- readerWriter.On("UpdateFull", tagErr).Return(nil, errUpdate).Once()
+ readerWriter.On("UpdateFull", testCtx, tagErr).Return(nil, errUpdate).Once()
- err = i.Update(errImageID)
+ err = i.Update(testCtx, errImageID)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
diff --git a/pkg/tag/query.go b/pkg/tag/query.go
index ce7406403..a048054d7 100644
--- a/pkg/tag/query.go
+++ b/pkg/tag/query.go
@@ -1,8 +1,20 @@
package tag
-import "github.com/stashapp/stash/pkg/models"
+import (
+ "context"
-func ByName(qb models.TagReader, name string) (*models.Tag, error) {
+ "github.com/stashapp/stash/pkg/models"
+)
+
+type Finder interface {
+ Find(ctx context.Context, id int) (*models.Tag, error)
+}
+
+type Queryer interface {
+ Query(ctx context.Context, tagFilter *models.TagFilterType, findFilter *models.FindFilterType) ([]*models.Tag, int, error)
+}
+
+func ByName(ctx context.Context, qb Queryer, name string) (*models.Tag, error) {
f := &models.TagFilterType{
Name: &models.StringCriterionInput{
Value: name,
@@ -11,7 +23,7 @@ func ByName(qb models.TagReader, name string) (*models.Tag, error) {
}
pp := 1
- ret, count, err := qb.Query(f, &models.FindFilterType{
+ ret, count, err := qb.Query(ctx, f, &models.FindFilterType{
PerPage: &pp,
})
@@ -26,7 +38,7 @@ func ByName(qb models.TagReader, name string) (*models.Tag, error) {
return nil, nil
}
-func ByAlias(qb models.TagReader, alias string) (*models.Tag, error) {
+func ByAlias(ctx context.Context, qb Queryer, alias string) (*models.Tag, error) {
f := &models.TagFilterType{
Aliases: &models.StringCriterionInput{
Value: alias,
@@ -35,7 +47,7 @@ func ByAlias(qb models.TagReader, alias string) (*models.Tag, error) {
}
pp := 1
- ret, count, err := qb.Query(f, &models.FindFilterType{
+ ret, count, err := qb.Query(ctx, f, &models.FindFilterType{
PerPage: &pp,
})
diff --git a/pkg/tag/update.go b/pkg/tag/update.go
index dfee55154..0c219b26c 100644
--- a/pkg/tag/update.go
+++ b/pkg/tag/update.go
@@ -1,11 +1,17 @@
package tag
import (
+ "context"
"fmt"
"github.com/stashapp/stash/pkg/models"
)
+type NameFinderCreator interface {
+ FindByNames(ctx context.Context, names []string, nocase bool) ([]*models.Tag, error)
+ Create(ctx context.Context, newTag models.Tag) (*models.Tag, error)
+}
+
type NameExistsError struct {
Name string
}
@@ -37,9 +43,9 @@ func (e *InvalidTagHierarchyError) Error() string {
// EnsureTagNameUnique returns an error if the tag name provided
// is used as a name or alias of another existing tag.
-func EnsureTagNameUnique(id int, name string, qb models.TagReader) error {
+func EnsureTagNameUnique(ctx context.Context, id int, name string, qb Queryer) error {
// ensure name is unique
- sameNameTag, err := ByName(qb, name)
+ sameNameTag, err := ByName(ctx, qb, name)
if err != nil {
return err
}
@@ -51,7 +57,7 @@ func EnsureTagNameUnique(id int, name string, qb models.TagReader) error {
}
// query by alias
- sameNameTag, err = ByAlias(qb, name)
+ sameNameTag, err = ByAlias(ctx, qb, name)
if err != nil {
return err
}
@@ -66,9 +72,9 @@ func EnsureTagNameUnique(id int, name string, qb models.TagReader) error {
return nil
}
-func EnsureAliasesUnique(id int, aliases []string, qb models.TagReader) error {
+func EnsureAliasesUnique(ctx context.Context, id int, aliases []string, qb Queryer) error {
for _, a := range aliases {
- if err := EnsureTagNameUnique(id, a, qb); err != nil {
+ if err := EnsureTagNameUnique(ctx, id, a, qb); err != nil {
return err
}
}
@@ -76,12 +82,19 @@ func EnsureAliasesUnique(id int, aliases []string, qb models.TagReader) error {
return nil
}
-func ValidateHierarchy(tag *models.Tag, parentIDs, childIDs []int, qb models.TagReader) error {
+type RelationshipGetter interface {
+ FindAllAncestors(ctx context.Context, tagID int, excludeIDs []int) ([]*models.TagPath, error)
+ FindAllDescendants(ctx context.Context, tagID int, excludeIDs []int) ([]*models.TagPath, error)
+ FindByChildTagID(ctx context.Context, childID int) ([]*models.Tag, error)
+ FindByParentTagID(ctx context.Context, parentID int) ([]*models.Tag, error)
+}
+
+func ValidateHierarchy(ctx context.Context, tag *models.Tag, parentIDs, childIDs []int, qb RelationshipGetter) error {
id := tag.ID
allAncestors := make(map[int]*models.TagPath)
allDescendants := make(map[int]*models.TagPath)
- parentsAncestors, err := qb.FindAllAncestors(id, nil)
+ parentsAncestors, err := qb.FindAllAncestors(ctx, id, nil)
if err != nil {
return err
}
@@ -90,7 +103,7 @@ func ValidateHierarchy(tag *models.Tag, parentIDs, childIDs []int, qb models.Tag
allAncestors[ancestorTag.ID] = ancestorTag
}
- childsDescendants, err := qb.FindAllDescendants(id, nil)
+ childsDescendants, err := qb.FindAllDescendants(ctx, id, nil)
if err != nil {
return err
}
@@ -128,7 +141,7 @@ func ValidateHierarchy(tag *models.Tag, parentIDs, childIDs []int, qb models.Tag
}
if parentIDs == nil {
- parentTags, err := qb.FindByChildTagID(id)
+ parentTags, err := qb.FindByChildTagID(ctx, id)
if err != nil {
return err
}
@@ -139,7 +152,7 @@ func ValidateHierarchy(tag *models.Tag, parentIDs, childIDs []int, qb models.Tag
}
if childIDs == nil {
- childTags, err := qb.FindByParentTagID(id)
+ childTags, err := qb.FindByParentTagID(ctx, id)
if err != nil {
return err
}
@@ -164,7 +177,7 @@ func ValidateHierarchy(tag *models.Tag, parentIDs, childIDs []int, qb models.Tag
return nil
}
-func MergeHierarchy(destination int, sources []int, qb models.TagReader) ([]int, []int, error) {
+func MergeHierarchy(ctx context.Context, destination int, sources []int, qb RelationshipGetter) ([]int, []int, error) {
var mergedParents, mergedChildren []int
allIds := append([]int{destination}, sources...)
@@ -192,14 +205,14 @@ func MergeHierarchy(destination int, sources []int, qb models.TagReader) ([]int,
}
for _, id := range allIds {
- parents, err := qb.FindByChildTagID(id)
+ parents, err := qb.FindByChildTagID(ctx, id)
if err != nil {
return nil, nil, err
}
mergedParents = addTo(mergedParents, parents)
- children, err := qb.FindByParentTagID(id)
+ children, err := qb.FindByParentTagID(ctx, id)
if err != nil {
return nil, nil, err
}
diff --git a/pkg/tag/update_test.go b/pkg/tag/update_test.go
index f7338da23..4cc14e961 100644
--- a/pkg/tag/update_test.go
+++ b/pkg/tag/update_test.go
@@ -1,6 +1,7 @@
package tag
import (
+ "context"
"fmt"
"testing"
@@ -219,6 +220,7 @@ func TestEnsureHierarchy(t *testing.T) {
func testEnsureHierarchy(t *testing.T, tc testUniqueHierarchyCase, queryParents, queryChildren bool) {
mockTagReader := &mocks.TagReaderWriter{}
+ ctx := context.Background()
var parentIDs, childIDs []int
find := make(map[int]*models.Tag)
@@ -245,33 +247,33 @@ func testEnsureHierarchy(t *testing.T, tc testUniqueHierarchyCase, queryParents,
if queryParents {
parentIDs = nil
- mockTagReader.On("FindByChildTagID", tc.id).Return(tc.parents, nil).Once()
+ mockTagReader.On("FindByChildTagID", ctx, tc.id).Return(tc.parents, nil).Once()
}
if queryChildren {
childIDs = nil
- mockTagReader.On("FindByParentTagID", tc.id).Return(tc.children, nil).Once()
+ mockTagReader.On("FindByParentTagID", ctx, tc.id).Return(tc.children, nil).Once()
}
- mockTagReader.On("FindAllAncestors", mock.AnythingOfType("int"), []int(nil)).Return(func(tagID int, excludeIDs []int) []*models.TagPath {
+ mockTagReader.On("FindAllAncestors", ctx, mock.AnythingOfType("int"), []int(nil)).Return(func(ctx context.Context, tagID int, excludeIDs []int) []*models.TagPath {
return tc.onFindAllAncestors
- }, func(tagID int, excludeIDs []int) error {
+ }, func(ctx context.Context, tagID int, excludeIDs []int) error {
if tc.onFindAllAncestors != nil {
return nil
}
return fmt.Errorf("undefined ancestors for: %d", tagID)
}).Maybe()
- mockTagReader.On("FindAllDescendants", mock.AnythingOfType("int"), []int(nil)).Return(func(tagID int, excludeIDs []int) []*models.TagPath {
+ mockTagReader.On("FindAllDescendants", ctx, mock.AnythingOfType("int"), []int(nil)).Return(func(ctx context.Context, tagID int, excludeIDs []int) []*models.TagPath {
return tc.onFindAllDescendants
- }, func(tagID int, excludeIDs []int) error {
+ }, func(ctx context.Context, tagID int, excludeIDs []int) error {
if tc.onFindAllDescendants != nil {
return nil
}
return fmt.Errorf("undefined descendants for: %d", tagID)
}).Maybe()
- res := ValidateHierarchy(testUniqueHierarchyTags[tc.id], parentIDs, childIDs, mockTagReader)
+ res := ValidateHierarchy(ctx, testUniqueHierarchyTags[tc.id], parentIDs, childIDs, mockTagReader)
assert := assert.New(t)
diff --git a/pkg/txn/transaction.go b/pkg/txn/transaction.go
new file mode 100644
index 000000000..6939828b4
--- /dev/null
+++ b/pkg/txn/transaction.go
@@ -0,0 +1,38 @@
+package txn
+
+import "context"
+
+type Manager interface {
+ Begin(ctx context.Context) (context.Context, error)
+ Commit(ctx context.Context) error
+ Rollback(ctx context.Context) error
+}
+
+type TxnFunc func(ctx context.Context) error
+
+func WithTxn(ctx context.Context, m Manager, fn TxnFunc) error {
+ var err error
+ ctx, err = m.Begin(ctx)
+ if err != nil {
+ return err
+ }
+
+ defer func() {
+ if p := recover(); p != nil {
+ // a panic occurred, rollback and repanic
+ _ = m.Rollback(ctx)
+ panic(p)
+ }
+
+ if err != nil {
+ // something went wrong, rollback
+ _ = m.Rollback(ctx)
+ } else {
+ // all good, commit
+ err = m.Commit(ctx)
+ }
+ }()
+
+ err = fn(ctx)
+ return err
+}
From 30877c75fbca03992d08b3e617230fcdde2ef1bc Mon Sep 17 00:00:00 2001
From: WithoutPants <53250216+WithoutPants@users.noreply.github.com>
Date: Wed, 13 Jul 2022 12:57:53 +1000
Subject: [PATCH 03/30] Release notes dialog (#2726)
* Move manual docs
* Move changelog docs
* Add migration notes
* Move changelog to settings
* Add release notes dialog
* Add new changelog
---
ui/v2.5/src/App.tsx | 38 ++++++++++++-
.../src/components/Changelog/Changelog.tsx | 54 ++++++++++---------
ui/v2.5/src/components/Changelog/styles.scss | 1 -
.../components/Dialogs/ReleaseNotesDialog.tsx | 39 ++++++++++++++
.../src/components/FrontPage/FrontPage.tsx | 1 +
ui/v2.5/src/components/Help/Manual.tsx | 44 +++++++--------
ui/v2.5/src/components/Settings/Settings.tsx | 9 ++++
ui/v2.5/src/components/Setup/Migrate.tsx | 48 +++++++++++++++--
ui/v2.5/src/components/Setup/styles.scss | 9 ++++
ui/v2.5/src/components/Stats.tsx | 4 --
ui/v2.5/src/core/config.ts | 1 +
.../versions => docs/en/Changelog}/v010.md | 0
.../versions => docs/en/Changelog}/v0100.md | 0
.../versions => docs/en/Changelog}/v011.md | 0
.../versions => docs/en/Changelog}/v0110.md | 0
.../versions => docs/en/Changelog}/v0120.md | 0
.../versions => docs/en/Changelog}/v0130.md | 0
.../versions => docs/en/Changelog}/v0131.md | 0
.../versions => docs/en/Changelog}/v0140.md | 0
.../versions => docs/en/Changelog}/v0150.md | 0
.../versions => docs/en/Changelog}/v0160.md | 0
.../versions => docs/en/Changelog}/v0161.md | 0
ui/v2.5/src/docs/en/Changelog/v0170.md | 5 ++
.../versions => docs/en/Changelog}/v020.md | 0
.../versions => docs/en/Changelog}/v021.md | 0
.../versions => docs/en/Changelog}/v030.md | 0
.../versions => docs/en/Changelog}/v040.md | 0
.../versions => docs/en/Changelog}/v050.md | 0
.../versions => docs/en/Changelog}/v060.md | 0
.../versions => docs/en/Changelog}/v070.md | 0
.../versions => docs/en/Changelog}/v080.md | 0
.../versions => docs/en/Changelog}/v090.md | 0
.../src/docs/en/{ => Manual}/AutoTagging.md | 0
ui/v2.5/src/docs/en/{ => Manual}/Browsing.md | 0
ui/v2.5/src/docs/en/{ => Manual}/Captions.md | 0
.../src/docs/en/{ => Manual}/Configuration.md | 0
.../src/docs/en/{ => Manual}/Contributing.md | 0
.../src/docs/en/{ => Manual}/Deduplication.md | 0
.../docs/en/{ => Manual}/EmbeddedPlugins.md | 0
.../docs/en/{ => Manual}/ExternalPlugins.md | 0
ui/v2.5/src/docs/en/{ => Manual}/Galleries.md | 0
ui/v2.5/src/docs/en/{ => Manual}/Help.md | 0
ui/v2.5/src/docs/en/{ => Manual}/Identify.md | 0
.../src/docs/en/{ => Manual}/Interactive.md | 0
ui/v2.5/src/docs/en/{ => Manual}/Interface.md | 0
.../src/docs/en/{ => Manual}/Introduction.md | 0
ui/v2.5/src/docs/en/{ => Manual}/JSONSpec.md | 0
.../docs/en/{ => Manual}/KeyboardShortcuts.md | 0
ui/v2.5/src/docs/en/{ => Manual}/Plugins.md | 0
.../en/{ => Manual}/SceneFilenameParser.md | 0
.../en/{ => Manual}/ScraperDevelopment.md | 0
ui/v2.5/src/docs/en/{ => Manual}/Scraping.md | 0
ui/v2.5/src/docs/en/{ => Manual}/Tagger.md | 0
ui/v2.5/src/docs/en/{ => Manual}/Tasks.md | 0
ui/v2.5/src/docs/en/MigrationNotes/index.ts | 9 ++++
ui/v2.5/src/docs/en/ReleaseNotes/index.ts | 16 ++++++
ui/v2.5/src/docs/en/ReleaseNotes/v0170.md | 1 +
ui/v2.5/src/index.scss | 1 +
ui/v2.5/src/locales/en-GB.json | 6 ++-
59 files changed, 229 insertions(+), 57 deletions(-)
create mode 100644 ui/v2.5/src/components/Dialogs/ReleaseNotesDialog.tsx
create mode 100644 ui/v2.5/src/components/Setup/styles.scss
rename ui/v2.5/src/{components/Changelog/versions => docs/en/Changelog}/v010.md (100%)
rename ui/v2.5/src/{components/Changelog/versions => docs/en/Changelog}/v0100.md (100%)
rename ui/v2.5/src/{components/Changelog/versions => docs/en/Changelog}/v011.md (100%)
rename ui/v2.5/src/{components/Changelog/versions => docs/en/Changelog}/v0110.md (100%)
rename ui/v2.5/src/{components/Changelog/versions => docs/en/Changelog}/v0120.md (100%)
rename ui/v2.5/src/{components/Changelog/versions => docs/en/Changelog}/v0130.md (100%)
rename ui/v2.5/src/{components/Changelog/versions => docs/en/Changelog}/v0131.md (100%)
rename ui/v2.5/src/{components/Changelog/versions => docs/en/Changelog}/v0140.md (100%)
rename ui/v2.5/src/{components/Changelog/versions => docs/en/Changelog}/v0150.md (100%)
rename ui/v2.5/src/{components/Changelog/versions => docs/en/Changelog}/v0160.md (100%)
rename ui/v2.5/src/{components/Changelog/versions => docs/en/Changelog}/v0161.md (100%)
create mode 100644 ui/v2.5/src/docs/en/Changelog/v0170.md
rename ui/v2.5/src/{components/Changelog/versions => docs/en/Changelog}/v020.md (100%)
rename ui/v2.5/src/{components/Changelog/versions => docs/en/Changelog}/v021.md (100%)
rename ui/v2.5/src/{components/Changelog/versions => docs/en/Changelog}/v030.md (100%)
rename ui/v2.5/src/{components/Changelog/versions => docs/en/Changelog}/v040.md (100%)
rename ui/v2.5/src/{components/Changelog/versions => docs/en/Changelog}/v050.md (100%)
rename ui/v2.5/src/{components/Changelog/versions => docs/en/Changelog}/v060.md (100%)
rename ui/v2.5/src/{components/Changelog/versions => docs/en/Changelog}/v070.md (100%)
rename ui/v2.5/src/{components/Changelog/versions => docs/en/Changelog}/v080.md (100%)
rename ui/v2.5/src/{components/Changelog/versions => docs/en/Changelog}/v090.md (100%)
rename ui/v2.5/src/docs/en/{ => Manual}/AutoTagging.md (100%)
rename ui/v2.5/src/docs/en/{ => Manual}/Browsing.md (100%)
rename ui/v2.5/src/docs/en/{ => Manual}/Captions.md (100%)
rename ui/v2.5/src/docs/en/{ => Manual}/Configuration.md (100%)
rename ui/v2.5/src/docs/en/{ => Manual}/Contributing.md (100%)
rename ui/v2.5/src/docs/en/{ => Manual}/Deduplication.md (100%)
rename ui/v2.5/src/docs/en/{ => Manual}/EmbeddedPlugins.md (100%)
rename ui/v2.5/src/docs/en/{ => Manual}/ExternalPlugins.md (100%)
rename ui/v2.5/src/docs/en/{ => Manual}/Galleries.md (100%)
rename ui/v2.5/src/docs/en/{ => Manual}/Help.md (100%)
rename ui/v2.5/src/docs/en/{ => Manual}/Identify.md (100%)
rename ui/v2.5/src/docs/en/{ => Manual}/Interactive.md (100%)
rename ui/v2.5/src/docs/en/{ => Manual}/Interface.md (100%)
rename ui/v2.5/src/docs/en/{ => Manual}/Introduction.md (100%)
rename ui/v2.5/src/docs/en/{ => Manual}/JSONSpec.md (100%)
rename ui/v2.5/src/docs/en/{ => Manual}/KeyboardShortcuts.md (100%)
rename ui/v2.5/src/docs/en/{ => Manual}/Plugins.md (100%)
rename ui/v2.5/src/docs/en/{ => Manual}/SceneFilenameParser.md (100%)
rename ui/v2.5/src/docs/en/{ => Manual}/ScraperDevelopment.md (100%)
rename ui/v2.5/src/docs/en/{ => Manual}/Scraping.md (100%)
rename ui/v2.5/src/docs/en/{ => Manual}/Tagger.md (100%)
rename ui/v2.5/src/docs/en/{ => Manual}/Tasks.md (100%)
create mode 100644 ui/v2.5/src/docs/en/MigrationNotes/index.ts
create mode 100644 ui/v2.5/src/docs/en/ReleaseNotes/index.ts
create mode 100644 ui/v2.5/src/docs/en/ReleaseNotes/v0170.md
diff --git a/ui/v2.5/src/App.tsx b/ui/v2.5/src/App.tsx
index 31a6a400e..db81aeb98 100755
--- a/ui/v2.5/src/App.tsx
+++ b/ui/v2.5/src/App.tsx
@@ -9,7 +9,11 @@ import LightboxProvider from "src/hooks/Lightbox/context";
import { initPolyfills } from "src/polyfills";
import locales from "src/locales";
-import { useConfiguration, useSystemStatus } from "src/core/StashService";
+import {
+ useConfiguration,
+ useConfigureUI,
+ useSystemStatus,
+} from "src/core/StashService";
import { flattenMessages } from "src/utils";
import Mousetrap from "mousetrap";
import MousetrapPause from "mousetrap-pause";
@@ -22,6 +26,9 @@ import { LoadingIndicator, TITLE_SUFFIX } from "./components/Shared";
import { ConfigurationProvider } from "./hooks/Config";
import { ManualProvider } from "./components/Help/context";
import { InteractiveProvider } from "./hooks/Interactive/context";
+import { ReleaseNotesDialog } from "./components/Dialogs/ReleaseNotesDialog";
+import { IUIConfig } from "./core/config";
+import { releaseNotes } from "./docs/en/ReleaseNotes";
const Performers = lazy(() => import("./components/Performers/Performers"));
const FrontPage = lazy(() => import("./components/FrontPage/FrontPage"));
@@ -62,6 +69,8 @@ function languageMessageString(language: string) {
export const App: React.FC = () => {
const config = useConfiguration();
+ const [saveUI] = useConfigureUI();
+
const { data: systemStatusData } = useSystemStatus();
const language =
@@ -161,6 +170,32 @@ export const App: React.FC = () => {
);
}
+ function maybeRenderReleaseNotes() {
+ const lastNoteSeen = (config.data?.configuration.ui as IUIConfig)
+ ?.lastNoteSeen;
+ const notes = releaseNotes.filter((n) => {
+ return !lastNoteSeen || n.date > lastNoteSeen;
+ });
+
+ if (notes.length === 0) return;
+
+ return (
+ Changelog:
{releases.map((r) => (
+
+
+ + {scene.title + ? scene.title + : TextUtils.fileNameFromPath(file?.path ?? "")} + +
+{file?.path ?? ""}
+- - {scene.title ?? - TextUtils.fileNameFromPath(scene.path)} - -
-{scene.path}
-Originally this option disabled all triggers. ^(However, since +** SQLite version 3.35.0, TEMP triggers are still allowed even if +** this option is off. So, in other words, this option now only disables +** triggers in the main database schema or in the schemas of ATTACH-ed +** databases.)^ ** ** [[SQLITE_DBCONFIG_ENABLE_VIEW]] **
Originally this option disabled all views. ^(However, since +** SQLite version 3.35.0, TEMP views are still allowed even if +** this option is off. So, in other words, this option now only disables +** views in the main database schema or in the schemas of ATTACH-ed +** databases.)^ ** ** [[SQLITE_DBCONFIG_ENABLE_FTS3_TOKENIZER]] **
| file:data.db?mode=readonly | ** An error. "readonly" is not a valid option for the "mode" parameter. +** Use "ro" instead: "file:data.db?mode=ro". ** |