diff --git a/pkg/api/resolver.go b/pkg/api/resolver.go index f9a5237a3..b253a07df 100644 --- a/pkg/api/resolver.go +++ b/pkg/api/resolver.go @@ -5,12 +5,13 @@ import ( "sort" "strconv" - "github.com/99designs/gqlgen/graphql" "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/models" ) -type Resolver struct{} +type Resolver struct { + txnManager models.TransactionManager +} func (r *Resolver) Gallery() models.GalleryResolver { return &galleryResolver{r} @@ -79,35 +80,75 @@ type scrapedSceneMovieResolver struct{ *Resolver } type scrapedScenePerformerResolver struct{ *Resolver } type scrapedSceneStudioResolver struct{ *Resolver } -func (r *queryResolver) MarkerWall(ctx context.Context, q *string) ([]*models.SceneMarker, error) { - qb := models.NewSceneMarkerQueryBuilder() - return qb.Wall(q) +func (r *Resolver) withTxn(ctx context.Context, fn func(r models.Repository) error) error { + return r.txnManager.WithTxn(ctx, fn) } -func (r *queryResolver) SceneWall(ctx context.Context, q *string) ([]*models.Scene, error) { - qb := models.NewSceneQueryBuilder() - return qb.Wall(q) +func (r *Resolver) withReadTxn(ctx context.Context, fn func(r models.ReaderRepository) error) error { + return r.txnManager.WithReadTxn(ctx, fn) } -func (r *queryResolver) MarkerStrings(ctx context.Context, q *string, sort *string) ([]*models.MarkerStringsResultType, error) { - qb := models.NewSceneMarkerQueryBuilder() - return qb.GetMarkerStrings(q, sort) +func (r *queryResolver) MarkerWall(ctx context.Context, q *string) (ret []*models.SceneMarker, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + ret, err = repo.SceneMarker().Wall(q) + return err + }); err != nil { + return nil, err + } + return ret, nil +} + +func (r *queryResolver) SceneWall(ctx context.Context, q *string) (ret []*models.Scene, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + ret, err = repo.Scene().Wall(q) + return err + }); err != nil { + return nil, err + } + + return ret, nil +} + +func (r *queryResolver) MarkerStrings(ctx context.Context, q *string, sort *string) (ret []*models.MarkerStringsResultType, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + ret, err = repo.SceneMarker().GetMarkerStrings(q, sort) + return err + }); err != nil { + return nil, err + } + + return ret, nil } func (r *queryResolver) ValidGalleriesForScene(ctx context.Context, scene_id *string) ([]*models.Gallery, error) { if scene_id == nil { panic("nil scene id") // TODO make scene_id mandatory } - sceneID, _ := strconv.Atoi(*scene_id) - sqb := models.NewSceneQueryBuilder() - scene, err := sqb.Find(sceneID) + sceneID, err := strconv.Atoi(*scene_id) if err != nil { return nil, err } - qb := models.NewGalleryQueryBuilder() - validGalleries, err := qb.ValidGalleriesForScenePath(scene.Path) - sceneGallery, _ := qb.FindBySceneID(sceneID, nil) + var validGalleries []*models.Gallery + var sceneGallery *models.Gallery + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + sqb := repo.Scene() + scene, err := sqb.Find(sceneID) + if err != nil { + return err + } + + qb := repo.Gallery() + validGalleries, err = qb.ValidGalleriesForScenePath(scene.Path) + if err != nil { + return err + } + sceneGallery, err = qb.FindBySceneID(sceneID) + return err + }); err != nil { + return nil, err + } + if sceneGallery != nil { validGalleries = append(validGalleries, sceneGallery) } @@ -115,33 +156,43 @@ func (r *queryResolver) ValidGalleriesForScene(ctx context.Context, scene_id *st } func (r *queryResolver) Stats(ctx context.Context) (*models.StatsResultType, error) { - scenesQB := models.NewSceneQueryBuilder() - scenesCount, _ := scenesQB.Count() - scenesSize, _ := scenesQB.Size() - imageQB := models.NewImageQueryBuilder() - imageCount, _ := imageQB.Count() - imageSize, _ := imageQB.Size() - galleryQB := models.NewGalleryQueryBuilder() - galleryCount, _ := galleryQB.Count() - performersQB := models.NewPerformerQueryBuilder() - performersCount, _ := performersQB.Count() - studiosQB := models.NewStudioQueryBuilder() - studiosCount, _ := studiosQB.Count() - moviesQB := models.NewMovieQueryBuilder() - moviesCount, _ := moviesQB.Count() - tagsQB := models.NewTagQueryBuilder() - tagsCount, _ := tagsQB.Count() - return &models.StatsResultType{ - SceneCount: scenesCount, - ScenesSize: scenesSize, - ImageCount: imageCount, - ImagesSize: imageSize, - GalleryCount: galleryCount, - PerformerCount: performersCount, - StudioCount: studiosCount, - MovieCount: moviesCount, - TagCount: tagsCount, - }, nil + var ret models.StatsResultType + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + scenesQB := repo.Scene() + imageQB := repo.Image() + galleryQB := repo.Gallery() + studiosQB := repo.Studio() + performersQB := repo.Performer() + moviesQB := repo.Movie() + tagsQB := repo.Tag() + scenesCount, _ := scenesQB.Count() + scenesSize, _ := scenesQB.Size() + imageCount, _ := imageQB.Count() + imageSize, _ := imageQB.Size() + galleryCount, _ := galleryQB.Count() + performersCount, _ := performersQB.Count() + studiosCount, _ := studiosQB.Count() + moviesCount, _ := moviesQB.Count() + tagsCount, _ := tagsQB.Count() + + ret = models.StatsResultType{ + SceneCount: scenesCount, + ScenesSize: scenesSize, + ImageCount: imageCount, + ImagesSize: imageSize, + GalleryCount: galleryCount, + PerformerCount: performersCount, + StudioCount: studiosCount, + MovieCount: moviesCount, + TagCount: tagsCount, + } + + return nil + }); err != nil { + return nil, err + } + + return &ret, nil } func (r *queryResolver) Version(ctx context.Context) (*models.Version, error) { @@ -171,31 +222,41 @@ func (r *queryResolver) Latestversion(ctx context.Context) (*models.ShortVersion // Get scene marker tags which show up under the video. func (r *queryResolver) SceneMarkerTags(ctx context.Context, scene_id string) ([]*models.SceneMarkerTag, error) { - sceneID, _ := strconv.Atoi(scene_id) - sqb := models.NewSceneMarkerQueryBuilder() - sceneMarkers, err := sqb.FindBySceneID(sceneID, nil) + sceneID, err := strconv.Atoi(scene_id) if err != nil { return nil, err } - tags := make(map[int]*models.SceneMarkerTag) var keys []int - tqb := models.NewTagQueryBuilder() - for _, sceneMarker := range sceneMarkers { - markerPrimaryTag, err := tqb.Find(sceneMarker.PrimaryTagID, nil) + tags := make(map[int]*models.SceneMarkerTag) + + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + sceneMarkers, err := repo.SceneMarker().FindBySceneID(sceneID) if err != nil { - return nil, err + return err } - _, hasKey := tags[markerPrimaryTag.ID] - var sceneMarkerTag *models.SceneMarkerTag - if !hasKey { - sceneMarkerTag = &models.SceneMarkerTag{Tag: markerPrimaryTag} - tags[markerPrimaryTag.ID] = sceneMarkerTag - keys = append(keys, markerPrimaryTag.ID) - } else { - sceneMarkerTag = tags[markerPrimaryTag.ID] + + tqb := repo.Tag() + for _, sceneMarker := range sceneMarkers { + markerPrimaryTag, err := tqb.Find(sceneMarker.PrimaryTagID) + if err != nil { + return err + } + _, hasKey := tags[markerPrimaryTag.ID] + var sceneMarkerTag *models.SceneMarkerTag + if !hasKey { + sceneMarkerTag = &models.SceneMarkerTag{Tag: markerPrimaryTag} + tags[markerPrimaryTag.ID] = sceneMarkerTag + keys = append(keys, markerPrimaryTag.ID) + } else { + sceneMarkerTag = tags[markerPrimaryTag.ID] + } + tags[markerPrimaryTag.ID].SceneMarkers = append(tags[markerPrimaryTag.ID].SceneMarkers, sceneMarker) } - tags[markerPrimaryTag.ID].SceneMarkers = append(tags[markerPrimaryTag.ID].SceneMarkers, sceneMarker) + + return nil + }); err != nil { + return nil, err } // Sort so that primary tags that show up earlier in the video are first. @@ -212,13 +273,3 @@ func (r *queryResolver) SceneMarkerTags(ctx context.Context, scene_id string) ([ return result, nil } - -// wasFieldIncluded returns true if the given field was included in the request. -// Slices are unmarshalled to empty slices even if the field was omitted. This -// method determines if it was omitted altogether. -func wasFieldIncluded(ctx context.Context, field string) bool { - rctx := graphql.GetRequestContext(ctx) - - _, ret := rctx.Variables[field] - return ret -} diff --git a/pkg/api/resolver_model_gallery.go b/pkg/api/resolver_model_gallery.go index 0d73f8219..4b646cd02 100644 --- a/pkg/api/resolver_model_gallery.go +++ b/pkg/api/resolver_model_gallery.go @@ -22,30 +22,39 @@ func (r *galleryResolver) Title(ctx context.Context, obj *models.Gallery) (*stri return nil, nil } -func (r *galleryResolver) Images(ctx context.Context, obj *models.Gallery) ([]*models.Image, error) { - qb := models.NewImageQueryBuilder() - - return qb.FindByGalleryID(obj.ID) -} - -func (r *galleryResolver) Cover(ctx context.Context, obj *models.Gallery) (*models.Image, error) { - qb := models.NewImageQueryBuilder() - - imgs, err := qb.FindByGalleryID(obj.ID) - if err != nil { +func (r *galleryResolver) Images(ctx context.Context, obj *models.Gallery) (ret []*models.Image, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + var err error + ret, err = repo.Image().FindByGalleryID(obj.ID) + return err + }); err != nil { return nil, err } - var ret *models.Image - if len(imgs) > 0 { - ret = imgs[0] - } + return ret, nil +} - for _, img := range imgs { - if image.IsCover(img) { - ret = img - break +func (r *galleryResolver) Cover(ctx context.Context, obj *models.Gallery) (ret *models.Image, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + imgs, err := repo.Image().FindByGalleryID(obj.ID) + if err != nil { + return err } + + if len(imgs) > 0 { + ret = imgs[0] + } + + for _, img := range imgs { + if image.IsCover(img) { + ret = img + break + } + } + + return nil + }); err != nil { + return nil, err } return ret, nil @@ -81,35 +90,71 @@ func (r *galleryResolver) Rating(ctx context.Context, obj *models.Gallery) (*int return nil, nil } -func (r *galleryResolver) Scene(ctx context.Context, obj *models.Gallery) (*models.Scene, error) { +func (r *galleryResolver) Scene(ctx context.Context, obj *models.Gallery) (ret *models.Scene, err error) { if !obj.SceneID.Valid { return nil, nil } - qb := models.NewSceneQueryBuilder() - return qb.Find(int(obj.SceneID.Int64)) + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + var err error + ret, err = repo.Scene().Find(int(obj.SceneID.Int64)) + + return err + }); err != nil { + return nil, err + } + + return ret, nil } -func (r *galleryResolver) Studio(ctx context.Context, obj *models.Gallery) (*models.Studio, error) { +func (r *galleryResolver) Studio(ctx context.Context, obj *models.Gallery) (ret *models.Studio, err error) { if !obj.StudioID.Valid { return nil, nil } - qb := models.NewStudioQueryBuilder() - return qb.Find(int(obj.StudioID.Int64), nil) + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + var err error + ret, err = repo.Studio().Find(int(obj.StudioID.Int64)) + return err + }); err != nil { + return nil, err + } + + return ret, nil } -func (r *galleryResolver) Tags(ctx context.Context, obj *models.Gallery) ([]*models.Tag, error) { - qb := models.NewTagQueryBuilder() - return qb.FindByGalleryID(obj.ID, nil) +func (r *galleryResolver) Tags(ctx context.Context, obj *models.Gallery) (ret []*models.Tag, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + var err error + ret, err = repo.Tag().FindByGalleryID(obj.ID) + return err + }); err != nil { + return nil, err + } + + return ret, nil } -func (r *galleryResolver) Performers(ctx context.Context, obj *models.Gallery) ([]*models.Performer, error) { - qb := models.NewPerformerQueryBuilder() - return qb.FindByGalleryID(obj.ID, nil) +func (r *galleryResolver) Performers(ctx context.Context, obj *models.Gallery) (ret []*models.Performer, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + var err error + ret, err = repo.Performer().FindByGalleryID(obj.ID) + return err + }); err != nil { + return nil, err + } + + return ret, nil } -func (r *galleryResolver) ImageCount(ctx context.Context, obj *models.Gallery) (int, error) { - qb := models.NewImageQueryBuilder() - return qb.CountByGalleryID(obj.ID) +func (r *galleryResolver) ImageCount(ctx context.Context, obj *models.Gallery) (ret int, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + var err error + ret, err = repo.Image().CountByGalleryID(obj.ID) + return err + }); err != nil { + return 0, err + } + + return ret, nil } diff --git a/pkg/api/resolver_model_image.go b/pkg/api/resolver_model_image.go index 6ca679f89..d3cf4c066 100644 --- a/pkg/api/resolver_model_image.go +++ b/pkg/api/resolver_model_image.go @@ -43,26 +43,51 @@ func (r *imageResolver) Paths(ctx context.Context, obj *models.Image) (*models.I }, nil } -func (r *imageResolver) Galleries(ctx context.Context, obj *models.Image) ([]*models.Gallery, error) { - qb := models.NewGalleryQueryBuilder() - return qb.FindByImageID(obj.ID, nil) +func (r *imageResolver) Galleries(ctx context.Context, obj *models.Image) (ret []*models.Gallery, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + var err error + ret, err = repo.Gallery().FindByImageID(obj.ID) + return err + }); err != nil { + return nil, err + } + + return ret, nil } -func (r *imageResolver) Studio(ctx context.Context, obj *models.Image) (*models.Studio, error) { +func (r *imageResolver) Studio(ctx context.Context, obj *models.Image) (ret *models.Studio, err error) { if !obj.StudioID.Valid { return nil, nil } - qb := models.NewStudioQueryBuilder() - return qb.Find(int(obj.StudioID.Int64), nil) + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + ret, err = repo.Studio().Find(int(obj.StudioID.Int64)) + return err + }); err != nil { + return nil, err + } + + return ret, nil } -func (r *imageResolver) Tags(ctx context.Context, obj *models.Image) ([]*models.Tag, error) { - qb := models.NewTagQueryBuilder() - return qb.FindByImageID(obj.ID, nil) +func (r *imageResolver) Tags(ctx context.Context, obj *models.Image) (ret []*models.Tag, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + ret, err = repo.Tag().FindByImageID(obj.ID) + return err + }); err != nil { + return nil, err + } + + return ret, nil } -func (r *imageResolver) Performers(ctx context.Context, obj *models.Image) ([]*models.Performer, error) { - qb := models.NewPerformerQueryBuilder() - return qb.FindByImageID(obj.ID, nil) +func (r *imageResolver) Performers(ctx context.Context, obj *models.Image) (ret []*models.Performer, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + ret, err = repo.Performer().FindByImageID(obj.ID) + return err + }); err != nil { + return nil, err + } + + return ret, nil } diff --git a/pkg/api/resolver_model_movie.go b/pkg/api/resolver_model_movie.go index 6ab444a64..137113ed7 100644 --- a/pkg/api/resolver_model_movie.go +++ b/pkg/api/resolver_model_movie.go @@ -53,10 +53,16 @@ func (r *movieResolver) Rating(ctx context.Context, obj *models.Movie) (*int, er return nil, nil } -func (r *movieResolver) Studio(ctx context.Context, obj *models.Movie) (*models.Studio, error) { - qb := models.NewStudioQueryBuilder() +func (r *movieResolver) Studio(ctx context.Context, obj *models.Movie) (ret *models.Studio, err error) { if obj.StudioID.Valid { - return qb.Find(int(obj.StudioID.Int64), nil) + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + ret, err = repo.Studio().Find(int(obj.StudioID.Int64)) + return err + }); err != nil { + return nil, err + } + + return ret, nil } return nil, nil @@ -88,8 +94,14 @@ func (r *movieResolver) BackImagePath(ctx context.Context, obj *models.Movie) (* return &backimagePath, nil } -func (r *movieResolver) SceneCount(ctx context.Context, obj *models.Movie) (*int, error) { - qb := models.NewSceneQueryBuilder() - res, err := qb.CountByMovieID(obj.ID) +func (r *movieResolver) SceneCount(ctx context.Context, obj *models.Movie) (ret *int, err error) { + var res int + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + res, err = repo.Scene().CountByMovieID(obj.ID) + return err + }); err != nil { + return nil, err + } + return &res, err } diff --git a/pkg/api/resolver_model_performer.go b/pkg/api/resolver_model_performer.go index c322e4943..ccbcbc58a 100644 --- a/pkg/api/resolver_model_performer.go +++ b/pkg/api/resolver_model_performer.go @@ -138,18 +138,36 @@ func (r *performerResolver) ImagePath(ctx context.Context, obj *models.Performer return &imagePath, nil } -func (r *performerResolver) SceneCount(ctx context.Context, obj *models.Performer) (*int, error) { - qb := models.NewSceneQueryBuilder() - res, err := qb.CountByPerformerID(obj.ID) - return &res, err +func (r *performerResolver) SceneCount(ctx context.Context, obj *models.Performer) (ret *int, err error) { + var res int + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + res, err = repo.Scene().CountByPerformerID(obj.ID) + return err + }); err != nil { + return nil, err + } + + return &res, nil } -func (r *performerResolver) Scenes(ctx context.Context, obj *models.Performer) ([]*models.Scene, error) { - qb := models.NewSceneQueryBuilder() - return qb.FindByPerformerID(obj.ID) +func (r *performerResolver) Scenes(ctx context.Context, obj *models.Performer) (ret []*models.Scene, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + ret, err = repo.Scene().FindByPerformerID(obj.ID) + return err + }); err != nil { + return nil, err + } + + return ret, nil } -func (r *performerResolver) StashIds(ctx context.Context, obj *models.Performer) ([]*models.StashID, error) { - qb := models.NewJoinsQueryBuilder() - return qb.GetPerformerStashIDs(obj.ID) +func (r *performerResolver) StashIds(ctx context.Context, obj *models.Performer) (ret []*models.StashID, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + ret, err = repo.Performer().GetStashIDs(obj.ID) + return err + }); err != nil { + return nil, err + } + + return ret, nil } diff --git a/pkg/api/resolver_model_scene.go b/pkg/api/resolver_model_scene.go index 656e2ef66..c81e82489 100644 --- a/pkg/api/resolver_model_scene.go +++ b/pkg/api/resolver_model_scene.go @@ -94,64 +94,109 @@ func (r *sceneResolver) Paths(ctx context.Context, obj *models.Scene) (*models.S }, nil } -func (r *sceneResolver) SceneMarkers(ctx context.Context, obj *models.Scene) ([]*models.SceneMarker, error) { - qb := models.NewSceneMarkerQueryBuilder() - return qb.FindBySceneID(obj.ID, nil) -} - -func (r *sceneResolver) Gallery(ctx context.Context, obj *models.Scene) (*models.Gallery, error) { - qb := models.NewGalleryQueryBuilder() - return qb.FindBySceneID(obj.ID, nil) -} - -func (r *sceneResolver) Studio(ctx context.Context, obj *models.Scene) (*models.Studio, error) { - qb := models.NewStudioQueryBuilder() - return qb.FindBySceneID(obj.ID) -} - -func (r *sceneResolver) Movies(ctx context.Context, obj *models.Scene) ([]*models.SceneMovie, error) { - joinQB := models.NewJoinsQueryBuilder() - qb := models.NewMovieQueryBuilder() - - sceneMovies, err := joinQB.GetSceneMovies(obj.ID, nil) - if err != nil { +func (r *sceneResolver) SceneMarkers(ctx context.Context, obj *models.Scene) (ret []*models.SceneMarker, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + ret, err = repo.SceneMarker().FindBySceneID(obj.ID) + return err + }); err != nil { return nil, err } - var ret []*models.SceneMovie - for _, sm := range sceneMovies { - movie, err := qb.Find(sm.MovieID, nil) + return ret, nil +} + +func (r *sceneResolver) Gallery(ctx context.Context, obj *models.Scene) (ret *models.Gallery, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + ret, err = repo.Gallery().FindBySceneID(obj.ID) + return err + }); err != nil { + return nil, err + } + + return ret, nil +} + +func (r *sceneResolver) Studio(ctx context.Context, obj *models.Scene) (ret *models.Studio, err error) { + if !obj.StudioID.Valid { + return nil, nil + } + + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + ret, err = repo.Studio().Find(int(obj.StudioID.Int64)) + return err + }); err != nil { + return nil, err + } + + return ret, nil +} + +func (r *sceneResolver) Movies(ctx context.Context, obj *models.Scene) (ret []*models.SceneMovie, err error) { + if err := r.withTxn(ctx, func(repo models.Repository) error { + qb := repo.Scene() + mqb := repo.Movie() + + sceneMovies, err := qb.GetMovies(obj.ID) if err != nil { - return nil, err + return err } - sceneIdx := sm.SceneIndex - sceneMovie := &models.SceneMovie{ - Movie: movie, + for _, sm := range sceneMovies { + movie, err := mqb.Find(sm.MovieID) + if err != nil { + return err + } + + sceneIdx := sm.SceneIndex + sceneMovie := &models.SceneMovie{ + Movie: movie, + } + + if sceneIdx.Valid { + var idx int + idx = int(sceneIdx.Int64) + sceneMovie.SceneIndex = &idx + } + + ret = append(ret, sceneMovie) } - if sceneIdx.Valid { - var idx int - idx = int(sceneIdx.Int64) - sceneMovie.SceneIndex = &idx - } - - ret = append(ret, sceneMovie) + return nil + }); err != nil { + return nil, err } return ret, nil } -func (r *sceneResolver) Tags(ctx context.Context, obj *models.Scene) ([]*models.Tag, error) { - qb := models.NewTagQueryBuilder() - return qb.FindBySceneID(obj.ID, nil) +func (r *sceneResolver) Tags(ctx context.Context, obj *models.Scene) (ret []*models.Tag, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + ret, err = repo.Tag().FindBySceneID(obj.ID) + return err + }); err != nil { + return nil, err + } + + return ret, nil } -func (r *sceneResolver) Performers(ctx context.Context, obj *models.Scene) ([]*models.Performer, error) { - qb := models.NewPerformerQueryBuilder() - return qb.FindBySceneID(obj.ID, nil) +func (r *sceneResolver) Performers(ctx context.Context, obj *models.Scene) (ret []*models.Performer, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + ret, err = repo.Performer().FindBySceneID(obj.ID) + return err + }); err != nil { + return nil, err + } + + return ret, nil } -func (r *sceneResolver) StashIds(ctx context.Context, obj *models.Scene) ([]*models.StashID, error) { - qb := models.NewJoinsQueryBuilder() - return qb.GetSceneStashIDs(obj.ID) +func (r *sceneResolver) StashIds(ctx context.Context, obj *models.Scene) (ret []*models.StashID, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + ret, err = repo.Scene().GetStashIDs(obj.ID) + return err + }); err != nil { + return nil, err + } + + return ret, nil } diff --git a/pkg/api/resolver_model_scene_marker.go b/pkg/api/resolver_model_scene_marker.go index 7324c84a0..98073871b 100644 --- a/pkg/api/resolver_model_scene_marker.go +++ b/pkg/api/resolver_model_scene_marker.go @@ -2,29 +2,47 @@ package api import ( "context" + "github.com/stashapp/stash/pkg/api/urlbuilders" "github.com/stashapp/stash/pkg/models" ) -func (r *sceneMarkerResolver) Scene(ctx context.Context, obj *models.SceneMarker) (*models.Scene, error) { +func (r *sceneMarkerResolver) Scene(ctx context.Context, obj *models.SceneMarker) (ret *models.Scene, err error) { if !obj.SceneID.Valid { panic("Invalid scene id") } - qb := models.NewSceneQueryBuilder() - sceneID := int(obj.SceneID.Int64) - scene, err := qb.Find(sceneID) - return scene, err + + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + sceneID := int(obj.SceneID.Int64) + ret, err = repo.Scene().Find(sceneID) + return err + }); err != nil { + return nil, err + } + + return ret, nil } -func (r *sceneMarkerResolver) PrimaryTag(ctx context.Context, obj *models.SceneMarker) (*models.Tag, error) { - qb := models.NewTagQueryBuilder() - tag, err := qb.Find(obj.PrimaryTagID, nil) - return tag, err +func (r *sceneMarkerResolver) PrimaryTag(ctx context.Context, obj *models.SceneMarker) (ret *models.Tag, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + ret, err = repo.Tag().Find(obj.PrimaryTagID) + return err + }); err != nil { + return nil, err + } + + return ret, err } -func (r *sceneMarkerResolver) Tags(ctx context.Context, obj *models.SceneMarker) ([]*models.Tag, error) { - qb := models.NewTagQueryBuilder() - return qb.FindBySceneMarkerID(obj.ID, nil) +func (r *sceneMarkerResolver) Tags(ctx context.Context, obj *models.SceneMarker) (ret []*models.Tag, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + ret, err = repo.Tag().FindBySceneMarkerID(obj.ID) + return err + }); err != nil { + return nil, err + } + + return ret, err } func (r *sceneMarkerResolver) Stream(ctx context.Context, obj *models.SceneMarker) (string, error) { diff --git a/pkg/api/resolver_model_studio.go b/pkg/api/resolver_model_studio.go index d701a884f..dac4e265e 100644 --- a/pkg/api/resolver_model_studio.go +++ b/pkg/api/resolver_model_studio.go @@ -25,10 +25,12 @@ func (r *studioResolver) ImagePath(ctx context.Context, obj *models.Studio) (*st baseURL, _ := ctx.Value(BaseURLCtxKey).(string) imagePath := urlbuilders.NewStudioURLBuilder(baseURL, obj.ID).GetStudioImageURL() - qb := models.NewStudioQueryBuilder() - hasImage, err := qb.HasStudioImage(obj.ID) - - if err != nil { + var hasImage bool + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + var err error + hasImage, err = repo.Studio().HasImage(obj.ID) + return err + }); err != nil { return nil, err } @@ -40,27 +42,51 @@ func (r *studioResolver) ImagePath(ctx context.Context, obj *models.Studio) (*st return &imagePath, nil } -func (r *studioResolver) SceneCount(ctx context.Context, obj *models.Studio) (*int, error) { - qb := models.NewSceneQueryBuilder() - res, err := qb.CountByStudioID(obj.ID) +func (r *studioResolver) SceneCount(ctx context.Context, obj *models.Studio) (ret *int, err error) { + var res int + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + res, err = repo.Scene().CountByStudioID(obj.ID) + return err + }); err != nil { + return nil, err + } + return &res, err } -func (r *studioResolver) ParentStudio(ctx context.Context, obj *models.Studio) (*models.Studio, error) { +func (r *studioResolver) ParentStudio(ctx context.Context, obj *models.Studio) (ret *models.Studio, err error) { if !obj.ParentID.Valid { return nil, nil } - qb := models.NewStudioQueryBuilder() - return qb.Find(int(obj.ParentID.Int64), nil) + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + ret, err = repo.Studio().Find(int(obj.ParentID.Int64)) + return err + }); err != nil { + return nil, err + } + + return ret, nil } -func (r *studioResolver) ChildStudios(ctx context.Context, obj *models.Studio) ([]*models.Studio, error) { - qb := models.NewStudioQueryBuilder() - return qb.FindChildren(obj.ID, nil) +func (r *studioResolver) ChildStudios(ctx context.Context, obj *models.Studio) (ret []*models.Studio, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + ret, err = repo.Studio().FindChildren(obj.ID) + return err + }); err != nil { + return nil, err + } + + return ret, nil } -func (r *studioResolver) StashIds(ctx context.Context, obj *models.Studio) ([]*models.StashID, error) { - qb := models.NewJoinsQueryBuilder() - return qb.GetStudioStashIDs(obj.ID) +func (r *studioResolver) StashIds(ctx context.Context, obj *models.Studio) (ret []*models.StashID, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + ret, err = repo.Studio().GetStashIDs(obj.ID) + return err + }); err != nil { + return nil, err + } + + return ret, nil } diff --git a/pkg/api/resolver_model_tag.go b/pkg/api/resolver_model_tag.go index 24e44fdb2..08cb50819 100644 --- a/pkg/api/resolver_model_tag.go +++ b/pkg/api/resolver_model_tag.go @@ -7,21 +7,27 @@ import ( "github.com/stashapp/stash/pkg/models" ) -func (r *tagResolver) SceneCount(ctx context.Context, obj *models.Tag) (*int, error) { - qb := models.NewSceneQueryBuilder() - if obj == nil { - return nil, nil +func (r *tagResolver) SceneCount(ctx context.Context, obj *models.Tag) (ret *int, err error) { + var count int + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + count, err = repo.Scene().CountByTagID(obj.ID) + return err + }); err != nil { + return nil, err } - count, err := qb.CountByTagID(obj.ID) + return &count, err } -func (r *tagResolver) SceneMarkerCount(ctx context.Context, obj *models.Tag) (*int, error) { - qb := models.NewSceneMarkerQueryBuilder() - if obj == nil { - return nil, nil +func (r *tagResolver) SceneMarkerCount(ctx context.Context, obj *models.Tag) (ret *int, err error) { + var count int + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + count, err = repo.SceneMarker().CountByTagID(obj.ID) + return err + }); err != nil { + return nil, err } - count, err := qb.CountByTagID(obj.ID) + return &count, err } diff --git a/pkg/api/resolver_mutation_configure.go b/pkg/api/resolver_mutation_configure.go index 94a9c5210..a1051e218 100644 --- a/pkg/api/resolver_mutation_configure.go +++ b/pkg/api/resolver_mutation_configure.go @@ -52,7 +52,7 @@ func (r *mutationResolver) ConfigureGeneral(ctx context.Context, input models.Co if input.VideoFileNamingAlgorithm != config.GetVideoFileNamingAlgorithm() { // validate changing VideoFileNamingAlgorithm - if err := manager.ValidateVideoFileNamingAlgorithm(input.VideoFileNamingAlgorithm); err != nil { + if err := manager.ValidateVideoFileNamingAlgorithm(r.txnManager, input.VideoFileNamingAlgorithm); err != nil { return makeConfigGeneralResult(), err } diff --git a/pkg/api/resolver_mutation_gallery.go b/pkg/api/resolver_mutation_gallery.go index 29c8c2d4d..68ed48e54 100644 --- a/pkg/api/resolver_mutation_gallery.go +++ b/pkg/api/resolver_mutation_gallery.go @@ -4,11 +4,10 @@ import ( "context" "database/sql" "errors" + "fmt" "strconv" "time" - "github.com/jmoiron/sqlx" - "github.com/stashapp/stash/pkg/database" "github.com/stashapp/stash/pkg/manager" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/utils" @@ -69,108 +68,102 @@ func (r *mutationResolver) GalleryCreate(ctx context.Context, input models.Galle newGallery.SceneID = sql.NullInt64{Valid: false} } - // Start the transaction and save the performer - tx := database.DB.MustBeginTx(ctx, nil) - qb := models.NewGalleryQueryBuilder() - jqb := models.NewJoinsQueryBuilder() - gallery, err := qb.Create(newGallery, tx) - if err != nil { - _ = tx.Rollback() - return nil, err - } - - // Save the performers - var performerJoins []models.PerformersGalleries - for _, pid := range input.PerformerIds { - performerID, _ := strconv.Atoi(pid) - performerJoin := models.PerformersGalleries{ - PerformerID: performerID, - GalleryID: gallery.ID, + // Start the transaction and save the gallery + var gallery *models.Gallery + if err := r.withTxn(ctx, func(repo models.Repository) error { + qb := repo.Gallery() + var err error + gallery, err = qb.Create(newGallery) + if err != nil { + return err } - performerJoins = append(performerJoins, performerJoin) - } - if err := jqb.UpdatePerformersGalleries(gallery.ID, performerJoins, tx); err != nil { - return nil, err - } - // Save the tags - var tagJoins []models.GalleriesTags - for _, tid := range input.TagIds { - tagID, _ := strconv.Atoi(tid) - tagJoin := models.GalleriesTags{ - GalleryID: gallery.ID, - TagID: tagID, + // Save the performers + if err := r.updateGalleryPerformers(qb, gallery.ID, input.PerformerIds); err != nil { + return err } - tagJoins = append(tagJoins, tagJoin) - } - if err := jqb.UpdateGalleriesTags(gallery.ID, tagJoins, tx); err != nil { - return nil, err - } - // Commit - if err := tx.Commit(); err != nil { + // Save the tags + if err := r.updateGalleryTags(qb, gallery.ID, input.TagIds); err != nil { + return err + } + + return nil + }); err != nil { return nil, err } return gallery, nil } -func (r *mutationResolver) GalleryUpdate(ctx context.Context, input models.GalleryUpdateInput) (*models.Gallery, error) { - // Start the transaction and save the gallery - tx := database.DB.MustBeginTx(ctx, nil) +func (r *mutationResolver) updateGalleryPerformers(qb models.GalleryReaderWriter, galleryID int, performerIDs []string) error { + ids, err := utils.StringSliceToIntSlice(performerIDs) + if err != nil { + return err + } + return qb.UpdatePerformers(galleryID, ids) +} +func (r *mutationResolver) updateGalleryTags(qb models.GalleryReaderWriter, galleryID int, tagIDs []string) error { + ids, err := utils.StringSliceToIntSlice(tagIDs) + if err != nil { + return err + } + return qb.UpdateTags(galleryID, ids) +} + +func (r *mutationResolver) GalleryUpdate(ctx context.Context, input models.GalleryUpdateInput) (ret *models.Gallery, err error) { translator := changesetTranslator{ inputMap: getUpdateInputMap(ctx), } - ret, err := r.galleryUpdate(input, translator, tx) - if err != nil { - _ = tx.Rollback() - return nil, err - } - - // Commit - if err := tx.Commit(); err != nil { + // Start the transaction and save the gallery + if err := r.withTxn(ctx, func(repo models.Repository) error { + ret, err = r.galleryUpdate(input, translator, repo) + return err + }); err != nil { return nil, err } return ret, nil } -func (r *mutationResolver) GalleriesUpdate(ctx context.Context, input []*models.GalleryUpdateInput) ([]*models.Gallery, error) { - // Start the transaction and save the gallery - tx := database.DB.MustBeginTx(ctx, nil) +func (r *mutationResolver) GalleriesUpdate(ctx context.Context, input []*models.GalleryUpdateInput) (ret []*models.Gallery, err error) { inputMaps := getUpdateInputMaps(ctx) - var ret []*models.Gallery + // Start the transaction and save the gallery + if err := r.withTxn(ctx, func(repo models.Repository) error { + for i, gallery := range input { + translator := changesetTranslator{ + inputMap: inputMaps[i], + } - for i, gallery := range input { - translator := changesetTranslator{ - inputMap: inputMaps[i], + thisGallery, err := r.galleryUpdate(*gallery, translator, repo) + if err != nil { + return err + } + + ret = append(ret, thisGallery) } - thisGallery, err := r.galleryUpdate(*gallery, translator, tx) - ret = append(ret, thisGallery) - - if err != nil { - _ = tx.Rollback() - return nil, err - } - } - - // Commit - if err := tx.Commit(); err != nil { + return nil + }); err != nil { return nil, err } return ret, nil } -func (r *mutationResolver) galleryUpdate(input models.GalleryUpdateInput, translator changesetTranslator, tx *sqlx.Tx) (*models.Gallery, error) { - qb := models.NewGalleryQueryBuilder() +func (r *mutationResolver) galleryUpdate(input models.GalleryUpdateInput, translator changesetTranslator, repo models.Repository) (*models.Gallery, error) { + qb := repo.Gallery() + // Populate gallery from the input - galleryID, _ := strconv.Atoi(input.ID) - originalGallery, err := qb.Find(galleryID, nil) + galleryID, err := strconv.Atoi(input.ID) + if err != nil { + return nil, err + } + + originalGallery, err := qb.Find(galleryID) if err != nil { return nil, err } @@ -209,40 +202,21 @@ func (r *mutationResolver) galleryUpdate(input models.GalleryUpdateInput, transl // gallery scene is set from the scene only - jqb := models.NewJoinsQueryBuilder() - gallery, err := qb.UpdatePartial(updatedGallery, tx) + gallery, err := qb.UpdatePartial(updatedGallery) if err != nil { return nil, err } // Save the performers if translator.hasField("performer_ids") { - var performerJoins []models.PerformersGalleries - for _, pid := range input.PerformerIds { - performerID, _ := strconv.Atoi(pid) - performerJoin := models.PerformersGalleries{ - PerformerID: performerID, - GalleryID: galleryID, - } - performerJoins = append(performerJoins, performerJoin) - } - if err := jqb.UpdatePerformersGalleries(galleryID, performerJoins, tx); err != nil { + if err := r.updateGalleryPerformers(qb, galleryID, input.PerformerIds); err != nil { return nil, err } } // Save the tags if translator.hasField("tag_ids") { - var tagJoins []models.GalleriesTags - for _, tid := range input.TagIds { - tagID, _ := strconv.Atoi(tid) - tagJoin := models.GalleriesTags{ - GalleryID: galleryID, - TagID: tagID, - } - tagJoins = append(tagJoins, tagJoin) - } - if err := jqb.UpdateGalleriesTags(galleryID, tagJoins, tx); err != nil { + if err := r.updateGalleryTags(qb, galleryID, input.TagIds); err != nil { return nil, err } } @@ -254,11 +228,6 @@ func (r *mutationResolver) BulkGalleryUpdate(ctx context.Context, input models.B // Populate gallery from the input updatedTime := time.Now() - // Start the transaction and save the gallery marker - tx := database.DB.MustBeginTx(ctx, nil) - qb := models.NewGalleryQueryBuilder() - jqb := models.NewJoinsQueryBuilder() - translator := changesetTranslator{ inputMap: getUpdateInputMap(ctx), } @@ -277,181 +246,143 @@ func (r *mutationResolver) BulkGalleryUpdate(ctx context.Context, input models.B ret := []*models.Gallery{} - for _, galleryIDStr := range input.Ids { - galleryID, _ := strconv.Atoi(galleryIDStr) - updatedGallery.ID = galleryID + // Start the transaction and save the galleries + if err := r.withTxn(ctx, func(repo models.Repository) error { + qb := repo.Gallery() - gallery, err := qb.UpdatePartial(updatedGallery, tx) - if err != nil { - _ = tx.Rollback() - return nil, err - } + for _, galleryIDStr := range input.Ids { + galleryID, _ := strconv.Atoi(galleryIDStr) + updatedGallery.ID = galleryID - ret = append(ret, gallery) - - // Save the performers - if translator.hasField("performer_ids") { - performerIDs, err := adjustGalleryPerformerIDs(tx, galleryID, *input.PerformerIds) + gallery, err := qb.UpdatePartial(updatedGallery) if err != nil { - _ = tx.Rollback() - return nil, err + return err } - var performerJoins []models.PerformersGalleries - for _, performerID := range performerIDs { - performerJoin := models.PerformersGalleries{ - PerformerID: performerID, - GalleryID: galleryID, + ret = append(ret, gallery) + + // Save the performers + if translator.hasField("performer_ids") { + performerIDs, err := adjustGalleryPerformerIDs(qb, galleryID, *input.PerformerIds) + if err != nil { + return err + } + + if err := qb.UpdatePerformers(galleryID, performerIDs); err != nil { + return err } - performerJoins = append(performerJoins, performerJoin) } - if err := jqb.UpdatePerformersGalleries(galleryID, performerJoins, tx); err != nil { - _ = tx.Rollback() - return nil, err + + // Save the tags + if translator.hasField("tag_ids") { + tagIDs, err := adjustGalleryTagIDs(qb, galleryID, *input.TagIds) + if err != nil { + return err + } + + if err := qb.UpdateTags(galleryID, tagIDs); err != nil { + return err + } } } - // Save the tags - if translator.hasField("tag_ids") { - tagIDs, err := adjustGalleryTagIDs(tx, galleryID, *input.TagIds) - if err != nil { - _ = tx.Rollback() - return nil, err - } - - var tagJoins []models.GalleriesTags - for _, tagID := range tagIDs { - tagJoin := models.GalleriesTags{ - GalleryID: galleryID, - TagID: tagID, - } - tagJoins = append(tagJoins, tagJoin) - } - if err := jqb.UpdateGalleriesTags(galleryID, tagJoins, tx); err != nil { - _ = tx.Rollback() - return nil, err - } - } - } - - // Commit - if err := tx.Commit(); err != nil { + return nil + }); err != nil { return nil, err } return ret, nil } -func adjustGalleryPerformerIDs(tx *sqlx.Tx, galleryID int, ids models.BulkUpdateIds) ([]int, error) { - var ret []int - - jqb := models.NewJoinsQueryBuilder() - if ids.Mode == models.BulkUpdateIDModeAdd || ids.Mode == models.BulkUpdateIDModeRemove { - // adding to the joins - performerJoins, err := jqb.GetGalleryPerformers(galleryID, tx) - - if err != nil { - return nil, err - } - - for _, join := range performerJoins { - ret = append(ret, join.PerformerID) - } +func adjustGalleryPerformerIDs(qb models.GalleryReader, galleryID int, ids models.BulkUpdateIds) (ret []int, err error) { + ret, err = qb.GetPerformerIDs(galleryID) + if err != nil { + return nil, err } return adjustIDs(ret, ids), nil } -func adjustGalleryTagIDs(tx *sqlx.Tx, galleryID int, ids models.BulkUpdateIds) ([]int, error) { - var ret []int - - jqb := models.NewJoinsQueryBuilder() - if ids.Mode == models.BulkUpdateIDModeAdd || ids.Mode == models.BulkUpdateIDModeRemove { - // adding to the joins - tagJoins, err := jqb.GetGalleryTags(galleryID, tx) - - if err != nil { - return nil, err - } - - for _, join := range tagJoins { - ret = append(ret, join.TagID) - } +func adjustGalleryTagIDs(qb models.GalleryReader, galleryID int, ids models.BulkUpdateIds) (ret []int, err error) { + ret, err = qb.GetTagIDs(galleryID) + if err != nil { + return nil, err } return adjustIDs(ret, ids), nil } func (r *mutationResolver) GalleryDestroy(ctx context.Context, input models.GalleryDestroyInput) (bool, error) { - qb := models.NewGalleryQueryBuilder() - iqb := models.NewImageQueryBuilder() - tx := database.DB.MustBeginTx(ctx, nil) + galleryIDs, err := utils.StringSliceToIntSlice(input.Ids) + if err != nil { + return false, err + } var galleries []*models.Gallery var imgsToPostProcess []*models.Image var imgsToDelete []*models.Image - for _, id := range input.Ids { - galleryID, _ := strconv.Atoi(id) + if err := r.withTxn(ctx, func(repo models.Repository) error { + qb := repo.Gallery() + iqb := repo.Image() + + for _, id := range galleryIDs { + gallery, err := qb.Find(id) + if err != nil { + return err + } + + if gallery == nil { + return fmt.Errorf("gallery with id %d not found", id) + } - gallery, err := qb.Find(galleryID, tx) - if gallery != nil { galleries = append(galleries, gallery) - } - err = qb.Destroy(galleryID, tx) - if err != nil { - tx.Rollback() - return false, err - } - - // if this is a zip-based gallery, delete the images as well - if gallery.Zip { - imgs, err := iqb.FindByGalleryID(galleryID) - if err != nil { - tx.Rollback() - return false, err - } - - for _, img := range imgs { - err = iqb.Destroy(img.ID, tx) + // if this is a zip-based gallery, delete the images as well first + if gallery.Zip { + imgs, err := iqb.FindByGalleryID(id) if err != nil { - tx.Rollback() - return false, err + return err } - imgsToPostProcess = append(imgsToPostProcess, img) - } - } else if input.DeleteFile != nil && *input.DeleteFile { - // Delete image if it is only attached to this gallery - imgs, err := iqb.FindByGalleryID(galleryID) - if err != nil { - tx.Rollback() - return false, err - } - - for _, img := range imgs { - imgGalleries, err := qb.FindByImageID(img.ID, tx) - if err != nil { - tx.Rollback() - return false, err - } - - if len(imgGalleries) == 0 { - err = iqb.Destroy(img.ID, tx) - if err != nil { - tx.Rollback() - return false, err + for _, img := range imgs { + if err := iqb.Destroy(img.ID); err != nil { + return err } - imgsToDelete = append(imgsToDelete, img) imgsToPostProcess = append(imgsToPostProcess, img) } + } else if input.DeleteFile != nil && *input.DeleteFile { + // Delete image if it is only attached to this gallery + imgs, err := iqb.FindByGalleryID(id) + if err != nil { + return err + } + + for _, img := range imgs { + imgGalleries, err := qb.FindByImageID(img.ID) + if err != nil { + return err + } + + if len(imgGalleries) == 0 { + if err := iqb.Destroy(img.ID); err != nil { + return err + } + + imgsToDelete = append(imgsToDelete, img) + imgsToPostProcess = append(imgsToPostProcess, img) + } + } + } + + if err := qb.Destroy(id); err != nil { + return err } } - } - if err := tx.Commit(); err != nil { + return nil + }); err != nil { return false, err } @@ -479,34 +410,39 @@ func (r *mutationResolver) GalleryDestroy(ctx context.Context, input models.Gall } func (r *mutationResolver) AddGalleryImages(ctx context.Context, input models.GalleryAddInput) (bool, error) { - galleryID, _ := strconv.Atoi(input.GalleryID) - qb := models.NewGalleryQueryBuilder() - gallery, err := qb.Find(galleryID, nil) + galleryID, err := strconv.Atoi(input.GalleryID) if err != nil { return false, err } - if gallery == nil { - return false, errors.New("gallery not found") + imageIDs, err := utils.StringSliceToIntSlice(input.ImageIds) + if err != nil { + return false, err } - if gallery.Zip { - return false, errors.New("cannot modify zip gallery images") - } - - jqb := models.NewJoinsQueryBuilder() - tx := database.DB.MustBeginTx(ctx, nil) - - for _, id := range input.ImageIds { - imageID, _ := strconv.Atoi(id) - _, err := jqb.AddImageGallery(imageID, galleryID, tx) + if err := r.withTxn(ctx, func(repo models.Repository) error { + qb := repo.Gallery() + gallery, err := qb.Find(galleryID) if err != nil { - tx.Rollback() - return false, err + return err } - } - if err := tx.Commit(); err != nil { + if gallery == nil { + return errors.New("gallery not found") + } + + if gallery.Zip { + return errors.New("cannot modify zip gallery images") + } + + newIDs, err := qb.GetImageIDs(galleryID) + if err != nil { + return err + } + + newIDs = utils.IntAppendUniques(newIDs, imageIDs) + return qb.UpdateImages(galleryID, newIDs) + }); err != nil { return false, err } @@ -514,34 +450,39 @@ func (r *mutationResolver) AddGalleryImages(ctx context.Context, input models.Ga } func (r *mutationResolver) RemoveGalleryImages(ctx context.Context, input models.GalleryRemoveInput) (bool, error) { - galleryID, _ := strconv.Atoi(input.GalleryID) - qb := models.NewGalleryQueryBuilder() - gallery, err := qb.Find(galleryID, nil) + galleryID, err := strconv.Atoi(input.GalleryID) if err != nil { return false, err } - if gallery == nil { - return false, errors.New("gallery not found") + imageIDs, err := utils.StringSliceToIntSlice(input.ImageIds) + if err != nil { + return false, err } - if gallery.Zip { - return false, errors.New("cannot modify zip gallery images") - } - - jqb := models.NewJoinsQueryBuilder() - tx := database.DB.MustBeginTx(ctx, nil) - - for _, id := range input.ImageIds { - imageID, _ := strconv.Atoi(id) - _, err := jqb.RemoveImageGallery(imageID, galleryID, tx) + if err := r.withTxn(ctx, func(repo models.Repository) error { + qb := repo.Gallery() + gallery, err := qb.Find(galleryID) if err != nil { - tx.Rollback() - return false, err + return err } - } - if err := tx.Commit(); err != nil { + if gallery == nil { + return errors.New("gallery not found") + } + + if gallery.Zip { + return errors.New("cannot modify zip gallery images") + } + + newIDs, err := qb.GetImageIDs(galleryID) + if err != nil { + return err + } + + newIDs = utils.IntExclude(newIDs, imageIDs) + return qb.UpdateImages(galleryID, newIDs) + }); err != nil { return false, err } diff --git a/pkg/api/resolver_mutation_image.go b/pkg/api/resolver_mutation_image.go index 2d2ad4ada..ac5131b1e 100644 --- a/pkg/api/resolver_mutation_image.go +++ b/pkg/api/resolver_mutation_image.go @@ -2,71 +2,63 @@ package api import ( "context" + "fmt" "strconv" "time" - "github.com/jmoiron/sqlx" - - "github.com/stashapp/stash/pkg/database" "github.com/stashapp/stash/pkg/manager" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/utils" ) -func (r *mutationResolver) ImageUpdate(ctx context.Context, input models.ImageUpdateInput) (*models.Image, error) { - // Start the transaction and save the image - tx := database.DB.MustBeginTx(ctx, nil) - +func (r *mutationResolver) ImageUpdate(ctx context.Context, input models.ImageUpdateInput) (ret *models.Image, err error) { translator := changesetTranslator{ inputMap: getUpdateInputMap(ctx), } - ret, err := r.imageUpdate(input, translator, tx) - - if err != nil { - _ = tx.Rollback() - return nil, err - } - - // Commit - if err := tx.Commit(); err != nil { + // Start the transaction and save the image + if err := r.withTxn(ctx, func(repo models.Repository) error { + ret, err = r.imageUpdate(input, translator, repo) + return err + }); err != nil { return nil, err } return ret, nil } -func (r *mutationResolver) ImagesUpdate(ctx context.Context, input []*models.ImageUpdateInput) ([]*models.Image, error) { - // Start the transaction and save the image - tx := database.DB.MustBeginTx(ctx, nil) +func (r *mutationResolver) ImagesUpdate(ctx context.Context, input []*models.ImageUpdateInput) (ret []*models.Image, err error) { inputMaps := getUpdateInputMaps(ctx) - var ret []*models.Image + // Start the transaction and save the image + if err := r.withTxn(ctx, func(repo models.Repository) error { + for i, image := range input { + translator := changesetTranslator{ + inputMap: inputMaps[i], + } - for i, image := range input { - translator := changesetTranslator{ - inputMap: inputMaps[i], + thisImage, err := r.imageUpdate(*image, translator, repo) + if err != nil { + return err + } + + ret = append(ret, thisImage) } - thisImage, err := r.imageUpdate(*image, translator, tx) - ret = append(ret, thisImage) - - if err != nil { - _ = tx.Rollback() - return nil, err - } - } - - // Commit - if err := tx.Commit(); err != nil { + return nil + }); err != nil { return nil, err } return ret, nil } -func (r *mutationResolver) imageUpdate(input models.ImageUpdateInput, translator changesetTranslator, tx *sqlx.Tx) (*models.Image, error) { +func (r *mutationResolver) imageUpdate(input models.ImageUpdateInput, translator changesetTranslator, repo models.Repository) (*models.Image, error) { // Populate image from the input - imageID, _ := strconv.Atoi(input.ID) + imageID, err := strconv.Atoi(input.ID) + if err != nil { + return nil, err + } updatedTime := time.Now() updatedImage := models.ImagePartial{ @@ -79,43 +71,28 @@ func (r *mutationResolver) imageUpdate(input models.ImageUpdateInput, translator updatedImage.StudioID = translator.nullInt64FromString(input.StudioID, "studio_id") updatedImage.Organized = input.Organized - qb := models.NewImageQueryBuilder() - jqb := models.NewJoinsQueryBuilder() - image, err := qb.Update(updatedImage, tx) + qb := repo.Image() + image, err := qb.Update(updatedImage) if err != nil { return nil, err } - // don't set the galleries directly. Use add/remove gallery images interface instead + if translator.hasField("gallery_ids") { + if err := r.updateImageGalleries(qb, imageID, input.GalleryIds); err != nil { + return nil, err + } + } // Save the performers if translator.hasField("performer_ids") { - var performerJoins []models.PerformersImages - for _, pid := range input.PerformerIds { - performerID, _ := strconv.Atoi(pid) - performerJoin := models.PerformersImages{ - PerformerID: performerID, - ImageID: imageID, - } - performerJoins = append(performerJoins, performerJoin) - } - if err := jqb.UpdatePerformersImages(imageID, performerJoins, tx); err != nil { + if err := r.updateImagePerformers(qb, imageID, input.PerformerIds); err != nil { return nil, err } } // Save the tags if translator.hasField("tag_ids") { - var tagJoins []models.ImagesTags - for _, tid := range input.TagIds { - tagID, _ := strconv.Atoi(tid) - tagJoin := models.ImagesTags{ - ImageID: imageID, - TagID: tagID, - } - tagJoins = append(tagJoins, tagJoin) - } - if err := jqb.UpdateImagesTags(imageID, tagJoins, tx); err != nil { + if err := r.updateImageTags(qb, imageID, input.TagIds); err != nil { return nil, err } } @@ -123,15 +100,39 @@ func (r *mutationResolver) imageUpdate(input models.ImageUpdateInput, translator return image, nil } -func (r *mutationResolver) BulkImageUpdate(ctx context.Context, input models.BulkImageUpdateInput) ([]*models.Image, error) { +func (r *mutationResolver) updateImageGalleries(qb models.ImageReaderWriter, imageID int, galleryIDs []string) error { + ids, err := utils.StringSliceToIntSlice(galleryIDs) + if err != nil { + return err + } + return qb.UpdateGalleries(imageID, ids) +} + +func (r *mutationResolver) updateImagePerformers(qb models.ImageReaderWriter, imageID int, performerIDs []string) error { + ids, err := utils.StringSliceToIntSlice(performerIDs) + if err != nil { + return err + } + return qb.UpdatePerformers(imageID, ids) +} + +func (r *mutationResolver) updateImageTags(qb models.ImageReaderWriter, imageID int, tagsIDs []string) error { + ids, err := utils.StringSliceToIntSlice(tagsIDs) + if err != nil { + return err + } + return qb.UpdateTags(imageID, ids) +} + +func (r *mutationResolver) BulkImageUpdate(ctx context.Context, input models.BulkImageUpdateInput) (ret []*models.Image, err error) { + imageIDs, err := utils.StringSliceToIntSlice(input.Ids) + if err != nil { + return nil, err + } + // Populate image from the input updatedTime := time.Now() - // Start the transaction and save the image marker - tx := database.DB.MustBeginTx(ctx, nil) - qb := models.NewImageQueryBuilder() - jqb := models.NewJoinsQueryBuilder() - updatedImage := models.ImagePartial{ UpdatedAt: &models.SQLiteTimestamp{Timestamp: updatedTime}, } @@ -145,168 +146,113 @@ func (r *mutationResolver) BulkImageUpdate(ctx context.Context, input models.Bul updatedImage.StudioID = translator.nullInt64FromString(input.StudioID, "studio_id") updatedImage.Organized = input.Organized - ret := []*models.Image{} + // Start the transaction and save the image marker + if err := r.withTxn(ctx, func(repo models.Repository) error { + qb := repo.Image() - for _, imageIDStr := range input.Ids { - imageID, _ := strconv.Atoi(imageIDStr) - updatedImage.ID = imageID + for _, imageID := range imageIDs { + updatedImage.ID = imageID - image, err := qb.Update(updatedImage, tx) - if err != nil { - _ = tx.Rollback() - return nil, err - } - - ret = append(ret, image) - - // Save the galleries - if translator.hasField("gallery_ids") { - galleryIDs, err := adjustImageGalleryIDs(tx, imageID, *input.GalleryIds) + image, err := qb.Update(updatedImage) if err != nil { - _ = tx.Rollback() - return nil, err + return err } - var galleryJoins []models.GalleriesImages - for _, gid := range galleryIDs { - galleryJoin := models.GalleriesImages{ - GalleryID: gid, - ImageID: imageID, + ret = append(ret, image) + + // Save the galleries + if translator.hasField("gallery_ids") { + galleryIDs, err := adjustImageGalleryIDs(qb, imageID, *input.GalleryIds) + if err != nil { + return err + } + + if err := qb.UpdateGalleries(imageID, galleryIDs); err != nil { + return err } - galleryJoins = append(galleryJoins, galleryJoin) } - if err := jqb.UpdateGalleriesImages(imageID, galleryJoins, tx); err != nil { - return nil, err + + // Save the performers + if translator.hasField("performer_ids") { + performerIDs, err := adjustImagePerformerIDs(qb, imageID, *input.PerformerIds) + if err != nil { + return err + } + + if err := qb.UpdatePerformers(imageID, performerIDs); err != nil { + return err + } + } + + // Save the tags + if translator.hasField("tag_ids") { + tagIDs, err := adjustImageTagIDs(qb, imageID, *input.TagIds) + if err != nil { + return err + } + + if err := qb.UpdateTags(imageID, tagIDs); err != nil { + return err + } } } - // Save the performers - if translator.hasField("performer_ids") { - performerIDs, err := adjustImagePerformerIDs(tx, imageID, *input.PerformerIds) - if err != nil { - _ = tx.Rollback() - return nil, err - } - - var performerJoins []models.PerformersImages - for _, performerID := range performerIDs { - performerJoin := models.PerformersImages{ - PerformerID: performerID, - ImageID: imageID, - } - performerJoins = append(performerJoins, performerJoin) - } - if err := jqb.UpdatePerformersImages(imageID, performerJoins, tx); err != nil { - _ = tx.Rollback() - return nil, err - } - } - - // Save the tags - if translator.hasField("tag_ids") { - tagIDs, err := adjustImageTagIDs(tx, imageID, *input.TagIds) - if err != nil { - _ = tx.Rollback() - return nil, err - } - - var tagJoins []models.ImagesTags - for _, tagID := range tagIDs { - tagJoin := models.ImagesTags{ - ImageID: imageID, - TagID: tagID, - } - tagJoins = append(tagJoins, tagJoin) - } - if err := jqb.UpdateImagesTags(imageID, tagJoins, tx); err != nil { - _ = tx.Rollback() - return nil, err - } - } - } - - // Commit - if err := tx.Commit(); err != nil { + return nil + }); err != nil { return nil, err } return ret, nil } -func adjustImageGalleryIDs(tx *sqlx.Tx, imageID int, ids models.BulkUpdateIds) ([]int, error) { - var ret []int - - jqb := models.NewJoinsQueryBuilder() - if ids.Mode == models.BulkUpdateIDModeAdd || ids.Mode == models.BulkUpdateIDModeRemove { - // adding to the joins - galleryJoins, err := jqb.GetImageGalleries(imageID, tx) - - if err != nil { - return nil, err - } - - for _, join := range galleryJoins { - ret = append(ret, join.GalleryID) - } - } - - return adjustIDs(ret, ids), nil -} - -func adjustImagePerformerIDs(tx *sqlx.Tx, imageID int, ids models.BulkUpdateIds) ([]int, error) { - var ret []int - - jqb := models.NewJoinsQueryBuilder() - if ids.Mode == models.BulkUpdateIDModeAdd || ids.Mode == models.BulkUpdateIDModeRemove { - // adding to the joins - performerJoins, err := jqb.GetImagePerformers(imageID, tx) - - if err != nil { - return nil, err - } - - for _, join := range performerJoins { - ret = append(ret, join.PerformerID) - } - } - - return adjustIDs(ret, ids), nil -} - -func adjustImageTagIDs(tx *sqlx.Tx, imageID int, ids models.BulkUpdateIds) ([]int, error) { - var ret []int - - jqb := models.NewJoinsQueryBuilder() - if ids.Mode == models.BulkUpdateIDModeAdd || ids.Mode == models.BulkUpdateIDModeRemove { - // adding to the joins - tagJoins, err := jqb.GetImageTags(imageID, tx) - - if err != nil { - return nil, err - } - - for _, join := range tagJoins { - ret = append(ret, join.TagID) - } - } - - return adjustIDs(ret, ids), nil -} - -func (r *mutationResolver) ImageDestroy(ctx context.Context, input models.ImageDestroyInput) (bool, error) { - qb := models.NewImageQueryBuilder() - tx := database.DB.MustBeginTx(ctx, nil) - - imageID, _ := strconv.Atoi(input.ID) - image, err := qb.Find(imageID) - err = qb.Destroy(imageID, tx) - +func adjustImageGalleryIDs(qb models.ImageReader, imageID int, ids models.BulkUpdateIds) (ret []int, err error) { + ret, err = qb.GetGalleryIDs(imageID) + if err != nil { + return nil, err + } + + return adjustIDs(ret, ids), nil +} + +func adjustImagePerformerIDs(qb models.ImageReader, imageID int, ids models.BulkUpdateIds) (ret []int, err error) { + ret, err = qb.GetPerformerIDs(imageID) + if err != nil { + return nil, err + } + + return adjustIDs(ret, ids), nil +} + +func adjustImageTagIDs(qb models.ImageReader, imageID int, ids models.BulkUpdateIds) (ret []int, err error) { + ret, err = qb.GetTagIDs(imageID) + if err != nil { + return nil, err + } + + return adjustIDs(ret, ids), nil +} + +func (r *mutationResolver) ImageDestroy(ctx context.Context, input models.ImageDestroyInput) (ret bool, err error) { + imageID, err := strconv.Atoi(input.ID) if err != nil { - tx.Rollback() return false, err } - if err := tx.Commit(); err != nil { + var image *models.Image + if err := r.withTxn(ctx, func(repo models.Repository) error { + qb := repo.Image() + + image, err = qb.Find(imageID) + if err != nil { + return err + } + + if image == nil { + return fmt.Errorf("image with id %d not found", imageID) + } + + return qb.Destroy(imageID) + }); err != nil { return false, err } @@ -325,27 +271,35 @@ func (r *mutationResolver) ImageDestroy(ctx context.Context, input models.ImageD return true, nil } -func (r *mutationResolver) ImagesDestroy(ctx context.Context, input models.ImagesDestroyInput) (bool, error) { - qb := models.NewImageQueryBuilder() - tx := database.DB.MustBeginTx(ctx, nil) - - var images []*models.Image - for _, id := range input.Ids { - imageID, _ := strconv.Atoi(id) - - image, err := qb.Find(imageID) - if image != nil { - images = append(images, image) - } - err = qb.Destroy(imageID, tx) - - if err != nil { - tx.Rollback() - return false, err - } +func (r *mutationResolver) ImagesDestroy(ctx context.Context, input models.ImagesDestroyInput) (ret bool, err error) { + imageIDs, err := utils.StringSliceToIntSlice(input.Ids) + if err != nil { + return false, err } - if err := tx.Commit(); err != nil { + var images []*models.Image + if err := r.withTxn(ctx, func(repo models.Repository) error { + qb := repo.Image() + + for _, imageID := range imageIDs { + + image, err := qb.Find(imageID) + if err != nil { + return err + } + + if image == nil { + return fmt.Errorf("image with id %d not found", imageID) + } + + images = append(images, image) + if err := qb.Destroy(imageID); err != nil { + return err + } + } + + return nil + }); err != nil { return false, err } @@ -366,62 +320,56 @@ func (r *mutationResolver) ImagesDestroy(ctx context.Context, input models.Image return true, nil } -func (r *mutationResolver) ImageIncrementO(ctx context.Context, id string) (int, error) { - imageID, _ := strconv.Atoi(id) - - tx := database.DB.MustBeginTx(ctx, nil) - qb := models.NewImageQueryBuilder() - - newVal, err := qb.IncrementOCounter(imageID, tx) +func (r *mutationResolver) ImageIncrementO(ctx context.Context, id string) (ret int, err error) { + imageID, err := strconv.Atoi(id) if err != nil { - _ = tx.Rollback() return 0, err } - // Commit - if err := tx.Commit(); err != nil { + if err := r.withTxn(ctx, func(repo models.Repository) error { + qb := repo.Image() + + ret, err = qb.IncrementOCounter(imageID) + return err + }); err != nil { return 0, err } - return newVal, nil + return ret, nil } -func (r *mutationResolver) ImageDecrementO(ctx context.Context, id string) (int, error) { - imageID, _ := strconv.Atoi(id) - - tx := database.DB.MustBeginTx(ctx, nil) - qb := models.NewImageQueryBuilder() - - newVal, err := qb.DecrementOCounter(imageID, tx) +func (r *mutationResolver) ImageDecrementO(ctx context.Context, id string) (ret int, err error) { + imageID, err := strconv.Atoi(id) if err != nil { - _ = tx.Rollback() return 0, err } - // Commit - if err := tx.Commit(); err != nil { + if err := r.withTxn(ctx, func(repo models.Repository) error { + qb := repo.Image() + + ret, err = qb.DecrementOCounter(imageID) + return err + }); err != nil { return 0, err } - return newVal, nil + return ret, nil } -func (r *mutationResolver) ImageResetO(ctx context.Context, id string) (int, error) { - imageID, _ := strconv.Atoi(id) - - tx := database.DB.MustBeginTx(ctx, nil) - qb := models.NewImageQueryBuilder() - - newVal, err := qb.ResetOCounter(imageID, tx) +func (r *mutationResolver) ImageResetO(ctx context.Context, id string) (ret int, err error) { + imageID, err := strconv.Atoi(id) if err != nil { - _ = tx.Rollback() return 0, err } - // Commit - if err := tx.Commit(); err != nil { + if err := r.withTxn(ctx, func(repo models.Repository) error { + qb := repo.Image() + + ret, err = qb.ResetOCounter(imageID) + return err + }); err != nil { return 0, err } - return newVal, nil + return ret, nil } diff --git a/pkg/api/resolver_mutation_movie.go b/pkg/api/resolver_mutation_movie.go index 23756054e..320ed49da 100644 --- a/pkg/api/resolver_mutation_movie.go +++ b/pkg/api/resolver_mutation_movie.go @@ -6,7 +6,6 @@ import ( "strconv" "time" - "github.com/stashapp/stash/pkg/database" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/utils" ) @@ -85,24 +84,23 @@ func (r *mutationResolver) MovieCreate(ctx context.Context, input models.MovieCr } // Start the transaction and save the movie - tx := database.DB.MustBeginTx(ctx, nil) - qb := models.NewMovieQueryBuilder() - movie, err := qb.Create(newMovie, tx) - if err != nil { - _ = tx.Rollback() - return nil, err - } - - // update image table - if len(frontimageData) > 0 { - if err := qb.UpdateMovieImages(movie.ID, frontimageData, backimageData, tx); err != nil { - _ = tx.Rollback() - return nil, err + var movie *models.Movie + if err := r.withTxn(ctx, func(repo models.Repository) error { + qb := repo.Movie() + movie, err = qb.Create(newMovie) + if err != nil { + return err } - } - // Commit - if err := tx.Commit(); err != nil { + // update image table + if len(frontimageData) > 0 { + if err := qb.UpdateImages(movie.ID, frontimageData, backimageData); err != nil { + return err + } + } + + return nil + }); err != nil { return nil, err } @@ -111,7 +109,10 @@ func (r *mutationResolver) MovieCreate(ctx context.Context, input models.MovieCr func (r *mutationResolver) MovieUpdate(ctx context.Context, input models.MovieUpdateInput) (*models.Movie, error) { // Populate movie from the input - movieID, _ := strconv.Atoi(input.ID) + movieID, err := strconv.Atoi(input.ID) + if err != nil { + return nil, err + } updatedMovie := models.MoviePartial{ ID: movieID, @@ -123,7 +124,6 @@ func (r *mutationResolver) MovieUpdate(ctx context.Context, input models.MovieUp } var frontimageData []byte - var err error frontImageIncluded := translator.hasField("front_image") if input.FrontImage != nil { _, frontimageData, err = utils.ProcessBase64Image(*input.FrontImage) @@ -157,53 +157,49 @@ func (r *mutationResolver) MovieUpdate(ctx context.Context, input models.MovieUp updatedMovie.URL = translator.nullString(input.URL, "url") // Start the transaction and save the movie - tx := database.DB.MustBeginTx(ctx, nil) - qb := models.NewMovieQueryBuilder() - movie, err := qb.Update(updatedMovie, tx) - if err != nil { - _ = tx.Rollback() - return nil, err - } - - // update image table - if frontImageIncluded || backImageIncluded { - if !frontImageIncluded { - frontimageData, err = qb.GetFrontImage(updatedMovie.ID, tx) - if err != nil { - tx.Rollback() - return nil, err - } + var movie *models.Movie + if err := r.withTxn(ctx, func(repo models.Repository) error { + qb := repo.Movie() + movie, err = qb.Update(updatedMovie) + if err != nil { + return err } - if !backImageIncluded { - backimageData, err = qb.GetBackImage(updatedMovie.ID, tx) - if err != nil { - tx.Rollback() - return nil, err + + // update image table + if frontImageIncluded || backImageIncluded { + if !frontImageIncluded { + frontimageData, err = qb.GetFrontImage(updatedMovie.ID) + if err != nil { + return err + } + } + if !backImageIncluded { + backimageData, err = qb.GetBackImage(updatedMovie.ID) + if err != nil { + return err + } + } + + if len(frontimageData) == 0 && len(backimageData) == 0 { + // both images are being nulled. Destroy them. + if err := qb.DestroyImages(movie.ID); err != nil { + return err + } + } else { + // HACK - if front image is null and back image is not null, then set the front image + // to the default image since we can't have a null front image and a non-null back image + if frontimageData == nil && backimageData != nil { + _, frontimageData, _ = utils.ProcessBase64Image(models.DefaultMovieImage) + } + + if err := qb.UpdateImages(movie.ID, frontimageData, backimageData); err != nil { + return err + } } } - if len(frontimageData) == 0 && len(backimageData) == 0 { - // both images are being nulled. Destroy them. - if err := qb.DestroyMovieImages(movie.ID, tx); err != nil { - tx.Rollback() - return nil, err - } - } else { - // HACK - if front image is null and back image is not null, then set the front image - // to the default image since we can't have a null front image and a non-null back image - if frontimageData == nil && backimageData != nil { - _, frontimageData, _ = utils.ProcessBase64Image(models.DefaultMovieImage) - } - - if err := qb.UpdateMovieImages(movie.ID, frontimageData, backimageData, tx); err != nil { - _ = tx.Rollback() - return nil, err - } - } - } - - // Commit - if err := tx.Commit(); err != nil { + return nil + }); err != nil { return nil, err } @@ -211,28 +207,35 @@ func (r *mutationResolver) MovieUpdate(ctx context.Context, input models.MovieUp } func (r *mutationResolver) MovieDestroy(ctx context.Context, input models.MovieDestroyInput) (bool, error) { - qb := models.NewMovieQueryBuilder() - tx := database.DB.MustBeginTx(ctx, nil) - if err := qb.Destroy(input.ID, tx); err != nil { - _ = tx.Rollback() + id, err := strconv.Atoi(input.ID) + if err != nil { return false, err } - if err := tx.Commit(); err != nil { + + if err := r.withTxn(ctx, func(repo models.Repository) error { + return repo.Movie().Destroy(id) + }); err != nil { return false, err } return true, nil } -func (r *mutationResolver) MoviesDestroy(ctx context.Context, ids []string) (bool, error) { - qb := models.NewMovieQueryBuilder() - tx := database.DB.MustBeginTx(ctx, nil) - for _, id := range ids { - if err := qb.Destroy(id, tx); err != nil { - _ = tx.Rollback() - return false, err - } +func (r *mutationResolver) MoviesDestroy(ctx context.Context, movieIDs []string) (bool, error) { + ids, err := utils.StringSliceToIntSlice(movieIDs) + if err != nil { + return false, err } - if err := tx.Commit(); err != nil { + + if err := r.withTxn(ctx, func(repo models.Repository) error { + qb := repo.Movie() + for _, id := range ids { + if err := qb.Destroy(id); err != nil { + return err + } + } + + return nil + }); err != nil { return false, err } return true, nil diff --git a/pkg/api/resolver_mutation_performer.go b/pkg/api/resolver_mutation_performer.go index d249891cf..790ce4cda 100644 --- a/pkg/api/resolver_mutation_performer.go +++ b/pkg/api/resolver_mutation_performer.go @@ -6,7 +6,6 @@ import ( "strconv" "time" - "github.com/stashapp/stash/pkg/database" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/utils" ) @@ -86,41 +85,32 @@ func (r *mutationResolver) PerformerCreate(ctx context.Context, input models.Per } // Start the transaction and save the performer - tx := database.DB.MustBeginTx(ctx, nil) - qb := models.NewPerformerQueryBuilder() - jqb := models.NewJoinsQueryBuilder() + var performer *models.Performer + if err := r.withTxn(ctx, func(repo models.Repository) error { + qb := repo.Performer() - performer, err := qb.Create(newPerformer, tx) - if err != nil { - _ = tx.Rollback() - return nil, err - } - - // update image table - if len(imageData) > 0 { - if err := qb.UpdatePerformerImage(performer.ID, imageData, tx); err != nil { - _ = tx.Rollback() - return nil, err + performer, err = qb.Create(newPerformer) + if err != nil { + return err } - } - // Save the stash_ids - if input.StashIds != nil { - var stashIDJoins []models.StashID - for _, stashID := range input.StashIds { - newJoin := models.StashID{ - StashID: stashID.StashID, - Endpoint: stashID.Endpoint, + // update image table + if len(imageData) > 0 { + if err := qb.UpdateImage(performer.ID, imageData); err != nil { + return err } - stashIDJoins = append(stashIDJoins, newJoin) } - if err := jqb.UpdatePerformerStashIDs(performer.ID, stashIDJoins, tx); err != nil { - return nil, err - } - } - // Commit - if err := tx.Commit(); err != nil { + // Save the stash_ids + if input.StashIds != nil { + stashIDJoins := models.StashIDsFromInput(input.StashIds) + if err := qb.UpdateStashIDs(performer.ID, stashIDJoins); err != nil { + return err + } + } + + return nil + }); err != nil { return nil, err } @@ -183,47 +173,38 @@ func (r *mutationResolver) PerformerUpdate(ctx context.Context, input models.Per updatedPerformer.Favorite = translator.nullBool(input.Favorite, "favorite") // Start the transaction and save the performer - tx := database.DB.MustBeginTx(ctx, nil) - qb := models.NewPerformerQueryBuilder() - jqb := models.NewJoinsQueryBuilder() + var performer *models.Performer + if err := r.withTxn(ctx, func(repo models.Repository) error { + qb := repo.Performer() - performer, err := qb.Update(updatedPerformer, tx) - if err != nil { - tx.Rollback() - return nil, err - } - - // update image table - if len(imageData) > 0 { - if err := qb.UpdatePerformerImage(performer.ID, imageData, tx); err != nil { - tx.Rollback() - return nil, err + var err error + performer, err = qb.Update(updatedPerformer) + if err != nil { + return err } - } else if imageIncluded { - // must be unsetting - if err := qb.DestroyPerformerImage(performer.ID, tx); err != nil { - tx.Rollback() - return nil, err - } - } - // Save the stash_ids - if translator.hasField("stash_ids") { - var stashIDJoins []models.StashID - for _, stashID := range input.StashIds { - newJoin := models.StashID{ - StashID: stashID.StashID, - Endpoint: stashID.Endpoint, + // update image table + if len(imageData) > 0 { + if err := qb.UpdateImage(performer.ID, imageData); err != nil { + return err + } + } else if imageIncluded { + // must be unsetting + if err := qb.DestroyImage(performer.ID); err != nil { + return err } - stashIDJoins = append(stashIDJoins, newJoin) } - if err := jqb.UpdatePerformerStashIDs(performerID, stashIDJoins, tx); err != nil { - return nil, err - } - } - // Commit - if err := tx.Commit(); err != nil { + // Save the stash_ids + if translator.hasField("stash_ids") { + stashIDJoins := models.StashIDsFromInput(input.StashIds) + if err := qb.UpdateStashIDs(performerID, stashIDJoins); err != nil { + return err + } + } + + return nil + }); err != nil { return nil, err } @@ -231,28 +212,35 @@ func (r *mutationResolver) PerformerUpdate(ctx context.Context, input models.Per } func (r *mutationResolver) PerformerDestroy(ctx context.Context, input models.PerformerDestroyInput) (bool, error) { - qb := models.NewPerformerQueryBuilder() - tx := database.DB.MustBeginTx(ctx, nil) - if err := qb.Destroy(input.ID, tx); err != nil { - _ = tx.Rollback() + id, err := strconv.Atoi(input.ID) + if err != nil { return false, err } - if err := tx.Commit(); err != nil { + + if err := r.withTxn(ctx, func(repo models.Repository) error { + return repo.Performer().Destroy(id) + }); err != nil { return false, err } return true, nil } -func (r *mutationResolver) PerformersDestroy(ctx context.Context, ids []string) (bool, error) { - qb := models.NewPerformerQueryBuilder() - tx := database.DB.MustBeginTx(ctx, nil) - for _, id := range ids { - if err := qb.Destroy(id, tx); err != nil { - _ = tx.Rollback() - return false, err - } +func (r *mutationResolver) PerformersDestroy(ctx context.Context, performerIDs []string) (bool, error) { + ids, err := utils.StringSliceToIntSlice(performerIDs) + if err != nil { + return false, err } - if err := tx.Commit(); err != nil { + + if err := r.withTxn(ctx, func(repo models.Repository) error { + qb := repo.Performer() + for _, id := range ids { + if err := qb.Destroy(id); err != nil { + return err + } + } + + return nil + }); err != nil { return false, err } return true, nil diff --git a/pkg/api/resolver_mutation_scene.go b/pkg/api/resolver_mutation_scene.go index aa0ff47a5..02b3920d1 100644 --- a/pkg/api/resolver_mutation_scene.go +++ b/pkg/api/resolver_mutation_scene.go @@ -3,73 +3,64 @@ package api import ( "context" "database/sql" + "fmt" "strconv" "time" - "github.com/jmoiron/sqlx" - - "github.com/stashapp/stash/pkg/database" "github.com/stashapp/stash/pkg/manager" "github.com/stashapp/stash/pkg/manager/config" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/utils" ) -func (r *mutationResolver) SceneUpdate(ctx context.Context, input models.SceneUpdateInput) (*models.Scene, error) { - // Start the transaction and save the scene - tx := database.DB.MustBeginTx(ctx, nil) - +func (r *mutationResolver) SceneUpdate(ctx context.Context, input models.SceneUpdateInput) (ret *models.Scene, err error) { translator := changesetTranslator{ inputMap: getUpdateInputMap(ctx), } - ret, err := r.sceneUpdate(input, translator, tx) - if err != nil { - _ = tx.Rollback() - return nil, err - } - - // Commit - if err := tx.Commit(); err != nil { + // Start the transaction and save the scene + if err := r.withTxn(ctx, func(repo models.Repository) error { + ret, err = r.sceneUpdate(input, translator, repo) + return err + }); err != nil { return nil, err } return ret, nil } -func (r *mutationResolver) ScenesUpdate(ctx context.Context, input []*models.SceneUpdateInput) ([]*models.Scene, error) { - // Start the transaction and save the scene - tx := database.DB.MustBeginTx(ctx, nil) - - var ret []*models.Scene - +func (r *mutationResolver) ScenesUpdate(ctx context.Context, input []*models.SceneUpdateInput) (ret []*models.Scene, err error) { inputMaps := getUpdateInputMaps(ctx) - for i, scene := range input { - translator := changesetTranslator{ - inputMap: inputMaps[i], + // Start the transaction and save the scene + if err := r.withTxn(ctx, func(repo models.Repository) error { + for i, scene := range input { + translator := changesetTranslator{ + inputMap: inputMaps[i], + } + + thisScene, err := r.sceneUpdate(*scene, translator, repo) + ret = append(ret, thisScene) + + if err != nil { + return err + } } - thisScene, err := r.sceneUpdate(*scene, translator, tx) - ret = append(ret, thisScene) - - if err != nil { - _ = tx.Rollback() - return nil, err - } - } - - // Commit - if err := tx.Commit(); err != nil { + return nil + }); err != nil { return nil, err } return ret, nil } -func (r *mutationResolver) sceneUpdate(input models.SceneUpdateInput, translator changesetTranslator, tx *sqlx.Tx) (*models.Scene, error) { +func (r *mutationResolver) sceneUpdate(input models.SceneUpdateInput, translator changesetTranslator, repo models.Repository) (*models.Scene, error) { // Populate scene from the input - sceneID, _ := strconv.Atoi(input.ID) + sceneID, err := strconv.Atoi(input.ID) + if err != nil { + return nil, err + } var coverImageData []byte @@ -97,24 +88,23 @@ func (r *mutationResolver) sceneUpdate(input models.SceneUpdateInput, translator // update the cover after updating the scene } - qb := models.NewSceneQueryBuilder() - jqb := models.NewJoinsQueryBuilder() - scene, err := qb.Update(updatedScene, tx) + qb := repo.Scene() + scene, err := qb.Update(updatedScene) if err != nil { return nil, err } // update cover table if len(coverImageData) > 0 { - if err := qb.UpdateSceneCover(sceneID, coverImageData, tx); err != nil { + if err := qb.UpdateCover(sceneID, coverImageData); err != nil { return nil, err } } // Clear the existing gallery value if translator.hasField("gallery_id") { - gqb := models.NewGalleryQueryBuilder() - err = gqb.ClearGalleryId(sceneID, tx) + gqb := repo.Gallery() + err = gqb.ClearGalleryId(sceneID) if err != nil { return nil, err } @@ -122,13 +112,12 @@ func (r *mutationResolver) sceneUpdate(input models.SceneUpdateInput, translator if input.GalleryID != nil { // Save the gallery galleryID, _ := strconv.Atoi(*input.GalleryID) - updatedGallery := models.Gallery{ + updatedGallery := models.GalleryPartial{ ID: galleryID, - SceneID: sql.NullInt64{Int64: int64(sceneID), Valid: true}, - UpdatedAt: models.SQLiteTimestamp{Timestamp: updatedTime}, + SceneID: &sql.NullInt64{Int64: int64(sceneID), Valid: true}, + UpdatedAt: &models.SQLiteTimestamp{Timestamp: updatedTime}, } - gqb := models.NewGalleryQueryBuilder() - _, err := gqb.Update(updatedGallery, tx) + _, err := gqb.UpdatePartial(updatedGallery) if err != nil { return nil, err } @@ -137,59 +126,29 @@ func (r *mutationResolver) sceneUpdate(input models.SceneUpdateInput, translator // Save the performers if translator.hasField("performer_ids") { - var performerJoins []models.PerformersScenes - for _, pid := range input.PerformerIds { - performerID, _ := strconv.Atoi(pid) - performerJoin := models.PerformersScenes{ - PerformerID: performerID, - SceneID: sceneID, - } - performerJoins = append(performerJoins, performerJoin) - } - if err := jqb.UpdatePerformersScenes(sceneID, performerJoins, tx); err != nil { + if err := r.updateScenePerformers(qb, sceneID, input.PerformerIds); err != nil { return nil, err } } // Save the movies if translator.hasField("movies") { - var movieJoins []models.MoviesScenes - - for _, movie := range input.Movies { - - movieID, _ := strconv.Atoi(movie.MovieID) - - movieJoin := models.MoviesScenes{ - MovieID: movieID, - SceneID: sceneID, - } - - if movie.SceneIndex != nil { - movieJoin.SceneIndex = sql.NullInt64{ - Int64: int64(*movie.SceneIndex), - Valid: true, - } - } - - movieJoins = append(movieJoins, movieJoin) - } - if err := jqb.UpdateMoviesScenes(sceneID, movieJoins, tx); err != nil { + if err := r.updateSceneMovies(qb, sceneID, input.Movies); err != nil { return nil, err } } // Save the tags if translator.hasField("tag_ids") { - var tagJoins []models.ScenesTags - for _, tid := range input.TagIds { - tagID, _ := strconv.Atoi(tid) - tagJoin := models.ScenesTags{ - SceneID: sceneID, - TagID: tagID, - } - tagJoins = append(tagJoins, tagJoin) + if err := r.updateSceneTags(qb, sceneID, input.TagIds); err != nil { + return nil, err } - if err := jqb.UpdateScenesTags(sceneID, tagJoins, tx); err != nil { + } + + // Save the stash_ids + if translator.hasField("stash_ids") { + stashIDJoins := models.StashIDsFromInput(input.StashIds) + if err := qb.UpdateStashIDs(sceneID, stashIDJoins); err != nil { return nil, err } } @@ -202,25 +161,57 @@ func (r *mutationResolver) sceneUpdate(input models.SceneUpdateInput, translator } } - // Save the stash_ids - if translator.hasField("stash_ids") { - var stashIDJoins []models.StashID - for _, stashID := range input.StashIds { - newJoin := models.StashID{ - StashID: stashID.StashID, - Endpoint: stashID.Endpoint, - } - stashIDJoins = append(stashIDJoins, newJoin) - } - if err := jqb.UpdateSceneStashIDs(sceneID, stashIDJoins, tx); err != nil { - return nil, err - } - } - return scene, nil } +func (r *mutationResolver) updateScenePerformers(qb models.SceneReaderWriter, sceneID int, performerIDs []string) error { + ids, err := utils.StringSliceToIntSlice(performerIDs) + if err != nil { + return err + } + return qb.UpdatePerformers(sceneID, ids) +} + +func (r *mutationResolver) updateSceneMovies(qb models.SceneReaderWriter, sceneID int, movies []*models.SceneMovieInput) error { + var movieJoins []models.MoviesScenes + + for _, movie := range movies { + movieID, err := strconv.Atoi(movie.MovieID) + if err != nil { + return err + } + + movieJoin := models.MoviesScenes{ + MovieID: movieID, + } + + if movie.SceneIndex != nil { + movieJoin.SceneIndex = sql.NullInt64{ + Int64: int64(*movie.SceneIndex), + Valid: true, + } + } + + movieJoins = append(movieJoins, movieJoin) + } + + return qb.UpdateMovies(sceneID, movieJoins) +} + +func (r *mutationResolver) updateSceneTags(qb models.SceneReaderWriter, sceneID int, tagsIDs []string) error { + ids, err := utils.StringSliceToIntSlice(tagsIDs) + if err != nil { + return err + } + return qb.UpdateTags(sceneID, ids) +} + func (r *mutationResolver) BulkSceneUpdate(ctx context.Context, input models.BulkSceneUpdateInput) ([]*models.Scene, error) { + sceneIDs, err := utils.StringSliceToIntSlice(input.Ids) + if err != nil { + return nil, err + } + // Populate scene from the input updatedTime := time.Now() @@ -228,11 +219,6 @@ func (r *mutationResolver) BulkSceneUpdate(ctx context.Context, input models.Bul inputMap: getUpdateInputMap(ctx), } - // Start the transaction and save the scene marker - tx := database.DB.MustBeginTx(ctx, nil) - qb := models.NewSceneQueryBuilder() - jqb := models.NewJoinsQueryBuilder() - updatedScene := models.ScenePartial{ UpdatedAt: &models.SQLiteTimestamp{Timestamp: updatedTime}, } @@ -247,84 +233,62 @@ func (r *mutationResolver) BulkSceneUpdate(ctx context.Context, input models.Bul ret := []*models.Scene{} - for _, sceneIDStr := range input.Ids { - sceneID, _ := strconv.Atoi(sceneIDStr) - updatedScene.ID = sceneID + // Start the transaction and save the scene marker + if err := r.withTxn(ctx, func(repo models.Repository) error { + qb := repo.Scene() + gqb := repo.Gallery() - scene, err := qb.Update(updatedScene, tx) - if err != nil { - _ = tx.Rollback() - return nil, err - } + for _, sceneID := range sceneIDs { + updatedScene.ID = sceneID - ret = append(ret, scene) - - if translator.hasField("gallery_id") { - // Save the gallery - var galleryID int - if input.GalleryID != nil { - galleryID, _ = strconv.Atoi(*input.GalleryID) - } - updatedGallery := models.Gallery{ - ID: galleryID, - SceneID: sql.NullInt64{Int64: int64(sceneID), Valid: true}, - UpdatedAt: models.SQLiteTimestamp{Timestamp: updatedTime}, - } - gqb := models.NewGalleryQueryBuilder() - _, err := gqb.Update(updatedGallery, tx) + scene, err := qb.Update(updatedScene) if err != nil { - _ = tx.Rollback() - return nil, err - } - } - - // Save the performers - if translator.hasField("performer_ids") { - performerIDs, err := adjustScenePerformerIDs(tx, sceneID, *input.PerformerIds) - if err != nil { - _ = tx.Rollback() - return nil, err + return err } - var performerJoins []models.PerformersScenes - for _, performerID := range performerIDs { - performerJoin := models.PerformersScenes{ - PerformerID: performerID, - SceneID: sceneID, + ret = append(ret, scene) + + if translator.hasField("gallery_id") { + // Save the gallery + galleryID, _ := strconv.Atoi(*input.GalleryID) + updatedGallery := models.GalleryPartial{ + ID: galleryID, + SceneID: &sql.NullInt64{Int64: int64(sceneID), Valid: true}, + UpdatedAt: &models.SQLiteTimestamp{Timestamp: updatedTime}, } - performerJoins = append(performerJoins, performerJoin) - } - if err := jqb.UpdatePerformersScenes(sceneID, performerJoins, tx); err != nil { - _ = tx.Rollback() - return nil, err - } - } - // Save the tags - if translator.hasField("tag_ids") { - tagIDs, err := adjustSceneTagIDs(tx, sceneID, *input.TagIds) - if err != nil { - _ = tx.Rollback() - return nil, err - } - - var tagJoins []models.ScenesTags - for _, tagID := range tagIDs { - tagJoin := models.ScenesTags{ - SceneID: sceneID, - TagID: tagID, + if _, err := gqb.UpdatePartial(updatedGallery); err != nil { + return err } - tagJoins = append(tagJoins, tagJoin) } - if err := jqb.UpdateScenesTags(sceneID, tagJoins, tx); err != nil { - _ = tx.Rollback() - return nil, err + + // Save the performers + if translator.hasField("performer_ids") { + performerIDs, err := adjustScenePerformerIDs(qb, sceneID, *input.PerformerIds) + if err != nil { + return err + } + + if err := qb.UpdatePerformers(sceneID, performerIDs); err != nil { + return err + } + } + + // Save the tags + if translator.hasField("tag_ids") { + tagIDs, err := adjustSceneTagIDs(qb, sceneID, *input.TagIds) + if err != nil { + return err + } + + if err := qb.UpdateTags(sceneID, tagIDs); err != nil { + return err + } } } - } - // Commit - if err := tx.Commit(); err != nil { + return nil + }); err != nil { return nil, err } @@ -332,6 +296,17 @@ func (r *mutationResolver) BulkSceneUpdate(ctx context.Context, input models.Bul } func adjustIDs(existingIDs []int, updateIDs models.BulkUpdateIds) []int { + // if we are setting the ids, just return the ids + if updateIDs.Mode == models.BulkUpdateIDModeSet { + existingIDs = []int{} + for _, idStr := range updateIDs.Ids { + id, _ := strconv.Atoi(idStr) + existingIDs = append(existingIDs, id) + } + + return existingIDs + } + for _, idStr := range updateIDs.Ids { id, _ := strconv.Atoi(idStr) @@ -357,63 +332,53 @@ func adjustIDs(existingIDs []int, updateIDs models.BulkUpdateIds) []int { return existingIDs } -func adjustScenePerformerIDs(tx *sqlx.Tx, sceneID int, ids models.BulkUpdateIds) ([]int, error) { - var ret []int - - jqb := models.NewJoinsQueryBuilder() - if ids.Mode == models.BulkUpdateIDModeAdd || ids.Mode == models.BulkUpdateIDModeRemove { - // adding to the joins - performerJoins, err := jqb.GetScenePerformers(sceneID, tx) - - if err != nil { - return nil, err - } - - for _, join := range performerJoins { - ret = append(ret, join.PerformerID) - } +func adjustScenePerformerIDs(qb models.SceneReader, sceneID int, ids models.BulkUpdateIds) (ret []int, err error) { + ret, err = qb.GetPerformerIDs(sceneID) + if err != nil { + return nil, err } return adjustIDs(ret, ids), nil } -func adjustSceneTagIDs(tx *sqlx.Tx, sceneID int, ids models.BulkUpdateIds) ([]int, error) { - var ret []int - - jqb := models.NewJoinsQueryBuilder() - if ids.Mode == models.BulkUpdateIDModeAdd || ids.Mode == models.BulkUpdateIDModeRemove { - // adding to the joins - tagJoins, err := jqb.GetSceneTags(sceneID, tx) - - if err != nil { - return nil, err - } - - for _, join := range tagJoins { - ret = append(ret, join.TagID) - } +func adjustSceneTagIDs(qb models.SceneReader, sceneID int, ids models.BulkUpdateIds) (ret []int, err error) { + ret, err = qb.GetTagIDs(sceneID) + if err != nil { + return nil, err } return adjustIDs(ret, ids), nil } func (r *mutationResolver) SceneDestroy(ctx context.Context, input models.SceneDestroyInput) (bool, error) { - qb := models.NewSceneQueryBuilder() - tx := database.DB.MustBeginTx(ctx, nil) - - sceneID, _ := strconv.Atoi(input.ID) - scene, err := qb.Find(sceneID) - err = manager.DestroyScene(sceneID, tx) - + sceneID, err := strconv.Atoi(input.ID) if err != nil { - tx.Rollback() return false, err } - if err := tx.Commit(); err != nil { + var scene *models.Scene + var postCommitFunc func() + if err := r.withTxn(ctx, func(repo models.Repository) error { + qb := repo.Scene() + var err error + scene, err = qb.Find(sceneID) + if err != nil { + return err + } + + if scene == nil { + return fmt.Errorf("scene with id %d not found", sceneID) + } + + postCommitFunc, err = manager.DestroyScene(scene, repo) + return err + }); err != nil { return false, err } + // perform the post-commit actions + postCommitFunc() + // if delete generated is true, then delete the generated files // for the scene if input.DeleteGenerated != nil && *input.DeleteGenerated { @@ -430,27 +395,33 @@ func (r *mutationResolver) SceneDestroy(ctx context.Context, input models.SceneD } func (r *mutationResolver) ScenesDestroy(ctx context.Context, input models.ScenesDestroyInput) (bool, error) { - qb := models.NewSceneQueryBuilder() - tx := database.DB.MustBeginTx(ctx, nil) - var scenes []*models.Scene - for _, id := range input.Ids { - sceneID, _ := strconv.Atoi(id) + var postCommitFuncs []func() + if err := r.withTxn(ctx, func(repo models.Repository) error { + qb := repo.Scene() - scene, err := qb.Find(sceneID) - if scene != nil { - scenes = append(scenes, scene) - } - err = manager.DestroyScene(sceneID, tx) + for _, id := range input.Ids { + sceneID, _ := strconv.Atoi(id) - if err != nil { - tx.Rollback() - return false, err + scene, err := qb.Find(sceneID) + if scene != nil { + scenes = append(scenes, scene) + } + f, err := manager.DestroyScene(scene, repo) + if err != nil { + return err + } + + postCommitFuncs = append(postCommitFuncs, f) } + + return nil + }); err != nil { + return false, err } - if err := tx.Commit(); err != nil { - return false, err + for _, f := range postCommitFuncs { + f() } fileNamingAlgo := config.GetVideoFileNamingAlgorithm() @@ -472,8 +443,16 @@ func (r *mutationResolver) ScenesDestroy(ctx context.Context, input models.Scene } func (r *mutationResolver) SceneMarkerCreate(ctx context.Context, input models.SceneMarkerCreateInput) (*models.SceneMarker, error) { - primaryTagID, _ := strconv.Atoi(input.PrimaryTagID) - sceneID, _ := strconv.Atoi(input.SceneID) + primaryTagID, err := strconv.Atoi(input.PrimaryTagID) + if err != nil { + return nil, err + } + + sceneID, err := strconv.Atoi(input.SceneID) + if err != nil { + return nil, err + } + currentTime := time.Now() newSceneMarker := models.SceneMarker{ Title: input.Title, @@ -484,14 +463,31 @@ func (r *mutationResolver) SceneMarkerCreate(ctx context.Context, input models.S UpdatedAt: models.SQLiteTimestamp{Timestamp: currentTime}, } - return changeMarker(ctx, create, newSceneMarker, input.TagIds) + tagIDs, err := utils.StringSliceToIntSlice(input.TagIds) + if err != nil { + return nil, err + } + + return r.changeMarker(ctx, create, newSceneMarker, tagIDs) } func (r *mutationResolver) SceneMarkerUpdate(ctx context.Context, input models.SceneMarkerUpdateInput) (*models.SceneMarker, error) { // Populate scene marker from the input - sceneMarkerID, _ := strconv.Atoi(input.ID) - sceneID, _ := strconv.Atoi(input.SceneID) - primaryTagID, _ := strconv.Atoi(input.PrimaryTagID) + sceneMarkerID, err := strconv.Atoi(input.ID) + if err != nil { + return nil, err + } + + primaryTagID, err := strconv.Atoi(input.PrimaryTagID) + if err != nil { + return nil, err + } + + sceneID, err := strconv.Atoi(input.SceneID) + if err != nil { + return nil, err + } + updatedSceneMarker := models.SceneMarker{ ID: sceneMarkerID, Title: input.Title, @@ -501,168 +497,151 @@ func (r *mutationResolver) SceneMarkerUpdate(ctx context.Context, input models.S UpdatedAt: models.SQLiteTimestamp{Timestamp: time.Now()}, } - return changeMarker(ctx, update, updatedSceneMarker, input.TagIds) + tagIDs, err := utils.StringSliceToIntSlice(input.TagIds) + if err != nil { + return nil, err + } + + return r.changeMarker(ctx, update, updatedSceneMarker, tagIDs) } func (r *mutationResolver) SceneMarkerDestroy(ctx context.Context, id string) (bool, error) { - qb := models.NewSceneMarkerQueryBuilder() - tx := database.DB.MustBeginTx(ctx, nil) - - markerID, _ := strconv.Atoi(id) - marker, err := qb.Find(markerID) - + markerID, err := strconv.Atoi(id) if err != nil { return false, err } - if err := qb.Destroy(id, tx); err != nil { - _ = tx.Rollback() - return false, err - } - if err := tx.Commit(); err != nil { + var postCommitFunc func() + if err := r.withTxn(ctx, func(repo models.Repository) error { + qb := repo.SceneMarker() + sqb := repo.Scene() + + marker, err := qb.Find(markerID) + + if err != nil { + return err + } + + if marker == nil { + return fmt.Errorf("scene marker with id %d not found", markerID) + } + + scene, err := sqb.Find(int(marker.SceneID.Int64)) + if err != nil { + return err + } + + postCommitFunc, err = manager.DestroySceneMarker(scene, marker, qb) + return err + }); err != nil { return false, err } - // delete the preview for the marker - sqb := models.NewSceneQueryBuilder() - scene, _ := sqb.Find(int(marker.SceneID.Int64)) - - if scene != nil { - seconds := int(marker.Seconds) - manager.DeleteSceneMarkerFiles(scene, seconds, config.GetVideoFileNamingAlgorithm()) - } + postCommitFunc() return true, nil } -func changeMarker(ctx context.Context, changeType int, changedMarker models.SceneMarker, tagIds []string) (*models.SceneMarker, error) { - // Start the transaction and save the scene marker - tx := database.DB.MustBeginTx(ctx, nil) - qb := models.NewSceneMarkerQueryBuilder() - jqb := models.NewJoinsQueryBuilder() - +func (r *mutationResolver) changeMarker(ctx context.Context, changeType int, changedMarker models.SceneMarker, tagIDs []int) (*models.SceneMarker, error) { var existingMarker *models.SceneMarker var sceneMarker *models.SceneMarker - var err error - switch changeType { - case create: - sceneMarker, err = qb.Create(changedMarker, tx) - case update: - // check to see if timestamp was changed - existingMarker, err = qb.Find(changedMarker.ID) - if err == nil { - sceneMarker, err = qb.Update(changedMarker, tx) - } - } - if err != nil { - _ = tx.Rollback() - return nil, err - } + var scene *models.Scene - // Save the marker tags - var markerTagJoins []models.SceneMarkersTags - for _, tid := range tagIds { - tagID, _ := strconv.Atoi(tid) - if tagID == changedMarker.PrimaryTagID { - continue // If this tag is the primary tag, then let's not add it. - } - markerTag := models.SceneMarkersTags{ - SceneMarkerID: sceneMarker.ID, - TagID: tagID, - } - markerTagJoins = append(markerTagJoins, markerTag) - } - switch changeType { - case create: - if err := jqb.CreateSceneMarkersTags(markerTagJoins, tx); err != nil { - _ = tx.Rollback() - return nil, err - } - case update: - if err := jqb.UpdateSceneMarkersTags(changedMarker.ID, markerTagJoins, tx); err != nil { - _ = tx.Rollback() - return nil, err - } - } + // Start the transaction and save the scene marker + if err := r.withTxn(ctx, func(repo models.Repository) error { + qb := repo.SceneMarker() + sqb := repo.Scene() - // Commit - if err := tx.Commit(); err != nil { + var err error + switch changeType { + case create: + sceneMarker, err = qb.Create(changedMarker) + case update: + // check to see if timestamp was changed + existingMarker, err = qb.Find(changedMarker.ID) + if err != nil { + return err + } + sceneMarker, err = qb.Update(changedMarker) + if err != nil { + return err + } + + scene, err = sqb.Find(int(existingMarker.SceneID.Int64)) + } + if err != nil { + return err + } + + // Save the marker tags + // If this tag is the primary tag, then let's not add it. + tagIDs = utils.IntExclude(tagIDs, []int{changedMarker.PrimaryTagID}) + return qb.UpdateTags(sceneMarker.ID, tagIDs) + }); err != nil { return nil, err } // remove the marker preview if the timestamp was changed - if existingMarker != nil && existingMarker.Seconds != changedMarker.Seconds { - sqb := models.NewSceneQueryBuilder() - - scene, _ := sqb.Find(int(existingMarker.SceneID.Int64)) - - if scene != nil { - seconds := int(existingMarker.Seconds) - manager.DeleteSceneMarkerFiles(scene, seconds, config.GetVideoFileNamingAlgorithm()) - } + if scene != nil && existingMarker != nil && existingMarker.Seconds != changedMarker.Seconds { + seconds := int(existingMarker.Seconds) + manager.DeleteSceneMarkerFiles(scene, seconds, config.GetVideoFileNamingAlgorithm()) } return sceneMarker, nil } -func (r *mutationResolver) SceneIncrementO(ctx context.Context, id string) (int, error) { - sceneID, _ := strconv.Atoi(id) - - tx := database.DB.MustBeginTx(ctx, nil) - qb := models.NewSceneQueryBuilder() - - newVal, err := qb.IncrementOCounter(sceneID, tx) +func (r *mutationResolver) SceneIncrementO(ctx context.Context, id string) (ret int, err error) { + sceneID, err := strconv.Atoi(id) if err != nil { - _ = tx.Rollback() return 0, err } - // Commit - if err := tx.Commit(); err != nil { + if err := r.withTxn(ctx, func(repo models.Repository) error { + qb := repo.Scene() + + ret, err = qb.IncrementOCounter(sceneID) + return err + }); err != nil { return 0, err } - return newVal, nil + return ret, nil } -func (r *mutationResolver) SceneDecrementO(ctx context.Context, id string) (int, error) { - sceneID, _ := strconv.Atoi(id) - - tx := database.DB.MustBeginTx(ctx, nil) - qb := models.NewSceneQueryBuilder() - - newVal, err := qb.DecrementOCounter(sceneID, tx) +func (r *mutationResolver) SceneDecrementO(ctx context.Context, id string) (ret int, err error) { + sceneID, err := strconv.Atoi(id) if err != nil { - _ = tx.Rollback() return 0, err } - // Commit - if err := tx.Commit(); err != nil { + if err := r.withTxn(ctx, func(repo models.Repository) error { + qb := repo.Scene() + + ret, err = qb.DecrementOCounter(sceneID) + return err + }); err != nil { return 0, err } - return newVal, nil + return ret, nil } -func (r *mutationResolver) SceneResetO(ctx context.Context, id string) (int, error) { - sceneID, _ := strconv.Atoi(id) - - tx := database.DB.MustBeginTx(ctx, nil) - qb := models.NewSceneQueryBuilder() - - newVal, err := qb.ResetOCounter(sceneID, tx) +func (r *mutationResolver) SceneResetO(ctx context.Context, id string) (ret int, err error) { + sceneID, err := strconv.Atoi(id) if err != nil { - _ = tx.Rollback() return 0, err } - // Commit - if err := tx.Commit(); err != nil { + if err := r.withTxn(ctx, func(repo models.Repository) error { + qb := repo.Scene() + + ret, err = qb.ResetOCounter(sceneID) + return err + }); err != nil { return 0, err } - return newVal, nil + return ret, nil } func (r *mutationResolver) SceneGenerateScreenshot(ctx context.Context, id string, at *float64) (string, error) { diff --git a/pkg/api/resolver_mutation_stash_box.go b/pkg/api/resolver_mutation_stash_box.go index 5dd1acfde..303b66dae 100644 --- a/pkg/api/resolver_mutation_stash_box.go +++ b/pkg/api/resolver_mutation_stash_box.go @@ -16,7 +16,7 @@ func (r *mutationResolver) SubmitStashBoxFingerprints(ctx context.Context, input return false, fmt.Errorf("invalid stash_box_index %d", input.StashBoxIndex) } - client := stashbox.NewClient(*boxes[input.StashBoxIndex]) + client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.txnManager) return client.SubmitStashBoxFingerprints(input.SceneIds, boxes[input.StashBoxIndex].Endpoint) } diff --git a/pkg/api/resolver_mutation_studio.go b/pkg/api/resolver_mutation_studio.go index c685d3155..bf1b5f312 100644 --- a/pkg/api/resolver_mutation_studio.go +++ b/pkg/api/resolver_mutation_studio.go @@ -6,7 +6,6 @@ import ( "strconv" "time" - "github.com/stashapp/stash/pkg/database" "github.com/stashapp/stash/pkg/manager" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/utils" @@ -44,41 +43,33 @@ func (r *mutationResolver) StudioCreate(ctx context.Context, input models.Studio } // Start the transaction and save the studio - tx := database.DB.MustBeginTx(ctx, nil) - qb := models.NewStudioQueryBuilder() - jqb := models.NewJoinsQueryBuilder() + var studio *models.Studio + if err := r.withTxn(ctx, func(repo models.Repository) error { + qb := repo.Studio() - studio, err := qb.Create(newStudio, tx) - if err != nil { - _ = tx.Rollback() - return nil, err - } - - // update image table - if len(imageData) > 0 { - if err := qb.UpdateStudioImage(studio.ID, imageData, tx); err != nil { - _ = tx.Rollback() - return nil, err + var err error + studio, err = qb.Create(newStudio) + if err != nil { + return err } - } - // Save the stash_ids - if input.StashIds != nil { - var stashIDJoins []models.StashID - for _, stashID := range input.StashIds { - newJoin := models.StashID{ - StashID: stashID.StashID, - Endpoint: stashID.Endpoint, + // update image table + if len(imageData) > 0 { + if err := qb.UpdateImage(studio.ID, imageData); err != nil { + return err } - stashIDJoins = append(stashIDJoins, newJoin) } - if err := jqb.UpdateStudioStashIDs(studio.ID, stashIDJoins, tx); err != nil { - return nil, err - } - } - // Commit - if err := tx.Commit(); err != nil { + // Save the stash_ids + if input.StashIds != nil { + stashIDJoins := models.StashIDsFromInput(input.StashIds) + if err := qb.UpdateStashIDs(studio.ID, stashIDJoins); err != nil { + return err + } + } + + return nil + }); err != nil { return nil, err } @@ -87,7 +78,10 @@ func (r *mutationResolver) StudioCreate(ctx context.Context, input models.Studio func (r *mutationResolver) StudioUpdate(ctx context.Context, input models.StudioUpdateInput) (*models.Studio, error) { // Populate studio from the input - studioID, _ := strconv.Atoi(input.ID) + studioID, err := strconv.Atoi(input.ID) + if err != nil { + return nil, err + } translator := changesetTranslator{ inputMap: getUpdateInputMap(ctx), @@ -118,52 +112,42 @@ func (r *mutationResolver) StudioUpdate(ctx context.Context, input models.Studio updatedStudio.ParentID = translator.nullInt64FromString(input.ParentID, "parent_id") // Start the transaction and save the studio - tx := database.DB.MustBeginTx(ctx, nil) - qb := models.NewStudioQueryBuilder() - jqb := models.NewJoinsQueryBuilder() + var studio *models.Studio + if err := r.withTxn(ctx, func(repo models.Repository) error { + qb := repo.Studio() - if err := manager.ValidateModifyStudio(updatedStudio, tx); err != nil { - tx.Rollback() - return nil, err - } - - studio, err := qb.Update(updatedStudio, tx) - if err != nil { - tx.Rollback() - return nil, err - } - - // update image table - if len(imageData) > 0 { - if err := qb.UpdateStudioImage(studio.ID, imageData, tx); err != nil { - tx.Rollback() - return nil, err + if err := manager.ValidateModifyStudio(updatedStudio, qb); err != nil { + return err } - } else if imageIncluded { - // must be unsetting - if err := qb.DestroyStudioImage(studio.ID, tx); err != nil { - tx.Rollback() - return nil, err - } - } - // Save the stash_ids - if translator.hasField("stash_ids") { - var stashIDJoins []models.StashID - for _, stashID := range input.StashIds { - newJoin := models.StashID{ - StashID: stashID.StashID, - Endpoint: stashID.Endpoint, + var err error + studio, err = qb.Update(updatedStudio) + if err != nil { + return err + } + + // update image table + if len(imageData) > 0 { + if err := qb.UpdateImage(studio.ID, imageData); err != nil { + return err + } + } else if imageIncluded { + // must be unsetting + if err := qb.DestroyImage(studio.ID); err != nil { + return err } - stashIDJoins = append(stashIDJoins, newJoin) } - if err := jqb.UpdateStudioStashIDs(studioID, stashIDJoins, tx); err != nil { - return nil, err - } - } - // Commit - if err := tx.Commit(); err != nil { + // Save the stash_ids + if translator.hasField("stash_ids") { + stashIDJoins := models.StashIDsFromInput(input.StashIds) + if err := qb.UpdateStashIDs(studioID, stashIDJoins); err != nil { + return err + } + } + + return nil + }); err != nil { return nil, err } @@ -171,28 +155,35 @@ func (r *mutationResolver) StudioUpdate(ctx context.Context, input models.Studio } func (r *mutationResolver) StudioDestroy(ctx context.Context, input models.StudioDestroyInput) (bool, error) { - qb := models.NewStudioQueryBuilder() - tx := database.DB.MustBeginTx(ctx, nil) - if err := qb.Destroy(input.ID, tx); err != nil { - _ = tx.Rollback() + id, err := strconv.Atoi(input.ID) + if err != nil { return false, err } - if err := tx.Commit(); err != nil { + + if err := r.withTxn(ctx, func(repo models.Repository) error { + return repo.Studio().Destroy(id) + }); err != nil { return false, err } return true, nil } -func (r *mutationResolver) StudiosDestroy(ctx context.Context, ids []string) (bool, error) { - qb := models.NewStudioQueryBuilder() - tx := database.DB.MustBeginTx(ctx, nil) - for _, id := range ids { - if err := qb.Destroy(id, tx); err != nil { - _ = tx.Rollback() - return false, err - } +func (r *mutationResolver) StudiosDestroy(ctx context.Context, studioIDs []string) (bool, error) { + ids, err := utils.StringSliceToIntSlice(studioIDs) + if err != nil { + return false, err } - if err := tx.Commit(); err != nil { + + if err := r.withTxn(ctx, func(repo models.Repository) error { + qb := repo.Studio() + for _, id := range ids { + if err := qb.Destroy(id); err != nil { + return err + } + } + + return nil + }); err != nil { return false, err } return true, nil diff --git a/pkg/api/resolver_mutation_tag.go b/pkg/api/resolver_mutation_tag.go index 28fe22de6..90ee479d1 100644 --- a/pkg/api/resolver_mutation_tag.go +++ b/pkg/api/resolver_mutation_tag.go @@ -6,7 +6,6 @@ import ( "strconv" "time" - "github.com/stashapp/stash/pkg/database" "github.com/stashapp/stash/pkg/manager" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/utils" @@ -33,31 +32,29 @@ func (r *mutationResolver) TagCreate(ctx context.Context, input models.TagCreate } // Start the transaction and save the tag - tx := database.DB.MustBeginTx(ctx, nil) - qb := models.NewTagQueryBuilder() + var tag *models.Tag + if err := r.withTxn(ctx, func(repo models.Repository) error { + qb := repo.Tag() - // ensure name is unique - if err := manager.EnsureTagNameUnique(newTag, tx); err != nil { - tx.Rollback() - return nil, err - } - - tag, err := qb.Create(newTag, tx) - if err != nil { - tx.Rollback() - return nil, err - } - - // update image table - if len(imageData) > 0 { - if err := qb.UpdateTagImage(tag.ID, imageData, tx); err != nil { - _ = tx.Rollback() - return nil, err + // ensure name is unique + if err := manager.EnsureTagNameUnique(newTag, qb); err != nil { + return err } - } - // Commit - if err := tx.Commit(); err != nil { + tag, err = qb.Create(newTag) + if err != nil { + return err + } + + // update image table + if len(imageData) > 0 { + if err := qb.UpdateImage(tag.ID, imageData); err != nil { + return err + } + } + + return nil + }); err != nil { return nil, err } @@ -66,7 +63,11 @@ func (r *mutationResolver) TagCreate(ctx context.Context, input models.TagCreate func (r *mutationResolver) TagUpdate(ctx context.Context, input models.TagUpdateInput) (*models.Tag, error) { // Populate tag from the input - tagID, _ := strconv.Atoi(input.ID) + tagID, err := strconv.Atoi(input.ID) + if err != nil { + return nil, err + } + updatedTag := models.Tag{ ID: tagID, Name: input.Name, @@ -74,7 +75,6 @@ func (r *mutationResolver) TagUpdate(ctx context.Context, input models.TagUpdate } var imageData []byte - var err error translator := changesetTranslator{ inputMap: getUpdateInputMap(ctx), @@ -90,50 +90,45 @@ func (r *mutationResolver) TagUpdate(ctx context.Context, input models.TagUpdate } // Start the transaction and save the tag - tx := database.DB.MustBeginTx(ctx, nil) - qb := models.NewTagQueryBuilder() + var tag *models.Tag + if err := r.withTxn(ctx, func(repo models.Repository) error { + qb := repo.Tag() - // ensure name is unique - existing, err := qb.Find(tagID, tx) - if err != nil { - tx.Rollback() - return nil, err - } - - if existing == nil { - tx.Rollback() - return nil, fmt.Errorf("Tag with ID %d not found", tagID) - } - - if existing.Name != updatedTag.Name { - if err := manager.EnsureTagNameUnique(updatedTag, tx); err != nil { - tx.Rollback() - return nil, err + // ensure name is unique + existing, err := qb.Find(tagID) + if err != nil { + return err } - } - tag, err := qb.Update(updatedTag, tx) - if err != nil { - _ = tx.Rollback() - return nil, err - } - - // update image table - if len(imageData) > 0 { - if err := qb.UpdateTagImage(tag.ID, imageData, tx); err != nil { - tx.Rollback() - return nil, err + if existing == nil { + return fmt.Errorf("Tag with ID %d not found", tagID) } - } else if imageIncluded { - // must be unsetting - if err := qb.DestroyTagImage(tag.ID, tx); err != nil { - tx.Rollback() - return nil, err - } - } - // Commit - if err := tx.Commit(); err != nil { + if existing.Name != updatedTag.Name { + if err := manager.EnsureTagNameUnique(updatedTag, qb); err != nil { + return err + } + } + + tag, err = qb.Update(updatedTag) + if err != nil { + return err + } + + // update image table + if len(imageData) > 0 { + if err := qb.UpdateImage(tag.ID, imageData); err != nil { + return err + } + } else if imageIncluded { + // must be unsetting + if err := qb.DestroyImage(tag.ID); err != nil { + return err + } + } + + return nil + }); err != nil { return nil, err } @@ -141,29 +136,35 @@ func (r *mutationResolver) TagUpdate(ctx context.Context, input models.TagUpdate } func (r *mutationResolver) TagDestroy(ctx context.Context, input models.TagDestroyInput) (bool, error) { - qb := models.NewTagQueryBuilder() - tx := database.DB.MustBeginTx(ctx, nil) - if err := qb.Destroy(input.ID, tx); err != nil { - _ = tx.Rollback() + tagID, err := strconv.Atoi(input.ID) + if err != nil { return false, err } - if err := tx.Commit(); err != nil { + + if err := r.withTxn(ctx, func(repo models.Repository) error { + return repo.Tag().Destroy(tagID) + }); err != nil { return false, err } return true, nil } -func (r *mutationResolver) TagsDestroy(ctx context.Context, ids []string) (bool, error) { - qb := models.NewTagQueryBuilder() - tx := database.DB.MustBeginTx(ctx, nil) +func (r *mutationResolver) TagsDestroy(ctx context.Context, tagIDs []string) (bool, error) { + ids, err := utils.StringSliceToIntSlice(tagIDs) + if err != nil { + return false, err + } - for _, id := range ids { - if err := qb.Destroy(id, tx); err != nil { - _ = tx.Rollback() - return false, err + if err := r.withTxn(ctx, func(repo models.Repository) error { + qb := repo.Tag() + for _, id := range ids { + if err := qb.Destroy(id); err != nil { + return err + } } - } - if err := tx.Commit(); err != nil { + + return nil + }); err != nil { return false, err } return true, nil diff --git a/pkg/api/resolver_mutation_tag_test.go b/pkg/api/resolver_mutation_tag_test.go new file mode 100644 index 000000000..d32490d9f --- /dev/null +++ b/pkg/api/resolver_mutation_tag_test.go @@ -0,0 +1,70 @@ +package api + +import ( + "context" + "errors" + "testing" + + "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/models/mocks" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +// TODO - move this into a common area +func newResolver() *Resolver { + return &Resolver{ + txnManager: mocks.NewTransactionManager(), + } +} + +const tagName = "tagName" +const errTagName = "errTagName" + +const existingTagID = 1 +const existingTagName = "existingTagName" +const newTagID = 2 + +func TestTagCreate(t *testing.T) { + r := newResolver() + + tagRW := r.txnManager.(*mocks.TransactionManager).Tag().(*mocks.TagReaderWriter) + tagRW.On("FindByName", existingTagName, true).Return(&models.Tag{ + ID: existingTagID, + Name: existingTagName, + }, nil).Once() + tagRW.On("FindByName", errTagName, true).Return(nil, nil).Once() + + expectedErr := errors.New("TagCreate error") + tagRW.On("Create", mock.AnythingOfType("models.Tag")).Return(nil, expectedErr) + + _, err := r.Mutation().TagCreate(context.TODO(), models.TagCreateInput{ + Name: existingTagName, + }) + + assert.NotNil(t, err) + + _, err = r.Mutation().TagCreate(context.TODO(), models.TagCreateInput{ + Name: errTagName, + }) + + assert.Equal(t, expectedErr, err) + tagRW.AssertExpectations(t) + + r = newResolver() + tagRW = r.txnManager.(*mocks.TransactionManager).Tag().(*mocks.TagReaderWriter) + + tagRW.On("FindByName", tagName, true).Return(nil, nil).Once() + tagRW.On("Create", mock.AnythingOfType("models.Tag")).Return(&models.Tag{ + ID: newTagID, + Name: tagName, + }, nil) + + tag, err := r.Mutation().TagCreate(context.TODO(), models.TagCreateInput{ + Name: tagName, + }) + + assert.Nil(t, err) + assert.NotNil(t, tag) +} diff --git a/pkg/api/resolver_query_find_gallery.go b/pkg/api/resolver_query_find_gallery.go index de5437efe..c00332a4f 100644 --- a/pkg/api/resolver_query_find_gallery.go +++ b/pkg/api/resolver_query_find_gallery.go @@ -7,17 +7,37 @@ import ( "github.com/stashapp/stash/pkg/models" ) -func (r *queryResolver) FindGallery(ctx context.Context, id string) (*models.Gallery, error) { - qb := models.NewGalleryQueryBuilder() - idInt, _ := strconv.Atoi(id) - return qb.Find(idInt, nil) +func (r *queryResolver) FindGallery(ctx context.Context, id string) (ret *models.Gallery, err error) { + idInt, err := strconv.Atoi(id) + if err != nil { + return nil, err + } + + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + ret, err = repo.Gallery().Find(idInt) + return err + }); err != nil { + return nil, err + } + + return ret, nil } -func (r *queryResolver) FindGalleries(ctx context.Context, galleryFilter *models.GalleryFilterType, filter *models.FindFilterType) (*models.FindGalleriesResultType, error) { - qb := models.NewGalleryQueryBuilder() - galleries, total := qb.Query(galleryFilter, filter) - return &models.FindGalleriesResultType{ - Count: total, - Galleries: galleries, - }, nil +func (r *queryResolver) FindGalleries(ctx context.Context, galleryFilter *models.GalleryFilterType, filter *models.FindFilterType) (ret *models.FindGalleriesResultType, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + galleries, total, err := repo.Gallery().Query(galleryFilter, filter) + if err != nil { + return err + } + + ret = &models.FindGalleriesResultType{ + Count: total, + Galleries: galleries, + } + return nil + }); err != nil { + return nil, err + } + + return ret, nil } diff --git a/pkg/api/resolver_query_find_image.go b/pkg/api/resolver_query_find_image.go index 471e80853..d3112e440 100644 --- a/pkg/api/resolver_query_find_image.go +++ b/pkg/api/resolver_query_find_image.go @@ -8,23 +8,48 @@ import ( ) func (r *queryResolver) FindImage(ctx context.Context, id *string, checksum *string) (*models.Image, error) { - qb := models.NewImageQueryBuilder() var image *models.Image - var err error - if id != nil { - idInt, _ := strconv.Atoi(*id) - image, err = qb.Find(idInt) - } else if checksum != nil { - image, err = qb.FindByChecksum(*checksum) + + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + qb := repo.Image() + var err error + + if id != nil { + idInt, err := strconv.Atoi(*id) + if err != nil { + return err + } + + image, err = qb.Find(idInt) + } else if checksum != nil { + image, err = qb.FindByChecksum(*checksum) + } + + return err + }); err != nil { + return nil, err } - return image, err + + return image, nil } -func (r *queryResolver) FindImages(ctx context.Context, imageFilter *models.ImageFilterType, imageIds []int, filter *models.FindFilterType) (*models.FindImagesResultType, error) { - qb := models.NewImageQueryBuilder() - images, total := qb.Query(imageFilter, filter) - return &models.FindImagesResultType{ - Count: total, - Images: images, - }, nil +func (r *queryResolver) FindImages(ctx context.Context, imageFilter *models.ImageFilterType, imageIds []int, filter *models.FindFilterType) (ret *models.FindImagesResultType, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + qb := repo.Image() + images, total, err := qb.Query(imageFilter, filter) + if err != nil { + return err + } + + ret = &models.FindImagesResultType{ + Count: total, + Images: images, + } + + return nil + }); err != nil { + return nil, err + } + + return ret, nil } diff --git a/pkg/api/resolver_query_find_movie.go b/pkg/api/resolver_query_find_movie.go index 9d23eddfc..16a024260 100644 --- a/pkg/api/resolver_query_find_movie.go +++ b/pkg/api/resolver_query_find_movie.go @@ -7,27 +7,60 @@ import ( "github.com/stashapp/stash/pkg/models" ) -func (r *queryResolver) FindMovie(ctx context.Context, id string) (*models.Movie, error) { - qb := models.NewMovieQueryBuilder() - idInt, _ := strconv.Atoi(id) - return qb.Find(idInt, nil) +func (r *queryResolver) FindMovie(ctx context.Context, id string) (ret *models.Movie, err error) { + idInt, err := strconv.Atoi(id) + if err != nil { + return nil, err + } + + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + ret, err = repo.Movie().Find(idInt) + return err + }); err != nil { + return nil, err + } + + return ret, nil } -func (r *queryResolver) FindMovies(ctx context.Context, movieFilter *models.MovieFilterType, filter *models.FindFilterType) (*models.FindMoviesResultType, error) { - qb := models.NewMovieQueryBuilder() - movies, total := qb.Query(movieFilter, filter) - return &models.FindMoviesResultType{ - Count: total, - Movies: movies, - }, nil +func (r *queryResolver) FindMovies(ctx context.Context, movieFilter *models.MovieFilterType, filter *models.FindFilterType) (ret *models.FindMoviesResultType, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + movies, total, err := repo.Movie().Query(movieFilter, filter) + if err != nil { + return err + } + + ret = &models.FindMoviesResultType{ + Count: total, + Movies: movies, + } + + return nil + }); err != nil { + return nil, err + } + + return ret, nil } -func (r *queryResolver) AllMovies(ctx context.Context) ([]*models.Movie, error) { - qb := models.NewMovieQueryBuilder() - return qb.All() +func (r *queryResolver) AllMovies(ctx context.Context) (ret []*models.Movie, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + ret, err = repo.Movie().All() + return err + }); err != nil { + return nil, err + } + + return ret, nil } -func (r *queryResolver) AllMoviesSlim(ctx context.Context) ([]*models.Movie, error) { - qb := models.NewMovieQueryBuilder() - return qb.AllSlim() +func (r *queryResolver) AllMoviesSlim(ctx context.Context) (ret []*models.Movie, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + ret, err = repo.Movie().AllSlim() + return err + }); err != nil { + return nil, err + } + + return ret, nil } diff --git a/pkg/api/resolver_query_find_performer.go b/pkg/api/resolver_query_find_performer.go index efb694910..741d5f094 100644 --- a/pkg/api/resolver_query_find_performer.go +++ b/pkg/api/resolver_query_find_performer.go @@ -2,31 +2,64 @@ package api import ( "context" - "github.com/stashapp/stash/pkg/models" "strconv" + + "github.com/stashapp/stash/pkg/models" ) -func (r *queryResolver) FindPerformer(ctx context.Context, id string) (*models.Performer, error) { - qb := models.NewPerformerQueryBuilder() - idInt, _ := strconv.Atoi(id) - return qb.Find(idInt) +func (r *queryResolver) FindPerformer(ctx context.Context, id string) (ret *models.Performer, err error) { + idInt, err := strconv.Atoi(id) + if err != nil { + return nil, err + } + + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + ret, err = repo.Performer().Find(idInt) + return err + }); err != nil { + return nil, err + } + + return ret, nil } -func (r *queryResolver) FindPerformers(ctx context.Context, performerFilter *models.PerformerFilterType, filter *models.FindFilterType) (*models.FindPerformersResultType, error) { - qb := models.NewPerformerQueryBuilder() - performers, total := qb.Query(performerFilter, filter) - return &models.FindPerformersResultType{ - Count: total, - Performers: performers, - }, nil +func (r *queryResolver) FindPerformers(ctx context.Context, performerFilter *models.PerformerFilterType, filter *models.FindFilterType) (ret *models.FindPerformersResultType, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + performers, total, err := repo.Performer().Query(performerFilter, filter) + if err != nil { + return err + } + + ret = &models.FindPerformersResultType{ + Count: total, + Performers: performers, + } + return nil + }); err != nil { + return nil, err + } + + return ret, nil } -func (r *queryResolver) AllPerformers(ctx context.Context) ([]*models.Performer, error) { - qb := models.NewPerformerQueryBuilder() - return qb.All() +func (r *queryResolver) AllPerformers(ctx context.Context) (ret []*models.Performer, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + ret, err = repo.Performer().All() + return err + }); err != nil { + return nil, err + } + + return ret, nil } -func (r *queryResolver) AllPerformersSlim(ctx context.Context) ([]*models.Performer, error) { - qb := models.NewPerformerQueryBuilder() - return qb.AllSlim() +func (r *queryResolver) AllPerformersSlim(ctx context.Context) (ret []*models.Performer, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + ret, err = repo.Performer().AllSlim() + return err + }); err != nil { + return nil, err + } + + return ret, nil } diff --git a/pkg/api/resolver_query_find_scene.go b/pkg/api/resolver_query_find_scene.go index 781e6c5b3..dcb59842a 100644 --- a/pkg/api/resolver_query_find_scene.go +++ b/pkg/api/resolver_query_find_scene.go @@ -9,70 +9,114 @@ import ( ) func (r *queryResolver) FindScene(ctx context.Context, id *string, checksum *string) (*models.Scene, error) { - qb := models.NewSceneQueryBuilder() var scene *models.Scene - var err error - if id != nil { - idInt, _ := strconv.Atoi(*id) - scene, err = qb.Find(idInt) - } else if checksum != nil { - scene, err = qb.FindByChecksum(*checksum) - } - return scene, err -} - -func (r *queryResolver) FindSceneByHash(ctx context.Context, input models.SceneHashInput) (*models.Scene, error) { - qb := models.NewSceneQueryBuilder() - var scene *models.Scene - var err error - - if input.Checksum != nil { - scene, err = qb.FindByChecksum(*input.Checksum) - if err != nil { - return nil, err + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + qb := repo.Scene() + var err error + if id != nil { + idInt, err := strconv.Atoi(*id) + if err != nil { + return err + } + scene, err = qb.Find(idInt) + } else if checksum != nil { + scene, err = qb.FindByChecksum(*checksum) } - } - if scene == nil && input.Oshash != nil { - scene, err = qb.FindByOSHash(*input.Oshash) - if err != nil { - return nil, err - } - } - - return scene, err -} - -func (r *queryResolver) FindScenes(ctx context.Context, sceneFilter *models.SceneFilterType, sceneIds []int, filter *models.FindFilterType) (*models.FindScenesResultType, error) { - qb := models.NewSceneQueryBuilder() - scenes, total := qb.Query(sceneFilter, filter) - return &models.FindScenesResultType{ - Count: total, - Scenes: scenes, - }, nil -} - -func (r *queryResolver) FindScenesByPathRegex(ctx context.Context, filter *models.FindFilterType) (*models.FindScenesResultType, error) { - qb := models.NewSceneQueryBuilder() - - scenes, total := qb.QueryByPathRegex(filter) - return &models.FindScenesResultType{ - Count: total, - Scenes: scenes, - }, nil -} - -func (r *queryResolver) ParseSceneFilenames(ctx context.Context, filter *models.FindFilterType, config models.SceneParserInput) (*models.SceneParserResultType, error) { - parser := manager.NewSceneFilenameParser(filter, config) - - result, count, err := parser.Parse() - - if err != nil { + return err + }); err != nil { return nil, err } - return &models.SceneParserResultType{ - Count: count, - Results: result, - }, nil + return scene, nil +} + +func (r *queryResolver) FindSceneByHash(ctx context.Context, input models.SceneHashInput) (*models.Scene, error) { + var scene *models.Scene + + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + qb := repo.Scene() + var err error + if input.Checksum != nil { + scene, err = qb.FindByChecksum(*input.Checksum) + if err != nil { + return err + } + } + + if scene == nil && input.Oshash != nil { + scene, err = qb.FindByOSHash(*input.Oshash) + if err != nil { + return err + } + } + + return nil + }); err != nil { + return nil, err + } + + return scene, nil +} + +func (r *queryResolver) FindScenes(ctx context.Context, sceneFilter *models.SceneFilterType, sceneIds []int, filter *models.FindFilterType) (ret *models.FindScenesResultType, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + scenes, total, err := repo.Scene().Query(sceneFilter, filter) + if err != nil { + return err + } + ret = &models.FindScenesResultType{ + Count: total, + Scenes: scenes, + } + + return nil + }); err != nil { + return nil, err + } + + return ret, nil +} + +func (r *queryResolver) FindScenesByPathRegex(ctx context.Context, filter *models.FindFilterType) (ret *models.FindScenesResultType, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + scenes, total, err := repo.Scene().QueryByPathRegex(filter) + if err != nil { + return err + } + + ret = &models.FindScenesResultType{ + Count: total, + Scenes: scenes, + } + + return nil + }); err != nil { + return nil, err + } + + return ret, nil +} + +func (r *queryResolver) ParseSceneFilenames(ctx context.Context, filter *models.FindFilterType, config models.SceneParserInput) (ret *models.SceneParserResultType, err error) { + parser := manager.NewSceneFilenameParser(filter, config) + + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + result, count, err := parser.Parse(repo) + + if err != nil { + return err + } + + ret = &models.SceneParserResultType{ + Count: count, + Results: result, + } + + return nil + }); err != nil { + return nil, err + } + + return ret, nil } diff --git a/pkg/api/resolver_query_find_scene_marker.go b/pkg/api/resolver_query_find_scene_marker.go index 71b9b927b..a3d6b6058 100644 --- a/pkg/api/resolver_query_find_scene_marker.go +++ b/pkg/api/resolver_query_find_scene_marker.go @@ -2,14 +2,25 @@ package api import ( "context" + "github.com/stashapp/stash/pkg/models" ) -func (r *queryResolver) FindSceneMarkers(ctx context.Context, sceneMarkerFilter *models.SceneMarkerFilterType, filter *models.FindFilterType) (*models.FindSceneMarkersResultType, error) { - qb := models.NewSceneMarkerQueryBuilder() - sceneMarkers, total := qb.Query(sceneMarkerFilter, filter) - return &models.FindSceneMarkersResultType{ - Count: total, - SceneMarkers: sceneMarkers, - }, nil +func (r *queryResolver) FindSceneMarkers(ctx context.Context, sceneMarkerFilter *models.SceneMarkerFilterType, filter *models.FindFilterType) (ret *models.FindSceneMarkersResultType, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + sceneMarkers, total, err := repo.SceneMarker().Query(sceneMarkerFilter, filter) + if err != nil { + return err + } + ret = &models.FindSceneMarkersResultType{ + Count: total, + SceneMarkers: sceneMarkers, + } + + return nil + }); err != nil { + return nil, err + } + + return ret, nil } diff --git a/pkg/api/resolver_query_find_studio.go b/pkg/api/resolver_query_find_studio.go index 25528c517..6df2ecae7 100644 --- a/pkg/api/resolver_query_find_studio.go +++ b/pkg/api/resolver_query_find_studio.go @@ -2,31 +2,66 @@ package api import ( "context" - "github.com/stashapp/stash/pkg/models" "strconv" + + "github.com/stashapp/stash/pkg/models" ) -func (r *queryResolver) FindStudio(ctx context.Context, id string) (*models.Studio, error) { - qb := models.NewStudioQueryBuilder() - idInt, _ := strconv.Atoi(id) - return qb.Find(idInt, nil) +func (r *queryResolver) FindStudio(ctx context.Context, id string) (ret *models.Studio, err error) { + idInt, err := strconv.Atoi(id) + if err != nil { + return nil, err + } + + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + var err error + ret, err = repo.Studio().Find(idInt) + return err + }); err != nil { + return nil, err + } + + return ret, nil } -func (r *queryResolver) FindStudios(ctx context.Context, studioFilter *models.StudioFilterType, filter *models.FindFilterType) (*models.FindStudiosResultType, error) { - qb := models.NewStudioQueryBuilder() - studios, total := qb.Query(studioFilter, filter) - return &models.FindStudiosResultType{ - Count: total, - Studios: studios, - }, nil +func (r *queryResolver) FindStudios(ctx context.Context, studioFilter *models.StudioFilterType, filter *models.FindFilterType) (ret *models.FindStudiosResultType, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + studios, total, err := repo.Studio().Query(studioFilter, filter) + if err != nil { + return err + } + + ret = &models.FindStudiosResultType{ + Count: total, + Studios: studios, + } + + return nil + }); err != nil { + return nil, err + } + + return ret, nil } -func (r *queryResolver) AllStudios(ctx context.Context) ([]*models.Studio, error) { - qb := models.NewStudioQueryBuilder() - return qb.All() +func (r *queryResolver) AllStudios(ctx context.Context) (ret []*models.Studio, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + ret, err = repo.Studio().All() + return err + }); err != nil { + return nil, err + } + + return ret, nil } -func (r *queryResolver) AllStudiosSlim(ctx context.Context) ([]*models.Studio, error) { - qb := models.NewStudioQueryBuilder() - return qb.AllSlim() +func (r *queryResolver) AllStudiosSlim(ctx context.Context) (ret []*models.Studio, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + ret, err = repo.Studio().AllSlim() + return err + }); err != nil { + return nil, err + } + + return ret, nil } diff --git a/pkg/api/resolver_query_find_tag.go b/pkg/api/resolver_query_find_tag.go index 7e8727b9b..c373c42bd 100644 --- a/pkg/api/resolver_query_find_tag.go +++ b/pkg/api/resolver_query_find_tag.go @@ -7,27 +7,60 @@ import ( "github.com/stashapp/stash/pkg/models" ) -func (r *queryResolver) FindTag(ctx context.Context, id string) (*models.Tag, error) { - qb := models.NewTagQueryBuilder() - idInt, _ := strconv.Atoi(id) - return qb.Find(idInt, nil) +func (r *queryResolver) FindTag(ctx context.Context, id string) (ret *models.Tag, err error) { + idInt, err := strconv.Atoi(id) + if err != nil { + return nil, err + } + + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + ret, err = repo.Tag().Find(idInt) + return err + }); err != nil { + return nil, err + } + + return ret, nil } -func (r *queryResolver) FindTags(ctx context.Context, tagFilter *models.TagFilterType, filter *models.FindFilterType) (*models.FindTagsResultType, error) { - qb := models.NewTagQueryBuilder() - tags, total := qb.Query(tagFilter, filter) - return &models.FindTagsResultType{ - Count: total, - Tags: tags, - }, nil +func (r *queryResolver) FindTags(ctx context.Context, tagFilter *models.TagFilterType, filter *models.FindFilterType) (ret *models.FindTagsResultType, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + tags, total, err := repo.Tag().Query(tagFilter, filter) + if err != nil { + return err + } + + ret = &models.FindTagsResultType{ + Count: total, + Tags: tags, + } + + return nil + }); err != nil { + return nil, err + } + + return ret, nil } -func (r *queryResolver) AllTags(ctx context.Context) ([]*models.Tag, error) { - qb := models.NewTagQueryBuilder() - return qb.All() +func (r *queryResolver) AllTags(ctx context.Context) (ret []*models.Tag, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + ret, err = repo.Tag().All() + return err + }); err != nil { + return nil, err + } + + return ret, nil } -func (r *queryResolver) AllTagsSlim(ctx context.Context) ([]*models.Tag, error) { - qb := models.NewTagQueryBuilder() - return qb.AllSlim() +func (r *queryResolver) AllTagsSlim(ctx context.Context) (ret []*models.Tag, err error) { + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + ret, err = repo.Tag().AllSlim() + return err + }); err != nil { + return nil, err + } + + return ret, nil } diff --git a/pkg/api/resolver_query_scene.go b/pkg/api/resolver_query_scene.go index faa745f5e..bffbcfd6f 100644 --- a/pkg/api/resolver_query_scene.go +++ b/pkg/api/resolver_query_scene.go @@ -12,11 +12,13 @@ import ( func (r *queryResolver) SceneStreams(ctx context.Context, id *string) ([]*models.SceneStreamEndpoint, error) { // find the scene - qb := models.NewSceneQueryBuilder() - idInt, _ := strconv.Atoi(*id) - scene, err := qb.Find(idInt) - - if err != nil { + var scene *models.Scene + if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error { + idInt, _ := strconv.Atoi(*id) + var err error + scene, err = repo.Scene().Find(idInt) + return err + }); err != nil { return nil, err } diff --git a/pkg/api/resolver_query_scraper.go b/pkg/api/resolver_query_scraper.go index 04be0d9ee..7a197a025 100644 --- a/pkg/api/resolver_query_scraper.go +++ b/pkg/api/resolver_query_scraper.go @@ -95,7 +95,7 @@ func (r *queryResolver) QueryStashBoxScene(ctx context.Context, input models.Sta return nil, fmt.Errorf("invalid stash_box_index %d", input.StashBoxIndex) } - client := stashbox.NewClient(*boxes[input.StashBoxIndex]) + client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.txnManager) if len(input.SceneIds) > 0 { return client.FindStashBoxScenesByFingerprints(input.SceneIds) diff --git a/pkg/api/routes_image.go b/pkg/api/routes_image.go index cc182fb0a..3e64887a8 100644 --- a/pkg/api/routes_image.go +++ b/pkg/api/routes_image.go @@ -12,7 +12,9 @@ import ( "github.com/stashapp/stash/pkg/utils" ) -type imageRoutes struct{} +type imageRoutes struct { + txnManager models.TransactionManager +} func (rs imageRoutes) Routes() chi.Router { r := chi.NewRouter() @@ -57,12 +59,16 @@ func ImageCtx(next http.Handler) http.Handler { imageID, _ := strconv.Atoi(imageIdentifierQueryParam) var image *models.Image - qb := models.NewImageQueryBuilder() - if imageID == 0 { - image, _ = qb.FindByChecksum(imageIdentifierQueryParam) - } else { - image, _ = qb.Find(imageID) - } + manager.GetInstance().TxnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { + qb := repo.Image() + if imageID == 0 { + image, _ = qb.FindByChecksum(imageIdentifierQueryParam) + } else { + image, _ = qb.Find(imageID) + } + + return nil + }) if image == nil { http.Error(w, http.StatusText(404), 404) diff --git a/pkg/api/routes_movie.go b/pkg/api/routes_movie.go index 00c882d6e..988d4ddf9 100644 --- a/pkg/api/routes_movie.go +++ b/pkg/api/routes_movie.go @@ -6,11 +6,14 @@ import ( "strconv" "github.com/go-chi/chi" + "github.com/stashapp/stash/pkg/manager" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/utils" ) -type movieRoutes struct{} +type movieRoutes struct { + txnManager models.TransactionManager +} func (rs movieRoutes) Routes() chi.Router { r := chi.NewRouter() @@ -26,11 +29,16 @@ func (rs movieRoutes) Routes() chi.Router { func (rs movieRoutes) FrontImage(w http.ResponseWriter, r *http.Request) { movie := r.Context().Value(movieKey).(*models.Movie) - qb := models.NewMovieQueryBuilder() - image, _ := qb.GetFrontImage(movie.ID, nil) - defaultParam := r.URL.Query().Get("default") - if len(image) == 0 || defaultParam == "true" { + var image []byte + if defaultParam != "true" { + rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { + image, _ = repo.Movie().GetFrontImage(movie.ID) + return nil + }) + } + + if len(image) == 0 { _, image, _ = utils.ProcessBase64Image(models.DefaultMovieImage) } @@ -39,11 +47,16 @@ func (rs movieRoutes) FrontImage(w http.ResponseWriter, r *http.Request) { func (rs movieRoutes) BackImage(w http.ResponseWriter, r *http.Request) { movie := r.Context().Value(movieKey).(*models.Movie) - qb := models.NewMovieQueryBuilder() - image, _ := qb.GetBackImage(movie.ID, nil) - defaultParam := r.URL.Query().Get("default") - if len(image) == 0 || defaultParam == "true" { + var image []byte + if defaultParam != "true" { + rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { + image, _ = repo.Movie().GetBackImage(movie.ID) + return nil + }) + } + + if len(image) == 0 { _, image, _ = utils.ProcessBase64Image(models.DefaultMovieImage) } @@ -58,9 +71,12 @@ func MovieCtx(next http.Handler) http.Handler { return } - qb := models.NewMovieQueryBuilder() - movie, err := qb.Find(movieID, nil) - if err != nil { + var movie *models.Movie + if err := manager.GetInstance().TxnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { + var err error + movie, err = repo.Movie().Find(movieID) + return err + }); err != nil { http.Error(w, http.StatusText(404), 404) return } diff --git a/pkg/api/routes_performer.go b/pkg/api/routes_performer.go index d552b2e89..7e50c4551 100644 --- a/pkg/api/routes_performer.go +++ b/pkg/api/routes_performer.go @@ -6,11 +6,14 @@ import ( "strconv" "github.com/go-chi/chi" + "github.com/stashapp/stash/pkg/manager" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/utils" ) -type performerRoutes struct{} +type performerRoutes struct { + txnManager models.TransactionManager +} func (rs performerRoutes) Routes() chi.Router { r := chi.NewRouter() @@ -25,10 +28,16 @@ func (rs performerRoutes) Routes() chi.Router { func (rs performerRoutes) Image(w http.ResponseWriter, r *http.Request) { performer := r.Context().Value(performerKey).(*models.Performer) - qb := models.NewPerformerQueryBuilder() - image, _ := qb.GetPerformerImage(performer.ID, nil) - defaultParam := r.URL.Query().Get("default") + + var image []byte + if defaultParam != "true" { + rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { + image, _ = repo.Performer().GetImage(performer.ID) + return nil + }) + } + if len(image) == 0 || defaultParam == "true" { image, _ = getRandomPerformerImageUsingName(performer.Name.String, performer.Gender.String) } @@ -44,9 +53,12 @@ func PerformerCtx(next http.Handler) http.Handler { return } - qb := models.NewPerformerQueryBuilder() - performer, err := qb.Find(performerID) - if err != nil { + var performer *models.Performer + if err := manager.GetInstance().TxnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { + var err error + performer, err = repo.Performer().Find(performerID) + return err + }); err != nil { http.Error(w, http.StatusText(404), 404) return } diff --git a/pkg/api/routes_scene.go b/pkg/api/routes_scene.go index e979ecf4b..328983be9 100644 --- a/pkg/api/routes_scene.go +++ b/pkg/api/routes_scene.go @@ -15,7 +15,9 @@ import ( "github.com/stashapp/stash/pkg/utils" ) -type sceneRoutes struct{} +type sceneRoutes struct { + txnManager models.TransactionManager +} func (rs sceneRoutes) Routes() chi.Router { r := chi.NewRouter() @@ -183,8 +185,11 @@ func (rs sceneRoutes) Screenshot(w http.ResponseWriter, r *http.Request) { if screenshotExists { http.ServeFile(w, r, filepath) } else { - qb := models.NewSceneQueryBuilder() - cover, _ := qb.GetSceneCover(scene.ID, nil) + var cover []byte + rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { + cover, _ = repo.Scene().GetCover(scene.ID) + return nil + }) utils.ServeImage(cover, w, r) } } @@ -201,39 +206,48 @@ func (rs sceneRoutes) Webp(w http.ResponseWriter, r *http.Request) { http.ServeFile(w, r, filepath) } -func getChapterVttTitle(marker *models.SceneMarker) string { +func (rs sceneRoutes) getChapterVttTitle(ctx context.Context, marker *models.SceneMarker) string { if marker.Title != "" { return marker.Title } - qb := models.NewTagQueryBuilder() - primaryTag, err := qb.Find(marker.PrimaryTagID, nil) - if err != nil { - // should not happen + var ret string + if err := rs.txnManager.WithReadTxn(ctx, func(repo models.ReaderRepository) error { + qb := repo.Tag() + primaryTag, err := qb.Find(marker.PrimaryTagID) + if err != nil { + return err + } + + ret = primaryTag.Name + + tags, err := qb.FindBySceneMarkerID(marker.ID) + if err != nil { + return err + } + + for _, t := range tags { + ret += ", " + t.Name + } + + return nil + }); err != nil { panic(err) } - ret := primaryTag.Name - - tags, err := qb.FindBySceneMarkerID(marker.ID, nil) - if err != nil { - // should not happen - panic(err) - } - - for _, t := range tags { - ret += ", " + t.Name - } - return ret } func (rs sceneRoutes) ChapterVtt(w http.ResponseWriter, r *http.Request) { scene := r.Context().Value(sceneKey).(*models.Scene) - qb := models.NewSceneMarkerQueryBuilder() - sceneMarkers, err := qb.FindBySceneID(scene.ID, nil) - if err != nil { - panic("invalid scene markers for chapter vtt") + var sceneMarkers []*models.SceneMarker + if err := rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { + var err error + sceneMarkers, err = repo.SceneMarker().FindBySceneID(scene.ID) + return err + }); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return } vttLines := []string{"WEBVTT", ""} @@ -241,7 +255,7 @@ func (rs sceneRoutes) ChapterVtt(w http.ResponseWriter, r *http.Request) { vttLines = append(vttLines, strconv.Itoa(i+1)) time := utils.GetVTTTime(marker.Seconds) vttLines = append(vttLines, time+" --> "+time) - vttLines = append(vttLines, getChapterVttTitle(marker)) + vttLines = append(vttLines, rs.getChapterVttTitle(r.Context(), marker)) vttLines = append(vttLines, "") } vtt := strings.Join(vttLines, "\n") @@ -267,11 +281,14 @@ func (rs sceneRoutes) VttSprite(w http.ResponseWriter, r *http.Request) { func (rs sceneRoutes) SceneMarkerStream(w http.ResponseWriter, r *http.Request) { scene := r.Context().Value(sceneKey).(*models.Scene) sceneMarkerID, _ := strconv.Atoi(chi.URLParam(r, "sceneMarkerId")) - qb := models.NewSceneMarkerQueryBuilder() - sceneMarker, err := qb.Find(sceneMarkerID) - if err != nil { - logger.Warn("Error when getting scene marker for stream") - http.Error(w, http.StatusText(404), 404) + var sceneMarker *models.SceneMarker + if err := rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { + var err error + sceneMarker, err = repo.SceneMarker().Find(sceneMarkerID) + return err + }); err != nil { + logger.Warnf("Error when getting scene marker for stream: %s", err.Error()) + http.Error(w, http.StatusText(500), 500) return } filepath := manager.GetInstance().Paths.SceneMarkers.GetStreamPath(scene.GetHash(config.GetVideoFileNamingAlgorithm()), int(sceneMarker.Seconds)) @@ -281,11 +298,14 @@ func (rs sceneRoutes) SceneMarkerStream(w http.ResponseWriter, r *http.Request) func (rs sceneRoutes) SceneMarkerPreview(w http.ResponseWriter, r *http.Request) { scene := r.Context().Value(sceneKey).(*models.Scene) sceneMarkerID, _ := strconv.Atoi(chi.URLParam(r, "sceneMarkerId")) - qb := models.NewSceneMarkerQueryBuilder() - sceneMarker, err := qb.Find(sceneMarkerID) - if err != nil { - logger.Warn("Error when getting scene marker for stream") - http.Error(w, http.StatusText(404), 404) + var sceneMarker *models.SceneMarker + if err := rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { + var err error + sceneMarker, err = repo.SceneMarker().Find(sceneMarkerID) + return err + }); err != nil { + logger.Warnf("Error when getting scene marker for stream: %s", err.Error()) + http.Error(w, http.StatusText(500), 500) return } filepath := manager.GetInstance().Paths.SceneMarkers.GetStreamPreviewImagePath(scene.GetHash(config.GetVideoFileNamingAlgorithm()), int(sceneMarker.Seconds)) @@ -310,17 +330,21 @@ func SceneCtx(next http.Handler) http.Handler { sceneID, _ := strconv.Atoi(sceneIdentifierQueryParam) var scene *models.Scene - qb := models.NewSceneQueryBuilder() - if sceneID == 0 { - // determine checksum/os by the length of the query param - if len(sceneIdentifierQueryParam) == 32 { - scene, _ = qb.FindByChecksum(sceneIdentifierQueryParam) + manager.GetInstance().TxnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { + qb := repo.Scene() + if sceneID == 0 { + // determine checksum/os by the length of the query param + if len(sceneIdentifierQueryParam) == 32 { + scene, _ = qb.FindByChecksum(sceneIdentifierQueryParam) + } else { + scene, _ = qb.FindByOSHash(sceneIdentifierQueryParam) + } } else { - scene, _ = qb.FindByOSHash(sceneIdentifierQueryParam) + scene, _ = qb.Find(sceneID) } - } else { - scene, _ = qb.Find(sceneID) - } + + return nil + }) if scene == nil { http.Error(w, http.StatusText(404), 404) diff --git a/pkg/api/routes_studio.go b/pkg/api/routes_studio.go index e8c72e5c1..840e060ad 100644 --- a/pkg/api/routes_studio.go +++ b/pkg/api/routes_studio.go @@ -6,11 +6,14 @@ import ( "strconv" "github.com/go-chi/chi" + "github.com/stashapp/stash/pkg/manager" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/utils" ) -type studioRoutes struct{} +type studioRoutes struct { + txnManager models.TransactionManager +} func (rs studioRoutes) Routes() chi.Router { r := chi.NewRouter() @@ -25,12 +28,14 @@ func (rs studioRoutes) Routes() chi.Router { func (rs studioRoutes) Image(w http.ResponseWriter, r *http.Request) { studio := r.Context().Value(studioKey).(*models.Studio) - qb := models.NewStudioQueryBuilder() - var image []byte defaultParam := r.URL.Query().Get("default") + var image []byte if defaultParam != "true" { - image, _ = qb.GetStudioImage(studio.ID, nil) + rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { + image, _ = repo.Studio().GetImage(studio.ID) + return nil + }) } if len(image) == 0 { @@ -48,9 +53,12 @@ func StudioCtx(next http.Handler) http.Handler { return } - qb := models.NewStudioQueryBuilder() - studio, err := qb.Find(studioID, nil) - if err != nil { + var studio *models.Studio + if err := manager.GetInstance().TxnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { + var err error + studio, err = repo.Studio().Find(studioID) + return err + }); err != nil { http.Error(w, http.StatusText(404), 404) return } diff --git a/pkg/api/routes_tag.go b/pkg/api/routes_tag.go index 0364eb95d..42d64271c 100644 --- a/pkg/api/routes_tag.go +++ b/pkg/api/routes_tag.go @@ -6,11 +6,14 @@ import ( "strconv" "github.com/go-chi/chi" + "github.com/stashapp/stash/pkg/manager" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/utils" ) -type tagRoutes struct{} +type tagRoutes struct { + txnManager models.TransactionManager +} func (rs tagRoutes) Routes() chi.Router { r := chi.NewRouter() @@ -25,12 +28,17 @@ func (rs tagRoutes) Routes() chi.Router { func (rs tagRoutes) Image(w http.ResponseWriter, r *http.Request) { tag := r.Context().Value(tagKey).(*models.Tag) - qb := models.NewTagQueryBuilder() - image, _ := qb.GetTagImage(tag.ID, nil) - - // use default image if not present defaultParam := r.URL.Query().Get("default") - if len(image) == 0 || defaultParam == "true" { + + var image []byte + if defaultParam != "true" { + rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { + image, _ = repo.Tag().GetImage(tag.ID) + return nil + }) + } + + if len(image) == 0 { image = models.DefaultTagImage } @@ -45,9 +53,12 @@ func TagCtx(next http.Handler) http.Handler { return } - qb := models.NewTagQueryBuilder() - tag, err := qb.Find(tagID, nil) - if err != nil { + var tag *models.Tag + if err := manager.GetInstance().TxnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error { + var err error + tag, err = repo.Tag().Find(tagID) + return err + }); err != nil { http.Error(w, http.StatusText(404), 404) return } diff --git a/pkg/api/server.go b/pkg/api/server.go index ab34e8353..5fe5b2784 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -134,7 +134,12 @@ func Start() { return true }, }) - gqlHandler := handler.GraphQL(models.NewExecutableSchema(models.Config{Resolvers: &Resolver{}}), recoverFunc, websocketUpgrader) + + txnManager := manager.GetInstance().TxnManager + resolver := &Resolver{ + txnManager: txnManager, + } + gqlHandler := handler.GraphQL(models.NewExecutableSchema(models.Config{Resolvers: resolver}), recoverFunc, websocketUpgrader) r.Handle("/graphql", gqlHandler) r.Handle("/playground", handler.Playground("GraphQL playground", "/graphql")) @@ -145,12 +150,24 @@ func Start() { r.Get(loginEndPoint, getLoginHandler) - r.Mount("/performer", performerRoutes{}.Routes()) - r.Mount("/scene", sceneRoutes{}.Routes()) - r.Mount("/image", imageRoutes{}.Routes()) - r.Mount("/studio", studioRoutes{}.Routes()) - r.Mount("/movie", movieRoutes{}.Routes()) - r.Mount("/tag", tagRoutes{}.Routes()) + r.Mount("/performer", performerRoutes{ + txnManager: txnManager, + }.Routes()) + r.Mount("/scene", sceneRoutes{ + txnManager: txnManager, + }.Routes()) + r.Mount("/image", imageRoutes{ + txnManager: txnManager, + }.Routes()) + r.Mount("/studio", studioRoutes{ + txnManager: txnManager, + }.Routes()) + r.Mount("/movie", movieRoutes{ + txnManager: txnManager, + }.Routes()) + r.Mount("/tag", tagRoutes{ + txnManager: txnManager, + }.Routes()) r.Mount("/downloads", downloadsRoutes{}.Routes()) r.HandleFunc("/css", func(w http.ResponseWriter, r *http.Request) { diff --git a/pkg/database/database.go b/pkg/database/database.go index 231fe22b7..4fef0cedd 100644 --- a/pkg/database/database.go +++ b/pkg/database/database.go @@ -68,9 +68,9 @@ func Initialize(databasePath string) bool { func open(databasePath string, disableForeignKeys bool) *sqlx.DB { // https://github.com/mattn/go-sqlite3 - url := "file:" + databasePath + url := "file:" + databasePath + "?_journal=WAL" if !disableForeignKeys { - url += "?_fk=true" + url += "&_fk=true" } conn, err := sqlx.Open(sqlite3Driver, url) diff --git a/pkg/gallery/export_test.go b/pkg/gallery/export_test.go index e0b29038a..56b3e540a 100644 --- a/pkg/gallery/export_test.go +++ b/pkg/gallery/export_test.go @@ -6,7 +6,6 @@ import ( "github.com/stashapp/stash/pkg/manager/jsonschema" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models/mocks" - "github.com/stashapp/stash/pkg/models/modelstest" "github.com/stretchr/testify/assert" "testing" @@ -51,18 +50,18 @@ var updateTime time.Time = time.Date(2002, 01, 01, 0, 0, 0, 0, time.UTC) func createFullGallery(id int) models.Gallery { return models.Gallery{ ID: id, - Path: modelstest.NullString(path), + Path: models.NullString(path), Zip: zip, - Title: modelstest.NullString(title), + Title: models.NullString(title), Checksum: checksum, Date: models.SQLiteDate{ String: date, Valid: true, }, - Details: modelstest.NullString(details), - Rating: modelstest.NullInt64(rating), + Details: models.NullString(details), + Rating: models.NullInt64(rating), Organized: organized, - URL: modelstest.NullString(url), + URL: models.NullString(url), CreatedAt: models.SQLiteTimestamp{ Timestamp: createTime, }, @@ -146,7 +145,7 @@ func TestToJSON(t *testing.T) { func createStudioGallery(studioID int) models.Gallery { return models.Gallery{ - StudioID: modelstest.NullInt64(int64(studioID)), + StudioID: models.NullInt64(int64(studioID)), } } @@ -180,7 +179,7 @@ func TestGetStudioName(t *testing.T) { studioErr := errors.New("error getting image") mockStudioReader.On("Find", studioID).Return(&models.Studio{ - Name: modelstest.NullString(studioName), + Name: models.NullString(studioName), }, nil).Once() mockStudioReader.On("Find", missingStudioID).Return(nil, nil).Once() mockStudioReader.On("Find", errStudioID).Return(nil, studioErr).Once() diff --git a/pkg/gallery/images.go b/pkg/gallery/images.go deleted file mode 100644 index 5487c797e..000000000 --- a/pkg/gallery/images.go +++ /dev/null @@ -1,30 +0,0 @@ -package gallery - -import ( - "github.com/stashapp/stash/pkg/api/urlbuilders" - "github.com/stashapp/stash/pkg/models" -) - -func GetFiles(g *models.Gallery, baseURL string) []*models.GalleryFilesType { - var galleryFiles []*models.GalleryFilesType - - qb := models.NewImageQueryBuilder() - images, err := qb.FindByGalleryID(g.ID) - if err != nil { - return nil - } - - for i, img := range images { - builder := urlbuilders.NewImageURLBuilder(baseURL, img.ID) - imageURL := builder.GetImageURL() - - galleryFile := models.GalleryFilesType{ - Index: i, - Name: &img.Title.String, - Path: &imageURL, - } - galleryFiles = append(galleryFiles, &galleryFile) - } - - return galleryFiles -} diff --git a/pkg/gallery/import.go b/pkg/gallery/import.go index edc75d3b7..7ad87cdbe 100644 --- a/pkg/gallery/import.go +++ b/pkg/gallery/import.go @@ -15,7 +15,6 @@ type Importer struct { StudioWriter models.StudioReaderWriter PerformerWriter models.PerformerReaderWriter TagWriter models.TagReaderWriter - JoinWriter models.JoinReaderWriter Input jsonschema.Gallery MissingRefBehaviour models.ImportMissingRefEnum @@ -237,29 +236,22 @@ func (i *Importer) createTags(names []string) ([]*models.Tag, error) { func (i *Importer) PostImport(id int) error { if len(i.performers) > 0 { - var performerJoins []models.PerformersGalleries + var performerIDs []int for _, performer := range i.performers { - join := models.PerformersGalleries{ - PerformerID: performer.ID, - GalleryID: id, - } - performerJoins = append(performerJoins, join) + performerIDs = append(performerIDs, performer.ID) } - if err := i.JoinWriter.UpdatePerformersGalleries(id, performerJoins); err != nil { + + if err := i.ReaderWriter.UpdatePerformers(id, performerIDs); err != nil { return fmt.Errorf("failed to associate performers: %s", err.Error()) } } if len(i.tags) > 0 { - var tagJoins []models.GalleriesTags - for _, tag := range i.tags { - join := models.GalleriesTags{ - GalleryID: id, - TagID: tag.ID, - } - tagJoins = append(tagJoins, join) + var tagIDs []int + for _, t := range i.tags { + tagIDs = append(tagIDs, t.ID) } - if err := i.JoinWriter.UpdateGalleriesTags(id, tagJoins); err != nil { + if err := i.ReaderWriter.UpdateTags(id, tagIDs); err != nil { return fmt.Errorf("failed to associate tags: %s", err.Error()) } } diff --git a/pkg/gallery/import_test.go b/pkg/gallery/import_test.go index d88738190..19de14f35 100644 --- a/pkg/gallery/import_test.go +++ b/pkg/gallery/import_test.go @@ -8,7 +8,6 @@ import ( "github.com/stashapp/stash/pkg/manager/jsonschema" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models/mocks" - "github.com/stashapp/stash/pkg/models/modelstest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) @@ -77,17 +76,17 @@ func TestImporterPreImport(t *testing.T) { assert.Nil(t, err) expectedGallery := models.Gallery{ - Path: modelstest.NullString(path), + Path: models.NullString(path), Checksum: checksum, - Title: modelstest.NullString(title), + Title: models.NullString(title), Date: models.SQLiteDate{ String: date, Valid: true, }, - Details: modelstest.NullString(details), - Rating: modelstest.NullInt64(rating), + Details: models.NullString(details), + Rating: models.NullInt64(rating), Organized: organized, - URL: modelstest.NullString(url), + URL: models.NullString(url), CreatedAt: models.SQLiteTimestamp{ Timestamp: createdAt, }, @@ -194,7 +193,7 @@ func TestImporterPreImportWithPerformer(t *testing.T) { performerReaderWriter.On("FindByNames", []string{existingPerformerName}, false).Return([]*models.Performer{ { ID: existingPerformerID, - Name: modelstest.NullString(existingPerformerName), + Name: models.NullString(existingPerformerName), }, }, nil).Once() performerReaderWriter.On("FindByNames", []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once() @@ -354,10 +353,10 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) { } func TestImporterPostImportUpdatePerformers(t *testing.T) { - joinReaderWriter := &mocks.JoinReaderWriter{} + galleryReaderWriter := &mocks.GalleryReaderWriter{} i := Importer{ - JoinWriter: joinReaderWriter, + ReaderWriter: galleryReaderWriter, performers: []*models.Performer{ { ID: existingPerformerID, @@ -365,15 +364,10 @@ func TestImporterPostImportUpdatePerformers(t *testing.T) { }, } - updateErr := errors.New("UpdatePerformersGalleries error") + updateErr := errors.New("UpdatePerformers error") - joinReaderWriter.On("UpdatePerformersGalleries", galleryID, []models.PerformersGalleries{ - { - PerformerID: existingPerformerID, - GalleryID: galleryID, - }, - }).Return(nil).Once() - joinReaderWriter.On("UpdatePerformersGalleries", errPerformersID, mock.AnythingOfType("[]models.PerformersGalleries")).Return(updateErr).Once() + galleryReaderWriter.On("UpdatePerformers", galleryID, []int{existingPerformerID}).Return(nil).Once() + galleryReaderWriter.On("UpdatePerformers", errPerformersID, mock.AnythingOfType("[]int")).Return(updateErr).Once() err := i.PostImport(galleryID) assert.Nil(t, err) @@ -381,14 +375,14 @@ func TestImporterPostImportUpdatePerformers(t *testing.T) { err = i.PostImport(errPerformersID) assert.NotNil(t, err) - joinReaderWriter.AssertExpectations(t) + galleryReaderWriter.AssertExpectations(t) } func TestImporterPostImportUpdateTags(t *testing.T) { - joinReaderWriter := &mocks.JoinReaderWriter{} + galleryReaderWriter := &mocks.GalleryReaderWriter{} i := Importer{ - JoinWriter: joinReaderWriter, + ReaderWriter: galleryReaderWriter, tags: []*models.Tag{ { ID: existingTagID, @@ -396,15 +390,10 @@ func TestImporterPostImportUpdateTags(t *testing.T) { }, } - updateErr := errors.New("UpdateGalleriesTags error") + updateErr := errors.New("UpdateTags error") - joinReaderWriter.On("UpdateGalleriesTags", galleryID, []models.GalleriesTags{ - { - TagID: existingTagID, - GalleryID: galleryID, - }, - }).Return(nil).Once() - joinReaderWriter.On("UpdateGalleriesTags", errTagsID, mock.AnythingOfType("[]models.GalleriesTags")).Return(updateErr).Once() + galleryReaderWriter.On("UpdateTags", galleryID, []int{existingTagID}).Return(nil).Once() + galleryReaderWriter.On("UpdateTags", errTagsID, mock.AnythingOfType("[]int")).Return(updateErr).Once() err := i.PostImport(galleryID) assert.Nil(t, err) @@ -412,7 +401,7 @@ func TestImporterPostImportUpdateTags(t *testing.T) { err = i.PostImport(errTagsID) assert.NotNil(t, err) - joinReaderWriter.AssertExpectations(t) + galleryReaderWriter.AssertExpectations(t) } func TestImporterFindExistingID(t *testing.T) { @@ -454,11 +443,11 @@ func TestCreate(t *testing.T) { readerWriter := &mocks.GalleryReaderWriter{} gallery := models.Gallery{ - Title: modelstest.NullString(title), + Title: models.NullString(title), } galleryErr := models.Gallery{ - Title: modelstest.NullString(galleryNameErr), + Title: models.NullString(galleryNameErr), } i := Importer{ @@ -488,7 +477,7 @@ func TestUpdate(t *testing.T) { readerWriter := &mocks.GalleryReaderWriter{} gallery := models.Gallery{ - Title: modelstest.NullString(title), + Title: models.NullString(title), } i := Importer{ diff --git a/pkg/gallery/update.go b/pkg/gallery/update.go new file mode 100644 index 000000000..6befc5b1c --- /dev/null +++ b/pkg/gallery/update.go @@ -0,0 +1,23 @@ +package gallery + +import ( + "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/utils" +) + +func UpdateFileModTime(qb models.GalleryWriter, id int, modTime models.NullSQLiteTimestamp) (*models.Gallery, error) { + return qb.UpdatePartial(models.GalleryPartial{ + ID: id, + FileModTime: &modTime, + }) +} + +func AddImage(qb models.GalleryReaderWriter, galleryID int, imageID int) error { + imageIDs, err := qb.GetImageIDs(galleryID) + if err != nil { + return err + } + + imageIDs = utils.IntAppendUnique(imageIDs, imageID) + return qb.UpdateImages(galleryID, imageIDs) +} diff --git a/pkg/image/export_test.go b/pkg/image/export_test.go index b1283298e..9c7ab5ed1 100644 --- a/pkg/image/export_test.go +++ b/pkg/image/export_test.go @@ -6,7 +6,6 @@ import ( "github.com/stashapp/stash/pkg/manager/jsonschema" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models/mocks" - "github.com/stashapp/stash/pkg/models/modelstest" "github.com/stretchr/testify/assert" "testing" @@ -65,14 +64,14 @@ var updateTime time.Time = time.Date(2002, 01, 01, 0, 0, 0, 0, time.UTC) func createFullImage(id int) models.Image { return models.Image{ ID: id, - Title: modelstest.NullString(title), + Title: models.NullString(title), Checksum: checksum, - Height: modelstest.NullInt64(height), + Height: models.NullInt64(height), OCounter: ocounter, - Rating: modelstest.NullInt64(rating), + Rating: models.NullInt64(rating), + Size: models.NullInt64(int64(size)), Organized: organized, - Size: modelstest.NullInt64(int64(size)), - Width: modelstest.NullInt64(width), + Width: models.NullInt64(width), CreatedAt: models.SQLiteTimestamp{ Timestamp: createTime, }, @@ -150,7 +149,7 @@ func TestToJSON(t *testing.T) { func createStudioImage(studioID int) models.Image { return models.Image{ - StudioID: modelstest.NullInt64(int64(studioID)), + StudioID: models.NullInt64(int64(studioID)), } } @@ -184,7 +183,7 @@ func TestGetStudioName(t *testing.T) { studioErr := errors.New("error getting image") mockStudioReader.On("Find", studioID).Return(&models.Studio{ - Name: modelstest.NullString(studioName), + Name: models.NullString(studioName), }, nil).Once() mockStudioReader.On("Find", missingStudioID).Return(nil, nil).Once() mockStudioReader.On("Find", errStudioID).Return(nil, studioErr).Once() diff --git a/pkg/image/import.go b/pkg/image/import.go index a74b561b1..7b0640595 100644 --- a/pkg/image/import.go +++ b/pkg/image/import.go @@ -16,7 +16,6 @@ type Importer struct { GalleryWriter models.GalleryReaderWriter PerformerWriter models.PerformerReaderWriter TagWriter models.TagReaderWriter - JoinWriter models.JoinReaderWriter Input jsonschema.Image Path string MissingRefBehaviour models.ImportMissingRefEnum @@ -227,43 +226,33 @@ func (i *Importer) populateTags() error { func (i *Importer) PostImport(id int) error { if len(i.galleries) > 0 { - var galleryJoins []models.GalleriesImages - for _, gallery := range i.galleries { - join := models.GalleriesImages{ - GalleryID: gallery.ID, - ImageID: id, - } - galleryJoins = append(galleryJoins, join) + var galleryIDs []int + for _, g := range i.galleries { + galleryIDs = append(galleryIDs, g.ID) } - if err := i.JoinWriter.UpdateGalleriesImages(id, galleryJoins); err != nil { + + if err := i.ReaderWriter.UpdateGalleries(id, galleryIDs); err != nil { return fmt.Errorf("failed to associate galleries: %s", err.Error()) } } if len(i.performers) > 0 { - var performerJoins []models.PerformersImages + var performerIDs []int for _, performer := range i.performers { - join := models.PerformersImages{ - PerformerID: performer.ID, - ImageID: id, - } - performerJoins = append(performerJoins, join) + performerIDs = append(performerIDs, performer.ID) } - if err := i.JoinWriter.UpdatePerformersImages(id, performerJoins); err != nil { + + if err := i.ReaderWriter.UpdatePerformers(id, performerIDs); err != nil { return fmt.Errorf("failed to associate performers: %s", err.Error()) } } if len(i.tags) > 0 { - var tagJoins []models.ImagesTags - for _, tag := range i.tags { - join := models.ImagesTags{ - ImageID: id, - TagID: tag.ID, - } - tagJoins = append(tagJoins, join) + var tagIDs []int + for _, t := range i.tags { + tagIDs = append(tagIDs, t.ID) } - if err := i.JoinWriter.UpdateImagesTags(id, tagJoins); err != nil { + if err := i.ReaderWriter.UpdateTags(id, tagIDs); err != nil { return fmt.Errorf("failed to associate tags: %s", err.Error()) } } diff --git a/pkg/image/import_test.go b/pkg/image/import_test.go index 53414666e..91978dac8 100644 --- a/pkg/image/import_test.go +++ b/pkg/image/import_test.go @@ -7,7 +7,6 @@ import ( "github.com/stashapp/stash/pkg/manager/jsonschema" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models/mocks" - "github.com/stashapp/stash/pkg/models/modelstest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) @@ -227,7 +226,7 @@ func TestImporterPreImportWithPerformer(t *testing.T) { performerReaderWriter.On("FindByNames", []string{existingPerformerName}, false).Return([]*models.Performer{ { ID: existingPerformerID, - Name: modelstest.NullString(existingPerformerName), + Name: models.NullString(existingPerformerName), }, }, nil).Once() performerReaderWriter.On("FindByNames", []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once() @@ -387,10 +386,10 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) { } func TestImporterPostImportUpdateGallery(t *testing.T) { - joinReaderWriter := &mocks.JoinReaderWriter{} + readerWriter := &mocks.ImageReaderWriter{} i := Importer{ - JoinWriter: joinReaderWriter, + ReaderWriter: readerWriter, galleries: []*models.Gallery{ { ID: existingGalleryID, @@ -398,15 +397,10 @@ func TestImporterPostImportUpdateGallery(t *testing.T) { }, } - updateErr := errors.New("UpdateGalleriesImages error") + updateErr := errors.New("UpdateGalleries error") - joinReaderWriter.On("UpdateGalleriesImages", imageID, []models.GalleriesImages{ - { - GalleryID: existingGalleryID, - ImageID: imageID, - }, - }).Return(nil).Once() - joinReaderWriter.On("UpdateGalleriesImages", errGalleriesID, mock.AnythingOfType("[]models.GalleriesImages")).Return(updateErr).Once() + readerWriter.On("UpdateGalleries", imageID, []int{existingGalleryID}).Return(nil).Once() + readerWriter.On("UpdateGalleries", errGalleriesID, mock.AnythingOfType("[]int")).Return(updateErr).Once() err := i.PostImport(imageID) assert.Nil(t, err) @@ -414,14 +408,14 @@ func TestImporterPostImportUpdateGallery(t *testing.T) { err = i.PostImport(errGalleriesID) assert.NotNil(t, err) - joinReaderWriter.AssertExpectations(t) + readerWriter.AssertExpectations(t) } func TestImporterPostImportUpdatePerformers(t *testing.T) { - joinReaderWriter := &mocks.JoinReaderWriter{} + readerWriter := &mocks.ImageReaderWriter{} i := Importer{ - JoinWriter: joinReaderWriter, + ReaderWriter: readerWriter, performers: []*models.Performer{ { ID: existingPerformerID, @@ -429,15 +423,10 @@ func TestImporterPostImportUpdatePerformers(t *testing.T) { }, } - updateErr := errors.New("UpdatePerformersImages error") + updateErr := errors.New("UpdatePerformers error") - joinReaderWriter.On("UpdatePerformersImages", imageID, []models.PerformersImages{ - { - PerformerID: existingPerformerID, - ImageID: imageID, - }, - }).Return(nil).Once() - joinReaderWriter.On("UpdatePerformersImages", errPerformersID, mock.AnythingOfType("[]models.PerformersImages")).Return(updateErr).Once() + readerWriter.On("UpdatePerformers", imageID, []int{existingPerformerID}).Return(nil).Once() + readerWriter.On("UpdatePerformers", errPerformersID, mock.AnythingOfType("[]int")).Return(updateErr).Once() err := i.PostImport(imageID) assert.Nil(t, err) @@ -445,14 +434,14 @@ func TestImporterPostImportUpdatePerformers(t *testing.T) { err = i.PostImport(errPerformersID) assert.NotNil(t, err) - joinReaderWriter.AssertExpectations(t) + readerWriter.AssertExpectations(t) } func TestImporterPostImportUpdateTags(t *testing.T) { - joinReaderWriter := &mocks.JoinReaderWriter{} + readerWriter := &mocks.ImageReaderWriter{} i := Importer{ - JoinWriter: joinReaderWriter, + ReaderWriter: readerWriter, tags: []*models.Tag{ { ID: existingTagID, @@ -460,15 +449,10 @@ func TestImporterPostImportUpdateTags(t *testing.T) { }, } - updateErr := errors.New("UpdateImagesTags error") + updateErr := errors.New("UpdateTags error") - joinReaderWriter.On("UpdateImagesTags", imageID, []models.ImagesTags{ - { - TagID: existingTagID, - ImageID: imageID, - }, - }).Return(nil).Once() - joinReaderWriter.On("UpdateImagesTags", errTagsID, mock.AnythingOfType("[]models.ImagesTags")).Return(updateErr).Once() + readerWriter.On("UpdateTags", imageID, []int{existingTagID}).Return(nil).Once() + readerWriter.On("UpdateTags", errTagsID, mock.AnythingOfType("[]int")).Return(updateErr).Once() err := i.PostImport(imageID) assert.Nil(t, err) @@ -476,7 +460,7 @@ func TestImporterPostImportUpdateTags(t *testing.T) { err = i.PostImport(errTagsID) assert.NotNil(t, err) - joinReaderWriter.AssertExpectations(t) + readerWriter.AssertExpectations(t) } func TestImporterFindExistingID(t *testing.T) { @@ -518,11 +502,11 @@ func TestCreate(t *testing.T) { readerWriter := &mocks.ImageReaderWriter{} image := models.Image{ - Title: modelstest.NullString(title), + Title: models.NullString(title), } imageErr := models.Image{ - Title: modelstest.NullString(imageNameErr), + Title: models.NullString(imageNameErr), } i := Importer{ @@ -553,11 +537,11 @@ func TestUpdate(t *testing.T) { readerWriter := &mocks.ImageReaderWriter{} image := models.Image{ - Title: modelstest.NullString(title), + Title: models.NullString(title), } imageErr := models.Image{ - Title: modelstest.NullString(imageNameErr), + Title: models.NullString(imageNameErr), } i := Importer{ diff --git a/pkg/image/update.go b/pkg/image/update.go new file mode 100644 index 000000000..3728d3187 --- /dev/null +++ b/pkg/image/update.go @@ -0,0 +1,10 @@ +package image + +import "github.com/stashapp/stash/pkg/models" + +func UpdateFileModTime(qb models.ImageWriter, id int, modTime models.NullSQLiteTimestamp) (*models.Image, error) { + return qb.Update(models.ImagePartial{ + ID: id, + FileModTime: &modTime, + }) +} diff --git a/pkg/manager/checksum.go b/pkg/manager/checksum.go index 87fd47962..244549dab 100644 --- a/pkg/manager/checksum.go +++ b/pkg/manager/checksum.go @@ -1,6 +1,7 @@ package manager import ( + "context" "errors" "github.com/spf13/viper" @@ -9,13 +10,16 @@ import ( "github.com/stashapp/stash/pkg/models" ) -func setInitialMD5Config() { +func setInitialMD5Config(txnManager models.TransactionManager) { // if there are no scene files in the database, then default the // VideoFileNamingAlgorithm config setting to oshash and calculateMD5 to // false, otherwise set them to true for backwards compatibility purposes - sqb := models.NewSceneQueryBuilder() - count, err := sqb.Count() - if err != nil { + var count int + if err := txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { + var err error + count, err = r.Scene().Count() + return err + }); err != nil { logger.Errorf("Error while counting scenes: %s", err.Error()) return } @@ -43,28 +47,30 @@ func setInitialMD5Config() { // // Likewise, if VideoFileNamingAlgorithm is set to oshash, then this function // will ensure that all oshash values are set on all scenes. -func ValidateVideoFileNamingAlgorithm(newValue models.HashAlgorithm) error { +func ValidateVideoFileNamingAlgorithm(txnManager models.TransactionManager, newValue models.HashAlgorithm) error { // if algorithm is being set to MD5, then all checksums must be present - qb := models.NewSceneQueryBuilder() - if newValue == models.HashAlgorithmMd5 { - missingMD5, err := qb.CountMissingChecksum() - if err != nil { - return err + return txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { + qb := r.Scene() + if newValue == models.HashAlgorithmMd5 { + missingMD5, err := qb.CountMissingChecksum() + if err != nil { + return err + } + + if missingMD5 > 0 { + return errors.New("some checksums are missing on scenes. Run Scan with calculateMD5 set to true") + } + } else if newValue == models.HashAlgorithmOshash { + missingOSHash, err := qb.CountMissingOSHash() + if err != nil { + return err + } + + if missingOSHash > 0 { + return errors.New("some oshash values are missing on scenes. Run Scan to populate") + } } - if missingMD5 > 0 { - return errors.New("some checksums are missing on scenes. Run Scan with calculateMD5 set to true") - } - } else if newValue == models.HashAlgorithmOshash { - missingOSHash, err := qb.CountMissingOSHash() - if err != nil { - return err - } - - if missingOSHash > 0 { - return errors.New("some oshash values are missing on scenes. Run Scan to populate") - } - } - - return nil + return nil + }) } diff --git a/pkg/manager/filename_parser.go b/pkg/manager/filename_parser.go index eae6ba365..35f8c6297 100644 --- a/pkg/manager/filename_parser.go +++ b/pkg/manager/filename_parser.go @@ -433,12 +433,6 @@ type SceneFilenameParser struct { studioCache map[string]*models.Studio movieCache map[string]*models.Movie tagCache map[string]*models.Tag - - performerQuery performerQueryer - sceneQuery sceneQueryer - tagQuery tagQueryer - studioQuery studioQueryer - movieQuery movieQueryer } func NewSceneFilenameParser(filter *models.FindFilterType, config models.SceneParserInput) *SceneFilenameParser { @@ -455,21 +449,6 @@ func NewSceneFilenameParser(filter *models.FindFilterType, config models.ScenePa p.initWhiteSpaceRegex() - performerQuery := models.NewPerformerQueryBuilder() - p.performerQuery = &performerQuery - - sceneQuery := models.NewSceneQueryBuilder() - p.sceneQuery = &sceneQuery - - tagQuery := models.NewTagQueryBuilder() - p.tagQuery = &tagQuery - - studioQuery := models.NewStudioQueryBuilder() - p.studioQuery = &studioQuery - - movieQuery := models.NewMovieQueryBuilder() - p.movieQuery = &movieQuery - return p } @@ -489,7 +468,7 @@ func (p *SceneFilenameParser) initWhiteSpaceRegex() { } } -func (p *SceneFilenameParser) Parse() ([]*models.SceneParserResult, int, error) { +func (p *SceneFilenameParser) Parse(repo models.ReaderRepository) ([]*models.SceneParserResult, int, error) { // perform the query to find the scenes mapper, err := newParseMapper(p.Pattern, p.ParserInput.IgnoreWords) @@ -499,14 +478,17 @@ func (p *SceneFilenameParser) Parse() ([]*models.SceneParserResult, int, error) p.Filter.Q = &mapper.regexString - scenes, total := p.sceneQuery.QueryByPathRegex(p.Filter) + scenes, total, err := repo.Scene().QueryByPathRegex(p.Filter) + if err != nil { + return nil, 0, err + } - ret := p.parseScenes(scenes, mapper) + ret := p.parseScenes(repo, scenes, mapper) return ret, total, nil } -func (p *SceneFilenameParser) parseScenes(scenes []*models.Scene, mapper *parseMapper) []*models.SceneParserResult { +func (p *SceneFilenameParser) parseScenes(repo models.ReaderRepository, scenes []*models.Scene, mapper *parseMapper) []*models.SceneParserResult { var ret []*models.SceneParserResult for _, scene := range scenes { sceneHolder := mapper.parse(scene) @@ -515,7 +497,7 @@ func (p *SceneFilenameParser) parseScenes(scenes []*models.Scene, mapper *parseM r := &models.SceneParserResult{ Scene: scene, } - p.setParserResult(*sceneHolder, r) + p.setParserResult(repo, *sceneHolder, r) if r != nil { ret = append(ret, r) @@ -536,7 +518,7 @@ func (p SceneFilenameParser) replaceWhitespaceCharacters(value string) string { return value } -func (p *SceneFilenameParser) queryPerformer(performerName string) *models.Performer { +func (p *SceneFilenameParser) queryPerformer(qb models.PerformerReader, performerName string) *models.Performer { // massage the performer name performerName = delimiterRE.ReplaceAllString(performerName, " ") @@ -546,7 +528,7 @@ func (p *SceneFilenameParser) queryPerformer(performerName string) *models.Perfo } // perform an exact match and grab the first - performers, _ := p.performerQuery.FindByNames([]string{performerName}, nil, true) + performers, _ := qb.FindByNames([]string{performerName}, true) var ret *models.Performer if len(performers) > 0 { @@ -559,7 +541,7 @@ func (p *SceneFilenameParser) queryPerformer(performerName string) *models.Perfo return ret } -func (p *SceneFilenameParser) queryStudio(studioName string) *models.Studio { +func (p *SceneFilenameParser) queryStudio(qb models.StudioReader, studioName string) *models.Studio { // massage the performer name studioName = delimiterRE.ReplaceAllString(studioName, " ") @@ -568,7 +550,7 @@ func (p *SceneFilenameParser) queryStudio(studioName string) *models.Studio { return ret } - ret, _ := p.studioQuery.FindByName(studioName, nil, true) + ret, _ := qb.FindByName(studioName, true) // add result to cache p.studioCache[studioName] = ret @@ -576,7 +558,7 @@ func (p *SceneFilenameParser) queryStudio(studioName string) *models.Studio { return ret } -func (p *SceneFilenameParser) queryMovie(movieName string) *models.Movie { +func (p *SceneFilenameParser) queryMovie(qb models.MovieReader, movieName string) *models.Movie { // massage the movie name movieName = delimiterRE.ReplaceAllString(movieName, " ") @@ -585,7 +567,7 @@ func (p *SceneFilenameParser) queryMovie(movieName string) *models.Movie { return ret } - ret, _ := p.movieQuery.FindByName(movieName, nil, true) + ret, _ := qb.FindByName(movieName, true) // add result to cache p.movieCache[movieName] = ret @@ -593,7 +575,7 @@ func (p *SceneFilenameParser) queryMovie(movieName string) *models.Movie { return ret } -func (p *SceneFilenameParser) queryTag(tagName string) *models.Tag { +func (p *SceneFilenameParser) queryTag(qb models.TagReader, tagName string) *models.Tag { // massage the performer name tagName = delimiterRE.ReplaceAllString(tagName, " ") @@ -603,7 +585,7 @@ func (p *SceneFilenameParser) queryTag(tagName string) *models.Tag { } // match tag name exactly - ret, _ := p.tagQuery.FindByName(tagName, nil, true) + ret, _ := qb.FindByName(tagName, true) // add result to cache p.tagCache[tagName] = ret @@ -611,12 +593,12 @@ func (p *SceneFilenameParser) queryTag(tagName string) *models.Tag { return ret } -func (p *SceneFilenameParser) setPerformers(h sceneHolder, result *models.SceneParserResult) { +func (p *SceneFilenameParser) setPerformers(qb models.PerformerReader, h sceneHolder, result *models.SceneParserResult) { // query for each performer performersSet := make(map[int]bool) for _, performerName := range h.performers { if performerName != "" { - performer := p.queryPerformer(performerName) + performer := p.queryPerformer(qb, performerName) if performer != nil { if _, found := performersSet[performer.ID]; !found { result.PerformerIds = append(result.PerformerIds, strconv.Itoa(performer.ID)) @@ -627,12 +609,12 @@ func (p *SceneFilenameParser) setPerformers(h sceneHolder, result *models.SceneP } } -func (p *SceneFilenameParser) setTags(h sceneHolder, result *models.SceneParserResult) { +func (p *SceneFilenameParser) setTags(qb models.TagReader, h sceneHolder, result *models.SceneParserResult) { // query for each performer tagsSet := make(map[int]bool) for _, tagName := range h.tags { if tagName != "" { - tag := p.queryTag(tagName) + tag := p.queryTag(qb, tagName) if tag != nil { if _, found := tagsSet[tag.ID]; !found { result.TagIds = append(result.TagIds, strconv.Itoa(tag.ID)) @@ -643,23 +625,23 @@ func (p *SceneFilenameParser) setTags(h sceneHolder, result *models.SceneParserR } } -func (p *SceneFilenameParser) setStudio(h sceneHolder, result *models.SceneParserResult) { +func (p *SceneFilenameParser) setStudio(qb models.StudioReader, h sceneHolder, result *models.SceneParserResult) { // query for each performer if h.studio != "" { - studio := p.queryStudio(h.studio) + studio := p.queryStudio(qb, h.studio) if studio != nil { - studioId := strconv.Itoa(studio.ID) - result.StudioID = &studioId + studioID := strconv.Itoa(studio.ID) + result.StudioID = &studioID } } } -func (p *SceneFilenameParser) setMovies(h sceneHolder, result *models.SceneParserResult) { +func (p *SceneFilenameParser) setMovies(qb models.MovieReader, h sceneHolder, result *models.SceneParserResult) { // query for each movie moviesSet := make(map[int]bool) for _, movieName := range h.movies { if movieName != "" { - movie := p.queryMovie(movieName) + movie := p.queryMovie(qb, movieName) if movie != nil { if _, found := moviesSet[movie.ID]; !found { result.Movies = append(result.Movies, &models.SceneMovieID{ @@ -672,7 +654,7 @@ func (p *SceneFilenameParser) setMovies(h sceneHolder, result *models.SceneParse } } -func (p *SceneFilenameParser) setParserResult(h sceneHolder, result *models.SceneParserResult) { +func (p *SceneFilenameParser) setParserResult(repo models.ReaderRepository, h sceneHolder, result *models.SceneParserResult) { if h.result.Title.Valid { title := h.result.Title.String title = p.replaceWhitespaceCharacters(title) @@ -694,15 +676,15 @@ func (p *SceneFilenameParser) setParserResult(h sceneHolder, result *models.Scen } if len(h.performers) > 0 { - p.setPerformers(h, result) + p.setPerformers(repo.Performer(), h, result) } if len(h.tags) > 0 { - p.setTags(h, result) + p.setTags(repo.Tag(), h, result) } - p.setStudio(h, result) + p.setStudio(repo.Studio(), h, result) if len(h.movies) > 0 { - p.setMovies(h, result) + p.setMovies(repo.Movie(), h, result) } } diff --git a/pkg/manager/image.go b/pkg/manager/image.go index 4c6b3682b..8af987233 100644 --- a/pkg/manager/image.go +++ b/pkg/manager/image.go @@ -5,43 +5,11 @@ import ( "os" "strings" - "github.com/jmoiron/sqlx" - "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/utils" ) -// DestroyImage deletes an image and its associated relationships from the -// database. -func DestroyImage(imageID int, tx *sqlx.Tx) error { - qb := models.NewImageQueryBuilder() - jqb := models.NewJoinsQueryBuilder() - - _, err := qb.Find(imageID) - if err != nil { - return err - } - - if err := jqb.DestroyImagesTags(imageID, tx); err != nil { - return err - } - - if err := jqb.DestroyPerformersImages(imageID, tx); err != nil { - return err - } - - if err := jqb.DestroyImageGalleries(imageID, tx); err != nil { - return err - } - - if err := qb.Destroy(imageID, tx); err != nil { - return err - } - - return nil -} - // DeleteGeneratedImageFiles deletes generated files for the provided image. func DeleteGeneratedImageFiles(image *models.Image) { thumbPath := GetInstance().Paths.Generated.GetThumbnailPath(image.Checksum, models.DefaultGthumbWidth) diff --git a/pkg/manager/manager.go b/pkg/manager/manager.go index 5f0679fc3..65deba48b 100644 --- a/pkg/manager/manager.go +++ b/pkg/manager/manager.go @@ -10,8 +10,10 @@ import ( "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/manager/config" "github.com/stashapp/stash/pkg/manager/paths" + "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/plugin" "github.com/stashapp/stash/pkg/scraper" + "github.com/stashapp/stash/pkg/sqlite" "github.com/stashapp/stash/pkg/utils" ) @@ -26,6 +28,8 @@ type singleton struct { ScraperCache *scraper.Cache DownloadStore *DownloadStore + + TxnManager models.TransactionManager } var instance *singleton @@ -53,11 +57,12 @@ func Initialize() *singleton { Status: TaskStatus{Status: Idle, Progress: -1}, Paths: paths.NewPaths(), - PluginCache: initPluginCache(), - ScraperCache: initScraperCache(), + PluginCache: initPluginCache(), DownloadStore: NewDownloadStore(), + TxnManager: sqlite.NewTransactionManager(), } + instance.ScraperCache = instance.initScraperCache() instance.RefreshConfig() @@ -180,13 +185,13 @@ func initPluginCache() *plugin.Cache { } // initScraperCache initializes a new scraper cache and returns it. -func initScraperCache() *scraper.Cache { +func (s *singleton) initScraperCache() *scraper.Cache { scraperConfig := scraper.GlobalConfig{ Path: config.GetScrapersPath(), UserAgent: config.GetScraperUserAgent(), CDPPath: config.GetScraperCDPPath(), } - ret, err := scraper.NewCache(scraperConfig) + ret, err := scraper.NewCache(scraperConfig, s.TxnManager) if err != nil { logger.Errorf("Error reading scraper configs: %s", err.Error()) @@ -210,5 +215,5 @@ func (s *singleton) RefreshConfig() { // RefreshScraperCache refreshes the scraper cache. Call this when scraper // configuration changes. func (s *singleton) RefreshScraperCache() { - s.ScraperCache = initScraperCache() + s.ScraperCache = s.initScraperCache() } diff --git a/pkg/manager/manager_tasks.go b/pkg/manager/manager_tasks.go index ace48ffc0..f5332d1b7 100644 --- a/pkg/manager/manager_tasks.go +++ b/pkg/manager/manager_tasks.go @@ -1,6 +1,7 @@ package manager import ( + "context" "errors" "fmt" "os" @@ -119,7 +120,7 @@ func (s *singleton) neededScan(paths []*models.StashConfig) (total *int, newFile for _, sp := range paths { err := walkFilesToScan(sp, func(path string, info os.FileInfo, err error) error { t++ - task := ScanTask{FilePath: path} + task := ScanTask{FilePath: path, TxnManager: s.TxnManager} if !task.doesPathExist() { n++ } @@ -211,7 +212,17 @@ func (s *singleton) Scan(input models.ScanMetadataInput) { instance.Paths.Generated.EnsureTmpDir() wg.Add() - task := ScanTask{FilePath: path, UseFileMetadata: input.UseFileMetadata, StripFileExtension: input.StripFileExtension, fileNamingAlgorithm: fileNamingAlgo, calculateMD5: calculateMD5, GeneratePreview: input.ScanGeneratePreviews, GenerateImagePreview: input.ScanGenerateImagePreviews, GenerateSprite: input.ScanGenerateSprites} + task := ScanTask{ + TxnManager: s.TxnManager, + FilePath: path, + UseFileMetadata: input.UseFileMetadata, + StripFileExtension: input.StripFileExtension, + fileNamingAlgorithm: fileNamingAlgo, + calculateMD5: calculateMD5, + GeneratePreview: input.ScanGeneratePreviews, + GenerateImagePreview: input.ScanGenerateImagePreviews, + GenerateSprite: input.ScanGenerateSprites, + } go task.Start(&wg) return nil @@ -240,7 +251,11 @@ func (s *singleton) Scan(input models.ScanMetadataInput) { for _, path := range galleries { wg.Add() - task := ScanTask{FilePath: path, UseFileMetadata: false} + task := ScanTask{ + TxnManager: s.TxnManager, + FilePath: path, + UseFileMetadata: false, + } go task.associateGallery(&wg) wg.Wait() } @@ -261,6 +276,7 @@ func (s *singleton) Import() { var wg sync.WaitGroup wg.Add(1) task := ImportTask{ + txnManager: s.TxnManager, BaseDir: config.GetMetadataPath(), Reset: true, DuplicateBehaviour: models.ImportDuplicateEnumFail, @@ -284,7 +300,11 @@ func (s *singleton) Export() { var wg sync.WaitGroup wg.Add(1) - task := ExportTask{full: true, fileNamingAlgorithm: config.GetVideoFileNamingAlgorithm()} + task := ExportTask{ + txnManager: s.TxnManager, + full: true, + fileNamingAlgorithm: config.GetVideoFileNamingAlgorithm(), + } go task.Start(&wg) wg.Wait() }() @@ -344,29 +364,47 @@ func (s *singleton) Generate(input models.GenerateMetadataInput) { s.Status.SetStatus(Generate) s.Status.indefiniteProgress() - qb := models.NewSceneQueryBuilder() - mqb := models.NewSceneMarkerQueryBuilder() - //this.job.total = await ObjectionUtils.getCount(Scene); instance.Paths.Generated.EnsureTmpDir() - sceneIDs := utils.StringSliceToIntSlice(input.SceneIDs) - markerIDs := utils.StringSliceToIntSlice(input.MarkerIDs) + sceneIDs, err := utils.StringSliceToIntSlice(input.SceneIDs) + if err != nil { + logger.Error(err.Error()) + } + markerIDs, err := utils.StringSliceToIntSlice(input.MarkerIDs) + if err != nil { + logger.Error(err.Error()) + } go func() { defer s.returnToIdleState() var scenes []*models.Scene var err error + var markers []*models.SceneMarker - if len(sceneIDs) > 0 { - scenes, err = qb.FindMany(sceneIDs) - } else { - scenes, err = qb.All() - } + if err := s.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { + qb := r.Scene() + if len(sceneIDs) > 0 { + scenes, err = qb.FindMany(sceneIDs) + } else { + scenes, err = qb.All() + } - if err != nil { - logger.Errorf("failed to get scenes for generate") + if err != nil { + return err + } + + if len(markerIDs) > 0 { + markers, err = r.SceneMarker().FindMany(markerIDs) + if err != nil { + return err + } + } + + return nil + }); err != nil { + logger.Error(err.Error()) return } @@ -377,14 +415,7 @@ func (s *singleton) Generate(input models.GenerateMetadataInput) { s.Status.Progress = 0 lenScenes := len(scenes) - total := lenScenes - - var markers []*models.SceneMarker - if len(markerIDs) > 0 { - markers, err = mqb.FindMany(markerIDs) - - total += len(markers) - } + total := lenScenes + len(markers) if s.Status.stopping { logger.Info("Stopping due to user request") @@ -429,7 +460,11 @@ func (s *singleton) Generate(input models.GenerateMetadataInput) { } if input.Sprites { - task := GenerateSpriteTask{Scene: *scene, Overwrite: overwrite, fileNamingAlgorithm: fileNamingAlgo} + task := GenerateSpriteTask{ + Scene: *scene, + Overwrite: overwrite, + fileNamingAlgorithm: fileNamingAlgo, + } wg.Add() go task.Start(&wg) } @@ -448,13 +483,22 @@ func (s *singleton) Generate(input models.GenerateMetadataInput) { if input.Markers { wg.Add() - task := GenerateMarkersTask{Scene: scene, Overwrite: overwrite, fileNamingAlgorithm: fileNamingAlgo} + task := GenerateMarkersTask{ + TxnManager: s.TxnManager, + Scene: scene, + Overwrite: overwrite, + fileNamingAlgorithm: fileNamingAlgo, + } go task.Start(&wg) } if input.Transcodes { wg.Add() - task := GenerateTranscodeTask{Scene: *scene, Overwrite: overwrite, fileNamingAlgorithm: fileNamingAlgo} + task := GenerateTranscodeTask{ + Scene: *scene, + Overwrite: overwrite, + fileNamingAlgorithm: fileNamingAlgo, + } go task.Start(&wg) } } @@ -474,7 +518,12 @@ func (s *singleton) Generate(input models.GenerateMetadataInput) { } wg.Add() - task := GenerateMarkersTask{Marker: marker, Overwrite: overwrite, fileNamingAlgorithm: fileNamingAlgo} + task := GenerateMarkersTask{ + TxnManager: s.TxnManager, + Marker: marker, + Overwrite: overwrite, + fileNamingAlgorithm: fileNamingAlgo, + } go task.Start(&wg) } @@ -502,7 +551,6 @@ func (s *singleton) generateScreenshot(sceneId string, at *float64) { s.Status.SetStatus(Generate) s.Status.indefiniteProgress() - qb := models.NewSceneQueryBuilder() instance.Paths.Generated.EnsureTmpDir() go func() { @@ -514,13 +562,18 @@ func (s *singleton) generateScreenshot(sceneId string, at *float64) { return } - scene, err := qb.Find(sceneIdInt) - if err != nil || scene == nil { - logger.Errorf("failed to get scene for generate") + var scene *models.Scene + if err := s.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { + var err error + scene, err = r.Scene().Find(sceneIdInt) + return err + }); err != nil || scene == nil { + logger.Errorf("failed to get scene for generate: %s", err.Error()) return } task := GenerateScreenshotTask{ + txnManager: s.TxnManager, Scene: *scene, ScreenshotAt: at, fileNamingAlgorithm: config.GetVideoFileNamingAlgorithm(), @@ -551,29 +604,36 @@ func (s *singleton) AutoTag(performerIds []string, studioIds []string, tagIds [] studioCount := len(studioIds) tagCount := len(tagIds) - performerQuery := models.NewPerformerQueryBuilder() - studioQuery := models.NewTagQueryBuilder() - tagQuery := models.NewTagQueryBuilder() + if err := s.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { + performerQuery := r.Performer() + studioQuery := r.Studio() + tagQuery := r.Tag() - const wildcard = "*" - var err error - if performerCount == 1 && performerIds[0] == wildcard { - performerCount, err = performerQuery.Count() - if err != nil { - logger.Errorf("Error getting performer count: %s", err.Error()) + const wildcard = "*" + var err error + if performerCount == 1 && performerIds[0] == wildcard { + performerCount, err = performerQuery.Count() + if err != nil { + return fmt.Errorf("Error getting performer count: %s", err.Error()) + } } - } - if studioCount == 1 && studioIds[0] == wildcard { - studioCount, err = studioQuery.Count() - if err != nil { - logger.Errorf("Error getting studio count: %s", err.Error()) + if studioCount == 1 && studioIds[0] == wildcard { + studioCount, err = studioQuery.Count() + if err != nil { + return fmt.Errorf("Error getting studio count: %s", err.Error()) + } } - } - if tagCount == 1 && tagIds[0] == wildcard { - tagCount, err = tagQuery.Count() - if err != nil { - logger.Errorf("Error getting tag count: %s", err.Error()) + if tagCount == 1 && tagIds[0] == wildcard { + tagCount, err = tagQuery.Count() + if err != nil { + return fmt.Errorf("Error getting tag count: %s", err.Error()) + } } + + return nil + }); err != nil { + logger.Error(err.Error()) + return } total := performerCount + studioCount + tagCount @@ -586,36 +646,44 @@ func (s *singleton) AutoTag(performerIds []string, studioIds []string, tagIds [] } func (s *singleton) autoTagPerformers(performerIds []string) { - performerQuery := models.NewPerformerQueryBuilder() - var wg sync.WaitGroup for _, performerId := range performerIds { var performers []*models.Performer - if performerId == "*" { - var err error - performers, err = performerQuery.All() - if err != nil { - logger.Errorf("Error querying performers: %s", err.Error()) - continue - } - } else { - performerIdInt, err := strconv.Atoi(performerId) - if err != nil { - logger.Errorf("Error parsing performer id %s: %s", performerId, err.Error()) - continue + + if err := s.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { + performerQuery := r.Performer() + + if performerId == "*" { + var err error + performers, err = performerQuery.All() + if err != nil { + return fmt.Errorf("Error querying performers: %s", err.Error()) + } + } else { + performerIdInt, err := strconv.Atoi(performerId) + if err != nil { + return fmt.Errorf("Error parsing performer id %s: %s", performerId, err.Error()) + } + + performer, err := performerQuery.Find(performerIdInt) + if err != nil { + return fmt.Errorf("Error finding performer id %s: %s", performerId, err.Error()) + } + performers = append(performers, performer) } - performer, err := performerQuery.Find(performerIdInt) - if err != nil { - logger.Errorf("Error finding performer id %s: %s", performerId, err.Error()) - continue - } - performers = append(performers, performer) + return nil + }); err != nil { + logger.Error(err.Error()) + continue } for _, performer := range performers { wg.Add(1) - task := AutoTagPerformerTask{performer: performer} + task := AutoTagPerformerTask{ + txnManager: s.TxnManager, + performer: performer, + } go task.Start(&wg) wg.Wait() @@ -625,36 +693,43 @@ func (s *singleton) autoTagPerformers(performerIds []string) { } func (s *singleton) autoTagStudios(studioIds []string) { - studioQuery := models.NewStudioQueryBuilder() - var wg sync.WaitGroup for _, studioId := range studioIds { var studios []*models.Studio - if studioId == "*" { - var err error - studios, err = studioQuery.All() - if err != nil { - logger.Errorf("Error querying studios: %s", err.Error()) - continue - } - } else { - studioIdInt, err := strconv.Atoi(studioId) - if err != nil { - logger.Errorf("Error parsing studio id %s: %s", studioId, err.Error()) - continue + + if err := s.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { + studioQuery := r.Studio() + if studioId == "*" { + var err error + studios, err = studioQuery.All() + if err != nil { + return fmt.Errorf("Error querying studios: %s", err.Error()) + } + } else { + studioIdInt, err := strconv.Atoi(studioId) + if err != nil { + return fmt.Errorf("Error parsing studio id %s: %s", studioId, err.Error()) + } + + studio, err := studioQuery.Find(studioIdInt) + if err != nil { + return fmt.Errorf("Error finding studio id %s: %s", studioId, err.Error()) + } + studios = append(studios, studio) } - studio, err := studioQuery.Find(studioIdInt, nil) - if err != nil { - logger.Errorf("Error finding studio id %s: %s", studioId, err.Error()) - continue - } - studios = append(studios, studio) + return nil + }); err != nil { + logger.Error(err.Error()) + continue } for _, studio := range studios { wg.Add(1) - task := AutoTagStudioTask{studio: studio} + task := AutoTagStudioTask{ + studio: studio, + txnManager: s.TxnManager, + } go task.Start(&wg) wg.Wait() @@ -664,36 +739,42 @@ func (s *singleton) autoTagStudios(studioIds []string) { } func (s *singleton) autoTagTags(tagIds []string) { - tagQuery := models.NewTagQueryBuilder() - var wg sync.WaitGroup for _, tagId := range tagIds { var tags []*models.Tag - if tagId == "*" { - var err error - tags, err = tagQuery.All() - if err != nil { - logger.Errorf("Error querying tags: %s", err.Error()) - continue - } - } else { - tagIdInt, err := strconv.Atoi(tagId) - if err != nil { - logger.Errorf("Error parsing tag id %s: %s", tagId, err.Error()) - continue + if err := s.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { + tagQuery := r.Tag() + if tagId == "*" { + var err error + tags, err = tagQuery.All() + if err != nil { + return fmt.Errorf("Error querying tags: %s", err.Error()) + } + } else { + tagIdInt, err := strconv.Atoi(tagId) + if err != nil { + return fmt.Errorf("Error parsing tag id %s: %s", tagId, err.Error()) + } + + tag, err := tagQuery.Find(tagIdInt) + if err != nil { + return fmt.Errorf("Error finding tag id %s: %s", tagId, err.Error()) + } + tags = append(tags, tag) } - tag, err := tagQuery.Find(tagIdInt, nil) - if err != nil { - logger.Errorf("Error finding tag id %s: %s", tagId, err.Error()) - continue - } - tags = append(tags, tag) + return nil + }); err != nil { + logger.Error(err.Error()) + continue } for _, tag := range tags { wg.Add(1) - task := AutoTagTagTask{tag: tag} + task := AutoTagTagTask{ + txnManager: s.TxnManager, + tag: tag, + } go task.Start(&wg) wg.Wait() @@ -709,28 +790,38 @@ func (s *singleton) Clean() { s.Status.SetStatus(Clean) s.Status.indefiniteProgress() - qb := models.NewSceneQueryBuilder() - iqb := models.NewImageQueryBuilder() - gqb := models.NewGalleryQueryBuilder() go func() { defer s.returnToIdleState() - logger.Infof("Starting cleaning of tracked files") - scenes, err := qb.All() - if err != nil { - logger.Errorf("failed to fetch list of scenes for cleaning") - return - } + var scenes []*models.Scene + var images []*models.Image + var galleries []*models.Gallery - images, err := iqb.All() - if err != nil { - logger.Errorf("failed to fetch list of images for cleaning") - return - } + if err := s.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { + qb := r.Scene() + iqb := r.Image() + gqb := r.Gallery() - galleries, err := gqb.All() - if err != nil { - logger.Errorf("failed to fetch list of galleries for cleaning") + logger.Infof("Starting cleaning of tracked files") + var err error + scenes, err = qb.All() + if err != nil { + return errors.New("failed to fetch list of scenes for cleaning") + } + + images, err = iqb.All() + if err != nil { + return errors.New("failed to fetch list of images for cleaning") + } + + galleries, err = gqb.All() + if err != nil { + return errors.New("failed to fetch list of galleries for cleaning") + } + + return nil + }); err != nil { + logger.Error(err.Error()) return } @@ -757,7 +848,11 @@ func (s *singleton) Clean() { wg.Add(1) - task := CleanTask{Scene: scene, fileNamingAlgorithm: fileNamingAlgo} + task := CleanTask{ + TxnManager: s.TxnManager, + Scene: scene, + fileNamingAlgorithm: fileNamingAlgo, + } go task.Start(&wg) wg.Wait() } @@ -776,7 +871,10 @@ func (s *singleton) Clean() { wg.Add(1) - task := CleanTask{Image: img} + task := CleanTask{ + TxnManager: s.TxnManager, + Image: img, + } go task.Start(&wg) wg.Wait() } @@ -795,7 +893,10 @@ func (s *singleton) Clean() { wg.Add(1) - task := CleanTask{Gallery: gallery} + task := CleanTask{ + TxnManager: s.TxnManager, + Gallery: gallery, + } go task.Start(&wg) wg.Wait() } @@ -811,17 +912,19 @@ func (s *singleton) MigrateHash() { s.Status.SetStatus(Migrate) s.Status.indefiniteProgress() - qb := models.NewSceneQueryBuilder() - go func() { defer s.returnToIdleState() fileNamingAlgo := config.GetVideoFileNamingAlgorithm() logger.Infof("Migrating generated files for %s naming hash", fileNamingAlgo.String()) - scenes, err := qb.All() - if err != nil { - logger.Errorf("failed to fetch list of scenes for migration") + var scenes []*models.Scene + if err := s.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { + var err error + scenes, err = r.Scene().All() + return err + }); err != nil { + logger.Errorf("failed to fetch list of scenes for migration: %s", err.Error()) return } @@ -926,6 +1029,7 @@ func (s *singleton) neededGenerate(scenes []*models.Scene, input models.Generate if input.Markers { task := GenerateMarkersTask{ + TxnManager: s.TxnManager, Scene: scene, Overwrite: overwrite, fileNamingAlgorithm: fileNamingAlgo, diff --git a/pkg/manager/post_migrate.go b/pkg/manager/post_migrate.go index c6c8e1419..400ae39a0 100644 --- a/pkg/manager/post_migrate.go +++ b/pkg/manager/post_migrate.go @@ -2,5 +2,5 @@ package manager // PostMigrate is executed after migrations have been executed. func (s *singleton) PostMigrate() { - setInitialMD5Config() + setInitialMD5Config(s.TxnManager) } diff --git a/pkg/manager/scene.go b/pkg/manager/scene.go index d46be1dd3..e5ddb0696 100644 --- a/pkg/manager/scene.go +++ b/pkg/manager/scene.go @@ -4,9 +4,6 @@ import ( "fmt" "os" "path/filepath" - "strconv" - - "github.com/jmoiron/sqlx" "github.com/stashapp/stash/pkg/ffmpeg" "github.com/stashapp/stash/pkg/logger" @@ -16,37 +13,54 @@ import ( ) // DestroyScene deletes a scene and its associated relationships from the -// database. -func DestroyScene(sceneID int, tx *sqlx.Tx) error { - qb := models.NewSceneQueryBuilder() - jqb := models.NewJoinsQueryBuilder() +// database. Returns a function to perform any post-commit actions. +func DestroyScene(scene *models.Scene, repo models.Repository) (func(), error) { + qb := repo.Scene() + mqb := repo.SceneMarker() + gqb := repo.Gallery() - _, err := qb.Find(sceneID) + if err := gqb.ClearGalleryId(scene.ID); err != nil { + return nil, err + } + + markers, err := mqb.FindBySceneID(scene.ID) if err != nil { - return err + return nil, err } - if err := jqb.DestroyScenesTags(sceneID, tx); err != nil { - return err + var funcs []func() + for _, m := range markers { + f, err := DestroySceneMarker(scene, m, mqb) + if err != nil { + return nil, err + } + funcs = append(funcs, f) } - if err := jqb.DestroyPerformersScenes(sceneID, tx); err != nil { - return err + if err := qb.Destroy(scene.ID); err != nil { + return nil, err } - if err := jqb.DestroyScenesMarkers(sceneID, tx); err != nil { - return err + return func() { + for _, f := range funcs { + f() + } + }, nil +} + +// DestroySceneMarker deletes the scene marker from the database and returns a +// function that removes the generated files, to be executed after the +// transaction is successfully committed. +func DestroySceneMarker(scene *models.Scene, sceneMarker *models.SceneMarker, qb models.SceneMarkerWriter) (func(), error) { + if err := qb.Destroy(sceneMarker.ID); err != nil { + return nil, err } - if err := jqb.DestroyScenesGalleries(sceneID, tx); err != nil { - return err - } - - if err := qb.Destroy(strconv.Itoa(sceneID), tx); err != nil { - return err - } - - return nil + // delete the preview for the marker + return func() { + seconds := int(sceneMarker.Seconds) + DeleteSceneMarkerFiles(scene, seconds, config.GetVideoFileNamingAlgorithm()) + }, nil } // DeleteGeneratedSceneFiles deletes generated files for the provided scene. diff --git a/pkg/manager/studio.go b/pkg/manager/studio.go index 6de1ee572..9602f50b6 100644 --- a/pkg/manager/studio.go +++ b/pkg/manager/studio.go @@ -4,18 +4,16 @@ import ( "errors" "fmt" - "github.com/jmoiron/sqlx" "github.com/stashapp/stash/pkg/models" ) -func ValidateModifyStudio(studio models.StudioPartial, tx *sqlx.Tx) error { +func ValidateModifyStudio(studio models.StudioPartial, qb models.StudioReader) error { if studio.ParentID == nil || !studio.ParentID.Valid { return nil } // ensure there is no cyclic dependency thisID := studio.ID - qb := models.NewStudioQueryBuilder() currentParentID := *studio.ParentID @@ -24,7 +22,7 @@ func ValidateModifyStudio(studio models.StudioPartial, tx *sqlx.Tx) error { return errors.New("studio cannot be an ancestor of itself") } - currentStudio, err := qb.Find(int(currentParentID.Int64), tx) + currentStudio, err := qb.Find(int(currentParentID.Int64)) if err != nil { return fmt.Errorf("error finding parent studio: %s", err.Error()) } diff --git a/pkg/manager/tag.go b/pkg/manager/tag.go index d26591f8c..8447b69f1 100644 --- a/pkg/manager/tag.go +++ b/pkg/manager/tag.go @@ -3,17 +3,13 @@ package manager import ( "fmt" - "github.com/jmoiron/sqlx" "github.com/stashapp/stash/pkg/models" ) -func EnsureTagNameUnique(tag models.Tag, tx *sqlx.Tx) error { - qb := models.NewTagQueryBuilder() - +func EnsureTagNameUnique(tag models.Tag, qb models.TagReader) error { // ensure name is unique - sameNameTag, err := qb.FindByName(tag.Name, tx, true) + sameNameTag, err := qb.FindByName(tag.Name, true) if err != nil { - _ = tx.Rollback() return err } diff --git a/pkg/manager/task_autotag.go b/pkg/manager/task_autotag.go index 8e4c4ab36..ac47c404e 100644 --- a/pkg/manager/task_autotag.go +++ b/pkg/manager/task_autotag.go @@ -3,16 +3,18 @@ package manager import ( "context" "database/sql" + "fmt" "strings" "sync" - "github.com/stashapp/stash/pkg/database" "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/scene" ) type AutoTagPerformerTask struct { - performer *models.Performer + performer *models.Performer + txnManager models.TransactionManager } func (t *AutoTagPerformerTask) Start(wg *sync.WaitGroup) { @@ -32,44 +34,38 @@ func getQueryRegex(name string) string { } func (t *AutoTagPerformerTask) autoTagPerformer() { - qb := models.NewSceneQueryBuilder() - jqb := models.NewJoinsQueryBuilder() - regex := getQueryRegex(t.performer.Name.String) - const ignoreOrganized = true - scenes, err := qb.QueryAllByPathRegex(regex, ignoreOrganized) - - if err != nil { - logger.Infof("Error querying scenes with regex '%s': %s", regex, err.Error()) - return - } - - ctx := context.TODO() - tx := database.DB.MustBeginTx(ctx, nil) - - for _, scene := range scenes { - added, err := jqb.AddPerformerScene(scene.ID, t.performer.ID, tx) + if err := t.txnManager.WithTxn(context.TODO(), func(r models.Repository) error { + qb := r.Scene() + const ignoreOrganized = true + scenes, err := qb.QueryAllByPathRegex(regex, ignoreOrganized) if err != nil { - logger.Infof("Error adding performer '%s' to scene '%s': %s", t.performer.Name.String, scene.GetTitle(), err.Error()) - tx.Rollback() - return + return fmt.Errorf("Error querying scenes with regex '%s': %s", regex, err.Error()) } - if added { - logger.Infof("Added performer '%s' to scene '%s'", t.performer.Name.String, scene.GetTitle()) - } - } + for _, s := range scenes { + added, err := scene.AddPerformer(qb, s.ID, t.performer.ID) - if err := tx.Commit(); err != nil { - logger.Infof("Error adding performer to scene: %s", err.Error()) - return + if err != nil { + return fmt.Errorf("Error adding performer '%s' to scene '%s': %s", t.performer.Name.String, s.GetTitle(), err.Error()) + } + + if added { + logger.Infof("Added performer '%s' to scene '%s'", t.performer.Name.String, s.GetTitle()) + } + } + + return nil + }); err != nil { + logger.Error(err.Error()) } } type AutoTagStudioTask struct { - studio *models.Studio + studio *models.Studio + txnManager models.TransactionManager } func (t *AutoTagStudioTask) Start(wg *sync.WaitGroup) { @@ -79,54 +75,47 @@ func (t *AutoTagStudioTask) Start(wg *sync.WaitGroup) { } func (t *AutoTagStudioTask) autoTagStudio() { - qb := models.NewSceneQueryBuilder() - regex := getQueryRegex(t.studio.Name.String) - const ignoreOrganized = true - scenes, err := qb.QueryAllByPathRegex(regex, ignoreOrganized) - - if err != nil { - logger.Infof("Error querying scenes with regex '%s': %s", regex, err.Error()) - return - } - - ctx := context.TODO() - tx := database.DB.MustBeginTx(ctx, nil) - - for _, scene := range scenes { - // #306 - don't overwrite studio if already present - if scene.StudioID.Valid { - // don't modify - continue - } - - logger.Infof("Adding studio '%s' to scene '%s'", t.studio.Name.String, scene.GetTitle()) - - // set the studio id - studioID := sql.NullInt64{Int64: int64(t.studio.ID), Valid: true} - scenePartial := models.ScenePartial{ - ID: scene.ID, - StudioID: &studioID, - } - - _, err := qb.Update(scenePartial, tx) + if err := t.txnManager.WithTxn(context.TODO(), func(r models.Repository) error { + qb := r.Scene() + const ignoreOrganized = true + scenes, err := qb.QueryAllByPathRegex(regex, ignoreOrganized) if err != nil { - logger.Infof("Error adding studio to scene: %s", err.Error()) - tx.Rollback() - return + return fmt.Errorf("Error querying scenes with regex '%s': %s", regex, err.Error()) } - } - if err := tx.Commit(); err != nil { - logger.Infof("Error adding studio to scene: %s", err.Error()) - return + for _, scene := range scenes { + // #306 - don't overwrite studio if already present + if scene.StudioID.Valid { + // don't modify + continue + } + + logger.Infof("Adding studio '%s' to scene '%s'", t.studio.Name.String, scene.GetTitle()) + + // set the studio id + studioID := sql.NullInt64{Int64: int64(t.studio.ID), Valid: true} + scenePartial := models.ScenePartial{ + ID: scene.ID, + StudioID: &studioID, + } + + if _, err := qb.Update(scenePartial); err != nil { + return fmt.Errorf("Error adding studio to scene: %s", err.Error()) + } + } + + return nil + }); err != nil { + logger.Error(err.Error()) } } type AutoTagTagTask struct { - tag *models.Tag + tag *models.Tag + txnManager models.TransactionManager } func (t *AutoTagTagTask) Start(wg *sync.WaitGroup) { @@ -136,38 +125,31 @@ func (t *AutoTagTagTask) Start(wg *sync.WaitGroup) { } func (t *AutoTagTagTask) autoTagTag() { - qb := models.NewSceneQueryBuilder() - jqb := models.NewJoinsQueryBuilder() - regex := getQueryRegex(t.tag.Name) - const ignoreOrganized = true - scenes, err := qb.QueryAllByPathRegex(regex, ignoreOrganized) - - if err != nil { - logger.Infof("Error querying scenes with regex '%s': %s", regex, err.Error()) - return - } - - ctx := context.TODO() - tx := database.DB.MustBeginTx(ctx, nil) - - for _, scene := range scenes { - added, err := jqb.AddSceneTag(scene.ID, t.tag.ID, tx) + if err := t.txnManager.WithTxn(context.TODO(), func(r models.Repository) error { + qb := r.Scene() + const ignoreOrganized = true + scenes, err := qb.QueryAllByPathRegex(regex, ignoreOrganized) if err != nil { - logger.Infof("Error adding tag '%s' to scene '%s': %s", t.tag.Name, scene.GetTitle(), err.Error()) - tx.Rollback() - return + return fmt.Errorf("Error querying scenes with regex '%s': %s", regex, err.Error()) } - if added { - logger.Infof("Added tag '%s' to scene '%s'", t.tag.Name, scene.GetTitle()) - } - } + for _, s := range scenes { + added, err := scene.AddTag(qb, s.ID, t.tag.ID) - if err := tx.Commit(); err != nil { - logger.Infof("Error adding tag to scene: %s", err.Error()) - return + if err != nil { + return fmt.Errorf("Error adding tag '%s' to scene '%s': %s", t.tag.Name, s.GetTitle(), err.Error()) + } + + if added { + logger.Infof("Added tag '%s' to scene '%s'", t.tag.Name, s.GetTitle()) + } + } + + return nil + }); err != nil { + logger.Error(err.Error()) } } diff --git a/pkg/manager/task_autotag_test.go b/pkg/manager/task_autotag_test.go index 67e428bbe..feaeb43c3 100644 --- a/pkg/manager/task_autotag_test.go +++ b/pkg/manager/task_autotag_test.go @@ -14,11 +14,11 @@ import ( "github.com/stashapp/stash/pkg/database" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/sqlite" "github.com/stashapp/stash/pkg/utils" _ "github.com/golang-migrate/migrate/v4/database/sqlite3" _ "github.com/golang-migrate/migrate/v4/source/file" - "github.com/jmoiron/sqlx" ) const testName = "Foo's Bar" @@ -106,17 +106,15 @@ func TestMain(m *testing.M) { os.Exit(ret) } -func createPerformer(tx *sqlx.Tx) error { +func createPerformer(pqb models.PerformerWriter) error { // create the performer - pqb := models.NewPerformerQueryBuilder() - performer := models.Performer{ Checksum: testName, Name: sql.NullString{Valid: true, String: testName}, Favorite: sql.NullBool{Valid: true, Bool: false}, } - _, err := pqb.Create(performer, tx) + _, err := pqb.Create(performer) if err != nil { return err } @@ -124,27 +122,23 @@ func createPerformer(tx *sqlx.Tx) error { return nil } -func createStudio(tx *sqlx.Tx, name string) (*models.Studio, error) { +func createStudio(qb models.StudioWriter, name string) (*models.Studio, error) { // create the studio - qb := models.NewStudioQueryBuilder() - studio := models.Studio{ Checksum: name, Name: sql.NullString{Valid: true, String: testName}, } - return qb.Create(studio, tx) + return qb.Create(studio) } -func createTag(tx *sqlx.Tx) error { +func createTag(qb models.TagWriter) error { // create the studio - qb := models.NewTagQueryBuilder() - tag := models.Tag{ Name: testName, } - _, err := qb.Create(tag, tx) + _, err := qb.Create(tag) if err != nil { return err } @@ -152,9 +146,7 @@ func createTag(tx *sqlx.Tx) error { return nil } -func createScenes(tx *sqlx.Tx) error { - sqb := models.NewSceneQueryBuilder() - +func createScenes(sqb models.SceneReaderWriter) error { // create the scenes var scenePatterns []string var falseScenePatterns []string @@ -175,13 +167,13 @@ func createScenes(tx *sqlx.Tx) error { } for _, fn := range scenePatterns { - err := createScene(sqb, tx, makeScene(fn, true)) + err := createScene(sqb, makeScene(fn, true)) if err != nil { return err } } for _, fn := range falseScenePatterns { - err := createScene(sqb, tx, makeScene(fn, false)) + err := createScene(sqb, makeScene(fn, false)) if err != nil { return err } @@ -191,7 +183,7 @@ func createScenes(tx *sqlx.Tx) error { for _, fn := range scenePatterns { s := makeScene("organized"+fn, false) s.Organized = true - err := createScene(sqb, tx, s) + err := createScene(sqb, s) if err != nil { return err } @@ -200,7 +192,7 @@ func createScenes(tx *sqlx.Tx) error { // create scene with existing studio io studioScene := makeScene(existingStudioSceneName, true) studioScene.StudioID = sql.NullInt64{Valid: true, Int64: int64(existingStudioID)} - err := createScene(sqb, tx, studioScene) + err := createScene(sqb, studioScene) if err != nil { return err } @@ -222,8 +214,8 @@ func makeScene(name string, expectedResult bool) *models.Scene { return scene } -func createScene(sqb models.SceneQueryBuilder, tx *sqlx.Tx, scene *models.Scene) error { - _, err := sqb.Create(*scene, tx) +func createScene(sqb models.SceneWriter, scene *models.Scene) error { + _, err := sqb.Create(*scene) if err != nil { return fmt.Errorf("Failed to create scene with name '%s': %s", scene.Path, err.Error()) @@ -232,39 +224,43 @@ func createScene(sqb models.SceneQueryBuilder, tx *sqlx.Tx, scene *models.Scene) return nil } +func withTxn(f func(r models.Repository) error) error { + t := sqlite.NewTransactionManager() + return t.WithTxn(context.TODO(), f) +} + func populateDB() error { - ctx := context.TODO() - tx := database.DB.MustBeginTx(ctx, nil) + if err := withTxn(func(r models.Repository) error { + err := createPerformer(r.Performer()) + if err != nil { + return err + } - err := createPerformer(tx) - if err != nil { - return err - } + _, err = createStudio(r.Studio(), testName) + if err != nil { + return err + } - _, err = createStudio(tx, testName) - if err != nil { - return err - } + // create existing studio + existingStudio, err := createStudio(r.Studio(), existingStudioName) + if err != nil { + return err + } - // create existing studio - existingStudio, err := createStudio(tx, existingStudioName) - if err != nil { - return err - } + existingStudioID = existingStudio.ID - existingStudioID = existingStudio.ID + err = createTag(r.Tag()) + if err != nil { + return err + } - err = createTag(tx) - if err != nil { - return err - } + err = createScenes(r.Scene()) + if err != nil { + return err + } - err = createScenes(tx) - if err != nil { - return err - } - - if err := tx.Commit(); err != nil { + return nil + }); err != nil { return err } @@ -272,16 +268,19 @@ func populateDB() error { } func TestParsePerformers(t *testing.T) { - pqb := models.NewPerformerQueryBuilder() - performers, err := pqb.All() - - if err != nil { + var performers []*models.Performer + if err := withTxn(func(r models.Repository) error { + var err error + performers, err = r.Performer().All() + return err + }); err != nil { t.Errorf("Error getting performer: %s", err) return } task := AutoTagPerformerTask{ - performer: performers[0], + performer: performers[0], + txnManager: sqlite.NewTransactionManager(), } var wg sync.WaitGroup @@ -289,38 +288,47 @@ func TestParsePerformers(t *testing.T) { task.Start(&wg) // verify that scenes were tagged correctly - sqb := models.NewSceneQueryBuilder() - - scenes, err := sqb.All() - - for _, scene := range scenes { - performers, err := pqb.FindBySceneID(scene.ID, nil) + withTxn(func(r models.Repository) error { + pqb := r.Performer() + scenes, err := r.Scene().All() if err != nil { - t.Errorf("Error getting scene performers: %s", err.Error()) - return + t.Error(err.Error()) } - // title is only set on scenes where we expect performer to be set - if scene.Title.String == scene.Path && len(performers) == 0 { - t.Errorf("Did not set performer '%s' for path '%s'", testName, scene.Path) - } else if scene.Title.String != scene.Path && len(performers) > 0 { - t.Errorf("Incorrectly set performer '%s' for path '%s'", testName, scene.Path) + for _, scene := range scenes { + performers, err := pqb.FindBySceneID(scene.ID) + + if err != nil { + t.Errorf("Error getting scene performers: %s", err.Error()) + } + + // title is only set on scenes where we expect performer to be set + if scene.Title.String == scene.Path && len(performers) == 0 { + t.Errorf("Did not set performer '%s' for path '%s'", testName, scene.Path) + } else if scene.Title.String != scene.Path && len(performers) > 0 { + t.Errorf("Incorrectly set performer '%s' for path '%s'", testName, scene.Path) + } } - } + + return nil + }) } func TestParseStudios(t *testing.T) { - studioQuery := models.NewStudioQueryBuilder() - studios, err := studioQuery.All() - - if err != nil { + var studios []*models.Studio + if err := withTxn(func(r models.Repository) error { + var err error + studios, err = r.Studio().All() + return err + }); err != nil { t.Errorf("Error getting studio: %s", err) return } task := AutoTagStudioTask{ - studio: studios[0], + studio: studios[0], + txnManager: sqlite.NewTransactionManager(), } var wg sync.WaitGroup @@ -328,38 +336,46 @@ func TestParseStudios(t *testing.T) { task.Start(&wg) // verify that scenes were tagged correctly - sqb := models.NewSceneQueryBuilder() + withTxn(func(r models.Repository) error { + scenes, err := r.Scene().All() + if err != nil { + t.Error(err.Error()) + } - scenes, err := sqb.All() - - for _, scene := range scenes { - // check for existing studio id scene first - if scene.Path == existingStudioSceneName { - if scene.StudioID.Int64 != int64(existingStudioID) { - t.Error("Incorrectly overwrote studio ID for scene with existing studio ID") - } - } else { - // title is only set on scenes where we expect studio to be set - if scene.Title.String == scene.Path && scene.StudioID.Int64 != int64(studios[0].ID) { - t.Errorf("Did not set studio '%s' for path '%s'", testName, scene.Path) - } else if scene.Title.String != scene.Path && scene.StudioID.Int64 == int64(studios[0].ID) { - t.Errorf("Incorrectly set studio '%s' for path '%s'", testName, scene.Path) + for _, scene := range scenes { + // check for existing studio id scene first + if scene.Path == existingStudioSceneName { + if scene.StudioID.Int64 != int64(existingStudioID) { + t.Error("Incorrectly overwrote studio ID for scene with existing studio ID") + } + } else { + // title is only set on scenes where we expect studio to be set + if scene.Title.String == scene.Path && scene.StudioID.Int64 != int64(studios[0].ID) { + t.Errorf("Did not set studio '%s' for path '%s'", testName, scene.Path) + } else if scene.Title.String != scene.Path && scene.StudioID.Int64 == int64(studios[0].ID) { + t.Errorf("Incorrectly set studio '%s' for path '%s'", testName, scene.Path) + } } } - } + + return nil + }) } func TestParseTags(t *testing.T) { - tagQuery := models.NewTagQueryBuilder() - tags, err := tagQuery.All() - - if err != nil { + var tags []*models.Tag + if err := withTxn(func(r models.Repository) error { + var err error + tags, err = r.Tag().All() + return err + }); err != nil { t.Errorf("Error getting performer: %s", err) return } task := AutoTagTagTask{ - tag: tags[0], + tag: tags[0], + txnManager: sqlite.NewTransactionManager(), } var wg sync.WaitGroup @@ -367,23 +383,29 @@ func TestParseTags(t *testing.T) { task.Start(&wg) // verify that scenes were tagged correctly - sqb := models.NewSceneQueryBuilder() - - scenes, err := sqb.All() - - for _, scene := range scenes { - tags, err := tagQuery.FindBySceneID(scene.ID, nil) - + withTxn(func(r models.Repository) error { + scenes, err := r.Scene().All() if err != nil { - t.Errorf("Error getting scene tags: %s", err.Error()) - return + t.Error(err.Error()) } - // title is only set on scenes where we expect performer to be set - if scene.Title.String == scene.Path && len(tags) == 0 { - t.Errorf("Did not set tag '%s' for path '%s'", testName, scene.Path) - } else if scene.Title.String != scene.Path && len(tags) > 0 { - t.Errorf("Incorrectly set tag '%s' for path '%s'", testName, scene.Path) + tqb := r.Tag() + + for _, scene := range scenes { + tags, err := tqb.FindBySceneID(scene.ID) + + if err != nil { + t.Errorf("Error getting scene tags: %s", err.Error()) + } + + // title is only set on scenes where we expect performer to be set + if scene.Title.String == scene.Path && len(tags) == 0 { + t.Errorf("Did not set tag '%s' for path '%s'", testName, scene.Path) + } else if scene.Title.String != scene.Path && len(tags) > 0 { + t.Errorf("Incorrectly set tag '%s' for path '%s'", testName, scene.Path) + } } - } + + return nil + }) } diff --git a/pkg/manager/task_clean.go b/pkg/manager/task_clean.go index dc6091d17..672745e6e 100644 --- a/pkg/manager/task_clean.go +++ b/pkg/manager/task_clean.go @@ -7,7 +7,6 @@ import ( "strings" "sync" - "github.com/stashapp/stash/pkg/database" "github.com/stashapp/stash/pkg/image" "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/manager/config" @@ -15,6 +14,7 @@ import ( ) type CleanTask struct { + TxnManager models.TransactionManager Scene *models.Scene Gallery *models.Gallery Image *models.Image @@ -133,60 +133,45 @@ func (t *CleanTask) shouldCleanImage(s *models.Image) bool { } func (t *CleanTask) deleteScene(sceneID int) { - ctx := context.TODO() - qb := models.NewSceneQueryBuilder() - tx := database.DB.MustBeginTx(ctx, nil) + var postCommitFunc func() + var scene *models.Scene + if err := t.TxnManager.WithTxn(context.TODO(), func(repo models.Repository) error { + qb := repo.Scene() - scene, err := qb.Find(sceneID) - err = DestroyScene(sceneID, tx) - - if err != nil { - logger.Errorf("Error deleting scene from database: %s", err.Error()) - tx.Rollback() - return - } - - if err := tx.Commit(); err != nil { + var err error + scene, err = qb.Find(sceneID) + if err != nil { + return err + } + postCommitFunc, err = DestroyScene(scene, repo) + return err + }); err != nil { logger.Errorf("Error deleting scene from database: %s", err.Error()) return } + postCommitFunc() + DeleteGeneratedSceneFiles(scene, t.fileNamingAlgorithm) } func (t *CleanTask) deleteGallery(galleryID int) { - ctx := context.TODO() - qb := models.NewGalleryQueryBuilder() - tx := database.DB.MustBeginTx(ctx, nil) - - err := qb.Destroy(galleryID, tx) - - if err != nil { - logger.Errorf("Error deleting gallery from database: %s", err.Error()) - tx.Rollback() - return - } - - if err := tx.Commit(); err != nil { + if err := t.TxnManager.WithTxn(context.TODO(), func(repo models.Repository) error { + qb := repo.Gallery() + return qb.Destroy(galleryID) + }); err != nil { logger.Errorf("Error deleting gallery from database: %s", err.Error()) return } } func (t *CleanTask) deleteImage(imageID int) { - ctx := context.TODO() - qb := models.NewImageQueryBuilder() - tx := database.DB.MustBeginTx(ctx, nil) - err := qb.Destroy(imageID, tx) + if err := t.TxnManager.WithTxn(context.TODO(), func(repo models.Repository) error { + qb := repo.Image() - if err != nil { - logger.Errorf("Error deleting image from database: %s", err.Error()) - tx.Rollback() - return - } - - if err := tx.Commit(); err != nil { + return qb.Destroy(imageID) + }); err != nil { logger.Errorf("Error deleting image from database: %s", err.Error()) return } diff --git a/pkg/manager/task_export.go b/pkg/manager/task_export.go index 5e654876c..011d739b8 100644 --- a/pkg/manager/task_export.go +++ b/pkg/manager/task_export.go @@ -2,6 +2,7 @@ package manager import ( "archive/zip" + "context" "fmt" "io" "io/ioutil" @@ -27,7 +28,8 @@ import ( ) type ExportTask struct { - full bool + txnManager models.TransactionManager + full bool baseDir string json jsonUtils @@ -58,8 +60,10 @@ func newExportSpec(input *models.ExportObjectTypeInput) *exportSpec { return &exportSpec{} } + ids, _ := utils.StringSliceToIntSlice(input.Ids) + ret := &exportSpec{ - IDs: utils.StringSliceToIntSlice(input.Ids), + IDs: ids, } if input.All != nil { @@ -76,6 +80,7 @@ func CreateExportTask(a models.HashAlgorithm, input models.ExportObjectsInput) * } return &ExportTask{ + txnManager: GetInstance().TxnManager, fileNamingAlgorithm: a, scenes: newExportSpec(input.Scenes), images: newExportSpec(input.Images), @@ -125,34 +130,40 @@ func (t *ExportTask) Start(wg *sync.WaitGroup) { paths.EnsureJSONDirs(t.baseDir) - // include movie scenes and gallery images - if !t.full { - // only include movie scenes if includeDependencies is also set - if !t.scenes.all && t.includeDependencies { - t.populateMovieScenes() + t.txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { + // include movie scenes and gallery images + if !t.full { + // only include movie scenes if includeDependencies is also set + if !t.scenes.all && t.includeDependencies { + t.populateMovieScenes(r) + } + + // always export gallery images + if !t.images.all { + t.populateGalleryImages(r) + } } - // always export gallery images - if !t.images.all { - t.populateGalleryImages() - } - } + t.ExportScenes(workerCount, r) + t.ExportImages(workerCount, r) + t.ExportGalleries(workerCount, r) + t.ExportMovies(workerCount, r) + t.ExportPerformers(workerCount, r) + t.ExportStudios(workerCount, r) + t.ExportTags(workerCount, r) - t.ExportScenes(workerCount) - t.ExportImages(workerCount) - t.ExportGalleries(workerCount) - t.ExportMovies(workerCount) - t.ExportPerformers(workerCount) - t.ExportStudios(workerCount) - t.ExportTags(workerCount) + if t.full { + t.ExportScrapedItems(r) + } + + return nil + }) if err := t.json.saveMappings(t.Mappings); err != nil { logger.Errorf("[mappings] failed to save json: %s", err.Error()) } - if t.full { - t.ExportScrapedItems() - } else { + if !t.full { err := t.generateDownload() if err != nil { logger.Errorf("error generating download link: %s", err.Error()) @@ -242,9 +253,9 @@ func (t *ExportTask) zipFile(fn, outDir string, z *zip.Writer) error { return nil } -func (t *ExportTask) populateMovieScenes() { - reader := models.NewMovieReaderWriter(nil) - sceneReader := models.NewSceneReaderWriter(nil) +func (t *ExportTask) populateMovieScenes(repo models.ReaderRepository) { + reader := repo.Movie() + sceneReader := repo.Scene() var movies []*models.Movie var err error @@ -272,9 +283,9 @@ func (t *ExportTask) populateMovieScenes() { } } -func (t *ExportTask) populateGalleryImages() { - reader := models.NewGalleryReaderWriter(nil) - imageReader := models.NewImageReaderWriter(nil) +func (t *ExportTask) populateGalleryImages(repo models.ReaderRepository) { + reader := repo.Gallery() + imageReader := repo.Image() var galleries []*models.Gallery var err error @@ -302,10 +313,10 @@ func (t *ExportTask) populateGalleryImages() { } } -func (t *ExportTask) ExportScenes(workers int) { +func (t *ExportTask) ExportScenes(workers int, repo models.ReaderRepository) { var scenesWg sync.WaitGroup - sceneReader := models.NewSceneReaderWriter(nil) + sceneReader := repo.Scene() var scenes []*models.Scene var err error @@ -327,7 +338,7 @@ func (t *ExportTask) ExportScenes(workers int) { for w := 0; w < workers; w++ { // create export Scene workers scenesWg.Add(1) - go exportScene(&scenesWg, jobCh, t) + go exportScene(&scenesWg, jobCh, repo, t) } for i, scene := range scenes { @@ -346,16 +357,15 @@ func (t *ExportTask) ExportScenes(workers int) { logger.Infof("[scenes] export complete in %s. %d workers used.", time.Since(startTime), workers) } -func exportScene(wg *sync.WaitGroup, jobChan <-chan *models.Scene, t *ExportTask) { +func exportScene(wg *sync.WaitGroup, jobChan <-chan *models.Scene, repo models.ReaderRepository, t *ExportTask) { defer wg.Done() - sceneReader := models.NewSceneReaderWriter(nil) - studioReader := models.NewStudioReaderWriter(nil) - movieReader := models.NewMovieReaderWriter(nil) - galleryReader := models.NewGalleryReaderWriter(nil) - performerReader := models.NewPerformerReaderWriter(nil) - tagReader := models.NewTagReaderWriter(nil) - sceneMarkerReader := models.NewSceneMarkerReaderWriter(nil) - joinReader := models.NewJoinReaderWriter(nil) + sceneReader := repo.Scene() + studioReader := repo.Studio() + movieReader := repo.Movie() + galleryReader := repo.Gallery() + performerReader := repo.Performer() + tagReader := repo.Tag() + sceneMarkerReader := repo.SceneMarker() for s := range jobChan { sceneHash := s.GetHash(t.fileNamingAlgorithm) @@ -402,7 +412,7 @@ func exportScene(wg *sync.WaitGroup, jobChan <-chan *models.Scene, t *ExportTask continue } - newSceneJSON.Movies, err = scene.GetSceneMoviesJSON(movieReader, joinReader, s) + newSceneJSON.Movies, err = scene.GetSceneMoviesJSON(movieReader, sceneReader, s) if err != nil { logger.Errorf("[scenes] <%s> error getting scene movies JSON: %s", sceneHash, err.Error()) continue @@ -417,14 +427,14 @@ func exportScene(wg *sync.WaitGroup, jobChan <-chan *models.Scene, t *ExportTask t.galleries.IDs = utils.IntAppendUnique(t.galleries.IDs, sceneGallery.ID) } - tagIDs, err := scene.GetDependentTagIDs(tagReader, joinReader, sceneMarkerReader, s) + tagIDs, err := scene.GetDependentTagIDs(tagReader, sceneMarkerReader, s) if err != nil { logger.Errorf("[scenes] <%s> error getting scene tags: %s", sceneHash, err.Error()) continue } t.tags.IDs = utils.IntAppendUniques(t.tags.IDs, tagIDs) - movieIDs, err := scene.GetDependentMovieIDs(joinReader, s) + movieIDs, err := scene.GetDependentMovieIDs(sceneReader, s) if err != nil { logger.Errorf("[scenes] <%s> error getting scene movies: %s", sceneHash, err.Error()) continue @@ -445,10 +455,10 @@ func exportScene(wg *sync.WaitGroup, jobChan <-chan *models.Scene, t *ExportTask } } -func (t *ExportTask) ExportImages(workers int) { +func (t *ExportTask) ExportImages(workers int, repo models.ReaderRepository) { var imagesWg sync.WaitGroup - imageReader := models.NewImageReaderWriter(nil) + imageReader := repo.Image() var images []*models.Image var err error @@ -470,7 +480,7 @@ func (t *ExportTask) ExportImages(workers int) { for w := 0; w < workers; w++ { // create export Image workers imagesWg.Add(1) - go exportImage(&imagesWg, jobCh, t) + go exportImage(&imagesWg, jobCh, repo, t) } for i, image := range images { @@ -489,12 +499,12 @@ func (t *ExportTask) ExportImages(workers int) { logger.Infof("[images] export complete in %s. %d workers used.", time.Since(startTime), workers) } -func exportImage(wg *sync.WaitGroup, jobChan <-chan *models.Image, t *ExportTask) { +func exportImage(wg *sync.WaitGroup, jobChan <-chan *models.Image, repo models.ReaderRepository, t *ExportTask) { defer wg.Done() - studioReader := models.NewStudioReaderWriter(nil) - galleryReader := models.NewGalleryReaderWriter(nil) - performerReader := models.NewPerformerReaderWriter(nil) - tagReader := models.NewTagReaderWriter(nil) + studioReader := repo.Studio() + galleryReader := repo.Gallery() + performerReader := repo.Performer() + tagReader := repo.Tag() for s := range jobChan { imageHash := s.Checksum @@ -560,10 +570,10 @@ func (t *ExportTask) getGalleryChecksums(galleries []*models.Gallery) (ret []str return } -func (t *ExportTask) ExportGalleries(workers int) { +func (t *ExportTask) ExportGalleries(workers int, repo models.ReaderRepository) { var galleriesWg sync.WaitGroup - reader := models.NewGalleryReaderWriter(nil) + reader := repo.Gallery() var galleries []*models.Gallery var err error @@ -585,7 +595,7 @@ func (t *ExportTask) ExportGalleries(workers int) { for w := 0; w < workers; w++ { // create export Scene workers galleriesWg.Add(1) - go exportGallery(&galleriesWg, jobCh, t) + go exportGallery(&galleriesWg, jobCh, repo, t) } for i, gallery := range galleries { @@ -609,11 +619,11 @@ func (t *ExportTask) ExportGalleries(workers int) { logger.Infof("[galleries] export complete in %s. %d workers used.", time.Since(startTime), workers) } -func exportGallery(wg *sync.WaitGroup, jobChan <-chan *models.Gallery, t *ExportTask) { +func exportGallery(wg *sync.WaitGroup, jobChan <-chan *models.Gallery, repo models.ReaderRepository, t *ExportTask) { defer wg.Done() - studioReader := models.NewStudioReaderWriter(nil) - performerReader := models.NewPerformerReaderWriter(nil) - tagReader := models.NewTagReaderWriter(nil) + studioReader := repo.Studio() + performerReader := repo.Performer() + tagReader := repo.Tag() for g := range jobChan { galleryHash := g.Checksum @@ -666,10 +676,10 @@ func exportGallery(wg *sync.WaitGroup, jobChan <-chan *models.Gallery, t *Export } } -func (t *ExportTask) ExportPerformers(workers int) { +func (t *ExportTask) ExportPerformers(workers int, repo models.ReaderRepository) { var performersWg sync.WaitGroup - reader := models.NewPerformerReaderWriter(nil) + reader := repo.Performer() var performers []*models.Performer var err error all := t.full || (t.performers != nil && t.performers.all) @@ -689,7 +699,7 @@ func (t *ExportTask) ExportPerformers(workers int) { for w := 0; w < workers; w++ { // create export Performer workers performersWg.Add(1) - go t.exportPerformer(&performersWg, jobCh) + go t.exportPerformer(&performersWg, jobCh, repo) } for i, performer := range performers { @@ -706,10 +716,10 @@ func (t *ExportTask) ExportPerformers(workers int) { logger.Infof("[performers] export complete in %s. %d workers used.", time.Since(startTime), workers) } -func (t *ExportTask) exportPerformer(wg *sync.WaitGroup, jobChan <-chan *models.Performer) { +func (t *ExportTask) exportPerformer(wg *sync.WaitGroup, jobChan <-chan *models.Performer, repo models.ReaderRepository) { defer wg.Done() - performerReader := models.NewPerformerReaderWriter(nil) + performerReader := repo.Performer() for p := range jobChan { newPerformerJSON, err := performer.ToJSON(performerReader, p) @@ -732,10 +742,10 @@ func (t *ExportTask) exportPerformer(wg *sync.WaitGroup, jobChan <-chan *models. } } -func (t *ExportTask) ExportStudios(workers int) { +func (t *ExportTask) ExportStudios(workers int, repo models.ReaderRepository) { var studiosWg sync.WaitGroup - reader := models.NewStudioReaderWriter(nil) + reader := repo.Studio() var studios []*models.Studio var err error all := t.full || (t.studios != nil && t.studios.all) @@ -756,7 +766,7 @@ func (t *ExportTask) ExportStudios(workers int) { for w := 0; w < workers; w++ { // create export Studio workers studiosWg.Add(1) - go t.exportStudio(&studiosWg, jobCh) + go t.exportStudio(&studiosWg, jobCh, repo) } for i, studio := range studios { @@ -773,10 +783,10 @@ func (t *ExportTask) ExportStudios(workers int) { logger.Infof("[studios] export complete in %s. %d workers used.", time.Since(startTime), workers) } -func (t *ExportTask) exportStudio(wg *sync.WaitGroup, jobChan <-chan *models.Studio) { +func (t *ExportTask) exportStudio(wg *sync.WaitGroup, jobChan <-chan *models.Studio, repo models.ReaderRepository) { defer wg.Done() - studioReader := models.NewStudioReaderWriter(nil) + studioReader := repo.Studio() for s := range jobChan { newStudioJSON, err := studio.ToJSON(studioReader, s) @@ -797,10 +807,10 @@ func (t *ExportTask) exportStudio(wg *sync.WaitGroup, jobChan <-chan *models.Stu } } -func (t *ExportTask) ExportTags(workers int) { +func (t *ExportTask) ExportTags(workers int, repo models.ReaderRepository) { var tagsWg sync.WaitGroup - reader := models.NewTagReaderWriter(nil) + reader := repo.Tag() var tags []*models.Tag var err error all := t.full || (t.tags != nil && t.tags.all) @@ -821,7 +831,7 @@ func (t *ExportTask) ExportTags(workers int) { for w := 0; w < workers; w++ { // create export Tag workers tagsWg.Add(1) - go t.exportTag(&tagsWg, jobCh) + go t.exportTag(&tagsWg, jobCh, repo) } for i, tag := range tags { @@ -841,10 +851,10 @@ func (t *ExportTask) ExportTags(workers int) { logger.Infof("[tags] export complete in %s. %d workers used.", time.Since(startTime), workers) } -func (t *ExportTask) exportTag(wg *sync.WaitGroup, jobChan <-chan *models.Tag) { +func (t *ExportTask) exportTag(wg *sync.WaitGroup, jobChan <-chan *models.Tag, repo models.ReaderRepository) { defer wg.Done() - tagReader := models.NewTagReaderWriter(nil) + tagReader := repo.Tag() for thisTag := range jobChan { newTagJSON, err := tag.ToJSON(tagReader, thisTag) @@ -868,10 +878,10 @@ func (t *ExportTask) exportTag(wg *sync.WaitGroup, jobChan <-chan *models.Tag) { } } -func (t *ExportTask) ExportMovies(workers int) { +func (t *ExportTask) ExportMovies(workers int, repo models.ReaderRepository) { var moviesWg sync.WaitGroup - reader := models.NewMovieReaderWriter(nil) + reader := repo.Movie() var movies []*models.Movie var err error all := t.full || (t.movies != nil && t.movies.all) @@ -892,7 +902,7 @@ func (t *ExportTask) ExportMovies(workers int) { for w := 0; w < workers; w++ { // create export Studio workers moviesWg.Add(1) - go t.exportMovie(&moviesWg, jobCh) + go t.exportMovie(&moviesWg, jobCh, repo) } for i, movie := range movies { @@ -909,11 +919,11 @@ func (t *ExportTask) ExportMovies(workers int) { logger.Infof("[movies] export complete in %s. %d workers used.", time.Since(startTime), workers) } -func (t *ExportTask) exportMovie(wg *sync.WaitGroup, jobChan <-chan *models.Movie) { +func (t *ExportTask) exportMovie(wg *sync.WaitGroup, jobChan <-chan *models.Movie, repo models.ReaderRepository) { defer wg.Done() - movieReader := models.NewMovieReaderWriter(nil) - studioReader := models.NewStudioReaderWriter(nil) + movieReader := repo.Movie() + studioReader := repo.Studio() for m := range jobChan { newMovieJSON, err := movie.ToJSON(movieReader, studioReader, m) @@ -942,9 +952,9 @@ func (t *ExportTask) exportMovie(wg *sync.WaitGroup, jobChan <-chan *models.Movi } } -func (t *ExportTask) ExportScrapedItems() { - qb := models.NewScrapedItemQueryBuilder() - sqb := models.NewStudioQueryBuilder() +func (t *ExportTask) ExportScrapedItems(repo models.ReaderRepository) { + qb := repo.ScrapedItem() + sqb := repo.Studio() scrapedItems, err := qb.All() if err != nil { logger.Errorf("[scraped sites] failed to fetch all items: %s", err.Error()) @@ -960,7 +970,7 @@ func (t *ExportTask) ExportScrapedItems() { var studioName string if scrapedItem.StudioID.Valid { - studio, _ := sqb.Find(int(scrapedItem.StudioID.Int64), nil) + studio, _ := sqb.Find(int(scrapedItem.StudioID.Int64)) if studio != nil { studioName = studio.Name.String } diff --git a/pkg/manager/task_generate_markers.go b/pkg/manager/task_generate_markers.go index e5fdf7683..0e11fb853 100644 --- a/pkg/manager/task_generate_markers.go +++ b/pkg/manager/task_generate_markers.go @@ -1,6 +1,7 @@ package manager import ( + "context" "path/filepath" "strconv" @@ -13,6 +14,7 @@ import ( ) type GenerateMarkersTask struct { + TxnManager models.TransactionManager Scene *models.Scene Marker *models.SceneMarker Overwrite bool @@ -27,13 +29,21 @@ func (t *GenerateMarkersTask) Start(wg *sizedwaitgroup.SizedWaitGroup) { } if t.Marker != nil { - qb := models.NewSceneQueryBuilder() - scene, err := qb.Find(int(t.Marker.SceneID.Int64)) - if err != nil { + var scene *models.Scene + if err := t.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { + var err error + scene, err = r.Scene().Find(int(t.Marker.SceneID.Int64)) + return err + }); err != nil { logger.Errorf("error finding scene for marker: %s", err.Error()) return } + if scene == nil { + logger.Errorf("scene not found for id %d", t.Marker.SceneID.Int64) + return + } + videoFile, err := ffmpeg.NewVideoFile(instance.FFProbePath, t.Scene.Path, false) if err != nil { logger.Errorf("error reading video file: %s", err.Error()) @@ -45,8 +55,16 @@ func (t *GenerateMarkersTask) Start(wg *sizedwaitgroup.SizedWaitGroup) { } func (t *GenerateMarkersTask) generateSceneMarkers() { - qb := models.NewSceneMarkerQueryBuilder() - sceneMarkers, _ := qb.FindBySceneID(t.Scene.ID, nil) + var sceneMarkers []*models.SceneMarker + if err := t.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { + var err error + sceneMarkers, err = r.SceneMarker().FindBySceneID(t.Scene.ID) + return err + }); err != nil { + logger.Errorf("error getting scene markers: %s", err.Error()) + return + } + if len(sceneMarkers) == 0 { return } @@ -117,8 +135,16 @@ func (t *GenerateMarkersTask) generateMarker(videoFile *ffmpeg.VideoFile, scene func (t *GenerateMarkersTask) isMarkerNeeded() int { markers := 0 - qb := models.NewSceneMarkerQueryBuilder() - sceneMarkers, _ := qb.FindBySceneID(t.Scene.ID, nil) + var sceneMarkers []*models.SceneMarker + if err := t.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { + var err error + sceneMarkers, err = r.SceneMarker().FindBySceneID(t.Scene.ID) + return err + }); err != nil { + logger.Errorf("errror finding scene markers: %s", err.Error()) + return 0 + } + if len(sceneMarkers) == 0 { return 0 } diff --git a/pkg/manager/task_generate_screenshot.go b/pkg/manager/task_generate_screenshot.go index 73cdd3499..21a340a92 100644 --- a/pkg/manager/task_generate_screenshot.go +++ b/pkg/manager/task_generate_screenshot.go @@ -2,12 +2,12 @@ package manager import ( "context" + "fmt" "io/ioutil" "os" "sync" "time" - "github.com/stashapp/stash/pkg/database" "github.com/stashapp/stash/pkg/ffmpeg" "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/models" @@ -17,6 +17,7 @@ type GenerateScreenshotTask struct { Scene models.Scene ScreenshotAt *float64 fileNamingAlgorithm models.HashAlgorithm + txnManager models.TransactionManager } func (t *GenerateScreenshotTask) Start(wg *sync.WaitGroup) { @@ -60,39 +61,31 @@ func (t *GenerateScreenshotTask) Start(wg *sync.WaitGroup) { return } - ctx := context.TODO() - tx := database.DB.MustBeginTx(ctx, nil) + if err := t.txnManager.WithTxn(context.TODO(), func(r models.Repository) error { + qb := r.Scene() + updatedTime := time.Now() + updatedScene := models.ScenePartial{ + ID: t.Scene.ID, + UpdatedAt: &models.SQLiteTimestamp{Timestamp: updatedTime}, + } - qb := models.NewSceneQueryBuilder() - updatedTime := time.Now() - updatedScene := models.ScenePartial{ - ID: t.Scene.ID, - UpdatedAt: &models.SQLiteTimestamp{Timestamp: updatedTime}, - } + if err := SetSceneScreenshot(checksum, coverImageData); err != nil { + return fmt.Errorf("Error writing screenshot: %s", err.Error()) + } - if err := SetSceneScreenshot(checksum, coverImageData); err != nil { - logger.Errorf("Error writing screenshot: %s", err.Error()) - tx.Rollback() - return - } + // update the scene cover table + if err := qb.UpdateCover(t.Scene.ID, coverImageData); err != nil { + return fmt.Errorf("Error setting screenshot: %s", err.Error()) + } - // update the scene cover table - if err := qb.UpdateSceneCover(t.Scene.ID, coverImageData, tx); err != nil { - logger.Errorf("Error setting screenshot: %s", err.Error()) - tx.Rollback() - return - } + // update the scene with the update date + _, err = qb.Update(updatedScene) + if err != nil { + return fmt.Errorf("Error updating scene: %s", err.Error()) + } - // update the scene with the update date - _, err = qb.Update(updatedScene, tx) - if err != nil { - logger.Errorf("Error updating scene: %s", err.Error()) - tx.Rollback() - return - } - - if err := tx.Commit(); err != nil { - logger.Errorf("Error setting screenshot: %s", err.Error()) - return + return nil + }); err != nil { + logger.Error(err.Error()) } } diff --git a/pkg/manager/task_import.go b/pkg/manager/task_import.go index 600497e48..6e9a4489f 100644 --- a/pkg/manager/task_import.go +++ b/pkg/manager/task_import.go @@ -11,7 +11,6 @@ import ( "sync" "time" - "github.com/jmoiron/sqlx" "github.com/stashapp/stash/pkg/database" "github.com/stashapp/stash/pkg/gallery" "github.com/stashapp/stash/pkg/image" @@ -29,7 +28,8 @@ import ( ) type ImportTask struct { - json jsonUtils + txnManager models.TransactionManager + json jsonUtils BaseDir string ZipFile io.Reader @@ -44,6 +44,7 @@ type ImportTask struct { func CreateImportTask(a models.HashAlgorithm, input models.ImportObjectsInput) *ImportTask { return &ImportTask{ + txnManager: GetInstance().TxnManager, ZipFile: input.File.File, Reset: false, DuplicateBehaviour: input.DuplicateBehaviour, @@ -200,22 +201,16 @@ func (t *ImportTask) ImportPerformers(ctx context.Context) { logger.Progressf("[performers] %d of %d", index, len(t.mappings.Performers)) - tx := database.DB.MustBeginTx(ctx, nil) - readerWriter := models.NewPerformerReaderWriter(tx) - importer := &performer.Importer{ - ReaderWriter: readerWriter, - Input: *performerJSON, - } + if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error { + readerWriter := r.Performer() + importer := &performer.Importer{ + ReaderWriter: readerWriter, + Input: *performerJSON, + } - if err := performImport(importer, t.DuplicateBehaviour); err != nil { - tx.Rollback() - logger.Errorf("[performers] <%s> failed to import: %s", mappingJSON.Checksum, err.Error()) - continue - } - - if err := tx.Commit(); err != nil { - tx.Rollback() - logger.Errorf("[performers] <%s> import failed to commit: %s", mappingJSON.Checksum, err.Error()) + return performImport(importer, t.DuplicateBehaviour) + }); err != nil { + logger.Errorf("[performers] <%s> import failed: %s", mappingJSON.Checksum, err.Error()) } } @@ -237,12 +232,9 @@ func (t *ImportTask) ImportStudios(ctx context.Context) { logger.Progressf("[studios] %d of %d", index, len(t.mappings.Studios)) - tx := database.DB.MustBeginTx(ctx, nil) - - // fail on missing parent studio to begin with - if err := t.ImportStudio(studioJSON, pendingParent, tx); err != nil { - tx.Rollback() - + if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error { + return t.ImportStudio(studioJSON, pendingParent, r.Studio()) + }); err != nil { if err == studio.ErrParentStudioNotExist { // add to the pending parent list so that it is created after the parent s := pendingParent[studioJSON.ParentStudio] @@ -254,11 +246,6 @@ func (t *ImportTask) ImportStudios(ctx context.Context) { logger.Errorf("[studios] <%s> failed to create: %s", mappingJSON.Checksum, err.Error()) continue } - - if err := tx.Commit(); err != nil { - logger.Errorf("[studios] import failed to commit: %s", err.Error()) - continue - } } // create the leftover studios, warning for missing parents @@ -267,18 +254,12 @@ func (t *ImportTask) ImportStudios(ctx context.Context) { for _, s := range pendingParent { for _, orphanStudioJSON := range s { - tx := database.DB.MustBeginTx(ctx, nil) - - if err := t.ImportStudio(orphanStudioJSON, nil, tx); err != nil { - tx.Rollback() + if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error { + return t.ImportStudio(orphanStudioJSON, nil, r.Studio()) + }); err != nil { logger.Errorf("[studios] <%s> failed to create: %s", orphanStudioJSON.Name, err.Error()) continue } - - if err := tx.Commit(); err != nil { - logger.Errorf("[studios] import failed to commit: %s", err.Error()) - continue - } } } } @@ -286,8 +267,7 @@ func (t *ImportTask) ImportStudios(ctx context.Context) { logger.Info("[studios] import complete") } -func (t *ImportTask) ImportStudio(studioJSON *jsonschema.Studio, pendingParent map[string][]*jsonschema.Studio, tx *sqlx.Tx) error { - readerWriter := models.NewStudioReaderWriter(tx) +func (t *ImportTask) ImportStudio(studioJSON *jsonschema.Studio, pendingParent map[string][]*jsonschema.Studio, readerWriter models.StudioReaderWriter) error { importer := &studio.Importer{ ReaderWriter: readerWriter, Input: *studioJSON, @@ -307,7 +287,7 @@ func (t *ImportTask) ImportStudio(studioJSON *jsonschema.Studio, pendingParent m s := pendingParent[studioJSON.Name] for _, childStudioJSON := range s { // map is nil since we're not checking parent studios at this point - if err := t.ImportStudio(childStudioJSON, nil, tx); err != nil { + if err := t.ImportStudio(childStudioJSON, nil, readerWriter); err != nil { return fmt.Errorf("failed to create child studio <%s>: %s", childStudioJSON.Name, err.Error()) } } @@ -331,26 +311,20 @@ func (t *ImportTask) ImportMovies(ctx context.Context) { logger.Progressf("[movies] %d of %d", index, len(t.mappings.Movies)) - tx := database.DB.MustBeginTx(ctx, nil) - readerWriter := models.NewMovieReaderWriter(tx) - studioReaderWriter := models.NewStudioReaderWriter(tx) + if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error { + readerWriter := r.Movie() + studioReaderWriter := r.Studio() - movieImporter := &movie.Importer{ - ReaderWriter: readerWriter, - StudioWriter: studioReaderWriter, - Input: *movieJSON, - MissingRefBehaviour: t.MissingRefBehaviour, - } + movieImporter := &movie.Importer{ + ReaderWriter: readerWriter, + StudioWriter: studioReaderWriter, + Input: *movieJSON, + MissingRefBehaviour: t.MissingRefBehaviour, + } - if err := performImport(movieImporter, t.DuplicateBehaviour); err != nil { - tx.Rollback() - logger.Errorf("[movies] <%s> failed to import: %s", mappingJSON.Checksum, err.Error()) - continue - } - - if err := tx.Commit(); err != nil { - tx.Rollback() - logger.Errorf("[movies] <%s> import failed to commit: %s", mappingJSON.Checksum, err.Error()) + return performImport(movieImporter, t.DuplicateBehaviour) + }); err != nil { + logger.Errorf("[movies] <%s> import failed: %s", mappingJSON.Checksum, err.Error()) continue } } @@ -371,31 +345,23 @@ func (t *ImportTask) ImportGalleries(ctx context.Context) { logger.Progressf("[galleries] %d of %d", index, len(t.mappings.Galleries)) - tx := database.DB.MustBeginTx(ctx, nil) - readerWriter := models.NewGalleryReaderWriter(tx) - tagWriter := models.NewTagReaderWriter(tx) - joinWriter := models.NewJoinReaderWriter(tx) - performerWriter := models.NewPerformerReaderWriter(tx) - studioWriter := models.NewStudioReaderWriter(tx) + if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error { + readerWriter := r.Gallery() + tagWriter := r.Tag() + performerWriter := r.Performer() + studioWriter := r.Studio() - galleryImporter := &gallery.Importer{ - ReaderWriter: readerWriter, - PerformerWriter: performerWriter, - StudioWriter: studioWriter, - TagWriter: tagWriter, - JoinWriter: joinWriter, - Input: *galleryJSON, - MissingRefBehaviour: t.MissingRefBehaviour, - } + galleryImporter := &gallery.Importer{ + ReaderWriter: readerWriter, + PerformerWriter: performerWriter, + StudioWriter: studioWriter, + TagWriter: tagWriter, + Input: *galleryJSON, + MissingRefBehaviour: t.MissingRefBehaviour, + } - if err := performImport(galleryImporter, t.DuplicateBehaviour); err != nil { - tx.Rollback() - logger.Errorf("[galleries] <%s> failed to import: %s", mappingJSON.Checksum, err.Error()) - continue - } - - if err := tx.Commit(); err != nil { - tx.Rollback() + return performImport(galleryImporter, t.DuplicateBehaviour) + }); err != nil { logger.Errorf("[galleries] <%s> import failed to commit: %s", mappingJSON.Checksum, err.Error()) continue } @@ -417,74 +383,71 @@ func (t *ImportTask) ImportTags(ctx context.Context) { logger.Progressf("[tags] %d of %d", index, len(t.mappings.Tags)) - tx := database.DB.MustBeginTx(ctx, nil) - readerWriter := models.NewTagReaderWriter(tx) + if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error { + readerWriter := r.Tag() - tagImporter := &tag.Importer{ - ReaderWriter: readerWriter, - Input: *tagJSON, - } + tagImporter := &tag.Importer{ + ReaderWriter: readerWriter, + Input: *tagJSON, + } - if err := performImport(tagImporter, t.DuplicateBehaviour); err != nil { - tx.Rollback() + return performImport(tagImporter, t.DuplicateBehaviour) + }); err != nil { logger.Errorf("[tags] <%s> failed to import: %s", mappingJSON.Checksum, err.Error()) continue } - - if err := tx.Commit(); err != nil { - tx.Rollback() - logger.Errorf("[tags] <%s> import failed to commit: %s", mappingJSON.Checksum, err.Error()) - } } logger.Info("[tags] import complete") } func (t *ImportTask) ImportScrapedItems(ctx context.Context) { - tx := database.DB.MustBeginTx(ctx, nil) - qb := models.NewScrapedItemQueryBuilder() - sqb := models.NewStudioQueryBuilder() - currentTime := time.Now() + if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error { + logger.Info("[scraped sites] importing") + qb := r.ScrapedItem() + sqb := r.Studio() + currentTime := time.Now() - for i, mappingJSON := range t.scraped { - index := i + 1 - logger.Progressf("[scraped sites] %d of %d", index, len(t.mappings.Scenes)) + for i, mappingJSON := range t.scraped { + index := i + 1 + logger.Progressf("[scraped sites] %d of %d", index, len(t.mappings.Scenes)) - newScrapedItem := models.ScrapedItem{ - Title: sql.NullString{String: mappingJSON.Title, Valid: true}, - Description: sql.NullString{String: mappingJSON.Description, Valid: true}, - URL: sql.NullString{String: mappingJSON.URL, Valid: true}, - Date: models.SQLiteDate{String: mappingJSON.Date, Valid: true}, - Rating: sql.NullString{String: mappingJSON.Rating, Valid: true}, - Tags: sql.NullString{String: mappingJSON.Tags, Valid: true}, - Models: sql.NullString{String: mappingJSON.Models, Valid: true}, - Episode: sql.NullInt64{Int64: int64(mappingJSON.Episode), Valid: true}, - GalleryFilename: sql.NullString{String: mappingJSON.GalleryFilename, Valid: true}, - GalleryURL: sql.NullString{String: mappingJSON.GalleryURL, Valid: true}, - VideoFilename: sql.NullString{String: mappingJSON.VideoFilename, Valid: true}, - VideoURL: sql.NullString{String: mappingJSON.VideoURL, Valid: true}, - CreatedAt: models.SQLiteTimestamp{Timestamp: currentTime}, - UpdatedAt: models.SQLiteTimestamp{Timestamp: t.getTimeFromJSONTime(mappingJSON.UpdatedAt)}, + newScrapedItem := models.ScrapedItem{ + Title: sql.NullString{String: mappingJSON.Title, Valid: true}, + Description: sql.NullString{String: mappingJSON.Description, Valid: true}, + URL: sql.NullString{String: mappingJSON.URL, Valid: true}, + Date: models.SQLiteDate{String: mappingJSON.Date, Valid: true}, + Rating: sql.NullString{String: mappingJSON.Rating, Valid: true}, + Tags: sql.NullString{String: mappingJSON.Tags, Valid: true}, + Models: sql.NullString{String: mappingJSON.Models, Valid: true}, + Episode: sql.NullInt64{Int64: int64(mappingJSON.Episode), Valid: true}, + GalleryFilename: sql.NullString{String: mappingJSON.GalleryFilename, Valid: true}, + GalleryURL: sql.NullString{String: mappingJSON.GalleryURL, Valid: true}, + VideoFilename: sql.NullString{String: mappingJSON.VideoFilename, Valid: true}, + VideoURL: sql.NullString{String: mappingJSON.VideoURL, Valid: true}, + CreatedAt: models.SQLiteTimestamp{Timestamp: currentTime}, + UpdatedAt: models.SQLiteTimestamp{Timestamp: t.getTimeFromJSONTime(mappingJSON.UpdatedAt)}, + } + + studio, err := sqb.FindByName(mappingJSON.Studio, false) + if err != nil { + logger.Errorf("[scraped sites] failed to fetch studio: %s", err.Error()) + } + if studio != nil { + newScrapedItem.StudioID = sql.NullInt64{Int64: int64(studio.ID), Valid: true} + } + + _, err = qb.Create(newScrapedItem) + if err != nil { + logger.Errorf("[scraped sites] <%s> failed to create: %s", newScrapedItem.Title.String, err.Error()) + } } - studio, err := sqb.FindByName(mappingJSON.Studio, tx, false) - if err != nil { - logger.Errorf("[scraped sites] failed to fetch studio: %s", err.Error()) - } - if studio != nil { - newScrapedItem.StudioID = sql.NullInt64{Int64: int64(studio.ID), Valid: true} - } - - _, err = qb.Create(newScrapedItem, tx) - if err != nil { - logger.Errorf("[scraped sites] <%s> failed to create: %s", newScrapedItem.Title.String, err.Error()) - } - } - - logger.Info("[scraped sites] importing") - if err := tx.Commit(); err != nil { + return nil + }); err != nil { logger.Errorf("[scraped sites] import failed to commit: %s", err.Error()) } + logger.Info("[scraped sites] import complete") } @@ -504,65 +467,52 @@ func (t *ImportTask) ImportScenes(ctx context.Context) { sceneHash := mappingJSON.Checksum - tx := database.DB.MustBeginTx(ctx, nil) - readerWriter := models.NewSceneReaderWriter(tx) - tagWriter := models.NewTagReaderWriter(tx) - galleryWriter := models.NewGalleryReaderWriter(tx) - joinWriter := models.NewJoinReaderWriter(tx) - movieWriter := models.NewMovieReaderWriter(tx) - performerWriter := models.NewPerformerReaderWriter(tx) - studioWriter := models.NewStudioReaderWriter(tx) - markerWriter := models.NewSceneMarkerReaderWriter(tx) + if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error { + readerWriter := r.Scene() + tagWriter := r.Tag() + galleryWriter := r.Gallery() + movieWriter := r.Movie() + performerWriter := r.Performer() + studioWriter := r.Studio() + markerWriter := r.SceneMarker() - sceneImporter := &scene.Importer{ - ReaderWriter: readerWriter, - Input: *sceneJSON, - Path: mappingJSON.Path, + sceneImporter := &scene.Importer{ + ReaderWriter: readerWriter, + Input: *sceneJSON, + Path: mappingJSON.Path, - FileNamingAlgorithm: t.fileNamingAlgorithm, - MissingRefBehaviour: t.MissingRefBehaviour, - - GalleryWriter: galleryWriter, - JoinWriter: joinWriter, - MovieWriter: movieWriter, - PerformerWriter: performerWriter, - StudioWriter: studioWriter, - TagWriter: tagWriter, - } - - if err := performImport(sceneImporter, t.DuplicateBehaviour); err != nil { - tx.Rollback() - logger.Errorf("[scenes] <%s> failed to import: %s", sceneHash, err.Error()) - continue - } - - // import the scene markers - failedMarkers := false - for _, m := range sceneJSON.Markers { - markerImporter := &scene.MarkerImporter{ - SceneID: sceneImporter.ID, - Input: m, + FileNamingAlgorithm: t.fileNamingAlgorithm, MissingRefBehaviour: t.MissingRefBehaviour, - ReaderWriter: markerWriter, - JoinWriter: joinWriter, - TagWriter: tagWriter, + + GalleryWriter: galleryWriter, + MovieWriter: movieWriter, + PerformerWriter: performerWriter, + StudioWriter: studioWriter, + TagWriter: tagWriter, } - if err := performImport(markerImporter, t.DuplicateBehaviour); err != nil { - failedMarkers = true - logger.Errorf("[scenes] <%s> failed to import markers: %s", sceneHash, err.Error()) - break + if err := performImport(sceneImporter, t.DuplicateBehaviour); err != nil { + return err } - } - if failedMarkers { - tx.Rollback() - continue - } + // import the scene markers + for _, m := range sceneJSON.Markers { + markerImporter := &scene.MarkerImporter{ + SceneID: sceneImporter.ID, + Input: m, + MissingRefBehaviour: t.MissingRefBehaviour, + ReaderWriter: markerWriter, + TagWriter: tagWriter, + } - if err := tx.Commit(); err != nil { - tx.Rollback() - logger.Errorf("[scenes] <%s> import failed to commit: %s", sceneHash, err.Error()) + if err := performImport(markerImporter, t.DuplicateBehaviour); err != nil { + return err + } + } + + return nil + }); err != nil { + logger.Errorf("[scenes] <%s> import failed: %s", sceneHash, err.Error()) } } @@ -585,46 +535,37 @@ func (t *ImportTask) ImportImages(ctx context.Context) { imageHash := mappingJSON.Checksum - tx := database.DB.MustBeginTx(ctx, nil) - readerWriter := models.NewImageReaderWriter(tx) - tagWriter := models.NewTagReaderWriter(tx) - galleryWriter := models.NewGalleryReaderWriter(tx) - joinWriter := models.NewJoinReaderWriter(tx) - performerWriter := models.NewPerformerReaderWriter(tx) - studioWriter := models.NewStudioReaderWriter(tx) + if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error { + readerWriter := r.Image() + tagWriter := r.Tag() + galleryWriter := r.Gallery() + performerWriter := r.Performer() + studioWriter := r.Studio() - imageImporter := &image.Importer{ - ReaderWriter: readerWriter, - Input: *imageJSON, - Path: mappingJSON.Path, + imageImporter := &image.Importer{ + ReaderWriter: readerWriter, + Input: *imageJSON, + Path: mappingJSON.Path, - MissingRefBehaviour: t.MissingRefBehaviour, + MissingRefBehaviour: t.MissingRefBehaviour, - GalleryWriter: galleryWriter, - JoinWriter: joinWriter, - PerformerWriter: performerWriter, - StudioWriter: studioWriter, - TagWriter: tagWriter, - } + GalleryWriter: galleryWriter, + PerformerWriter: performerWriter, + StudioWriter: studioWriter, + TagWriter: tagWriter, + } - if err := performImport(imageImporter, t.DuplicateBehaviour); err != nil { - tx.Rollback() - logger.Errorf("[images] <%s> failed to import: %s", imageHash, err.Error()) - continue - } - - if err := tx.Commit(); err != nil { - tx.Rollback() - logger.Errorf("[images] <%s> import failed to commit: %s", imageHash, err.Error()) + return performImport(imageImporter, t.DuplicateBehaviour) + }); err != nil { + logger.Errorf("[images] <%s> import failed: %s", imageHash, err.Error()) } } logger.Info("[images] import complete") } -func (t *ImportTask) getPerformers(names []string, tx *sqlx.Tx) ([]*models.Performer, error) { - pqb := models.NewPerformerQueryBuilder() - performers, err := pqb.FindByNames(names, tx, false) +func (t *ImportTask) getPerformers(names []string, qb models.PerformerReader) ([]*models.Performer, error) { + performers, err := qb.FindByNames(names, false) if err != nil { return nil, err } @@ -648,12 +589,10 @@ func (t *ImportTask) getPerformers(names []string, tx *sqlx.Tx) ([]*models.Perfo return performers, nil } -func (t *ImportTask) getMoviesScenes(input []jsonschema.SceneMovie, sceneID int, tx *sqlx.Tx) ([]models.MoviesScenes, error) { - mqb := models.NewMovieQueryBuilder() - +func (t *ImportTask) getMoviesScenes(input []jsonschema.SceneMovie, sceneID int, mqb models.MovieReader) ([]models.MoviesScenes, error) { var movies []models.MoviesScenes for _, inputMovie := range input { - movie, err := mqb.FindByName(inputMovie.MovieName, tx, false) + movie, err := mqb.FindByName(inputMovie.MovieName, false) if err != nil { return nil, err } @@ -680,9 +619,8 @@ func (t *ImportTask) getMoviesScenes(input []jsonschema.SceneMovie, sceneID int, return movies, nil } -func (t *ImportTask) getTags(sceneChecksum string, names []string, tx *sqlx.Tx) ([]*models.Tag, error) { - tqb := models.NewTagQueryBuilder() - tags, err := tqb.FindByNames(names, tx, false) +func (t *ImportTask) getTags(sceneChecksum string, names []string, tqb models.TagReader) ([]*models.Tag, error) { + tags, err := tqb.FindByNames(names, false) if err != nil { return nil, err } diff --git a/pkg/manager/task_scan.go b/pkg/manager/task_scan.go index f3ac5299c..e3ba74ec0 100644 --- a/pkg/manager/task_scan.go +++ b/pkg/manager/task_scan.go @@ -11,19 +11,20 @@ import ( "strings" "time" - "github.com/jmoiron/sqlx" "github.com/remeh/sizedwaitgroup" - "github.com/stashapp/stash/pkg/database" "github.com/stashapp/stash/pkg/ffmpeg" + "github.com/stashapp/stash/pkg/gallery" "github.com/stashapp/stash/pkg/image" "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/manager/config" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/scene" "github.com/stashapp/stash/pkg/utils" ) type ScanTask struct { + TxnManager models.TransactionManager FilePath string UseFileMetadata bool StripFileExtension bool @@ -39,14 +40,18 @@ func (t *ScanTask) Start(wg *sizedwaitgroup.SizedWaitGroup) { if isGallery(t.FilePath) { t.scanGallery() } else if isVideo(t.FilePath) { - scene := t.scanScene() + s := t.scanScene() - if scene != nil { + if s != nil { iwg := sizedwaitgroup.New(2) if t.GenerateSprite { iwg.Add() - taskSprite := GenerateSpriteTask{Scene: *scene, Overwrite: false, fileNamingAlgorithm: t.fileNamingAlgorithm} + taskSprite := GenerateSpriteTask{ + Scene: *s, + Overwrite: false, + fileNamingAlgorithm: t.fileNamingAlgorithm, + } go taskSprite.Start(&iwg) } @@ -69,7 +74,7 @@ func (t *ScanTask) Start(wg *sizedwaitgroup.SizedWaitGroup) { } taskPreview := GeneratePreviewTask{ - Scene: *scene, + Scene: *s, ImagePreview: t.GenerateImagePreview, Options: previewOptions, Overwrite: false, @@ -88,8 +93,26 @@ func (t *ScanTask) Start(wg *sizedwaitgroup.SizedWaitGroup) { } func (t *ScanTask) scanGallery() { - qb := models.NewGalleryQueryBuilder() - gallery, _ := qb.FindByPath(t.FilePath) + var g *models.Gallery + images := 0 + scanImages := false + + if err := t.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { + var err error + g, err = r.Gallery().FindByPath(t.FilePath) + + if g != nil && err != nil { + images, err = r.Image().CountByGalleryID(g.ID) + if err != nil { + return fmt.Errorf("error getting images for zip gallery %s: %s", t.FilePath, err.Error()) + } + } + + return err + }); err != nil { + logger.Error(err.Error()) + return + } fileModTime, err := t.getFileModTime() if err != nil { @@ -97,20 +120,29 @@ func (t *ScanTask) scanGallery() { return } - if gallery != nil { + if g != nil { // We already have this item in the database, keep going // if file mod time is not set, set it now - // we will also need to rescan the zip contents - updateModTime := false - if !gallery.FileModTime.Valid { - updateModTime = true - t.updateFileModTime(gallery.ID, fileModTime, &qb) + if !g.FileModTime.Valid { + // we will also need to rescan the zip contents + scanImages = true + logger.Infof("setting file modification time on %s", t.FilePath) - // update our copy of the gallery - var err error - gallery, err = qb.Find(gallery.ID, nil) - if err != nil { + if err := t.TxnManager.WithTxn(context.TODO(), func(r models.Repository) error { + qb := r.Gallery() + if _, err := gallery.UpdateFileModTime(qb, g.ID, models.NullSQLiteTimestamp{ + Timestamp: fileModTime, + Valid: true, + }); err != nil { + return err + } + + // update our copy of the gallery + var err error + g, err = qb.Find(g.ID) + return err + }); err != nil { logger.Error(err.Error()) return } @@ -118,8 +150,9 @@ func (t *ScanTask) scanGallery() { // if the mod time of the zip file is different than that of the associated // gallery, then recalculate the checksum - modified := t.isFileModified(fileModTime, gallery.FileModTime) + modified := t.isFileModified(fileModTime, g.FileModTime) if modified { + scanImages = true logger.Infof("%s has been updated: rescanning", t.FilePath) // update the checksum and the modification time @@ -131,7 +164,7 @@ func (t *ScanTask) scanGallery() { currentTime := time.Now() galleryPartial := models.GalleryPartial{ - ID: gallery.ID, + ID: g.ID, Checksum: &checksum, FileModTime: &models.NullSQLiteTimestamp{ Timestamp: fileModTime, @@ -140,126 +173,97 @@ func (t *ScanTask) scanGallery() { UpdatedAt: &models.SQLiteTimestamp{Timestamp: currentTime}, } - err = database.WithTxn(func(tx *sqlx.Tx) error { - _, err := qb.UpdatePartial(galleryPartial, tx) + if err := t.TxnManager.WithTxn(context.TODO(), func(r models.Repository) error { + _, err := r.Gallery().UpdatePartial(galleryPartial) return err - }) - if err != nil { + }); err != nil { logger.Error(err.Error()) return } } // scan the zip files if the gallery has no images - iqb := models.NewImageQueryBuilder() - images, err := iqb.CountByGalleryID(gallery.ID) - if err != nil { - logger.Errorf("error getting images for zip gallery %s: %s", t.FilePath, err.Error()) + scanImages = scanImages || images == 0 + } else { + // Ignore directories. + if isDir, _ := utils.DirExists(t.FilePath); isDir { + return } - if images == 0 || modified || updateModTime { - t.scanZipImages(gallery) + checksum, err := t.calculateChecksum() + if err != nil { + logger.Error(err.Error()) + return + } + + if err := t.TxnManager.WithTxn(context.TODO(), func(r models.Repository) error { + qb := r.Gallery() + g, _ = qb.FindByChecksum(checksum) + if g != nil { + exists, _ := utils.FileExists(g.Path.String) + if exists { + logger.Infof("%s already exists. Duplicate of %s ", t.FilePath, g.Path.String) + } else { + logger.Infof("%s already exists. Updating path...", t.FilePath) + g.Path = sql.NullString{ + String: t.FilePath, + Valid: true, + } + g, err = qb.Update(*g) + if err != nil { + return err + } + } + } else { + currentTime := time.Now() + + newGallery := models.Gallery{ + Checksum: checksum, + Zip: true, + Path: sql.NullString{ + String: t.FilePath, + Valid: true, + }, + FileModTime: models.NullSQLiteTimestamp{ + Timestamp: fileModTime, + Valid: true, + }, + CreatedAt: models.SQLiteTimestamp{Timestamp: currentTime}, + UpdatedAt: models.SQLiteTimestamp{Timestamp: currentTime}, + } + + // don't create gallery if it has no images + if countImagesInZip(t.FilePath) > 0 { + // only warn when creating the gallery + ok, err := utils.IsZipFileUncompressed(t.FilePath) + if err == nil && !ok { + logger.Warnf("%s is using above store (0) level compression.", t.FilePath) + } + + logger.Infof("%s doesn't exist. Creating new item...", t.FilePath) + g, err = qb.Create(newGallery) + if err != nil { + return err + } + scanImages = true + } + } + + return nil + }); err != nil { + logger.Error(err.Error()) + return + } + } + + if g != nil { + if scanImages { + t.scanZipImages(g) } else { // in case thumbnails have been deleted, regenerate them - t.regenerateZipImages(gallery) - } - return - } - - // Ignore directories. - if isDir, _ := utils.DirExists(t.FilePath); isDir { - return - } - - checksum, err := t.calculateChecksum() - if err != nil { - logger.Error(err.Error()) - return - } - - ctx := context.TODO() - tx := database.DB.MustBeginTx(ctx, nil) - gallery, _ = qb.FindByChecksum(checksum, tx) - if gallery != nil { - exists, _ := utils.FileExists(gallery.Path.String) - if exists { - logger.Infof("%s already exists. Duplicate of %s ", t.FilePath, gallery.Path.String) - } else { - logger.Infof("%s already exists. Updating path...", t.FilePath) - gallery.Path = sql.NullString{ - String: t.FilePath, - Valid: true, - } - gallery, err = qb.Update(*gallery, tx) - } - } else { - currentTime := time.Now() - - newGallery := models.Gallery{ - Checksum: checksum, - Zip: true, - Path: sql.NullString{ - String: t.FilePath, - Valid: true, - }, - FileModTime: models.NullSQLiteTimestamp{ - Timestamp: fileModTime, - Valid: true, - }, - CreatedAt: models.SQLiteTimestamp{Timestamp: currentTime}, - UpdatedAt: models.SQLiteTimestamp{Timestamp: currentTime}, - } - - // don't create gallery if it has no images - if countImagesInZip(t.FilePath) > 0 { - // only warn when creating the gallery - ok, err := utils.IsZipFileUncompressed(t.FilePath) - if err == nil && !ok { - logger.Warnf("%s is using above store (0) level compression.", t.FilePath) - } - - logger.Infof("%s doesn't exist. Creating new item...", t.FilePath) - gallery, err = qb.Create(newGallery, tx) + t.regenerateZipImages(g) } } - - if err != nil { - logger.Error(err.Error()) - tx.Rollback() - return - } - - err = tx.Commit() - if err != nil { - logger.Error(err.Error()) - return - } - - // if the gallery has no associated images, then scan the zip for images - if gallery != nil { - t.scanZipImages(gallery) - } -} - -type fileModTimeUpdater interface { - UpdateFileModTime(id int, modTime models.NullSQLiteTimestamp, tx *sqlx.Tx) error -} - -func (t *ScanTask) updateFileModTime(id int, fileModTime time.Time, updater fileModTimeUpdater) error { - logger.Infof("setting file modification time on %s", t.FilePath) - - err := database.WithTxn(func(tx *sqlx.Tx) error { - return updater.UpdateFileModTime(id, models.NullSQLiteTimestamp{ - Timestamp: fileModTime, - Valid: true, - }, tx) - }) - - if err != nil { - return err - } - - return nil } func (t *ScanTask) getFileModTime() (time.Time, error) { @@ -281,170 +285,185 @@ func (t *ScanTask) isFileModified(fileModTime time.Time, modTime models.NullSQLi // associates a gallery to a scene with the same basename func (t *ScanTask) associateGallery(wg *sizedwaitgroup.SizedWaitGroup) { - qb := models.NewGalleryQueryBuilder() - gallery, _ := qb.FindByPath(t.FilePath) - if gallery == nil { - // associate is run after scan is finished - // should only happen if gallery is a directory or an io error occurs during hashing - logger.Warnf("associate: gallery %s not found in DB", t.FilePath) - wg.Done() - return - } - - // gallery has no SceneID - if !gallery.SceneID.Valid { - basename := strings.TrimSuffix(t.FilePath, filepath.Ext(t.FilePath)) - var relatedFiles []string - vExt := config.GetVideoExtensions() - // make a list of media files that can be related to the gallery - for _, ext := range vExt { - related := basename + "." + ext - // exclude gallery extensions from the related files - if !isGallery(related) { - relatedFiles = append(relatedFiles, related) - } + if err := t.TxnManager.WithTxn(context.TODO(), func(r models.Repository) error { + qb := r.Gallery() + sqb := r.Scene() + g, err := qb.FindByPath(t.FilePath) + if err != nil { + return err } - for _, scenePath := range relatedFiles { - qbScene := models.NewSceneQueryBuilder() - scene, _ := qbScene.FindByPath(scenePath) - // found related Scene - if scene != nil { - logger.Infof("associate: Gallery %s is related to scene: %d", t.FilePath, scene.ID) - gallery.SceneID.Int64 = int64(scene.ID) - gallery.SceneID.Valid = true + if g == nil { + // associate is run after scan is finished + // should only happen if gallery is a directory or an io error occurs during hashing + logger.Warnf("associate: gallery %s not found in DB", t.FilePath) + return nil + } - ctx := context.TODO() - tx := database.DB.MustBeginTx(ctx, nil) - - _, err := qb.Update(*gallery, tx) + // gallery has no SceneID + if !g.SceneID.Valid { + basename := strings.TrimSuffix(t.FilePath, filepath.Ext(t.FilePath)) + var relatedFiles []string + vExt := config.GetVideoExtensions() + // make a list of media files that can be related to the gallery + for _, ext := range vExt { + related := basename + "." + ext + // exclude gallery extensions from the related files + if !isGallery(related) { + relatedFiles = append(relatedFiles, related) + } + } + for _, scenePath := range relatedFiles { + s, err := sqb.FindByPath(scenePath) if err != nil { - logger.Errorf("associate: Error updating gallery sceneId %s", err) - _ = tx.Rollback() - } else if err := tx.Commit(); err != nil { - logger.Error(err.Error()) + return err } - // since a gallery can have only one related scene - // only first found is associated - break + // found related Scene + if s != nil { + logger.Infof("associate: Gallery %s is related to scene: %d", t.FilePath, s.ID) + + g.SceneID.Int64 = int64(s.ID) + g.SceneID.Valid = true + + _, err = qb.Update(*g) + if err != nil { + return fmt.Errorf("associate: Error updating gallery sceneId %s", err) + } + + // since a gallery can have only one related scene + // only first found is associated + break + } } } + + return nil + }); err != nil { + logger.Error(err.Error()) } wg.Done() } func (t *ScanTask) scanScene() *models.Scene { - qb := models.NewSceneQueryBuilder() - scene, _ := qb.FindByPath(t.FilePath) - - fileModTime, err := t.getFileModTime() - if err != nil { + logError := func(err error) *models.Scene { logger.Error(err.Error()) return nil } - if scene != nil { - // if file mod time is not set, set it now - if !scene.FileModTime.Valid { - t.updateFileModTime(scene.ID, fileModTime, &qb) + var retScene *models.Scene + var s *models.Scene - // update our copy of the scene - var err error - scene, err = qb.Find(scene.ID) - if err != nil { - logger.Error(err.Error()) - return nil + if err := t.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { + var err error + s, err = r.Scene().FindByPath(t.FilePath) + return err + }); err != nil { + logger.Error(err.Error()) + return nil + } + + fileModTime, err := t.getFileModTime() + if err != nil { + return logError(err) + } + + if s != nil { + // if file mod time is not set, set it now + if !s.FileModTime.Valid { + logger.Infof("setting file modification time on %s", t.FilePath) + + if err := t.TxnManager.WithTxn(context.TODO(), func(r models.Repository) error { + qb := r.Scene() + if _, err := scene.UpdateFileModTime(qb, s.ID, models.NullSQLiteTimestamp{ + Timestamp: fileModTime, + Valid: true, + }); err != nil { + return err + } + + // update our copy of the scene + var err error + s, err = qb.Find(s.ID) + return err + }); err != nil { + return logError(err) } } // if the mod time of the file is different than that of the associated // scene, then recalculate the checksum and regenerate the thumbnail - modified := t.isFileModified(fileModTime, scene.FileModTime) - if modified || !scene.Size.Valid { - scene, err = t.rescanScene(scene, fileModTime) + modified := t.isFileModified(fileModTime, s.FileModTime) + if modified || !s.Size.Valid { + s, err = t.rescanScene(s, fileModTime) if err != nil { - logger.Error(err.Error()) - return nil + return logError(err) } } // We already have this item in the database // check for thumbnails,screenshots - t.makeScreenshots(nil, scene.GetHash(t.fileNamingAlgorithm)) + t.makeScreenshots(nil, s.GetHash(t.fileNamingAlgorithm)) // check for container - if !scene.Format.Valid { + if !s.Format.Valid { videoFile, err := ffmpeg.NewVideoFile(instance.FFProbePath, t.FilePath, t.StripFileExtension) if err != nil { - logger.Error(err.Error()) - return nil + return logError(err) } container := ffmpeg.MatchContainer(videoFile.Container, t.FilePath) logger.Infof("Adding container %s to file %s", container, t.FilePath) - ctx := context.TODO() - tx := database.DB.MustBeginTx(ctx, nil) - err = qb.UpdateFormat(scene.ID, string(container), tx) - if err != nil { - logger.Error(err.Error()) - _ = tx.Rollback() - } else if err := tx.Commit(); err != nil { - logger.Error(err.Error()) + if err := t.TxnManager.WithTxn(context.TODO(), func(r models.Repository) error { + _, err := scene.UpdateFormat(r.Scene(), s.ID, string(container)) + return err + }); err != nil { + return logError(err) } } // check if oshash is set - if !scene.OSHash.Valid { + if !s.OSHash.Valid { logger.Infof("Calculating oshash for existing file %s ...", t.FilePath) oshash, err := utils.OSHashFromFilePath(t.FilePath) if err != nil { - logger.Error(err.Error()) return nil } - // check if oshash clashes with existing scene - dupe, _ := qb.FindByOSHash(oshash) - if dupe != nil { - logger.Errorf("OSHash for file %s is the same as that of %s", t.FilePath, dupe.Path) - return nil - } + if err := t.TxnManager.WithTxn(context.TODO(), func(r models.Repository) error { + qb := r.Scene() + // check if oshash clashes with existing scene + dupe, _ := qb.FindByOSHash(oshash) + if dupe != nil { + return fmt.Errorf("OSHash for file %s is the same as that of %s", t.FilePath, dupe.Path) + } - ctx := context.TODO() - tx := database.DB.MustBeginTx(ctx, nil) - err = qb.UpdateOSHash(scene.ID, oshash, tx) - if err != nil { - logger.Error(err.Error()) - tx.Rollback() - return nil - } else if err := tx.Commit(); err != nil { - logger.Error(err.Error()) + _, err := scene.UpdateOSHash(qb, s.ID, oshash) + return err + }); err != nil { + return logError(err) } } // check if MD5 is set, if calculateMD5 is true - if t.calculateMD5 && !scene.Checksum.Valid { + if t.calculateMD5 && !s.Checksum.Valid { checksum, err := t.calculateChecksum() if err != nil { - logger.Error(err.Error()) - return nil + return logError(err) } - // check if checksum clashes with existing scene - dupe, _ := qb.FindByChecksum(checksum) - if dupe != nil { - logger.Errorf("MD5 for file %s is the same as that of %s", t.FilePath, dupe.Path) - return nil - } + if err := t.TxnManager.WithTxn(context.TODO(), func(r models.Repository) error { + qb := r.Scene() + // check if checksum clashes with existing scene + dupe, _ := qb.FindByChecksum(checksum) + if dupe != nil { + return fmt.Errorf("MD5 for file %s is the same as that of %s", t.FilePath, dupe.Path) + } - ctx := context.TODO() - tx := database.DB.MustBeginTx(ctx, nil) - err = qb.UpdateChecksum(scene.ID, checksum, tx) - if err != nil { - logger.Error(err.Error()) - _ = tx.Rollback() - } else if err := tx.Commit(); err != nil { - logger.Error(err.Error()) + _, err := scene.UpdateChecksum(qb, s.ID, checksum) + return err + }); err != nil { + return logError(err) } } @@ -473,27 +492,30 @@ func (t *ScanTask) scanScene() *models.Scene { logger.Infof("%s not found. Calculating oshash...", t.FilePath) oshash, err := utils.OSHashFromFilePath(t.FilePath) if err != nil { - logger.Error(err.Error()) - return nil + return logError(err) } if t.fileNamingAlgorithm == models.HashAlgorithmMd5 || t.calculateMD5 { checksum, err = t.calculateChecksum() if err != nil { - logger.Error(err.Error()) - return nil + return logError(err) } } // check for scene by checksum and oshash - MD5 should be // redundant, but check both - if checksum != "" { - scene, _ = qb.FindByChecksum(checksum) - } + t.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { + qb := r.Scene() + if checksum != "" { + s, _ = qb.FindByChecksum(checksum) + } - if scene == nil { - scene, _ = qb.FindByOSHash(oshash) - } + if s == nil { + s, _ = qb.FindByOSHash(oshash) + } + + return nil + }) sceneHash := oshash @@ -503,21 +525,22 @@ func (t *ScanTask) scanScene() *models.Scene { t.makeScreenshots(videoFile, sceneHash) - var retScene *models.Scene - - ctx := context.TODO() - tx := database.DB.MustBeginTx(ctx, nil) - if scene != nil { - exists, _ := utils.FileExists(scene.Path) + if s != nil { + exists, _ := utils.FileExists(s.Path) if exists { - logger.Infof("%s already exists. Duplicate of %s", t.FilePath, scene.Path) + logger.Infof("%s already exists. Duplicate of %s", t.FilePath, s.Path) } else { logger.Infof("%s already exists. Updating path...", t.FilePath) scenePartial := models.ScenePartial{ - ID: scene.ID, + ID: s.ID, Path: &t.FilePath, } - _, err = qb.Update(scenePartial, tx) + if err := t.TxnManager.WithTxn(context.TODO(), func(r models.Repository) error { + _, err := r.Scene().Update(scenePartial) + return err + }); err != nil { + return logError(err) + } } } else { logger.Infof("%s doesn't exist. Creating new item...", t.FilePath) @@ -549,23 +572,19 @@ func (t *ScanTask) scanScene() *models.Scene { newScene.Date = models.SQLiteDate{String: videoFile.CreationTime.Format("2006-01-02")} } - retScene, err = qb.Create(newScene, tx) - } - - if err != nil { - logger.Error(err.Error()) - _ = tx.Rollback() - return nil - - } else if err := tx.Commit(); err != nil { - logger.Error(err.Error()) - return nil + if err := t.TxnManager.WithTxn(context.TODO(), func(r models.Repository) error { + var err error + retScene, err = r.Scene().Create(newScene) + return err + }); err != nil { + return logError(err) + } } return retScene } -func (t *ScanTask) rescanScene(scene *models.Scene, fileModTime time.Time) (*models.Scene, error) { +func (t *ScanTask) rescanScene(s *models.Scene, fileModTime time.Time) (*models.Scene, error) { logger.Infof("%s has been updated: rescanning", t.FilePath) // update the oshash/checksum and the modification time @@ -597,7 +616,7 @@ func (t *ScanTask) rescanScene(scene *models.Scene, fileModTime time.Time) (*mod currentTime := time.Now() scenePartial := models.ScenePartial{ - ID: scene.ID, + ID: s.ID, Checksum: checksum, OSHash: &sql.NullString{ String: oshash, @@ -620,13 +639,11 @@ func (t *ScanTask) rescanScene(scene *models.Scene, fileModTime time.Time) (*mod } var ret *models.Scene - err = database.WithTxn(func(tx *sqlx.Tx) error { - qb := models.NewSceneQueryBuilder() - var txnErr error - ret, txnErr = qb.Update(scenePartial, tx) - return txnErr - }) - if err != nil { + if err := t.TxnManager.WithTxn(context.TODO(), func(r models.Repository) error { + var err error + ret, err = r.Scene().Update(scenePartial) + return err + }); err != nil { logger.Error(err.Error()) return nil, err } @@ -692,10 +709,14 @@ func (t *ScanTask) scanZipImages(zipGallery *models.Gallery) { } func (t *ScanTask) regenerateZipImages(zipGallery *models.Gallery) { - iqb := models.NewImageQueryBuilder() + var images []*models.Image + if err := t.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { + iqb := r.Image() - images, err := iqb.FindByGalleryID(zipGallery.ID) - if err != nil { + var err error + images, err = iqb.FindByGalleryID(zipGallery.ID) + return err + }); err != nil { logger.Warnf("failed to find gallery images: %s", err.Error()) return } @@ -706,8 +727,16 @@ func (t *ScanTask) regenerateZipImages(zipGallery *models.Gallery) { } func (t *ScanTask) scanImage() { - qb := models.NewImageQueryBuilder() - i, _ := qb.FindByPath(t.FilePath) + var i *models.Image + + if err := t.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { + var err error + i, err = r.Image().FindByPath(t.FilePath) + return err + }); err != nil { + logger.Error(err.Error()) + return + } fileModTime, err := image.GetFileModTime(t.FilePath) if err != nil { @@ -718,12 +747,22 @@ func (t *ScanTask) scanImage() { if i != nil { // if file mod time is not set, set it now if !i.FileModTime.Valid { - t.updateFileModTime(i.ID, fileModTime, &qb) + logger.Infof("setting file modification time on %s", t.FilePath) - // update our copy of the gallery - var err error - i, err = qb.Find(i.ID) - if err != nil { + if err := t.TxnManager.WithTxn(context.TODO(), func(r models.Repository) error { + qb := r.Image() + if _, err := image.UpdateFileModTime(qb, i.ID, models.NullSQLiteTimestamp{ + Timestamp: fileModTime, + Valid: true, + }); err != nil { + return err + } + + // update our copy of the gallery + var err error + i, err = qb.Find(i.ID) + return err + }); err != nil { logger.Error(err.Error()) return } @@ -743,83 +782,102 @@ func (t *ScanTask) scanImage() { // We already have this item in the database // check for thumbnails t.generateThumbnail(i) - - return - } - - // Ignore directories. - if isDir, _ := utils.DirExists(t.FilePath); isDir { - return - } - - var checksum string - - logger.Infof("%s not found. Calculating checksum...", t.FilePath) - checksum, err = t.calculateImageChecksum() - if err != nil { - logger.Errorf("error calculating checksum for %s: %s", t.FilePath, err.Error()) - return - } - - // check for scene by checksum and oshash - MD5 should be - // redundant, but check both - i, _ = qb.FindByChecksum(checksum) - - ctx := context.TODO() - tx := database.DB.MustBeginTx(ctx, nil) - if i != nil { - exists := image.FileExists(i.Path) - if exists { - logger.Infof("%s already exists. Duplicate of %s ", image.PathDisplayName(t.FilePath), image.PathDisplayName(i.Path)) - } else { - logger.Infof("%s already exists. Updating path...", image.PathDisplayName(t.FilePath)) - imagePartial := models.ImagePartial{ - ID: i.ID, - Path: &t.FilePath, - } - _, err = qb.Update(imagePartial, tx) - } } else { - logger.Infof("%s doesn't exist. Creating new item...", image.PathDisplayName(t.FilePath)) - currentTime := time.Now() - newImage := models.Image{ - Checksum: checksum, - Path: t.FilePath, - FileModTime: models.NullSQLiteTimestamp{ - Timestamp: fileModTime, - Valid: true, - }, - CreatedAt: models.SQLiteTimestamp{Timestamp: currentTime}, - UpdatedAt: models.SQLiteTimestamp{Timestamp: currentTime}, + // Ignore directories. + if isDir, _ := utils.DirExists(t.FilePath); isDir { + return + } + + var checksum string + + logger.Infof("%s not found. Calculating checksum...", t.FilePath) + checksum, err = t.calculateImageChecksum() + if err != nil { + logger.Errorf("error calculating checksum for %s: %s", t.FilePath, err.Error()) + return + } + + // check for scene by checksum and oshash - MD5 should be + // redundant, but check both + if err := t.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { + var err error + i, err = r.Image().FindByChecksum(checksum) + return err + }); err != nil { + logger.Error(err.Error()) + return + } + + if i != nil { + exists := image.FileExists(i.Path) + if exists { + logger.Infof("%s already exists. Duplicate of %s ", image.PathDisplayName(t.FilePath), image.PathDisplayName(i.Path)) + } else { + logger.Infof("%s already exists. Updating path...", image.PathDisplayName(t.FilePath)) + imagePartial := models.ImagePartial{ + ID: i.ID, + Path: &t.FilePath, + } + + if err := t.TxnManager.WithTxn(context.TODO(), func(r models.Repository) error { + _, err := r.Image().Update(imagePartial) + return err + }); err != nil { + logger.Error(err.Error()) + return + } + } + } else { + logger.Infof("%s doesn't exist. Creating new item...", image.PathDisplayName(t.FilePath)) + currentTime := time.Now() + newImage := models.Image{ + Checksum: checksum, + Path: t.FilePath, + FileModTime: models.NullSQLiteTimestamp{ + Timestamp: fileModTime, + Valid: true, + }, + CreatedAt: models.SQLiteTimestamp{Timestamp: currentTime}, + UpdatedAt: models.SQLiteTimestamp{Timestamp: currentTime}, + } + if err := image.SetFileDetails(&newImage); err != nil { + logger.Error(err.Error()) + return + } + + if err := t.TxnManager.WithTxn(context.TODO(), func(r models.Repository) error { + var err error + i, err = r.Image().Create(newImage) + return err + }); err != nil { + logger.Error(err.Error()) + return + } } - err = image.SetFileDetails(&newImage) - if err == nil { - i, err = qb.Create(newImage, tx) - } - } - if err == nil { - jqb := models.NewJoinsQueryBuilder() if t.zipGallery != nil { // associate with gallery - _, err = jqb.AddImageGallery(i.ID, t.zipGallery.ID, tx) + if err := t.TxnManager.WithTxn(context.TODO(), func(r models.Repository) error { + return gallery.AddImage(r.Gallery(), t.zipGallery.ID, i.ID) + }); err != nil { + logger.Error(err.Error()) + return + } } else if config.GetCreateGalleriesFromFolders() { // create gallery from folder or associate with existing gallery logger.Infof("Associating image %s with folder gallery", i.Path) - err = t.associateImageWithFolderGallery(i.ID, tx) + if err := t.TxnManager.WithTxn(context.TODO(), func(r models.Repository) error { + return t.associateImageWithFolderGallery(i.ID, r.Gallery()) + }); err != nil { + logger.Error(err.Error()) + return + } } } - if err != nil { - logger.Error(err.Error()) - _ = tx.Rollback() - return - } else if err := tx.Commit(); err != nil { - logger.Error(err.Error()) - return + if i != nil { + t.generateThumbnail(i) } - - t.generateThumbnail(i) } func (t *ScanTask) rescanImage(i *models.Image, fileModTime time.Time) (*models.Image, error) { @@ -854,13 +912,11 @@ func (t *ScanTask) rescanImage(i *models.Image, fileModTime time.Time) (*models. } var ret *models.Image - err = database.WithTxn(func(tx *sqlx.Tx) error { - qb := models.NewImageQueryBuilder() - var txnErr error - ret, txnErr = qb.Update(imagePartial, tx) - return txnErr - }) - if err != nil { + if err := t.TxnManager.WithTxn(context.TODO(), func(r models.Repository) error { + var err error + ret, err = r.Image().Update(imagePartial) + return err + }); err != nil { return nil, err } @@ -875,12 +931,10 @@ func (t *ScanTask) rescanImage(i *models.Image, fileModTime time.Time) (*models. return ret, nil } -func (t *ScanTask) associateImageWithFolderGallery(imageID int, tx *sqlx.Tx) error { +func (t *ScanTask) associateImageWithFolderGallery(imageID int, qb models.GalleryReaderWriter) error { // find a gallery with the path specified path := filepath.Dir(t.FilePath) - gqb := models.NewGalleryQueryBuilder() - jqb := models.NewJoinsQueryBuilder() - g, err := gqb.FindByPath(path) + g, err := qb.FindByPath(path) if err != nil { return err } @@ -902,14 +956,14 @@ func (t *ScanTask) associateImageWithFolderGallery(imageID int, tx *sqlx.Tx) err } logger.Infof("Creating gallery for folder %s", path) - g, err = gqb.Create(newGallery, tx) + g, err = qb.Create(newGallery) if err != nil { return err } } // associate image with gallery - _, err = jqb.AddImageGallery(imageID, g.ID, tx) + err = gallery.AddImage(qb, g.ID, imageID) return err } @@ -966,27 +1020,29 @@ func (t *ScanTask) doesPathExist() bool { imgExt := config.GetImageExtensions() gExt := config.GetGalleryExtensions() - if matchExtension(t.FilePath, gExt) { - qb := models.NewGalleryQueryBuilder() - gallery, _ := qb.FindByPath(t.FilePath) - if gallery != nil { - return true + ret := false + t.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { + if matchExtension(t.FilePath, gExt) { + gallery, _ := r.Gallery().FindByPath(t.FilePath) + if gallery != nil { + ret = true + } + } else if matchExtension(t.FilePath, vidExt) { + s, _ := r.Scene().FindByPath(t.FilePath) + if s != nil { + ret = true + } + } else if matchExtension(t.FilePath, imgExt) { + i, _ := r.Image().FindByPath(t.FilePath) + if i != nil { + ret = true + } } - } else if matchExtension(t.FilePath, vidExt) { - qb := models.NewSceneQueryBuilder() - scene, _ := qb.FindByPath(t.FilePath) - if scene != nil { - return true - } - } else if matchExtension(t.FilePath, imgExt) { - qb := models.NewImageQueryBuilder() - i, _ := qb.FindByPath(t.FilePath) - if i != nil { - return true - } - } - return false + return nil + }) + + return ret } func walkFilesToScan(s *models.StashConfig, f filepath.WalkFunc) error { diff --git a/pkg/models/gallery.go b/pkg/models/gallery.go index 0aa5f3cc3..f4c1b006d 100644 --- a/pkg/models/gallery.go +++ b/pkg/models/gallery.go @@ -1,74 +1,34 @@ package models -import ( - "github.com/jmoiron/sqlx" -) - type GalleryReader interface { - // Find(id int) (*Gallery, error) + Find(id int) (*Gallery, error) FindMany(ids []int) ([]*Gallery, error) FindByChecksum(checksum string) (*Gallery, error) FindByPath(path string) (*Gallery, error) FindBySceneID(sceneID int) (*Gallery, error) FindByImageID(imageID int) ([]*Gallery, error) - // ValidGalleriesForScenePath(scenePath string) ([]*Gallery, error) - // Count() (int, error) + ValidGalleriesForScenePath(scenePath string) ([]*Gallery, error) + Count() (int, error) All() ([]*Gallery, error) - // Query(galleryFilter *GalleryFilterType, findFilter *FindFilterType) ([]*Gallery, int) + Query(galleryFilter *GalleryFilterType, findFilter *FindFilterType) ([]*Gallery, int, error) + GetPerformerIDs(galleryID int) ([]int, error) + GetTagIDs(galleryID int) ([]int, error) + GetImageIDs(galleryID int) ([]int, error) } type GalleryWriter interface { Create(newGallery Gallery) (*Gallery, error) Update(updatedGallery Gallery) (*Gallery, error) - // Destroy(id int) error - // ClearGalleryId(sceneID int) error + UpdatePartial(updatedGallery GalleryPartial) (*Gallery, error) + UpdateFileModTime(id int, modTime NullSQLiteTimestamp) error + Destroy(id int) error + ClearGalleryId(sceneID int) error + UpdatePerformers(galleryID int, performerIDs []int) error + UpdateTags(galleryID int, tagIDs []int) error + UpdateImages(galleryID int, imageIDs []int) error } type GalleryReaderWriter interface { GalleryReader GalleryWriter } - -func NewGalleryReaderWriter(tx *sqlx.Tx) GalleryReaderWriter { - return &galleryReaderWriter{ - tx: tx, - qb: NewGalleryQueryBuilder(), - } -} - -type galleryReaderWriter struct { - tx *sqlx.Tx - qb GalleryQueryBuilder -} - -func (t *galleryReaderWriter) FindMany(ids []int) ([]*Gallery, error) { - return t.qb.FindMany(ids) -} - -func (t *galleryReaderWriter) FindByChecksum(checksum string) (*Gallery, error) { - return t.qb.FindByChecksum(checksum, t.tx) -} - -func (t *galleryReaderWriter) All() ([]*Gallery, error) { - return t.qb.All() -} - -func (t *galleryReaderWriter) FindByPath(path string) (*Gallery, error) { - return t.qb.FindByPath(path) -} - -func (t *galleryReaderWriter) FindBySceneID(sceneID int) (*Gallery, error) { - return t.qb.FindBySceneID(sceneID, t.tx) -} - -func (t *galleryReaderWriter) FindByImageID(imageID int) ([]*Gallery, error) { - return t.qb.FindByImageID(imageID, t.tx) -} - -func (t *galleryReaderWriter) Create(newGallery Gallery) (*Gallery, error) { - return t.qb.Create(newGallery, t.tx) -} - -func (t *galleryReaderWriter) Update(updatedGallery Gallery) (*Gallery, error) { - return t.qb.Update(updatedGallery, t.tx) -} diff --git a/pkg/models/image.go b/pkg/models/image.go index 5ea398bed..d160aeba5 100644 --- a/pkg/models/image.go +++ b/pkg/models/image.go @@ -1,77 +1,41 @@ package models -import ( - "github.com/jmoiron/sqlx" -) - type ImageReader interface { - // Find(id int) (*Image, error) + Find(id int) (*Image, error) FindMany(ids []int) ([]*Image, error) FindByChecksum(checksum string) (*Image, error) FindByGalleryID(galleryID int) ([]*Image, error) - // FindByPath(path string) (*Image, error) + CountByGalleryID(galleryID int) (int, error) + FindByPath(path string) (*Image, error) // FindByPerformerID(performerID int) ([]*Image, error) // CountByPerformerID(performerID int) (int, error) // FindByStudioID(studioID int) ([]*Image, error) - // Count() (int, error) + Count() (int, error) + Size() (float64, error) // SizeCount() (string, error) // CountByStudioID(studioID int) (int, error) // CountByTagID(tagID int) (int, error) All() ([]*Image, error) - // Query(imageFilter *ImageFilterType, findFilter *FindFilterType) ([]*Image, int) + Query(imageFilter *ImageFilterType, findFilter *FindFilterType) ([]*Image, int, error) + GetGalleryIDs(imageID int) ([]int, error) + GetTagIDs(imageID int) ([]int, error) + GetPerformerIDs(imageID int) ([]int, error) } type ImageWriter interface { Create(newImage Image) (*Image, error) Update(updatedImage ImagePartial) (*Image, error) UpdateFull(updatedImage Image) (*Image, error) - // IncrementOCounter(id int) (int, error) - // DecrementOCounter(id int) (int, error) - // ResetOCounter(id int) (int, error) - // Destroy(id string) error + IncrementOCounter(id int) (int, error) + DecrementOCounter(id int) (int, error) + ResetOCounter(id int) (int, error) + Destroy(id int) error + UpdateGalleries(imageID int, galleryIDs []int) error + UpdatePerformers(imageID int, performerIDs []int) error + UpdateTags(imageID int, tagIDs []int) error } type ImageReaderWriter interface { ImageReader ImageWriter } - -func NewImageReaderWriter(tx *sqlx.Tx) ImageReaderWriter { - return &imageReaderWriter{ - tx: tx, - qb: NewImageQueryBuilder(), - } -} - -type imageReaderWriter struct { - tx *sqlx.Tx - qb ImageQueryBuilder -} - -func (t *imageReaderWriter) FindMany(ids []int) ([]*Image, error) { - return t.qb.FindMany(ids) -} - -func (t *imageReaderWriter) FindByChecksum(checksum string) (*Image, error) { - return t.qb.FindByChecksum(checksum) -} - -func (t *imageReaderWriter) FindByGalleryID(galleryID int) ([]*Image, error) { - return t.qb.FindByGalleryID(galleryID) -} - -func (t *imageReaderWriter) All() ([]*Image, error) { - return t.qb.All() -} - -func (t *imageReaderWriter) Create(newImage Image) (*Image, error) { - return t.qb.Create(newImage, t.tx) -} - -func (t *imageReaderWriter) Update(updatedImage ImagePartial) (*Image, error) { - return t.qb.Update(updatedImage, t.tx) -} - -func (t *imageReaderWriter) UpdateFull(updatedImage Image) (*Image, error) { - return t.qb.UpdateFull(updatedImage, t.tx) -} diff --git a/pkg/models/join.go b/pkg/models/join.go deleted file mode 100644 index 3b0d259ba..000000000 --- a/pkg/models/join.go +++ /dev/null @@ -1,101 +0,0 @@ -package models - -import ( - "github.com/jmoiron/sqlx" -) - -type JoinReader interface { - // GetScenePerformers(sceneID int) ([]PerformersScenes, error) - GetSceneMovies(sceneID int) ([]MoviesScenes, error) - // GetSceneTags(sceneID int) ([]ScenesTags, error) -} - -type JoinWriter interface { - CreatePerformersScenes(newJoins []PerformersScenes) error - // AddPerformerScene(sceneID int, performerID int) (bool, error) - UpdatePerformersScenes(sceneID int, updatedJoins []PerformersScenes) error - // DestroyPerformersScenes(sceneID int) error - CreateMoviesScenes(newJoins []MoviesScenes) error - // AddMoviesScene(sceneID int, movieID int, sceneIdx *int) (bool, error) - UpdateMoviesScenes(sceneID int, updatedJoins []MoviesScenes) error - // DestroyMoviesScenes(sceneID int) error - // CreateScenesTags(newJoins []ScenesTags) error - UpdateScenesTags(sceneID int, updatedJoins []ScenesTags) error - // AddSceneTag(sceneID int, tagID int) (bool, error) - // DestroyScenesTags(sceneID int) error - // CreateSceneMarkersTags(newJoins []SceneMarkersTags) error - UpdateSceneMarkersTags(sceneMarkerID int, updatedJoins []SceneMarkersTags) error - // DestroySceneMarkersTags(sceneMarkerID int, updatedJoins []SceneMarkersTags) error - // DestroyScenesGalleries(sceneID int) error - // DestroyScenesMarkers(sceneID int) error - UpdatePerformersGalleries(galleryID int, updatedJoins []PerformersGalleries) error - UpdateGalleriesTags(galleryID int, updatedJoins []GalleriesTags) error - UpdateGalleriesImages(imageID int, updatedJoins []GalleriesImages) error - UpdatePerformersImages(imageID int, updatedJoins []PerformersImages) error - UpdateImagesTags(imageID int, updatedJoins []ImagesTags) error -} - -type JoinReaderWriter interface { - JoinReader - JoinWriter -} - -func NewJoinReaderWriter(tx *sqlx.Tx) JoinReaderWriter { - return &joinReaderWriter{ - tx: tx, - qb: NewJoinsQueryBuilder(), - } -} - -type joinReaderWriter struct { - tx *sqlx.Tx - qb JoinsQueryBuilder -} - -func (t *joinReaderWriter) GetSceneMovies(sceneID int) ([]MoviesScenes, error) { - return t.qb.GetSceneMovies(sceneID, t.tx) -} - -func (t *joinReaderWriter) CreatePerformersScenes(newJoins []PerformersScenes) error { - return t.qb.CreatePerformersScenes(newJoins, t.tx) -} - -func (t *joinReaderWriter) UpdatePerformersScenes(sceneID int, updatedJoins []PerformersScenes) error { - return t.qb.UpdatePerformersScenes(sceneID, updatedJoins, t.tx) -} - -func (t *joinReaderWriter) CreateMoviesScenes(newJoins []MoviesScenes) error { - return t.qb.CreateMoviesScenes(newJoins, t.tx) -} - -func (t *joinReaderWriter) UpdateMoviesScenes(sceneID int, updatedJoins []MoviesScenes) error { - return t.qb.UpdateMoviesScenes(sceneID, updatedJoins, t.tx) -} - -func (t *joinReaderWriter) UpdateScenesTags(sceneID int, updatedJoins []ScenesTags) error { - return t.qb.UpdateScenesTags(sceneID, updatedJoins, t.tx) -} - -func (t *joinReaderWriter) UpdateSceneMarkersTags(sceneMarkerID int, updatedJoins []SceneMarkersTags) error { - return t.qb.UpdateSceneMarkersTags(sceneMarkerID, updatedJoins, t.tx) -} - -func (t *joinReaderWriter) UpdatePerformersGalleries(galleryID int, updatedJoins []PerformersGalleries) error { - return t.qb.UpdatePerformersGalleries(galleryID, updatedJoins, t.tx) -} - -func (t *joinReaderWriter) UpdateGalleriesTags(galleryID int, updatedJoins []GalleriesTags) error { - return t.qb.UpdateGalleriesTags(galleryID, updatedJoins, t.tx) -} - -func (t *joinReaderWriter) UpdateGalleriesImages(imageID int, updatedJoins []GalleriesImages) error { - return t.qb.UpdateGalleriesImages(imageID, updatedJoins, t.tx) -} - -func (t *joinReaderWriter) UpdatePerformersImages(imageID int, updatedJoins []PerformersImages) error { - return t.qb.UpdatePerformersImages(imageID, updatedJoins, t.tx) -} - -func (t *joinReaderWriter) UpdateImagesTags(imageID int, updatedJoins []ImagesTags) error { - return t.qb.UpdateImagesTags(imageID, updatedJoins, t.tx) -} diff --git a/pkg/models/mocks/GalleryReaderWriter.go b/pkg/models/mocks/GalleryReaderWriter.go index dcc37fc47..6869b9f6b 100644 --- a/pkg/models/mocks/GalleryReaderWriter.go +++ b/pkg/models/mocks/GalleryReaderWriter.go @@ -35,6 +35,41 @@ func (_m *GalleryReaderWriter) All() ([]*models.Gallery, error) { return r0, r1 } +// ClearGalleryId provides a mock function with given fields: sceneID +func (_m *GalleryReaderWriter) ClearGalleryId(sceneID int) error { + ret := _m.Called(sceneID) + + var r0 error + if rf, ok := ret.Get(0).(func(int) error); ok { + r0 = rf(sceneID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Count provides a mock function with given fields: +func (_m *GalleryReaderWriter) Count() (int, error) { + ret := _m.Called() + + var r0 int + if rf, ok := ret.Get(0).(func() int); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(int) + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // Create provides a mock function with given fields: newGallery func (_m *GalleryReaderWriter) Create(newGallery models.Gallery) (*models.Gallery, error) { ret := _m.Called(newGallery) @@ -58,6 +93,43 @@ func (_m *GalleryReaderWriter) Create(newGallery models.Gallery) (*models.Galler return r0, r1 } +// Destroy provides a mock function with given fields: id +func (_m *GalleryReaderWriter) Destroy(id int) error { + ret := _m.Called(id) + + var r0 error + if rf, ok := ret.Get(0).(func(int) error); ok { + r0 = rf(id) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Find provides a mock function with given fields: id +func (_m *GalleryReaderWriter) Find(id int) (*models.Gallery, error) { + ret := _m.Called(id) + + var r0 *models.Gallery + if rf, ok := ret.Get(0).(func(int) *models.Gallery); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*models.Gallery) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(int) error); ok { + r1 = rf(id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // FindByChecksum provides a mock function with given fields: checksum func (_m *GalleryReaderWriter) FindByChecksum(checksum string) (*models.Gallery, error) { ret := _m.Called(checksum) @@ -173,6 +245,105 @@ func (_m *GalleryReaderWriter) FindMany(ids []int) ([]*models.Gallery, error) { return r0, r1 } +// GetImageIDs provides a mock function with given fields: galleryID +func (_m *GalleryReaderWriter) GetImageIDs(galleryID int) ([]int, error) { + ret := _m.Called(galleryID) + + var r0 []int + if rf, ok := ret.Get(0).(func(int) []int); ok { + r0 = rf(galleryID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]int) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(int) error); ok { + r1 = rf(galleryID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetPerformerIDs provides a mock function with given fields: galleryID +func (_m *GalleryReaderWriter) GetPerformerIDs(galleryID int) ([]int, error) { + ret := _m.Called(galleryID) + + var r0 []int + if rf, ok := ret.Get(0).(func(int) []int); ok { + r0 = rf(galleryID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]int) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(int) error); ok { + r1 = rf(galleryID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetTagIDs provides a mock function with given fields: galleryID +func (_m *GalleryReaderWriter) GetTagIDs(galleryID int) ([]int, error) { + ret := _m.Called(galleryID) + + var r0 []int + if rf, ok := ret.Get(0).(func(int) []int); ok { + r0 = rf(galleryID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]int) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(int) error); ok { + r1 = rf(galleryID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Query provides a mock function with given fields: galleryFilter, findFilter +func (_m *GalleryReaderWriter) Query(galleryFilter *models.GalleryFilterType, findFilter *models.FindFilterType) ([]*models.Gallery, int, error) { + ret := _m.Called(galleryFilter, findFilter) + + var r0 []*models.Gallery + if rf, ok := ret.Get(0).(func(*models.GalleryFilterType, *models.FindFilterType) []*models.Gallery); ok { + r0 = rf(galleryFilter, findFilter) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*models.Gallery) + } + } + + var r1 int + if rf, ok := ret.Get(1).(func(*models.GalleryFilterType, *models.FindFilterType) int); ok { + r1 = rf(galleryFilter, findFilter) + } else { + r1 = ret.Get(1).(int) + } + + var r2 error + if rf, ok := ret.Get(2).(func(*models.GalleryFilterType, *models.FindFilterType) error); ok { + r2 = rf(galleryFilter, findFilter) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + // Update provides a mock function with given fields: updatedGallery func (_m *GalleryReaderWriter) Update(updatedGallery models.Gallery) (*models.Gallery, error) { ret := _m.Called(updatedGallery) @@ -195,3 +366,105 @@ func (_m *GalleryReaderWriter) Update(updatedGallery models.Gallery) (*models.Ga return r0, r1 } + +// UpdateFileModTime provides a mock function with given fields: id, modTime +func (_m *GalleryReaderWriter) UpdateFileModTime(id int, modTime models.NullSQLiteTimestamp) error { + ret := _m.Called(id, modTime) + + var r0 error + if rf, ok := ret.Get(0).(func(int, models.NullSQLiteTimestamp) error); ok { + r0 = rf(id, modTime) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// UpdateImages provides a mock function with given fields: galleryID, imageIDs +func (_m *GalleryReaderWriter) UpdateImages(galleryID int, imageIDs []int) error { + ret := _m.Called(galleryID, imageIDs) + + var r0 error + if rf, ok := ret.Get(0).(func(int, []int) error); ok { + r0 = rf(galleryID, imageIDs) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// UpdatePartial provides a mock function with given fields: updatedGallery +func (_m *GalleryReaderWriter) UpdatePartial(updatedGallery models.GalleryPartial) (*models.Gallery, error) { + ret := _m.Called(updatedGallery) + + var r0 *models.Gallery + if rf, ok := ret.Get(0).(func(models.GalleryPartial) *models.Gallery); ok { + r0 = rf(updatedGallery) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*models.Gallery) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(models.GalleryPartial) error); ok { + r1 = rf(updatedGallery) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// UpdatePerformers provides a mock function with given fields: galleryID, performerIDs +func (_m *GalleryReaderWriter) UpdatePerformers(galleryID int, performerIDs []int) error { + ret := _m.Called(galleryID, performerIDs) + + var r0 error + if rf, ok := ret.Get(0).(func(int, []int) error); ok { + r0 = rf(galleryID, performerIDs) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// UpdateTags provides a mock function with given fields: galleryID, tagIDs +func (_m *GalleryReaderWriter) UpdateTags(galleryID int, tagIDs []int) error { + ret := _m.Called(galleryID, tagIDs) + + var r0 error + if rf, ok := ret.Get(0).(func(int, []int) error); ok { + r0 = rf(galleryID, tagIDs) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// ValidGalleriesForScenePath provides a mock function with given fields: scenePath +func (_m *GalleryReaderWriter) ValidGalleriesForScenePath(scenePath string) ([]*models.Gallery, error) { + ret := _m.Called(scenePath) + + var r0 []*models.Gallery + if rf, ok := ret.Get(0).(func(string) []*models.Gallery); ok { + r0 = rf(scenePath) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*models.Gallery) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(scenePath) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/pkg/models/mocks/ImageReaderWriter.go b/pkg/models/mocks/ImageReaderWriter.go index 7010075cb..b00cacfc9 100644 --- a/pkg/models/mocks/ImageReaderWriter.go +++ b/pkg/models/mocks/ImageReaderWriter.go @@ -35,6 +35,48 @@ func (_m *ImageReaderWriter) All() ([]*models.Image, error) { return r0, r1 } +// Count provides a mock function with given fields: +func (_m *ImageReaderWriter) Count() (int, error) { + ret := _m.Called() + + var r0 int + if rf, ok := ret.Get(0).(func() int); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(int) + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// CountByGalleryID provides a mock function with given fields: galleryID +func (_m *ImageReaderWriter) CountByGalleryID(galleryID int) (int, error) { + ret := _m.Called(galleryID) + + var r0 int + if rf, ok := ret.Get(0).(func(int) int); ok { + r0 = rf(galleryID) + } else { + r0 = ret.Get(0).(int) + } + + var r1 error + if rf, ok := ret.Get(1).(func(int) error); ok { + r1 = rf(galleryID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // Create provides a mock function with given fields: newImage func (_m *ImageReaderWriter) Create(newImage models.Image) (*models.Image, error) { ret := _m.Called(newImage) @@ -58,6 +100,64 @@ func (_m *ImageReaderWriter) Create(newImage models.Image) (*models.Image, error return r0, r1 } +// DecrementOCounter provides a mock function with given fields: id +func (_m *ImageReaderWriter) DecrementOCounter(id int) (int, error) { + ret := _m.Called(id) + + var r0 int + if rf, ok := ret.Get(0).(func(int) int); ok { + r0 = rf(id) + } else { + r0 = ret.Get(0).(int) + } + + var r1 error + if rf, ok := ret.Get(1).(func(int) error); ok { + r1 = rf(id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Destroy provides a mock function with given fields: id +func (_m *ImageReaderWriter) Destroy(id int) error { + ret := _m.Called(id) + + var r0 error + if rf, ok := ret.Get(0).(func(int) error); ok { + r0 = rf(id) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Find provides a mock function with given fields: id +func (_m *ImageReaderWriter) Find(id int) (*models.Image, error) { + ret := _m.Called(id) + + var r0 *models.Image + if rf, ok := ret.Get(0).(func(int) *models.Image); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*models.Image) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(int) error); ok { + r1 = rf(id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // FindByChecksum provides a mock function with given fields: checksum func (_m *ImageReaderWriter) FindByChecksum(checksum string) (*models.Image, error) { ret := _m.Called(checksum) @@ -104,6 +204,29 @@ func (_m *ImageReaderWriter) FindByGalleryID(galleryID int) ([]*models.Image, er return r0, r1 } +// FindByPath provides a mock function with given fields: path +func (_m *ImageReaderWriter) FindByPath(path string) (*models.Image, error) { + ret := _m.Called(path) + + var r0 *models.Image + if rf, ok := ret.Get(0).(func(string) *models.Image); ok { + r0 = rf(path) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*models.Image) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(path) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // FindMany provides a mock function with given fields: ids func (_m *ImageReaderWriter) FindMany(ids []int) ([]*models.Image, error) { ret := _m.Called(ids) @@ -127,6 +250,168 @@ func (_m *ImageReaderWriter) FindMany(ids []int) ([]*models.Image, error) { return r0, r1 } +// GetGalleryIDs provides a mock function with given fields: imageID +func (_m *ImageReaderWriter) GetGalleryIDs(imageID int) ([]int, error) { + ret := _m.Called(imageID) + + var r0 []int + if rf, ok := ret.Get(0).(func(int) []int); ok { + r0 = rf(imageID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]int) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(int) error); ok { + r1 = rf(imageID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetPerformerIDs provides a mock function with given fields: imageID +func (_m *ImageReaderWriter) GetPerformerIDs(imageID int) ([]int, error) { + ret := _m.Called(imageID) + + var r0 []int + if rf, ok := ret.Get(0).(func(int) []int); ok { + r0 = rf(imageID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]int) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(int) error); ok { + r1 = rf(imageID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetTagIDs provides a mock function with given fields: imageID +func (_m *ImageReaderWriter) GetTagIDs(imageID int) ([]int, error) { + ret := _m.Called(imageID) + + var r0 []int + if rf, ok := ret.Get(0).(func(int) []int); ok { + r0 = rf(imageID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]int) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(int) error); ok { + r1 = rf(imageID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// IncrementOCounter provides a mock function with given fields: id +func (_m *ImageReaderWriter) IncrementOCounter(id int) (int, error) { + ret := _m.Called(id) + + var r0 int + if rf, ok := ret.Get(0).(func(int) int); ok { + r0 = rf(id) + } else { + r0 = ret.Get(0).(int) + } + + var r1 error + if rf, ok := ret.Get(1).(func(int) error); ok { + r1 = rf(id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Query provides a mock function with given fields: imageFilter, findFilter +func (_m *ImageReaderWriter) Query(imageFilter *models.ImageFilterType, findFilter *models.FindFilterType) ([]*models.Image, int, error) { + ret := _m.Called(imageFilter, findFilter) + + var r0 []*models.Image + if rf, ok := ret.Get(0).(func(*models.ImageFilterType, *models.FindFilterType) []*models.Image); ok { + r0 = rf(imageFilter, findFilter) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*models.Image) + } + } + + var r1 int + if rf, ok := ret.Get(1).(func(*models.ImageFilterType, *models.FindFilterType) int); ok { + r1 = rf(imageFilter, findFilter) + } else { + r1 = ret.Get(1).(int) + } + + var r2 error + if rf, ok := ret.Get(2).(func(*models.ImageFilterType, *models.FindFilterType) error); ok { + r2 = rf(imageFilter, findFilter) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// ResetOCounter provides a mock function with given fields: id +func (_m *ImageReaderWriter) ResetOCounter(id int) (int, error) { + ret := _m.Called(id) + + var r0 int + if rf, ok := ret.Get(0).(func(int) int); ok { + r0 = rf(id) + } else { + r0 = ret.Get(0).(int) + } + + var r1 error + if rf, ok := ret.Get(1).(func(int) error); ok { + r1 = rf(id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Size provides a mock function with given fields: +func (_m *ImageReaderWriter) Size() (float64, error) { + ret := _m.Called() + + var r0 float64 + if rf, ok := ret.Get(0).(func() float64); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(float64) + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // Update provides a mock function with given fields: updatedImage func (_m *ImageReaderWriter) Update(updatedImage models.ImagePartial) (*models.Image, error) { ret := _m.Called(updatedImage) @@ -172,3 +457,45 @@ func (_m *ImageReaderWriter) UpdateFull(updatedImage models.Image) (*models.Imag return r0, r1 } + +// UpdateGalleries provides a mock function with given fields: imageID, galleryIDs +func (_m *ImageReaderWriter) UpdateGalleries(imageID int, galleryIDs []int) error { + ret := _m.Called(imageID, galleryIDs) + + var r0 error + if rf, ok := ret.Get(0).(func(int, []int) error); ok { + r0 = rf(imageID, galleryIDs) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// UpdatePerformers provides a mock function with given fields: imageID, performerIDs +func (_m *ImageReaderWriter) UpdatePerformers(imageID int, performerIDs []int) error { + ret := _m.Called(imageID, performerIDs) + + var r0 error + if rf, ok := ret.Get(0).(func(int, []int) error); ok { + r0 = rf(imageID, performerIDs) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// UpdateTags provides a mock function with given fields: imageID, tagIDs +func (_m *ImageReaderWriter) UpdateTags(imageID int, tagIDs []int) error { + ret := _m.Called(imageID, tagIDs) + + var r0 error + if rf, ok := ret.Get(0).(func(int, []int) error); ok { + r0 = rf(imageID, tagIDs) + } else { + r0 = ret.Error(0) + } + + return r0 +} diff --git a/pkg/models/mocks/JoinReaderWriter.go b/pkg/models/mocks/JoinReaderWriter.go deleted file mode 100644 index 6fc20057e..000000000 --- a/pkg/models/mocks/JoinReaderWriter.go +++ /dev/null @@ -1,190 +0,0 @@ -// Code generated by mockery v0.0.0-dev. DO NOT EDIT. - -package mocks - -import ( - models "github.com/stashapp/stash/pkg/models" - mock "github.com/stretchr/testify/mock" -) - -// JoinReaderWriter is an autogenerated mock type for the JoinReaderWriter type -type JoinReaderWriter struct { - mock.Mock -} - -// CreateMoviesScenes provides a mock function with given fields: newJoins -func (_m *JoinReaderWriter) CreateMoviesScenes(newJoins []models.MoviesScenes) error { - ret := _m.Called(newJoins) - - var r0 error - if rf, ok := ret.Get(0).(func([]models.MoviesScenes) error); ok { - r0 = rf(newJoins) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// CreatePerformersScenes provides a mock function with given fields: newJoins -func (_m *JoinReaderWriter) CreatePerformersScenes(newJoins []models.PerformersScenes) error { - ret := _m.Called(newJoins) - - var r0 error - if rf, ok := ret.Get(0).(func([]models.PerformersScenes) error); ok { - r0 = rf(newJoins) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// GetSceneMovies provides a mock function with given fields: sceneID -func (_m *JoinReaderWriter) GetSceneMovies(sceneID int) ([]models.MoviesScenes, error) { - ret := _m.Called(sceneID) - - var r0 []models.MoviesScenes - if rf, ok := ret.Get(0).(func(int) []models.MoviesScenes); ok { - r0 = rf(sceneID) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]models.MoviesScenes) - } - } - - var r1 error - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(sceneID) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// UpdateGalleriesImages provides a mock function with given fields: imageID, updatedJoins -func (_m *JoinReaderWriter) UpdateGalleriesImages(imageID int, updatedJoins []models.GalleriesImages) error { - ret := _m.Called(imageID, updatedJoins) - - var r0 error - if rf, ok := ret.Get(0).(func(int, []models.GalleriesImages) error); ok { - r0 = rf(imageID, updatedJoins) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// UpdateGalleriesTags provides a mock function with given fields: galleryID, updatedJoins -func (_m *JoinReaderWriter) UpdateGalleriesTags(galleryID int, updatedJoins []models.GalleriesTags) error { - ret := _m.Called(galleryID, updatedJoins) - - var r0 error - if rf, ok := ret.Get(0).(func(int, []models.GalleriesTags) error); ok { - r0 = rf(galleryID, updatedJoins) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// UpdateImagesTags provides a mock function with given fields: imageID, updatedJoins -func (_m *JoinReaderWriter) UpdateImagesTags(imageID int, updatedJoins []models.ImagesTags) error { - ret := _m.Called(imageID, updatedJoins) - - var r0 error - if rf, ok := ret.Get(0).(func(int, []models.ImagesTags) error); ok { - r0 = rf(imageID, updatedJoins) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// UpdateMoviesScenes provides a mock function with given fields: sceneID, updatedJoins -func (_m *JoinReaderWriter) UpdateMoviesScenes(sceneID int, updatedJoins []models.MoviesScenes) error { - ret := _m.Called(sceneID, updatedJoins) - - var r0 error - if rf, ok := ret.Get(0).(func(int, []models.MoviesScenes) error); ok { - r0 = rf(sceneID, updatedJoins) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// UpdatePerformersGalleries provides a mock function with given fields: galleryID, updatedJoins -func (_m *JoinReaderWriter) UpdatePerformersGalleries(galleryID int, updatedJoins []models.PerformersGalleries) error { - ret := _m.Called(galleryID, updatedJoins) - - var r0 error - if rf, ok := ret.Get(0).(func(int, []models.PerformersGalleries) error); ok { - r0 = rf(galleryID, updatedJoins) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// UpdatePerformersImages provides a mock function with given fields: imageID, updatedJoins -func (_m *JoinReaderWriter) UpdatePerformersImages(imageID int, updatedJoins []models.PerformersImages) error { - ret := _m.Called(imageID, updatedJoins) - - var r0 error - if rf, ok := ret.Get(0).(func(int, []models.PerformersImages) error); ok { - r0 = rf(imageID, updatedJoins) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// UpdatePerformersScenes provides a mock function with given fields: sceneID, updatedJoins -func (_m *JoinReaderWriter) UpdatePerformersScenes(sceneID int, updatedJoins []models.PerformersScenes) error { - ret := _m.Called(sceneID, updatedJoins) - - var r0 error - if rf, ok := ret.Get(0).(func(int, []models.PerformersScenes) error); ok { - r0 = rf(sceneID, updatedJoins) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// UpdateSceneMarkersTags provides a mock function with given fields: sceneMarkerID, updatedJoins -func (_m *JoinReaderWriter) UpdateSceneMarkersTags(sceneMarkerID int, updatedJoins []models.SceneMarkersTags) error { - ret := _m.Called(sceneMarkerID, updatedJoins) - - var r0 error - if rf, ok := ret.Get(0).(func(int, []models.SceneMarkersTags) error); ok { - r0 = rf(sceneMarkerID, updatedJoins) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// UpdateScenesTags provides a mock function with given fields: sceneID, updatedJoins -func (_m *JoinReaderWriter) UpdateScenesTags(sceneID int, updatedJoins []models.ScenesTags) error { - ret := _m.Called(sceneID, updatedJoins) - - var r0 error - if rf, ok := ret.Get(0).(func(int, []models.ScenesTags) error); ok { - r0 = rf(sceneID, updatedJoins) - } else { - r0 = ret.Error(0) - } - - return r0 -} diff --git a/pkg/models/mocks/MovieReaderWriter.go b/pkg/models/mocks/MovieReaderWriter.go index 5521f25e4..be9ec5dd6 100644 --- a/pkg/models/mocks/MovieReaderWriter.go +++ b/pkg/models/mocks/MovieReaderWriter.go @@ -35,6 +35,50 @@ func (_m *MovieReaderWriter) All() ([]*models.Movie, error) { return r0, r1 } +// AllSlim provides a mock function with given fields: +func (_m *MovieReaderWriter) AllSlim() ([]*models.Movie, error) { + ret := _m.Called() + + var r0 []*models.Movie + if rf, ok := ret.Get(0).(func() []*models.Movie); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*models.Movie) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Count provides a mock function with given fields: +func (_m *MovieReaderWriter) Count() (int, error) { + ret := _m.Called() + + var r0 int + if rf, ok := ret.Get(0).(func() int); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(int) + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // Create provides a mock function with given fields: newMovie func (_m *MovieReaderWriter) Create(newMovie models.Movie) (*models.Movie, error) { ret := _m.Called(newMovie) @@ -58,6 +102,34 @@ func (_m *MovieReaderWriter) Create(newMovie models.Movie) (*models.Movie, error return r0, r1 } +// Destroy provides a mock function with given fields: id +func (_m *MovieReaderWriter) Destroy(id int) error { + ret := _m.Called(id) + + var r0 error + if rf, ok := ret.Get(0).(func(int) error); ok { + r0 = rf(id) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// DestroyImages provides a mock function with given fields: movieID +func (_m *MovieReaderWriter) DestroyImages(movieID int) error { + ret := _m.Called(movieID) + + var r0 error + if rf, ok := ret.Get(0).(func(int) error); ok { + r0 = rf(movieID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // Find provides a mock function with given fields: id func (_m *MovieReaderWriter) Find(id int) (*models.Movie, error) { ret := _m.Called(id) @@ -196,6 +268,36 @@ func (_m *MovieReaderWriter) GetFrontImage(movieID int) ([]byte, error) { return r0, r1 } +// Query provides a mock function with given fields: movieFilter, findFilter +func (_m *MovieReaderWriter) Query(movieFilter *models.MovieFilterType, findFilter *models.FindFilterType) ([]*models.Movie, int, error) { + ret := _m.Called(movieFilter, findFilter) + + var r0 []*models.Movie + if rf, ok := ret.Get(0).(func(*models.MovieFilterType, *models.FindFilterType) []*models.Movie); ok { + r0 = rf(movieFilter, findFilter) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*models.Movie) + } + } + + var r1 int + if rf, ok := ret.Get(1).(func(*models.MovieFilterType, *models.FindFilterType) int); ok { + r1 = rf(movieFilter, findFilter) + } else { + r1 = ret.Get(1).(int) + } + + var r2 error + if rf, ok := ret.Get(2).(func(*models.MovieFilterType, *models.FindFilterType) error); ok { + r2 = rf(movieFilter, findFilter) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + // Update provides a mock function with given fields: updatedMovie func (_m *MovieReaderWriter) Update(updatedMovie models.MoviePartial) (*models.Movie, error) { ret := _m.Called(updatedMovie) @@ -242,8 +344,8 @@ func (_m *MovieReaderWriter) UpdateFull(updatedMovie models.Movie) (*models.Movi return r0, r1 } -// UpdateMovieImages provides a mock function with given fields: movieID, frontImage, backImage -func (_m *MovieReaderWriter) UpdateMovieImages(movieID int, frontImage []byte, backImage []byte) error { +// UpdateImages provides a mock function with given fields: movieID, frontImage, backImage +func (_m *MovieReaderWriter) UpdateImages(movieID int, frontImage []byte, backImage []byte) error { ret := _m.Called(movieID, frontImage, backImage) var r0 error diff --git a/pkg/models/mocks/PerformerReaderWriter.go b/pkg/models/mocks/PerformerReaderWriter.go index f10129fd8..24cdc5bbf 100644 --- a/pkg/models/mocks/PerformerReaderWriter.go +++ b/pkg/models/mocks/PerformerReaderWriter.go @@ -35,6 +35,50 @@ func (_m *PerformerReaderWriter) All() ([]*models.Performer, error) { return r0, r1 } +// AllSlim provides a mock function with given fields: +func (_m *PerformerReaderWriter) AllSlim() ([]*models.Performer, error) { + ret := _m.Called() + + var r0 []*models.Performer + if rf, ok := ret.Get(0).(func() []*models.Performer); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*models.Performer) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Count provides a mock function with given fields: +func (_m *PerformerReaderWriter) Count() (int, error) { + ret := _m.Called() + + var r0 int + if rf, ok := ret.Get(0).(func() int); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(int) + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // Create provides a mock function with given fields: newPerformer func (_m *PerformerReaderWriter) Create(newPerformer models.Performer) (*models.Performer, error) { ret := _m.Called(newPerformer) @@ -58,6 +102,57 @@ func (_m *PerformerReaderWriter) Create(newPerformer models.Performer) (*models. return r0, r1 } +// Destroy provides a mock function with given fields: id +func (_m *PerformerReaderWriter) Destroy(id int) error { + ret := _m.Called(id) + + var r0 error + if rf, ok := ret.Get(0).(func(int) error); ok { + r0 = rf(id) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// DestroyImage provides a mock function with given fields: performerID +func (_m *PerformerReaderWriter) DestroyImage(performerID int) error { + ret := _m.Called(performerID) + + var r0 error + if rf, ok := ret.Get(0).(func(int) error); ok { + r0 = rf(performerID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Find provides a mock function with given fields: id +func (_m *PerformerReaderWriter) Find(id int) (*models.Performer, error) { + ret := _m.Called(id) + + var r0 *models.Performer + if rf, ok := ret.Get(0).(func(int) *models.Performer); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*models.Performer) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(int) error); ok { + r1 = rf(id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // FindByGalleryID provides a mock function with given fields: galleryID func (_m *PerformerReaderWriter) FindByGalleryID(galleryID int) ([]*models.Performer, error) { ret := _m.Called(galleryID) @@ -196,8 +291,8 @@ func (_m *PerformerReaderWriter) FindNamesBySceneID(sceneID int) ([]*models.Perf return r0, r1 } -// GetPerformerImage provides a mock function with given fields: performerID -func (_m *PerformerReaderWriter) GetPerformerImage(performerID int) ([]byte, error) { +// GetImage provides a mock function with given fields: performerID +func (_m *PerformerReaderWriter) GetImage(performerID int) ([]byte, error) { ret := _m.Called(performerID) var r0 []byte @@ -219,6 +314,59 @@ func (_m *PerformerReaderWriter) GetPerformerImage(performerID int) ([]byte, err return r0, r1 } +// GetStashIDs provides a mock function with given fields: performerID +func (_m *PerformerReaderWriter) GetStashIDs(performerID int) ([]*models.StashID, error) { + ret := _m.Called(performerID) + + var r0 []*models.StashID + if rf, ok := ret.Get(0).(func(int) []*models.StashID); ok { + r0 = rf(performerID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*models.StashID) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(int) error); ok { + r1 = rf(performerID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Query provides a mock function with given fields: performerFilter, findFilter +func (_m *PerformerReaderWriter) Query(performerFilter *models.PerformerFilterType, findFilter *models.FindFilterType) ([]*models.Performer, int, error) { + ret := _m.Called(performerFilter, findFilter) + + var r0 []*models.Performer + if rf, ok := ret.Get(0).(func(*models.PerformerFilterType, *models.FindFilterType) []*models.Performer); ok { + r0 = rf(performerFilter, findFilter) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*models.Performer) + } + } + + var r1 int + if rf, ok := ret.Get(1).(func(*models.PerformerFilterType, *models.FindFilterType) int); ok { + r1 = rf(performerFilter, findFilter) + } else { + r1 = ret.Get(1).(int) + } + + var r2 error + if rf, ok := ret.Get(2).(func(*models.PerformerFilterType, *models.FindFilterType) error); ok { + r2 = rf(performerFilter, findFilter) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + // Update provides a mock function with given fields: updatedPerformer func (_m *PerformerReaderWriter) Update(updatedPerformer models.PerformerPartial) (*models.Performer, error) { ret := _m.Called(updatedPerformer) @@ -265,8 +413,8 @@ func (_m *PerformerReaderWriter) UpdateFull(updatedPerformer models.Performer) ( return r0, r1 } -// UpdatePerformerImage provides a mock function with given fields: performerID, image -func (_m *PerformerReaderWriter) UpdatePerformerImage(performerID int, image []byte) error { +// UpdateImage provides a mock function with given fields: performerID, image +func (_m *PerformerReaderWriter) UpdateImage(performerID int, image []byte) error { ret := _m.Called(performerID, image) var r0 error @@ -278,3 +426,17 @@ func (_m *PerformerReaderWriter) UpdatePerformerImage(performerID int, image []b return r0 } + +// UpdateStashIDs provides a mock function with given fields: performerID, stashIDs +func (_m *PerformerReaderWriter) UpdateStashIDs(performerID int, stashIDs []models.StashID) error { + ret := _m.Called(performerID, stashIDs) + + var r0 error + if rf, ok := ret.Get(0).(func(int, []models.StashID) error); ok { + r0 = rf(performerID, stashIDs) + } else { + r0 = ret.Error(0) + } + + return r0 +} diff --git a/pkg/models/mocks/SceneMarkerReaderWriter.go b/pkg/models/mocks/SceneMarkerReaderWriter.go index df1b4f937..ebecb8e4d 100644 --- a/pkg/models/mocks/SceneMarkerReaderWriter.go +++ b/pkg/models/mocks/SceneMarkerReaderWriter.go @@ -12,6 +12,27 @@ type SceneMarkerReaderWriter struct { mock.Mock } +// CountByTagID provides a mock function with given fields: tagID +func (_m *SceneMarkerReaderWriter) CountByTagID(tagID int) (int, error) { + ret := _m.Called(tagID) + + var r0 int + if rf, ok := ret.Get(0).(func(int) int); ok { + r0 = rf(tagID) + } else { + r0 = ret.Get(0).(int) + } + + var r1 error + if rf, ok := ret.Get(1).(func(int) error); ok { + r1 = rf(tagID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // Create provides a mock function with given fields: newSceneMarker func (_m *SceneMarkerReaderWriter) Create(newSceneMarker models.SceneMarker) (*models.SceneMarker, error) { ret := _m.Called(newSceneMarker) @@ -35,6 +56,43 @@ func (_m *SceneMarkerReaderWriter) Create(newSceneMarker models.SceneMarker) (*m return r0, r1 } +// Destroy provides a mock function with given fields: id +func (_m *SceneMarkerReaderWriter) Destroy(id int) error { + ret := _m.Called(id) + + var r0 error + if rf, ok := ret.Get(0).(func(int) error); ok { + r0 = rf(id) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Find provides a mock function with given fields: id +func (_m *SceneMarkerReaderWriter) Find(id int) (*models.SceneMarker, error) { + ret := _m.Called(id) + + var r0 *models.SceneMarker + if rf, ok := ret.Get(0).(func(int) *models.SceneMarker); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*models.SceneMarker) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(int) error); ok { + r1 = rf(id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // FindBySceneID provides a mock function with given fields: sceneID func (_m *SceneMarkerReaderWriter) FindBySceneID(sceneID int) ([]*models.SceneMarker, error) { ret := _m.Called(sceneID) @@ -58,6 +116,105 @@ func (_m *SceneMarkerReaderWriter) FindBySceneID(sceneID int) ([]*models.SceneMa return r0, r1 } +// FindMany provides a mock function with given fields: ids +func (_m *SceneMarkerReaderWriter) FindMany(ids []int) ([]*models.SceneMarker, error) { + ret := _m.Called(ids) + + var r0 []*models.SceneMarker + if rf, ok := ret.Get(0).(func([]int) []*models.SceneMarker); ok { + r0 = rf(ids) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*models.SceneMarker) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func([]int) error); ok { + r1 = rf(ids) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetMarkerStrings provides a mock function with given fields: q, sort +func (_m *SceneMarkerReaderWriter) GetMarkerStrings(q *string, sort *string) ([]*models.MarkerStringsResultType, error) { + ret := _m.Called(q, sort) + + var r0 []*models.MarkerStringsResultType + if rf, ok := ret.Get(0).(func(*string, *string) []*models.MarkerStringsResultType); ok { + r0 = rf(q, sort) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*models.MarkerStringsResultType) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(*string, *string) error); ok { + r1 = rf(q, sort) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetTagIDs provides a mock function with given fields: imageID +func (_m *SceneMarkerReaderWriter) GetTagIDs(imageID int) ([]int, error) { + ret := _m.Called(imageID) + + var r0 []int + if rf, ok := ret.Get(0).(func(int) []int); ok { + r0 = rf(imageID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]int) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(int) error); ok { + r1 = rf(imageID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Query provides a mock function with given fields: sceneMarkerFilter, findFilter +func (_m *SceneMarkerReaderWriter) Query(sceneMarkerFilter *models.SceneMarkerFilterType, findFilter *models.FindFilterType) ([]*models.SceneMarker, int, error) { + ret := _m.Called(sceneMarkerFilter, findFilter) + + var r0 []*models.SceneMarker + if rf, ok := ret.Get(0).(func(*models.SceneMarkerFilterType, *models.FindFilterType) []*models.SceneMarker); ok { + r0 = rf(sceneMarkerFilter, findFilter) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*models.SceneMarker) + } + } + + var r1 int + if rf, ok := ret.Get(1).(func(*models.SceneMarkerFilterType, *models.FindFilterType) int); ok { + r1 = rf(sceneMarkerFilter, findFilter) + } else { + r1 = ret.Get(1).(int) + } + + var r2 error + if rf, ok := ret.Get(2).(func(*models.SceneMarkerFilterType, *models.FindFilterType) error); ok { + r2 = rf(sceneMarkerFilter, findFilter) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + // Update provides a mock function with given fields: updatedSceneMarker func (_m *SceneMarkerReaderWriter) Update(updatedSceneMarker models.SceneMarker) (*models.SceneMarker, error) { ret := _m.Called(updatedSceneMarker) @@ -80,3 +237,40 @@ func (_m *SceneMarkerReaderWriter) Update(updatedSceneMarker models.SceneMarker) return r0, r1 } + +// UpdateTags provides a mock function with given fields: markerID, tagIDs +func (_m *SceneMarkerReaderWriter) UpdateTags(markerID int, tagIDs []int) error { + ret := _m.Called(markerID, tagIDs) + + var r0 error + if rf, ok := ret.Get(0).(func(int, []int) error); ok { + r0 = rf(markerID, tagIDs) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Wall provides a mock function with given fields: q +func (_m *SceneMarkerReaderWriter) Wall(q *string) ([]*models.SceneMarker, error) { + ret := _m.Called(q) + + var r0 []*models.SceneMarker + if rf, ok := ret.Get(0).(func(*string) []*models.SceneMarker); ok { + r0 = rf(q) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*models.SceneMarker) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(*string) error); ok { + r1 = rf(q) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/pkg/models/mocks/SceneReaderWriter.go b/pkg/models/mocks/SceneReaderWriter.go index a88acc723..2c6862fb9 100644 --- a/pkg/models/mocks/SceneReaderWriter.go +++ b/pkg/models/mocks/SceneReaderWriter.go @@ -35,6 +35,153 @@ func (_m *SceneReaderWriter) All() ([]*models.Scene, error) { return r0, r1 } +// Count provides a mock function with given fields: +func (_m *SceneReaderWriter) Count() (int, error) { + ret := _m.Called() + + var r0 int + if rf, ok := ret.Get(0).(func() int); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(int) + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// CountByMovieID provides a mock function with given fields: movieID +func (_m *SceneReaderWriter) CountByMovieID(movieID int) (int, error) { + ret := _m.Called(movieID) + + var r0 int + if rf, ok := ret.Get(0).(func(int) int); ok { + r0 = rf(movieID) + } else { + r0 = ret.Get(0).(int) + } + + var r1 error + if rf, ok := ret.Get(1).(func(int) error); ok { + r1 = rf(movieID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// CountByPerformerID provides a mock function with given fields: performerID +func (_m *SceneReaderWriter) CountByPerformerID(performerID int) (int, error) { + ret := _m.Called(performerID) + + var r0 int + if rf, ok := ret.Get(0).(func(int) int); ok { + r0 = rf(performerID) + } else { + r0 = ret.Get(0).(int) + } + + var r1 error + if rf, ok := ret.Get(1).(func(int) error); ok { + r1 = rf(performerID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// CountByStudioID provides a mock function with given fields: studioID +func (_m *SceneReaderWriter) CountByStudioID(studioID int) (int, error) { + ret := _m.Called(studioID) + + var r0 int + if rf, ok := ret.Get(0).(func(int) int); ok { + r0 = rf(studioID) + } else { + r0 = ret.Get(0).(int) + } + + var r1 error + if rf, ok := ret.Get(1).(func(int) error); ok { + r1 = rf(studioID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// CountByTagID provides a mock function with given fields: tagID +func (_m *SceneReaderWriter) CountByTagID(tagID int) (int, error) { + ret := _m.Called(tagID) + + var r0 int + if rf, ok := ret.Get(0).(func(int) int); ok { + r0 = rf(tagID) + } else { + r0 = ret.Get(0).(int) + } + + var r1 error + if rf, ok := ret.Get(1).(func(int) error); ok { + r1 = rf(tagID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// CountMissingChecksum provides a mock function with given fields: +func (_m *SceneReaderWriter) CountMissingChecksum() (int, error) { + ret := _m.Called() + + var r0 int + if rf, ok := ret.Get(0).(func() int); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(int) + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// CountMissingOSHash provides a mock function with given fields: +func (_m *SceneReaderWriter) CountMissingOSHash() (int, error) { + ret := _m.Called() + + var r0 int + if rf, ok := ret.Get(0).(func() int); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(int) + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // Create provides a mock function with given fields: newScene func (_m *SceneReaderWriter) Create(newScene models.Scene) (*models.Scene, error) { ret := _m.Called(newScene) @@ -58,6 +205,78 @@ func (_m *SceneReaderWriter) Create(newScene models.Scene) (*models.Scene, error return r0, r1 } +// DecrementOCounter provides a mock function with given fields: id +func (_m *SceneReaderWriter) DecrementOCounter(id int) (int, error) { + ret := _m.Called(id) + + var r0 int + if rf, ok := ret.Get(0).(func(int) int); ok { + r0 = rf(id) + } else { + r0 = ret.Get(0).(int) + } + + var r1 error + if rf, ok := ret.Get(1).(func(int) error); ok { + r1 = rf(id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Destroy provides a mock function with given fields: id +func (_m *SceneReaderWriter) Destroy(id int) error { + ret := _m.Called(id) + + var r0 error + if rf, ok := ret.Get(0).(func(int) error); ok { + r0 = rf(id) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// DestroyCover provides a mock function with given fields: sceneID +func (_m *SceneReaderWriter) DestroyCover(sceneID int) error { + ret := _m.Called(sceneID) + + var r0 error + if rf, ok := ret.Get(0).(func(int) error); ok { + r0 = rf(sceneID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Find provides a mock function with given fields: id +func (_m *SceneReaderWriter) Find(id int) (*models.Scene, error) { + ret := _m.Called(id) + + var r0 *models.Scene + if rf, ok := ret.Get(0).(func(int) *models.Scene); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*models.Scene) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(int) error); ok { + r1 = rf(id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // FindByChecksum provides a mock function with given fields: checksum func (_m *SceneReaderWriter) FindByChecksum(checksum string) (*models.Scene, error) { ret := _m.Called(checksum) @@ -127,6 +346,52 @@ func (_m *SceneReaderWriter) FindByOSHash(oshash string) (*models.Scene, error) return r0, r1 } +// FindByPath provides a mock function with given fields: path +func (_m *SceneReaderWriter) FindByPath(path string) (*models.Scene, error) { + ret := _m.Called(path) + + var r0 *models.Scene + if rf, ok := ret.Get(0).(func(string) *models.Scene); ok { + r0 = rf(path) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*models.Scene) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(path) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// FindByPerformerID provides a mock function with given fields: performerID +func (_m *SceneReaderWriter) FindByPerformerID(performerID int) ([]*models.Scene, error) { + ret := _m.Called(performerID) + + var r0 []*models.Scene + if rf, ok := ret.Get(0).(func(int) []*models.Scene); ok { + r0 = rf(performerID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*models.Scene) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(int) error); ok { + r1 = rf(performerID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // FindMany provides a mock function with given fields: ids func (_m *SceneReaderWriter) FindMany(ids []int) ([]*models.Scene, error) { ret := _m.Called(ids) @@ -150,8 +415,8 @@ func (_m *SceneReaderWriter) FindMany(ids []int) ([]*models.Scene, error) { return r0, r1 } -// GetSceneCover provides a mock function with given fields: sceneID -func (_m *SceneReaderWriter) GetSceneCover(sceneID int) ([]byte, error) { +// GetCover provides a mock function with given fields: sceneID +func (_m *SceneReaderWriter) GetCover(sceneID int) ([]byte, error) { ret := _m.Called(sceneID) var r0 []byte @@ -173,6 +438,244 @@ func (_m *SceneReaderWriter) GetSceneCover(sceneID int) ([]byte, error) { return r0, r1 } +// GetMovies provides a mock function with given fields: sceneID +func (_m *SceneReaderWriter) GetMovies(sceneID int) ([]models.MoviesScenes, error) { + ret := _m.Called(sceneID) + + var r0 []models.MoviesScenes + if rf, ok := ret.Get(0).(func(int) []models.MoviesScenes); ok { + r0 = rf(sceneID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]models.MoviesScenes) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(int) error); ok { + r1 = rf(sceneID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetPerformerIDs provides a mock function with given fields: imageID +func (_m *SceneReaderWriter) GetPerformerIDs(imageID int) ([]int, error) { + ret := _m.Called(imageID) + + var r0 []int + if rf, ok := ret.Get(0).(func(int) []int); ok { + r0 = rf(imageID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]int) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(int) error); ok { + r1 = rf(imageID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetStashIDs provides a mock function with given fields: performerID +func (_m *SceneReaderWriter) GetStashIDs(performerID int) ([]*models.StashID, error) { + ret := _m.Called(performerID) + + var r0 []*models.StashID + if rf, ok := ret.Get(0).(func(int) []*models.StashID); ok { + r0 = rf(performerID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*models.StashID) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(int) error); ok { + r1 = rf(performerID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetTagIDs provides a mock function with given fields: imageID +func (_m *SceneReaderWriter) GetTagIDs(imageID int) ([]int, error) { + ret := _m.Called(imageID) + + var r0 []int + if rf, ok := ret.Get(0).(func(int) []int); ok { + r0 = rf(imageID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]int) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(int) error); ok { + r1 = rf(imageID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// IncrementOCounter provides a mock function with given fields: id +func (_m *SceneReaderWriter) IncrementOCounter(id int) (int, error) { + ret := _m.Called(id) + + var r0 int + if rf, ok := ret.Get(0).(func(int) int); ok { + r0 = rf(id) + } else { + r0 = ret.Get(0).(int) + } + + var r1 error + if rf, ok := ret.Get(1).(func(int) error); ok { + r1 = rf(id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Query provides a mock function with given fields: sceneFilter, findFilter +func (_m *SceneReaderWriter) Query(sceneFilter *models.SceneFilterType, findFilter *models.FindFilterType) ([]*models.Scene, int, error) { + ret := _m.Called(sceneFilter, findFilter) + + var r0 []*models.Scene + if rf, ok := ret.Get(0).(func(*models.SceneFilterType, *models.FindFilterType) []*models.Scene); ok { + r0 = rf(sceneFilter, findFilter) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*models.Scene) + } + } + + var r1 int + if rf, ok := ret.Get(1).(func(*models.SceneFilterType, *models.FindFilterType) int); ok { + r1 = rf(sceneFilter, findFilter) + } else { + r1 = ret.Get(1).(int) + } + + var r2 error + if rf, ok := ret.Get(2).(func(*models.SceneFilterType, *models.FindFilterType) error); ok { + r2 = rf(sceneFilter, findFilter) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// QueryAllByPathRegex provides a mock function with given fields: regex, ignoreOrganized +func (_m *SceneReaderWriter) QueryAllByPathRegex(regex string, ignoreOrganized bool) ([]*models.Scene, error) { + ret := _m.Called(regex, ignoreOrganized) + + var r0 []*models.Scene + if rf, ok := ret.Get(0).(func(string, bool) []*models.Scene); ok { + r0 = rf(regex, ignoreOrganized) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*models.Scene) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string, bool) error); ok { + r1 = rf(regex, ignoreOrganized) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// QueryByPathRegex provides a mock function with given fields: findFilter +func (_m *SceneReaderWriter) QueryByPathRegex(findFilter *models.FindFilterType) ([]*models.Scene, int, error) { + ret := _m.Called(findFilter) + + var r0 []*models.Scene + if rf, ok := ret.Get(0).(func(*models.FindFilterType) []*models.Scene); ok { + r0 = rf(findFilter) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*models.Scene) + } + } + + var r1 int + if rf, ok := ret.Get(1).(func(*models.FindFilterType) int); ok { + r1 = rf(findFilter) + } else { + r1 = ret.Get(1).(int) + } + + var r2 error + if rf, ok := ret.Get(2).(func(*models.FindFilterType) error); ok { + r2 = rf(findFilter) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// ResetOCounter provides a mock function with given fields: id +func (_m *SceneReaderWriter) ResetOCounter(id int) (int, error) { + ret := _m.Called(id) + + var r0 int + if rf, ok := ret.Get(0).(func(int) int); ok { + r0 = rf(id) + } else { + r0 = ret.Get(0).(int) + } + + var r1 error + if rf, ok := ret.Get(1).(func(int) error); ok { + r1 = rf(id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Size provides a mock function with given fields: +func (_m *SceneReaderWriter) Size() (float64, error) { + ret := _m.Called() + + var r0 float64 + if rf, ok := ret.Get(0).(func() float64); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(float64) + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // Update provides a mock function with given fields: updatedScene func (_m *SceneReaderWriter) Update(updatedScene models.ScenePartial) (*models.Scene, error) { ret := _m.Called(updatedScene) @@ -196,6 +699,34 @@ func (_m *SceneReaderWriter) Update(updatedScene models.ScenePartial) (*models.S return r0, r1 } +// UpdateCover provides a mock function with given fields: sceneID, cover +func (_m *SceneReaderWriter) UpdateCover(sceneID int, cover []byte) error { + ret := _m.Called(sceneID, cover) + + var r0 error + if rf, ok := ret.Get(0).(func(int, []byte) error); ok { + r0 = rf(sceneID, cover) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// UpdateFileModTime provides a mock function with given fields: id, modTime +func (_m *SceneReaderWriter) UpdateFileModTime(id int, modTime models.NullSQLiteTimestamp) error { + ret := _m.Called(id, modTime) + + var r0 error + if rf, ok := ret.Get(0).(func(int, models.NullSQLiteTimestamp) error); ok { + r0 = rf(id, modTime) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // UpdateFull provides a mock function with given fields: updatedScene func (_m *SceneReaderWriter) UpdateFull(updatedScene models.Scene) (*models.Scene, error) { ret := _m.Called(updatedScene) @@ -219,16 +750,81 @@ func (_m *SceneReaderWriter) UpdateFull(updatedScene models.Scene) (*models.Scen return r0, r1 } -// UpdateSceneCover provides a mock function with given fields: sceneID, cover -func (_m *SceneReaderWriter) UpdateSceneCover(sceneID int, cover []byte) error { - ret := _m.Called(sceneID, cover) +// UpdateMovies provides a mock function with given fields: sceneID, movies +func (_m *SceneReaderWriter) UpdateMovies(sceneID int, movies []models.MoviesScenes) error { + ret := _m.Called(sceneID, movies) var r0 error - if rf, ok := ret.Get(0).(func(int, []byte) error); ok { - r0 = rf(sceneID, cover) + if rf, ok := ret.Get(0).(func(int, []models.MoviesScenes) error); ok { + r0 = rf(sceneID, movies) } else { r0 = ret.Error(0) } return r0 } + +// UpdatePerformers provides a mock function with given fields: sceneID, performerIDs +func (_m *SceneReaderWriter) UpdatePerformers(sceneID int, performerIDs []int) error { + ret := _m.Called(sceneID, performerIDs) + + var r0 error + if rf, ok := ret.Get(0).(func(int, []int) error); ok { + r0 = rf(sceneID, performerIDs) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// UpdateStashIDs provides a mock function with given fields: sceneID, stashIDs +func (_m *SceneReaderWriter) UpdateStashIDs(sceneID int, stashIDs []models.StashID) error { + ret := _m.Called(sceneID, stashIDs) + + var r0 error + if rf, ok := ret.Get(0).(func(int, []models.StashID) error); ok { + r0 = rf(sceneID, stashIDs) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// UpdateTags provides a mock function with given fields: sceneID, tagIDs +func (_m *SceneReaderWriter) UpdateTags(sceneID int, tagIDs []int) error { + ret := _m.Called(sceneID, tagIDs) + + var r0 error + if rf, ok := ret.Get(0).(func(int, []int) error); ok { + r0 = rf(sceneID, tagIDs) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Wall provides a mock function with given fields: q +func (_m *SceneReaderWriter) Wall(q *string) ([]*models.Scene, error) { + ret := _m.Called(q) + + var r0 []*models.Scene + if rf, ok := ret.Get(0).(func(*string) []*models.Scene); ok { + r0 = rf(q) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*models.Scene) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(*string) error); ok { + r1 = rf(q) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/pkg/models/mocks/ScrapedItemReaderWriter.go b/pkg/models/mocks/ScrapedItemReaderWriter.go new file mode 100644 index 000000000..8b8e1fbdf --- /dev/null +++ b/pkg/models/mocks/ScrapedItemReaderWriter.go @@ -0,0 +1,59 @@ +// Code generated by mockery v0.0.0-dev. DO NOT EDIT. + +package mocks + +import ( + models "github.com/stashapp/stash/pkg/models" + mock "github.com/stretchr/testify/mock" +) + +// ScrapedItemReaderWriter is an autogenerated mock type for the ScrapedItemReaderWriter type +type ScrapedItemReaderWriter struct { + mock.Mock +} + +// All provides a mock function with given fields: +func (_m *ScrapedItemReaderWriter) All() ([]*models.ScrapedItem, error) { + ret := _m.Called() + + var r0 []*models.ScrapedItem + if rf, ok := ret.Get(0).(func() []*models.ScrapedItem); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*models.ScrapedItem) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Create provides a mock function with given fields: newObject +func (_m *ScrapedItemReaderWriter) Create(newObject models.ScrapedItem) (*models.ScrapedItem, error) { + ret := _m.Called(newObject) + + var r0 *models.ScrapedItem + if rf, ok := ret.Get(0).(func(models.ScrapedItem) *models.ScrapedItem); ok { + r0 = rf(newObject) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*models.ScrapedItem) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(models.ScrapedItem) error); ok { + r1 = rf(newObject) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/pkg/models/mocks/StudioReaderWriter.go b/pkg/models/mocks/StudioReaderWriter.go index cc51d0755..bd182c95e 100644 --- a/pkg/models/mocks/StudioReaderWriter.go +++ b/pkg/models/mocks/StudioReaderWriter.go @@ -35,6 +35,50 @@ func (_m *StudioReaderWriter) All() ([]*models.Studio, error) { return r0, r1 } +// AllSlim provides a mock function with given fields: +func (_m *StudioReaderWriter) AllSlim() ([]*models.Studio, error) { + ret := _m.Called() + + var r0 []*models.Studio + if rf, ok := ret.Get(0).(func() []*models.Studio); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*models.Studio) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Count provides a mock function with given fields: +func (_m *StudioReaderWriter) Count() (int, error) { + ret := _m.Called() + + var r0 int + if rf, ok := ret.Get(0).(func() int); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(int) + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // Create provides a mock function with given fields: newStudio func (_m *StudioReaderWriter) Create(newStudio models.Studio) (*models.Studio, error) { ret := _m.Called(newStudio) @@ -58,6 +102,34 @@ func (_m *StudioReaderWriter) Create(newStudio models.Studio) (*models.Studio, e return r0, r1 } +// Destroy provides a mock function with given fields: id +func (_m *StudioReaderWriter) Destroy(id int) error { + ret := _m.Called(id) + + var r0 error + if rf, ok := ret.Get(0).(func(int) error); ok { + r0 = rf(id) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// DestroyImage provides a mock function with given fields: studioID +func (_m *StudioReaderWriter) DestroyImage(studioID int) error { + ret := _m.Called(studioID) + + var r0 error + if rf, ok := ret.Get(0).(func(int) error); ok { + r0 = rf(studioID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // Find provides a mock function with given fields: id func (_m *StudioReaderWriter) Find(id int) (*models.Studio, error) { ret := _m.Called(id) @@ -104,6 +176,29 @@ func (_m *StudioReaderWriter) FindByName(name string, nocase bool) (*models.Stud return r0, r1 } +// FindChildren provides a mock function with given fields: id +func (_m *StudioReaderWriter) FindChildren(id int) ([]*models.Studio, error) { + ret := _m.Called(id) + + var r0 []*models.Studio + if rf, ok := ret.Get(0).(func(int) []*models.Studio); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*models.Studio) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(int) error); ok { + r1 = rf(id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // FindMany provides a mock function with given fields: ids func (_m *StudioReaderWriter) FindMany(ids []int) ([]*models.Studio, error) { ret := _m.Called(ids) @@ -127,8 +222,8 @@ func (_m *StudioReaderWriter) FindMany(ids []int) ([]*models.Studio, error) { return r0, r1 } -// GetStudioImage provides a mock function with given fields: studioID -func (_m *StudioReaderWriter) GetStudioImage(studioID int) ([]byte, error) { +// GetImage provides a mock function with given fields: studioID +func (_m *StudioReaderWriter) GetImage(studioID int) ([]byte, error) { ret := _m.Called(studioID) var r0 []byte @@ -150,6 +245,80 @@ func (_m *StudioReaderWriter) GetStudioImage(studioID int) ([]byte, error) { return r0, r1 } +// GetStashIDs provides a mock function with given fields: studioID +func (_m *StudioReaderWriter) GetStashIDs(studioID int) ([]*models.StashID, error) { + ret := _m.Called(studioID) + + var r0 []*models.StashID + if rf, ok := ret.Get(0).(func(int) []*models.StashID); ok { + r0 = rf(studioID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*models.StashID) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(int) error); ok { + r1 = rf(studioID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// HasImage provides a mock function with given fields: studioID +func (_m *StudioReaderWriter) HasImage(studioID int) (bool, error) { + ret := _m.Called(studioID) + + var r0 bool + if rf, ok := ret.Get(0).(func(int) bool); ok { + r0 = rf(studioID) + } else { + r0 = ret.Get(0).(bool) + } + + var r1 error + if rf, ok := ret.Get(1).(func(int) error); ok { + r1 = rf(studioID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Query provides a mock function with given fields: studioFilter, findFilter +func (_m *StudioReaderWriter) Query(studioFilter *models.StudioFilterType, findFilter *models.FindFilterType) ([]*models.Studio, int, error) { + ret := _m.Called(studioFilter, findFilter) + + var r0 []*models.Studio + if rf, ok := ret.Get(0).(func(*models.StudioFilterType, *models.FindFilterType) []*models.Studio); ok { + r0 = rf(studioFilter, findFilter) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*models.Studio) + } + } + + var r1 int + if rf, ok := ret.Get(1).(func(*models.StudioFilterType, *models.FindFilterType) int); ok { + r1 = rf(studioFilter, findFilter) + } else { + r1 = ret.Get(1).(int) + } + + var r2 error + if rf, ok := ret.Get(2).(func(*models.StudioFilterType, *models.FindFilterType) error); ok { + r2 = rf(studioFilter, findFilter) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + // Update provides a mock function with given fields: updatedStudio func (_m *StudioReaderWriter) Update(updatedStudio models.StudioPartial) (*models.Studio, error) { ret := _m.Called(updatedStudio) @@ -196,8 +365,8 @@ func (_m *StudioReaderWriter) UpdateFull(updatedStudio models.Studio) (*models.S return r0, r1 } -// UpdateStudioImage provides a mock function with given fields: studioID, image -func (_m *StudioReaderWriter) UpdateStudioImage(studioID int, image []byte) error { +// UpdateImage provides a mock function with given fields: studioID, image +func (_m *StudioReaderWriter) UpdateImage(studioID int, image []byte) error { ret := _m.Called(studioID, image) var r0 error @@ -209,3 +378,17 @@ func (_m *StudioReaderWriter) UpdateStudioImage(studioID int, image []byte) erro return r0 } + +// UpdateStashIDs provides a mock function with given fields: studioID, stashIDs +func (_m *StudioReaderWriter) UpdateStashIDs(studioID int, stashIDs []models.StashID) error { + ret := _m.Called(studioID, stashIDs) + + var r0 error + if rf, ok := ret.Get(0).(func(int, []models.StashID) error); ok { + r0 = rf(studioID, stashIDs) + } else { + r0 = ret.Error(0) + } + + return r0 +} diff --git a/pkg/models/mocks/TagReaderWriter.go b/pkg/models/mocks/TagReaderWriter.go index ddeea97bf..3d258ea98 100644 --- a/pkg/models/mocks/TagReaderWriter.go +++ b/pkg/models/mocks/TagReaderWriter.go @@ -35,6 +35,50 @@ func (_m *TagReaderWriter) All() ([]*models.Tag, error) { return r0, r1 } +// AllSlim provides a mock function with given fields: +func (_m *TagReaderWriter) AllSlim() ([]*models.Tag, error) { + ret := _m.Called() + + var r0 []*models.Tag + if rf, ok := ret.Get(0).(func() []*models.Tag); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*models.Tag) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Count provides a mock function with given fields: +func (_m *TagReaderWriter) Count() (int, error) { + ret := _m.Called() + + var r0 int + if rf, ok := ret.Get(0).(func() int); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(int) + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // Create provides a mock function with given fields: newTag func (_m *TagReaderWriter) Create(newTag models.Tag) (*models.Tag, error) { ret := _m.Called(newTag) @@ -58,6 +102,34 @@ func (_m *TagReaderWriter) Create(newTag models.Tag) (*models.Tag, error) { return r0, r1 } +// Destroy provides a mock function with given fields: id +func (_m *TagReaderWriter) Destroy(id int) error { + ret := _m.Called(id) + + var r0 error + if rf, ok := ret.Get(0).(func(int) error); ok { + r0 = rf(id) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// DestroyImage provides a mock function with given fields: tagID +func (_m *TagReaderWriter) DestroyImage(tagID int) error { + ret := _m.Called(tagID) + + var r0 error + if rf, ok := ret.Get(0).(func(int) error); ok { + r0 = rf(tagID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // Find provides a mock function with given fields: id func (_m *TagReaderWriter) Find(id int) (*models.Tag, error) { ret := _m.Called(id) @@ -242,8 +314,8 @@ func (_m *TagReaderWriter) FindMany(ids []int) ([]*models.Tag, error) { return r0, r1 } -// GetTagImage provides a mock function with given fields: tagID -func (_m *TagReaderWriter) GetTagImage(tagID int) ([]byte, error) { +// GetImage provides a mock function with given fields: tagID +func (_m *TagReaderWriter) GetImage(tagID int) ([]byte, error) { ret := _m.Called(tagID) var r0 []byte @@ -265,6 +337,36 @@ func (_m *TagReaderWriter) GetTagImage(tagID int) ([]byte, error) { return r0, r1 } +// Query provides a mock function with given fields: tagFilter, findFilter +func (_m *TagReaderWriter) Query(tagFilter *models.TagFilterType, findFilter *models.FindFilterType) ([]*models.Tag, int, error) { + ret := _m.Called(tagFilter, findFilter) + + var r0 []*models.Tag + if rf, ok := ret.Get(0).(func(*models.TagFilterType, *models.FindFilterType) []*models.Tag); ok { + r0 = rf(tagFilter, findFilter) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*models.Tag) + } + } + + var r1 int + if rf, ok := ret.Get(1).(func(*models.TagFilterType, *models.FindFilterType) int); ok { + r1 = rf(tagFilter, findFilter) + } else { + r1 = ret.Get(1).(int) + } + + var r2 error + if rf, ok := ret.Get(2).(func(*models.TagFilterType, *models.FindFilterType) error); ok { + r2 = rf(tagFilter, findFilter) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + // Update provides a mock function with given fields: updatedTag func (_m *TagReaderWriter) Update(updatedTag models.Tag) (*models.Tag, error) { ret := _m.Called(updatedTag) @@ -288,8 +390,8 @@ func (_m *TagReaderWriter) Update(updatedTag models.Tag) (*models.Tag, error) { return r0, r1 } -// UpdateTagImage provides a mock function with given fields: tagID, image -func (_m *TagReaderWriter) UpdateTagImage(tagID int, image []byte) error { +// UpdateImage provides a mock function with given fields: tagID, image +func (_m *TagReaderWriter) UpdateImage(tagID int, image []byte) error { ret := _m.Called(tagID, image) var r0 error diff --git a/pkg/models/mocks/transaction.go b/pkg/models/mocks/transaction.go new file mode 100644 index 000000000..9b4a19feb --- /dev/null +++ b/pkg/models/mocks/transaction.go @@ -0,0 +1,117 @@ +package mocks + +import ( + "context" + + models "github.com/stashapp/stash/pkg/models" +) + +type TransactionManager struct { + gallery models.GalleryReaderWriter + image models.ImageReaderWriter + movie models.MovieReaderWriter + performer models.PerformerReaderWriter + scene models.SceneReaderWriter + sceneMarker models.SceneMarkerReaderWriter + scrapedItem models.ScrapedItemReaderWriter + studio models.StudioReaderWriter + tag models.TagReaderWriter +} + +func NewTransactionManager() *TransactionManager { + return &TransactionManager{ + gallery: &GalleryReaderWriter{}, + image: &ImageReaderWriter{}, + movie: &MovieReaderWriter{}, + performer: &PerformerReaderWriter{}, + scene: &SceneReaderWriter{}, + sceneMarker: &SceneMarkerReaderWriter{}, + scrapedItem: &ScrapedItemReaderWriter{}, + studio: &StudioReaderWriter{}, + tag: &TagReaderWriter{}, + } +} + +func (t *TransactionManager) WithTxn(ctx context.Context, fn func(r models.Repository) error) error { + return fn(t) +} + +func (t *TransactionManager) Gallery() models.GalleryReaderWriter { + return t.gallery +} + +func (t *TransactionManager) Image() models.ImageReaderWriter { + return t.image +} + +func (t *TransactionManager) Movie() models.MovieReaderWriter { + return t.movie +} + +func (t *TransactionManager) Performer() models.PerformerReaderWriter { + return t.performer +} + +func (t *TransactionManager) SceneMarker() models.SceneMarkerReaderWriter { + return t.sceneMarker +} + +func (t *TransactionManager) Scene() models.SceneReaderWriter { + return t.scene +} + +func (t *TransactionManager) ScrapedItem() models.ScrapedItemReaderWriter { + return t.scrapedItem +} + +func (t *TransactionManager) Studio() models.StudioReaderWriter { + return t.studio +} + +func (t *TransactionManager) Tag() models.TagReaderWriter { + return t.tag +} + +type ReadTransaction struct { + t *TransactionManager +} + +func (t *TransactionManager) WithReadTxn(ctx context.Context, fn func(r models.ReaderRepository) error) error { + return fn(&ReadTransaction{t: t}) +} + +func (r *ReadTransaction) Gallery() models.GalleryReader { + return r.t.gallery +} + +func (r *ReadTransaction) Image() models.ImageReader { + return r.t.image +} + +func (r *ReadTransaction) Movie() models.MovieReader { + return r.t.movie +} + +func (r *ReadTransaction) Performer() models.PerformerReader { + return r.t.performer +} + +func (r *ReadTransaction) SceneMarker() models.SceneMarkerReader { + return r.t.sceneMarker +} + +func (r *ReadTransaction) Scene() models.SceneReader { + return r.t.scene +} + +func (r *ReadTransaction) ScrapedItem() models.ScrapedItemReader { + return r.t.scrapedItem +} + +func (r *ReadTransaction) Studio() models.StudioReader { + return r.t.studio +} + +func (r *ReadTransaction) Tag() models.TagReader { + return r.t.tag +} diff --git a/pkg/models/model_gallery.go b/pkg/models/model_gallery.go index db8383ad4..acb0ca7c3 100644 --- a/pkg/models/model_gallery.go +++ b/pkg/models/model_gallery.go @@ -42,3 +42,13 @@ type GalleryPartial struct { } const DefaultGthumbWidth int = 640 + +type Galleries []*Gallery + +func (g *Galleries) Append(o interface{}) { + *g = append(*g, o.(*Gallery)) +} + +func (g *Galleries) New() interface{} { + return &Gallery{} +} diff --git a/pkg/models/model_image.go b/pkg/models/model_image.go index 6ad13eb2f..47c21fcd5 100644 --- a/pkg/models/model_image.go +++ b/pkg/models/model_image.go @@ -46,3 +46,13 @@ type ImageFileType struct { Width *int `graphql:"width" json:"width"` Height *int `graphql:"height" json:"height"` } + +type Images []*Image + +func (i *Images) Append(o interface{}) { + *i = append(*i, o.(*Image)) +} + +func (i *Images) New() interface{} { + return &Image{} +} diff --git a/pkg/models/model_joins.go b/pkg/models/model_joins.go index 9c522f70a..802651e21 100644 --- a/pkg/models/model_joins.go +++ b/pkg/models/model_joins.go @@ -2,53 +2,13 @@ package models import "database/sql" -type PerformersScenes struct { - PerformerID int `db:"performer_id" json:"performer_id"` - SceneID int `db:"scene_id" json:"scene_id"` -} - type MoviesScenes struct { MovieID int `db:"movie_id" json:"movie_id"` SceneID int `db:"scene_id" json:"scene_id"` SceneIndex sql.NullInt64 `db:"scene_index" json:"scene_index"` } -type ScenesTags struct { - SceneID int `db:"scene_id" json:"scene_id"` - TagID int `db:"tag_id" json:"tag_id"` -} - -type SceneMarkersTags struct { - SceneMarkerID int `db:"scene_marker_id" json:"scene_marker_id"` - TagID int `db:"tag_id" json:"tag_id"` -} - type StashID struct { StashID string `db:"stash_id" json:"stash_id"` Endpoint string `db:"endpoint" json:"endpoint"` } - -type PerformersImages struct { - PerformerID int `db:"performer_id" json:"performer_id"` - ImageID int `db:"image_id" json:"image_id"` -} - -type ImagesTags struct { - ImageID int `db:"image_id" json:"image_id"` - TagID int `db:"tag_id" json:"tag_id"` -} - -type GalleriesImages struct { - GalleryID int `db:"gallery_id" json:"gallery_id"` - ImageID int `db:"image_id" json:"image_id"` -} - -type PerformersGalleries struct { - PerformerID int `db:"performer_id" json:"performer_id"` - GalleryID int `db:"gallery_id" json:"gallery_id"` -} - -type GalleriesTags struct { - TagID int `db:"tag_id" json:"tag_id"` - GalleryID int `db:"gallery_id" json:"gallery_id"` -} diff --git a/pkg/models/model_movie.go b/pkg/models/model_movie.go index 678ce2944..18bf49a74 100644 --- a/pkg/models/model_movie.go +++ b/pkg/models/model_movie.go @@ -50,3 +50,13 @@ func NewMovie(name string) *Movie { UpdatedAt: SQLiteTimestamp{Timestamp: currentTime}, } } + +type Movies []*Movie + +func (m *Movies) Append(o interface{}) { + *m = append(*m, o.(*Movie)) +} + +func (m *Movies) New() interface{} { + return &Movie{} +} diff --git a/pkg/models/model_performer.go b/pkg/models/model_performer.go index 52dbfa699..4d6134b8a 100644 --- a/pkg/models/model_performer.go +++ b/pkg/models/model_performer.go @@ -65,3 +65,13 @@ func NewPerformer(name string) *Performer { UpdatedAt: SQLiteTimestamp{Timestamp: currentTime}, } } + +type Performers []*Performer + +func (p *Performers) Append(o interface{}) { + *p = append(*p, o.(*Performer)) +} + +func (p *Performers) New() interface{} { + return &Performer{} +} diff --git a/pkg/models/model_scene.go b/pkg/models/model_scene.go index 5c036aee7..5159a2b0a 100644 --- a/pkg/models/model_scene.go +++ b/pkg/models/model_scene.go @@ -95,3 +95,13 @@ type SceneFileType struct { Framerate *float64 `graphql:"framerate" json:"framerate"` Bitrate *int `graphql:"bitrate" json:"bitrate"` } + +type Scenes []*Scene + +func (s *Scenes) Append(o interface{}) { + *s = append(*s, o.(*Scene)) +} + +func (m *Scenes) New() interface{} { + return &Scene{} +} diff --git a/pkg/models/model_scene_marker.go b/pkg/models/model_scene_marker.go index 56139800d..d69b475bb 100644 --- a/pkg/models/model_scene_marker.go +++ b/pkg/models/model_scene_marker.go @@ -13,3 +13,13 @@ type SceneMarker struct { CreatedAt SQLiteTimestamp `db:"created_at" json:"created_at"` UpdatedAt SQLiteTimestamp `db:"updated_at" json:"updated_at"` } + +type SceneMarkers []*SceneMarker + +func (m *SceneMarkers) Append(o interface{}) { + *m = append(*m, o.(*SceneMarker)) +} + +func (m *SceneMarkers) New() interface{} { + return &SceneMarker{} +} diff --git a/pkg/models/model_scraped_item.go b/pkg/models/model_scraped_item.go index db1e05b7c..a3f6aff44 100644 --- a/pkg/models/model_scraped_item.go +++ b/pkg/models/model_scraped_item.go @@ -174,3 +174,13 @@ type ScrapedMovieStudio struct { Name string `graphql:"name" json:"name"` URL *string `graphql:"url" json:"url"` } + +type ScrapedItems []*ScrapedItem + +func (s *ScrapedItems) Append(o interface{}) { + *s = append(*s, o.(*ScrapedItem)) +} + +func (s *ScrapedItems) New() interface{} { + return &ScrapedItem{} +} diff --git a/pkg/models/model_studio.go b/pkg/models/model_studio.go index d3a4940ff..4bc687526 100644 --- a/pkg/models/model_studio.go +++ b/pkg/models/model_studio.go @@ -38,3 +38,13 @@ func NewStudio(name string) *Studio { UpdatedAt: SQLiteTimestamp{Timestamp: currentTime}, } } + +type Studios []*Studio + +func (s *Studios) Append(o interface{}) { + *s = append(*s, o.(*Studio)) +} + +func (s *Studios) New() interface{} { + return &Studio{} +} diff --git a/pkg/models/model_tag.go b/pkg/models/model_tag.go index 181f188a3..7dc419b0b 100644 --- a/pkg/models/model_tag.go +++ b/pkg/models/model_tag.go @@ -18,6 +18,16 @@ func NewTag(name string) *Tag { } } +type Tags []*Tag + +func (t *Tags) Append(o interface{}) { + *t = append(*t, o.(*Tag)) +} + +func (t *Tags) New() interface{} { + return &Tag{} +} + // Original Tag image from: https://fontawesome.com/icons/tag?style=solid // Modified to change color and rotate // Licensed under CC Attribution 4.0: https://fontawesome.com/license diff --git a/pkg/models/movie.go b/pkg/models/movie.go index f4dd7ea89..62f3bf52d 100644 --- a/pkg/models/movie.go +++ b/pkg/models/movie.go @@ -1,9 +1,5 @@ package models -import ( - "github.com/jmoiron/sqlx" -) - type MovieReader interface { Find(id int) (*Movie, error) FindMany(ids []int) ([]*Movie, error) @@ -11,8 +7,9 @@ type MovieReader interface { FindByName(name string, nocase bool) (*Movie, error) FindByNames(names []string, nocase bool) ([]*Movie, error) All() ([]*Movie, error) - // AllSlim() ([]*Movie, error) - // Query(movieFilter *MovieFilterType, findFilter *FindFilterType) ([]*Movie, int) + Count() (int, error) + AllSlim() ([]*Movie, error) + Query(movieFilter *MovieFilterType, findFilter *FindFilterType) ([]*Movie, int, error) GetFrontImage(movieID int) ([]byte, error) GetBackImage(movieID int) ([]byte, error) } @@ -21,68 +18,12 @@ type MovieWriter interface { Create(newMovie Movie) (*Movie, error) Update(updatedMovie MoviePartial) (*Movie, error) UpdateFull(updatedMovie Movie) (*Movie, error) - // Destroy(id string) error - UpdateMovieImages(movieID int, frontImage []byte, backImage []byte) error - // DestroyMovieImages(movieID int) error + Destroy(id int) error + UpdateImages(movieID int, frontImage []byte, backImage []byte) error + DestroyImages(movieID int) error } type MovieReaderWriter interface { MovieReader MovieWriter } - -func NewMovieReaderWriter(tx *sqlx.Tx) MovieReaderWriter { - return &movieReaderWriter{ - tx: tx, - qb: NewMovieQueryBuilder(), - } -} - -type movieReaderWriter struct { - tx *sqlx.Tx - qb MovieQueryBuilder -} - -func (t *movieReaderWriter) Find(id int) (*Movie, error) { - return t.qb.Find(id, t.tx) -} - -func (t *movieReaderWriter) FindMany(ids []int) ([]*Movie, error) { - return t.qb.FindMany(ids) -} - -func (t *movieReaderWriter) FindByName(name string, nocase bool) (*Movie, error) { - return t.qb.FindByName(name, t.tx, nocase) -} - -func (t *movieReaderWriter) FindByNames(names []string, nocase bool) ([]*Movie, error) { - return t.qb.FindByNames(names, t.tx, nocase) -} - -func (t *movieReaderWriter) All() ([]*Movie, error) { - return t.qb.All() -} - -func (t *movieReaderWriter) GetFrontImage(movieID int) ([]byte, error) { - return t.qb.GetFrontImage(movieID, t.tx) -} - -func (t *movieReaderWriter) GetBackImage(movieID int) ([]byte, error) { - return t.qb.GetBackImage(movieID, t.tx) -} - -func (t *movieReaderWriter) Create(newMovie Movie) (*Movie, error) { - return t.qb.Create(newMovie, t.tx) -} - -func (t *movieReaderWriter) Update(updatedMovie MoviePartial) (*Movie, error) { - return t.qb.Update(updatedMovie, t.tx) -} - -func (t *movieReaderWriter) UpdateFull(updatedMovie Movie) (*Movie, error) { - return t.qb.UpdateFull(updatedMovie, t.tx) -} - -func (t *movieReaderWriter) UpdateMovieImages(movieID int, frontImage []byte, backImage []byte) error { - return t.qb.UpdateMovieImages(movieID, frontImage, backImage, t.tx) -} diff --git a/pkg/models/performer.go b/pkg/models/performer.go index 61a4d08e3..92fbe4399 100644 --- a/pkg/models/performer.go +++ b/pkg/models/performer.go @@ -1,94 +1,32 @@ package models -import ( - "github.com/jmoiron/sqlx" -) - type PerformerReader interface { - // Find(id int) (*Performer, error) + Find(id int) (*Performer, error) FindMany(ids []int) ([]*Performer, error) FindBySceneID(sceneID int) ([]*Performer, error) FindNamesBySceneID(sceneID int) ([]*Performer, error) FindByImageID(imageID int) ([]*Performer, error) FindByGalleryID(galleryID int) ([]*Performer, error) FindByNames(names []string, nocase bool) ([]*Performer, error) - // Count() (int, error) + Count() (int, error) All() ([]*Performer, error) - // AllSlim() ([]*Performer, error) - // Query(performerFilter *PerformerFilterType, findFilter *FindFilterType) ([]*Performer, int) - GetPerformerImage(performerID int) ([]byte, error) + AllSlim() ([]*Performer, error) + Query(performerFilter *PerformerFilterType, findFilter *FindFilterType) ([]*Performer, int, error) + GetImage(performerID int) ([]byte, error) + GetStashIDs(performerID int) ([]*StashID, error) } type PerformerWriter interface { Create(newPerformer Performer) (*Performer, error) Update(updatedPerformer PerformerPartial) (*Performer, error) UpdateFull(updatedPerformer Performer) (*Performer, error) - // Destroy(id string) error - UpdatePerformerImage(performerID int, image []byte) error - // DestroyPerformerImage(performerID int) error + Destroy(id int) error + UpdateImage(performerID int, image []byte) error + DestroyImage(performerID int) error + UpdateStashIDs(performerID int, stashIDs []StashID) error } type PerformerReaderWriter interface { PerformerReader PerformerWriter } - -func NewPerformerReaderWriter(tx *sqlx.Tx) PerformerReaderWriter { - return &performerReaderWriter{ - tx: tx, - qb: NewPerformerQueryBuilder(), - } -} - -type performerReaderWriter struct { - tx *sqlx.Tx - qb PerformerQueryBuilder -} - -func (t *performerReaderWriter) FindMany(ids []int) ([]*Performer, error) { - return t.qb.FindMany(ids) -} - -func (t *performerReaderWriter) FindByNames(names []string, nocase bool) ([]*Performer, error) { - return t.qb.FindByNames(names, t.tx, nocase) -} - -func (t *performerReaderWriter) All() ([]*Performer, error) { - return t.qb.All() -} - -func (t *performerReaderWriter) GetPerformerImage(performerID int) ([]byte, error) { - return t.qb.GetPerformerImage(performerID, t.tx) -} - -func (t *performerReaderWriter) FindBySceneID(id int) ([]*Performer, error) { - return t.qb.FindBySceneID(id, t.tx) -} - -func (t *performerReaderWriter) FindNamesBySceneID(sceneID int) ([]*Performer, error) { - return t.qb.FindNameBySceneID(sceneID, t.tx) -} - -func (t *performerReaderWriter) FindByImageID(id int) ([]*Performer, error) { - return t.qb.FindByImageID(id, t.tx) -} - -func (t *performerReaderWriter) FindByGalleryID(id int) ([]*Performer, error) { - return t.qb.FindByGalleryID(id, t.tx) -} - -func (t *performerReaderWriter) Create(newPerformer Performer) (*Performer, error) { - return t.qb.Create(newPerformer, t.tx) -} - -func (t *performerReaderWriter) Update(updatedPerformer PerformerPartial) (*Performer, error) { - return t.qb.Update(updatedPerformer, t.tx) -} - -func (t *performerReaderWriter) UpdateFull(updatedPerformer Performer) (*Performer, error) { - return t.qb.UpdateFull(updatedPerformer, t.tx) -} - -func (t *performerReaderWriter) UpdatePerformerImage(performerID int, image []byte) error { - return t.qb.UpdatePerformerImage(performerID, image, t.tx) -} diff --git a/pkg/models/querybuilder_gallery_test.go b/pkg/models/querybuilder_gallery_test.go deleted file mode 100644 index da46a18d9..000000000 --- a/pkg/models/querybuilder_gallery_test.go +++ /dev/null @@ -1,225 +0,0 @@ -// +build integration - -package models_test - -import ( - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/stashapp/stash/pkg/models" -) - -func TestGalleryFind(t *testing.T) { - gqb := models.NewGalleryQueryBuilder() - - const galleryIdx = 0 - gallery, err := gqb.Find(galleryIDs[galleryIdx], nil) - - if err != nil { - t.Fatalf("Error finding gallery: %s", err.Error()) - } - - assert.Equal(t, getGalleryStringValue(galleryIdx, "Path"), gallery.Path.String) - - gallery, err = gqb.Find(0, nil) - - if err != nil { - t.Fatalf("Error finding gallery: %s", err.Error()) - } - - assert.Nil(t, gallery) -} - -func TestGalleryFindByChecksum(t *testing.T) { - gqb := models.NewGalleryQueryBuilder() - - const galleryIdx = 0 - galleryChecksum := getGalleryStringValue(galleryIdx, "Checksum") - gallery, err := gqb.FindByChecksum(galleryChecksum, nil) - - if err != nil { - t.Fatalf("Error finding gallery: %s", err.Error()) - } - - assert.Equal(t, getGalleryStringValue(galleryIdx, "Path"), gallery.Path.String) - - galleryChecksum = "not exist" - gallery, err = gqb.FindByChecksum(galleryChecksum, nil) - - if err != nil { - t.Fatalf("Error finding gallery: %s", err.Error()) - } - - assert.Nil(t, gallery) -} - -func TestGalleryFindByPath(t *testing.T) { - gqb := models.NewGalleryQueryBuilder() - - const galleryIdx = 0 - galleryPath := getGalleryStringValue(galleryIdx, "Path") - gallery, err := gqb.FindByPath(galleryPath) - - if err != nil { - t.Fatalf("Error finding gallery: %s", err.Error()) - } - - assert.Equal(t, galleryPath, gallery.Path.String) - - galleryPath = "not exist" - gallery, err = gqb.FindByPath(galleryPath) - - if err != nil { - t.Fatalf("Error finding gallery: %s", err.Error()) - } - - assert.Nil(t, gallery) -} - -func TestGalleryFindBySceneID(t *testing.T) { - gqb := models.NewGalleryQueryBuilder() - - sceneID := sceneIDs[sceneIdxWithGallery] - gallery, err := gqb.FindBySceneID(sceneID, nil) - - if err != nil { - t.Fatalf("Error finding gallery: %s", err.Error()) - } - - assert.Equal(t, getGalleryStringValue(galleryIdxWithScene, "Path"), gallery.Path.String) - - gallery, err = gqb.FindBySceneID(0, nil) - - if err != nil { - t.Fatalf("Error finding gallery: %s", err.Error()) - } - - assert.Nil(t, gallery) -} - -func TestGalleryQueryQ(t *testing.T) { - const galleryIdx = 0 - - q := getGalleryStringValue(galleryIdx, pathField) - - sqb := models.NewGalleryQueryBuilder() - - galleryQueryQ(t, sqb, q, galleryIdx) -} - -func galleryQueryQ(t *testing.T, qb models.GalleryQueryBuilder, q string, expectedGalleryIdx int) { - filter := models.FindFilterType{ - Q: &q, - } - galleries, _ := qb.Query(nil, &filter) - - assert.Len(t, galleries, 1) - gallery := galleries[0] - assert.Equal(t, galleryIDs[expectedGalleryIdx], gallery.ID) - - // no Q should return all results - filter.Q = nil - galleries, _ = qb.Query(nil, &filter) - - assert.Len(t, galleries, totalGalleries) -} - -func TestGalleryQueryPath(t *testing.T) { - const galleryIdx = 1 - galleryPath := getGalleryStringValue(galleryIdx, "Path") - - pathCriterion := models.StringCriterionInput{ - Value: galleryPath, - Modifier: models.CriterionModifierEquals, - } - - verifyGalleriesPath(t, pathCriterion) - - pathCriterion.Modifier = models.CriterionModifierNotEquals - verifyGalleriesPath(t, pathCriterion) -} - -func verifyGalleriesPath(t *testing.T, pathCriterion models.StringCriterionInput) { - sqb := models.NewGalleryQueryBuilder() - galleryFilter := models.GalleryFilterType{ - Path: &pathCriterion, - } - - galleries, _ := sqb.Query(&galleryFilter, nil) - - for _, gallery := range galleries { - verifyNullString(t, gallery.Path, pathCriterion) - } -} - -func TestGalleryQueryRating(t *testing.T) { - const rating = 3 - ratingCriterion := models.IntCriterionInput{ - Value: rating, - Modifier: models.CriterionModifierEquals, - } - - verifyGalleriesRating(t, ratingCriterion) - - ratingCriterion.Modifier = models.CriterionModifierNotEquals - verifyGalleriesRating(t, ratingCriterion) - - ratingCriterion.Modifier = models.CriterionModifierGreaterThan - verifyGalleriesRating(t, ratingCriterion) - - ratingCriterion.Modifier = models.CriterionModifierLessThan - verifyGalleriesRating(t, ratingCriterion) - - ratingCriterion.Modifier = models.CriterionModifierIsNull - verifyGalleriesRating(t, ratingCriterion) - - ratingCriterion.Modifier = models.CriterionModifierNotNull - verifyGalleriesRating(t, ratingCriterion) -} - -func verifyGalleriesRating(t *testing.T, ratingCriterion models.IntCriterionInput) { - sqb := models.NewGalleryQueryBuilder() - galleryFilter := models.GalleryFilterType{ - Rating: &ratingCriterion, - } - - galleries, _ := sqb.Query(&galleryFilter, nil) - - for _, gallery := range galleries { - verifyInt64(t, gallery.Rating, ratingCriterion) - } -} - -func TestGalleryQueryIsMissingScene(t *testing.T) { - qb := models.NewGalleryQueryBuilder() - isMissing := "scene" - galleryFilter := models.GalleryFilterType{ - IsMissing: &isMissing, - } - - q := getGalleryStringValue(galleryIdxWithScene, titleField) - findFilter := models.FindFilterType{ - Q: &q, - } - - galleries, _ := qb.Query(&galleryFilter, &findFilter) - - assert.Len(t, galleries, 0) - - findFilter.Q = nil - galleries, _ = qb.Query(&galleryFilter, &findFilter) - - // ensure non of the ids equal the one with gallery - for _, gallery := range galleries { - assert.NotEqual(t, galleryIDs[galleryIdxWithScene], gallery.ID) - } -} - -// TODO ValidGalleriesForScenePath -// TODO Count -// TODO All -// TODO Query -// TODO Update -// TODO Destroy -// TODO ClearGalleryId diff --git a/pkg/models/querybuilder_image_test.go b/pkg/models/querybuilder_image_test.go deleted file mode 100644 index 83d662991..000000000 --- a/pkg/models/querybuilder_image_test.go +++ /dev/null @@ -1,652 +0,0 @@ -// +build integration - -package models_test - -import ( - "database/sql" - "strconv" - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/stashapp/stash/pkg/models" -) - -func TestImageFind(t *testing.T) { - // assume that the first image is imageWithGalleryPath - sqb := models.NewImageQueryBuilder() - - const imageIdx = 0 - imageID := imageIDs[imageIdx] - image, err := sqb.Find(imageID) - - if err != nil { - t.Fatalf("Error finding image: %s", err.Error()) - } - - assert.Equal(t, getImageStringValue(imageIdx, "Path"), image.Path) - - imageID = 0 - image, err = sqb.Find(imageID) - - if err != nil { - t.Fatalf("Error finding image: %s", err.Error()) - } - - assert.Nil(t, image) -} - -func TestImageFindByPath(t *testing.T) { - sqb := models.NewImageQueryBuilder() - - const imageIdx = 1 - imagePath := getImageStringValue(imageIdx, "Path") - image, err := sqb.FindByPath(imagePath) - - if err != nil { - t.Fatalf("Error finding image: %s", err.Error()) - } - - assert.Equal(t, imageIDs[imageIdx], image.ID) - assert.Equal(t, imagePath, image.Path) - - imagePath = "not exist" - image, err = sqb.FindByPath(imagePath) - - if err != nil { - t.Fatalf("Error finding image: %s", err.Error()) - } - - assert.Nil(t, image) -} - -func TestImageCountByPerformerID(t *testing.T) { - sqb := models.NewImageQueryBuilder() - count, err := sqb.CountByPerformerID(performerIDs[performerIdxWithImage]) - - if err != nil { - t.Fatalf("Error counting images: %s", err.Error()) - } - - assert.Equal(t, 1, count) - - count, err = sqb.CountByPerformerID(0) - - if err != nil { - t.Fatalf("Error counting images: %s", err.Error()) - } - - assert.Equal(t, 0, count) -} - -func TestImageQueryQ(t *testing.T) { - const imageIdx = 2 - - q := getImageStringValue(imageIdx, titleField) - - sqb := models.NewImageQueryBuilder() - - imageQueryQ(t, sqb, q, imageIdx) -} - -func imageQueryQ(t *testing.T, sqb models.ImageQueryBuilder, q string, expectedImageIdx int) { - filter := models.FindFilterType{ - Q: &q, - } - images, _ := sqb.Query(nil, &filter) - - assert.Len(t, images, 1) - image := images[0] - assert.Equal(t, imageIDs[expectedImageIdx], image.ID) - - // no Q should return all results - filter.Q = nil - images, _ = sqb.Query(nil, &filter) - - assert.Len(t, images, totalImages) -} - -func TestImageQueryPath(t *testing.T) { - const imageIdx = 1 - imagePath := getImageStringValue(imageIdx, "Path") - - pathCriterion := models.StringCriterionInput{ - Value: imagePath, - Modifier: models.CriterionModifierEquals, - } - - verifyImagePath(t, pathCriterion) - - pathCriterion.Modifier = models.CriterionModifierNotEquals - verifyImagePath(t, pathCriterion) -} - -func verifyImagePath(t *testing.T, pathCriterion models.StringCriterionInput) { - sqb := models.NewImageQueryBuilder() - imageFilter := models.ImageFilterType{ - Path: &pathCriterion, - } - - images, _ := sqb.Query(&imageFilter, nil) - - for _, image := range images { - verifyString(t, image.Path, pathCriterion) - } -} - -func TestImageQueryRating(t *testing.T) { - const rating = 3 - ratingCriterion := models.IntCriterionInput{ - Value: rating, - Modifier: models.CriterionModifierEquals, - } - - verifyImagesRating(t, ratingCriterion) - - ratingCriterion.Modifier = models.CriterionModifierNotEquals - verifyImagesRating(t, ratingCriterion) - - ratingCriterion.Modifier = models.CriterionModifierGreaterThan - verifyImagesRating(t, ratingCriterion) - - ratingCriterion.Modifier = models.CriterionModifierLessThan - verifyImagesRating(t, ratingCriterion) - - ratingCriterion.Modifier = models.CriterionModifierIsNull - verifyImagesRating(t, ratingCriterion) - - ratingCriterion.Modifier = models.CriterionModifierNotNull - verifyImagesRating(t, ratingCriterion) -} - -func verifyImagesRating(t *testing.T, ratingCriterion models.IntCriterionInput) { - sqb := models.NewImageQueryBuilder() - imageFilter := models.ImageFilterType{ - Rating: &ratingCriterion, - } - - images, _ := sqb.Query(&imageFilter, nil) - - for _, image := range images { - verifyInt64(t, image.Rating, ratingCriterion) - } -} - -func TestImageQueryOCounter(t *testing.T) { - const oCounter = 1 - oCounterCriterion := models.IntCriterionInput{ - Value: oCounter, - Modifier: models.CriterionModifierEquals, - } - - verifyImagesOCounter(t, oCounterCriterion) - - oCounterCriterion.Modifier = models.CriterionModifierNotEquals - verifyImagesOCounter(t, oCounterCriterion) - - oCounterCriterion.Modifier = models.CriterionModifierGreaterThan - verifyImagesOCounter(t, oCounterCriterion) - - oCounterCriterion.Modifier = models.CriterionModifierLessThan - verifyImagesOCounter(t, oCounterCriterion) -} - -func verifyImagesOCounter(t *testing.T, oCounterCriterion models.IntCriterionInput) { - sqb := models.NewImageQueryBuilder() - imageFilter := models.ImageFilterType{ - OCounter: &oCounterCriterion, - } - - images, _ := sqb.Query(&imageFilter, nil) - - for _, image := range images { - verifyInt(t, image.OCounter, oCounterCriterion) - } -} - -func TestImageQueryResolution(t *testing.T) { - verifyImagesResolution(t, models.ResolutionEnumLow) - verifyImagesResolution(t, models.ResolutionEnumStandard) - verifyImagesResolution(t, models.ResolutionEnumStandardHd) - verifyImagesResolution(t, models.ResolutionEnumFullHd) - verifyImagesResolution(t, models.ResolutionEnumFourK) - verifyImagesResolution(t, models.ResolutionEnum("unknown")) -} - -func verifyImagesResolution(t *testing.T, resolution models.ResolutionEnum) { - sqb := models.NewImageQueryBuilder() - imageFilter := models.ImageFilterType{ - Resolution: &resolution, - } - - images, _ := sqb.Query(&imageFilter, nil) - - for _, image := range images { - verifyImageResolution(t, image.Height, resolution) - } -} - -func verifyImageResolution(t *testing.T, height sql.NullInt64, resolution models.ResolutionEnum) { - assert := assert.New(t) - h := height.Int64 - - switch resolution { - case models.ResolutionEnumLow: - assert.True(h < 480) - case models.ResolutionEnumStandard: - assert.True(h >= 480 && h < 720) - case models.ResolutionEnumStandardHd: - assert.True(h >= 720 && h < 1080) - case models.ResolutionEnumFullHd: - assert.True(h >= 1080 && h < 2160) - case models.ResolutionEnumFourK: - assert.True(h >= 2160) - } -} - -func TestImageQueryIsMissingGalleries(t *testing.T) { - sqb := models.NewImageQueryBuilder() - isMissing := "galleries" - imageFilter := models.ImageFilterType{ - IsMissing: &isMissing, - } - - q := getImageStringValue(imageIdxWithGallery, titleField) - findFilter := models.FindFilterType{ - Q: &q, - } - - images, _ := sqb.Query(&imageFilter, &findFilter) - - assert.Len(t, images, 0) - - findFilter.Q = nil - images, _ = sqb.Query(&imageFilter, &findFilter) - - // ensure non of the ids equal the one with gallery - for _, image := range images { - assert.NotEqual(t, imageIDs[imageIdxWithGallery], image.ID) - } -} - -func TestImageQueryIsMissingStudio(t *testing.T) { - sqb := models.NewImageQueryBuilder() - isMissing := "studio" - imageFilter := models.ImageFilterType{ - IsMissing: &isMissing, - } - - q := getImageStringValue(imageIdxWithStudio, titleField) - findFilter := models.FindFilterType{ - Q: &q, - } - - images, _ := sqb.Query(&imageFilter, &findFilter) - - assert.Len(t, images, 0) - - findFilter.Q = nil - images, _ = sqb.Query(&imageFilter, &findFilter) - - // ensure non of the ids equal the one with studio - for _, image := range images { - assert.NotEqual(t, imageIDs[imageIdxWithStudio], image.ID) - } -} - -func TestImageQueryIsMissingPerformers(t *testing.T) { - sqb := models.NewImageQueryBuilder() - isMissing := "performers" - imageFilter := models.ImageFilterType{ - IsMissing: &isMissing, - } - - q := getImageStringValue(imageIdxWithPerformer, titleField) - findFilter := models.FindFilterType{ - Q: &q, - } - - images, _ := sqb.Query(&imageFilter, &findFilter) - - assert.Len(t, images, 0) - - findFilter.Q = nil - images, _ = sqb.Query(&imageFilter, &findFilter) - - assert.True(t, len(images) > 0) - - // ensure non of the ids equal the one with movies - for _, image := range images { - assert.NotEqual(t, imageIDs[imageIdxWithPerformer], image.ID) - } -} - -func TestImageQueryIsMissingTags(t *testing.T) { - sqb := models.NewImageQueryBuilder() - isMissing := "tags" - imageFilter := models.ImageFilterType{ - IsMissing: &isMissing, - } - - q := getImageStringValue(imageIdxWithTwoTags, titleField) - findFilter := models.FindFilterType{ - Q: &q, - } - - images, _ := sqb.Query(&imageFilter, &findFilter) - - assert.Len(t, images, 0) - - findFilter.Q = nil - images, _ = sqb.Query(&imageFilter, &findFilter) - - assert.True(t, len(images) > 0) -} - -func TestImageQueryIsMissingRating(t *testing.T) { - sqb := models.NewImageQueryBuilder() - isMissing := "rating" - imageFilter := models.ImageFilterType{ - IsMissing: &isMissing, - } - - images, _ := sqb.Query(&imageFilter, nil) - - assert.True(t, len(images) > 0) - - // ensure date is null, empty or "0001-01-01" - for _, image := range images { - assert.True(t, !image.Rating.Valid) - } -} - -func TestImageQueryPerformers(t *testing.T) { - sqb := models.NewImageQueryBuilder() - performerCriterion := models.MultiCriterionInput{ - Value: []string{ - strconv.Itoa(performerIDs[performerIdxWithImage]), - strconv.Itoa(performerIDs[performerIdx1WithImage]), - }, - Modifier: models.CriterionModifierIncludes, - } - - imageFilter := models.ImageFilterType{ - Performers: &performerCriterion, - } - - images, _ := sqb.Query(&imageFilter, nil) - - assert.Len(t, images, 2) - - // ensure ids are correct - for _, image := range images { - assert.True(t, image.ID == imageIDs[imageIdxWithPerformer] || image.ID == imageIDs[imageIdxWithTwoPerformers]) - } - - performerCriterion = models.MultiCriterionInput{ - Value: []string{ - strconv.Itoa(performerIDs[performerIdx1WithImage]), - strconv.Itoa(performerIDs[performerIdx2WithImage]), - }, - Modifier: models.CriterionModifierIncludesAll, - } - - images, _ = sqb.Query(&imageFilter, nil) - - assert.Len(t, images, 1) - assert.Equal(t, imageIDs[imageIdxWithTwoPerformers], images[0].ID) - - performerCriterion = models.MultiCriterionInput{ - Value: []string{ - strconv.Itoa(performerIDs[performerIdx1WithImage]), - }, - Modifier: models.CriterionModifierExcludes, - } - - q := getImageStringValue(imageIdxWithTwoPerformers, titleField) - findFilter := models.FindFilterType{ - Q: &q, - } - - images, _ = sqb.Query(&imageFilter, &findFilter) - assert.Len(t, images, 0) -} - -func TestImageQueryTags(t *testing.T) { - sqb := models.NewImageQueryBuilder() - tagCriterion := models.MultiCriterionInput{ - Value: []string{ - strconv.Itoa(tagIDs[tagIdxWithImage]), - strconv.Itoa(tagIDs[tagIdx1WithImage]), - }, - Modifier: models.CriterionModifierIncludes, - } - - imageFilter := models.ImageFilterType{ - Tags: &tagCriterion, - } - - images, _ := sqb.Query(&imageFilter, nil) - - assert.Len(t, images, 2) - - // ensure ids are correct - for _, image := range images { - assert.True(t, image.ID == imageIDs[imageIdxWithTag] || image.ID == imageIDs[imageIdxWithTwoTags]) - } - - tagCriterion = models.MultiCriterionInput{ - Value: []string{ - strconv.Itoa(tagIDs[tagIdx1WithImage]), - strconv.Itoa(tagIDs[tagIdx2WithImage]), - }, - Modifier: models.CriterionModifierIncludesAll, - } - - images, _ = sqb.Query(&imageFilter, nil) - - assert.Len(t, images, 1) - assert.Equal(t, imageIDs[imageIdxWithTwoTags], images[0].ID) - - tagCriterion = models.MultiCriterionInput{ - Value: []string{ - strconv.Itoa(tagIDs[tagIdx1WithImage]), - }, - Modifier: models.CriterionModifierExcludes, - } - - q := getImageStringValue(imageIdxWithTwoTags, titleField) - findFilter := models.FindFilterType{ - Q: &q, - } - - images, _ = sqb.Query(&imageFilter, &findFilter) - assert.Len(t, images, 0) -} - -func TestImageQueryStudio(t *testing.T) { - sqb := models.NewImageQueryBuilder() - studioCriterion := models.MultiCriterionInput{ - Value: []string{ - strconv.Itoa(studioIDs[studioIdxWithImage]), - }, - Modifier: models.CriterionModifierIncludes, - } - - imageFilter := models.ImageFilterType{ - Studios: &studioCriterion, - } - - images, _ := sqb.Query(&imageFilter, nil) - - assert.Len(t, images, 1) - - // ensure id is correct - assert.Equal(t, imageIDs[imageIdxWithStudio], images[0].ID) - - studioCriterion = models.MultiCriterionInput{ - Value: []string{ - strconv.Itoa(studioIDs[studioIdxWithImage]), - }, - Modifier: models.CriterionModifierExcludes, - } - - q := getImageStringValue(imageIdxWithStudio, titleField) - findFilter := models.FindFilterType{ - Q: &q, - } - - images, _ = sqb.Query(&imageFilter, &findFilter) - assert.Len(t, images, 0) -} - -func TestImageQuerySorting(t *testing.T) { - sort := titleField - direction := models.SortDirectionEnumAsc - findFilter := models.FindFilterType{ - Sort: &sort, - Direction: &direction, - } - - sqb := models.NewImageQueryBuilder() - images, _ := sqb.Query(nil, &findFilter) - - // images should be in same order as indexes - firstImage := images[0] - lastImage := images[len(images)-1] - - assert.Equal(t, imageIDs[0], firstImage.ID) - assert.Equal(t, imageIDs[len(imageIDs)-1], lastImage.ID) - - // sort in descending order - direction = models.SortDirectionEnumDesc - - images, _ = sqb.Query(nil, &findFilter) - firstImage = images[0] - lastImage = images[len(images)-1] - - assert.Equal(t, imageIDs[len(imageIDs)-1], firstImage.ID) - assert.Equal(t, imageIDs[0], lastImage.ID) -} - -func TestImageQueryPagination(t *testing.T) { - perPage := 1 - findFilter := models.FindFilterType{ - PerPage: &perPage, - } - - sqb := models.NewImageQueryBuilder() - images, _ := sqb.Query(nil, &findFilter) - - assert.Len(t, images, 1) - - firstID := images[0].ID - - page := 2 - findFilter.Page = &page - images, _ = sqb.Query(nil, &findFilter) - - assert.Len(t, images, 1) - secondID := images[0].ID - assert.NotEqual(t, firstID, secondID) - - perPage = 2 - page = 1 - - images, _ = sqb.Query(nil, &findFilter) - assert.Len(t, images, 2) - assert.Equal(t, firstID, images[0].ID) - assert.Equal(t, secondID, images[1].ID) -} - -func TestImageCountByTagID(t *testing.T) { - sqb := models.NewImageQueryBuilder() - - imageCount, err := sqb.CountByTagID(tagIDs[tagIdxWithImage]) - - if err != nil { - t.Fatalf("error calling CountByTagID: %s", err.Error()) - } - - assert.Equal(t, 1, imageCount) - - imageCount, err = sqb.CountByTagID(0) - - if err != nil { - t.Fatalf("error calling CountByTagID: %s", err.Error()) - } - - assert.Equal(t, 0, imageCount) -} - -func TestImageCountByStudioID(t *testing.T) { - sqb := models.NewImageQueryBuilder() - - imageCount, err := sqb.CountByStudioID(studioIDs[studioIdxWithImage]) - - if err != nil { - t.Fatalf("error calling CountByStudioID: %s", err.Error()) - } - - assert.Equal(t, 1, imageCount) - - imageCount, err = sqb.CountByStudioID(0) - - if err != nil { - t.Fatalf("error calling CountByStudioID: %s", err.Error()) - } - - assert.Equal(t, 0, imageCount) -} - -func TestImageFindByPerformerID(t *testing.T) { - sqb := models.NewImageQueryBuilder() - - images, err := sqb.FindByPerformerID(performerIDs[performerIdxWithImage]) - - if err != nil { - t.Fatalf("error calling FindByPerformerID: %s", err.Error()) - } - - assert.Len(t, images, 1) - assert.Equal(t, imageIDs[imageIdxWithPerformer], images[0].ID) - - images, err = sqb.FindByPerformerID(0) - - if err != nil { - t.Fatalf("error calling FindByPerformerID: %s", err.Error()) - } - - assert.Len(t, images, 0) -} - -func TestImageFindByStudioID(t *testing.T) { - sqb := models.NewImageQueryBuilder() - - images, err := sqb.FindByStudioID(performerIDs[studioIdxWithImage]) - - if err != nil { - t.Fatalf("error calling FindByStudioID: %s", err.Error()) - } - - assert.Len(t, images, 1) - assert.Equal(t, imageIDs[imageIdxWithStudio], images[0].ID) - - images, err = sqb.FindByStudioID(0) - - if err != nil { - t.Fatalf("error calling FindByStudioID: %s", err.Error()) - } - - assert.Len(t, images, 0) -} - -// TODO Update -// TODO IncrementOCounter -// TODO DecrementOCounter -// TODO ResetOCounter -// TODO Destroy -// TODO FindByChecksum -// TODO Count -// TODO SizeCount -// TODO All diff --git a/pkg/models/querybuilder_joins.go b/pkg/models/querybuilder_joins.go deleted file mode 100644 index 3e3cbbd54..000000000 --- a/pkg/models/querybuilder_joins.go +++ /dev/null @@ -1,1001 +0,0 @@ -package models - -import ( - "database/sql" - - "github.com/jmoiron/sqlx" - "github.com/stashapp/stash/pkg/database" -) - -type JoinsQueryBuilder struct{} - -func NewJoinsQueryBuilder() JoinsQueryBuilder { - return JoinsQueryBuilder{} -} - -func (qb *JoinsQueryBuilder) GetScenePerformers(sceneID int, tx *sqlx.Tx) ([]PerformersScenes, error) { - ensureTx(tx) - - // Delete the existing joins and then create new ones - query := `SELECT * from performers_scenes WHERE scene_id = ?` - - var rows *sqlx.Rows - var err error - if tx != nil { - rows, err = tx.Queryx(query, sceneID) - } else { - rows, err = database.DB.Queryx(query, sceneID) - } - - if err != nil && err != sql.ErrNoRows { - return nil, err - } - defer rows.Close() - - performerScenes := make([]PerformersScenes, 0) - for rows.Next() { - performerScene := PerformersScenes{} - if err := rows.StructScan(&performerScene); err != nil { - return nil, err - } - performerScenes = append(performerScenes, performerScene) - } - - if err := rows.Err(); err != nil { - return nil, err - } - - return performerScenes, nil -} - -func (qb *JoinsQueryBuilder) CreatePerformersScenes(newJoins []PerformersScenes, tx *sqlx.Tx) error { - ensureTx(tx) - for _, join := range newJoins { - _, err := tx.NamedExec( - `INSERT INTO performers_scenes (performer_id, scene_id) VALUES (:performer_id, :scene_id)`, - join, - ) - if err != nil { - return err - } - } - return nil -} - -// AddPerformerScene adds a performer to a scene. It does not make any change -// if the performer already exists on the scene. It returns true if scene -// performer was added. -func (qb *JoinsQueryBuilder) AddPerformerScene(sceneID int, performerID int, tx *sqlx.Tx) (bool, error) { - ensureTx(tx) - - existingPerformers, err := qb.GetScenePerformers(sceneID, tx) - - if err != nil { - return false, err - } - - // ensure not already present - for _, p := range existingPerformers { - if p.PerformerID == performerID && p.SceneID == sceneID { - return false, nil - } - } - - performerJoin := PerformersScenes{ - PerformerID: performerID, - SceneID: sceneID, - } - performerJoins := append(existingPerformers, performerJoin) - - err = qb.UpdatePerformersScenes(sceneID, performerJoins, tx) - - return err == nil, err -} - -func (qb *JoinsQueryBuilder) UpdatePerformersScenes(sceneID int, updatedJoins []PerformersScenes, tx *sqlx.Tx) error { - ensureTx(tx) - - // Delete the existing joins and then create new ones - _, err := tx.Exec("DELETE FROM performers_scenes WHERE scene_id = ?", sceneID) - if err != nil { - return err - } - return qb.CreatePerformersScenes(updatedJoins, tx) -} - -func (qb *JoinsQueryBuilder) DestroyPerformersScenes(sceneID int, tx *sqlx.Tx) error { - ensureTx(tx) - - // Delete the existing joins - _, err := tx.Exec("DELETE FROM performers_scenes WHERE scene_id = ?", sceneID) - return err -} - -func (qb *JoinsQueryBuilder) GetSceneMovies(sceneID int, tx *sqlx.Tx) ([]MoviesScenes, error) { - query := `SELECT * from movies_scenes WHERE scene_id = ?` - - var rows *sqlx.Rows - var err error - if tx != nil { - rows, err = tx.Queryx(query, sceneID) - } else { - rows, err = database.DB.Queryx(query, sceneID) - } - - if err != nil && err != sql.ErrNoRows { - return nil, err - } - defer rows.Close() - - movieScenes := make([]MoviesScenes, 0) - for rows.Next() { - movieScene := MoviesScenes{} - if err := rows.StructScan(&movieScene); err != nil { - return nil, err - } - movieScenes = append(movieScenes, movieScene) - } - - if err := rows.Err(); err != nil { - return nil, err - } - - return movieScenes, nil -} - -func (qb *JoinsQueryBuilder) CreateMoviesScenes(newJoins []MoviesScenes, tx *sqlx.Tx) error { - ensureTx(tx) - for _, join := range newJoins { - _, err := tx.NamedExec( - `INSERT INTO movies_scenes (movie_id, scene_id, scene_index) VALUES (:movie_id, :scene_id, :scene_index)`, - join, - ) - if err != nil { - return err - } - } - return nil -} - -// AddMovieScene adds a movie to a scene. It does not make any change -// if the movie already exists on the scene. It returns true if scene -// movie was added. - -func (qb *JoinsQueryBuilder) AddMoviesScene(sceneID int, movieID int, sceneIdx *int, tx *sqlx.Tx) (bool, error) { - ensureTx(tx) - - existingMovies, err := qb.GetSceneMovies(sceneID, tx) - - if err != nil { - return false, err - } - - // ensure not already present - for _, p := range existingMovies { - if p.MovieID == movieID && p.SceneID == sceneID { - return false, nil - } - } - - movieJoin := MoviesScenes{ - MovieID: movieID, - SceneID: sceneID, - } - - if sceneIdx != nil { - movieJoin.SceneIndex = sql.NullInt64{ - Int64: int64(*sceneIdx), - Valid: true, - } - } - movieJoins := append(existingMovies, movieJoin) - - err = qb.UpdateMoviesScenes(sceneID, movieJoins, tx) - - return err == nil, err -} - -func (qb *JoinsQueryBuilder) UpdateMoviesScenes(sceneID int, updatedJoins []MoviesScenes, tx *sqlx.Tx) error { - ensureTx(tx) - - // Delete the existing joins and then create new ones - _, err := tx.Exec("DELETE FROM movies_scenes WHERE scene_id = ?", sceneID) - if err != nil { - return err - } - return qb.CreateMoviesScenes(updatedJoins, tx) -} - -func (qb *JoinsQueryBuilder) DestroyMoviesScenes(sceneID int, tx *sqlx.Tx) error { - ensureTx(tx) - - // Delete the existing joins - _, err := tx.Exec("DELETE FROM movies_scenes WHERE scene_id = ?", sceneID) - return err -} - -func (qb *JoinsQueryBuilder) GetSceneTags(sceneID int, tx *sqlx.Tx) ([]ScenesTags, error) { - ensureTx(tx) - - // Delete the existing joins and then create new ones - query := `SELECT * from scenes_tags WHERE scene_id = ?` - - var rows *sqlx.Rows - var err error - if tx != nil { - rows, err = tx.Queryx(query, sceneID) - } else { - rows, err = database.DB.Queryx(query, sceneID) - } - - if err != nil && err != sql.ErrNoRows { - return nil, err - } - defer rows.Close() - - sceneTags := make([]ScenesTags, 0) - for rows.Next() { - sceneTag := ScenesTags{} - if err := rows.StructScan(&sceneTag); err != nil { - return nil, err - } - sceneTags = append(sceneTags, sceneTag) - } - - if err := rows.Err(); err != nil { - return nil, err - } - - return sceneTags, nil -} - -func (qb *JoinsQueryBuilder) CreateScenesTags(newJoins []ScenesTags, tx *sqlx.Tx) error { - ensureTx(tx) - for _, join := range newJoins { - _, err := tx.NamedExec( - `INSERT INTO scenes_tags (scene_id, tag_id) VALUES (:scene_id, :tag_id)`, - join, - ) - if err != nil { - return err - } - } - return nil -} - -func (qb *JoinsQueryBuilder) UpdateScenesTags(sceneID int, updatedJoins []ScenesTags, tx *sqlx.Tx) error { - ensureTx(tx) - - // Delete the existing joins and then create new ones - _, err := tx.Exec("DELETE FROM scenes_tags WHERE scene_id = ?", sceneID) - if err != nil { - return err - } - return qb.CreateScenesTags(updatedJoins, tx) -} - -// AddSceneTag adds a tag to a scene. It does not make any change if the tag -// already exists on the scene. It returns true if scene tag was added. -func (qb *JoinsQueryBuilder) AddSceneTag(sceneID int, tagID int, tx *sqlx.Tx) (bool, error) { - ensureTx(tx) - - existingTags, err := qb.GetSceneTags(sceneID, tx) - - if err != nil { - return false, err - } - - // ensure not already present - for _, p := range existingTags { - if p.TagID == tagID && p.SceneID == sceneID { - return false, nil - } - } - - tagJoin := ScenesTags{ - TagID: tagID, - SceneID: sceneID, - } - tagJoins := append(existingTags, tagJoin) - - err = qb.UpdateScenesTags(sceneID, tagJoins, tx) - - return err == nil, err -} - -func (qb *JoinsQueryBuilder) DestroyScenesTags(sceneID int, tx *sqlx.Tx) error { - ensureTx(tx) - - // Delete the existing joins - _, err := tx.Exec("DELETE FROM scenes_tags WHERE scene_id = ?", sceneID) - - return err -} - -func (qb *JoinsQueryBuilder) CreateSceneMarkersTags(newJoins []SceneMarkersTags, tx *sqlx.Tx) error { - ensureTx(tx) - for _, join := range newJoins { - _, err := tx.NamedExec( - `INSERT INTO scene_markers_tags (scene_marker_id, tag_id) VALUES (:scene_marker_id, :tag_id)`, - join, - ) - if err != nil { - return err - } - } - return nil -} - -func (qb *JoinsQueryBuilder) UpdateSceneMarkersTags(sceneMarkerID int, updatedJoins []SceneMarkersTags, tx *sqlx.Tx) error { - ensureTx(tx) - - // Delete the existing joins and then create new ones - _, err := tx.Exec("DELETE FROM scene_markers_tags WHERE scene_marker_id = ?", sceneMarkerID) - if err != nil { - return err - } - return qb.CreateSceneMarkersTags(updatedJoins, tx) -} - -func (qb *JoinsQueryBuilder) DestroySceneMarkersTags(sceneMarkerID int, updatedJoins []SceneMarkersTags, tx *sqlx.Tx) error { - ensureTx(tx) - - // Delete the existing joins - _, err := tx.Exec("DELETE FROM scene_markers_tags WHERE scene_marker_id = ?", sceneMarkerID) - return err -} - -func (qb *JoinsQueryBuilder) DestroyScenesGalleries(sceneID int, tx *sqlx.Tx) error { - ensureTx(tx) - - // Unset the existing scene id from galleries - _, err := tx.Exec("UPDATE galleries SET scene_id = null WHERE scene_id = ?", sceneID) - - return err -} - -func (qb *JoinsQueryBuilder) DestroyScenesMarkers(sceneID int, tx *sqlx.Tx) error { - ensureTx(tx) - - // Delete the scene marker tags - _, err := tx.Exec("DELETE t FROM scene_markers_tags t join scene_markers m on t.scene_marker_id = m.id WHERE m.scene_id = ?", sceneID) - - // Delete the existing joins - _, err = tx.Exec("DELETE FROM scene_markers WHERE scene_id = ?", sceneID) - - return err -} - -func (qb *JoinsQueryBuilder) CreateStashIDs(entityName string, entityID int, newJoins []StashID, tx *sqlx.Tx) error { - query := "INSERT INTO " + entityName + "_stash_ids (" + entityName + "_id, endpoint, stash_id) VALUES (?, ?, ?)" - ensureTx(tx) - for _, join := range newJoins { - _, err := tx.Exec(query, entityID, join.Endpoint, join.StashID) - if err != nil { - return err - } - } - return nil -} - -func (qb *JoinsQueryBuilder) GetImagePerformers(imageID int, tx *sqlx.Tx) ([]PerformersImages, error) { - ensureTx(tx) - - // Delete the existing joins and then create new ones - query := `SELECT * from performers_images WHERE image_id = ?` - - var rows *sqlx.Rows - var err error - if tx != nil { - rows, err = tx.Queryx(query, imageID) - } else { - rows, err = database.DB.Queryx(query, imageID) - } - - if err != nil && err != sql.ErrNoRows { - return nil, err - } - defer rows.Close() - - performerImages := make([]PerformersImages, 0) - for rows.Next() { - performerImage := PerformersImages{} - if err := rows.StructScan(&performerImage); err != nil { - return nil, err - } - performerImages = append(performerImages, performerImage) - } - - if err := rows.Err(); err != nil { - return nil, err - } - - return performerImages, nil -} - -func (qb *JoinsQueryBuilder) CreatePerformersImages(newJoins []PerformersImages, tx *sqlx.Tx) error { - ensureTx(tx) - for _, join := range newJoins { - _, err := tx.NamedExec( - `INSERT INTO performers_images (performer_id, image_id) VALUES (:performer_id, :image_id)`, - join, - ) - if err != nil { - return err - } - } - return nil -} - -// AddPerformerImage adds a performer to a image. It does not make any change -// if the performer already exists on the image. It returns true if image -// performer was added. -func (qb *JoinsQueryBuilder) AddPerformerImage(imageID int, performerID int, tx *sqlx.Tx) (bool, error) { - ensureTx(tx) - - existingPerformers, err := qb.GetImagePerformers(imageID, tx) - - if err != nil { - return false, err - } - - // ensure not already present - for _, p := range existingPerformers { - if p.PerformerID == performerID && p.ImageID == imageID { - return false, nil - } - } - - performerJoin := PerformersImages{ - PerformerID: performerID, - ImageID: imageID, - } - performerJoins := append(existingPerformers, performerJoin) - - err = qb.UpdatePerformersImages(imageID, performerJoins, tx) - - return err == nil, err -} - -func (qb *JoinsQueryBuilder) UpdatePerformersImages(imageID int, updatedJoins []PerformersImages, tx *sqlx.Tx) error { - ensureTx(tx) - - // Delete the existing joins and then create new ones - _, err := tx.Exec("DELETE FROM performers_images WHERE image_id = ?", imageID) - if err != nil { - return err - } - return qb.CreatePerformersImages(updatedJoins, tx) -} - -func (qb *JoinsQueryBuilder) DestroyPerformersImages(imageID int, tx *sqlx.Tx) error { - ensureTx(tx) - - // Delete the existing joins - _, err := tx.Exec("DELETE FROM performers_images WHERE image_id = ?", imageID) - return err -} - -func (qb *JoinsQueryBuilder) GetImageTags(imageID int, tx *sqlx.Tx) ([]ImagesTags, error) { - ensureTx(tx) - - // Delete the existing joins and then create new ones - query := `SELECT * from images_tags WHERE image_id = ?` - - var rows *sqlx.Rows - var err error - if tx != nil { - rows, err = tx.Queryx(query, imageID) - } else { - rows, err = database.DB.Queryx(query, imageID) - } - - if err != nil && err != sql.ErrNoRows { - return nil, err - } - defer rows.Close() - - imageTags := make([]ImagesTags, 0) - for rows.Next() { - imageTag := ImagesTags{} - if err := rows.StructScan(&imageTag); err != nil { - return nil, err - } - imageTags = append(imageTags, imageTag) - } - - if err := rows.Err(); err != nil { - return nil, err - } - - return imageTags, nil -} - -func (qb *JoinsQueryBuilder) CreateImagesTags(newJoins []ImagesTags, tx *sqlx.Tx) error { - ensureTx(tx) - for _, join := range newJoins { - _, err := tx.NamedExec( - `INSERT INTO images_tags (image_id, tag_id) VALUES (:image_id, :tag_id)`, - join, - ) - if err != nil { - return err - } - } - return nil -} - -func (qb *JoinsQueryBuilder) UpdateImagesTags(imageID int, updatedJoins []ImagesTags, tx *sqlx.Tx) error { - ensureTx(tx) - - // Delete the existing joins and then create new ones - _, err := tx.Exec("DELETE FROM images_tags WHERE image_id = ?", imageID) - if err != nil { - return err - } - return qb.CreateImagesTags(updatedJoins, tx) -} - -// AddImageTag adds a tag to a image. It does not make any change if the tag -// already exists on the image. It returns true if image tag was added. -func (qb *JoinsQueryBuilder) AddImageTag(imageID int, tagID int, tx *sqlx.Tx) (bool, error) { - ensureTx(tx) - - existingTags, err := qb.GetImageTags(imageID, tx) - - if err != nil { - return false, err - } - - // ensure not already present - for _, p := range existingTags { - if p.TagID == tagID && p.ImageID == imageID { - return false, nil - } - } - - tagJoin := ImagesTags{ - TagID: tagID, - ImageID: imageID, - } - tagJoins := append(existingTags, tagJoin) - - err = qb.UpdateImagesTags(imageID, tagJoins, tx) - - return err == nil, err -} - -func (qb *JoinsQueryBuilder) DestroyImagesTags(imageID int, tx *sqlx.Tx) error { - ensureTx(tx) - - // Delete the existing joins - _, err := tx.Exec("DELETE FROM images_tags WHERE image_id = ?", imageID) - - return err -} - -func (qb *JoinsQueryBuilder) GetImageGalleries(imageID int, tx *sqlx.Tx) ([]GalleriesImages, error) { - ensureTx(tx) - - // Delete the existing joins and then create new ones - query := `SELECT * from galleries_images WHERE image_id = ?` - - var rows *sqlx.Rows - var err error - if tx != nil { - rows, err = tx.Queryx(query, imageID) - } else { - rows, err = database.DB.Queryx(query, imageID) - } - - if err != nil && err != sql.ErrNoRows { - return nil, err - } - defer rows.Close() - - galleryImages := make([]GalleriesImages, 0) - for rows.Next() { - galleriesImages := GalleriesImages{} - if err := rows.StructScan(&galleriesImages); err != nil { - return nil, err - } - galleryImages = append(galleryImages, galleriesImages) - } - - if err := rows.Err(); err != nil { - return nil, err - } - - return galleryImages, nil -} - -func (qb *JoinsQueryBuilder) CreateGalleriesImages(newJoins []GalleriesImages, tx *sqlx.Tx) error { - ensureTx(tx) - for _, join := range newJoins { - _, err := tx.NamedExec( - `INSERT INTO galleries_images (gallery_id, image_id) VALUES (:gallery_id, :image_id)`, - join, - ) - if err != nil { - return err - } - } - return nil -} - -func (qb *JoinsQueryBuilder) UpdateGalleriesImages(imageID int, updatedJoins []GalleriesImages, tx *sqlx.Tx) error { - ensureTx(tx) - - // Delete the existing joins and then create new ones - _, err := tx.Exec("DELETE FROM galleries_images WHERE image_id = ?", imageID) - if err != nil { - return err - } - return qb.CreateGalleriesImages(updatedJoins, tx) -} - -// AddGalleryImage adds a gallery to an image. It does not make any change if the tag -// already exists on the image. It returns true if image tag was added. -func (qb *JoinsQueryBuilder) AddImageGallery(imageID int, galleryID int, tx *sqlx.Tx) (bool, error) { - ensureTx(tx) - - existingGalleries, err := qb.GetImageGalleries(imageID, tx) - - if err != nil { - return false, err - } - - // ensure not already present - for _, p := range existingGalleries { - if p.GalleryID == galleryID && p.ImageID == imageID { - return false, nil - } - } - - galleryJoin := GalleriesImages{ - GalleryID: galleryID, - ImageID: imageID, - } - galleryJoins := append(existingGalleries, galleryJoin) - - err = qb.UpdateGalleriesImages(imageID, galleryJoins, tx) - - return err == nil, err -} - -// RemoveImageGallery removes a gallery from an image. Returns true if the join -// was removed. -func (qb *JoinsQueryBuilder) RemoveImageGallery(imageID int, galleryID int, tx *sqlx.Tx) (bool, error) { - ensureTx(tx) - - existingGalleries, err := qb.GetImageGalleries(imageID, tx) - - if err != nil { - return false, err - } - - // remove the join - var updatedJoins []GalleriesImages - found := false - for _, p := range existingGalleries { - if p.GalleryID == galleryID && p.ImageID == imageID { - found = true - continue - } - - updatedJoins = append(updatedJoins, p) - } - - if found { - err = qb.UpdateGalleriesImages(imageID, updatedJoins, tx) - } - - return found && err == nil, err -} - -func (qb *JoinsQueryBuilder) DestroyImageGalleries(imageID int, tx *sqlx.Tx) error { - ensureTx(tx) - - // Delete the existing joins - _, err := tx.Exec("DELETE FROM galleries_images WHERE image_id = ?", imageID) - - return err -} - -func (qb *JoinsQueryBuilder) GetGalleryPerformers(galleryID int, tx *sqlx.Tx) ([]PerformersGalleries, error) { - ensureTx(tx) - - // Delete the existing joins and then create new ones - query := `SELECT * from performers_galleries WHERE gallery_id = ?` - - var rows *sqlx.Rows - var err error - if tx != nil { - rows, err = tx.Queryx(query, galleryID) - } else { - rows, err = database.DB.Queryx(query, galleryID) - } - - if err != nil && err != sql.ErrNoRows { - return nil, err - } - defer rows.Close() - - performerGalleries := make([]PerformersGalleries, 0) - for rows.Next() { - performerGallery := PerformersGalleries{} - if err := rows.StructScan(&performerGallery); err != nil { - return nil, err - } - performerGalleries = append(performerGalleries, performerGallery) - } - - if err := rows.Err(); err != nil { - return nil, err - } - - return performerGalleries, nil -} - -func (qb *JoinsQueryBuilder) CreatePerformersGalleries(newJoins []PerformersGalleries, tx *sqlx.Tx) error { - ensureTx(tx) - for _, join := range newJoins { - _, err := tx.NamedExec( - `INSERT INTO performers_galleries (performer_id, gallery_id) VALUES (:performer_id, :gallery_id)`, - join, - ) - if err != nil { - return err - } - } - return nil -} - -// AddPerformerGallery adds a performer to a gallery. It does not make any change -// if the performer already exists on the gallery. It returns true if gallery -// performer was added. -func (qb *JoinsQueryBuilder) AddPerformerGallery(galleryID int, performerID int, tx *sqlx.Tx) (bool, error) { - ensureTx(tx) - - existingPerformers, err := qb.GetGalleryPerformers(galleryID, tx) - - if err != nil { - return false, err - } - - // ensure not already present - for _, p := range existingPerformers { - if p.PerformerID == performerID && p.GalleryID == galleryID { - return false, nil - } - } - - performerJoin := PerformersGalleries{ - PerformerID: performerID, - GalleryID: galleryID, - } - performerJoins := append(existingPerformers, performerJoin) - - err = qb.UpdatePerformersGalleries(galleryID, performerJoins, tx) - - return err == nil, err -} - -func (qb *JoinsQueryBuilder) UpdatePerformersGalleries(galleryID int, updatedJoins []PerformersGalleries, tx *sqlx.Tx) error { - ensureTx(tx) - - // Delete the existing joins and then create new ones - _, err := tx.Exec("DELETE FROM performers_galleries WHERE gallery_id = ?", galleryID) - if err != nil { - return err - } - return qb.CreatePerformersGalleries(updatedJoins, tx) -} - -func (qb *JoinsQueryBuilder) DestroyPerformersGalleries(galleryID int, tx *sqlx.Tx) error { - ensureTx(tx) - - // Delete the existing joins - _, err := tx.Exec("DELETE FROM performers_galleries WHERE gallery_id = ?", galleryID) - return err -} - -func (qb *JoinsQueryBuilder) GetGalleryTags(galleryID int, tx *sqlx.Tx) ([]GalleriesTags, error) { - ensureTx(tx) - - // Delete the existing joins and then create new ones - query := `SELECT * from galleries_tags WHERE gallery_id = ?` - - var rows *sqlx.Rows - var err error - if tx != nil { - rows, err = tx.Queryx(query, galleryID) - } else { - rows, err = database.DB.Queryx(query, galleryID) - } - - if err != nil && err != sql.ErrNoRows { - return nil, err - } - defer rows.Close() - - galleryTags := make([]GalleriesTags, 0) - for rows.Next() { - galleryTag := GalleriesTags{} - if err := rows.StructScan(&galleryTag); err != nil { - return nil, err - } - galleryTags = append(galleryTags, galleryTag) - } - - if err := rows.Err(); err != nil { - return nil, err - } - - return galleryTags, nil -} - -func (qb *JoinsQueryBuilder) CreateGalleriesTags(newJoins []GalleriesTags, tx *sqlx.Tx) error { - ensureTx(tx) - for _, join := range newJoins { - _, err := tx.NamedExec( - `INSERT INTO galleries_tags (gallery_id, tag_id) VALUES (:gallery_id, :tag_id)`, - join, - ) - if err != nil { - return err - } - } - return nil -} - -func (qb *JoinsQueryBuilder) UpdateGalleriesTags(galleryID int, updatedJoins []GalleriesTags, tx *sqlx.Tx) error { - ensureTx(tx) - - // Delete the existing joins and then create new ones - _, err := tx.Exec("DELETE FROM galleries_tags WHERE gallery_id = ?", galleryID) - if err != nil { - return err - } - return qb.CreateGalleriesTags(updatedJoins, tx) -} - -// AddGalleryTag adds a tag to a gallery. It does not make any change if the tag -// already exists on the gallery. It returns true if gallery tag was added. -func (qb *JoinsQueryBuilder) AddGalleryTag(galleryID int, tagID int, tx *sqlx.Tx) (bool, error) { - ensureTx(tx) - - existingTags, err := qb.GetGalleryTags(galleryID, tx) - - if err != nil { - return false, err - } - - // ensure not already present - for _, p := range existingTags { - if p.TagID == tagID && p.GalleryID == galleryID { - return false, nil - } - } - - tagJoin := GalleriesTags{ - TagID: tagID, - GalleryID: galleryID, - } - tagJoins := append(existingTags, tagJoin) - - err = qb.UpdateGalleriesTags(galleryID, tagJoins, tx) - - return err == nil, err -} - -func (qb *JoinsQueryBuilder) DestroyGalleriesTags(galleryID int, tx *sqlx.Tx) error { - ensureTx(tx) - - // Delete the existing joins - _, err := tx.Exec("DELETE FROM galleries_tags WHERE gallery_id = ?", galleryID) - - return err -} - -func (qb *JoinsQueryBuilder) GetSceneStashIDs(sceneID int) ([]*StashID, error) { - rows, err := database.DB.Queryx(`SELECT stash_id, endpoint from scene_stash_ids WHERE scene_id = ?`, sceneID) - - if err != nil && err != sql.ErrNoRows { - return nil, err - } - defer rows.Close() - - stashIDs := []*StashID{} - for rows.Next() { - stashID := StashID{} - if err := rows.StructScan(&stashID); err != nil { - return nil, err - } - stashIDs = append(stashIDs, &stashID) - } - - if err := rows.Err(); err != nil { - return nil, err - } - - return stashIDs, nil -} - -func (qb *JoinsQueryBuilder) GetPerformerStashIDs(performerID int) ([]*StashID, error) { - rows, err := database.DB.Queryx(`SELECT stash_id, endpoint from performer_stash_ids WHERE performer_id = ?`, performerID) - - if err != nil && err != sql.ErrNoRows { - return nil, err - } - defer rows.Close() - - stashIDs := []*StashID{} - for rows.Next() { - stashID := StashID{} - if err := rows.StructScan(&stashID); err != nil { - return nil, err - } - stashIDs = append(stashIDs, &stashID) - } - - if err := rows.Err(); err != nil { - return nil, err - } - - return stashIDs, nil -} - -func (qb *JoinsQueryBuilder) GetStudioStashIDs(studioID int) ([]*StashID, error) { - rows, err := database.DB.Queryx(`SELECT stash_id, endpoint from studio_stash_ids WHERE studio_id = ?`, studioID) - - if err != nil && err != sql.ErrNoRows { - return nil, err - } - defer rows.Close() - - stashIDs := []*StashID{} - for rows.Next() { - stashID := StashID{} - if err := rows.StructScan(&stashID); err != nil { - return nil, err - } - stashIDs = append(stashIDs, &stashID) - } - - if err := rows.Err(); err != nil { - return nil, err - } - - return stashIDs, nil -} - -func (qb *JoinsQueryBuilder) UpdateSceneStashIDs(sceneID int, updatedJoins []StashID, tx *sqlx.Tx) error { - ensureTx(tx) - - _, err := tx.Exec("DELETE FROM scene_stash_ids WHERE scene_id = ?", sceneID) - if err != nil { - return err - } - return qb.CreateStashIDs("scene", sceneID, updatedJoins, tx) -} - -func (qb *JoinsQueryBuilder) UpdatePerformerStashIDs(performerID int, updatedJoins []StashID, tx *sqlx.Tx) error { - ensureTx(tx) - - _, err := tx.Exec("DELETE FROM performer_stash_ids WHERE performer_id = ?", performerID) - if err != nil { - return err - } - return qb.CreateStashIDs("performer", performerID, updatedJoins, tx) -} - -func (qb *JoinsQueryBuilder) UpdateStudioStashIDs(studioID int, updatedJoins []StashID, tx *sqlx.Tx) error { - ensureTx(tx) - - _, err := tx.Exec("DELETE FROM studio_stash_ids WHERE studio_id = ?", studioID) - if err != nil { - return err - } - return qb.CreateStashIDs("studio", studioID, updatedJoins, tx) -} diff --git a/pkg/models/querybuilder_movies.go b/pkg/models/querybuilder_movies.go deleted file mode 100644 index a111cf563..000000000 --- a/pkg/models/querybuilder_movies.go +++ /dev/null @@ -1,312 +0,0 @@ -package models - -import ( - "database/sql" - "fmt" - - "github.com/jmoiron/sqlx" - "github.com/stashapp/stash/pkg/database" -) - -type MovieQueryBuilder struct{} - -func NewMovieQueryBuilder() MovieQueryBuilder { - return MovieQueryBuilder{} -} - -func (qb *MovieQueryBuilder) Create(newMovie Movie, tx *sqlx.Tx) (*Movie, error) { - ensureTx(tx) - result, err := tx.NamedExec( - `INSERT INTO movies (checksum, name, aliases, duration, date, rating, studio_id, director, synopsis, url, created_at, updated_at) - VALUES (:checksum, :name, :aliases, :duration, :date, :rating, :studio_id, :director, :synopsis, :url, :created_at, :updated_at) - `, - newMovie, - ) - if err != nil { - return nil, err - } - movieID, err := result.LastInsertId() - if err != nil { - return nil, err - } - - if err := tx.Get(&newMovie, `SELECT * FROM movies WHERE id = ? LIMIT 1`, movieID); err != nil { - return nil, err - } - return &newMovie, nil -} - -func (qb *MovieQueryBuilder) Update(updatedMovie MoviePartial, tx *sqlx.Tx) (*Movie, error) { - ensureTx(tx) - _, err := tx.NamedExec( - `UPDATE movies SET `+SQLGenKeysPartial(updatedMovie)+` WHERE movies.id = :id`, - updatedMovie, - ) - if err != nil { - return nil, err - } - - return qb.Find(updatedMovie.ID, tx) -} - -func (qb *MovieQueryBuilder) UpdateFull(updatedMovie Movie, tx *sqlx.Tx) (*Movie, error) { - ensureTx(tx) - _, err := tx.NamedExec( - `UPDATE movies SET `+SQLGenKeys(updatedMovie)+` WHERE movies.id = :id`, - updatedMovie, - ) - if err != nil { - return nil, err - } - - return qb.Find(updatedMovie.ID, tx) -} - -func (qb *MovieQueryBuilder) Destroy(id string, tx *sqlx.Tx) error { - // delete movie from movies_scenes - - _, err := tx.Exec("DELETE FROM movies_scenes WHERE movie_id = ?", id) - if err != nil { - return err - } - - // // remove movie from scraped items - // _, err = tx.Exec("UPDATE scraped_items SET movie_id = null WHERE movie_id = ?", id) - // if err != nil { - // return err - // } - - return executeDeleteQuery("movies", id, tx) -} - -func (qb *MovieQueryBuilder) Find(id int, tx *sqlx.Tx) (*Movie, error) { - query := "SELECT * FROM movies WHERE id = ? LIMIT 1" - args := []interface{}{id} - return qb.queryMovie(query, args, tx) -} - -func (qb *MovieQueryBuilder) FindMany(ids []int) ([]*Movie, error) { - var movies []*Movie - for _, id := range ids { - movie, err := qb.Find(id, nil) - if err != nil { - return nil, err - } - - if movie == nil { - return nil, fmt.Errorf("movie with id %d not found", id) - } - - movies = append(movies, movie) - } - - return movies, nil -} - -func (qb *MovieQueryBuilder) FindBySceneID(sceneID int, tx *sqlx.Tx) ([]*Movie, error) { - query := ` - SELECT movies.* FROM movies - LEFT JOIN movies_scenes as scenes_join on scenes_join.movie_id = movies.id - WHERE scenes_join.scene_id = ? - GROUP BY movies.id - ` - args := []interface{}{sceneID} - return qb.queryMovies(query, args, tx) -} - -func (qb *MovieQueryBuilder) FindByName(name string, tx *sqlx.Tx, nocase bool) (*Movie, error) { - query := "SELECT * FROM movies WHERE name = ?" - if nocase { - query += " COLLATE NOCASE" - } - query += " LIMIT 1" - args := []interface{}{name} - return qb.queryMovie(query, args, tx) -} - -func (qb *MovieQueryBuilder) FindByNames(names []string, tx *sqlx.Tx, nocase bool) ([]*Movie, error) { - query := "SELECT * FROM movies WHERE name" - if nocase { - query += " COLLATE NOCASE" - } - query += " IN " + getInBinding(len(names)) - var args []interface{} - for _, name := range names { - args = append(args, name) - } - return qb.queryMovies(query, args, tx) -} - -func (qb *MovieQueryBuilder) Count() (int, error) { - return runCountQuery(buildCountQuery("SELECT movies.id FROM movies"), nil) -} - -func (qb *MovieQueryBuilder) All() ([]*Movie, error) { - return qb.queryMovies(selectAll("movies")+qb.getMovieSort(nil), nil, nil) -} - -func (qb *MovieQueryBuilder) AllSlim() ([]*Movie, error) { - return qb.queryMovies("SELECT movies.id, movies.name FROM movies "+qb.getMovieSort(nil), nil, nil) -} - -func (qb *MovieQueryBuilder) Query(movieFilter *MovieFilterType, findFilter *FindFilterType) ([]*Movie, int) { - if findFilter == nil { - findFilter = &FindFilterType{} - } - if movieFilter == nil { - movieFilter = &MovieFilterType{} - } - - var whereClauses []string - var havingClauses []string - var args []interface{} - body := selectDistinctIDs("movies") - body += ` - left join movies_scenes as scenes_join on scenes_join.movie_id = movies.id - left join scenes on scenes_join.scene_id = scenes.id - left join studios as studio on studio.id = movies.studio_id -` - - if q := findFilter.Q; q != nil && *q != "" { - searchColumns := []string{"movies.name"} - clause, thisArgs := getSearchBinding(searchColumns, *q, false) - whereClauses = append(whereClauses, clause) - args = append(args, thisArgs...) - } - - if studiosFilter := movieFilter.Studios; studiosFilter != nil && len(studiosFilter.Value) > 0 { - for _, studioID := range studiosFilter.Value { - args = append(args, studioID) - } - - whereClause, havingClause := getMultiCriterionClause("movies", "studio", "", "", "studio_id", studiosFilter) - whereClauses = appendClause(whereClauses, whereClause) - havingClauses = appendClause(havingClauses, havingClause) - } - - if isMissingFilter := movieFilter.IsMissing; isMissingFilter != nil && *isMissingFilter != "" { - switch *isMissingFilter { - case "front_image": - body += `left join movies_images on movies_images.movie_id = movies.id - ` - whereClauses = appendClause(whereClauses, "movies_images.front_image IS NULL") - case "back_image": - body += `left join movies_images on movies_images.movie_id = movies.id - ` - whereClauses = appendClause(whereClauses, "movies_images.back_image IS NULL") - case "scenes": - body += `left join movies_scenes on movies_scenes.movie_id = movies.id - ` - whereClauses = appendClause(whereClauses, "movies_scenes.scene_id IS NULL") - default: - whereClauses = appendClause(whereClauses, "movies."+*isMissingFilter+" IS NULL") - } - } - - sortAndPagination := qb.getMovieSort(findFilter) + getPagination(findFilter) - idsResult, countResult := executeFindQuery("movies", body, args, sortAndPagination, whereClauses, havingClauses) - - var movies []*Movie - for _, id := range idsResult { - movie, _ := qb.Find(id, nil) - movies = append(movies, movie) - } - - return movies, countResult -} - -func (qb *MovieQueryBuilder) getMovieSort(findFilter *FindFilterType) string { - var sort string - var direction string - if findFilter == nil { - sort = "name" - direction = "ASC" - } else { - sort = findFilter.GetSort("name") - direction = findFilter.GetDirection() - } - - // #943 - override name sorting to use natural sort - if sort == "name" { - return " ORDER BY " + getColumn("movies", sort) + " COLLATE NATURAL_CS " + direction - } - - return getSort(sort, direction, "movies") -} - -func (qb *MovieQueryBuilder) queryMovie(query string, args []interface{}, tx *sqlx.Tx) (*Movie, error) { - results, err := qb.queryMovies(query, args, tx) - if err != nil || len(results) < 1 { - return nil, err - } - return results[0], nil -} - -func (qb *MovieQueryBuilder) queryMovies(query string, args []interface{}, tx *sqlx.Tx) ([]*Movie, error) { - var rows *sqlx.Rows - var err error - if tx != nil { - rows, err = tx.Queryx(query, args...) - } else { - rows, err = database.DB.Queryx(query, args...) - } - - if err != nil && err != sql.ErrNoRows { - return nil, err - } - defer rows.Close() - - movies := make([]*Movie, 0) - for rows.Next() { - movie := Movie{} - if err := rows.StructScan(&movie); err != nil { - return nil, err - } - movies = append(movies, &movie) - } - - if err := rows.Err(); err != nil { - return nil, err - } - - return movies, nil -} - -func (qb *MovieQueryBuilder) UpdateMovieImages(movieID int, frontImage []byte, backImage []byte, tx *sqlx.Tx) error { - ensureTx(tx) - - // Delete the existing cover and then create new - if err := qb.DestroyMovieImages(movieID, tx); err != nil { - return err - } - - _, err := tx.Exec( - `INSERT INTO movies_images (movie_id, front_image, back_image) VALUES (?, ?, ?)`, - movieID, - frontImage, - backImage, - ) - - return err -} - -func (qb *MovieQueryBuilder) DestroyMovieImages(movieID int, tx *sqlx.Tx) error { - ensureTx(tx) - - // Delete the existing joins - _, err := tx.Exec("DELETE FROM movies_images WHERE movie_id = ?", movieID) - if err != nil { - return err - } - return err -} - -func (qb *MovieQueryBuilder) GetFrontImage(movieID int, tx *sqlx.Tx) ([]byte, error) { - query := `SELECT front_image from movies_images WHERE movie_id = ?` - return getImage(tx, query, movieID) -} - -func (qb *MovieQueryBuilder) GetBackImage(movieID int, tx *sqlx.Tx) ([]byte, error) { - query := `SELECT back_image from movies_images WHERE movie_id = ?` - return getImage(tx, query, movieID) -} diff --git a/pkg/models/querybuilder_movies_test.go b/pkg/models/querybuilder_movies_test.go deleted file mode 100644 index 407350e36..000000000 --- a/pkg/models/querybuilder_movies_test.go +++ /dev/null @@ -1,275 +0,0 @@ -// +build integration - -package models_test - -import ( - "context" - "database/sql" - "strconv" - "strings" - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/stashapp/stash/pkg/database" - "github.com/stashapp/stash/pkg/models" - "github.com/stashapp/stash/pkg/utils" -) - -func TestMovieFindBySceneID(t *testing.T) { - mqb := models.NewMovieQueryBuilder() - sceneID := sceneIDs[sceneIdxWithMovie] - - movies, err := mqb.FindBySceneID(sceneID, nil) - - if err != nil { - t.Fatalf("Error finding movie: %s", err.Error()) - } - - assert.Equal(t, 1, len(movies), "expect 1 movie") - - movie := movies[0] - assert.Equal(t, getMovieStringValue(movieIdxWithScene, "Name"), movie.Name.String) - - movies, err = mqb.FindBySceneID(0, nil) - - if err != nil { - t.Fatalf("Error finding movie: %s", err.Error()) - } - - assert.Equal(t, 0, len(movies)) -} - -func TestMovieFindByName(t *testing.T) { - - mqb := models.NewMovieQueryBuilder() - - name := movieNames[movieIdxWithScene] // find a movie by name - - movie, err := mqb.FindByName(name, nil, false) - - if err != nil { - t.Fatalf("Error finding movies: %s", err.Error()) - } - - assert.Equal(t, movieNames[movieIdxWithScene], movie.Name.String) - - name = movieNames[movieIdxWithDupName] // find a movie by name nocase - - movie, err = mqb.FindByName(name, nil, true) - - if err != nil { - t.Fatalf("Error finding movies: %s", err.Error()) - } - // movieIdxWithDupName and movieIdxWithScene should have similar names ( only diff should be Name vs NaMe) - //movie.Name should match with movieIdxWithScene since its ID is before moveIdxWithDupName - assert.Equal(t, movieNames[movieIdxWithScene], movie.Name.String) - //movie.Name should match with movieIdxWithDupName if the check is not case sensitive - assert.Equal(t, strings.ToLower(movieNames[movieIdxWithDupName]), strings.ToLower(movie.Name.String)) -} - -func TestMovieFindByNames(t *testing.T) { - var names []string - - mqb := models.NewMovieQueryBuilder() - - names = append(names, movieNames[movieIdxWithScene]) // find movies by names - - movies, err := mqb.FindByNames(names, nil, false) - if err != nil { - t.Fatalf("Error finding movies: %s", err.Error()) - } - assert.Len(t, movies, 1) - assert.Equal(t, movieNames[movieIdxWithScene], movies[0].Name.String) - - movies, err = mqb.FindByNames(names, nil, true) // find movies by names nocase - if err != nil { - t.Fatalf("Error finding movies: %s", err.Error()) - } - assert.Len(t, movies, 2) // movieIdxWithScene and movieIdxWithDupName - assert.Equal(t, strings.ToLower(movieNames[movieIdxWithScene]), strings.ToLower(movies[0].Name.String)) - assert.Equal(t, strings.ToLower(movieNames[movieIdxWithScene]), strings.ToLower(movies[1].Name.String)) -} - -func TestMovieQueryStudio(t *testing.T) { - mqb := models.NewMovieQueryBuilder() - studioCriterion := models.MultiCriterionInput{ - Value: []string{ - strconv.Itoa(studioIDs[studioIdxWithMovie]), - }, - Modifier: models.CriterionModifierIncludes, - } - - movieFilter := models.MovieFilterType{ - Studios: &studioCriterion, - } - - movies, _ := mqb.Query(&movieFilter, nil) - - assert.Len(t, movies, 1) - - // ensure id is correct - assert.Equal(t, movieIDs[movieIdxWithStudio], movies[0].ID) - - studioCriterion = models.MultiCriterionInput{ - Value: []string{ - strconv.Itoa(studioIDs[studioIdxWithMovie]), - }, - Modifier: models.CriterionModifierExcludes, - } - - q := getMovieStringValue(movieIdxWithStudio, titleField) - findFilter := models.FindFilterType{ - Q: &q, - } - - movies, _ = mqb.Query(&movieFilter, &findFilter) - assert.Len(t, movies, 0) -} - -func TestMovieUpdateMovieImages(t *testing.T) { - mqb := models.NewMovieQueryBuilder() - - // create movie to test against - ctx := context.TODO() - tx := database.DB.MustBeginTx(ctx, nil) - - const name = "TestMovieUpdateMovieImages" - movie := models.Movie{ - Name: sql.NullString{String: name, Valid: true}, - Checksum: utils.MD5FromString(name), - } - created, err := mqb.Create(movie, tx) - if err != nil { - tx.Rollback() - t.Fatalf("Error creating movie: %s", err.Error()) - } - - frontImage := []byte("frontImage") - backImage := []byte("backImage") - err = mqb.UpdateMovieImages(created.ID, frontImage, backImage, tx) - if err != nil { - tx.Rollback() - t.Fatalf("Error updating movie images: %s", err.Error()) - } - - if err := tx.Commit(); err != nil { - tx.Rollback() - t.Fatalf("Error committing: %s", err.Error()) - } - - // ensure images are set - storedFront, err := mqb.GetFrontImage(created.ID, nil) - if err != nil { - t.Fatalf("Error getting front image: %s", err.Error()) - } - assert.Equal(t, storedFront, frontImage) - - storedBack, err := mqb.GetBackImage(created.ID, nil) - if err != nil { - t.Fatalf("Error getting back image: %s", err.Error()) - } - assert.Equal(t, storedBack, backImage) - - // set front image only - newImage := []byte("newImage") - tx = database.DB.MustBeginTx(ctx, nil) - err = mqb.UpdateMovieImages(created.ID, newImage, nil, tx) - if err != nil { - tx.Rollback() - t.Fatalf("Error updating movie images: %s", err.Error()) - } - - storedFront, err = mqb.GetFrontImage(created.ID, tx) - if err != nil { - tx.Rollback() - t.Fatalf("Error getting front image: %s", err.Error()) - } - assert.Equal(t, storedFront, newImage) - - // back image should be nil - storedBack, err = mqb.GetBackImage(created.ID, tx) - if err != nil { - tx.Rollback() - t.Fatalf("Error getting back image: %s", err.Error()) - } - assert.Nil(t, nil) - - // set back image only - err = mqb.UpdateMovieImages(created.ID, nil, newImage, tx) - if err == nil { - tx.Rollback() - t.Fatalf("Expected error setting nil front image") - } - - if err := tx.Commit(); err != nil { - tx.Rollback() - t.Fatalf("Error committing: %s", err.Error()) - } -} - -func TestMovieDestroyMovieImages(t *testing.T) { - mqb := models.NewMovieQueryBuilder() - - // create movie to test against - ctx := context.TODO() - tx := database.DB.MustBeginTx(ctx, nil) - - const name = "TestMovieDestroyMovieImages" - movie := models.Movie{ - Name: sql.NullString{String: name, Valid: true}, - Checksum: utils.MD5FromString(name), - } - created, err := mqb.Create(movie, tx) - if err != nil { - tx.Rollback() - t.Fatalf("Error creating movie: %s", err.Error()) - } - - frontImage := []byte("frontImage") - backImage := []byte("backImage") - err = mqb.UpdateMovieImages(created.ID, frontImage, backImage, tx) - if err != nil { - tx.Rollback() - t.Fatalf("Error updating movie images: %s", err.Error()) - } - - if err := tx.Commit(); err != nil { - tx.Rollback() - t.Fatalf("Error committing: %s", err.Error()) - } - - tx = database.DB.MustBeginTx(ctx, nil) - - err = mqb.DestroyMovieImages(created.ID, tx) - if err != nil { - tx.Rollback() - t.Fatalf("Error destroying movie images: %s", err.Error()) - } - - if err := tx.Commit(); err != nil { - tx.Rollback() - t.Fatalf("Error committing: %s", err.Error()) - } - - // front image should be nil - storedFront, err := mqb.GetFrontImage(created.ID, nil) - if err != nil { - t.Fatalf("Error getting front image: %s", err.Error()) - } - assert.Nil(t, storedFront) - - // back image should be nil - storedBack, err := mqb.GetBackImage(created.ID, nil) - if err != nil { - t.Fatalf("Error getting back image: %s", err.Error()) - } - assert.Nil(t, storedBack) -} - -// TODO Update -// TODO Destroy -// TODO Find -// TODO Count -// TODO All -// TODO Query diff --git a/pkg/models/querybuilder_performer_test.go b/pkg/models/querybuilder_performer_test.go deleted file mode 100644 index eb65b55bf..000000000 --- a/pkg/models/querybuilder_performer_test.go +++ /dev/null @@ -1,257 +0,0 @@ -// +build integration - -package models_test - -import ( - "context" - "database/sql" - "strings" - "testing" - "time" - - "github.com/stretchr/testify/assert" - - "github.com/stashapp/stash/pkg/database" - "github.com/stashapp/stash/pkg/models" - "github.com/stashapp/stash/pkg/utils" -) - -func TestPerformerFindBySceneID(t *testing.T) { - pqb := models.NewPerformerQueryBuilder() - sceneID := sceneIDs[sceneIdxWithPerformer] - - performers, err := pqb.FindBySceneID(sceneID, nil) - - if err != nil { - t.Fatalf("Error finding performer: %s", err.Error()) - } - - assert.Equal(t, 1, len(performers)) - performer := performers[0] - - assert.Equal(t, getPerformerStringValue(performerIdxWithScene, "Name"), performer.Name.String) - - performers, err = pqb.FindBySceneID(0, nil) - - if err != nil { - t.Fatalf("Error finding performer: %s", err.Error()) - } - - assert.Equal(t, 0, len(performers)) -} - -func TestPerformerFindNameBySceneID(t *testing.T) { - pqb := models.NewPerformerQueryBuilder() - sceneID := sceneIDs[sceneIdxWithPerformer] - - performers, err := pqb.FindNameBySceneID(sceneID, nil) - - if err != nil { - t.Fatalf("Error finding performer: %s", err.Error()) - } - - assert.Equal(t, 1, len(performers)) - performer := performers[0] - - assert.Equal(t, getPerformerStringValue(performerIdxWithScene, "Name"), performer.Name.String) - - performers, err = pqb.FindBySceneID(0, nil) - - if err != nil { - t.Fatalf("Error finding performer: %s", err.Error()) - } - - assert.Equal(t, 0, len(performers)) -} - -func TestPerformerFindByNames(t *testing.T) { - var names []string - - pqb := models.NewPerformerQueryBuilder() - - names = append(names, performerNames[performerIdxWithScene]) // find performers by names - - performers, err := pqb.FindByNames(names, nil, false) - if err != nil { - t.Fatalf("Error finding performers: %s", err.Error()) - } - assert.Len(t, performers, 1) - assert.Equal(t, performerNames[performerIdxWithScene], performers[0].Name.String) - - performers, err = pqb.FindByNames(names, nil, true) // find performers by names nocase - if err != nil { - t.Fatalf("Error finding performers: %s", err.Error()) - } - assert.Len(t, performers, 2) // performerIdxWithScene and performerIdxWithDupName - assert.Equal(t, strings.ToLower(performerNames[performerIdxWithScene]), strings.ToLower(performers[0].Name.String)) - assert.Equal(t, strings.ToLower(performerNames[performerIdxWithScene]), strings.ToLower(performers[1].Name.String)) - - names = append(names, performerNames[performerIdx1WithScene]) // find performers by names ( 2 names ) - - performers, err = pqb.FindByNames(names, nil, false) - if err != nil { - t.Fatalf("Error finding performers: %s", err.Error()) - } - assert.Len(t, performers, 2) // performerIdxWithScene and performerIdx1WithScene - assert.Equal(t, performerNames[performerIdxWithScene], performers[0].Name.String) - assert.Equal(t, performerNames[performerIdx1WithScene], performers[1].Name.String) - - performers, err = pqb.FindByNames(names, nil, true) // find performers by names ( 2 names nocase) - if err != nil { - t.Fatalf("Error finding performers: %s", err.Error()) - } - assert.Len(t, performers, 4) // performerIdxWithScene and performerIdxWithDupName , performerIdx1WithScene and performerIdx1WithDupName - assert.Equal(t, performerNames[performerIdxWithScene], performers[0].Name.String) - assert.Equal(t, performerNames[performerIdx1WithScene], performers[1].Name.String) - assert.Equal(t, performerNames[performerIdx1WithDupName], performers[2].Name.String) - assert.Equal(t, performerNames[performerIdxWithDupName], performers[3].Name.String) - -} - -func TestPerformerUpdatePerformerImage(t *testing.T) { - qb := models.NewPerformerQueryBuilder() - - // create performer to test against - ctx := context.TODO() - tx := database.DB.MustBeginTx(ctx, nil) - - const name = "TestPerformerUpdatePerformerImage" - performer := models.Performer{ - Name: sql.NullString{String: name, Valid: true}, - Checksum: utils.MD5FromString(name), - Favorite: sql.NullBool{Bool: false, Valid: true}, - } - created, err := qb.Create(performer, tx) - if err != nil { - tx.Rollback() - t.Fatalf("Error creating performer: %s", err.Error()) - } - - image := []byte("image") - err = qb.UpdatePerformerImage(created.ID, image, tx) - if err != nil { - tx.Rollback() - t.Fatalf("Error updating performer image: %s", err.Error()) - } - - if err := tx.Commit(); err != nil { - tx.Rollback() - t.Fatalf("Error committing: %s", err.Error()) - } - - // ensure image set - storedImage, err := qb.GetPerformerImage(created.ID, nil) - if err != nil { - t.Fatalf("Error getting image: %s", err.Error()) - } - assert.Equal(t, storedImage, image) - - // set nil image - tx = database.DB.MustBeginTx(ctx, nil) - err = qb.UpdatePerformerImage(created.ID, nil, tx) - if err == nil { - t.Fatalf("Expected error setting nil image") - } - - tx.Rollback() -} - -func TestPerformerDestroyPerformerImage(t *testing.T) { - qb := models.NewPerformerQueryBuilder() - - // create performer to test against - ctx := context.TODO() - tx := database.DB.MustBeginTx(ctx, nil) - - const name = "TestPerformerDestroyPerformerImage" - performer := models.Performer{ - Name: sql.NullString{String: name, Valid: true}, - Checksum: utils.MD5FromString(name), - Favorite: sql.NullBool{Bool: false, Valid: true}, - } - created, err := qb.Create(performer, tx) - if err != nil { - tx.Rollback() - t.Fatalf("Error creating performer: %s", err.Error()) - } - - image := []byte("image") - err = qb.UpdatePerformerImage(created.ID, image, tx) - if err != nil { - tx.Rollback() - t.Fatalf("Error updating performer image: %s", err.Error()) - } - - if err := tx.Commit(); err != nil { - tx.Rollback() - t.Fatalf("Error committing: %s", err.Error()) - } - - tx = database.DB.MustBeginTx(ctx, nil) - - err = qb.DestroyPerformerImage(created.ID, tx) - if err != nil { - tx.Rollback() - t.Fatalf("Error destroying performer image: %s", err.Error()) - } - - if err := tx.Commit(); err != nil { - tx.Rollback() - t.Fatalf("Error committing: %s", err.Error()) - } - - // image should be nil - storedImage, err := qb.GetPerformerImage(created.ID, nil) - if err != nil { - t.Fatalf("Error getting image: %s", err.Error()) - } - assert.Nil(t, storedImage) -} - -func TestPerformerQueryAge(t *testing.T) { - const age = 19 - ageCriterion := models.IntCriterionInput{ - Value: age, - Modifier: models.CriterionModifierEquals, - } - - verifyPerformerAge(t, ageCriterion) - - ageCriterion.Modifier = models.CriterionModifierNotEquals - verifyPerformerAge(t, ageCriterion) - - ageCriterion.Modifier = models.CriterionModifierGreaterThan - verifyPerformerAge(t, ageCriterion) - - ageCriterion.Modifier = models.CriterionModifierLessThan - verifyPerformerAge(t, ageCriterion) -} - -func verifyPerformerAge(t *testing.T, ageCriterion models.IntCriterionInput) { - qb := models.NewPerformerQueryBuilder() - performerFilter := models.PerformerFilterType{ - Age: &ageCriterion, - } - - performers, _ := qb.Query(&performerFilter, nil) - - now := time.Now() - for _, performer := range performers { - bd := performer.Birthdate.String - d, _ := time.Parse("2006-01-02", bd) - age := now.Year() - d.Year() - if now.YearDay() < d.YearDay() { - age = age - 1 - } - - verifyInt(t, age, ageCriterion) - } -} - -// TODO Update -// TODO Destroy -// TODO Find -// TODO Count -// TODO All -// TODO AllSlim -// TODO Query diff --git a/pkg/models/querybuilder_scene_marker_test.go b/pkg/models/querybuilder_scene_marker_test.go deleted file mode 100644 index ee8be5406..000000000 --- a/pkg/models/querybuilder_scene_marker_test.go +++ /dev/null @@ -1,68 +0,0 @@ -// +build integration - -package models_test - -import ( - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/stashapp/stash/pkg/models" -) - -func TestMarkerFindBySceneID(t *testing.T) { - mqb := models.NewSceneMarkerQueryBuilder() - - sceneID := sceneIDs[sceneIdxWithMarker] - markers, err := mqb.FindBySceneID(sceneID, nil) - - if err != nil { - t.Fatalf("Error finding markers: %s", err.Error()) - } - - assert.Len(t, markers, 1) - assert.Equal(t, markerIDs[markerIdxWithScene], markers[0].ID) - - markers, err = mqb.FindBySceneID(0, nil) - - if err != nil { - t.Fatalf("Error finding marker: %s", err.Error()) - } - - assert.Len(t, markers, 0) -} - -func TestMarkerCountByTagID(t *testing.T) { - mqb := models.NewSceneMarkerQueryBuilder() - - markerCount, err := mqb.CountByTagID(tagIDs[tagIdxWithPrimaryMarker]) - - if err != nil { - t.Fatalf("error calling CountByTagID: %s", err.Error()) - } - - assert.Equal(t, 1, markerCount) - - markerCount, err = mqb.CountByTagID(tagIDs[tagIdxWithMarker]) - - if err != nil { - t.Fatalf("error calling CountByTagID: %s", err.Error()) - } - - assert.Equal(t, 1, markerCount) - - markerCount, err = mqb.CountByTagID(0) - - if err != nil { - t.Fatalf("error calling CountByTagID: %s", err.Error()) - } - - assert.Equal(t, 0, markerCount) -} - -// TODO Update -// TODO Destroy -// TODO Find -// TODO GetMarkerStrings -// TODO Wall -// TODO Query diff --git a/pkg/models/querybuilder_scene_test.go b/pkg/models/querybuilder_scene_test.go deleted file mode 100644 index 06b300674..000000000 --- a/pkg/models/querybuilder_scene_test.go +++ /dev/null @@ -1,1064 +0,0 @@ -// +build integration - -package models_test - -import ( - "context" - "database/sql" - "strconv" - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/stashapp/stash/pkg/database" - "github.com/stashapp/stash/pkg/models" - "github.com/stashapp/stash/pkg/utils" -) - -func TestSceneFind(t *testing.T) { - // assume that the first scene is sceneWithGalleryPath - sqb := models.NewSceneQueryBuilder() - - const sceneIdx = 0 - sceneID := sceneIDs[sceneIdx] - scene, err := sqb.Find(sceneID) - - if err != nil { - t.Fatalf("Error finding scene: %s", err.Error()) - } - - assert.Equal(t, getSceneStringValue(sceneIdx, "Path"), scene.Path) - - sceneID = 0 - scene, err = sqb.Find(sceneID) - - if err != nil { - t.Fatalf("Error finding scene: %s", err.Error()) - } - - assert.Nil(t, scene) -} - -func TestSceneFindByPath(t *testing.T) { - sqb := models.NewSceneQueryBuilder() - - const sceneIdx = 1 - scenePath := getSceneStringValue(sceneIdx, "Path") - scene, err := sqb.FindByPath(scenePath) - - if err != nil { - t.Fatalf("Error finding scene: %s", err.Error()) - } - - assert.Equal(t, sceneIDs[sceneIdx], scene.ID) - assert.Equal(t, scenePath, scene.Path) - - scenePath = "not exist" - scene, err = sqb.FindByPath(scenePath) - - if err != nil { - t.Fatalf("Error finding scene: %s", err.Error()) - } - - assert.Nil(t, scene) -} - -func TestSceneCountByPerformerID(t *testing.T) { - sqb := models.NewSceneQueryBuilder() - count, err := sqb.CountByPerformerID(performerIDs[performerIdxWithScene]) - - if err != nil { - t.Fatalf("Error counting scenes: %s", err.Error()) - } - - assert.Equal(t, 1, count) - - count, err = sqb.CountByPerformerID(0) - - if err != nil { - t.Fatalf("Error counting scenes: %s", err.Error()) - } - - assert.Equal(t, 0, count) -} - -func TestSceneWall(t *testing.T) { - sqb := models.NewSceneQueryBuilder() - - const sceneIdx = 2 - wallQuery := getSceneStringValue(sceneIdx, "Details") - scenes, err := sqb.Wall(&wallQuery) - - if err != nil { - t.Fatalf("Error finding scenes: %s", err.Error()) - } - - assert.Len(t, scenes, 1) - scene := scenes[0] - assert.Equal(t, sceneIDs[sceneIdx], scene.ID) - assert.Equal(t, getSceneStringValue(sceneIdx, "Path"), scene.Path) - - wallQuery = "not exist" - scenes, err = sqb.Wall(&wallQuery) - - if err != nil { - t.Fatalf("Error finding scene: %s", err.Error()) - } - - assert.Len(t, scenes, 0) -} - -func TestSceneQueryQ(t *testing.T) { - const sceneIdx = 2 - - q := getSceneStringValue(sceneIdx, titleField) - - sqb := models.NewSceneQueryBuilder() - - sceneQueryQ(t, sqb, q, sceneIdx) -} - -func sceneQueryQ(t *testing.T, sqb models.SceneQueryBuilder, q string, expectedSceneIdx int) { - filter := models.FindFilterType{ - Q: &q, - } - scenes, _ := sqb.Query(nil, &filter) - - assert.Len(t, scenes, 1) - scene := scenes[0] - assert.Equal(t, sceneIDs[expectedSceneIdx], scene.ID) - - // no Q should return all results - filter.Q = nil - scenes, _ = sqb.Query(nil, &filter) - - assert.Len(t, scenes, totalScenes) -} - -func TestSceneQueryPath(t *testing.T) { - const sceneIdx = 1 - scenePath := getSceneStringValue(sceneIdx, "Path") - - pathCriterion := models.StringCriterionInput{ - Value: scenePath, - Modifier: models.CriterionModifierEquals, - } - - verifyScenesPath(t, pathCriterion) - - pathCriterion.Modifier = models.CriterionModifierNotEquals - verifyScenesPath(t, pathCriterion) -} - -func verifyScenesPath(t *testing.T, pathCriterion models.StringCriterionInput) { - sqb := models.NewSceneQueryBuilder() - sceneFilter := models.SceneFilterType{ - Path: &pathCriterion, - } - - scenes, _ := sqb.Query(&sceneFilter, nil) - - for _, scene := range scenes { - verifyString(t, scene.Path, pathCriterion) - } -} - -func verifyNullString(t *testing.T, value sql.NullString, criterion models.StringCriterionInput) { - t.Helper() - assert := assert.New(t) - if criterion.Modifier == models.CriterionModifierIsNull { - assert.False(value.Valid, "expect is null values to be null") - } - if criterion.Modifier == models.CriterionModifierNotNull { - assert.True(value.Valid, "expect is null values to be null") - } - if criterion.Modifier == models.CriterionModifierEquals { - assert.Equal(criterion.Value, value.String) - } - if criterion.Modifier == models.CriterionModifierNotEquals { - assert.NotEqual(criterion.Value, value.String) - } -} - -func verifyString(t *testing.T, value string, criterion models.StringCriterionInput) { - t.Helper() - assert := assert.New(t) - if criterion.Modifier == models.CriterionModifierEquals { - assert.Equal(criterion.Value, value) - } - if criterion.Modifier == models.CriterionModifierNotEquals { - assert.NotEqual(criterion.Value, value) - } -} - -func TestSceneQueryRating(t *testing.T) { - const rating = 3 - ratingCriterion := models.IntCriterionInput{ - Value: rating, - Modifier: models.CriterionModifierEquals, - } - - verifyScenesRating(t, ratingCriterion) - - ratingCriterion.Modifier = models.CriterionModifierNotEquals - verifyScenesRating(t, ratingCriterion) - - ratingCriterion.Modifier = models.CriterionModifierGreaterThan - verifyScenesRating(t, ratingCriterion) - - ratingCriterion.Modifier = models.CriterionModifierLessThan - verifyScenesRating(t, ratingCriterion) - - ratingCriterion.Modifier = models.CriterionModifierIsNull - verifyScenesRating(t, ratingCriterion) - - ratingCriterion.Modifier = models.CriterionModifierNotNull - verifyScenesRating(t, ratingCriterion) -} - -func verifyScenesRating(t *testing.T, ratingCriterion models.IntCriterionInput) { - sqb := models.NewSceneQueryBuilder() - sceneFilter := models.SceneFilterType{ - Rating: &ratingCriterion, - } - - scenes, _ := sqb.Query(&sceneFilter, nil) - - for _, scene := range scenes { - verifyInt64(t, scene.Rating, ratingCriterion) - } -} - -func verifyInt64(t *testing.T, value sql.NullInt64, criterion models.IntCriterionInput) { - t.Helper() - assert := assert.New(t) - if criterion.Modifier == models.CriterionModifierIsNull { - assert.False(value.Valid, "expect is null values to be null") - } - if criterion.Modifier == models.CriterionModifierNotNull { - assert.True(value.Valid, "expect is null values to be null") - } - if criterion.Modifier == models.CriterionModifierEquals { - assert.Equal(int64(criterion.Value), value.Int64) - } - if criterion.Modifier == models.CriterionModifierNotEquals { - assert.NotEqual(int64(criterion.Value), value.Int64) - } - if criterion.Modifier == models.CriterionModifierGreaterThan { - assert.True(value.Int64 > int64(criterion.Value)) - } - if criterion.Modifier == models.CriterionModifierLessThan { - assert.True(value.Int64 < int64(criterion.Value)) - } -} - -func TestSceneQueryOCounter(t *testing.T) { - const oCounter = 1 - oCounterCriterion := models.IntCriterionInput{ - Value: oCounter, - Modifier: models.CriterionModifierEquals, - } - - verifyScenesOCounter(t, oCounterCriterion) - - oCounterCriterion.Modifier = models.CriterionModifierNotEquals - verifyScenesOCounter(t, oCounterCriterion) - - oCounterCriterion.Modifier = models.CriterionModifierGreaterThan - verifyScenesOCounter(t, oCounterCriterion) - - oCounterCriterion.Modifier = models.CriterionModifierLessThan - verifyScenesOCounter(t, oCounterCriterion) -} - -func verifyScenesOCounter(t *testing.T, oCounterCriterion models.IntCriterionInput) { - sqb := models.NewSceneQueryBuilder() - sceneFilter := models.SceneFilterType{ - OCounter: &oCounterCriterion, - } - - scenes, _ := sqb.Query(&sceneFilter, nil) - - for _, scene := range scenes { - verifyInt(t, scene.OCounter, oCounterCriterion) - } -} - -func verifyInt(t *testing.T, value int, criterion models.IntCriterionInput) { - t.Helper() - assert := assert.New(t) - if criterion.Modifier == models.CriterionModifierEquals { - assert.Equal(criterion.Value, value) - } - if criterion.Modifier == models.CriterionModifierNotEquals { - assert.NotEqual(criterion.Value, value) - } - if criterion.Modifier == models.CriterionModifierGreaterThan { - assert.True(value > criterion.Value) - } - if criterion.Modifier == models.CriterionModifierLessThan { - assert.True(value < criterion.Value) - } -} - -func TestSceneQueryDuration(t *testing.T) { - duration := 200.432 - - durationCriterion := models.IntCriterionInput{ - Value: int(duration), - Modifier: models.CriterionModifierEquals, - } - verifyScenesDuration(t, durationCriterion) - - durationCriterion.Modifier = models.CriterionModifierNotEquals - verifyScenesDuration(t, durationCriterion) - - durationCriterion.Modifier = models.CriterionModifierGreaterThan - verifyScenesDuration(t, durationCriterion) - - durationCriterion.Modifier = models.CriterionModifierLessThan - verifyScenesDuration(t, durationCriterion) - - durationCriterion.Modifier = models.CriterionModifierIsNull - verifyScenesDuration(t, durationCriterion) - - durationCriterion.Modifier = models.CriterionModifierNotNull - verifyScenesDuration(t, durationCriterion) -} - -func verifyScenesDuration(t *testing.T, durationCriterion models.IntCriterionInput) { - sqb := models.NewSceneQueryBuilder() - sceneFilter := models.SceneFilterType{ - Duration: &durationCriterion, - } - - scenes, _ := sqb.Query(&sceneFilter, nil) - - for _, scene := range scenes { - if durationCriterion.Modifier == models.CriterionModifierEquals { - assert.True(t, scene.Duration.Float64 >= float64(durationCriterion.Value) && scene.Duration.Float64 < float64(durationCriterion.Value+1)) - } else if durationCriterion.Modifier == models.CriterionModifierNotEquals { - assert.True(t, scene.Duration.Float64 < float64(durationCriterion.Value) || scene.Duration.Float64 >= float64(durationCriterion.Value+1)) - } else { - verifyFloat64(t, scene.Duration, durationCriterion) - } - } -} - -func verifyFloat64(t *testing.T, value sql.NullFloat64, criterion models.IntCriterionInput) { - assert := assert.New(t) - if criterion.Modifier == models.CriterionModifierIsNull { - assert.False(value.Valid, "expect is null values to be null") - } - if criterion.Modifier == models.CriterionModifierNotNull { - assert.True(value.Valid, "expect is null values to be null") - } - if criterion.Modifier == models.CriterionModifierEquals { - assert.Equal(float64(criterion.Value), value.Float64) - } - if criterion.Modifier == models.CriterionModifierNotEquals { - assert.NotEqual(float64(criterion.Value), value.Float64) - } - if criterion.Modifier == models.CriterionModifierGreaterThan { - assert.True(value.Float64 > float64(criterion.Value)) - } - if criterion.Modifier == models.CriterionModifierLessThan { - assert.True(value.Float64 < float64(criterion.Value)) - } -} - -func TestSceneQueryResolution(t *testing.T) { - verifyScenesResolution(t, models.ResolutionEnumLow) - verifyScenesResolution(t, models.ResolutionEnumStandard) - verifyScenesResolution(t, models.ResolutionEnumStandardHd) - verifyScenesResolution(t, models.ResolutionEnumFullHd) - verifyScenesResolution(t, models.ResolutionEnumFourK) - verifyScenesResolution(t, models.ResolutionEnum("unknown")) -} - -func verifyScenesResolution(t *testing.T, resolution models.ResolutionEnum) { - sqb := models.NewSceneQueryBuilder() - sceneFilter := models.SceneFilterType{ - Resolution: &resolution, - } - - scenes, _ := sqb.Query(&sceneFilter, nil) - - for _, scene := range scenes { - verifySceneResolution(t, scene.Height, resolution) - } -} - -func verifySceneResolution(t *testing.T, height sql.NullInt64, resolution models.ResolutionEnum) { - assert := assert.New(t) - h := height.Int64 - - switch resolution { - case models.ResolutionEnumLow: - assert.True(h < 480) - case models.ResolutionEnumStandard: - assert.True(h >= 480 && h < 720) - case models.ResolutionEnumStandardHd: - assert.True(h >= 720 && h < 1080) - case models.ResolutionEnumFullHd: - assert.True(h >= 1080 && h < 2160) - case models.ResolutionEnumFourK: - assert.True(h >= 2160) - } -} - -func TestSceneQueryHasMarkers(t *testing.T) { - sqb := models.NewSceneQueryBuilder() - hasMarkers := "true" - sceneFilter := models.SceneFilterType{ - HasMarkers: &hasMarkers, - } - - q := getSceneStringValue(sceneIdxWithMarker, titleField) - findFilter := models.FindFilterType{ - Q: &q, - } - - scenes, _ := sqb.Query(&sceneFilter, &findFilter) - - assert.Len(t, scenes, 1) - assert.Equal(t, sceneIDs[sceneIdxWithMarker], scenes[0].ID) - - hasMarkers = "false" - scenes, _ = sqb.Query(&sceneFilter, &findFilter) - assert.Len(t, scenes, 0) - - findFilter.Q = nil - scenes, _ = sqb.Query(&sceneFilter, &findFilter) - - assert.NotEqual(t, 0, len(scenes)) - - // ensure non of the ids equal the one with gallery - for _, scene := range scenes { - assert.NotEqual(t, sceneIDs[sceneIdxWithMarker], scene.ID) - } -} - -func TestSceneQueryIsMissingGallery(t *testing.T) { - sqb := models.NewSceneQueryBuilder() - isMissing := "gallery" - sceneFilter := models.SceneFilterType{ - IsMissing: &isMissing, - } - - q := getSceneStringValue(sceneIdxWithGallery, titleField) - findFilter := models.FindFilterType{ - Q: &q, - } - - scenes, _ := sqb.Query(&sceneFilter, &findFilter) - - assert.Len(t, scenes, 0) - - findFilter.Q = nil - scenes, _ = sqb.Query(&sceneFilter, &findFilter) - - // ensure non of the ids equal the one with gallery - for _, scene := range scenes { - assert.NotEqual(t, sceneIDs[sceneIdxWithGallery], scene.ID) - } -} - -func TestSceneQueryIsMissingStudio(t *testing.T) { - sqb := models.NewSceneQueryBuilder() - isMissing := "studio" - sceneFilter := models.SceneFilterType{ - IsMissing: &isMissing, - } - - q := getSceneStringValue(sceneIdxWithStudio, titleField) - findFilter := models.FindFilterType{ - Q: &q, - } - - scenes, _ := sqb.Query(&sceneFilter, &findFilter) - - assert.Len(t, scenes, 0) - - findFilter.Q = nil - scenes, _ = sqb.Query(&sceneFilter, &findFilter) - - // ensure non of the ids equal the one with studio - for _, scene := range scenes { - assert.NotEqual(t, sceneIDs[sceneIdxWithStudio], scene.ID) - } -} - -func TestSceneQueryIsMissingMovies(t *testing.T) { - sqb := models.NewSceneQueryBuilder() - isMissing := "movie" - sceneFilter := models.SceneFilterType{ - IsMissing: &isMissing, - } - - q := getSceneStringValue(sceneIdxWithMovie, titleField) - findFilter := models.FindFilterType{ - Q: &q, - } - - scenes, _ := sqb.Query(&sceneFilter, &findFilter) - - assert.Len(t, scenes, 0) - - findFilter.Q = nil - scenes, _ = sqb.Query(&sceneFilter, &findFilter) - - // ensure non of the ids equal the one with movies - for _, scene := range scenes { - assert.NotEqual(t, sceneIDs[sceneIdxWithMovie], scene.ID) - } -} - -func TestSceneQueryIsMissingPerformers(t *testing.T) { - sqb := models.NewSceneQueryBuilder() - isMissing := "performers" - sceneFilter := models.SceneFilterType{ - IsMissing: &isMissing, - } - - q := getSceneStringValue(sceneIdxWithPerformer, titleField) - findFilter := models.FindFilterType{ - Q: &q, - } - - scenes, _ := sqb.Query(&sceneFilter, &findFilter) - - assert.Len(t, scenes, 0) - - findFilter.Q = nil - scenes, _ = sqb.Query(&sceneFilter, &findFilter) - - assert.True(t, len(scenes) > 0) - - // ensure non of the ids equal the one with movies - for _, scene := range scenes { - assert.NotEqual(t, sceneIDs[sceneIdxWithPerformer], scene.ID) - } -} - -func TestSceneQueryIsMissingDate(t *testing.T) { - sqb := models.NewSceneQueryBuilder() - isMissing := "date" - sceneFilter := models.SceneFilterType{ - IsMissing: &isMissing, - } - - scenes, _ := sqb.Query(&sceneFilter, nil) - - assert.True(t, len(scenes) > 0) - - // ensure date is null, empty or "0001-01-01" - for _, scene := range scenes { - assert.True(t, !scene.Date.Valid || scene.Date.String == "" || scene.Date.String == "0001-01-01") - } -} - -func TestSceneQueryIsMissingTags(t *testing.T) { - sqb := models.NewSceneQueryBuilder() - isMissing := "tags" - sceneFilter := models.SceneFilterType{ - IsMissing: &isMissing, - } - - q := getSceneStringValue(sceneIdxWithTwoTags, titleField) - findFilter := models.FindFilterType{ - Q: &q, - } - - scenes, _ := sqb.Query(&sceneFilter, &findFilter) - - assert.Len(t, scenes, 0) - - findFilter.Q = nil - scenes, _ = sqb.Query(&sceneFilter, &findFilter) - - assert.True(t, len(scenes) > 0) -} - -func TestSceneQueryIsMissingRating(t *testing.T) { - sqb := models.NewSceneQueryBuilder() - isMissing := "rating" - sceneFilter := models.SceneFilterType{ - IsMissing: &isMissing, - } - - scenes, _ := sqb.Query(&sceneFilter, nil) - - assert.True(t, len(scenes) > 0) - - // ensure date is null, empty or "0001-01-01" - for _, scene := range scenes { - assert.True(t, !scene.Rating.Valid) - } -} - -func TestSceneQueryPerformers(t *testing.T) { - sqb := models.NewSceneQueryBuilder() - performerCriterion := models.MultiCriterionInput{ - Value: []string{ - strconv.Itoa(performerIDs[performerIdxWithScene]), - strconv.Itoa(performerIDs[performerIdx1WithScene]), - }, - Modifier: models.CriterionModifierIncludes, - } - - sceneFilter := models.SceneFilterType{ - Performers: &performerCriterion, - } - - scenes, _ := sqb.Query(&sceneFilter, nil) - - assert.Len(t, scenes, 2) - - // ensure ids are correct - for _, scene := range scenes { - assert.True(t, scene.ID == sceneIDs[sceneIdxWithPerformer] || scene.ID == sceneIDs[sceneIdxWithTwoPerformers]) - } - - performerCriterion = models.MultiCriterionInput{ - Value: []string{ - strconv.Itoa(performerIDs[performerIdx1WithScene]), - strconv.Itoa(performerIDs[performerIdx2WithScene]), - }, - Modifier: models.CriterionModifierIncludesAll, - } - - scenes, _ = sqb.Query(&sceneFilter, nil) - - assert.Len(t, scenes, 1) - assert.Equal(t, sceneIDs[sceneIdxWithTwoPerformers], scenes[0].ID) - - performerCriterion = models.MultiCriterionInput{ - Value: []string{ - strconv.Itoa(performerIDs[performerIdx1WithScene]), - }, - Modifier: models.CriterionModifierExcludes, - } - - q := getSceneStringValue(sceneIdxWithTwoPerformers, titleField) - findFilter := models.FindFilterType{ - Q: &q, - } - - scenes, _ = sqb.Query(&sceneFilter, &findFilter) - assert.Len(t, scenes, 0) -} - -func TestSceneQueryTags(t *testing.T) { - sqb := models.NewSceneQueryBuilder() - tagCriterion := models.MultiCriterionInput{ - Value: []string{ - strconv.Itoa(tagIDs[tagIdxWithScene]), - strconv.Itoa(tagIDs[tagIdx1WithScene]), - }, - Modifier: models.CriterionModifierIncludes, - } - - sceneFilter := models.SceneFilterType{ - Tags: &tagCriterion, - } - - scenes, _ := sqb.Query(&sceneFilter, nil) - - assert.Len(t, scenes, 2) - - // ensure ids are correct - for _, scene := range scenes { - assert.True(t, scene.ID == sceneIDs[sceneIdxWithTag] || scene.ID == sceneIDs[sceneIdxWithTwoTags]) - } - - tagCriterion = models.MultiCriterionInput{ - Value: []string{ - strconv.Itoa(tagIDs[tagIdx1WithScene]), - strconv.Itoa(tagIDs[tagIdx2WithScene]), - }, - Modifier: models.CriterionModifierIncludesAll, - } - - scenes, _ = sqb.Query(&sceneFilter, nil) - - assert.Len(t, scenes, 1) - assert.Equal(t, sceneIDs[sceneIdxWithTwoTags], scenes[0].ID) - - tagCriterion = models.MultiCriterionInput{ - Value: []string{ - strconv.Itoa(tagIDs[tagIdx1WithScene]), - }, - Modifier: models.CriterionModifierExcludes, - } - - q := getSceneStringValue(sceneIdxWithTwoTags, titleField) - findFilter := models.FindFilterType{ - Q: &q, - } - - scenes, _ = sqb.Query(&sceneFilter, &findFilter) - assert.Len(t, scenes, 0) -} - -func TestSceneQueryStudio(t *testing.T) { - sqb := models.NewSceneQueryBuilder() - studioCriterion := models.MultiCriterionInput{ - Value: []string{ - strconv.Itoa(studioIDs[studioIdxWithScene]), - }, - Modifier: models.CriterionModifierIncludes, - } - - sceneFilter := models.SceneFilterType{ - Studios: &studioCriterion, - } - - scenes, _ := sqb.Query(&sceneFilter, nil) - - assert.Len(t, scenes, 1) - - // ensure id is correct - assert.Equal(t, sceneIDs[sceneIdxWithStudio], scenes[0].ID) - - studioCriterion = models.MultiCriterionInput{ - Value: []string{ - strconv.Itoa(studioIDs[studioIdxWithScene]), - }, - Modifier: models.CriterionModifierExcludes, - } - - q := getSceneStringValue(sceneIdxWithStudio, titleField) - findFilter := models.FindFilterType{ - Q: &q, - } - - scenes, _ = sqb.Query(&sceneFilter, &findFilter) - assert.Len(t, scenes, 0) -} - -func TestSceneQueryMovies(t *testing.T) { - sqb := models.NewSceneQueryBuilder() - movieCriterion := models.MultiCriterionInput{ - Value: []string{ - strconv.Itoa(movieIDs[movieIdxWithScene]), - }, - Modifier: models.CriterionModifierIncludes, - } - - sceneFilter := models.SceneFilterType{ - Movies: &movieCriterion, - } - - scenes, _ := sqb.Query(&sceneFilter, nil) - - assert.Len(t, scenes, 1) - - // ensure id is correct - assert.Equal(t, sceneIDs[sceneIdxWithMovie], scenes[0].ID) - - movieCriterion = models.MultiCriterionInput{ - Value: []string{ - strconv.Itoa(movieIDs[movieIdxWithScene]), - }, - Modifier: models.CriterionModifierExcludes, - } - - q := getSceneStringValue(sceneIdxWithMovie, titleField) - findFilter := models.FindFilterType{ - Q: &q, - } - - scenes, _ = sqb.Query(&sceneFilter, &findFilter) - assert.Len(t, scenes, 0) -} - -func TestSceneQuerySorting(t *testing.T) { - sort := titleField - direction := models.SortDirectionEnumAsc - findFilter := models.FindFilterType{ - Sort: &sort, - Direction: &direction, - } - - sqb := models.NewSceneQueryBuilder() - scenes, _ := sqb.Query(nil, &findFilter) - - // scenes should be in same order as indexes - firstScene := scenes[0] - lastScene := scenes[len(scenes)-1] - - assert.Equal(t, sceneIDs[0], firstScene.ID) - assert.Equal(t, sceneIDs[len(sceneIDs)-1], lastScene.ID) - - // sort in descending order - direction = models.SortDirectionEnumDesc - - scenes, _ = sqb.Query(nil, &findFilter) - firstScene = scenes[0] - lastScene = scenes[len(scenes)-1] - - assert.Equal(t, sceneIDs[len(sceneIDs)-1], firstScene.ID) - assert.Equal(t, sceneIDs[0], lastScene.ID) -} - -func TestSceneQueryPagination(t *testing.T) { - perPage := 1 - findFilter := models.FindFilterType{ - PerPage: &perPage, - } - - sqb := models.NewSceneQueryBuilder() - scenes, _ := sqb.Query(nil, &findFilter) - - assert.Len(t, scenes, 1) - - firstID := scenes[0].ID - - page := 2 - findFilter.Page = &page - scenes, _ = sqb.Query(nil, &findFilter) - - assert.Len(t, scenes, 1) - secondID := scenes[0].ID - assert.NotEqual(t, firstID, secondID) - - perPage = 2 - page = 1 - - scenes, _ = sqb.Query(nil, &findFilter) - assert.Len(t, scenes, 2) - assert.Equal(t, firstID, scenes[0].ID) - assert.Equal(t, secondID, scenes[1].ID) -} - -func TestSceneCountByTagID(t *testing.T) { - sqb := models.NewSceneQueryBuilder() - - sceneCount, err := sqb.CountByTagID(tagIDs[tagIdxWithScene]) - - if err != nil { - t.Fatalf("error calling CountByTagID: %s", err.Error()) - } - - assert.Equal(t, 1, sceneCount) - - sceneCount, err = sqb.CountByTagID(0) - - if err != nil { - t.Fatalf("error calling CountByTagID: %s", err.Error()) - } - - assert.Equal(t, 0, sceneCount) -} - -func TestSceneCountByMovieID(t *testing.T) { - sqb := models.NewSceneQueryBuilder() - - sceneCount, err := sqb.CountByMovieID(movieIDs[movieIdxWithScene]) - - if err != nil { - t.Fatalf("error calling CountByMovieID: %s", err.Error()) - } - - assert.Equal(t, 1, sceneCount) - - sceneCount, err = sqb.CountByMovieID(0) - - if err != nil { - t.Fatalf("error calling CountByMovieID: %s", err.Error()) - } - - assert.Equal(t, 0, sceneCount) -} - -func TestSceneCountByStudioID(t *testing.T) { - sqb := models.NewSceneQueryBuilder() - - sceneCount, err := sqb.CountByStudioID(studioIDs[studioIdxWithScene]) - - if err != nil { - t.Fatalf("error calling CountByStudioID: %s", err.Error()) - } - - assert.Equal(t, 1, sceneCount) - - sceneCount, err = sqb.CountByStudioID(0) - - if err != nil { - t.Fatalf("error calling CountByStudioID: %s", err.Error()) - } - - assert.Equal(t, 0, sceneCount) -} - -func TestFindByMovieID(t *testing.T) { - sqb := models.NewSceneQueryBuilder() - - scenes, err := sqb.FindByMovieID(movieIDs[movieIdxWithScene]) - - if err != nil { - t.Fatalf("error calling FindByMovieID: %s", err.Error()) - } - - assert.Len(t, scenes, 1) - assert.Equal(t, sceneIDs[sceneIdxWithMovie], scenes[0].ID) - - scenes, err = sqb.FindByMovieID(0) - - if err != nil { - t.Fatalf("error calling FindByMovieID: %s", err.Error()) - } - - assert.Len(t, scenes, 0) -} - -func TestFindByPerformerID(t *testing.T) { - sqb := models.NewSceneQueryBuilder() - - scenes, err := sqb.FindByPerformerID(performerIDs[performerIdxWithScene]) - - if err != nil { - t.Fatalf("error calling FindByPerformerID: %s", err.Error()) - } - - assert.Len(t, scenes, 1) - assert.Equal(t, sceneIDs[sceneIdxWithPerformer], scenes[0].ID) - - scenes, err = sqb.FindByPerformerID(0) - - if err != nil { - t.Fatalf("error calling FindByPerformerID: %s", err.Error()) - } - - assert.Len(t, scenes, 0) -} - -func TestFindByStudioID(t *testing.T) { - sqb := models.NewSceneQueryBuilder() - - scenes, err := sqb.FindByStudioID(performerIDs[studioIdxWithScene]) - - if err != nil { - t.Fatalf("error calling FindByStudioID: %s", err.Error()) - } - - assert.Len(t, scenes, 1) - assert.Equal(t, sceneIDs[sceneIdxWithStudio], scenes[0].ID) - - scenes, err = sqb.FindByStudioID(0) - - if err != nil { - t.Fatalf("error calling FindByStudioID: %s", err.Error()) - } - - assert.Len(t, scenes, 0) -} - -func TestSceneUpdateSceneCover(t *testing.T) { - qb := models.NewSceneQueryBuilder() - - // create performer to test against - ctx := context.TODO() - tx := database.DB.MustBeginTx(ctx, nil) - - const name = "TestSceneUpdateSceneCover" - scene := models.Scene{ - Path: name, - Checksum: sql.NullString{String: utils.MD5FromString(name), Valid: true}, - } - created, err := qb.Create(scene, tx) - if err != nil { - tx.Rollback() - t.Fatalf("Error creating scene: %s", err.Error()) - } - - image := []byte("image") - err = qb.UpdateSceneCover(created.ID, image, tx) - if err != nil { - tx.Rollback() - t.Fatalf("Error updating scene cover: %s", err.Error()) - } - - if err := tx.Commit(); err != nil { - tx.Rollback() - t.Fatalf("Error committing: %s", err.Error()) - } - - // ensure image set - storedImage, err := qb.GetSceneCover(created.ID, nil) - if err != nil { - t.Fatalf("Error getting image: %s", err.Error()) - } - assert.Equal(t, storedImage, image) - - // set nil image - tx = database.DB.MustBeginTx(ctx, nil) - err = qb.UpdateSceneCover(created.ID, nil, tx) - if err == nil { - t.Fatalf("Expected error setting nil image") - } - - tx.Rollback() -} - -func TestSceneDestroySceneCover(t *testing.T) { - qb := models.NewSceneQueryBuilder() - - // create performer to test against - ctx := context.TODO() - tx := database.DB.MustBeginTx(ctx, nil) - - const name = "TestSceneDestroySceneCover" - scene := models.Scene{ - Path: name, - Checksum: sql.NullString{String: utils.MD5FromString(name), Valid: true}, - } - created, err := qb.Create(scene, tx) - if err != nil { - tx.Rollback() - t.Fatalf("Error creating scene: %s", err.Error()) - } - - image := []byte("image") - err = qb.UpdateSceneCover(created.ID, image, tx) - if err != nil { - tx.Rollback() - t.Fatalf("Error updating scene image: %s", err.Error()) - } - - if err := tx.Commit(); err != nil { - tx.Rollback() - t.Fatalf("Error committing: %s", err.Error()) - } - - tx = database.DB.MustBeginTx(ctx, nil) - - err = qb.DestroySceneCover(created.ID, tx) - if err != nil { - tx.Rollback() - t.Fatalf("Error destroying scene cover: %s", err.Error()) - } - - if err := tx.Commit(); err != nil { - tx.Rollback() - t.Fatalf("Error committing: %s", err.Error()) - } - - // image should be nil - storedImage, err := qb.GetSceneCover(created.ID, nil) - if err != nil { - t.Fatalf("Error getting image: %s", err.Error()) - } - assert.Nil(t, storedImage) -} - -// TODO Update -// TODO IncrementOCounter -// TODO DecrementOCounter -// TODO ResetOCounter -// TODO Destroy -// TODO FindByChecksum -// TODO Count -// TODO SizeCount -// TODO All diff --git a/pkg/models/querybuilder_scraped_item.go b/pkg/models/querybuilder_scraped_item.go deleted file mode 100644 index d78dfd918..000000000 --- a/pkg/models/querybuilder_scraped_item.go +++ /dev/null @@ -1,113 +0,0 @@ -package models - -import ( - "database/sql" - "github.com/jmoiron/sqlx" - "github.com/stashapp/stash/pkg/database" -) - -type ScrapedItemQueryBuilder struct{} - -func NewScrapedItemQueryBuilder() ScrapedItemQueryBuilder { - return ScrapedItemQueryBuilder{} -} - -func (qb *ScrapedItemQueryBuilder) Create(newScrapedItem ScrapedItem, tx *sqlx.Tx) (*ScrapedItem, error) { - ensureTx(tx) - result, err := tx.NamedExec( - `INSERT INTO scraped_items (title, description, url, date, rating, tags, models, episode, gallery_filename, - gallery_url, video_filename, video_url, studio_id, created_at, updated_at) - VALUES (:title, :description, :url, :date, :rating, :tags, :models, :episode, :gallery_filename, - :gallery_url, :video_filename, :video_url, :studio_id, :created_at, :updated_at) - `, - newScrapedItem, - ) - if err != nil { - return nil, err - } - scrapedItemID, err := result.LastInsertId() - if err != nil { - return nil, err - } - if err := tx.Get(&newScrapedItem, `SELECT * FROM scraped_items WHERE id = ? LIMIT 1`, scrapedItemID); err != nil { - return nil, err - } - return &newScrapedItem, nil -} - -func (qb *ScrapedItemQueryBuilder) Update(updatedScrapedItem ScrapedItem, tx *sqlx.Tx) (*ScrapedItem, error) { - ensureTx(tx) - _, err := tx.NamedExec( - `UPDATE scraped_items SET `+SQLGenKeys(updatedScrapedItem)+` WHERE scraped_items.id = :id`, - updatedScrapedItem, - ) - if err != nil { - return nil, err - } - - if err := tx.Get(&updatedScrapedItem, `SELECT * FROM scraped_items WHERE id = ? LIMIT 1`, updatedScrapedItem.ID); err != nil { - return nil, err - } - return &updatedScrapedItem, nil -} - -func (qb *ScrapedItemQueryBuilder) Find(id int) (*ScrapedItem, error) { - query := "SELECT * FROM scraped_items WHERE id = ? LIMIT 1" - args := []interface{}{id} - return qb.queryScrapedItem(query, args, nil) -} - -func (qb *ScrapedItemQueryBuilder) All() ([]*ScrapedItem, error) { - return qb.queryScrapedItems(selectAll("scraped_items")+qb.getScrapedItemsSort(nil), nil, nil) -} - -func (qb *ScrapedItemQueryBuilder) getScrapedItemsSort(findFilter *FindFilterType) string { - var sort string - var direction string - if findFilter == nil { - sort = "id" // TODO studio_id and title - direction = "ASC" - } else { - sort = findFilter.GetSort("id") - direction = findFilter.GetDirection() - } - return getSort(sort, direction, "scraped_items") -} - -func (qb *ScrapedItemQueryBuilder) queryScrapedItem(query string, args []interface{}, tx *sqlx.Tx) (*ScrapedItem, error) { - results, err := qb.queryScrapedItems(query, args, tx) - if err != nil || len(results) < 1 { - return nil, err - } - return results[0], nil -} - -func (qb *ScrapedItemQueryBuilder) queryScrapedItems(query string, args []interface{}, tx *sqlx.Tx) ([]*ScrapedItem, error) { - var rows *sqlx.Rows - var err error - if tx != nil { - rows, err = tx.Queryx(query, args...) - } else { - rows, err = database.DB.Queryx(query, args...) - } - - if err != nil && err != sql.ErrNoRows { - return nil, err - } - defer rows.Close() - - scrapedItems := make([]*ScrapedItem, 0) - for rows.Next() { - scrapedItem := ScrapedItem{} - if err := rows.StructScan(&scrapedItem); err != nil { - return nil, err - } - scrapedItems = append(scrapedItems, &scrapedItem) - } - - if err := rows.Err(); err != nil { - return nil, err - } - - return scrapedItems, nil -} diff --git a/pkg/models/querybuilder_sql.go b/pkg/models/querybuilder_sql.go deleted file mode 100644 index 50fa7d7b5..000000000 --- a/pkg/models/querybuilder_sql.go +++ /dev/null @@ -1,484 +0,0 @@ -package models - -import ( - "database/sql" - "fmt" - "math/rand" - "reflect" - "strconv" - "strings" - - "github.com/jmoiron/sqlx" - "github.com/stashapp/stash/pkg/database" - "github.com/stashapp/stash/pkg/logger" -) - -type queryBuilder struct { - tableName string - body string - - whereClauses []string - havingClauses []string - args []interface{} - - sortAndPagination string -} - -func (qb queryBuilder) executeFind() ([]int, int) { - return executeFindQuery(qb.tableName, qb.body, qb.args, qb.sortAndPagination, qb.whereClauses, qb.havingClauses) -} - -func (qb *queryBuilder) addWhere(clauses ...string) { - for _, clause := range clauses { - if len(clause) > 0 { - qb.whereClauses = append(qb.whereClauses, clause) - } - } -} - -func (qb *queryBuilder) addHaving(clauses ...string) { - for _, clause := range clauses { - if len(clause) > 0 { - qb.havingClauses = append(qb.havingClauses, clause) - } - } -} - -func (qb *queryBuilder) addArg(args ...interface{}) { - qb.args = append(qb.args, args...) -} - -func (qb *queryBuilder) handleIntCriterionInput(c *IntCriterionInput, column string) { - if c != nil { - clause, count := getIntCriterionWhereClause(column, *c) - qb.addWhere(clause) - if count == 1 { - qb.addArg(c.Value) - } - } -} - -func (qb *queryBuilder) handleStringCriterionInput(c *StringCriterionInput, column string) { - if c != nil { - if modifier := c.Modifier; c.Modifier.IsValid() { - switch modifier { - case CriterionModifierIncludes: - clause, thisArgs := getSearchBinding([]string{column}, c.Value, false) - qb.addWhere(clause) - qb.addArg(thisArgs...) - case CriterionModifierExcludes: - clause, thisArgs := getSearchBinding([]string{column}, c.Value, true) - qb.addWhere(clause) - qb.addArg(thisArgs...) - case CriterionModifierEquals: - qb.addWhere(column + " LIKE ?") - qb.addArg(c.Value) - case CriterionModifierNotEquals: - qb.addWhere(column + " NOT LIKE ?") - qb.addArg(c.Value) - default: - clause, count := getSimpleCriterionClause(modifier, "?") - qb.addWhere(column + " " + clause) - if count == 1 { - qb.addArg(c.Value) - } - } - } - } -} - -var randomSortFloat = rand.Float64() - -func selectAll(tableName string) string { - idColumn := getColumn(tableName, "*") - return "SELECT " + idColumn + " FROM " + tableName + " " -} - -func selectDistinctIDs(tableName string) string { - idColumn := getColumn(tableName, "id") - return "SELECT DISTINCT " + idColumn + " FROM " + tableName + " " -} - -func buildCountQuery(query string) string { - return "SELECT COUNT(*) as count FROM (" + query + ") as temp" -} - -func getColumn(tableName string, columnName string) string { - return tableName + "." + columnName -} - -func getPagination(findFilter *FindFilterType) string { - if findFilter == nil { - panic("nil find filter for pagination") - } - - var page int - if findFilter.Page == nil || *findFilter.Page < 1 { - page = 1 - } else { - page = *findFilter.Page - } - - var perPage int - if findFilter.PerPage == nil { - perPage = 25 - } else { - perPage = *findFilter.PerPage - } - - if perPage > 1000 { - perPage = 1000 - } else if perPage < 1 { - perPage = 1 - } - - page = (page - 1) * perPage - return " LIMIT " + strconv.Itoa(perPage) + " OFFSET " + strconv.Itoa(page) + " " -} - -func getSort(sort string, direction string, tableName string) string { - if direction != "ASC" && direction != "DESC" { - direction = "ASC" - } - - const randomSeedPrefix = "random_" - - if strings.HasSuffix(sort, "_count") { - var relationTableName = strings.TrimSuffix(sort, "_count") // TODO: pluralize? - colName := getColumn(relationTableName, "id") - return " ORDER BY COUNT(distinct " + colName + ") " + direction - } else if strings.Compare(sort, "filesize") == 0 { - colName := getColumn(tableName, "size") - return " ORDER BY cast(" + colName + " as integer) " + direction - } else if strings.HasPrefix(sort, randomSeedPrefix) { - // seed as a parameter from the UI - // turn the provided seed into a float - seedStr := "0." + sort[len(randomSeedPrefix):] - seed, err := strconv.ParseFloat(seedStr, 32) - if err != nil { - // fallback to default seed - seed = randomSortFloat - } - return getRandomSort(tableName, direction, seed) - } else if strings.Compare(sort, "random") == 0 { - return getRandomSort(tableName, direction, randomSortFloat) - } else { - colName := getColumn(tableName, sort) - var additional string - if tableName == "scenes" { - additional = ", bitrate DESC, framerate DESC, scenes.rating DESC, scenes.duration DESC" - } else if tableName == "scene_markers" { - additional = ", scene_markers.scene_id ASC, scene_markers.seconds ASC" - } - if strings.Compare(sort, "name") == 0 { - return " ORDER BY " + colName + " COLLATE NOCASE " + direction + additional - } - if strings.Compare(sort, "title") == 0 { - return " ORDER BY " + colName + " COLLATE NATURAL_CS " + direction + additional - } - - return " ORDER BY " + colName + " " + direction + additional - } -} - -func getRandomSort(tableName string, direction string, seed float64) string { - // https://stackoverflow.com/a/24511461 - colName := getColumn(tableName, "id") - randomSortString := strconv.FormatFloat(seed, 'f', 16, 32) - return " ORDER BY " + "(substr(" + colName + " * " + randomSortString + ", length(" + colName + ") + 2))" + " " + direction -} - -func getSearchBinding(columns []string, q string, not bool) (string, []interface{}) { - var likeClauses []string - var args []interface{} - - notStr := "" - binaryType := " OR " - if not { - notStr = " NOT " - binaryType = " AND " - } - - queryWords := strings.Split(q, " ") - trimmedQuery := strings.Trim(q, "\"") - if trimmedQuery == q { - // Search for any word - for _, word := range queryWords { - for _, column := range columns { - likeClauses = append(likeClauses, column+notStr+" LIKE ?") - args = append(args, "%"+word+"%") - } - } - } else { - // Search the exact query - for _, column := range columns { - likeClauses = append(likeClauses, column+notStr+" LIKE ?") - args = append(args, "%"+trimmedQuery+"%") - } - } - likes := strings.Join(likeClauses, binaryType) - - return "(" + likes + ")", args -} - -func getInBinding(length int) string { - bindings := strings.Repeat("?, ", length) - bindings = strings.TrimRight(bindings, ", ") - return "(" + bindings + ")" -} - -func getCriterionModifierBinding(criterionModifier CriterionModifier, value interface{}) (string, int) { - var length int - switch x := value.(type) { - case []string: - length = len(x) - case []int: - length = len(x) - default: - length = 1 - } - if modifier := criterionModifier.String(); criterionModifier.IsValid() { - switch modifier { - case "EQUALS", "NOT_EQUALS", "GREATER_THAN", "LESS_THAN", "IS_NULL", "NOT_NULL": - return getSimpleCriterionClause(criterionModifier, "?") - case "INCLUDES": - return "IN " + getInBinding(length), length // TODO? - case "EXCLUDES": - return "NOT IN " + getInBinding(length), length // TODO? - default: - logger.Errorf("todo") - return "= ?", 1 // TODO - } - } - return "= ?", 1 // TODO -} - -func getSimpleCriterionClause(criterionModifier CriterionModifier, rhs string) (string, int) { - if modifier := criterionModifier.String(); criterionModifier.IsValid() { - switch modifier { - case "EQUALS": - return "= " + rhs, 1 - case "NOT_EQUALS": - return "!= " + rhs, 1 - case "GREATER_THAN": - return "> " + rhs, 1 - case "LESS_THAN": - return "< " + rhs, 1 - case "IS_NULL": - return "IS NULL", 0 - case "NOT_NULL": - return "IS NOT NULL", 0 - default: - logger.Errorf("todo") - return "= ?", 1 // TODO - } - } - - return "= ?", 1 // TODO -} - -func getIntCriterionWhereClause(column string, input IntCriterionInput) (string, int) { - binding, count := getCriterionModifierBinding(input.Modifier, input.Value) - return column + " " + binding, count -} - -// returns where clause and having clause -func getMultiCriterionClause(primaryTable, foreignTable, joinTable, primaryFK, foreignFK string, criterion *MultiCriterionInput) (string, string) { - whereClause := "" - havingClause := "" - if criterion.Modifier == CriterionModifierIncludes { - // includes any of the provided ids - whereClause = foreignTable + ".id IN " + getInBinding(len(criterion.Value)) - } else if criterion.Modifier == CriterionModifierIncludesAll { - // includes all of the provided ids - whereClause = foreignTable + ".id IN " + getInBinding(len(criterion.Value)) - havingClause = "count(distinct " + foreignTable + ".id) IS " + strconv.Itoa(len(criterion.Value)) - } else if criterion.Modifier == CriterionModifierExcludes { - // excludes all of the provided ids - if joinTable != "" { - whereClause = "not exists (select " + joinTable + "." + primaryFK + " from " + joinTable + " where " + joinTable + "." + primaryFK + " = " + primaryTable + ".id and " + joinTable + "." + foreignFK + " in " + getInBinding(len(criterion.Value)) + ")" - } else { - whereClause = "not exists (select s.id from " + primaryTable + " as s where s.id = " + primaryTable + ".id and s." + foreignFK + " in " + getInBinding(len(criterion.Value)) + ")" - } - } - - return whereClause, havingClause -} - -func runIdsQuery(query string, args []interface{}) ([]int, error) { - var result []struct { - Int int `db:"id"` - } - if err := database.DB.Select(&result, query, args...); err != nil && err != sql.ErrNoRows { - return []int{}, err - } - - vsm := make([]int, len(result)) - for i, v := range result { - vsm[i] = v.Int - } - return vsm, nil -} - -func runCountQuery(query string, args []interface{}) (int, error) { - // Perform query and fetch result - result := struct { - Int int `db:"count"` - }{0} - if err := database.DB.Get(&result, query, args...); err != nil && err != sql.ErrNoRows { - return 0, err - } - - return result.Int, nil -} - -func runSumQuery(query string, args []interface{}) (float64, error) { - // Perform query and fetch result - result := struct { - Float64 float64 `db:"sum"` - }{0} - if err := database.DB.Get(&result, query, args...); err != nil && err != sql.ErrNoRows { - return 0, err - } - - return result.Float64, nil -} - -func executeFindQuery(tableName string, body string, args []interface{}, sortAndPagination string, whereClauses []string, havingClauses []string) ([]int, int) { - if len(whereClauses) > 0 { - body = body + " WHERE " + strings.Join(whereClauses, " AND ") // TODO handle AND or OR - } - body = body + " GROUP BY " + tableName + ".id " - if len(havingClauses) > 0 { - body = body + " HAVING " + strings.Join(havingClauses, " AND ") // TODO handle AND or OR - } - - countQuery := buildCountQuery(body) - idsQuery := body + sortAndPagination - - // Perform query and fetch result - logger.Tracef("SQL: %s, args: %v", idsQuery, args) - - countResult, countErr := runCountQuery(countQuery, args) - idsResult, idsErr := runIdsQuery(idsQuery, args) - - if countErr != nil { - logger.Errorf("Error executing count query with SQL: %s, args: %v, error: %s", countQuery, args, countErr.Error()) - panic(countErr) - } - if idsErr != nil { - logger.Errorf("Error executing find query with SQL: %s, args: %v, error: %s", idsQuery, args, idsErr.Error()) - panic(idsErr) - } - - return idsResult, countResult -} - -func executeDeleteQuery(tableName string, id string, tx *sqlx.Tx) error { - if tx == nil { - panic("must use a transaction") - } - idColumnName := getColumn(tableName, "id") - _, err := tx.Exec( - `DELETE FROM `+tableName+` WHERE `+idColumnName+` = ?`, - id, - ) - return err -} - -func ensureTx(tx *sqlx.Tx) { - if tx == nil { - panic("must use a transaction") - } -} - -// https://github.com/jmoiron/sqlx/issues/410 -// sqlGenKeys is used for passing a struct and returning a string -// of keys for non empty key:values. These keys are formated -// keyname=:keyname with a comma seperating them -func SQLGenKeys(i interface{}) string { - return sqlGenKeys(i, false) -} - -// support a partial interface. When a partial interface is provided, -// keys will always be included if the value is not null. The partial -// interface must therefore consist of pointers -func SQLGenKeysPartial(i interface{}) string { - return sqlGenKeys(i, true) -} - -func sqlGenKeys(i interface{}, partial bool) string { - var query []string - v := reflect.ValueOf(i) - for i := 0; i < v.NumField(); i++ { - //get key for struct tag - rawKey := v.Type().Field(i).Tag.Get("db") - key := strings.Split(rawKey, ",")[0] - if key == "id" { - continue - } - - var add bool - switch t := v.Field(i).Interface().(type) { - case string: - add = partial || t != "" - case int: - add = partial || t != 0 - case float64: - add = partial || t != 0 - case bool: - add = true - case SQLiteTimestamp: - add = partial || !t.Timestamp.IsZero() - case NullSQLiteTimestamp: - add = partial || t.Valid - case SQLiteDate: - add = partial || t.Valid - case sql.NullString: - add = partial || t.Valid - case sql.NullBool: - add = partial || t.Valid - case sql.NullInt64: - add = partial || t.Valid - case sql.NullFloat64: - add = partial || t.Valid - default: - reflectValue := reflect.ValueOf(t) - isNil := reflectValue.IsNil() - add = !isNil - } - - if add { - query = append(query, fmt.Sprintf("%s=:%s", key, key)) - } - } - return strings.Join(query, ", ") -} - -func getImage(tx *sqlx.Tx, query string, args ...interface{}) ([]byte, error) { - var rows *sqlx.Rows - var err error - if tx != nil { - rows, err = tx.Queryx(query, args...) - } else { - rows, err = database.DB.Queryx(query, args...) - } - - if err != nil && err != sql.ErrNoRows { - return nil, err - } - defer rows.Close() - - var ret []byte - if rows.Next() { - if err := rows.Scan(&ret); err != nil { - return nil, err - } - } - - if err := rows.Err(); err != nil { - return nil, err - } - - return ret, nil -} diff --git a/pkg/models/querybuilder_studio.go b/pkg/models/querybuilder_studio.go deleted file mode 100644 index dcdeeec2f..000000000 --- a/pkg/models/querybuilder_studio.go +++ /dev/null @@ -1,307 +0,0 @@ -package models - -import ( - "database/sql" - "fmt" - - "github.com/jmoiron/sqlx" - "github.com/stashapp/stash/pkg/database" -) - -type StudioQueryBuilder struct{} - -func NewStudioQueryBuilder() StudioQueryBuilder { - return StudioQueryBuilder{} -} - -func (qb *StudioQueryBuilder) Create(newStudio Studio, tx *sqlx.Tx) (*Studio, error) { - ensureTx(tx) - result, err := tx.NamedExec( - `INSERT INTO studios (checksum, name, url, parent_id, created_at, updated_at) - VALUES (:checksum, :name, :url, :parent_id, :created_at, :updated_at) - `, - newStudio, - ) - if err != nil { - return nil, err - } - studioID, err := result.LastInsertId() - if err != nil { - return nil, err - } - - if err := tx.Get(&newStudio, `SELECT * FROM studios WHERE id = ? LIMIT 1`, studioID); err != nil { - return nil, err - } - return &newStudio, nil -} - -func (qb *StudioQueryBuilder) Update(updatedStudio StudioPartial, tx *sqlx.Tx) (*Studio, error) { - ensureTx(tx) - _, err := tx.NamedExec( - `UPDATE studios SET `+SQLGenKeysPartial(updatedStudio)+` WHERE studios.id = :id`, - updatedStudio, - ) - if err != nil { - return nil, err - } - - var ret Studio - if err := tx.Get(&ret, `SELECT * FROM studios WHERE id = ? LIMIT 1`, updatedStudio.ID); err != nil { - return nil, err - } - return &ret, nil -} - -func (qb *StudioQueryBuilder) UpdateFull(updatedStudio Studio, tx *sqlx.Tx) (*Studio, error) { - ensureTx(tx) - _, err := tx.NamedExec( - `UPDATE studios SET `+SQLGenKeys(updatedStudio)+` WHERE studios.id = :id`, - updatedStudio, - ) - if err != nil { - return nil, err - } - - var ret Studio - if err := tx.Get(&ret, `SELECT * FROM studios WHERE id = ? LIMIT 1`, updatedStudio.ID); err != nil { - return nil, err - } - return &ret, nil -} - -func (qb *StudioQueryBuilder) Destroy(id string, tx *sqlx.Tx) error { - // remove studio from scenes - _, err := tx.Exec("UPDATE scenes SET studio_id = null WHERE studio_id = ?", id) - if err != nil { - return err - } - - // remove studio from scraped items - _, err = tx.Exec("UPDATE scraped_items SET studio_id = null WHERE studio_id = ?", id) - if err != nil { - return err - } - - return executeDeleteQuery("studios", id, tx) -} - -func (qb *StudioQueryBuilder) Find(id int, tx *sqlx.Tx) (*Studio, error) { - query := "SELECT * FROM studios WHERE id = ? LIMIT 1" - args := []interface{}{id} - return qb.queryStudio(query, args, tx) -} - -func (qb *StudioQueryBuilder) FindMany(ids []int) ([]*Studio, error) { - var studios []*Studio - for _, id := range ids { - studio, err := qb.Find(id, nil) - if err != nil { - return nil, err - } - - if studio == nil { - return nil, fmt.Errorf("studio with id %d not found", id) - } - - studios = append(studios, studio) - } - - return studios, nil -} - -func (qb *StudioQueryBuilder) FindChildren(id int, tx *sqlx.Tx) ([]*Studio, error) { - query := "SELECT studios.* FROM studios WHERE studios.parent_id = ?" - args := []interface{}{id} - return qb.queryStudios(query, args, tx) -} - -func (qb *StudioQueryBuilder) FindBySceneID(sceneID int) (*Studio, error) { - query := "SELECT studios.* FROM studios JOIN scenes ON studios.id = scenes.studio_id WHERE scenes.id = ? LIMIT 1" - args := []interface{}{sceneID} - return qb.queryStudio(query, args, nil) -} - -func (qb *StudioQueryBuilder) FindByName(name string, tx *sqlx.Tx, nocase bool) (*Studio, error) { - query := "SELECT * FROM studios WHERE name = ?" - if nocase { - query += " COLLATE NOCASE" - } - query += " LIMIT 1" - args := []interface{}{name} - return qb.queryStudio(query, args, tx) -} - -func (qb *StudioQueryBuilder) Count() (int, error) { - return runCountQuery(buildCountQuery("SELECT studios.id FROM studios"), nil) -} - -func (qb *StudioQueryBuilder) All() ([]*Studio, error) { - return qb.queryStudios(selectAll("studios")+qb.getStudioSort(nil), nil, nil) -} - -func (qb *StudioQueryBuilder) AllSlim() ([]*Studio, error) { - return qb.queryStudios("SELECT studios.id, studios.name, studios.parent_id FROM studios "+qb.getStudioSort(nil), nil, nil) -} - -func (qb *StudioQueryBuilder) Query(studioFilter *StudioFilterType, findFilter *FindFilterType) ([]*Studio, int) { - if studioFilter == nil { - studioFilter = &StudioFilterType{} - } - if findFilter == nil { - findFilter = &FindFilterType{} - } - - var whereClauses []string - var havingClauses []string - var args []interface{} - body := selectDistinctIDs("studios") - body += ` - left join scenes on studios.id = scenes.studio_id - left join studio_stash_ids on studio_stash_ids.studio_id = studios.id - ` - - if q := findFilter.Q; q != nil && *q != "" { - searchColumns := []string{"studios.name"} - - clause, thisArgs := getSearchBinding(searchColumns, *q, false) - whereClauses = append(whereClauses, clause) - args = append(args, thisArgs...) - } - - if parentsFilter := studioFilter.Parents; parentsFilter != nil && len(parentsFilter.Value) > 0 { - body += ` - left join studios as parent_studio on parent_studio.id = studios.parent_id - ` - - for _, studioID := range parentsFilter.Value { - args = append(args, studioID) - } - - whereClause, havingClause := getMultiCriterionClause("studios", "parent_studio", "", "", "parent_id", parentsFilter) - whereClauses = appendClause(whereClauses, whereClause) - havingClauses = appendClause(havingClauses, havingClause) - } - - if stashIDFilter := studioFilter.StashID; stashIDFilter != nil { - whereClauses = append(whereClauses, "studio_stash_ids.stash_id = ?") - args = append(args, stashIDFilter) - } - - if isMissingFilter := studioFilter.IsMissing; isMissingFilter != nil && *isMissingFilter != "" { - switch *isMissingFilter { - case "image": - body += `left join studios_image on studios_image.studio_id = studios.id - ` - whereClauses = appendClause(whereClauses, "studios_image.studio_id IS NULL") - case "stash_id": - whereClauses = appendClause(whereClauses, "studio_stash_ids.studio_id IS NULL") - default: - whereClauses = appendClause(whereClauses, "studios."+*isMissingFilter+" IS NULL") - } - } - - sortAndPagination := qb.getStudioSort(findFilter) + getPagination(findFilter) - idsResult, countResult := executeFindQuery("studios", body, args, sortAndPagination, whereClauses, havingClauses) - - var studios []*Studio - for _, id := range idsResult { - studio, _ := qb.Find(id, nil) - studios = append(studios, studio) - } - - return studios, countResult -} - -func (qb *StudioQueryBuilder) getStudioSort(findFilter *FindFilterType) string { - var sort string - var direction string - if findFilter == nil { - sort = "name" - direction = "ASC" - } else { - sort = findFilter.GetSort("name") - direction = findFilter.GetDirection() - } - return getSort(sort, direction, "studios") -} - -func (qb *StudioQueryBuilder) queryStudio(query string, args []interface{}, tx *sqlx.Tx) (*Studio, error) { - results, err := qb.queryStudios(query, args, tx) - if err != nil || len(results) < 1 { - return nil, err - } - return results[0], nil -} - -func (qb *StudioQueryBuilder) queryStudios(query string, args []interface{}, tx *sqlx.Tx) ([]*Studio, error) { - var rows *sqlx.Rows - var err error - if tx != nil { - rows, err = tx.Queryx(query, args...) - } else { - rows, err = database.DB.Queryx(query, args...) - } - - if err != nil && err != sql.ErrNoRows { - return nil, err - } - defer rows.Close() - - studios := make([]*Studio, 0) - for rows.Next() { - studio := Studio{} - if err := rows.StructScan(&studio); err != nil { - return nil, err - } - studios = append(studios, &studio) - } - - if err := rows.Err(); err != nil { - return nil, err - } - - return studios, nil -} - -func (qb *StudioQueryBuilder) UpdateStudioImage(studioID int, image []byte, tx *sqlx.Tx) error { - ensureTx(tx) - - // Delete the existing cover and then create new - if err := qb.DestroyStudioImage(studioID, tx); err != nil { - return err - } - - _, err := tx.Exec( - `INSERT INTO studios_image (studio_id, image) VALUES (?, ?)`, - studioID, - image, - ) - - return err -} - -func (qb *StudioQueryBuilder) DestroyStudioImage(studioID int, tx *sqlx.Tx) error { - ensureTx(tx) - - // Delete the existing joins - _, err := tx.Exec("DELETE FROM studios_image WHERE studio_id = ?", studioID) - if err != nil { - return err - } - return err -} - -func (qb *StudioQueryBuilder) GetStudioImage(studioID int, tx *sqlx.Tx) ([]byte, error) { - query := `SELECT image from studios_image WHERE studio_id = ?` - return getImage(tx, query, studioID) -} - -func (qb *StudioQueryBuilder) HasStudioImage(studioID int) (bool, error) { - ret, err := runCountQuery(buildCountQuery("SELECT studio_id from studios_image WHERE studio_id = ?"), []interface{}{studioID}) - if err != nil { - return false, err - } - - return ret == 1, nil -} diff --git a/pkg/models/querybuilder_studio_test.go b/pkg/models/querybuilder_studio_test.go deleted file mode 100644 index ed192046a..000000000 --- a/pkg/models/querybuilder_studio_test.go +++ /dev/null @@ -1,311 +0,0 @@ -// +build integration - -package models_test - -import ( - "context" - "database/sql" - "strconv" - "strings" - "testing" - - "github.com/stashapp/stash/pkg/database" - "github.com/stashapp/stash/pkg/models" - "github.com/stretchr/testify/assert" -) - -func TestStudioFindByName(t *testing.T) { - - sqb := models.NewStudioQueryBuilder() - - name := studioNames[studioIdxWithScene] // find a studio by name - - studio, err := sqb.FindByName(name, nil, false) - - if err != nil { - t.Fatalf("Error finding studios: %s", err.Error()) - } - - assert.Equal(t, studioNames[studioIdxWithScene], studio.Name.String) - - name = studioNames[studioIdxWithDupName] // find a studio by name nocase - - studio, err = sqb.FindByName(name, nil, true) - - if err != nil { - t.Fatalf("Error finding studios: %s", err.Error()) - } - // studioIdxWithDupName and studioIdxWithScene should have similar names ( only diff should be Name vs NaMe) - //studio.Name should match with studioIdxWithScene since its ID is before studioIdxWithDupName - assert.Equal(t, studioNames[studioIdxWithScene], studio.Name.String) - //studio.Name should match with studioIdxWithDupName if the check is not case sensitive - assert.Equal(t, strings.ToLower(studioNames[studioIdxWithDupName]), strings.ToLower(studio.Name.String)) - -} - -func TestStudioQueryParent(t *testing.T) { - sqb := models.NewStudioQueryBuilder() - studioCriterion := models.MultiCriterionInput{ - Value: []string{ - strconv.Itoa(studioIDs[studioIdxWithChildStudio]), - }, - Modifier: models.CriterionModifierIncludes, - } - - studioFilter := models.StudioFilterType{ - Parents: &studioCriterion, - } - - studios, _ := sqb.Query(&studioFilter, nil) - - assert.Len(t, studios, 1) - - // ensure id is correct - assert.Equal(t, sceneIDs[studioIdxWithParentStudio], studios[0].ID) - - studioCriterion = models.MultiCriterionInput{ - Value: []string{ - strconv.Itoa(studioIDs[studioIdxWithChildStudio]), - }, - Modifier: models.CriterionModifierExcludes, - } - - q := getStudioStringValue(studioIdxWithParentStudio, titleField) - findFilter := models.FindFilterType{ - Q: &q, - } - - studios, _ = sqb.Query(&studioFilter, &findFilter) - assert.Len(t, studios, 0) -} - -func TestStudioDestroyParent(t *testing.T) { - const parentName = "parent" - const childName = "child" - - // create parent and child studios - ctx := context.TODO() - tx := database.DB.MustBeginTx(ctx, nil) - - createdParent, err := createStudio(tx, parentName, nil) - if err != nil { - tx.Rollback() - t.Fatalf("Error creating parent studio: %s", err.Error()) - } - - parentID := int64(createdParent.ID) - createdChild, err := createStudio(tx, childName, &parentID) - if err != nil { - tx.Rollback() - t.Fatalf("Error creating child studio: %s", err.Error()) - } - - if err := tx.Commit(); err != nil { - tx.Rollback() - t.Fatalf("Error committing: %s", err.Error()) - } - - sqb := models.NewStudioQueryBuilder() - - // destroy the parent - tx = database.DB.MustBeginTx(ctx, nil) - - err = sqb.Destroy(strconv.Itoa(createdParent.ID), tx) - if err != nil { - tx.Rollback() - t.Fatalf("Error destroying parent studio: %s", err.Error()) - } - - if err := tx.Commit(); err != nil { - tx.Rollback() - t.Fatalf("Error committing: %s", err.Error()) - } - - // destroy the child - tx = database.DB.MustBeginTx(ctx, nil) - - err = sqb.Destroy(strconv.Itoa(createdChild.ID), tx) - if err != nil { - tx.Rollback() - t.Fatalf("Error destroying child studio: %s", err.Error()) - } - - if err := tx.Commit(); err != nil { - tx.Rollback() - t.Fatalf("Error committing: %s", err.Error()) - } -} - -func TestStudioFindChildren(t *testing.T) { - sqb := models.NewStudioQueryBuilder() - - studios, err := sqb.FindChildren(studioIDs[studioIdxWithChildStudio], nil) - - if err != nil { - t.Fatalf("error calling FindChildren: %s", err.Error()) - } - - assert.Len(t, studios, 1) - assert.Equal(t, studioIDs[studioIdxWithParentStudio], studios[0].ID) - - studios, err = sqb.FindChildren(0, nil) - - if err != nil { - t.Fatalf("error calling FindChildren: %s", err.Error()) - } - - assert.Len(t, studios, 0) -} - -func TestStudioUpdateClearParent(t *testing.T) { - const parentName = "clearParent_parent" - const childName = "clearParent_child" - - // create parent and child studios - ctx := context.TODO() - tx := database.DB.MustBeginTx(ctx, nil) - - createdParent, err := createStudio(tx, parentName, nil) - if err != nil { - tx.Rollback() - t.Fatalf("Error creating parent studio: %s", err.Error()) - } - - parentID := int64(createdParent.ID) - createdChild, err := createStudio(tx, childName, &parentID) - if err != nil { - tx.Rollback() - t.Fatalf("Error creating child studio: %s", err.Error()) - } - - if err := tx.Commit(); err != nil { - tx.Rollback() - t.Fatalf("Error committing: %s", err.Error()) - } - - sqb := models.NewStudioQueryBuilder() - - // clear the parent id from the child - tx = database.DB.MustBeginTx(ctx, nil) - - updatePartial := models.StudioPartial{ - ID: createdChild.ID, - ParentID: &sql.NullInt64{Valid: false}, - } - - updatedStudio, err := sqb.Update(updatePartial, tx) - - if err != nil { - tx.Rollback() - t.Fatalf("Error updated studio: %s", err.Error()) - } - - if updatedStudio.ParentID.Valid { - t.Error("updated studio has parent ID set") - } - - if err := tx.Commit(); err != nil { - tx.Rollback() - t.Fatalf("Error committing: %s", err.Error()) - } -} - -func TestStudioUpdateStudioImage(t *testing.T) { - qb := models.NewStudioQueryBuilder() - - // create performer to test against - ctx := context.TODO() - tx := database.DB.MustBeginTx(ctx, nil) - - const name = "TestStudioUpdateStudioImage" - created, err := createStudio(tx, name, nil) - if err != nil { - tx.Rollback() - t.Fatalf("Error creating studio: %s", err.Error()) - } - - image := []byte("image") - err = qb.UpdateStudioImage(created.ID, image, tx) - if err != nil { - tx.Rollback() - t.Fatalf("Error updating studio image: %s", err.Error()) - } - - if err := tx.Commit(); err != nil { - tx.Rollback() - t.Fatalf("Error committing: %s", err.Error()) - } - - // ensure image set - storedImage, err := qb.GetStudioImage(created.ID, nil) - if err != nil { - t.Fatalf("Error getting image: %s", err.Error()) - } - assert.Equal(t, storedImage, image) - - // set nil image - tx = database.DB.MustBeginTx(ctx, nil) - err = qb.UpdateStudioImage(created.ID, nil, tx) - if err == nil { - t.Fatalf("Expected error setting nil image") - } - - tx.Rollback() -} - -func TestStudioDestroyStudioImage(t *testing.T) { - qb := models.NewStudioQueryBuilder() - - // create performer to test against - ctx := context.TODO() - tx := database.DB.MustBeginTx(ctx, nil) - - const name = "TestStudioDestroyStudioImage" - created, err := createStudio(tx, name, nil) - if err != nil { - tx.Rollback() - t.Fatalf("Error creating studio: %s", err.Error()) - } - - image := []byte("image") - err = qb.UpdateStudioImage(created.ID, image, tx) - if err != nil { - tx.Rollback() - t.Fatalf("Error updating studio image: %s", err.Error()) - } - - if err := tx.Commit(); err != nil { - tx.Rollback() - t.Fatalf("Error committing: %s", err.Error()) - } - - tx = database.DB.MustBeginTx(ctx, nil) - - err = qb.DestroyStudioImage(created.ID, tx) - if err != nil { - tx.Rollback() - t.Fatalf("Error destroying studio image: %s", err.Error()) - } - - if err := tx.Commit(); err != nil { - tx.Rollback() - t.Fatalf("Error committing: %s", err.Error()) - } - - // image should be nil - storedImage, err := qb.GetStudioImage(created.ID, nil) - if err != nil { - t.Fatalf("Error getting image: %s", err.Error()) - } - assert.Nil(t, storedImage) -} - -// TODO Create -// TODO Update -// TODO Destroy -// TODO Find -// TODO FindBySceneID -// TODO Count -// TODO All -// TODO AllSlim -// TODO Query diff --git a/pkg/models/querybuilder_tag_test.go b/pkg/models/querybuilder_tag_test.go deleted file mode 100644 index 02b431ce7..000000000 --- a/pkg/models/querybuilder_tag_test.go +++ /dev/null @@ -1,314 +0,0 @@ -// +build integration - -package models_test - -import ( - "context" - "database/sql" - "strings" - "testing" - - "github.com/stashapp/stash/pkg/database" - "github.com/stashapp/stash/pkg/models" - "github.com/stretchr/testify/assert" -) - -func TestMarkerFindBySceneMarkerID(t *testing.T) { - tqb := models.NewTagQueryBuilder() - - markerID := markerIDs[markerIdxWithScene] - - tags, err := tqb.FindBySceneMarkerID(markerID, nil) - - if err != nil { - t.Fatalf("Error finding tags: %s", err.Error()) - } - - assert.Len(t, tags, 1) - assert.Equal(t, tagIDs[tagIdxWithMarker], tags[0].ID) - - tags, err = tqb.FindBySceneMarkerID(0, nil) - - if err != nil { - t.Fatalf("Error finding tags: %s", err.Error()) - } - - assert.Len(t, tags, 0) -} - -func TestTagFindByName(t *testing.T) { - - tqb := models.NewTagQueryBuilder() - - name := tagNames[tagIdxWithScene] // find a tag by name - - tag, err := tqb.FindByName(name, nil, false) - - if err != nil { - t.Fatalf("Error finding tags: %s", err.Error()) - } - - assert.Equal(t, tagNames[tagIdxWithScene], tag.Name) - - name = tagNames[tagIdxWithDupName] // find a tag by name nocase - - tag, err = tqb.FindByName(name, nil, true) - - if err != nil { - t.Fatalf("Error finding tags: %s", err.Error()) - } - // tagIdxWithDupName and tagIdxWithScene should have similar names ( only diff should be Name vs NaMe) - //tag.Name should match with tagIdxWithScene since its ID is before tagIdxWithDupName - assert.Equal(t, tagNames[tagIdxWithScene], tag.Name) - //tag.Name should match with tagIdxWithDupName if the check is not case sensitive - assert.Equal(t, strings.ToLower(tagNames[tagIdxWithDupName]), strings.ToLower(tag.Name)) - -} - -func TestTagFindByNames(t *testing.T) { - var names []string - - tqb := models.NewTagQueryBuilder() - - names = append(names, tagNames[tagIdxWithScene]) // find tags by names - - tags, err := tqb.FindByNames(names, nil, false) - if err != nil { - t.Fatalf("Error finding tags: %s", err.Error()) - } - assert.Len(t, tags, 1) - assert.Equal(t, tagNames[tagIdxWithScene], tags[0].Name) - - tags, err = tqb.FindByNames(names, nil, true) // find tags by names nocase - if err != nil { - t.Fatalf("Error finding tags: %s", err.Error()) - } - assert.Len(t, tags, 2) // tagIdxWithScene and tagIdxWithDupName - assert.Equal(t, strings.ToLower(tagNames[tagIdxWithScene]), strings.ToLower(tags[0].Name)) - assert.Equal(t, strings.ToLower(tagNames[tagIdxWithScene]), strings.ToLower(tags[1].Name)) - - names = append(names, tagNames[tagIdx1WithScene]) // find tags by names ( 2 names ) - - tags, err = tqb.FindByNames(names, nil, false) - if err != nil { - t.Fatalf("Error finding tags: %s", err.Error()) - } - assert.Len(t, tags, 2) // tagIdxWithScene and tagIdx1WithScene - assert.Equal(t, tagNames[tagIdxWithScene], tags[0].Name) - assert.Equal(t, tagNames[tagIdx1WithScene], tags[1].Name) - - tags, err = tqb.FindByNames(names, nil, true) // find tags by names ( 2 names nocase) - if err != nil { - t.Fatalf("Error finding tags: %s", err.Error()) - } - assert.Len(t, tags, 4) // tagIdxWithScene and tagIdxWithDupName , tagIdx1WithScene and tagIdx1WithDupName - assert.Equal(t, tagNames[tagIdxWithScene], tags[0].Name) - assert.Equal(t, tagNames[tagIdx1WithScene], tags[1].Name) - assert.Equal(t, tagNames[tagIdx1WithDupName], tags[2].Name) - assert.Equal(t, tagNames[tagIdxWithDupName], tags[3].Name) - -} - -func TestTagQueryIsMissingImage(t *testing.T) { - qb := models.NewTagQueryBuilder() - isMissing := "image" - tagFilter := models.TagFilterType{ - IsMissing: &isMissing, - } - - q := getTagStringValue(tagIdxWithCoverImage, "name") - findFilter := models.FindFilterType{ - Q: &q, - } - - tags, _ := qb.Query(&tagFilter, &findFilter) - - assert.Len(t, tags, 0) - - findFilter.Q = nil - tags, _ = qb.Query(&tagFilter, &findFilter) - - // ensure non of the ids equal the one with image - for _, tag := range tags { - assert.NotEqual(t, tagIDs[tagIdxWithCoverImage], tag.ID) - } -} - -func TestTagQuerySceneCount(t *testing.T) { - countCriterion := models.IntCriterionInput{ - Value: 1, - Modifier: models.CriterionModifierEquals, - } - - verifyTagSceneCount(t, countCriterion) - - countCriterion.Modifier = models.CriterionModifierNotEquals - verifyTagSceneCount(t, countCriterion) - - countCriterion.Modifier = models.CriterionModifierLessThan - verifyTagSceneCount(t, countCriterion) - - countCriterion.Value = 0 - countCriterion.Modifier = models.CriterionModifierGreaterThan - verifyTagSceneCount(t, countCriterion) -} - -func verifyTagSceneCount(t *testing.T, sceneCountCriterion models.IntCriterionInput) { - qb := models.NewTagQueryBuilder() - tagFilter := models.TagFilterType{ - SceneCount: &sceneCountCriterion, - } - - tags, _ := qb.Query(&tagFilter, nil) - - for _, tag := range tags { - verifyInt64(t, sql.NullInt64{ - Int64: int64(getTagSceneCount(tag.ID)), - Valid: true, - }, sceneCountCriterion) - } -} - -// disabled due to performance issues - -// func TestTagQueryMarkerCount(t *testing.T) { -// countCriterion := models.IntCriterionInput{ -// Value: 1, -// Modifier: models.CriterionModifierEquals, -// } - -// verifyTagMarkerCount(t, countCriterion) - -// countCriterion.Modifier = models.CriterionModifierNotEquals -// verifyTagMarkerCount(t, countCriterion) - -// countCriterion.Modifier = models.CriterionModifierLessThan -// verifyTagMarkerCount(t, countCriterion) - -// countCriterion.Value = 0 -// countCriterion.Modifier = models.CriterionModifierGreaterThan -// verifyTagMarkerCount(t, countCriterion) -// } - -func verifyTagMarkerCount(t *testing.T, markerCountCriterion models.IntCriterionInput) { - qb := models.NewTagQueryBuilder() - tagFilter := models.TagFilterType{ - MarkerCount: &markerCountCriterion, - } - - tags, _ := qb.Query(&tagFilter, nil) - - for _, tag := range tags { - verifyInt64(t, sql.NullInt64{ - Int64: int64(getTagMarkerCount(tag.ID)), - Valid: true, - }, markerCountCriterion) - } -} - -func TestTagUpdateTagImage(t *testing.T) { - qb := models.NewTagQueryBuilder() - - // create tag to test against - ctx := context.TODO() - tx := database.DB.MustBeginTx(ctx, nil) - - const name = "TestTagUpdateTagImage" - tag := models.Tag{ - Name: name, - } - created, err := qb.Create(tag, tx) - if err != nil { - tx.Rollback() - t.Fatalf("Error creating tag: %s", err.Error()) - } - - image := []byte("image") - err = qb.UpdateTagImage(created.ID, image, tx) - if err != nil { - tx.Rollback() - t.Fatalf("Error updating studio image: %s", err.Error()) - } - - if err := tx.Commit(); err != nil { - tx.Rollback() - t.Fatalf("Error committing: %s", err.Error()) - } - - // ensure image set - storedImage, err := qb.GetTagImage(created.ID, nil) - if err != nil { - t.Fatalf("Error getting image: %s", err.Error()) - } - assert.Equal(t, storedImage, image) - - // set nil image - tx = database.DB.MustBeginTx(ctx, nil) - err = qb.UpdateTagImage(created.ID, nil, tx) - if err == nil { - t.Fatalf("Expected error setting nil image") - } - - tx.Rollback() -} - -func TestTagDestroyTagImage(t *testing.T) { - qb := models.NewTagQueryBuilder() - - // create performer to test against - ctx := context.TODO() - tx := database.DB.MustBeginTx(ctx, nil) - - const name = "TestTagDestroyTagImage" - tag := models.Tag{ - Name: name, - } - created, err := qb.Create(tag, tx) - if err != nil { - tx.Rollback() - t.Fatalf("Error creating tag: %s", err.Error()) - } - - image := []byte("image") - err = qb.UpdateTagImage(created.ID, image, tx) - if err != nil { - tx.Rollback() - t.Fatalf("Error updating studio image: %s", err.Error()) - } - - if err := tx.Commit(); err != nil { - tx.Rollback() - t.Fatalf("Error committing: %s", err.Error()) - } - - tx = database.DB.MustBeginTx(ctx, nil) - - err = qb.DestroyTagImage(created.ID, tx) - if err != nil { - tx.Rollback() - t.Fatalf("Error destroying studio image: %s", err.Error()) - } - - if err := tx.Commit(); err != nil { - tx.Rollback() - t.Fatalf("Error committing: %s", err.Error()) - } - - // image should be nil - storedImage, err := qb.GetTagImage(created.ID, nil) - if err != nil { - t.Fatalf("Error getting image: %s", err.Error()) - } - assert.Nil(t, storedImage) -} - -// TODO Create -// TODO Update -// TODO Destroy -// TODO Find -// TODO FindBySceneID -// TODO FindBySceneMarkerID -// TODO Count -// TODO All -// TODO AllSlim -// TODO Query diff --git a/pkg/models/repository.go b/pkg/models/repository.go new file mode 100644 index 000000000..dfed77453 --- /dev/null +++ b/pkg/models/repository.go @@ -0,0 +1,25 @@ +package models + +type Repository interface { + Gallery() GalleryReaderWriter + Image() ImageReaderWriter + Movie() MovieReaderWriter + Performer() PerformerReaderWriter + Scene() SceneReaderWriter + SceneMarker() SceneMarkerReaderWriter + ScrapedItem() ScrapedItemReaderWriter + Studio() StudioReaderWriter + Tag() TagReaderWriter +} + +type ReaderRepository interface { + Gallery() GalleryReader + Image() ImageReader + Movie() MovieReader + Performer() PerformerReader + Scene() SceneReader + SceneMarker() SceneMarkerReader + ScrapedItem() ScrapedItemReader + Studio() StudioReader + Tag() TagReader +} diff --git a/pkg/models/scene.go b/pkg/models/scene.go index 2580506b6..e86671e3b 100644 --- a/pkg/models/scene.go +++ b/pkg/models/scene.go @@ -1,102 +1,53 @@ package models -import ( - "github.com/jmoiron/sqlx" -) - type SceneReader interface { - // Find(id int) (*Scene, error) + Find(id int) (*Scene, error) FindMany(ids []int) ([]*Scene, error) FindByChecksum(checksum string) (*Scene, error) FindByOSHash(oshash string) (*Scene, error) - // FindByPath(path string) (*Scene, error) - // FindByPerformerID(performerID int) ([]*Scene, error) - // CountByPerformerID(performerID int) (int, error) + FindByPath(path string) (*Scene, error) + FindByPerformerID(performerID int) ([]*Scene, error) + CountByPerformerID(performerID int) (int, error) // FindByStudioID(studioID int) ([]*Scene, error) FindByMovieID(movieID int) ([]*Scene, error) - // CountByMovieID(movieID int) (int, error) - // Count() (int, error) + CountByMovieID(movieID int) (int, error) + Count() (int, error) + Size() (float64, error) // SizeCount() (string, error) - // CountByStudioID(studioID int) (int, error) - // CountByTagID(tagID int) (int, error) - // CountMissingChecksum() (int, error) - // CountMissingOSHash() (int, error) - // Wall(q *string) ([]*Scene, error) + CountByStudioID(studioID int) (int, error) + CountByTagID(tagID int) (int, error) + CountMissingChecksum() (int, error) + CountMissingOSHash() (int, error) + Wall(q *string) ([]*Scene, error) All() ([]*Scene, error) - // Query(sceneFilter *SceneFilterType, findFilter *FindFilterType) ([]*Scene, int) - // QueryAllByPathRegex(regex string) ([]*Scene, error) - // QueryByPathRegex(findFilter *FindFilterType) ([]*Scene, int) - GetSceneCover(sceneID int) ([]byte, error) + Query(sceneFilter *SceneFilterType, findFilter *FindFilterType) ([]*Scene, int, error) + QueryAllByPathRegex(regex string, ignoreOrganized bool) ([]*Scene, error) + QueryByPathRegex(findFilter *FindFilterType) ([]*Scene, int, error) + GetCover(sceneID int) ([]byte, error) + GetMovies(sceneID int) ([]MoviesScenes, error) + GetTagIDs(imageID int) ([]int, error) + GetPerformerIDs(imageID int) ([]int, error) + GetStashIDs(performerID int) ([]*StashID, error) } type SceneWriter interface { Create(newScene Scene) (*Scene, error) Update(updatedScene ScenePartial) (*Scene, error) UpdateFull(updatedScene Scene) (*Scene, error) - // IncrementOCounter(id int) (int, error) - // DecrementOCounter(id int) (int, error) - // ResetOCounter(id int) (int, error) - // Destroy(id string) error - // UpdateFormat(id int, format string) error - // UpdateOSHash(id int, oshash string) error - // UpdateChecksum(id int, checksum string) error - UpdateSceneCover(sceneID int, cover []byte) error - // DestroySceneCover(sceneID int) error + IncrementOCounter(id int) (int, error) + DecrementOCounter(id int) (int, error) + ResetOCounter(id int) (int, error) + UpdateFileModTime(id int, modTime NullSQLiteTimestamp) error + Destroy(id int) error + UpdateCover(sceneID int, cover []byte) error + DestroyCover(sceneID int) error + UpdatePerformers(sceneID int, performerIDs []int) error + UpdateTags(sceneID int, tagIDs []int) error + UpdateMovies(sceneID int, movies []MoviesScenes) error + UpdateStashIDs(sceneID int, stashIDs []StashID) error } type SceneReaderWriter interface { SceneReader SceneWriter } - -func NewSceneReaderWriter(tx *sqlx.Tx) SceneReaderWriter { - return &sceneReaderWriter{ - tx: tx, - qb: NewSceneQueryBuilder(), - } -} - -type sceneReaderWriter struct { - tx *sqlx.Tx - qb SceneQueryBuilder -} - -func (t *sceneReaderWriter) FindMany(ids []int) ([]*Scene, error) { - return t.qb.FindMany(ids) -} - -func (t *sceneReaderWriter) FindByChecksum(checksum string) (*Scene, error) { - return t.qb.FindByChecksum(checksum) -} - -func (t *sceneReaderWriter) FindByOSHash(oshash string) (*Scene, error) { - return t.qb.FindByOSHash(oshash) -} - -func (t *sceneReaderWriter) FindByMovieID(movieID int) ([]*Scene, error) { - return t.qb.FindByMovieID(movieID) -} - -func (t *sceneReaderWriter) All() ([]*Scene, error) { - return t.qb.All() -} - -func (t *sceneReaderWriter) GetSceneCover(sceneID int) ([]byte, error) { - return t.qb.GetSceneCover(sceneID, t.tx) -} - -func (t *sceneReaderWriter) Create(newScene Scene) (*Scene, error) { - return t.qb.Create(newScene, t.tx) -} - -func (t *sceneReaderWriter) Update(updatedScene ScenePartial) (*Scene, error) { - return t.qb.Update(updatedScene, t.tx) -} - -func (t *sceneReaderWriter) UpdateFull(updatedScene Scene) (*Scene, error) { - return t.qb.UpdateFull(updatedScene, t.tx) -} - -func (t *sceneReaderWriter) UpdateSceneCover(sceneID int, cover []byte) error { - return t.qb.UpdateSceneCover(sceneID, cover, t.tx) -} diff --git a/pkg/models/scene_marker.go b/pkg/models/scene_marker.go index 530e00684..49f169098 100644 --- a/pkg/models/scene_marker.go +++ b/pkg/models/scene_marker.go @@ -1,50 +1,24 @@ package models -import ( - "github.com/jmoiron/sqlx" -) - type SceneMarkerReader interface { - // Find(id int) (*SceneMarker, error) - // FindMany(ids []int) ([]*SceneMarker, error) + Find(id int) (*SceneMarker, error) + FindMany(ids []int) ([]*SceneMarker, error) FindBySceneID(sceneID int) ([]*SceneMarker, error) - // CountByTagID(tagID int) (int, error) - // GetMarkerStrings(q *string, sort *string) ([]*MarkerStringsResultType, error) - // Wall(q *string) ([]*SceneMarker, error) - // Query(sceneMarkerFilter *SceneMarkerFilterType, findFilter *FindFilterType) ([]*SceneMarker, int) + CountByTagID(tagID int) (int, error) + GetMarkerStrings(q *string, sort *string) ([]*MarkerStringsResultType, error) + Wall(q *string) ([]*SceneMarker, error) + Query(sceneMarkerFilter *SceneMarkerFilterType, findFilter *FindFilterType) ([]*SceneMarker, int, error) + GetTagIDs(imageID int) ([]int, error) } type SceneMarkerWriter interface { Create(newSceneMarker SceneMarker) (*SceneMarker, error) Update(updatedSceneMarker SceneMarker) (*SceneMarker, error) - // Destroy(id string) error + Destroy(id int) error + UpdateTags(markerID int, tagIDs []int) error } type SceneMarkerReaderWriter interface { SceneMarkerReader SceneMarkerWriter } - -func NewSceneMarkerReaderWriter(tx *sqlx.Tx) SceneMarkerReaderWriter { - return &sceneMarkerReaderWriter{ - tx: tx, - qb: NewSceneMarkerQueryBuilder(), - } -} - -type sceneMarkerReaderWriter struct { - tx *sqlx.Tx - qb SceneMarkerQueryBuilder -} - -func (t *sceneMarkerReaderWriter) FindBySceneID(sceneID int) ([]*SceneMarker, error) { - return t.qb.FindBySceneID(sceneID, t.tx) -} - -func (t *sceneMarkerReaderWriter) Create(newSceneMarker SceneMarker) (*SceneMarker, error) { - return t.qb.Create(newSceneMarker, t.tx) -} - -func (t *sceneMarkerReaderWriter) Update(updatedSceneMarker SceneMarker) (*SceneMarker, error) { - return t.qb.Update(updatedSceneMarker, t.tx) -} diff --git a/pkg/models/scraped.go b/pkg/models/scraped.go index c2167e902..ecbddf68a 100644 --- a/pkg/models/scraped.go +++ b/pkg/models/scraped.go @@ -1,87 +1,14 @@ package models -import "strconv" - -// MatchScrapedScenePerformer matches the provided performer with the -// performers in the database and sets the ID field if one is found. -func MatchScrapedScenePerformer(p *ScrapedScenePerformer) error { - qb := NewPerformerQueryBuilder() - - performers, err := qb.FindByNames([]string{p.Name}, nil, true) - - if err != nil { - return err - } - - if len(performers) != 1 { - // ignore - cannot match - return nil - } - - id := strconv.Itoa(performers[0].ID) - p.ID = &id - return nil +type ScrapedItemReader interface { + All() ([]*ScrapedItem, error) } -// MatchScrapedSceneStudio matches the provided studio with the studios -// in the database and sets the ID field if one is found. -func MatchScrapedSceneStudio(s *ScrapedSceneStudio) error { - qb := NewStudioQueryBuilder() - - studio, err := qb.FindByName(s.Name, nil, true) - - if err != nil { - return err - } - - if studio == nil { - // ignore - cannot match - return nil - } - - id := strconv.Itoa(studio.ID) - s.ID = &id - return nil +type ScrapedItemWriter interface { + Create(newObject ScrapedItem) (*ScrapedItem, error) } -// MatchScrapedSceneMovie matches the provided movie with the movies -// in the database and sets the ID field if one is found. -func MatchScrapedSceneMovie(m *ScrapedSceneMovie) error { - qb := NewMovieQueryBuilder() - - movies, err := qb.FindByNames([]string{m.Name}, nil, true) - - if err != nil { - return err - } - - if len(movies) != 1 { - // ignore - cannot match - return nil - } - - id := strconv.Itoa(movies[0].ID) - m.ID = &id - return nil -} - -// MatchScrapedSceneTag matches the provided tag with the tags -// in the database and sets the ID field if one is found. -func MatchScrapedSceneTag(s *ScrapedSceneTag) error { - qb := NewTagQueryBuilder() - - tag, err := qb.FindByName(s.Name, nil, true) - - if err != nil { - return err - } - - if tag == nil { - // ignore - cannot match - return nil - } - - id := strconv.Itoa(tag.ID) - s.ID = &id - return nil +type ScrapedItemReaderWriter interface { + ScrapedItemReader + ScrapedItemWriter } diff --git a/pkg/models/modelstest/sql.go b/pkg/models/sql.go similarity index 65% rename from pkg/models/modelstest/sql.go rename to pkg/models/sql.go index dbc34bab5..af7f8374e 100644 --- a/pkg/models/modelstest/sql.go +++ b/pkg/models/sql.go @@ -1,4 +1,4 @@ -package modelstest +package models import "database/sql" @@ -14,11 +14,4 @@ func NullInt64(v int64) sql.NullInt64 { Int64: v, Valid: true, } -} - -func NullBool(v bool) sql.NullBool { - return sql.NullBool{ - Bool: v, - Valid: true, - } -} +} \ No newline at end of file diff --git a/pkg/models/stash_ids.go b/pkg/models/stash_ids.go new file mode 100644 index 000000000..fb20522a5 --- /dev/null +++ b/pkg/models/stash_ids.go @@ -0,0 +1,14 @@ +package models + +func StashIDsFromInput(i []*StashIDInput) []StashID { + var ret []StashID + for _, stashID := range i { + newJoin := StashID{ + StashID: stashID.StashID, + Endpoint: stashID.Endpoint, + } + ret = append(ret, newJoin) + } + + return ret +} diff --git a/pkg/models/studio.go b/pkg/models/studio.go index 8916ad3a4..6addf8cd1 100644 --- a/pkg/models/studio.go +++ b/pkg/models/studio.go @@ -1,80 +1,30 @@ package models -import ( - "github.com/jmoiron/sqlx" -) - type StudioReader interface { Find(id int) (*Studio, error) FindMany(ids []int) ([]*Studio, error) - // FindChildren(id int) ([]*Studio, error) - // FindBySceneID(sceneID int) (*Studio, error) + FindChildren(id int) ([]*Studio, error) FindByName(name string, nocase bool) (*Studio, error) - // Count() (int, error) + Count() (int, error) All() ([]*Studio, error) - // AllSlim() ([]*Studio, error) - // Query(studioFilter *StudioFilterType, findFilter *FindFilterType) ([]*Studio, int) - GetStudioImage(studioID int) ([]byte, error) + AllSlim() ([]*Studio, error) + Query(studioFilter *StudioFilterType, findFilter *FindFilterType) ([]*Studio, int, error) + GetImage(studioID int) ([]byte, error) + HasImage(studioID int) (bool, error) + GetStashIDs(studioID int) ([]*StashID, error) } type StudioWriter interface { Create(newStudio Studio) (*Studio, error) Update(updatedStudio StudioPartial) (*Studio, error) UpdateFull(updatedStudio Studio) (*Studio, error) - // Destroy(id string) error - UpdateStudioImage(studioID int, image []byte) error - // DestroyStudioImage(studioID int) error + Destroy(id int) error + UpdateImage(studioID int, image []byte) error + DestroyImage(studioID int) error + UpdateStashIDs(studioID int, stashIDs []StashID) error } type StudioReaderWriter interface { StudioReader StudioWriter } - -func NewStudioReaderWriter(tx *sqlx.Tx) StudioReaderWriter { - return &studioReaderWriter{ - tx: tx, - qb: NewStudioQueryBuilder(), - } -} - -type studioReaderWriter struct { - tx *sqlx.Tx - qb StudioQueryBuilder -} - -func (t *studioReaderWriter) Find(id int) (*Studio, error) { - return t.qb.Find(id, t.tx) -} - -func (t *studioReaderWriter) FindMany(ids []int) ([]*Studio, error) { - return t.qb.FindMany(ids) -} - -func (t *studioReaderWriter) FindByName(name string, nocase bool) (*Studio, error) { - return t.qb.FindByName(name, t.tx, nocase) -} - -func (t *studioReaderWriter) All() ([]*Studio, error) { - return t.qb.All() -} - -func (t *studioReaderWriter) GetStudioImage(studioID int) ([]byte, error) { - return t.qb.GetStudioImage(studioID, t.tx) -} - -func (t *studioReaderWriter) Create(newStudio Studio) (*Studio, error) { - return t.qb.Create(newStudio, t.tx) -} - -func (t *studioReaderWriter) Update(updatedStudio StudioPartial) (*Studio, error) { - return t.qb.Update(updatedStudio, t.tx) -} - -func (t *studioReaderWriter) UpdateFull(updatedStudio Studio) (*Studio, error) { - return t.qb.UpdateFull(updatedStudio, t.tx) -} - -func (t *studioReaderWriter) UpdateStudioImage(studioID int, image []byte) error { - return t.qb.UpdateStudioImage(studioID, image, t.tx) -} diff --git a/pkg/models/tag.go b/pkg/models/tag.go index 5816f60b7..7bdd8a197 100644 --- a/pkg/models/tag.go +++ b/pkg/models/tag.go @@ -1,9 +1,5 @@ package models -import ( - "github.com/jmoiron/sqlx" -) - type TagReader interface { Find(id int) (*Tag, error) FindMany(ids []int) ([]*Tag, error) @@ -13,86 +9,22 @@ type TagReader interface { FindByGalleryID(galleryID int) ([]*Tag, error) FindByName(name string, nocase bool) (*Tag, error) FindByNames(names []string, nocase bool) ([]*Tag, error) - // Count() (int, error) + Count() (int, error) All() ([]*Tag, error) - // AllSlim() ([]*Tag, error) - // Query(tagFilter *TagFilterType, findFilter *FindFilterType) ([]*Tag, int, error) - GetTagImage(tagID int) ([]byte, error) + AllSlim() ([]*Tag, error) + Query(tagFilter *TagFilterType, findFilter *FindFilterType) ([]*Tag, int, error) + GetImage(tagID int) ([]byte, error) } type TagWriter interface { Create(newTag Tag) (*Tag, error) Update(updatedTag Tag) (*Tag, error) - // Destroy(id string) error - UpdateTagImage(tagID int, image []byte) error - // DestroyTagImage(tagID int) error + Destroy(id int) error + UpdateImage(tagID int, image []byte) error + DestroyImage(tagID int) error } type TagReaderWriter interface { TagReader TagWriter } - -func NewTagReaderWriter(tx *sqlx.Tx) TagReaderWriter { - return &tagReaderWriter{ - tx: tx, - qb: NewTagQueryBuilder(), - } -} - -type tagReaderWriter struct { - tx *sqlx.Tx - qb TagQueryBuilder -} - -func (t *tagReaderWriter) Find(id int) (*Tag, error) { - return t.qb.Find(id, t.tx) -} - -func (t *tagReaderWriter) FindMany(ids []int) ([]*Tag, error) { - return t.qb.FindMany(ids) -} - -func (t *tagReaderWriter) All() ([]*Tag, error) { - return t.qb.All() -} - -func (t *tagReaderWriter) FindBySceneMarkerID(sceneMarkerID int) ([]*Tag, error) { - return t.qb.FindBySceneMarkerID(sceneMarkerID, t.tx) -} - -func (t *tagReaderWriter) FindByName(name string, nocase bool) (*Tag, error) { - return t.qb.FindByName(name, t.tx, nocase) -} - -func (t *tagReaderWriter) FindByNames(names []string, nocase bool) ([]*Tag, error) { - return t.qb.FindByNames(names, t.tx, nocase) -} - -func (t *tagReaderWriter) GetTagImage(tagID int) ([]byte, error) { - return t.qb.GetTagImage(tagID, t.tx) -} - -func (t *tagReaderWriter) FindBySceneID(sceneID int) ([]*Tag, error) { - return t.qb.FindBySceneID(sceneID, t.tx) -} - -func (t *tagReaderWriter) FindByImageID(imageID int) ([]*Tag, error) { - return t.qb.FindByImageID(imageID, t.tx) -} - -func (t *tagReaderWriter) FindByGalleryID(imageID int) ([]*Tag, error) { - return t.qb.FindByGalleryID(imageID, t.tx) -} - -func (t *tagReaderWriter) Create(newTag Tag) (*Tag, error) { - return t.qb.Create(newTag, t.tx) -} - -func (t *tagReaderWriter) Update(updatedTag Tag) (*Tag, error) { - return t.qb.Update(updatedTag, t.tx) -} - -func (t *tagReaderWriter) UpdateTagImage(tagID int, image []byte) error { - return t.qb.UpdateTagImage(tagID, image, t.tx) -} diff --git a/pkg/models/transaction.go b/pkg/models/transaction.go new file mode 100644 index 000000000..525de5a99 --- /dev/null +++ b/pkg/models/transaction.go @@ -0,0 +1,70 @@ +package models + +import "context" + +type Transaction interface { + Begin() error + Rollback() error + Commit() error + Repository() Repository +} + +type ReadTransaction interface { + Begin() error + Rollback() error + Commit() error + Repository() ReaderRepository +} + +type TransactionManager interface { + WithTxn(ctx context.Context, fn func(r Repository) error) error + WithReadTxn(ctx context.Context, fn func(r ReaderRepository) error) error +} + +func WithTxn(txn Transaction, fn func(r Repository) error) error { + err := txn.Begin() + if err != nil { + return err + } + + defer func() { + if p := recover(); p != nil { + // a panic occurred, rollback and repanic + txn.Rollback() + panic(p) + } else if err != nil { + // something went wrong, rollback + txn.Rollback() + } else { + // all good, commit + err = txn.Commit() + } + }() + + err = fn(txn.Repository()) + return err +} + +func WithROTxn(txn ReadTransaction, fn func(r ReaderRepository) error) error { + err := txn.Begin() + if err != nil { + return err + } + + defer func() { + if p := recover(); p != nil { + // a panic occurred, rollback and repanic + txn.Rollback() + panic(p) + } else if err != nil { + // something went wrong, rollback + txn.Rollback() + } else { + // all good, commit + err = txn.Commit() + } + }() + + err = fn(txn.Repository()) + return err +} diff --git a/pkg/movie/export_test.go b/pkg/movie/export_test.go index c5298dc82..b889a0428 100644 --- a/pkg/movie/export_test.go +++ b/pkg/movie/export_test.go @@ -7,7 +7,6 @@ import ( "github.com/stashapp/stash/pkg/manager/jsonschema" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models/mocks" - "github.com/stashapp/stash/pkg/models/modelstest" "github.com/stretchr/testify/assert" "testing" @@ -52,7 +51,7 @@ var frontImageBytes = []byte("frontImageBytes") var backImageBytes = []byte("backImageBytes") var studio models.Studio = models.Studio{ - Name: modelstest.NullString(studioName), + Name: models.NullString(studioName), } var createTime time.Time = time.Date(2001, 01, 01, 0, 0, 0, 0, time.UTC) @@ -61,8 +60,8 @@ var updateTime time.Time = time.Date(2002, 01, 01, 0, 0, 0, 0, time.UTC) func createFullMovie(id int, studioID int) models.Movie { return models.Movie{ ID: id, - Name: modelstest.NullString(movieName), - Aliases: modelstest.NullString(movieAliases), + Name: models.NullString(movieName), + Aliases: models.NullString(movieAliases), Date: date, Rating: sql.NullInt64{ Int64: rating, @@ -72,9 +71,9 @@ func createFullMovie(id int, studioID int) models.Movie { Int64: duration, Valid: true, }, - Director: modelstest.NullString(director), - Synopsis: modelstest.NullString(synopsis), - URL: modelstest.NullString(url), + Director: models.NullString(director), + Synopsis: models.NullString(synopsis), + URL: models.NullString(url), StudioID: sql.NullInt64{ Int64: int64(studioID), Valid: true, diff --git a/pkg/movie/import.go b/pkg/movie/import.go index 7e1065df6..7108944fb 100644 --- a/pkg/movie/import.go +++ b/pkg/movie/import.go @@ -117,7 +117,7 @@ func (i *Importer) createStudio(name string) (int, error) { func (i *Importer) PostImport(id int) error { if len(i.frontImageData) > 0 { - if err := i.ReaderWriter.UpdateMovieImages(id, i.frontImageData, i.backImageData); err != nil { + if err := i.ReaderWriter.UpdateImages(id, i.frontImageData, i.backImageData); err != nil { return fmt.Errorf("error setting movie images: %s", err.Error()) } } diff --git a/pkg/movie/import_test.go b/pkg/movie/import_test.go index afdc484d1..e81e1a9f1 100644 --- a/pkg/movie/import_test.go +++ b/pkg/movie/import_test.go @@ -7,7 +7,6 @@ import ( "github.com/stashapp/stash/pkg/manager/jsonschema" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models/mocks" - "github.com/stashapp/stash/pkg/models/modelstest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) @@ -158,10 +157,10 @@ func TestImporterPostImport(t *testing.T) { backImageData: backImageBytes, } - updateMovieImageErr := errors.New("UpdateMovieImage error") + updateMovieImageErr := errors.New("UpdateImages error") - readerWriter.On("UpdateMovieImages", movieID, frontImageBytes, backImageBytes).Return(nil).Once() - readerWriter.On("UpdateMovieImages", errImageID, frontImageBytes, backImageBytes).Return(updateMovieImageErr).Once() + readerWriter.On("UpdateImages", movieID, frontImageBytes, backImageBytes).Return(nil).Once() + readerWriter.On("UpdateImages", errImageID, frontImageBytes, backImageBytes).Return(updateMovieImageErr).Once() err := i.PostImport(movieID) assert.Nil(t, err) @@ -210,11 +209,11 @@ func TestCreate(t *testing.T) { readerWriter := &mocks.MovieReaderWriter{} movie := models.Movie{ - Name: modelstest.NullString(movieName), + Name: models.NullString(movieName), } movieErr := models.Movie{ - Name: modelstest.NullString(movieNameErr), + Name: models.NullString(movieNameErr), } i := Importer{ @@ -244,11 +243,11 @@ func TestUpdate(t *testing.T) { readerWriter := &mocks.MovieReaderWriter{} movie := models.Movie{ - Name: modelstest.NullString(movieName), + Name: models.NullString(movieName), } movieErr := models.Movie{ - Name: modelstest.NullString(movieNameErr), + Name: models.NullString(movieNameErr), } i := Importer{ diff --git a/pkg/performer/export.go b/pkg/performer/export.go index 6f9b44ee8..a038a2560 100644 --- a/pkg/performer/export.go +++ b/pkg/performer/export.go @@ -67,7 +67,7 @@ func ToJSON(reader models.PerformerReader, performer *models.Performer) (*jsonsc newPerformerJSON.Favorite = performer.Favorite.Bool } - image, err := reader.GetPerformerImage(performer.ID) + image, err := reader.GetImage(performer.ID) if err != nil { return nil, fmt.Errorf("error getting performers image: %s", err.Error()) } diff --git a/pkg/performer/export_test.go b/pkg/performer/export_test.go index 6df042341..aa880e40c 100644 --- a/pkg/performer/export_test.go +++ b/pkg/performer/export_test.go @@ -7,7 +7,6 @@ import ( "github.com/stashapp/stash/pkg/manager/jsonschema" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models/mocks" - "github.com/stashapp/stash/pkg/models/modelstest" "github.com/stashapp/stash/pkg/utils" "github.com/stretchr/testify/assert" @@ -53,27 +52,27 @@ var updateTime time.Time = time.Date(2002, 01, 01, 0, 0, 0, 0, time.Local) func createFullPerformer(id int, name string) *models.Performer { return &models.Performer{ ID: id, - Name: modelstest.NullString(name), + Name: models.NullString(name), Checksum: utils.MD5FromString(name), - URL: modelstest.NullString(url), - Aliases: modelstest.NullString(aliases), + URL: models.NullString(url), + Aliases: models.NullString(aliases), Birthdate: birthDate, - CareerLength: modelstest.NullString(careerLength), - Country: modelstest.NullString(country), - Ethnicity: modelstest.NullString(ethnicity), - EyeColor: modelstest.NullString(eyeColor), - FakeTits: modelstest.NullString(fakeTits), + CareerLength: models.NullString(careerLength), + Country: models.NullString(country), + Ethnicity: models.NullString(ethnicity), + EyeColor: models.NullString(eyeColor), + FakeTits: models.NullString(fakeTits), Favorite: sql.NullBool{ Bool: true, Valid: true, }, - Gender: modelstest.NullString(gender), - Height: modelstest.NullString(height), - Instagram: modelstest.NullString(instagram), - Measurements: modelstest.NullString(measurements), - Piercings: modelstest.NullString(piercings), - Tattoos: modelstest.NullString(tattoos), - Twitter: modelstest.NullString(twitter), + Gender: models.NullString(gender), + Height: models.NullString(height), + Instagram: models.NullString(instagram), + Measurements: models.NullString(measurements), + Piercings: models.NullString(piercings), + Tattoos: models.NullString(tattoos), + Twitter: models.NullString(twitter), CreatedAt: models.SQLiteTimestamp{ Timestamp: createTime, }, @@ -170,9 +169,9 @@ func TestToJSON(t *testing.T) { imageErr := errors.New("error getting image") - mockPerformerReader.On("GetPerformerImage", performerID).Return(imageBytes, nil).Once() - mockPerformerReader.On("GetPerformerImage", noImageID).Return(nil, nil).Once() - mockPerformerReader.On("GetPerformerImage", errImageID).Return(nil, imageErr).Once() + mockPerformerReader.On("GetImage", performerID).Return(imageBytes, nil).Once() + mockPerformerReader.On("GetImage", noImageID).Return(nil, nil).Once() + mockPerformerReader.On("GetImage", errImageID).Return(nil, imageErr).Once() for i, s := range scenarios { tag := s.input diff --git a/pkg/performer/import.go b/pkg/performer/import.go index 605722577..b07c073e2 100644 --- a/pkg/performer/import.go +++ b/pkg/performer/import.go @@ -33,7 +33,7 @@ func (i *Importer) PreImport() error { func (i *Importer) PostImport(id int) error { if len(i.imageData) > 0 { - if err := i.ReaderWriter.UpdatePerformerImage(id, i.imageData); err != nil { + if err := i.ReaderWriter.UpdateImage(id, i.imageData); err != nil { return fmt.Errorf("error setting performer image: %s", err.Error()) } } diff --git a/pkg/performer/import_test.go b/pkg/performer/import_test.go index 5560ba265..43509702d 100644 --- a/pkg/performer/import_test.go +++ b/pkg/performer/import_test.go @@ -6,7 +6,6 @@ import ( "github.com/stashapp/stash/pkg/manager/jsonschema" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models/mocks" - "github.com/stashapp/stash/pkg/models/modelstest" "github.com/stashapp/stash/pkg/utils" "github.com/stretchr/testify/assert" @@ -62,10 +61,10 @@ func TestImporterPostImport(t *testing.T) { imageData: imageBytes, } - updatePerformerImageErr := errors.New("UpdatePerformerImage error") + updatePerformerImageErr := errors.New("UpdateImage error") - readerWriter.On("UpdatePerformerImage", performerID, imageBytes).Return(nil).Once() - readerWriter.On("UpdatePerformerImage", errImageID, imageBytes).Return(updatePerformerImageErr).Once() + readerWriter.On("UpdateImage", performerID, imageBytes).Return(nil).Once() + readerWriter.On("UpdateImage", errImageID, imageBytes).Return(updatePerformerImageErr).Once() err := i.PostImport(performerID) assert.Nil(t, err) @@ -116,11 +115,11 @@ func TestCreate(t *testing.T) { readerWriter := &mocks.PerformerReaderWriter{} performer := models.Performer{ - Name: modelstest.NullString(performerName), + Name: models.NullString(performerName), } performerErr := models.Performer{ - Name: modelstest.NullString(performerNameErr), + Name: models.NullString(performerNameErr), } i := Importer{ @@ -150,11 +149,11 @@ func TestUpdate(t *testing.T) { readerWriter := &mocks.PerformerReaderWriter{} performer := models.Performer{ - Name: modelstest.NullString(performerName), + Name: models.NullString(performerName), } performerErr := models.Performer{ - Name: modelstest.NullString(performerNameErr), + Name: models.NullString(performerNameErr), } i := Importer{ diff --git a/pkg/scene/export.go b/pkg/scene/export.go index c7e680eeb..95af51119 100644 --- a/pkg/scene/export.go +++ b/pkg/scene/export.go @@ -52,7 +52,7 @@ func ToBasicJSON(reader models.SceneReader, scene *models.Scene) (*jsonschema.Sc newSceneJSON.File = getSceneFileJSON(scene) - cover, err := reader.GetSceneCover(scene.ID) + cover, err := reader.GetCover(scene.ID) if err != nil { return nil, fmt.Errorf("error getting scene cover: %s", err.Error()) } @@ -165,7 +165,7 @@ func getTagNames(tags []*models.Tag) []string { } // GetDependentTagIDs returns a slice of unique tag IDs that this scene references. -func GetDependentTagIDs(tags models.TagReader, joins models.JoinReader, markerReader models.SceneMarkerReader, scene *models.Scene) ([]int, error) { +func GetDependentTagIDs(tags models.TagReader, markerReader models.SceneMarkerReader, scene *models.Scene) ([]int, error) { var ret []int t, err := tags.FindBySceneID(scene.ID) @@ -199,8 +199,8 @@ func GetDependentTagIDs(tags models.TagReader, joins models.JoinReader, markerRe // GetSceneMoviesJSON returns a slice of SceneMovie JSON representation objects // corresponding to the provided scene's scene movie relationships. -func GetSceneMoviesJSON(movieReader models.MovieReader, joinReader models.JoinReader, scene *models.Scene) ([]jsonschema.SceneMovie, error) { - sceneMovies, err := joinReader.GetSceneMovies(scene.ID) +func GetSceneMoviesJSON(movieReader models.MovieReader, sceneReader models.SceneReader, scene *models.Scene) ([]jsonschema.SceneMovie, error) { + sceneMovies, err := sceneReader.GetMovies(scene.ID) if err != nil { return nil, fmt.Errorf("error getting scene movies: %s", err.Error()) } @@ -225,10 +225,10 @@ func GetSceneMoviesJSON(movieReader models.MovieReader, joinReader models.JoinRe } // GetDependentMovieIDs returns a slice of movie IDs that this scene references. -func GetDependentMovieIDs(joins models.JoinReader, scene *models.Scene) ([]int, error) { +func GetDependentMovieIDs(sceneReader models.SceneReader, scene *models.Scene) ([]int, error) { var ret []int - m, err := joins.GetSceneMovies(scene.ID) + m, err := sceneReader.GetMovies(scene.ID) if err != nil { return nil, err } diff --git a/pkg/scene/export_test.go b/pkg/scene/export_test.go index 7ca7e2bb5..682d0a9f6 100644 --- a/pkg/scene/export_test.go +++ b/pkg/scene/export_test.go @@ -7,7 +7,6 @@ import ( "github.com/stashapp/stash/pkg/manager/jsonschema" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models/mocks" - "github.com/stashapp/stash/pkg/models/modelstest" "github.com/stretchr/testify/assert" "testing" @@ -92,33 +91,33 @@ var updateTime time.Time = time.Date(2002, 01, 01, 0, 0, 0, 0, time.UTC) func createFullScene(id int) models.Scene { return models.Scene{ ID: id, - Title: modelstest.NullString(title), - AudioCodec: modelstest.NullString(audioCodec), - Bitrate: modelstest.NullInt64(bitrate), - Checksum: modelstest.NullString(checksum), + Title: models.NullString(title), + AudioCodec: models.NullString(audioCodec), + Bitrate: models.NullInt64(bitrate), + Checksum: models.NullString(checksum), Date: models.SQLiteDate{ String: date, Valid: true, }, - Details: modelstest.NullString(details), + Details: models.NullString(details), Duration: sql.NullFloat64{ Float64: duration, Valid: true, }, - Format: modelstest.NullString(format), + Format: models.NullString(format), Framerate: sql.NullFloat64{ Float64: framerate, Valid: true, }, - Height: modelstest.NullInt64(height), + Height: models.NullInt64(height), OCounter: ocounter, - OSHash: modelstest.NullString(oshash), - Rating: modelstest.NullInt64(rating), + OSHash: models.NullString(oshash), + Rating: models.NullInt64(rating), Organized: organized, - Size: modelstest.NullString(size), - VideoCodec: modelstest.NullString(videoCodec), - Width: modelstest.NullInt64(width), - URL: modelstest.NullString(url), + Size: models.NullString(size), + VideoCodec: models.NullString(videoCodec), + Width: models.NullInt64(width), + URL: models.NullString(url), CreatedAt: models.SQLiteTimestamp{ Timestamp: createTime, }, @@ -213,9 +212,9 @@ func TestToJSON(t *testing.T) { imageErr := errors.New("error getting image") - mockSceneReader.On("GetSceneCover", sceneID).Return(imageBytes, nil).Once() - mockSceneReader.On("GetSceneCover", noImageID).Return(nil, nil).Once() - mockSceneReader.On("GetSceneCover", errImageID).Return(nil, imageErr).Once() + mockSceneReader.On("GetCover", sceneID).Return(imageBytes, nil).Once() + mockSceneReader.On("GetCover", noImageID).Return(nil, nil).Once() + mockSceneReader.On("GetCover", errImageID).Return(nil, imageErr).Once() for i, s := range scenarios { scene := s.input @@ -235,7 +234,7 @@ func TestToJSON(t *testing.T) { func createStudioScene(studioID int) models.Scene { return models.Scene{ - StudioID: modelstest.NullInt64(int64(studioID)), + StudioID: models.NullInt64(int64(studioID)), } } @@ -269,7 +268,7 @@ func TestGetStudioName(t *testing.T) { studioErr := errors.New("error getting image") mockStudioReader.On("Find", studioID).Return(&models.Studio{ - Name: modelstest.NullString(studioName), + Name: models.NullString(studioName), }, nil).Once() mockStudioReader.On("Find", missingStudioID).Return(nil, nil).Once() mockStudioReader.On("Find", errStudioID).Return(nil, studioErr).Once() @@ -436,44 +435,44 @@ var getSceneMoviesJSONScenarios = []sceneMoviesTestScenario{ var validMovies = []models.MoviesScenes{ { MovieID: validMovie1, - SceneIndex: modelstest.NullInt64(movie1Scene), + SceneIndex: models.NullInt64(movie1Scene), }, { MovieID: validMovie2, - SceneIndex: modelstest.NullInt64(movie2Scene), + SceneIndex: models.NullInt64(movie2Scene), }, } var invalidMovies = []models.MoviesScenes{ { MovieID: invalidMovie, - SceneIndex: modelstest.NullInt64(movie1Scene), + SceneIndex: models.NullInt64(movie1Scene), }, } func TestGetSceneMoviesJSON(t *testing.T) { mockMovieReader := &mocks.MovieReaderWriter{} - mockJoinReader := &mocks.JoinReaderWriter{} + mockSceneReader := &mocks.SceneReaderWriter{} joinErr := errors.New("error getting scene movies") movieErr := errors.New("error getting movie") - mockJoinReader.On("GetSceneMovies", sceneID).Return(validMovies, nil).Once() - mockJoinReader.On("GetSceneMovies", noMoviesID).Return(nil, nil).Once() - mockJoinReader.On("GetSceneMovies", errMoviesID).Return(nil, joinErr).Once() - mockJoinReader.On("GetSceneMovies", errFindMovieID).Return(invalidMovies, nil).Once() + mockSceneReader.On("GetMovies", sceneID).Return(validMovies, nil).Once() + mockSceneReader.On("GetMovies", noMoviesID).Return(nil, nil).Once() + mockSceneReader.On("GetMovies", errMoviesID).Return(nil, joinErr).Once() + mockSceneReader.On("GetMovies", errFindMovieID).Return(invalidMovies, nil).Once() mockMovieReader.On("Find", validMovie1).Return(&models.Movie{ - Name: modelstest.NullString(movie1Name), + Name: models.NullString(movie1Name), }, nil).Once() mockMovieReader.On("Find", validMovie2).Return(&models.Movie{ - Name: modelstest.NullString(movie2Name), + Name: models.NullString(movie2Name), }, nil).Once() mockMovieReader.On("Find", invalidMovie).Return(nil, movieErr).Once() for i, s := range getSceneMoviesJSONScenarios { scene := s.input - json, err := GetSceneMoviesJSON(mockMovieReader, mockJoinReader, &scene) + json, err := GetSceneMoviesJSON(mockMovieReader, mockSceneReader, &scene) if !s.err && err != nil { t.Errorf("[%d] unexpected error: %s", i, err.Error()) diff --git a/pkg/scene/import.go b/pkg/scene/import.go index 1e634d31f..b3bc52420 100644 --- a/pkg/scene/import.go +++ b/pkg/scene/import.go @@ -18,7 +18,6 @@ type Importer struct { PerformerWriter models.PerformerReaderWriter MovieWriter models.MovieReaderWriter TagWriter models.TagReaderWriter - JoinWriter models.JoinReaderWriter Input jsonschema.Scene Path string MissingRefBehaviour models.ImportMissingRefEnum @@ -329,7 +328,7 @@ func (i *Importer) populateTags() error { func (i *Importer) PostImport(id int) error { if len(i.coverImageData) > 0 { - if err := i.ReaderWriter.UpdateSceneCover(id, i.coverImageData); err != nil { + if err := i.ReaderWriter.UpdateCover(id, i.coverImageData); err != nil { return fmt.Errorf("error setting scene images: %s", err.Error()) } } @@ -343,15 +342,12 @@ func (i *Importer) PostImport(id int) error { } if len(i.performers) > 0 { - var performerJoins []models.PerformersScenes + var performerIDs []int for _, performer := range i.performers { - join := models.PerformersScenes{ - PerformerID: performer.ID, - SceneID: id, - } - performerJoins = append(performerJoins, join) + performerIDs = append(performerIDs, performer.ID) } - if err := i.JoinWriter.UpdatePerformersScenes(id, performerJoins); err != nil { + + if err := i.ReaderWriter.UpdatePerformers(id, performerIDs); err != nil { return fmt.Errorf("failed to associate performers: %s", err.Error()) } } @@ -360,21 +356,17 @@ func (i *Importer) PostImport(id int) error { for index := range i.movies { i.movies[index].SceneID = id } - if err := i.JoinWriter.UpdateMoviesScenes(id, i.movies); err != nil { + if err := i.ReaderWriter.UpdateMovies(id, i.movies); err != nil { return fmt.Errorf("failed to associate movies: %s", err.Error()) } } if len(i.tags) > 0 { - var tagJoins []models.ScenesTags - for _, tag := range i.tags { - join := models.ScenesTags{ - SceneID: id, - TagID: tag.ID, - } - tagJoins = append(tagJoins, join) + var tagIDs []int + for _, t := range i.tags { + tagIDs = append(tagIDs, t.ID) } - if err := i.JoinWriter.UpdateScenesTags(id, tagJoins); err != nil { + if err := i.ReaderWriter.UpdateTags(id, tagIDs); err != nil { return fmt.Errorf("failed to associate tags: %s", err.Error()) } } diff --git a/pkg/scene/import_test.go b/pkg/scene/import_test.go index e43e0ff43..971d7ae31 100644 --- a/pkg/scene/import_test.go +++ b/pkg/scene/import_test.go @@ -7,7 +7,6 @@ import ( "github.com/stashapp/stash/pkg/manager/jsonschema" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models/mocks" - "github.com/stashapp/stash/pkg/models/modelstest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) @@ -233,7 +232,7 @@ func TestImporterPreImportWithPerformer(t *testing.T) { performerReaderWriter.On("FindByNames", []string{existingPerformerName}, false).Return([]*models.Performer{ { ID: existingPerformerID, - Name: modelstest.NullString(existingPerformerName), + Name: models.NullString(existingPerformerName), }, }, nil).Once() performerReaderWriter.On("FindByNames", []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once() @@ -323,7 +322,7 @@ func TestImporterPreImportWithMovie(t *testing.T) { movieReaderWriter.On("FindByName", existingMovieName, false).Return(&models.Movie{ ID: existingMovieID, - Name: modelstest.NullString(existingMovieName), + Name: models.NullString(existingMovieName), }, nil).Once() movieReaderWriter.On("FindByName", existingMovieErr, false).Return(nil, errors.New("FindByName error")).Once() @@ -493,10 +492,10 @@ func TestImporterPostImport(t *testing.T) { coverImageData: imageBytes, } - updateSceneImageErr := errors.New("UpdateSceneCover error") + updateSceneImageErr := errors.New("UpdateCover error") - readerWriter.On("UpdateSceneCover", sceneID, imageBytes).Return(nil).Once() - readerWriter.On("UpdateSceneCover", errImageID, imageBytes).Return(updateSceneImageErr).Once() + readerWriter.On("UpdateCover", sceneID, imageBytes).Return(nil).Once() + readerWriter.On("UpdateCover", errImageID, imageBytes).Return(updateSceneImageErr).Once() err := i.PostImport(sceneID) assert.Nil(t, err) @@ -520,11 +519,11 @@ func TestImporterPostImportUpdateGallery(t *testing.T) { updateErr := errors.New("Update error") updateArg := *i.gallery - updateArg.SceneID = modelstest.NullInt64(sceneID) + updateArg.SceneID = models.NullInt64(sceneID) galleryReaderWriter.On("Update", updateArg).Return(nil, nil).Once() - updateArg.SceneID = modelstest.NullInt64(errGalleryID) + updateArg.SceneID = models.NullInt64(errGalleryID) galleryReaderWriter.On("Update", updateArg).Return(nil, updateErr).Once() err := i.PostImport(sceneID) @@ -537,10 +536,10 @@ func TestImporterPostImportUpdateGallery(t *testing.T) { } func TestImporterPostImportUpdatePerformers(t *testing.T) { - joinReaderWriter := &mocks.JoinReaderWriter{} + sceneReaderWriter := &mocks.SceneReaderWriter{} i := Importer{ - JoinWriter: joinReaderWriter, + ReaderWriter: sceneReaderWriter, performers: []*models.Performer{ { ID: existingPerformerID, @@ -548,15 +547,10 @@ func TestImporterPostImportUpdatePerformers(t *testing.T) { }, } - updateErr := errors.New("UpdatePerformersScenes error") + updateErr := errors.New("UpdatePerformers error") - joinReaderWriter.On("UpdatePerformersScenes", sceneID, []models.PerformersScenes{ - { - PerformerID: existingPerformerID, - SceneID: sceneID, - }, - }).Return(nil).Once() - joinReaderWriter.On("UpdatePerformersScenes", errPerformersID, mock.AnythingOfType("[]models.PerformersScenes")).Return(updateErr).Once() + sceneReaderWriter.On("UpdatePerformers", sceneID, []int{existingPerformerID}).Return(nil).Once() + sceneReaderWriter.On("UpdatePerformers", errPerformersID, mock.AnythingOfType("[]int")).Return(updateErr).Once() err := i.PostImport(sceneID) assert.Nil(t, err) @@ -564,14 +558,14 @@ func TestImporterPostImportUpdatePerformers(t *testing.T) { err = i.PostImport(errPerformersID) assert.NotNil(t, err) - joinReaderWriter.AssertExpectations(t) + sceneReaderWriter.AssertExpectations(t) } func TestImporterPostImportUpdateMovies(t *testing.T) { - joinReaderWriter := &mocks.JoinReaderWriter{} + sceneReaderWriter := &mocks.SceneReaderWriter{} i := Importer{ - JoinWriter: joinReaderWriter, + ReaderWriter: sceneReaderWriter, movies: []models.MoviesScenes{ { MovieID: existingMovieID, @@ -579,15 +573,15 @@ func TestImporterPostImportUpdateMovies(t *testing.T) { }, } - updateErr := errors.New("UpdateMoviesScenes error") + updateErr := errors.New("UpdateMovies error") - joinReaderWriter.On("UpdateMoviesScenes", sceneID, []models.MoviesScenes{ + sceneReaderWriter.On("UpdateMovies", sceneID, []models.MoviesScenes{ { MovieID: existingMovieID, SceneID: sceneID, }, }).Return(nil).Once() - joinReaderWriter.On("UpdateMoviesScenes", errMoviesID, mock.AnythingOfType("[]models.MoviesScenes")).Return(updateErr).Once() + sceneReaderWriter.On("UpdateMovies", errMoviesID, mock.AnythingOfType("[]models.MoviesScenes")).Return(updateErr).Once() err := i.PostImport(sceneID) assert.Nil(t, err) @@ -595,14 +589,14 @@ func TestImporterPostImportUpdateMovies(t *testing.T) { err = i.PostImport(errMoviesID) assert.NotNil(t, err) - joinReaderWriter.AssertExpectations(t) + sceneReaderWriter.AssertExpectations(t) } func TestImporterPostImportUpdateTags(t *testing.T) { - joinReaderWriter := &mocks.JoinReaderWriter{} + sceneReaderWriter := &mocks.SceneReaderWriter{} i := Importer{ - JoinWriter: joinReaderWriter, + ReaderWriter: sceneReaderWriter, tags: []*models.Tag{ { ID: existingTagID, @@ -610,15 +604,10 @@ func TestImporterPostImportUpdateTags(t *testing.T) { }, } - updateErr := errors.New("UpdateScenesTags error") + updateErr := errors.New("UpdateTags error") - joinReaderWriter.On("UpdateScenesTags", sceneID, []models.ScenesTags{ - { - TagID: existingTagID, - SceneID: sceneID, - }, - }).Return(nil).Once() - joinReaderWriter.On("UpdateScenesTags", errTagsID, mock.AnythingOfType("[]models.ScenesTags")).Return(updateErr).Once() + sceneReaderWriter.On("UpdateTags", sceneID, []int{existingTagID}).Return(nil).Once() + sceneReaderWriter.On("UpdateTags", errTagsID, mock.AnythingOfType("[]int")).Return(updateErr).Once() err := i.PostImport(sceneID) assert.Nil(t, err) @@ -626,7 +615,7 @@ func TestImporterPostImportUpdateTags(t *testing.T) { err = i.PostImport(errTagsID) assert.NotNil(t, err) - joinReaderWriter.AssertExpectations(t) + sceneReaderWriter.AssertExpectations(t) } func TestImporterFindExistingID(t *testing.T) { @@ -691,11 +680,11 @@ func TestCreate(t *testing.T) { readerWriter := &mocks.SceneReaderWriter{} scene := models.Scene{ - Title: modelstest.NullString(title), + Title: models.NullString(title), } sceneErr := models.Scene{ - Title: modelstest.NullString(sceneNameErr), + Title: models.NullString(sceneNameErr), } i := Importer{ @@ -726,11 +715,11 @@ func TestUpdate(t *testing.T) { readerWriter := &mocks.SceneReaderWriter{} scene := models.Scene{ - Title: modelstest.NullString(title), + Title: models.NullString(title), } sceneErr := models.Scene{ - Title: modelstest.NullString(sceneNameErr), + Title: models.NullString(sceneNameErr), } i := Importer{ diff --git a/pkg/scene/marker_import.go b/pkg/scene/marker_import.go index 9a559b384..d15563eb5 100644 --- a/pkg/scene/marker_import.go +++ b/pkg/scene/marker_import.go @@ -13,7 +13,6 @@ type MarkerImporter struct { SceneID int ReaderWriter models.SceneMarkerReaderWriter TagWriter models.TagReaderWriter - JoinWriter models.JoinReaderWriter Input jsonschema.SceneMarker MissingRefBehaviour models.ImportMissingRefEnum @@ -66,15 +65,11 @@ func (i *MarkerImporter) populateTags() error { func (i *MarkerImporter) PostImport(id int) error { if len(i.tags) > 0 { - var tagJoins []models.SceneMarkersTags - for _, tag := range i.tags { - join := models.SceneMarkersTags{ - SceneMarkerID: id, - TagID: tag.ID, - } - tagJoins = append(tagJoins, join) + var tagIDs []int + for _, t := range i.tags { + tagIDs = append(tagIDs, t.ID) } - if err := i.JoinWriter.UpdateSceneMarkersTags(id, tagJoins); err != nil { + if err := i.ReaderWriter.UpdateTags(id, tagIDs); err != nil { return fmt.Errorf("failed to associate tags: %s", err.Error()) } } diff --git a/pkg/scene/marker_import_test.go b/pkg/scene/marker_import_test.go index 23d0d7cf6..f70700a5c 100644 --- a/pkg/scene/marker_import_test.go +++ b/pkg/scene/marker_import_test.go @@ -71,10 +71,10 @@ func TestMarkerImporterPreImportWithTag(t *testing.T) { } func TestMarkerImporterPostImportUpdateTags(t *testing.T) { - joinReaderWriter := &mocks.JoinReaderWriter{} + sceneMarkerReaderWriter := &mocks.SceneMarkerReaderWriter{} i := MarkerImporter{ - JoinWriter: joinReaderWriter, + ReaderWriter: sceneMarkerReaderWriter, tags: []*models.Tag{ { ID: existingTagID, @@ -82,15 +82,10 @@ func TestMarkerImporterPostImportUpdateTags(t *testing.T) { }, } - updateErr := errors.New("UpdateSceneMarkersTags error") + updateErr := errors.New("UpdateTags error") - joinReaderWriter.On("UpdateSceneMarkersTags", sceneID, []models.SceneMarkersTags{ - { - TagID: existingTagID, - SceneMarkerID: sceneID, - }, - }).Return(nil).Once() - joinReaderWriter.On("UpdateSceneMarkersTags", errTagsID, mock.AnythingOfType("[]models.SceneMarkersTags")).Return(updateErr).Once() + sceneMarkerReaderWriter.On("UpdateTags", sceneID, []int{existingTagID}).Return(nil).Once() + sceneMarkerReaderWriter.On("UpdateTags", errTagsID, mock.AnythingOfType("[]int")).Return(updateErr).Once() err := i.PostImport(sceneID) assert.Nil(t, err) @@ -98,7 +93,7 @@ func TestMarkerImporterPostImportUpdateTags(t *testing.T) { err = i.PostImport(errTagsID) assert.NotNil(t, err) - joinReaderWriter.AssertExpectations(t) + sceneMarkerReaderWriter.AssertExpectations(t) } func TestMarkerImporterFindExistingID(t *testing.T) { diff --git a/pkg/scene/update.go b/pkg/scene/update.go new file mode 100644 index 000000000..940a1c92e --- /dev/null +++ b/pkg/scene/update.go @@ -0,0 +1,85 @@ +package scene + +import ( + "database/sql" + + "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/utils" +) + +func UpdateFormat(qb models.SceneWriter, id int, format string) (*models.Scene, error) { + return qb.Update(models.ScenePartial{ + ID: id, + Format: &sql.NullString{ + String: format, + Valid: true, + }, + }) +} + +func UpdateOSHash(qb models.SceneWriter, id int, oshash string) (*models.Scene, error) { + return qb.Update(models.ScenePartial{ + ID: id, + OSHash: &sql.NullString{ + String: oshash, + Valid: true, + }, + }) +} + +func UpdateChecksum(qb models.SceneWriter, id int, checksum string) (*models.Scene, error) { + return qb.Update(models.ScenePartial{ + ID: id, + Checksum: &sql.NullString{ + String: checksum, + Valid: true, + }, + }) +} + +func UpdateFileModTime(qb models.SceneWriter, id int, modTime models.NullSQLiteTimestamp) (*models.Scene, error) { + return qb.Update(models.ScenePartial{ + ID: id, + FileModTime: &modTime, + }) +} + +func AddPerformer(qb models.SceneReaderWriter, id int, performerID int) (bool, error) { + performerIDs, err := qb.GetPerformerIDs(id) + if err != nil { + return false, err + } + + oldLen := len(performerIDs) + performerIDs = utils.IntAppendUnique(performerIDs, performerID) + + if len(performerIDs) != oldLen { + if err := qb.UpdatePerformers(id, performerIDs); err != nil { + return false, err + } + + return true, nil + } + + return false, nil +} + +func AddTag(qb models.SceneReaderWriter, id int, tagID int) (bool, error) { + tagIDs, err := qb.GetTagIDs(id) + if err != nil { + return false, err + } + + oldLen := len(tagIDs) + tagIDs = utils.IntAppendUnique(tagIDs, tagID) + + if len(tagIDs) != oldLen { + if err := qb.UpdateTags(id, tagIDs); err != nil { + return false, err + } + + return true, nil + } + + return false, nil +} diff --git a/pkg/scraper/action.go b/pkg/scraper/action.go index d3937ea8b..ca7e82b2c 100644 --- a/pkg/scraper/action.go +++ b/pkg/scraper/action.go @@ -46,16 +46,16 @@ type scraper interface { scrapeMovieByURL(url string) (*models.ScrapedMovie, error) } -func getScraper(scraper scraperTypeConfig, config config, globalConfig GlobalConfig) scraper { +func getScraper(scraper scraperTypeConfig, txnManager models.TransactionManager, config config, globalConfig GlobalConfig) scraper { switch scraper.Action { case scraperActionScript: return newScriptScraper(scraper, config, globalConfig) case scraperActionStash: - return newStashScraper(scraper, config, globalConfig) + return newStashScraper(scraper, txnManager, config, globalConfig) case scraperActionXPath: - return newXpathScraper(scraper, config, globalConfig) + return newXpathScraper(scraper, txnManager, config, globalConfig) case scraperActionJson: - return newJsonScraper(scraper, config, globalConfig) + return newJsonScraper(scraper, txnManager, config, globalConfig) } panic("unknown scraper action: " + scraper.Action) diff --git a/pkg/scraper/config.go b/pkg/scraper/config.go index 2401869fc..c81fb2061 100644 --- a/pkg/scraper/config.go +++ b/pkg/scraper/config.go @@ -304,33 +304,33 @@ func (c config) matchesPerformerURL(url string) bool { return false } -func (c config) ScrapePerformerNames(name string, globalConfig GlobalConfig) ([]*models.ScrapedPerformer, error) { +func (c config) ScrapePerformerNames(name string, txnManager models.TransactionManager, globalConfig GlobalConfig) ([]*models.ScrapedPerformer, error) { if c.PerformerByName != nil { - s := getScraper(*c.PerformerByName, c, globalConfig) + s := getScraper(*c.PerformerByName, txnManager, c, globalConfig) return s.scrapePerformersByName(name) } return nil, nil } -func (c config) ScrapePerformer(scrapedPerformer models.ScrapedPerformerInput, globalConfig GlobalConfig) (*models.ScrapedPerformer, error) { +func (c config) ScrapePerformer(scrapedPerformer models.ScrapedPerformerInput, txnManager models.TransactionManager, globalConfig GlobalConfig) (*models.ScrapedPerformer, error) { if c.PerformerByFragment != nil { - s := getScraper(*c.PerformerByFragment, c, globalConfig) + s := getScraper(*c.PerformerByFragment, txnManager, c, globalConfig) return s.scrapePerformerByFragment(scrapedPerformer) } // try to match against URL if present if scrapedPerformer.URL != nil && *scrapedPerformer.URL != "" { - return c.ScrapePerformerURL(*scrapedPerformer.URL, globalConfig) + return c.ScrapePerformerURL(*scrapedPerformer.URL, txnManager, globalConfig) } return nil, nil } -func (c config) ScrapePerformerURL(url string, globalConfig GlobalConfig) (*models.ScrapedPerformer, error) { +func (c config) ScrapePerformerURL(url string, txnManager models.TransactionManager, globalConfig GlobalConfig) (*models.ScrapedPerformer, error) { for _, scraper := range c.PerformerByURL { if scraper.matchesURL(url) { - s := getScraper(scraper.scraperTypeConfig, c, globalConfig) + s := getScraper(scraper.scraperTypeConfig, txnManager, c, globalConfig) ret, err := s.scrapePerformerByURL(url) if err != nil { return nil, err @@ -386,19 +386,19 @@ func (c config) matchesMovieURL(url string) bool { return false } -func (c config) ScrapeScene(scene models.SceneUpdateInput, globalConfig GlobalConfig) (*models.ScrapedScene, error) { +func (c config) ScrapeScene(scene models.SceneUpdateInput, txnManager models.TransactionManager, globalConfig GlobalConfig) (*models.ScrapedScene, error) { if c.SceneByFragment != nil { - s := getScraper(*c.SceneByFragment, c, globalConfig) + s := getScraper(*c.SceneByFragment, txnManager, c, globalConfig) return s.scrapeSceneByFragment(scene) } return nil, nil } -func (c config) ScrapeSceneURL(url string, globalConfig GlobalConfig) (*models.ScrapedScene, error) { +func (c config) ScrapeSceneURL(url string, txnManager models.TransactionManager, globalConfig GlobalConfig) (*models.ScrapedScene, error) { for _, scraper := range c.SceneByURL { if scraper.matchesURL(url) { - s := getScraper(scraper.scraperTypeConfig, c, globalConfig) + s := getScraper(scraper.scraperTypeConfig, txnManager, c, globalConfig) ret, err := s.scrapeSceneByURL(url) if err != nil { return nil, err @@ -413,19 +413,19 @@ func (c config) ScrapeSceneURL(url string, globalConfig GlobalConfig) (*models.S return nil, nil } -func (c config) ScrapeGallery(gallery models.GalleryUpdateInput, globalConfig GlobalConfig) (*models.ScrapedGallery, error) { +func (c config) ScrapeGallery(gallery models.GalleryUpdateInput, txnManager models.TransactionManager, globalConfig GlobalConfig) (*models.ScrapedGallery, error) { if c.GalleryByFragment != nil { - s := getScraper(*c.GalleryByFragment, c, globalConfig) + s := getScraper(*c.GalleryByFragment, txnManager, c, globalConfig) return s.scrapeGalleryByFragment(gallery) } return nil, nil } -func (c config) ScrapeGalleryURL(url string, globalConfig GlobalConfig) (*models.ScrapedGallery, error) { +func (c config) ScrapeGalleryURL(url string, txnManager models.TransactionManager, globalConfig GlobalConfig) (*models.ScrapedGallery, error) { for _, scraper := range c.GalleryByURL { if scraper.matchesURL(url) { - s := getScraper(scraper.scraperTypeConfig, c, globalConfig) + s := getScraper(scraper.scraperTypeConfig, txnManager, c, globalConfig) ret, err := s.scrapeGalleryByURL(url) if err != nil { return nil, err @@ -440,10 +440,10 @@ func (c config) ScrapeGalleryURL(url string, globalConfig GlobalConfig) (*models return nil, nil } -func (c config) ScrapeMovieURL(url string, globalConfig GlobalConfig) (*models.ScrapedMovie, error) { +func (c config) ScrapeMovieURL(url string, txnManager models.TransactionManager, globalConfig GlobalConfig) (*models.ScrapedMovie, error) { for _, scraper := range c.MovieByURL { if scraper.matchesURL(url) { - s := getScraper(scraper.scraperTypeConfig, c, globalConfig) + s := getScraper(scraper.scraperTypeConfig, txnManager, c, globalConfig) ret, err := s.scrapeMovieByURL(url) if err != nil { return nil, err diff --git a/pkg/scraper/json.go b/pkg/scraper/json.go index 00590f5fe..b40b04e77 100644 --- a/pkg/scraper/json.go +++ b/pkg/scraper/json.go @@ -15,13 +15,15 @@ type jsonScraper struct { scraper scraperTypeConfig config config globalConfig GlobalConfig + txnManager models.TransactionManager } -func newJsonScraper(scraper scraperTypeConfig, config config, globalConfig GlobalConfig) *jsonScraper { +func newJsonScraper(scraper scraperTypeConfig, txnManager models.TransactionManager, config config, globalConfig GlobalConfig) *jsonScraper { return &jsonScraper{ scraper: scraper, config: config, globalConfig: globalConfig, + txnManager: txnManager, } } @@ -138,7 +140,7 @@ func (s *jsonScraper) scrapePerformerByFragment(scrapedPerformer models.ScrapedP } func (s *jsonScraper) scrapeSceneByFragment(scene models.SceneUpdateInput) (*models.ScrapedScene, error) { - storedScene, err := sceneFromUpdateFragment(scene) + storedScene, err := sceneFromUpdateFragment(scene, s.txnManager) if err != nil { return nil, err } @@ -171,7 +173,7 @@ func (s *jsonScraper) scrapeSceneByFragment(scene models.SceneUpdateInput) (*mod } func (s *jsonScraper) scrapeGalleryByFragment(gallery models.GalleryUpdateInput) (*models.ScrapedGallery, error) { - storedGallery, err := galleryFromUpdateFragment(gallery) + storedGallery, err := galleryFromUpdateFragment(gallery, s.txnManager) if err != nil { return nil, err } diff --git a/pkg/scraper/matchers.go b/pkg/scraper/matchers.go new file mode 100644 index 000000000..3fcba3a48 --- /dev/null +++ b/pkg/scraper/matchers.go @@ -0,0 +1,83 @@ +package scraper + +import ( + "strconv" + + "github.com/stashapp/stash/pkg/models" +) + +// MatchScrapedScenePerformer matches the provided performer with the +// performers in the database and sets the ID field if one is found. +func MatchScrapedScenePerformer(qb models.PerformerReader, p *models.ScrapedScenePerformer) error { + performers, err := qb.FindByNames([]string{p.Name}, true) + + if err != nil { + return err + } + + if len(performers) != 1 { + // ignore - cannot match + return nil + } + + id := strconv.Itoa(performers[0].ID) + p.ID = &id + return nil +} + +// MatchScrapedSceneStudio matches the provided studio with the studios +// in the database and sets the ID field if one is found. +func MatchScrapedSceneStudio(qb models.StudioReader, s *models.ScrapedSceneStudio) error { + studio, err := qb.FindByName(s.Name, true) + + if err != nil { + return err + } + + if studio == nil { + // ignore - cannot match + return nil + } + + id := strconv.Itoa(studio.ID) + s.ID = &id + return nil +} + +// MatchScrapedSceneMovie matches the provided movie with the movies +// in the database and sets the ID field if one is found. +func MatchScrapedSceneMovie(qb models.MovieReader, m *models.ScrapedSceneMovie) error { + movies, err := qb.FindByNames([]string{m.Name}, true) + + if err != nil { + return err + } + + if len(movies) != 1 { + // ignore - cannot match + return nil + } + + id := strconv.Itoa(movies[0].ID) + m.ID = &id + return nil +} + +// MatchScrapedSceneTag matches the provided tag with the tags +// in the database and sets the ID field if one is found. +func MatchScrapedSceneTag(qb models.TagReader, s *models.ScrapedSceneTag) error { + tag, err := qb.FindByName(s.Name, true) + + if err != nil { + return err + } + + if tag == nil { + // ignore - cannot match + return nil + } + + id := strconv.Itoa(tag.ID) + s.ID = &id + return nil +} diff --git a/pkg/scraper/scrapers.go b/pkg/scraper/scrapers.go index 23918d684..9c40e2a77 100644 --- a/pkg/scraper/scrapers.go +++ b/pkg/scraper/scrapers.go @@ -1,6 +1,7 @@ package scraper import ( + "context" "errors" "os" "path/filepath" @@ -34,6 +35,7 @@ func (c GlobalConfig) isCDPPathWS() bool { type Cache struct { scrapers []config globalConfig GlobalConfig + txnManager models.TransactionManager } // NewCache returns a new Cache loading scraper configurations from the @@ -42,7 +44,7 @@ type Cache struct { // // Scraper configurations are loaded from yml files in the provided scrapers // directory and any subdirectories. -func NewCache(globalConfig GlobalConfig) (*Cache, error) { +func NewCache(globalConfig GlobalConfig, txnManager models.TransactionManager) (*Cache, error) { scrapers, err := loadScrapers(globalConfig.Path) if err != nil { return nil, err @@ -51,6 +53,7 @@ func NewCache(globalConfig GlobalConfig) (*Cache, error) { return &Cache{ globalConfig: globalConfig, scrapers: scrapers, + txnManager: txnManager, }, nil } @@ -178,7 +181,7 @@ func (c Cache) ScrapePerformerList(scraperID string, query string) ([]*models.Sc // find scraper with the provided id s := c.findScraper(scraperID) if s != nil { - return s.ScrapePerformerNames(query, c.globalConfig) + return s.ScrapePerformerNames(query, c.txnManager, c.globalConfig) } return nil, errors.New("Scraper with ID " + scraperID + " not found") @@ -190,7 +193,7 @@ func (c Cache) ScrapePerformer(scraperID string, scrapedPerformer models.Scraped // find scraper with the provided id s := c.findScraper(scraperID) if s != nil { - ret, err := s.ScrapePerformer(scrapedPerformer, c.globalConfig) + ret, err := s.ScrapePerformer(scrapedPerformer, c.txnManager, c.globalConfig) if err != nil { return nil, err } @@ -212,7 +215,7 @@ func (c Cache) ScrapePerformer(scraperID string, scrapedPerformer models.Scraped func (c Cache) ScrapePerformerURL(url string) (*models.ScrapedPerformer, error) { for _, s := range c.scrapers { if s.matchesPerformerURL(url) { - ret, err := s.ScrapePerformerURL(url, c.globalConfig) + ret, err := s.ScrapePerformerURL(url, c.txnManager, c.globalConfig) if err != nil { return nil, err } @@ -230,32 +233,43 @@ func (c Cache) ScrapePerformerURL(url string) (*models.ScrapedPerformer, error) } func (c Cache) postScrapeScene(ret *models.ScrapedScene) error { - for _, p := range ret.Performers { - err := models.MatchScrapedScenePerformer(p) - if err != nil { - return err - } - } + if err := c.txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { + pqb := r.Performer() + mqb := r.Movie() + tqb := r.Tag() + sqb := r.Studio() - for _, p := range ret.Movies { - err := models.MatchScrapedSceneMovie(p) - if err != nil { - return err + for _, p := range ret.Performers { + err := MatchScrapedScenePerformer(pqb, p) + if err != nil { + return err + } } - } - for _, t := range ret.Tags { - err := models.MatchScrapedSceneTag(t) - if err != nil { - return err + for _, p := range ret.Movies { + err := MatchScrapedSceneMovie(mqb, p) + if err != nil { + return err + } } - } - if ret.Studio != nil { - err := models.MatchScrapedSceneStudio(ret.Studio) - if err != nil { - return err + for _, t := range ret.Tags { + err := MatchScrapedSceneTag(tqb, t) + if err != nil { + return err + } } + + if ret.Studio != nil { + err := MatchScrapedSceneStudio(sqb, ret.Studio) + if err != nil { + return err + } + } + + return nil + }); err != nil { + return err } // post-process - set the image if applicable @@ -267,25 +281,35 @@ func (c Cache) postScrapeScene(ret *models.ScrapedScene) error { } func (c Cache) postScrapeGallery(ret *models.ScrapedGallery) error { - for _, p := range ret.Performers { - err := models.MatchScrapedScenePerformer(p) - if err != nil { - return err - } - } + if err := c.txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { + pqb := r.Performer() + tqb := r.Tag() + sqb := r.Studio() - for _, t := range ret.Tags { - err := models.MatchScrapedSceneTag(t) - if err != nil { - return err + for _, p := range ret.Performers { + err := MatchScrapedScenePerformer(pqb, p) + if err != nil { + return err + } } - } - if ret.Studio != nil { - err := models.MatchScrapedSceneStudio(ret.Studio) - if err != nil { - return err + for _, t := range ret.Tags { + err := MatchScrapedSceneTag(tqb, t) + if err != nil { + return err + } } + + if ret.Studio != nil { + err := MatchScrapedSceneStudio(sqb, ret.Studio) + if err != nil { + return err + } + } + + return nil + }); err != nil { + return err } return nil @@ -296,7 +320,7 @@ func (c Cache) ScrapeScene(scraperID string, scene models.SceneUpdateInput) (*mo // find scraper with the provided id s := c.findScraper(scraperID) if s != nil { - ret, err := s.ScrapeScene(scene, c.globalConfig) + ret, err := s.ScrapeScene(scene, c.txnManager, c.globalConfig) if err != nil { return nil, err @@ -321,7 +345,7 @@ func (c Cache) ScrapeScene(scraperID string, scene models.SceneUpdateInput) (*mo func (c Cache) ScrapeSceneURL(url string) (*models.ScrapedScene, error) { for _, s := range c.scrapers { if s.matchesSceneURL(url) { - ret, err := s.ScrapeSceneURL(url, c.globalConfig) + ret, err := s.ScrapeSceneURL(url, c.txnManager, c.globalConfig) if err != nil { return nil, err @@ -343,7 +367,7 @@ func (c Cache) ScrapeSceneURL(url string) (*models.ScrapedScene, error) { func (c Cache) ScrapeGallery(scraperID string, gallery models.GalleryUpdateInput) (*models.ScrapedGallery, error) { s := c.findScraper(scraperID) if s != nil { - ret, err := s.ScrapeGallery(gallery, c.globalConfig) + ret, err := s.ScrapeGallery(gallery, c.txnManager, c.globalConfig) if err != nil { return nil, err @@ -368,7 +392,7 @@ func (c Cache) ScrapeGallery(scraperID string, gallery models.GalleryUpdateInput func (c Cache) ScrapeGalleryURL(url string) (*models.ScrapedGallery, error) { for _, s := range c.scrapers { if s.matchesGalleryURL(url) { - ret, err := s.ScrapeGalleryURL(url, c.globalConfig) + ret, err := s.ScrapeGalleryURL(url, c.txnManager, c.globalConfig) if err != nil { return nil, err @@ -386,10 +410,8 @@ func (c Cache) ScrapeGalleryURL(url string) (*models.ScrapedGallery, error) { return nil, nil } -func matchMovieStudio(s *models.ScrapedMovieStudio) error { - qb := models.NewStudioQueryBuilder() - - studio, err := qb.FindByName(s.Name, nil, true) +func matchMovieStudio(qb models.StudioReader, s *models.ScrapedMovieStudio) error { + studio, err := qb.FindByName(s.Name, true) if err != nil { return err @@ -411,14 +433,15 @@ func matchMovieStudio(s *models.ScrapedMovieStudio) error { func (c Cache) ScrapeMovieURL(url string) (*models.ScrapedMovie, error) { for _, s := range c.scrapers { if s.matchesMovieURL(url) { - ret, err := s.ScrapeMovieURL(url, c.globalConfig) + ret, err := s.ScrapeMovieURL(url, c.txnManager, c.globalConfig) if err != nil { return nil, err } if ret.Studio != nil { - err := matchMovieStudio(ret.Studio) - if err != nil { + if err := c.txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { + return matchMovieStudio(r.Studio(), ret.Studio) + }); err != nil { return nil, err } } diff --git a/pkg/scraper/stash.go b/pkg/scraper/stash.go index 873828e54..a53afcab5 100644 --- a/pkg/scraper/stash.go +++ b/pkg/scraper/stash.go @@ -15,13 +15,15 @@ type stashScraper struct { scraper scraperTypeConfig config config globalConfig GlobalConfig + txnManager models.TransactionManager } -func newStashScraper(scraper scraperTypeConfig, config config, globalConfig GlobalConfig) *stashScraper { +func newStashScraper(scraper scraperTypeConfig, txnManager models.TransactionManager, config config, globalConfig GlobalConfig) *stashScraper { return &stashScraper{ scraper: scraper, config: config, globalConfig: globalConfig, + txnManager: txnManager, } } @@ -117,15 +119,17 @@ func (s *stashScraper) scrapePerformerByFragment(scrapedPerformer models.Scraped func (s *stashScraper) scrapeSceneByFragment(scene models.SceneUpdateInput) (*models.ScrapedScene, error) { // query by MD5 // assumes that the scene exists in the database - qb := models.NewSceneQueryBuilder() id, err := strconv.Atoi(scene.ID) if err != nil { return nil, err } - storedScene, err := qb.Find(id) - - if err != nil { + var storedScene *models.Scene + if err := s.txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { + var err error + storedScene, err = r.Scene().Find(id) + return err + }); err != nil { return nil, err } @@ -185,17 +189,21 @@ func (s *stashScraper) scrapeSceneByFragment(scene models.SceneUpdateInput) (*mo } func (s *stashScraper) scrapeGalleryByFragment(scene models.GalleryUpdateInput) (*models.ScrapedGallery, error) { - // query by MD5 - // assumes that the gallery exists in the database - qb := models.NewGalleryQueryBuilder() id, err := strconv.Atoi(scene.ID) if err != nil { return nil, err } - storedGallery, err := qb.Find(id, nil) + // query by MD5 + // assumes that the gallery exists in the database + var storedGallery *models.Gallery + if err := s.txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { + qb := r.Gallery() - if err != nil { + var err error + storedGallery, err = qb.Find(id) + return err + }); err != nil { return nil, err } @@ -262,23 +270,36 @@ func (s *stashScraper) scrapeMovieByURL(url string) (*models.ScrapedMovie, error return nil, errors.New("scrapeMovieByURL not supported for stash scraper") } -func sceneFromUpdateFragment(scene models.SceneUpdateInput) (*models.Scene, error) { - qb := models.NewSceneQueryBuilder() +func sceneFromUpdateFragment(scene models.SceneUpdateInput, txnManager models.TransactionManager) (*models.Scene, error) { id, err := strconv.Atoi(scene.ID) if err != nil { return nil, err } // TODO - should we modify it with the input? - return qb.Find(id) + var ret *models.Scene + if err := txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { + var err error + ret, err = r.Scene().Find(id) + return err + }); err != nil { + return nil, err + } + return ret, nil } -func galleryFromUpdateFragment(gallery models.GalleryUpdateInput) (*models.Gallery, error) { - qb := models.NewGalleryQueryBuilder() +func galleryFromUpdateFragment(gallery models.GalleryUpdateInput, txnManager models.TransactionManager) (ret *models.Gallery, err error) { id, err := strconv.Atoi(gallery.ID) if err != nil { return nil, err } - return qb.Find(id, nil) + if err := txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { + ret, err = r.Gallery().Find(id) + return err + }); err != nil { + return nil, err + } + + return ret, nil } diff --git a/pkg/scraper/stashbox/stash_box.go b/pkg/scraper/stashbox/stash_box.go index 4b99d557d..77ae492f6 100644 --- a/pkg/scraper/stashbox/stash_box.go +++ b/pkg/scraper/stashbox/stash_box.go @@ -13,6 +13,7 @@ import ( "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/scraper" "github.com/stashapp/stash/pkg/scraper/stashbox/graphql" "github.com/stashapp/stash/pkg/utils" ) @@ -23,11 +24,12 @@ const imageGetTimeout = time.Second * 30 // Client represents the client interface to a stash-box server instance. type Client struct { - client *graphql.Client + client *graphql.Client + txnManager models.TransactionManager } // NewClient returns a new instance of a stash-box client. -func NewClient(box models.StashBox) *Client { +func NewClient(box models.StashBox, txnManager models.TransactionManager) *Client { authHeader := func(req *http.Request) { req.Header.Set("ApiKey", box.APIKey) } @@ -37,7 +39,8 @@ func NewClient(box models.StashBox) *Client { } return &Client{ - client: client, + client: client, + txnManager: txnManager, } } @@ -52,7 +55,7 @@ func (c Client) QueryStashBoxScene(queryStr string) ([]*models.ScrapedScene, err var ret []*models.ScrapedScene for _, s := range sceneFragments { - ss, err := sceneFragmentToScrapedScene(s) + ss, err := sceneFragmentToScrapedScene(c.txnManager, s) if err != nil { return nil, err } @@ -65,28 +68,38 @@ func (c Client) QueryStashBoxScene(queryStr string) ([]*models.ScrapedScene, err // FindStashBoxScenesByFingerprints queries stash-box for scenes using every // scene's MD5 checksum and/or oshash. func (c Client) FindStashBoxScenesByFingerprints(sceneIDs []string) ([]*models.ScrapedScene, error) { - qb := models.NewSceneQueryBuilder() + ids, err := utils.StringSliceToIntSlice(sceneIDs) + if err != nil { + return nil, err + } var fingerprints []string - for _, sceneID := range sceneIDs { - idInt, _ := strconv.Atoi(sceneID) - scene, err := qb.Find(idInt) - if err != nil { - return nil, err + if err := c.txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { + qb := r.Scene() + + for _, sceneID := range ids { + scene, err := qb.Find(sceneID) + if err != nil { + return err + } + + if scene == nil { + return fmt.Errorf("scene with id %d not found", sceneID) + } + + if scene.Checksum.Valid { + fingerprints = append(fingerprints, scene.Checksum.String) + } + + if scene.OSHash.Valid { + fingerprints = append(fingerprints, scene.OSHash.String) + } } - if scene == nil { - return nil, fmt.Errorf("scene with id %d not found", idInt) - } - - if scene.Checksum.Valid { - fingerprints = append(fingerprints, scene.Checksum.String) - } - - if scene.OSHash.Valid { - fingerprints = append(fingerprints, scene.OSHash.String) - } + return nil + }); err != nil { + return nil, err } return c.findStashBoxScenesByFingerprints(fingerprints) @@ -108,7 +121,7 @@ func (c Client) findStashBoxScenesByFingerprints(fingerprints []string) ([]*mode sceneFragments := scenes.FindScenesByFingerprints for _, s := range sceneFragments { - ss, err := sceneFragmentToScrapedScene(s) + ss, err := sceneFragmentToScrapedScene(c.txnManager, s) if err != nil { return nil, err } @@ -120,59 +133,68 @@ func (c Client) findStashBoxScenesByFingerprints(fingerprints []string) ([]*mode } func (c Client) SubmitStashBoxFingerprints(sceneIDs []string, endpoint string) (bool, error) { - qb := models.NewSceneQueryBuilder() - jqb := models.NewJoinsQueryBuilder() + ids, err := utils.StringSliceToIntSlice(sceneIDs) + if err != nil { + return false, err + } var fingerprints []graphql.FingerprintSubmission - for _, sceneID := range sceneIDs { - idInt, _ := strconv.Atoi(sceneID) - scene, err := qb.Find(idInt) - if err != nil { - return false, err - } + if err := c.txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { + qb := r.Scene() - if scene == nil { - continue - } - - stashIDs, err := jqb.GetSceneStashIDs(idInt) - if err != nil { - return false, err - } - - sceneStashID := "" - for _, stashID := range stashIDs { - if stashID.Endpoint == endpoint { - sceneStashID = stashID.StashID + for _, sceneID := range ids { + scene, err := qb.Find(sceneID) + if err != nil { + return err } - } - if sceneStashID != "" { - if scene.Checksum.Valid && scene.Duration.Valid { - fingerprint := graphql.FingerprintInput{ - Hash: scene.Checksum.String, - Algorithm: graphql.FingerprintAlgorithmMd5, - Duration: int(scene.Duration.Float64), + if scene == nil { + continue + } + + stashIDs, err := qb.GetStashIDs(sceneID) + if err != nil { + return err + } + + sceneStashID := "" + for _, stashID := range stashIDs { + if stashID.Endpoint == endpoint { + sceneStashID = stashID.StashID } - fingerprints = append(fingerprints, graphql.FingerprintSubmission{ - SceneID: sceneStashID, - Fingerprint: &fingerprint, - }) } - if scene.OSHash.Valid && scene.Duration.Valid { - fingerprint := graphql.FingerprintInput{ - Hash: scene.OSHash.String, - Algorithm: graphql.FingerprintAlgorithmOshash, - Duration: int(scene.Duration.Float64), + if sceneStashID != "" { + if scene.Checksum.Valid && scene.Duration.Valid { + fingerprint := graphql.FingerprintInput{ + Hash: scene.Checksum.String, + Algorithm: graphql.FingerprintAlgorithmMd5, + Duration: int(scene.Duration.Float64), + } + fingerprints = append(fingerprints, graphql.FingerprintSubmission{ + SceneID: sceneStashID, + Fingerprint: &fingerprint, + }) + } + + if scene.OSHash.Valid && scene.Duration.Valid { + fingerprint := graphql.FingerprintInput{ + Hash: scene.OSHash.String, + Algorithm: graphql.FingerprintAlgorithmOshash, + Duration: int(scene.Duration.Float64), + } + fingerprints = append(fingerprints, graphql.FingerprintSubmission{ + SceneID: sceneStashID, + Fingerprint: &fingerprint, + }) } - fingerprints = append(fingerprints, graphql.FingerprintSubmission{ - SceneID: sceneStashID, - Fingerprint: &fingerprint, - }) } } + + return nil + }); err != nil { + return false, err } return c.submitStashBoxFingerprints(fingerprints) @@ -355,7 +377,7 @@ func getFingerprints(scene *graphql.SceneFragment) []*models.StashBoxFingerprint return fingerprints } -func sceneFragmentToScrapedScene(s *graphql.SceneFragment) (*models.ScrapedScene, error) { +func sceneFragmentToScrapedScene(txnManager models.TransactionManager, s *graphql.SceneFragment) (*models.ScrapedScene, error) { stashID := s.ID ss := &models.ScrapedScene{ Title: s.Title, @@ -375,42 +397,51 @@ func sceneFragmentToScrapedScene(s *graphql.SceneFragment) (*models.ScrapedScene ss.Image = getFirstImage(s.Images) } - if s.Studio != nil { - studioID := s.Studio.ID - ss.Studio = &models.ScrapedSceneStudio{ - Name: s.Studio.Name, - URL: findURL(s.Studio.Urls, "HOME"), - RemoteSiteID: &studioID, + if err := txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error { + pqb := r.Performer() + tqb := r.Tag() + + if s.Studio != nil { + studioID := s.Studio.ID + ss.Studio = &models.ScrapedSceneStudio{ + Name: s.Studio.Name, + URL: findURL(s.Studio.Urls, "HOME"), + RemoteSiteID: &studioID, + } + + err := scraper.MatchScrapedSceneStudio(r.Studio(), ss.Studio) + if err != nil { + return err + } } - err := models.MatchScrapedSceneStudio(ss.Studio) - if err != nil { - return nil, err - } - } + for _, p := range s.Performers { + sp := performerFragmentToScrapedScenePerformer(p.Performer) - for _, p := range s.Performers { - sp := performerFragmentToScrapedScenePerformer(p.Performer) + err := scraper.MatchScrapedScenePerformer(pqb, sp) + if err != nil { + return err + } - err := models.MatchScrapedScenePerformer(sp) - if err != nil { - return nil, err + ss.Performers = append(ss.Performers, sp) } - ss.Performers = append(ss.Performers, sp) - } + for _, t := range s.Tags { + st := &models.ScrapedSceneTag{ + Name: t.Name, + } - for _, t := range s.Tags { - st := &models.ScrapedSceneTag{ - Name: t.Name, + err := scraper.MatchScrapedSceneTag(tqb, st) + if err != nil { + return err + } + + ss.Tags = append(ss.Tags, st) } - err := models.MatchScrapedSceneTag(st) - if err != nil { - return nil, err - } - - ss.Tags = append(ss.Tags, st) + return nil + }); err != nil { + return nil, err } return ss, nil diff --git a/pkg/scraper/xpath.go b/pkg/scraper/xpath.go index 787187bd0..f4f55bcdc 100644 --- a/pkg/scraper/xpath.go +++ b/pkg/scraper/xpath.go @@ -19,13 +19,15 @@ type xpathScraper struct { scraper scraperTypeConfig config config globalConfig GlobalConfig + txnManager models.TransactionManager } -func newXpathScraper(scraper scraperTypeConfig, config config, globalConfig GlobalConfig) *xpathScraper { +func newXpathScraper(scraper scraperTypeConfig, txnManager models.TransactionManager, config config, globalConfig GlobalConfig) *xpathScraper { return &xpathScraper{ scraper: scraper, config: config, globalConfig: globalConfig, + txnManager: txnManager, } } @@ -119,7 +121,7 @@ func (s *xpathScraper) scrapePerformerByFragment(scrapedPerformer models.Scraped } func (s *xpathScraper) scrapeSceneByFragment(scene models.SceneUpdateInput) (*models.ScrapedScene, error) { - storedScene, err := sceneFromUpdateFragment(scene) + storedScene, err := sceneFromUpdateFragment(scene, s.txnManager) if err != nil { return nil, err } @@ -152,7 +154,7 @@ func (s *xpathScraper) scrapeSceneByFragment(scene models.SceneUpdateInput) (*mo } func (s *xpathScraper) scrapeGalleryByFragment(gallery models.GalleryUpdateInput) (*models.ScrapedGallery, error) { - storedGallery, err := galleryFromUpdateFragment(gallery) + storedGallery, err := galleryFromUpdateFragment(gallery, s.txnManager) if err != nil { return nil, err } diff --git a/pkg/scraper/xpath_test.go b/pkg/scraper/xpath_test.go index bd81b121c..f5950460c 100644 --- a/pkg/scraper/xpath_test.go +++ b/pkg/scraper/xpath_test.go @@ -807,7 +807,7 @@ xPathScrapers: globalConfig := GlobalConfig{} - performer, err := c.ScrapePerformerURL(ts.URL, globalConfig) + performer, err := c.ScrapePerformerURL(ts.URL, nil, globalConfig) if err != nil { t.Errorf("Error scraping performer: %s", err.Error()) diff --git a/pkg/models/querybuilder_gallery.go b/pkg/sqlite/gallery.go similarity index 53% rename from pkg/models/querybuilder_gallery.go rename to pkg/sqlite/gallery.go index 4a8c8ff2d..cdc3ae0c3 100644 --- a/pkg/models/querybuilder_gallery.go +++ b/pkg/sqlite/gallery.go @@ -1,4 +1,4 @@ -package models +package sqlite import ( "database/sql" @@ -6,105 +6,79 @@ import ( "path/filepath" "strconv" - "github.com/jmoiron/sqlx" - "github.com/stashapp/stash/pkg/database" + "github.com/stashapp/stash/pkg/models" ) const galleryTable = "galleries" -type GalleryQueryBuilder struct{} +const performersGalleriesTable = "performers_galleries" +const galleriesTagsTable = "galleries_tags" +const galleriesImagesTable = "galleries_images" +const galleryIDColumn = "gallery_id" -func NewGalleryQueryBuilder() GalleryQueryBuilder { - return GalleryQueryBuilder{} +type galleryQueryBuilder struct { + repository } -func (qb *GalleryQueryBuilder) Create(newGallery Gallery, tx *sqlx.Tx) (*Gallery, error) { - ensureTx(tx) - result, err := tx.NamedExec( - `INSERT INTO galleries (path, checksum, zip, title, date, details, url, studio_id, rating, organized, scene_id, file_mod_time, created_at, updated_at) - VALUES (:path, :checksum, :zip, :title, :date, :details, :url, :studio_id, :rating, :organized, :scene_id, :file_mod_time, :created_at, :updated_at) - `, - newGallery, - ) - if err != nil { - return nil, err +func NewGalleryReaderWriter(tx dbi) *galleryQueryBuilder { + return &galleryQueryBuilder{ + repository{ + tx: tx, + tableName: galleryTable, + idColumn: idColumn, + }, } - galleryID, err := result.LastInsertId() - if err != nil { - return nil, err - } - if err := tx.Get(&newGallery, `SELECT * FROM galleries WHERE id = ? LIMIT 1`, galleryID); err != nil { - return nil, err - } - return &newGallery, nil } -func (qb *GalleryQueryBuilder) Update(updatedGallery Gallery, tx *sqlx.Tx) (*Gallery, error) { - ensureTx(tx) - _, err := tx.NamedExec( - `UPDATE galleries SET `+SQLGenKeys(updatedGallery)+` WHERE galleries.id = :id`, - updatedGallery, - ) - if err != nil { +func (qb *galleryQueryBuilder) Create(newObject models.Gallery) (*models.Gallery, error) { + var ret models.Gallery + if err := qb.insertObject(newObject, &ret); err != nil { return nil, err } - if err := tx.Get(&updatedGallery, `SELECT * FROM galleries WHERE id = ? LIMIT 1`, updatedGallery.ID); err != nil { - return nil, err - } - return &updatedGallery, nil + return &ret, nil } -func (qb *GalleryQueryBuilder) UpdatePartial(updatedGallery GalleryPartial, tx *sqlx.Tx) (*Gallery, error) { - ensureTx(tx) - _, err := tx.NamedExec( - `UPDATE galleries SET `+SQLGenKeysPartial(updatedGallery)+` WHERE galleries.id = :id`, - updatedGallery, - ) - if err != nil { +func (qb *galleryQueryBuilder) Update(updatedObject models.Gallery) (*models.Gallery, error) { + const partial = false + if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil { return nil, err } - return qb.Find(updatedGallery.ID, tx) + return qb.Find(updatedObject.ID) } -func (qb *GalleryQueryBuilder) UpdateChecksum(id int, checksum string, tx *sqlx.Tx) error { - ensureTx(tx) - _, err := tx.Exec( - `UPDATE galleries SET checksum = ? WHERE galleries.id = ? `, - checksum, id, - ) - if err != nil { - return err +func (qb *galleryQueryBuilder) UpdatePartial(updatedObject models.GalleryPartial) (*models.Gallery, error) { + const partial = true + if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil { + return nil, err } - return nil + return qb.Find(updatedObject.ID) } -func (qb *GalleryQueryBuilder) UpdateFileModTime(id int, modTime NullSQLiteTimestamp, tx *sqlx.Tx) error { - ensureTx(tx) - _, err := tx.Exec( - `UPDATE galleries SET file_mod_time = ? WHERE galleries.id = ? `, - modTime, id, - ) - if err != nil { - return err - } - - return nil +func (qb *galleryQueryBuilder) UpdateChecksum(id int, checksum string) error { + return qb.updateMap(id, map[string]interface{}{ + "checksum": checksum, + }) } -func (qb *GalleryQueryBuilder) Destroy(id int, tx *sqlx.Tx) error { - return executeDeleteQuery("galleries", strconv.Itoa(id), tx) +func (qb *galleryQueryBuilder) UpdateFileModTime(id int, modTime models.NullSQLiteTimestamp) error { + return qb.updateMap(id, map[string]interface{}{ + "file_mod_time": modTime, + }) +} + +func (qb *galleryQueryBuilder) Destroy(id int) error { + return qb.destroyExisting([]int{id}) } type GalleryNullSceneID struct { SceneID sql.NullInt64 } -func (qb *GalleryQueryBuilder) ClearGalleryId(sceneID int, tx *sqlx.Tx) error { - ensureTx(tx) - _, err := tx.NamedExec( +func (qb *galleryQueryBuilder) ClearGalleryId(sceneID int) error { + _, err := qb.tx.NamedExec( `UPDATE galleries SET scene_id = null WHERE scene_id = :sceneid`, GalleryNullSceneID{ SceneID: sql.NullInt64{ @@ -116,16 +90,21 @@ func (qb *GalleryQueryBuilder) ClearGalleryId(sceneID int, tx *sqlx.Tx) error { return err } -func (qb *GalleryQueryBuilder) Find(id int, tx *sqlx.Tx) (*Gallery, error) { - query := "SELECT * FROM galleries WHERE id = ? LIMIT 1" - args := []interface{}{id} - return qb.queryGallery(query, args, tx) +func (qb *galleryQueryBuilder) Find(id int) (*models.Gallery, error) { + var ret models.Gallery + if err := qb.get(id, &ret); err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, err + } + return &ret, nil } -func (qb *GalleryQueryBuilder) FindMany(ids []int) ([]*Gallery, error) { - var galleries []*Gallery +func (qb *galleryQueryBuilder) FindMany(ids []int) ([]*models.Gallery, error) { + var galleries []*models.Gallery for _, id := range ids { - gallery, err := qb.Find(id, nil) + gallery, err := qb.Find(id) if err != nil { return nil, err } @@ -140,67 +119,65 @@ func (qb *GalleryQueryBuilder) FindMany(ids []int) ([]*Gallery, error) { return galleries, nil } -func (qb *GalleryQueryBuilder) FindByChecksum(checksum string, tx *sqlx.Tx) (*Gallery, error) { +func (qb *galleryQueryBuilder) FindByChecksum(checksum string) (*models.Gallery, error) { query := "SELECT * FROM galleries WHERE checksum = ? LIMIT 1" args := []interface{}{checksum} - return qb.queryGallery(query, args, tx) + return qb.queryGallery(query, args) } -func (qb *GalleryQueryBuilder) FindByPath(path string) (*Gallery, error) { +func (qb *galleryQueryBuilder) FindByPath(path string) (*models.Gallery, error) { query := "SELECT * FROM galleries WHERE path = ? LIMIT 1" args := []interface{}{path} - return qb.queryGallery(query, args, nil) + return qb.queryGallery(query, args) } -func (qb *GalleryQueryBuilder) FindBySceneID(sceneID int, tx *sqlx.Tx) (*Gallery, error) { +func (qb *galleryQueryBuilder) FindBySceneID(sceneID int) (*models.Gallery, error) { query := "SELECT galleries.* FROM galleries WHERE galleries.scene_id = ? LIMIT 1" args := []interface{}{sceneID} - return qb.queryGallery(query, args, tx) + return qb.queryGallery(query, args) } -func (qb *GalleryQueryBuilder) ValidGalleriesForScenePath(scenePath string) ([]*Gallery, error) { +func (qb *galleryQueryBuilder) ValidGalleriesForScenePath(scenePath string) ([]*models.Gallery, error) { sceneDirPath := filepath.Dir(scenePath) query := "SELECT galleries.* FROM galleries WHERE galleries.scene_id IS NULL AND galleries.path LIKE '" + sceneDirPath + "%' ORDER BY path ASC" - return qb.queryGalleries(query, nil, nil) + return qb.queryGalleries(query, nil) } -func (qb *GalleryQueryBuilder) FindByImageID(imageID int, tx *sqlx.Tx) ([]*Gallery, error) { +func (qb *galleryQueryBuilder) FindByImageID(imageID int) ([]*models.Gallery, error) { query := selectAll(galleryTable) + ` LEFT JOIN galleries_images as images_join on images_join.gallery_id = galleries.id WHERE images_join.image_id = ? GROUP BY galleries.id ` args := []interface{}{imageID} - return qb.queryGalleries(query, args, tx) + return qb.queryGalleries(query, args) } -func (qb *GalleryQueryBuilder) CountByImageID(imageID int) (int, error) { +func (qb *galleryQueryBuilder) CountByImageID(imageID int) (int, error) { query := `SELECT image_id FROM galleries_images WHERE image_id = ? GROUP BY gallery_id` args := []interface{}{imageID} - return runCountQuery(buildCountQuery(query), args) + return qb.runCountQuery(qb.buildCountQuery(query), args) } -func (qb *GalleryQueryBuilder) Count() (int, error) { - return runCountQuery(buildCountQuery("SELECT galleries.id FROM galleries"), nil) +func (qb *galleryQueryBuilder) Count() (int, error) { + return qb.runCountQuery(qb.buildCountQuery("SELECT galleries.id FROM galleries"), nil) } -func (qb *GalleryQueryBuilder) All() ([]*Gallery, error) { - return qb.queryGalleries(selectAll("galleries")+qb.getGallerySort(nil), nil, nil) +func (qb *galleryQueryBuilder) All() ([]*models.Gallery, error) { + return qb.queryGalleries(selectAll("galleries")+qb.getGallerySort(nil), nil) } -func (qb *GalleryQueryBuilder) Query(galleryFilter *GalleryFilterType, findFilter *FindFilterType) ([]*Gallery, int) { +func (qb *galleryQueryBuilder) Query(galleryFilter *models.GalleryFilterType, findFilter *models.FindFilterType) ([]*models.Gallery, int, error) { if galleryFilter == nil { - galleryFilter = &GalleryFilterType{} + galleryFilter = &models.GalleryFilterType{} } if findFilter == nil { - findFilter = &FindFilterType{} + findFilter = &models.FindFilterType{} } - query := queryBuilder{ - tableName: galleryTable, - } + query := qb.newQuery() query.body = selectDistinctIDs("galleries") query.body += ` @@ -292,18 +269,25 @@ func (qb *GalleryQueryBuilder) Query(galleryFilter *GalleryFilterType, findFilte } query.sortAndPagination = qb.getGallerySort(findFilter) + getPagination(findFilter) - idsResult, countResult := query.executeFind() + idsResult, countResult, err := query.executeFind() + if err != nil { + return nil, 0, err + } - var galleries []*Gallery + var galleries []*models.Gallery for _, id := range idsResult { - gallery, _ := qb.Find(id, nil) + gallery, err := qb.Find(id) + if err != nil { + return nil, 0, err + } + galleries = append(galleries, gallery) } - return galleries, countResult + return galleries, countResult, nil } -func (qb *GalleryQueryBuilder) handleAverageResolutionFilter(query *queryBuilder, resolutionFilter *ResolutionEnum) { +func (qb *galleryQueryBuilder) handleAverageResolutionFilter(query *queryBuilder, resolutionFilter *models.ResolutionEnum) { if resolutionFilter == nil { return } @@ -369,7 +353,7 @@ func (qb *GalleryQueryBuilder) handleAverageResolutionFilter(query *queryBuilder } } -func (qb *GalleryQueryBuilder) getGallerySort(findFilter *FindFilterType) string { +func (qb *galleryQueryBuilder) getGallerySort(findFilter *models.FindFilterType) string { var sort string var direction string if findFilter == nil { @@ -382,40 +366,79 @@ func (qb *GalleryQueryBuilder) getGallerySort(findFilter *FindFilterType) string return getSort(sort, direction, "galleries") } -func (qb *GalleryQueryBuilder) queryGallery(query string, args []interface{}, tx *sqlx.Tx) (*Gallery, error) { - results, err := qb.queryGalleries(query, args, tx) +func (qb *galleryQueryBuilder) queryGallery(query string, args []interface{}) (*models.Gallery, error) { + results, err := qb.queryGalleries(query, args) if err != nil || len(results) < 1 { return nil, err } return results[0], nil } -func (qb *GalleryQueryBuilder) queryGalleries(query string, args []interface{}, tx *sqlx.Tx) ([]*Gallery, error) { - var rows *sqlx.Rows - var err error - if tx != nil { - rows, err = tx.Queryx(query, args...) - } else { - rows, err = database.DB.Queryx(query, args...) - } - - if err != nil && err != sql.ErrNoRows { - return nil, err - } - defer rows.Close() - - galleries := make([]*Gallery, 0) - for rows.Next() { - gallery := Gallery{} - if err := rows.StructScan(&gallery); err != nil { - return nil, err - } - galleries = append(galleries, &gallery) - } - - if err := rows.Err(); err != nil { +func (qb *galleryQueryBuilder) queryGalleries(query string, args []interface{}) ([]*models.Gallery, error) { + var ret models.Galleries + if err := qb.query(query, args, &ret); err != nil { return nil, err } - return galleries, nil + return []*models.Gallery(ret), nil +} + +func (qb *galleryQueryBuilder) performersRepository() *joinRepository { + return &joinRepository{ + repository: repository{ + tx: qb.tx, + tableName: performersGalleriesTable, + idColumn: galleryIDColumn, + }, + fkColumn: "performer_id", + } +} + +func (qb *galleryQueryBuilder) GetPerformerIDs(galleryID int) ([]int, error) { + return qb.performersRepository().getIDs(galleryID) +} + +func (qb *galleryQueryBuilder) UpdatePerformers(galleryID int, performerIDs []int) error { + // Delete the existing joins and then create new ones + return qb.performersRepository().replace(galleryID, performerIDs) +} + +func (qb *galleryQueryBuilder) tagsRepository() *joinRepository { + return &joinRepository{ + repository: repository{ + tx: qb.tx, + tableName: galleriesTagsTable, + idColumn: galleryIDColumn, + }, + fkColumn: "tag_id", + } +} + +func (qb *galleryQueryBuilder) GetTagIDs(galleryID int) ([]int, error) { + return qb.tagsRepository().getIDs(galleryID) +} + +func (qb *galleryQueryBuilder) UpdateTags(galleryID int, tagIDs []int) error { + // Delete the existing joins and then create new ones + return qb.tagsRepository().replace(galleryID, tagIDs) +} + +func (qb *galleryQueryBuilder) imagesRepository() *joinRepository { + return &joinRepository{ + repository: repository{ + tx: qb.tx, + tableName: galleriesImagesTable, + idColumn: galleryIDColumn, + }, + fkColumn: "image_id", + } +} + +func (qb *galleryQueryBuilder) GetImageIDs(galleryID int) ([]int, error) { + return qb.imagesRepository().getIDs(galleryID) +} + +func (qb *galleryQueryBuilder) UpdateImages(galleryID int, imageIDs []int) error { + // Delete the existing joins and then create new ones + return qb.imagesRepository().replace(galleryID, imageIDs) } diff --git a/pkg/sqlite/gallery_test.go b/pkg/sqlite/gallery_test.go new file mode 100644 index 000000000..8ca671db2 --- /dev/null +++ b/pkg/sqlite/gallery_test.go @@ -0,0 +1,274 @@ +// +build integration + +package sqlite_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/stashapp/stash/pkg/models" +) + +func TestGalleryFind(t *testing.T) { + withTxn(func(r models.Repository) error { + gqb := r.Gallery() + + const galleryIdx = 0 + gallery, err := gqb.Find(galleryIDs[galleryIdx]) + + if err != nil { + t.Errorf("Error finding gallery: %s", err.Error()) + } + + assert.Equal(t, getGalleryStringValue(galleryIdx, "Path"), gallery.Path.String) + + gallery, err = gqb.Find(0) + + if err != nil { + t.Errorf("Error finding gallery: %s", err.Error()) + } + + assert.Nil(t, gallery) + + return nil + }) +} + +func TestGalleryFindByChecksum(t *testing.T) { + withTxn(func(r models.Repository) error { + gqb := r.Gallery() + + const galleryIdx = 0 + galleryChecksum := getGalleryStringValue(galleryIdx, "Checksum") + gallery, err := gqb.FindByChecksum(galleryChecksum) + + if err != nil { + t.Errorf("Error finding gallery: %s", err.Error()) + } + + assert.Equal(t, getGalleryStringValue(galleryIdx, "Path"), gallery.Path.String) + + galleryChecksum = "not exist" + gallery, err = gqb.FindByChecksum(galleryChecksum) + + if err != nil { + t.Errorf("Error finding gallery: %s", err.Error()) + } + + assert.Nil(t, gallery) + + return nil + }) +} + +func TestGalleryFindByPath(t *testing.T) { + withTxn(func(r models.Repository) error { + gqb := r.Gallery() + + const galleryIdx = 0 + galleryPath := getGalleryStringValue(galleryIdx, "Path") + gallery, err := gqb.FindByPath(galleryPath) + + if err != nil { + t.Errorf("Error finding gallery: %s", err.Error()) + } + + assert.Equal(t, galleryPath, gallery.Path.String) + + galleryPath = "not exist" + gallery, err = gqb.FindByPath(galleryPath) + + if err != nil { + t.Errorf("Error finding gallery: %s", err.Error()) + } + + assert.Nil(t, gallery) + + return nil + }) +} + +func TestGalleryFindBySceneID(t *testing.T) { + withTxn(func(r models.Repository) error { + gqb := r.Gallery() + + sceneID := sceneIDs[sceneIdxWithGallery] + gallery, err := gqb.FindBySceneID(sceneID) + + if err != nil { + t.Errorf("Error finding gallery: %s", err.Error()) + } + + assert.Equal(t, getGalleryStringValue(galleryIdxWithScene, "Path"), gallery.Path.String) + + gallery, err = gqb.FindBySceneID(0) + + if err != nil { + t.Errorf("Error finding gallery: %s", err.Error()) + } + + assert.Nil(t, gallery) + + return nil + }) +} + +func TestGalleryQueryQ(t *testing.T) { + withTxn(func(r models.Repository) error { + const galleryIdx = 0 + + q := getGalleryStringValue(galleryIdx, pathField) + + sqb := r.Gallery() + + galleryQueryQ(t, sqb, q, galleryIdx) + + return nil + }) +} + +func galleryQueryQ(t *testing.T, qb models.GalleryReader, q string, expectedGalleryIdx int) { + filter := models.FindFilterType{ + Q: &q, + } + galleries, _, err := qb.Query(nil, &filter) + if err != nil { + t.Errorf("Error querying gallery: %s", err.Error()) + } + + assert.Len(t, galleries, 1) + gallery := galleries[0] + assert.Equal(t, galleryIDs[expectedGalleryIdx], gallery.ID) + + // no Q should return all results + filter.Q = nil + galleries, _, err = qb.Query(nil, &filter) + if err != nil { + t.Errorf("Error querying gallery: %s", err.Error()) + } + + assert.Len(t, galleries, totalGalleries) +} + +func TestGalleryQueryPath(t *testing.T) { + withTxn(func(r models.Repository) error { + const galleryIdx = 1 + galleryPath := getGalleryStringValue(galleryIdx, "Path") + + pathCriterion := models.StringCriterionInput{ + Value: galleryPath, + Modifier: models.CriterionModifierEquals, + } + + verifyGalleriesPath(t, r.Gallery(), pathCriterion) + + pathCriterion.Modifier = models.CriterionModifierNotEquals + verifyGalleriesPath(t, r.Gallery(), pathCriterion) + + return nil + }) +} + +func verifyGalleriesPath(t *testing.T, sqb models.GalleryReader, pathCriterion models.StringCriterionInput) { + galleryFilter := models.GalleryFilterType{ + Path: &pathCriterion, + } + + galleries, _, err := sqb.Query(&galleryFilter, nil) + if err != nil { + t.Errorf("Error querying gallery: %s", err.Error()) + } + + for _, gallery := range galleries { + verifyNullString(t, gallery.Path, pathCriterion) + } +} + +func TestGalleryQueryRating(t *testing.T) { + const rating = 3 + ratingCriterion := models.IntCriterionInput{ + Value: rating, + Modifier: models.CriterionModifierEquals, + } + + verifyGalleriesRating(t, ratingCriterion) + + ratingCriterion.Modifier = models.CriterionModifierNotEquals + verifyGalleriesRating(t, ratingCriterion) + + ratingCriterion.Modifier = models.CriterionModifierGreaterThan + verifyGalleriesRating(t, ratingCriterion) + + ratingCriterion.Modifier = models.CriterionModifierLessThan + verifyGalleriesRating(t, ratingCriterion) + + ratingCriterion.Modifier = models.CriterionModifierIsNull + verifyGalleriesRating(t, ratingCriterion) + + ratingCriterion.Modifier = models.CriterionModifierNotNull + verifyGalleriesRating(t, ratingCriterion) +} + +func verifyGalleriesRating(t *testing.T, ratingCriterion models.IntCriterionInput) { + withTxn(func(r models.Repository) error { + sqb := r.Gallery() + galleryFilter := models.GalleryFilterType{ + Rating: &ratingCriterion, + } + + galleries, _, err := sqb.Query(&galleryFilter, nil) + if err != nil { + t.Errorf("Error querying gallery: %s", err.Error()) + } + + for _, gallery := range galleries { + verifyInt64(t, gallery.Rating, ratingCriterion) + } + + return nil + }) +} + +func TestGalleryQueryIsMissingScene(t *testing.T) { + withTxn(func(r models.Repository) error { + qb := r.Gallery() + isMissing := "scene" + galleryFilter := models.GalleryFilterType{ + IsMissing: &isMissing, + } + + q := getGalleryStringValue(galleryIdxWithScene, titleField) + findFilter := models.FindFilterType{ + Q: &q, + } + + galleries, _, err := qb.Query(&galleryFilter, &findFilter) + if err != nil { + t.Errorf("Error querying gallery: %s", err.Error()) + } + + assert.Len(t, galleries, 0) + + findFilter.Q = nil + galleries, _, err = qb.Query(&galleryFilter, &findFilter) + if err != nil { + t.Errorf("Error querying gallery: %s", err.Error()) + } + + // ensure non of the ids equal the one with gallery + for _, gallery := range galleries { + assert.NotEqual(t, galleryIDs[galleryIdxWithScene], gallery.ID) + } + + return nil + }) +} + +// TODO ValidGalleriesForScenePath +// TODO Count +// TODO All +// TODO Query +// TODO Update +// TODO Destroy +// TODO ClearGalleryId diff --git a/pkg/models/querybuilder_image.go b/pkg/sqlite/image.go similarity index 59% rename from pkg/models/querybuilder_image.go rename to pkg/sqlite/image.go index 04fb54df5..48123786b 100644 --- a/pkg/models/querybuilder_image.go +++ b/pkg/sqlite/image.go @@ -1,15 +1,16 @@ -package models +package sqlite import ( "database/sql" "fmt" - "strconv" - "github.com/jmoiron/sqlx" - "github.com/stashapp/stash/pkg/database" + "github.com/stashapp/stash/pkg/models" ) const imageTable = "images" +const imageIDColumn = "image_id" +const performersImagesTable = "performers_images" +const imagesTagsTable = "images_tags" var imagesForPerformerQuery = selectAll(imageTable) + ` LEFT JOIN performers_images as performers_join on performers_join.image_id = images.id @@ -52,85 +53,57 @@ WHERE gallery_id = ? GROUP BY image_id ` -type ImageQueryBuilder struct{} - -func NewImageQueryBuilder() ImageQueryBuilder { - return ImageQueryBuilder{} +type imageQueryBuilder struct { + repository } -func (qb *ImageQueryBuilder) Create(newImage Image, tx *sqlx.Tx) (*Image, error) { - ensureTx(tx) - result, err := tx.NamedExec( - `INSERT INTO images (checksum, path, title, rating, organized, o_counter, size, - width, height, studio_id, file_mod_time, created_at, updated_at) - VALUES (:checksum, :path, :title, :rating, :organized, :o_counter, :size, - :width, :height, :studio_id, :file_mod_time, :created_at, :updated_at) - `, - newImage, - ) - if err != nil { - return nil, err +func NewImageReaderWriter(tx dbi) *imageQueryBuilder { + return &imageQueryBuilder{ + repository{ + tx: tx, + tableName: imageTable, + idColumn: idColumn, + }, } - imageID, err := result.LastInsertId() - if err != nil { - return nil, err - } - if err := tx.Get(&newImage, `SELECT * FROM images WHERE id = ? LIMIT 1`, imageID); err != nil { - return nil, err - } - return &newImage, nil } -func (qb *ImageQueryBuilder) Update(updatedImage ImagePartial, tx *sqlx.Tx) (*Image, error) { - ensureTx(tx) - _, err := tx.NamedExec( - `UPDATE images SET `+SQLGenKeysPartial(updatedImage)+` WHERE images.id = :id`, - updatedImage, - ) - if err != nil { +func (qb *imageQueryBuilder) Create(newObject models.Image) (*models.Image, error) { + var ret models.Image + if err := qb.insertObject(newObject, &ret); err != nil { return nil, err } - return qb.find(updatedImage.ID, tx) + return &ret, nil } -func (qb *ImageQueryBuilder) UpdateFull(updatedImage Image, tx *sqlx.Tx) (*Image, error) { - ensureTx(tx) - _, err := tx.NamedExec( - `UPDATE images SET `+SQLGenKeys(updatedImage)+` WHERE images.id = :id`, - updatedImage, - ) - if err != nil { +func (qb *imageQueryBuilder) Update(updatedObject models.ImagePartial) (*models.Image, error) { + const partial = true + if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil { return nil, err } - return qb.find(updatedImage.ID, tx) + return qb.find(updatedObject.ID) } -func (qb *ImageQueryBuilder) UpdateFileModTime(id int, modTime NullSQLiteTimestamp, tx *sqlx.Tx) error { - ensureTx(tx) - _, err := tx.Exec( - `UPDATE images SET file_mod_time = ? WHERE images.id = ? `, - modTime, id, - ) - if err != nil { - return err +func (qb *imageQueryBuilder) UpdateFull(updatedObject models.Image) (*models.Image, error) { + const partial = false + if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil { + return nil, err } - return nil + return qb.find(updatedObject.ID) } -func (qb *ImageQueryBuilder) IncrementOCounter(id int, tx *sqlx.Tx) (int, error) { - ensureTx(tx) - _, err := tx.Exec( - `UPDATE images SET o_counter = o_counter + 1 WHERE images.id = ?`, +func (qb *imageQueryBuilder) IncrementOCounter(id int) (int, error) { + _, err := qb.tx.Exec( + `UPDATE `+imageTable+` SET o_counter = o_counter + 1 WHERE `+imageTable+`.id = ?`, id, ) if err != nil { return 0, err } - image, err := qb.find(id, tx) + image, err := qb.find(id) if err != nil { return 0, err } @@ -138,17 +111,16 @@ func (qb *ImageQueryBuilder) IncrementOCounter(id int, tx *sqlx.Tx) (int, error) return image.OCounter, nil } -func (qb *ImageQueryBuilder) DecrementOCounter(id int, tx *sqlx.Tx) (int, error) { - ensureTx(tx) - _, err := tx.Exec( - `UPDATE images SET o_counter = o_counter - 1 WHERE images.id = ? and images.o_counter > 0`, +func (qb *imageQueryBuilder) DecrementOCounter(id int) (int, error) { + _, err := qb.tx.Exec( + `UPDATE `+imageTable+` SET o_counter = o_counter - 1 WHERE `+imageTable+`.id = ? and `+imageTable+`.o_counter > 0`, id, ) if err != nil { return 0, err } - image, err := qb.find(id, tx) + image, err := qb.find(id) if err != nil { return 0, err } @@ -156,17 +128,16 @@ func (qb *ImageQueryBuilder) DecrementOCounter(id int, tx *sqlx.Tx) (int, error) return image.OCounter, nil } -func (qb *ImageQueryBuilder) ResetOCounter(id int, tx *sqlx.Tx) (int, error) { - ensureTx(tx) - _, err := tx.Exec( - `UPDATE images SET o_counter = 0 WHERE images.id = ?`, +func (qb *imageQueryBuilder) ResetOCounter(id int) (int, error) { + _, err := qb.tx.Exec( + `UPDATE `+imageTable+` SET o_counter = 0 WHERE `+imageTable+`.id = ?`, id, ) if err != nil { return 0, err } - image, err := qb.find(id, tx) + image, err := qb.find(id) if err != nil { return 0, err } @@ -174,15 +145,16 @@ func (qb *ImageQueryBuilder) ResetOCounter(id int, tx *sqlx.Tx) (int, error) { return image.OCounter, nil } -func (qb *ImageQueryBuilder) Destroy(id int, tx *sqlx.Tx) error { - return executeDeleteQuery("images", strconv.Itoa(id), tx) -} -func (qb *ImageQueryBuilder) Find(id int) (*Image, error) { - return qb.find(id, nil) +func (qb *imageQueryBuilder) Destroy(id int) error { + return qb.destroyExisting([]int{id}) } -func (qb *ImageQueryBuilder) FindMany(ids []int) ([]*Image, error) { - var images []*Image +func (qb *imageQueryBuilder) Find(id int) (*models.Image, error) { + return qb.find(id) +} + +func (qb *imageQueryBuilder) FindMany(ids []int) ([]*models.Image, error) { + var images []*models.Image for _, id := range ids { image, err := qb.Find(id) if err != nil { @@ -199,82 +171,60 @@ func (qb *ImageQueryBuilder) FindMany(ids []int) ([]*Image, error) { return images, nil } -func (qb *ImageQueryBuilder) find(id int, tx *sqlx.Tx) (*Image, error) { - query := selectAll(imageTable) + "WHERE id = ? LIMIT 1" - args := []interface{}{id} - return qb.queryImage(query, args, tx) +func (qb *imageQueryBuilder) find(id int) (*models.Image, error) { + var ret models.Image + if err := qb.get(id, &ret); err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, err + } + return &ret, nil } -func (qb *ImageQueryBuilder) FindByChecksum(checksum string) (*Image, error) { +func (qb *imageQueryBuilder) FindByChecksum(checksum string) (*models.Image, error) { query := "SELECT * FROM images WHERE checksum = ? LIMIT 1" args := []interface{}{checksum} - return qb.queryImage(query, args, nil) + return qb.queryImage(query, args) } -func (qb *ImageQueryBuilder) FindByPath(path string) (*Image, error) { +func (qb *imageQueryBuilder) FindByPath(path string) (*models.Image, error) { query := selectAll(imageTable) + "WHERE path = ? LIMIT 1" args := []interface{}{path} - return qb.queryImage(query, args, nil) + return qb.queryImage(query, args) } -func (qb *ImageQueryBuilder) FindByPerformerID(performerID int) ([]*Image, error) { - args := []interface{}{performerID} - return qb.queryImages(imagesForPerformerQuery, args, nil) -} - -func (qb *ImageQueryBuilder) CountByPerformerID(performerID int) (int, error) { - args := []interface{}{performerID} - return runCountQuery(buildCountQuery(countImagesForPerformerQuery), args) -} - -func (qb *ImageQueryBuilder) FindByStudioID(studioID int) ([]*Image, error) { - args := []interface{}{studioID} - return qb.queryImages(imagesForStudioQuery, args, nil) -} - -func (qb *ImageQueryBuilder) FindByGalleryID(galleryID int) ([]*Image, error) { +func (qb *imageQueryBuilder) FindByGalleryID(galleryID int) ([]*models.Image, error) { args := []interface{}{galleryID} - return qb.queryImages(imagesForGalleryQuery+qb.getImageSort(nil), args, nil) + return qb.queryImages(imagesForGalleryQuery+qb.getImageSort(nil), args) } -func (qb *ImageQueryBuilder) CountByGalleryID(galleryID int) (int, error) { +func (qb *imageQueryBuilder) CountByGalleryID(galleryID int) (int, error) { args := []interface{}{galleryID} - return runCountQuery(buildCountQuery(countImagesForGalleryQuery), args) + return qb.runCountQuery(qb.buildCountQuery(countImagesForGalleryQuery), args) } -func (qb *ImageQueryBuilder) Count() (int, error) { - return runCountQuery(buildCountQuery("SELECT images.id FROM images"), nil) +func (qb *imageQueryBuilder) Count() (int, error) { + return qb.runCountQuery(qb.buildCountQuery("SELECT images.id FROM images"), nil) } -func (qb *ImageQueryBuilder) Size() (float64, error) { - return runSumQuery("SELECT SUM(cast(size as double)) as sum FROM images", nil) +func (qb *imageQueryBuilder) Size() (float64, error) { + return qb.runSumQuery("SELECT SUM(cast(size as double)) as sum FROM images", nil) } -func (qb *ImageQueryBuilder) CountByStudioID(studioID int) (int, error) { - args := []interface{}{studioID} - return runCountQuery(buildCountQuery(imagesForStudioQuery), args) +func (qb *imageQueryBuilder) All() ([]*models.Image, error) { + return qb.queryImages(selectAll(imageTable)+qb.getImageSort(nil), nil) } -func (qb *ImageQueryBuilder) CountByTagID(tagID int) (int, error) { - args := []interface{}{tagID} - return runCountQuery(buildCountQuery(countImagesForTagQuery), args) -} - -func (qb *ImageQueryBuilder) All() ([]*Image, error) { - return qb.queryImages(selectAll(imageTable)+qb.getImageSort(nil), nil, nil) -} - -func (qb *ImageQueryBuilder) Query(imageFilter *ImageFilterType, findFilter *FindFilterType) ([]*Image, int) { +func (qb *imageQueryBuilder) Query(imageFilter *models.ImageFilterType, findFilter *models.FindFilterType) ([]*models.Image, int, error) { if imageFilter == nil { - imageFilter = &ImageFilterType{} + imageFilter = &models.ImageFilterType{} } if findFilter == nil { - findFilter = &FindFilterType{} + findFilter = &models.FindFilterType{} } - query := queryBuilder{ - tableName: imageTable, - } + query := qb.newQuery() query.body = selectDistinctIDs(imageTable) query.body += ` @@ -411,18 +361,25 @@ func (qb *ImageQueryBuilder) Query(imageFilter *ImageFilterType, findFilter *Fin } query.sortAndPagination = qb.getImageSort(findFilter) + getPagination(findFilter) - idsResult, countResult := query.executeFind() + idsResult, countResult, err := query.executeFind() + if err != nil { + return nil, 0, err + } - var images []*Image + var images []*models.Image for _, id := range idsResult { - image, _ := qb.Find(id) + image, err := qb.Find(id) + if err != nil { + return nil, 0, err + } + images = append(images, image) } - return images, countResult + return images, countResult, nil } -func (qb *ImageQueryBuilder) getImageSort(findFilter *FindFilterType) string { +func (qb *imageQueryBuilder) getImageSort(findFilter *models.FindFilterType) string { if findFilter == nil { return " ORDER BY images.path ASC " } @@ -431,40 +388,79 @@ func (qb *ImageQueryBuilder) getImageSort(findFilter *FindFilterType) string { return getSort(sort, direction, "images") } -func (qb *ImageQueryBuilder) queryImage(query string, args []interface{}, tx *sqlx.Tx) (*Image, error) { - results, err := qb.queryImages(query, args, tx) +func (qb *imageQueryBuilder) queryImage(query string, args []interface{}) (*models.Image, error) { + results, err := qb.queryImages(query, args) if err != nil || len(results) < 1 { return nil, err } return results[0], nil } -func (qb *ImageQueryBuilder) queryImages(query string, args []interface{}, tx *sqlx.Tx) ([]*Image, error) { - var rows *sqlx.Rows - var err error - if tx != nil { - rows, err = tx.Queryx(query, args...) - } else { - rows, err = database.DB.Queryx(query, args...) - } - - if err != nil && err != sql.ErrNoRows { - return nil, err - } - defer rows.Close() - - images := make([]*Image, 0) - for rows.Next() { - image := Image{} - if err := rows.StructScan(&image); err != nil { - return nil, err - } - images = append(images, &image) - } - - if err := rows.Err(); err != nil { +func (qb *imageQueryBuilder) queryImages(query string, args []interface{}) ([]*models.Image, error) { + var ret models.Images + if err := qb.query(query, args, &ret); err != nil { return nil, err } - return images, nil + return []*models.Image(ret), nil +} + +func (qb *imageQueryBuilder) galleriesRepository() *joinRepository { + return &joinRepository{ + repository: repository{ + tx: qb.tx, + tableName: galleriesImagesTable, + idColumn: imageIDColumn, + }, + fkColumn: galleryIDColumn, + } +} + +func (qb *imageQueryBuilder) GetGalleryIDs(imageID int) ([]int, error) { + return qb.galleriesRepository().getIDs(imageID) +} + +func (qb *imageQueryBuilder) UpdateGalleries(imageID int, galleryIDs []int) error { + // Delete the existing joins and then create new ones + return qb.galleriesRepository().replace(imageID, galleryIDs) +} + +func (qb *imageQueryBuilder) performersRepository() *joinRepository { + return &joinRepository{ + repository: repository{ + tx: qb.tx, + tableName: performersImagesTable, + idColumn: imageIDColumn, + }, + fkColumn: performerIDColumn, + } +} + +func (qb *imageQueryBuilder) GetPerformerIDs(imageID int) ([]int, error) { + return qb.performersRepository().getIDs(imageID) +} + +func (qb *imageQueryBuilder) UpdatePerformers(imageID int, performerIDs []int) error { + // Delete the existing joins and then create new ones + return qb.performersRepository().replace(imageID, performerIDs) +} + +func (qb *imageQueryBuilder) tagsRepository() *joinRepository { + return &joinRepository{ + repository: repository{ + tx: qb.tx, + tableName: imagesTagsTable, + idColumn: imageIDColumn, + }, + fkColumn: tagIDColumn, + } +} + +func (qb *imageQueryBuilder) GetTagIDs(imageID int) ([]int, error) { + return qb.tagsRepository().getIDs(imageID) +} + +func (qb *imageQueryBuilder) UpdateTags(imageID int, tagIDs []int) error { + // Delete the existing joins and then create new ones + return qb.tagsRepository().replace(imageID, tagIDs) } diff --git a/pkg/sqlite/image_test.go b/pkg/sqlite/image_test.go new file mode 100644 index 000000000..a56f9b888 --- /dev/null +++ b/pkg/sqlite/image_test.go @@ -0,0 +1,703 @@ +// +build integration + +package sqlite_test + +import ( + "database/sql" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/stashapp/stash/pkg/models" +) + +func TestImageFind(t *testing.T) { + withTxn(func(r models.Repository) error { + // assume that the first image is imageWithGalleryPath + sqb := r.Image() + + const imageIdx = 0 + imageID := imageIDs[imageIdx] + image, err := sqb.Find(imageID) + + if err != nil { + t.Errorf("Error finding image: %s", err.Error()) + } + + assert.Equal(t, getImageStringValue(imageIdx, "Path"), image.Path) + + imageID = 0 + image, err = sqb.Find(imageID) + + if err != nil { + t.Errorf("Error finding image: %s", err.Error()) + } + + assert.Nil(t, image) + + return nil + }) +} + +func TestImageFindByPath(t *testing.T) { + withTxn(func(r models.Repository) error { + sqb := r.Image() + + const imageIdx = 1 + imagePath := getImageStringValue(imageIdx, "Path") + image, err := sqb.FindByPath(imagePath) + + if err != nil { + t.Errorf("Error finding image: %s", err.Error()) + } + + assert.Equal(t, imageIDs[imageIdx], image.ID) + assert.Equal(t, imagePath, image.Path) + + imagePath = "not exist" + image, err = sqb.FindByPath(imagePath) + + if err != nil { + t.Errorf("Error finding image: %s", err.Error()) + } + + assert.Nil(t, image) + + return nil + }) +} + +func TestImageQueryQ(t *testing.T) { + withTxn(func(r models.Repository) error { + const imageIdx = 2 + + q := getImageStringValue(imageIdx, titleField) + + sqb := r.Image() + + imageQueryQ(t, sqb, q, imageIdx) + + return nil + }) +} + +func imageQueryQ(t *testing.T, sqb models.ImageReader, q string, expectedImageIdx int) { + filter := models.FindFilterType{ + Q: &q, + } + images, _, err := sqb.Query(nil, &filter) + if err != nil { + t.Errorf("Error querying image: %s", err.Error()) + } + + assert.Len(t, images, 1) + image := images[0] + assert.Equal(t, imageIDs[expectedImageIdx], image.ID) + + // no Q should return all results + filter.Q = nil + images, _, err = sqb.Query(nil, &filter) + if err != nil { + t.Errorf("Error querying image: %s", err.Error()) + } + + assert.Len(t, images, totalImages) +} + +func TestImageQueryPath(t *testing.T) { + const imageIdx = 1 + imagePath := getImageStringValue(imageIdx, "Path") + + pathCriterion := models.StringCriterionInput{ + Value: imagePath, + Modifier: models.CriterionModifierEquals, + } + + verifyImagePath(t, pathCriterion) + + pathCriterion.Modifier = models.CriterionModifierNotEquals + verifyImagePath(t, pathCriterion) +} + +func verifyImagePath(t *testing.T, pathCriterion models.StringCriterionInput) { + withTxn(func(r models.Repository) error { + sqb := r.Image() + imageFilter := models.ImageFilterType{ + Path: &pathCriterion, + } + + images, _, err := sqb.Query(&imageFilter, nil) + if err != nil { + t.Errorf("Error querying image: %s", err.Error()) + } + + for _, image := range images { + verifyString(t, image.Path, pathCriterion) + } + + return nil + }) +} + +func TestImageQueryRating(t *testing.T) { + const rating = 3 + ratingCriterion := models.IntCriterionInput{ + Value: rating, + Modifier: models.CriterionModifierEquals, + } + + verifyImagesRating(t, ratingCriterion) + + ratingCriterion.Modifier = models.CriterionModifierNotEquals + verifyImagesRating(t, ratingCriterion) + + ratingCriterion.Modifier = models.CriterionModifierGreaterThan + verifyImagesRating(t, ratingCriterion) + + ratingCriterion.Modifier = models.CriterionModifierLessThan + verifyImagesRating(t, ratingCriterion) + + ratingCriterion.Modifier = models.CriterionModifierIsNull + verifyImagesRating(t, ratingCriterion) + + ratingCriterion.Modifier = models.CriterionModifierNotNull + verifyImagesRating(t, ratingCriterion) +} + +func verifyImagesRating(t *testing.T, ratingCriterion models.IntCriterionInput) { + withTxn(func(r models.Repository) error { + sqb := r.Image() + imageFilter := models.ImageFilterType{ + Rating: &ratingCriterion, + } + + images, _, err := sqb.Query(&imageFilter, nil) + if err != nil { + t.Errorf("Error querying image: %s", err.Error()) + } + + for _, image := range images { + verifyInt64(t, image.Rating, ratingCriterion) + } + + return nil + }) +} + +func TestImageQueryOCounter(t *testing.T) { + const oCounter = 1 + oCounterCriterion := models.IntCriterionInput{ + Value: oCounter, + Modifier: models.CriterionModifierEquals, + } + + verifyImagesOCounter(t, oCounterCriterion) + + oCounterCriterion.Modifier = models.CriterionModifierNotEquals + verifyImagesOCounter(t, oCounterCriterion) + + oCounterCriterion.Modifier = models.CriterionModifierGreaterThan + verifyImagesOCounter(t, oCounterCriterion) + + oCounterCriterion.Modifier = models.CriterionModifierLessThan + verifyImagesOCounter(t, oCounterCriterion) +} + +func verifyImagesOCounter(t *testing.T, oCounterCriterion models.IntCriterionInput) { + withTxn(func(r models.Repository) error { + sqb := r.Image() + imageFilter := models.ImageFilterType{ + OCounter: &oCounterCriterion, + } + + images, _, err := sqb.Query(&imageFilter, nil) + if err != nil { + t.Errorf("Error querying image: %s", err.Error()) + } + + for _, image := range images { + verifyInt(t, image.OCounter, oCounterCriterion) + } + + return nil + }) +} + +func TestImageQueryResolution(t *testing.T) { + verifyImagesResolution(t, models.ResolutionEnumLow) + verifyImagesResolution(t, models.ResolutionEnumStandard) + verifyImagesResolution(t, models.ResolutionEnumStandardHd) + verifyImagesResolution(t, models.ResolutionEnumFullHd) + verifyImagesResolution(t, models.ResolutionEnumFourK) + verifyImagesResolution(t, models.ResolutionEnum("unknown")) +} + +func verifyImagesResolution(t *testing.T, resolution models.ResolutionEnum) { + withTxn(func(r models.Repository) error { + sqb := r.Image() + imageFilter := models.ImageFilterType{ + Resolution: &resolution, + } + + images, _, err := sqb.Query(&imageFilter, nil) + if err != nil { + t.Errorf("Error querying image: %s", err.Error()) + } + + for _, image := range images { + verifyImageResolution(t, image.Height, resolution) + } + + return nil + }) +} + +func verifyImageResolution(t *testing.T, height sql.NullInt64, resolution models.ResolutionEnum) { + assert := assert.New(t) + h := height.Int64 + + switch resolution { + case models.ResolutionEnumLow: + assert.True(h < 480) + case models.ResolutionEnumStandard: + assert.True(h >= 480 && h < 720) + case models.ResolutionEnumStandardHd: + assert.True(h >= 720 && h < 1080) + case models.ResolutionEnumFullHd: + assert.True(h >= 1080 && h < 2160) + case models.ResolutionEnumFourK: + assert.True(h >= 2160) + } +} + +func TestImageQueryIsMissingGalleries(t *testing.T) { + withTxn(func(r models.Repository) error { + sqb := r.Image() + isMissing := "galleries" + imageFilter := models.ImageFilterType{ + IsMissing: &isMissing, + } + + q := getImageStringValue(imageIdxWithGallery, titleField) + findFilter := models.FindFilterType{ + Q: &q, + } + + images, _, err := sqb.Query(&imageFilter, &findFilter) + if err != nil { + t.Errorf("Error querying image: %s", err.Error()) + } + + assert.Len(t, images, 0) + + findFilter.Q = nil + images, _, err = sqb.Query(&imageFilter, &findFilter) + if err != nil { + t.Errorf("Error querying image: %s", err.Error()) + } + + // ensure non of the ids equal the one with gallery + for _, image := range images { + assert.NotEqual(t, imageIDs[imageIdxWithGallery], image.ID) + } + + return nil + }) +} + +func TestImageQueryIsMissingStudio(t *testing.T) { + withTxn(func(r models.Repository) error { + sqb := r.Image() + isMissing := "studio" + imageFilter := models.ImageFilterType{ + IsMissing: &isMissing, + } + + q := getImageStringValue(imageIdxWithStudio, titleField) + findFilter := models.FindFilterType{ + Q: &q, + } + + images, _, err := sqb.Query(&imageFilter, &findFilter) + if err != nil { + t.Errorf("Error querying image: %s", err.Error()) + } + + assert.Len(t, images, 0) + + findFilter.Q = nil + images, _, err = sqb.Query(&imageFilter, &findFilter) + if err != nil { + t.Errorf("Error querying image: %s", err.Error()) + } + + // ensure non of the ids equal the one with studio + for _, image := range images { + assert.NotEqual(t, imageIDs[imageIdxWithStudio], image.ID) + } + + return nil + }) +} + +func TestImageQueryIsMissingPerformers(t *testing.T) { + withTxn(func(r models.Repository) error { + sqb := r.Image() + isMissing := "performers" + imageFilter := models.ImageFilterType{ + IsMissing: &isMissing, + } + + q := getImageStringValue(imageIdxWithPerformer, titleField) + findFilter := models.FindFilterType{ + Q: &q, + } + + images, _, err := sqb.Query(&imageFilter, &findFilter) + if err != nil { + t.Errorf("Error querying image: %s", err.Error()) + } + + assert.Len(t, images, 0) + + findFilter.Q = nil + images, _, err = sqb.Query(&imageFilter, &findFilter) + if err != nil { + t.Errorf("Error querying image: %s", err.Error()) + } + + assert.True(t, len(images) > 0) + + // ensure non of the ids equal the one with movies + for _, image := range images { + assert.NotEqual(t, imageIDs[imageIdxWithPerformer], image.ID) + } + + return nil + }) +} + +func TestImageQueryIsMissingTags(t *testing.T) { + withTxn(func(r models.Repository) error { + sqb := r.Image() + isMissing := "tags" + imageFilter := models.ImageFilterType{ + IsMissing: &isMissing, + } + + q := getImageStringValue(imageIdxWithTwoTags, titleField) + findFilter := models.FindFilterType{ + Q: &q, + } + + images, _, err := sqb.Query(&imageFilter, &findFilter) + if err != nil { + t.Errorf("Error querying image: %s", err.Error()) + } + + assert.Len(t, images, 0) + + findFilter.Q = nil + images, _, err = sqb.Query(&imageFilter, &findFilter) + if err != nil { + t.Errorf("Error querying image: %s", err.Error()) + } + + assert.True(t, len(images) > 0) + + return nil + }) +} + +func TestImageQueryIsMissingRating(t *testing.T) { + withTxn(func(r models.Repository) error { + sqb := r.Image() + isMissing := "rating" + imageFilter := models.ImageFilterType{ + IsMissing: &isMissing, + } + + images, _, err := sqb.Query(&imageFilter, nil) + if err != nil { + t.Errorf("Error querying image: %s", err.Error()) + } + + assert.True(t, len(images) > 0) + + // ensure date is null, empty or "0001-01-01" + for _, image := range images { + assert.True(t, !image.Rating.Valid) + } + + return nil + }) +} + +func TestImageQueryPerformers(t *testing.T) { + withTxn(func(r models.Repository) error { + sqb := r.Image() + performerCriterion := models.MultiCriterionInput{ + Value: []string{ + strconv.Itoa(performerIDs[performerIdxWithImage]), + strconv.Itoa(performerIDs[performerIdx1WithImage]), + }, + Modifier: models.CriterionModifierIncludes, + } + + imageFilter := models.ImageFilterType{ + Performers: &performerCriterion, + } + + images, _, err := sqb.Query(&imageFilter, nil) + if err != nil { + t.Errorf("Error querying image: %s", err.Error()) + } + + assert.Len(t, images, 2) + + // ensure ids are correct + for _, image := range images { + assert.True(t, image.ID == imageIDs[imageIdxWithPerformer] || image.ID == imageIDs[imageIdxWithTwoPerformers]) + } + + performerCriterion = models.MultiCriterionInput{ + Value: []string{ + strconv.Itoa(performerIDs[performerIdx1WithImage]), + strconv.Itoa(performerIDs[performerIdx2WithImage]), + }, + Modifier: models.CriterionModifierIncludesAll, + } + + images, _, err = sqb.Query(&imageFilter, nil) + if err != nil { + t.Errorf("Error querying image: %s", err.Error()) + } + + assert.Len(t, images, 1) + assert.Equal(t, imageIDs[imageIdxWithTwoPerformers], images[0].ID) + + performerCriterion = models.MultiCriterionInput{ + Value: []string{ + strconv.Itoa(performerIDs[performerIdx1WithImage]), + }, + Modifier: models.CriterionModifierExcludes, + } + + q := getImageStringValue(imageIdxWithTwoPerformers, titleField) + findFilter := models.FindFilterType{ + Q: &q, + } + + images, _, err = sqb.Query(&imageFilter, &findFilter) + if err != nil { + t.Errorf("Error querying image: %s", err.Error()) + } + assert.Len(t, images, 0) + + return nil + }) +} + +func TestImageQueryTags(t *testing.T) { + withTxn(func(r models.Repository) error { + sqb := r.Image() + tagCriterion := models.MultiCriterionInput{ + Value: []string{ + strconv.Itoa(tagIDs[tagIdxWithImage]), + strconv.Itoa(tagIDs[tagIdx1WithImage]), + }, + Modifier: models.CriterionModifierIncludes, + } + + imageFilter := models.ImageFilterType{ + Tags: &tagCriterion, + } + + images, _, err := sqb.Query(&imageFilter, nil) + if err != nil { + t.Errorf("Error querying image: %s", err.Error()) + } + + assert.Len(t, images, 2) + + // ensure ids are correct + for _, image := range images { + assert.True(t, image.ID == imageIDs[imageIdxWithTag] || image.ID == imageIDs[imageIdxWithTwoTags]) + } + + tagCriterion = models.MultiCriterionInput{ + Value: []string{ + strconv.Itoa(tagIDs[tagIdx1WithImage]), + strconv.Itoa(tagIDs[tagIdx2WithImage]), + }, + Modifier: models.CriterionModifierIncludesAll, + } + + images, _, err = sqb.Query(&imageFilter, nil) + if err != nil { + t.Errorf("Error querying image: %s", err.Error()) + } + + assert.Len(t, images, 1) + assert.Equal(t, imageIDs[imageIdxWithTwoTags], images[0].ID) + + tagCriterion = models.MultiCriterionInput{ + Value: []string{ + strconv.Itoa(tagIDs[tagIdx1WithImage]), + }, + Modifier: models.CriterionModifierExcludes, + } + + q := getImageStringValue(imageIdxWithTwoTags, titleField) + findFilter := models.FindFilterType{ + Q: &q, + } + + images, _, err = sqb.Query(&imageFilter, &findFilter) + if err != nil { + t.Errorf("Error querying image: %s", err.Error()) + } + assert.Len(t, images, 0) + + return nil + }) +} + +func TestImageQueryStudio(t *testing.T) { + withTxn(func(r models.Repository) error { + sqb := r.Image() + studioCriterion := models.MultiCriterionInput{ + Value: []string{ + strconv.Itoa(studioIDs[studioIdxWithImage]), + }, + Modifier: models.CriterionModifierIncludes, + } + + imageFilter := models.ImageFilterType{ + Studios: &studioCriterion, + } + + images, _, err := sqb.Query(&imageFilter, nil) + if err != nil { + t.Errorf("Error querying image: %s", err.Error()) + } + + assert.Len(t, images, 1) + + // ensure id is correct + assert.Equal(t, imageIDs[imageIdxWithStudio], images[0].ID) + + studioCriterion = models.MultiCriterionInput{ + Value: []string{ + strconv.Itoa(studioIDs[studioIdxWithImage]), + }, + Modifier: models.CriterionModifierExcludes, + } + + q := getImageStringValue(imageIdxWithStudio, titleField) + findFilter := models.FindFilterType{ + Q: &q, + } + + images, _, err = sqb.Query(&imageFilter, &findFilter) + if err != nil { + t.Errorf("Error querying image: %s", err.Error()) + } + assert.Len(t, images, 0) + + return nil + }) +} + +func TestImageQuerySorting(t *testing.T) { + withTxn(func(r models.Repository) error { + sort := titleField + direction := models.SortDirectionEnumAsc + findFilter := models.FindFilterType{ + Sort: &sort, + Direction: &direction, + } + + sqb := r.Image() + images, _, err := sqb.Query(nil, &findFilter) + if err != nil { + t.Errorf("Error querying image: %s", err.Error()) + } + + // images should be in same order as indexes + firstImage := images[0] + lastImage := images[len(images)-1] + + assert.Equal(t, imageIDs[0], firstImage.ID) + assert.Equal(t, imageIDs[len(imageIDs)-1], lastImage.ID) + + // sort in descending order + direction = models.SortDirectionEnumDesc + + images, _, err = sqb.Query(nil, &findFilter) + if err != nil { + t.Errorf("Error querying image: %s", err.Error()) + } + firstImage = images[0] + lastImage = images[len(images)-1] + + assert.Equal(t, imageIDs[len(imageIDs)-1], firstImage.ID) + assert.Equal(t, imageIDs[0], lastImage.ID) + + return nil + }) +} + +func TestImageQueryPagination(t *testing.T) { + withTxn(func(r models.Repository) error { + perPage := 1 + findFilter := models.FindFilterType{ + PerPage: &perPage, + } + + sqb := r.Image() + images, _, err := sqb.Query(nil, &findFilter) + if err != nil { + t.Errorf("Error querying image: %s", err.Error()) + } + + assert.Len(t, images, 1) + + firstID := images[0].ID + + page := 2 + findFilter.Page = &page + images, _, err = sqb.Query(nil, &findFilter) + if err != nil { + t.Errorf("Error querying image: %s", err.Error()) + } + + assert.Len(t, images, 1) + secondID := images[0].ID + assert.NotEqual(t, firstID, secondID) + + perPage = 2 + page = 1 + + images, _, err = sqb.Query(nil, &findFilter) + if err != nil { + t.Errorf("Error querying image: %s", err.Error()) + } + assert.Len(t, images, 2) + assert.Equal(t, firstID, images[0].ID) + assert.Equal(t, secondID, images[1].ID) + + return nil + }) +} + +// TODO Update +// TODO IncrementOCounter +// TODO DecrementOCounter +// TODO ResetOCounter +// TODO Destroy +// TODO FindByChecksum +// TODO Count +// TODO SizeCount +// TODO All diff --git a/pkg/sqlite/movies.go b/pkg/sqlite/movies.go new file mode 100644 index 000000000..40565aeac --- /dev/null +++ b/pkg/sqlite/movies.go @@ -0,0 +1,263 @@ +package sqlite + +import ( + "database/sql" + "fmt" + + "github.com/stashapp/stash/pkg/models" +) + +const movieTable = "movies" + +type movieQueryBuilder struct { + repository +} + +func NewMovieReaderWriter(tx dbi) *movieQueryBuilder { + return &movieQueryBuilder{ + repository{ + tx: tx, + tableName: movieTable, + idColumn: idColumn, + }, + } +} + +func (qb *movieQueryBuilder) Create(newObject models.Movie) (*models.Movie, error) { + var ret models.Movie + if err := qb.insertObject(newObject, &ret); err != nil { + return nil, err + } + + return &ret, nil +} + +func (qb *movieQueryBuilder) Update(updatedObject models.MoviePartial) (*models.Movie, error) { + const partial = true + if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil { + return nil, err + } + + return qb.Find(updatedObject.ID) +} + +func (qb *movieQueryBuilder) UpdateFull(updatedObject models.Movie) (*models.Movie, error) { + const partial = false + if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil { + return nil, err + } + + return qb.Find(updatedObject.ID) +} + +func (qb *movieQueryBuilder) Destroy(id int) error { + return qb.destroyExisting([]int{id}) +} + +func (qb *movieQueryBuilder) Find(id int) (*models.Movie, error) { + var ret models.Movie + if err := qb.get(id, &ret); err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, err + } + return &ret, nil +} + +func (qb *movieQueryBuilder) FindMany(ids []int) ([]*models.Movie, error) { + var movies []*models.Movie + for _, id := range ids { + movie, err := qb.Find(id) + if err != nil { + return nil, err + } + + if movie == nil { + return nil, fmt.Errorf("movie with id %d not found", id) + } + + movies = append(movies, movie) + } + + return movies, nil +} + +func (qb *movieQueryBuilder) FindByName(name string, nocase bool) (*models.Movie, error) { + query := "SELECT * FROM movies WHERE name = ?" + if nocase { + query += " COLLATE NOCASE" + } + query += " LIMIT 1" + args := []interface{}{name} + return qb.queryMovie(query, args) +} + +func (qb *movieQueryBuilder) FindByNames(names []string, nocase bool) ([]*models.Movie, error) { + query := "SELECT * FROM movies WHERE name" + if nocase { + query += " COLLATE NOCASE" + } + query += " IN " + getInBinding(len(names)) + var args []interface{} + for _, name := range names { + args = append(args, name) + } + return qb.queryMovies(query, args) +} + +func (qb *movieQueryBuilder) Count() (int, error) { + return qb.runCountQuery(qb.buildCountQuery("SELECT movies.id FROM movies"), nil) +} + +func (qb *movieQueryBuilder) All() ([]*models.Movie, error) { + return qb.queryMovies(selectAll("movies")+qb.getMovieSort(nil), nil) +} + +func (qb *movieQueryBuilder) AllSlim() ([]*models.Movie, error) { + return qb.queryMovies("SELECT movies.id, movies.name FROM movies "+qb.getMovieSort(nil), nil) +} + +func (qb *movieQueryBuilder) Query(movieFilter *models.MovieFilterType, findFilter *models.FindFilterType) ([]*models.Movie, int, error) { + if findFilter == nil { + findFilter = &models.FindFilterType{} + } + if movieFilter == nil { + movieFilter = &models.MovieFilterType{} + } + + var whereClauses []string + var havingClauses []string + var args []interface{} + body := selectDistinctIDs("movies") + body += ` + left join movies_scenes as scenes_join on scenes_join.movie_id = movies.id + left join scenes on scenes_join.scene_id = scenes.id + left join studios as studio on studio.id = movies.studio_id +` + + if q := findFilter.Q; q != nil && *q != "" { + searchColumns := []string{"movies.name"} + clause, thisArgs := getSearchBinding(searchColumns, *q, false) + whereClauses = append(whereClauses, clause) + args = append(args, thisArgs...) + } + + if studiosFilter := movieFilter.Studios; studiosFilter != nil && len(studiosFilter.Value) > 0 { + for _, studioID := range studiosFilter.Value { + args = append(args, studioID) + } + + whereClause, havingClause := getMultiCriterionClause("movies", "studio", "", "", "studio_id", studiosFilter) + whereClauses = appendClause(whereClauses, whereClause) + havingClauses = appendClause(havingClauses, havingClause) + } + + if isMissingFilter := movieFilter.IsMissing; isMissingFilter != nil && *isMissingFilter != "" { + switch *isMissingFilter { + case "front_image": + body += `left join movies_images on movies_images.movie_id = movies.id + ` + whereClauses = appendClause(whereClauses, "movies_images.front_image IS NULL") + case "back_image": + body += `left join movies_images on movies_images.movie_id = movies.id + ` + whereClauses = appendClause(whereClauses, "movies_images.back_image IS NULL") + case "scenes": + body += `left join movies_scenes on movies_scenes.movie_id = movies.id + ` + whereClauses = appendClause(whereClauses, "movies_scenes.scene_id IS NULL") + default: + whereClauses = appendClause(whereClauses, "movies."+*isMissingFilter+" IS NULL") + } + } + + sortAndPagination := qb.getMovieSort(findFilter) + getPagination(findFilter) + idsResult, countResult, err := qb.executeFindQuery(body, args, sortAndPagination, whereClauses, havingClauses) + if err != nil { + return nil, 0, err + } + + var movies []*models.Movie + for _, id := range idsResult { + movie, err := qb.Find(id) + if err != nil { + return nil, 0, err + } + + movies = append(movies, movie) + } + + return movies, countResult, nil +} + +func (qb *movieQueryBuilder) getMovieSort(findFilter *models.FindFilterType) string { + var sort string + var direction string + if findFilter == nil { + sort = "name" + direction = "ASC" + } else { + sort = findFilter.GetSort("name") + direction = findFilter.GetDirection() + } + + // #943 - override name sorting to use natural sort + if sort == "name" { + return " ORDER BY " + getColumn("movies", sort) + " COLLATE NATURAL_CS " + direction + } + + return getSort(sort, direction, "movies") +} + +func (qb *movieQueryBuilder) queryMovie(query string, args []interface{}) (*models.Movie, error) { + results, err := qb.queryMovies(query, args) + if err != nil || len(results) < 1 { + return nil, err + } + return results[0], nil +} + +func (qb *movieQueryBuilder) queryMovies(query string, args []interface{}) ([]*models.Movie, error) { + var ret models.Movies + if err := qb.query(query, args, &ret); err != nil { + return nil, err + } + + return []*models.Movie(ret), nil +} + +func (qb *movieQueryBuilder) UpdateImages(movieID int, frontImage []byte, backImage []byte) error { + // Delete the existing cover and then create new + if err := qb.DestroyImages(movieID); err != nil { + return err + } + + _, err := qb.tx.Exec( + `INSERT INTO movies_images (movie_id, front_image, back_image) VALUES (?, ?, ?)`, + movieID, + frontImage, + backImage, + ) + + return err +} + +func (qb *movieQueryBuilder) DestroyImages(movieID int) error { + // Delete the existing joins + _, err := qb.tx.Exec("DELETE FROM movies_images WHERE movie_id = ?", movieID) + if err != nil { + return err + } + return err +} + +func (qb *movieQueryBuilder) GetFrontImage(movieID int) ([]byte, error) { + query := `SELECT front_image from movies_images WHERE movie_id = ?` + return getImage(qb.tx, query, movieID) +} + +func (qb *movieQueryBuilder) GetBackImage(movieID int) ([]byte, error) { + query := `SELECT back_image from movies_images WHERE movie_id = ?` + return getImage(qb.tx, query, movieID) +} diff --git a/pkg/sqlite/movies_test.go b/pkg/sqlite/movies_test.go new file mode 100644 index 000000000..235f0184c --- /dev/null +++ b/pkg/sqlite/movies_test.go @@ -0,0 +1,241 @@ +// +build integration + +package sqlite_test + +import ( + "database/sql" + "fmt" + "strconv" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/utils" +) + +func TestMovieFindByName(t *testing.T) { + withTxn(func(r models.Repository) error { + mqb := r.Movie() + + name := movieNames[movieIdxWithScene] // find a movie by name + + movie, err := mqb.FindByName(name, false) + + if err != nil { + t.Errorf("Error finding movies: %s", err.Error()) + } + + assert.Equal(t, movieNames[movieIdxWithScene], movie.Name.String) + + name = movieNames[movieIdxWithDupName] // find a movie by name nocase + + movie, err = mqb.FindByName(name, true) + + if err != nil { + t.Errorf("Error finding movies: %s", err.Error()) + } + // movieIdxWithDupName and movieIdxWithScene should have similar names ( only diff should be Name vs NaMe) + //movie.Name should match with movieIdxWithScene since its ID is before moveIdxWithDupName + assert.Equal(t, movieNames[movieIdxWithScene], movie.Name.String) + //movie.Name should match with movieIdxWithDupName if the check is not case sensitive + assert.Equal(t, strings.ToLower(movieNames[movieIdxWithDupName]), strings.ToLower(movie.Name.String)) + + return nil + }) +} + +func TestMovieFindByNames(t *testing.T) { + withTxn(func(r models.Repository) error { + var names []string + + mqb := r.Movie() + + names = append(names, movieNames[movieIdxWithScene]) // find movies by names + + movies, err := mqb.FindByNames(names, false) + if err != nil { + t.Errorf("Error finding movies: %s", err.Error()) + } + assert.Len(t, movies, 1) + assert.Equal(t, movieNames[movieIdxWithScene], movies[0].Name.String) + + movies, err = mqb.FindByNames(names, true) // find movies by names nocase + if err != nil { + t.Errorf("Error finding movies: %s", err.Error()) + } + assert.Len(t, movies, 2) // movieIdxWithScene and movieIdxWithDupName + assert.Equal(t, strings.ToLower(movieNames[movieIdxWithScene]), strings.ToLower(movies[0].Name.String)) + assert.Equal(t, strings.ToLower(movieNames[movieIdxWithScene]), strings.ToLower(movies[1].Name.String)) + + return nil + }) +} + +func TestMovieQueryStudio(t *testing.T) { + withTxn(func(r models.Repository) error { + mqb := r.Movie() + studioCriterion := models.MultiCriterionInput{ + Value: []string{ + strconv.Itoa(studioIDs[studioIdxWithMovie]), + }, + Modifier: models.CriterionModifierIncludes, + } + + movieFilter := models.MovieFilterType{ + Studios: &studioCriterion, + } + + movies, _, err := mqb.Query(&movieFilter, nil) + if err != nil { + t.Errorf("Error querying movie: %s", err.Error()) + } + + assert.Len(t, movies, 1) + + // ensure id is correct + assert.Equal(t, movieIDs[movieIdxWithStudio], movies[0].ID) + + studioCriterion = models.MultiCriterionInput{ + Value: []string{ + strconv.Itoa(studioIDs[studioIdxWithMovie]), + }, + Modifier: models.CriterionModifierExcludes, + } + + q := getMovieStringValue(movieIdxWithStudio, titleField) + findFilter := models.FindFilterType{ + Q: &q, + } + + movies, _, err = mqb.Query(&movieFilter, &findFilter) + if err != nil { + t.Errorf("Error querying movie: %s", err.Error()) + } + assert.Len(t, movies, 0) + + return nil + }) +} + +func TestMovieUpdateMovieImages(t *testing.T) { + if err := withTxn(func(r models.Repository) error { + mqb := r.Movie() + + // create movie to test against + const name = "TestMovieUpdateMovieImages" + movie := models.Movie{ + Name: sql.NullString{String: name, Valid: true}, + Checksum: utils.MD5FromString(name), + } + created, err := mqb.Create(movie) + if err != nil { + return fmt.Errorf("Error creating movie: %s", err.Error()) + } + + frontImage := []byte("frontImage") + backImage := []byte("backImage") + err = mqb.UpdateImages(created.ID, frontImage, backImage) + if err != nil { + return fmt.Errorf("Error updating movie images: %s", err.Error()) + } + + // ensure images are set + storedFront, err := mqb.GetFrontImage(created.ID) + if err != nil { + return fmt.Errorf("Error getting front image: %s", err.Error()) + } + assert.Equal(t, storedFront, frontImage) + + storedBack, err := mqb.GetBackImage(created.ID) + if err != nil { + return fmt.Errorf("Error getting back image: %s", err.Error()) + } + assert.Equal(t, storedBack, backImage) + + // set front image only + newImage := []byte("newImage") + err = mqb.UpdateImages(created.ID, newImage, nil) + if err != nil { + return fmt.Errorf("Error updating movie images: %s", err.Error()) + } + + storedFront, err = mqb.GetFrontImage(created.ID) + if err != nil { + return fmt.Errorf("Error getting front image: %s", err.Error()) + } + assert.Equal(t, storedFront, newImage) + + // back image should be nil + storedBack, err = mqb.GetBackImage(created.ID) + if err != nil { + return fmt.Errorf("Error getting back image: %s", err.Error()) + } + assert.Nil(t, nil) + + // set back image only + err = mqb.UpdateImages(created.ID, nil, newImage) + if err == nil { + return fmt.Errorf("Expected error setting nil front image") + } + + return nil + }); err != nil { + t.Error(err.Error()) + } +} + +func TestMovieDestroyMovieImages(t *testing.T) { + if err := withTxn(func(r models.Repository) error { + mqb := r.Movie() + + // create movie to test against + const name = "TestMovieDestroyMovieImages" + movie := models.Movie{ + Name: sql.NullString{String: name, Valid: true}, + Checksum: utils.MD5FromString(name), + } + created, err := mqb.Create(movie) + if err != nil { + return fmt.Errorf("Error creating movie: %s", err.Error()) + } + + frontImage := []byte("frontImage") + backImage := []byte("backImage") + err = mqb.UpdateImages(created.ID, frontImage, backImage) + if err != nil { + return fmt.Errorf("Error updating movie images: %s", err.Error()) + } + + err = mqb.DestroyImages(created.ID) + if err != nil { + return fmt.Errorf("Error destroying movie images: %s", err.Error()) + } + + // front image should be nil + storedFront, err := mqb.GetFrontImage(created.ID) + if err != nil { + return fmt.Errorf("Error getting front image: %s", err.Error()) + } + assert.Nil(t, storedFront) + + // back image should be nil + storedBack, err := mqb.GetBackImage(created.ID) + if err != nil { + return fmt.Errorf("Error getting back image: %s", err.Error()) + } + assert.Nil(t, storedBack) + + return nil + }); err != nil { + t.Error(err.Error()) + } +} + +// TODO Update +// TODO Destroy +// TODO Find +// TODO Count +// TODO All +// TODO Query diff --git a/pkg/models/querybuilder_performer.go b/pkg/sqlite/performer.go similarity index 56% rename from pkg/models/querybuilder_performer.go rename to pkg/sqlite/performer.go index 7810143ed..078b46e4a 100644 --- a/pkg/models/querybuilder_performer.go +++ b/pkg/sqlite/performer.go @@ -1,4 +1,4 @@ -package models +package sqlite import ( "database/sql" @@ -6,96 +6,86 @@ import ( "strconv" "time" - "github.com/jmoiron/sqlx" - "github.com/stashapp/stash/pkg/database" + "github.com/stashapp/stash/pkg/models" ) -type PerformerQueryBuilder struct{} +const performerTable = "performers" +const performerIDColumn = "performer_id" -func NewPerformerQueryBuilder() PerformerQueryBuilder { - return PerformerQueryBuilder{} +type performerQueryBuilder struct { + repository } -func (qb *PerformerQueryBuilder) Create(newPerformer Performer, tx *sqlx.Tx) (*Performer, error) { - ensureTx(tx) - result, err := tx.NamedExec( - `INSERT INTO performers (checksum, name, url, gender, twitter, instagram, birthdate, ethnicity, country, - eye_color, height, measurements, fake_tits, career_length, tattoos, piercings, - aliases, favorite, created_at, updated_at) - VALUES (:checksum, :name, :url, :gender, :twitter, :instagram, :birthdate, :ethnicity, :country, - :eye_color, :height, :measurements, :fake_tits, :career_length, :tattoos, :piercings, - :aliases, :favorite, :created_at, :updated_at) - `, - newPerformer, - ) - if err != nil { - return nil, err +func NewPerformerReaderWriter(tx dbi) *performerQueryBuilder { + return &performerQueryBuilder{ + repository{ + tx: tx, + tableName: performerTable, + idColumn: idColumn, + }, } - performerID, err := result.LastInsertId() - if err != nil { - return nil, err - } - - if err := tx.Get(&newPerformer, `SELECT * FROM performers WHERE id = ? LIMIT 1`, performerID); err != nil { - return nil, err - } - return &newPerformer, nil } -func (qb *PerformerQueryBuilder) Update(updatedPerformer PerformerPartial, tx *sqlx.Tx) (*Performer, error) { - ensureTx(tx) - _, err := tx.NamedExec( - `UPDATE performers SET `+SQLGenKeysPartial(updatedPerformer)+` WHERE performers.id = :id`, - updatedPerformer, - ) - if err != nil { +func (qb *performerQueryBuilder) Create(newObject models.Performer) (*models.Performer, error) { + var ret models.Performer + if err := qb.insertObject(newObject, &ret); err != nil { return nil, err } - var ret Performer - if err := tx.Get(&ret, `SELECT * FROM performers WHERE id = ? LIMIT 1`, updatedPerformer.ID); err != nil { + return &ret, nil +} + +func (qb *performerQueryBuilder) Update(updatedObject models.PerformerPartial) (*models.Performer, error) { + const partial = true + if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil { + return nil, err + } + + var ret models.Performer + if err := qb.get(updatedObject.ID, &ret); err != nil { + return nil, err + } + + return &ret, nil +} + +func (qb *performerQueryBuilder) UpdateFull(updatedObject models.Performer) (*models.Performer, error) { + const partial = false + if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil { + return nil, err + } + + var ret models.Performer + if err := qb.get(updatedObject.ID, &ret); err != nil { + return nil, err + } + + return &ret, nil +} + +func (qb *performerQueryBuilder) Destroy(id int) error { + // TODO - add on delete cascade to performers_scenes + _, err := qb.tx.Exec("DELETE FROM performers_scenes WHERE performer_id = ?", id) + if err != nil { + return err + } + + return qb.destroyExisting([]int{id}) +} + +func (qb *performerQueryBuilder) Find(id int) (*models.Performer, error) { + var ret models.Performer + if err := qb.get(id, &ret); err != nil { + if err == sql.ErrNoRows { + return nil, nil + } return nil, err } return &ret, nil } -func (qb *PerformerQueryBuilder) UpdateFull(updatedPerformer Performer, tx *sqlx.Tx) (*Performer, error) { - ensureTx(tx) - _, err := tx.NamedExec( - `UPDATE performers SET `+SQLGenKeys(updatedPerformer)+` WHERE performers.id = :id`, - updatedPerformer, - ) - if err != nil { - return nil, err - } - - if err := tx.Get(&updatedPerformer, `SELECT * FROM performers WHERE id = ? LIMIT 1`, updatedPerformer.ID); err != nil { - return nil, err - } - return &updatedPerformer, nil -} - -func (qb *PerformerQueryBuilder) Destroy(id string, tx *sqlx.Tx) error { - _, err := tx.Exec("DELETE FROM performers_scenes WHERE performer_id = ?", id) - if err != nil { - return err - } - - return executeDeleteQuery("performers", id, tx) -} - -func (qb *PerformerQueryBuilder) Find(id int) (*Performer, error) { - query := "SELECT * FROM performers WHERE id = ? LIMIT 1" - args := []interface{}{id} - results, err := qb.queryPerformers(query, args, nil) - if err != nil || len(results) < 1 { - return nil, err - } - return results[0], nil -} - -func (qb *PerformerQueryBuilder) FindMany(ids []int) ([]*Performer, error) { - var performers []*Performer +func (qb *performerQueryBuilder) FindMany(ids []int) ([]*models.Performer, error) { + var performers []*models.Performer for _, id := range ids { performer, err := qb.Find(id) if err != nil { @@ -112,44 +102,44 @@ func (qb *PerformerQueryBuilder) FindMany(ids []int) ([]*Performer, error) { return performers, nil } -func (qb *PerformerQueryBuilder) FindBySceneID(sceneID int, tx *sqlx.Tx) ([]*Performer, error) { +func (qb *performerQueryBuilder) FindBySceneID(sceneID int) ([]*models.Performer, error) { query := selectAll("performers") + ` LEFT JOIN performers_scenes as scenes_join on scenes_join.performer_id = performers.id WHERE scenes_join.scene_id = ? ` args := []interface{}{sceneID} - return qb.queryPerformers(query, args, tx) + return qb.queryPerformers(query, args) } -func (qb *PerformerQueryBuilder) FindByImageID(imageID int, tx *sqlx.Tx) ([]*Performer, error) { +func (qb *performerQueryBuilder) FindByImageID(imageID int) ([]*models.Performer, error) { query := selectAll("performers") + ` LEFT JOIN performers_images as images_join on images_join.performer_id = performers.id WHERE images_join.image_id = ? ` args := []interface{}{imageID} - return qb.queryPerformers(query, args, tx) + return qb.queryPerformers(query, args) } -func (qb *PerformerQueryBuilder) FindByGalleryID(galleryID int, tx *sqlx.Tx) ([]*Performer, error) { +func (qb *performerQueryBuilder) FindByGalleryID(galleryID int) ([]*models.Performer, error) { query := selectAll("performers") + ` LEFT JOIN performers_galleries as galleries_join on galleries_join.performer_id = performers.id WHERE galleries_join.gallery_id = ? ` args := []interface{}{galleryID} - return qb.queryPerformers(query, args, tx) + return qb.queryPerformers(query, args) } -func (qb *PerformerQueryBuilder) FindNameBySceneID(sceneID int, tx *sqlx.Tx) ([]*Performer, error) { +func (qb *performerQueryBuilder) FindNamesBySceneID(sceneID int) ([]*models.Performer, error) { query := ` SELECT performers.name FROM performers LEFT JOIN performers_scenes as scenes_join on scenes_join.performer_id = performers.id WHERE scenes_join.scene_id = ? ` args := []interface{}{sceneID} - return qb.queryPerformers(query, args, tx) + return qb.queryPerformers(query, args) } -func (qb *PerformerQueryBuilder) FindByNames(names []string, tx *sqlx.Tx, nocase bool) ([]*Performer, error) { +func (qb *performerQueryBuilder) FindByNames(names []string, nocase bool) ([]*models.Performer, error) { query := "SELECT * FROM performers WHERE name" if nocase { query += " COLLATE NOCASE" @@ -160,33 +150,31 @@ func (qb *PerformerQueryBuilder) FindByNames(names []string, tx *sqlx.Tx, nocase for _, name := range names { args = append(args, name) } - return qb.queryPerformers(query, args, tx) + return qb.queryPerformers(query, args) } -func (qb *PerformerQueryBuilder) Count() (int, error) { - return runCountQuery(buildCountQuery("SELECT performers.id FROM performers"), nil) +func (qb *performerQueryBuilder) Count() (int, error) { + return qb.runCountQuery(qb.buildCountQuery("SELECT performers.id FROM performers"), nil) } -func (qb *PerformerQueryBuilder) All() ([]*Performer, error) { - return qb.queryPerformers(selectAll("performers")+qb.getPerformerSort(nil), nil, nil) +func (qb *performerQueryBuilder) All() ([]*models.Performer, error) { + return qb.queryPerformers(selectAll("performers")+qb.getPerformerSort(nil), nil) } -func (qb *PerformerQueryBuilder) AllSlim() ([]*Performer, error) { - return qb.queryPerformers("SELECT performers.id, performers.name, performers.gender FROM performers "+qb.getPerformerSort(nil), nil, nil) +func (qb *performerQueryBuilder) AllSlim() ([]*models.Performer, error) { + return qb.queryPerformers("SELECT performers.id, performers.name, performers.gender FROM performers "+qb.getPerformerSort(nil), nil) } -func (qb *PerformerQueryBuilder) Query(performerFilter *PerformerFilterType, findFilter *FindFilterType) ([]*Performer, int) { +func (qb *performerQueryBuilder) Query(performerFilter *models.PerformerFilterType, findFilter *models.FindFilterType) ([]*models.Performer, int, error) { if performerFilter == nil { - performerFilter = &PerformerFilterType{} + performerFilter = &models.PerformerFilterType{} } if findFilter == nil { - findFilter = &FindFilterType{} + findFilter = &models.FindFilterType{} } tableName := "performers" - query := queryBuilder{ - tableName: tableName, - } + query := qb.newQuery() query.body = selectDistinctIDs(tableName) query.body += ` @@ -263,18 +251,24 @@ func (qb *PerformerQueryBuilder) Query(performerFilter *PerformerFilterType, fin query.handleStringCriterionInput(performerFilter.Aliases, tableName+".aliases") query.sortAndPagination = qb.getPerformerSort(findFilter) + getPagination(findFilter) - idsResult, countResult := query.executeFind() + idsResult, countResult, err := query.executeFind() + if err != nil { + return nil, 0, err + } - var performers []*Performer + var performers []*models.Performer for _, id := range idsResult { - performer, _ := qb.Find(id) + performer, err := qb.Find(id) + if err != nil { + return nil, 0, err + } performers = append(performers, performer) } - return performers, countResult + return performers, countResult, nil } -func getBirthYearFilterClause(criterionModifier CriterionModifier, value int) ([]string, []interface{}) { +func getBirthYearFilterClause(criterionModifier models.CriterionModifier, value int) ([]string, []interface{}) { var clauses []string var args []interface{} @@ -309,7 +303,7 @@ func getBirthYearFilterClause(criterionModifier CriterionModifier, value int) ([ return clauses, args } -func getAgeFilterClause(criterionModifier CriterionModifier, value int) ([]string, []interface{}) { +func getAgeFilterClause(criterionModifier models.CriterionModifier, value int) ([]string, []interface{}) { var clauses []string var args []interface{} @@ -345,7 +339,7 @@ func getAgeFilterClause(criterionModifier CriterionModifier, value int) ([]strin return clauses, args } -func (qb *PerformerQueryBuilder) getPerformerSort(findFilter *FindFilterType) string { +func (qb *performerQueryBuilder) getPerformerSort(findFilter *models.FindFilterType) string { var sort string var direction string if findFilter == nil { @@ -358,65 +352,52 @@ func (qb *PerformerQueryBuilder) getPerformerSort(findFilter *FindFilterType) st return getSort(sort, direction, "performers") } -func (qb *PerformerQueryBuilder) queryPerformers(query string, args []interface{}, tx *sqlx.Tx) ([]*Performer, error) { - var rows *sqlx.Rows - var err error - if tx != nil { - rows, err = tx.Queryx(query, args...) - } else { - rows, err = database.DB.Queryx(query, args...) - } - - if err != nil && err != sql.ErrNoRows { - return nil, err - } - defer rows.Close() - - performers := make([]*Performer, 0) - for rows.Next() { - performer := Performer{} - if err := rows.StructScan(&performer); err != nil { - return nil, err - } - performers = append(performers, &performer) - } - - if err := rows.Err(); err != nil { +func (qb *performerQueryBuilder) queryPerformers(query string, args []interface{}) ([]*models.Performer, error) { + var ret models.Performers + if err := qb.query(query, args, &ret); err != nil { return nil, err } - return performers, nil + return []*models.Performer(ret), nil } -func (qb *PerformerQueryBuilder) UpdatePerformerImage(performerID int, image []byte, tx *sqlx.Tx) error { - ensureTx(tx) - - // Delete the existing cover and then create new - if err := qb.DestroyPerformerImage(performerID, tx); err != nil { - return err +func (qb *performerQueryBuilder) imageRepository() *imageRepository { + return &imageRepository{ + repository: repository{ + tx: qb.tx, + tableName: "performers_image", + idColumn: performerIDColumn, + }, + imageColumn: "image", } - - _, err := tx.Exec( - `INSERT INTO performers_image (performer_id, image) VALUES (?, ?)`, - performerID, - image, - ) - - return err } -func (qb *PerformerQueryBuilder) DestroyPerformerImage(performerID int, tx *sqlx.Tx) error { - ensureTx(tx) +func (qb *performerQueryBuilder) GetImage(performerID int) ([]byte, error) { + return qb.imageRepository().get(performerID) +} - // Delete the existing joins - _, err := tx.Exec("DELETE FROM performers_image WHERE performer_id = ?", performerID) - if err != nil { - return err +func (qb *performerQueryBuilder) UpdateImage(performerID int, image []byte) error { + return qb.imageRepository().replace(performerID, image) +} + +func (qb *performerQueryBuilder) DestroyImage(performerID int) error { + return qb.imageRepository().destroy([]int{performerID}) +} + +func (qb *performerQueryBuilder) stashIDRepository() *stashIDRepository { + return &stashIDRepository{ + repository{ + tx: qb.tx, + tableName: "performer_stash_ids", + idColumn: performerIDColumn, + }, } - return err } -func (qb *PerformerQueryBuilder) GetPerformerImage(performerID int, tx *sqlx.Tx) ([]byte, error) { - query := `SELECT image from performers_image WHERE performer_id = ?` - return getImage(tx, query, performerID) +func (qb *performerQueryBuilder) GetStashIDs(performerID int) ([]*models.StashID, error) { + return qb.stashIDRepository().get(performerID) +} + +func (qb *performerQueryBuilder) UpdateStashIDs(performerID int, stashIDs []models.StashID) error { + return qb.stashIDRepository().replace(performerID, stashIDs) } diff --git a/pkg/sqlite/performer_test.go b/pkg/sqlite/performer_test.go new file mode 100644 index 000000000..530430403 --- /dev/null +++ b/pkg/sqlite/performer_test.go @@ -0,0 +1,250 @@ +// +build integration + +package sqlite_test + +import ( + "database/sql" + "fmt" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/utils" +) + +func TestPerformerFindBySceneID(t *testing.T) { + withTxn(func(r models.Repository) error { + pqb := r.Performer() + sceneID := sceneIDs[sceneIdxWithPerformer] + + performers, err := pqb.FindBySceneID(sceneID) + + if err != nil { + t.Errorf("Error finding performer: %s", err.Error()) + } + + assert.Equal(t, 1, len(performers)) + performer := performers[0] + + assert.Equal(t, getPerformerStringValue(performerIdxWithScene, "Name"), performer.Name.String) + + performers, err = pqb.FindBySceneID(0) + + if err != nil { + t.Errorf("Error finding performer: %s", err.Error()) + } + + assert.Equal(t, 0, len(performers)) + + return nil + }) +} + +func TestPerformerFindByNames(t *testing.T) { + withTxn(func(r models.Repository) error { + var names []string + + pqb := r.Performer() + + names = append(names, performerNames[performerIdxWithScene]) // find performers by names + + performers, err := pqb.FindByNames(names, false) + if err != nil { + t.Errorf("Error finding performers: %s", err.Error()) + } + assert.Len(t, performers, 1) + assert.Equal(t, performerNames[performerIdxWithScene], performers[0].Name.String) + + performers, err = pqb.FindByNames(names, true) // find performers by names nocase + if err != nil { + t.Errorf("Error finding performers: %s", err.Error()) + } + assert.Len(t, performers, 2) // performerIdxWithScene and performerIdxWithDupName + assert.Equal(t, strings.ToLower(performerNames[performerIdxWithScene]), strings.ToLower(performers[0].Name.String)) + assert.Equal(t, strings.ToLower(performerNames[performerIdxWithScene]), strings.ToLower(performers[1].Name.String)) + + names = append(names, performerNames[performerIdx1WithScene]) // find performers by names ( 2 names ) + + performers, err = pqb.FindByNames(names, false) + if err != nil { + t.Errorf("Error finding performers: %s", err.Error()) + } + assert.Len(t, performers, 2) // performerIdxWithScene and performerIdx1WithScene + assert.Equal(t, performerNames[performerIdxWithScene], performers[0].Name.String) + assert.Equal(t, performerNames[performerIdx1WithScene], performers[1].Name.String) + + performers, err = pqb.FindByNames(names, true) // find performers by names ( 2 names nocase) + if err != nil { + t.Errorf("Error finding performers: %s", err.Error()) + } + assert.Len(t, performers, 4) // performerIdxWithScene and performerIdxWithDupName , performerIdx1WithScene and performerIdx1WithDupName + assert.Equal(t, performerNames[performerIdxWithScene], performers[0].Name.String) + assert.Equal(t, performerNames[performerIdx1WithScene], performers[1].Name.String) + assert.Equal(t, performerNames[performerIdx1WithDupName], performers[2].Name.String) + assert.Equal(t, performerNames[performerIdxWithDupName], performers[3].Name.String) + + return nil + }) +} + +func TestPerformerUpdatePerformerImage(t *testing.T) { + if err := withTxn(func(r models.Repository) error { + qb := r.Performer() + + // create performer to test against + const name = "TestPerformerUpdatePerformerImage" + performer := models.Performer{ + Name: sql.NullString{String: name, Valid: true}, + Checksum: utils.MD5FromString(name), + Favorite: sql.NullBool{Bool: false, Valid: true}, + } + created, err := qb.Create(performer) + if err != nil { + return fmt.Errorf("Error creating performer: %s", err.Error()) + } + + image := []byte("image") + err = qb.UpdateImage(created.ID, image) + if err != nil { + return fmt.Errorf("Error updating performer image: %s", err.Error()) + } + + // ensure image set + storedImage, err := qb.GetImage(created.ID) + if err != nil { + return fmt.Errorf("Error getting image: %s", err.Error()) + } + assert.Equal(t, storedImage, image) + + // set nil image + err = qb.UpdateImage(created.ID, nil) + if err == nil { + return fmt.Errorf("Expected error setting nil image") + } + + return nil + }); err != nil { + t.Error(err.Error()) + } +} + +func TestPerformerDestroyPerformerImage(t *testing.T) { + if err := withTxn(func(r models.Repository) error { + qb := r.Performer() + + // create performer to test against + const name = "TestPerformerDestroyPerformerImage" + performer := models.Performer{ + Name: sql.NullString{String: name, Valid: true}, + Checksum: utils.MD5FromString(name), + Favorite: sql.NullBool{Bool: false, Valid: true}, + } + created, err := qb.Create(performer) + if err != nil { + return fmt.Errorf("Error creating performer: %s", err.Error()) + } + + image := []byte("image") + err = qb.UpdateImage(created.ID, image) + if err != nil { + return fmt.Errorf("Error updating performer image: %s", err.Error()) + } + + err = qb.DestroyImage(created.ID) + if err != nil { + return fmt.Errorf("Error destroying performer image: %s", err.Error()) + } + + // image should be nil + storedImage, err := qb.GetImage(created.ID) + if err != nil { + return fmt.Errorf("Error getting image: %s", err.Error()) + } + assert.Nil(t, storedImage) + + return nil + }); err != nil { + t.Error(err.Error()) + } +} + +func TestPerformerQueryAge(t *testing.T) { + const age = 19 + ageCriterion := models.IntCriterionInput{ + Value: age, + Modifier: models.CriterionModifierEquals, + } + + verifyPerformerAge(t, ageCriterion) + + ageCriterion.Modifier = models.CriterionModifierNotEquals + verifyPerformerAge(t, ageCriterion) + + ageCriterion.Modifier = models.CriterionModifierGreaterThan + verifyPerformerAge(t, ageCriterion) + + ageCriterion.Modifier = models.CriterionModifierLessThan + verifyPerformerAge(t, ageCriterion) +} + +func verifyPerformerAge(t *testing.T, ageCriterion models.IntCriterionInput) { + withTxn(func(r models.Repository) error { + qb := r.Performer() + performerFilter := models.PerformerFilterType{ + Age: &ageCriterion, + } + + performers, _, err := qb.Query(&performerFilter, nil) + if err != nil { + t.Errorf("Error querying performer: %s", err.Error()) + } + + now := time.Now() + for _, performer := range performers { + bd := performer.Birthdate.String + d, _ := time.Parse("2006-01-02", bd) + age := now.Year() - d.Year() + if now.YearDay() < d.YearDay() { + age = age - 1 + } + + verifyInt(t, age, ageCriterion) + } + + return nil + }) +} + +func TestPerformerStashIDs(t *testing.T) { + if err := withTxn(func(r models.Repository) error { + qb := r.Performer() + + // create performer to test against + const name = "TestStashIDs" + performer := models.Performer{ + Name: sql.NullString{String: name, Valid: true}, + Checksum: utils.MD5FromString(name), + Favorite: sql.NullBool{Bool: false, Valid: true}, + } + created, err := qb.Create(performer) + if err != nil { + return fmt.Errorf("Error creating performer: %s", err.Error()) + } + + testStashIDReaderWriter(t, qb, created.ID) + return nil + }); err != nil { + t.Error(err.Error()) + } +} + +// TODO Update +// TODO Destroy +// TODO Find +// TODO Count +// TODO All +// TODO AllSlim +// TODO Query diff --git a/pkg/sqlite/query.go b/pkg/sqlite/query.go new file mode 100644 index 000000000..58e5420de --- /dev/null +++ b/pkg/sqlite/query.go @@ -0,0 +1,78 @@ +package sqlite + +import "github.com/stashapp/stash/pkg/models" + +type queryBuilder struct { + repository *repository + + body string + + whereClauses []string + havingClauses []string + args []interface{} + + sortAndPagination string +} + +func (qb queryBuilder) executeFind() ([]int, int, error) { + return qb.repository.executeFindQuery(qb.body, qb.args, qb.sortAndPagination, qb.whereClauses, qb.havingClauses) +} + +func (qb *queryBuilder) addWhere(clauses ...string) { + for _, clause := range clauses { + if len(clause) > 0 { + qb.whereClauses = append(qb.whereClauses, clause) + } + } +} + +func (qb *queryBuilder) addHaving(clauses ...string) { + for _, clause := range clauses { + if len(clause) > 0 { + qb.havingClauses = append(qb.havingClauses, clause) + } + } +} + +func (qb *queryBuilder) addArg(args ...interface{}) { + qb.args = append(qb.args, args...) +} + +func (qb *queryBuilder) handleIntCriterionInput(c *models.IntCriterionInput, column string) { + if c != nil { + clause, count := getIntCriterionWhereClause(column, *c) + qb.addWhere(clause) + if count == 1 { + qb.addArg(c.Value) + } + } +} + +func (qb *queryBuilder) handleStringCriterionInput(c *models.StringCriterionInput, column string) { + if c != nil { + if modifier := c.Modifier; c.Modifier.IsValid() { + switch modifier { + case models.CriterionModifierIncludes: + clause, thisArgs := getSearchBinding([]string{column}, c.Value, false) + qb.addWhere(clause) + qb.addArg(thisArgs...) + case models.CriterionModifierExcludes: + clause, thisArgs := getSearchBinding([]string{column}, c.Value, true) + qb.addWhere(clause) + qb.addArg(thisArgs...) + case models.CriterionModifierEquals: + qb.addWhere(column + " LIKE ?") + qb.addArg(c.Value) + case models.CriterionModifierNotEquals: + qb.addWhere(column + " NOT LIKE ?") + qb.addArg(c.Value) + default: + clause, count := getSimpleCriterionClause(modifier, "?") + qb.addWhere(column + " " + clause) + if count == 1 { + qb.addArg(c.Value) + } + } + } + } +} diff --git a/pkg/sqlite/repository.go b/pkg/sqlite/repository.go new file mode 100644 index 000000000..99d2eeb01 --- /dev/null +++ b/pkg/sqlite/repository.go @@ -0,0 +1,412 @@ +package sqlite + +import ( + "database/sql" + "fmt" + "reflect" + "strings" + + "github.com/jmoiron/sqlx" + + "github.com/stashapp/stash/pkg/logger" + "github.com/stashapp/stash/pkg/models" +) + +const idColumn = "id" + +type objectList interface { + Append(o interface{}) + New() interface{} +} + +type repository struct { + tx dbi + tableName string + idColumn string +} + +func (r *repository) get(id int, dest interface{}) error { + stmt := fmt.Sprintf("SELECT * FROM %s WHERE %s = ? LIMIT 1", r.tableName, r.idColumn) + return r.tx.Get(dest, stmt, id) +} + +func (r *repository) getAll(id int, f func(rows *sqlx.Rows) error) error { + stmt := fmt.Sprintf("SELECT * FROM %s WHERE %s = ?", r.tableName, r.idColumn) + return r.queryFunc(stmt, []interface{}{id}, f) +} + +func (r *repository) insert(obj interface{}) (sql.Result, error) { + stmt := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", r.tableName, listKeys(obj, false), listKeys(obj, true)) + return r.tx.NamedExec(stmt, obj) +} + +func (r *repository) insertObject(obj interface{}, out interface{}) error { + result, err := r.insert(obj) + if err != nil { + return err + } + id, err := result.LastInsertId() + if err != nil { + return err + } + return r.get(int(id), out) +} + +func (r *repository) update(id int, obj interface{}, partial bool) error { + exists, err := r.exists(id) + if err != nil { + return err + } + + if !exists { + return fmt.Errorf("%s %d does not exist in %s", r.idColumn, id, r.tableName) + } + + stmt := fmt.Sprintf("UPDATE %s SET %s WHERE %s.%s = :id", r.tableName, updateSet(obj, partial), r.tableName, r.idColumn) + _, err = r.tx.NamedExec(stmt, obj) + + return err +} + +func (r *repository) updateMap(id int, m map[string]interface{}) error { + exists, err := r.exists(id) + if err != nil { + return err + } + + if !exists { + return fmt.Errorf("%s %d does not exist in %s", r.idColumn, id, r.tableName) + } + + stmt := fmt.Sprintf("UPDATE %s SET %s WHERE %s.%s = :id", r.tableName, updateSetMap(m), r.tableName, r.idColumn) + _, err = r.tx.NamedExec(stmt, m) + + return err +} + +func (r *repository) destroyExisting(ids []int) error { + for _, id := range ids { + exists, err := r.exists(id) + if err != nil { + return err + } + + if !exists { + return fmt.Errorf("%s %d does not exist in %s", r.idColumn, id, r.tableName) + } + } + + return r.destroy(ids) +} + +func (r *repository) destroy(ids []int) error { + for _, id := range ids { + stmt := fmt.Sprintf("DELETE FROM %s WHERE %s = ?", r.tableName, r.idColumn) + if _, err := r.tx.Exec(stmt, id); err != nil { + return err + } + } + + return nil +} + +func (r *repository) exists(id int) (bool, error) { + stmt := fmt.Sprintf("SELECT %s FROM %s WHERE %s = ? LIMIT 1", r.idColumn, r.tableName, r.idColumn) + stmt = r.buildCountQuery(stmt) + + c, err := r.runCountQuery(stmt, []interface{}{id}) + if err != nil { + return false, err + } + + return c == 1, nil +} + +func (r *repository) buildCountQuery(query string) string { + return "SELECT COUNT(*) as count FROM (" + query + ") as temp" +} + +func (r *repository) runCountQuery(query string, args []interface{}) (int, error) { + result := struct { + Int int `db:"count"` + }{0} + + // Perform query and fetch result + if err := r.tx.Get(&result, query, args...); err != nil && err != sql.ErrNoRows { + return 0, err + } + + return result.Int, nil +} + +func (r *repository) runIdsQuery(query string, args []interface{}) ([]int, error) { + var result []struct { + Int int `db:"id"` + } + + if err := r.tx.Select(&result, query, args...); err != nil && err != sql.ErrNoRows { + return []int{}, err + } + + vsm := make([]int, len(result)) + for i, v := range result { + vsm[i] = v.Int + } + return vsm, nil +} + +func (r *repository) runSumQuery(query string, args []interface{}) (float64, error) { + // Perform query and fetch result + result := struct { + Float64 float64 `db:"sum"` + }{0} + + // Perform query and fetch result + if err := r.tx.Get(&result, query, args...); err != nil && err != sql.ErrNoRows { + return 0, err + } + + return result.Float64, nil +} + +func (r *repository) queryFunc(query string, args []interface{}, f func(rows *sqlx.Rows) error) error { + rows, err := r.tx.Queryx(query, args...) + + if err != nil && err != sql.ErrNoRows { + return err + } + defer rows.Close() + + for rows.Next() { + if err := f(rows); err != nil { + return err + } + } + + if err := rows.Err(); err != nil { + return err + } + + return nil +} + +func (r *repository) query(query string, args []interface{}, out objectList) error { + rows, err := r.tx.Queryx(query, args...) + + if err != nil && err != sql.ErrNoRows { + return err + } + defer rows.Close() + + for rows.Next() { + object := out.New() + if err := rows.StructScan(object); err != nil { + return err + } + out.Append(object) + } + + if err := rows.Err(); err != nil { + return err + } + + return nil +} + +func (r *repository) querySimple(query string, args []interface{}, out interface{}) error { + rows, err := r.tx.Queryx(query, args...) + + if err != nil && err != sql.ErrNoRows { + return err + } + defer rows.Close() + + if rows.Next() { + if err := rows.Scan(out); err != nil { + return err + } + } + + if err := rows.Err(); err != nil { + return err + } + + return nil +} + +func (r *repository) executeFindQuery(body string, args []interface{}, sortAndPagination string, whereClauses []string, havingClauses []string) ([]int, int, error) { + if len(whereClauses) > 0 { + body = body + " WHERE " + strings.Join(whereClauses, " AND ") // TODO handle AND or OR + } + body = body + " GROUP BY " + r.tableName + ".id " + if len(havingClauses) > 0 { + body = body + " HAVING " + strings.Join(havingClauses, " AND ") // TODO handle AND or OR + } + + countQuery := r.buildCountQuery(body) + idsQuery := body + sortAndPagination + + // Perform query and fetch result + logger.Tracef("SQL: %s, args: %v", idsQuery, args) + + var countResult int + var countErr error + var idsResult []int + var idsErr error + + countResult, countErr = r.runCountQuery(countQuery, args) + idsResult, idsErr = r.runIdsQuery(idsQuery, args) + + if countErr != nil { + return nil, 0, fmt.Errorf("Error executing count query with SQL: %s, args: %v, error: %s", countQuery, args, countErr.Error()) + } + if idsErr != nil { + return nil, 0, fmt.Errorf("Error executing find query with SQL: %s, args: %v, error: %s", idsQuery, args, idsErr.Error()) + } + + return idsResult, countResult, nil +} + +func (r *repository) newQuery() queryBuilder { + return queryBuilder{ + repository: r, + } +} + +type joinRepository struct { + repository + fkColumn string +} + +func (r *joinRepository) getIDs(id int) ([]int, error) { + query := fmt.Sprintf(`SELECT %s as id from %s WHERE %s = ?`, r.fkColumn, r.tableName, r.idColumn) + return r.runIdsQuery(query, []interface{}{id}) +} + +func (r *joinRepository) insert(id, foreignID int) (sql.Result, error) { + stmt := fmt.Sprintf("INSERT INTO %s (%s, %s) VALUES (?, ?)", r.tableName, r.idColumn, r.fkColumn) + return r.tx.Exec(stmt, id, foreignID) +} + +func (r *joinRepository) replace(id int, foreignIDs []int) error { + if err := r.destroy([]int{id}); err != nil { + return err + } + + for _, fk := range foreignIDs { + if _, err := r.insert(id, fk); err != nil { + return err + } + } + + return nil +} + +type imageRepository struct { + repository + imageColumn string +} + +func (r *imageRepository) get(id int) ([]byte, error) { + query := fmt.Sprintf("SELECT %s from %s WHERE %s = ?", r.imageColumn, r.tableName, r.idColumn) + var ret []byte + err := r.querySimple(query, []interface{}{id}, &ret) + return ret, err +} + +func (r *imageRepository) replace(id int, image []byte) error { + if err := r.destroy([]int{id}); err != nil { + return err + } + + stmt := fmt.Sprintf("INSERT INTO %s (%s, %s) VALUES (?, ?)", r.tableName, r.idColumn, r.imageColumn) + _, err := r.tx.Exec(stmt, id, image) + + return err +} + +type stashIDRepository struct { + repository +} + +type stashIDs []*models.StashID + +func (s *stashIDs) Append(o interface{}) { + *s = append(*s, o.(*models.StashID)) +} + +func (s *stashIDs) New() interface{} { + return &models.StashID{} +} + +func (r *stashIDRepository) get(id int) ([]*models.StashID, error) { + query := fmt.Sprintf("SELECT stash_id, endpoint from %s WHERE %s = ?", r.tableName, r.idColumn) + var ret stashIDs + err := r.query(query, []interface{}{id}, &ret) + return []*models.StashID(ret), err +} + +func (r *stashIDRepository) replace(id int, newIDs []models.StashID) error { + if err := r.destroy([]int{id}); err != nil { + return err + } + + query := fmt.Sprintf("INSERT INTO %s (%s, endpoint, stash_id) VALUES (?, ?, ?)", r.tableName, r.idColumn) + for _, stashID := range newIDs { + _, err := r.tx.Exec(query, id, stashID.Endpoint, stashID.StashID) + if err != nil { + return err + } + } + return nil +} + +func listKeys(i interface{}, addPrefix bool) string { + var query []string + v := reflect.ValueOf(i) + for i := 0; i < v.NumField(); i++ { + //get key for struct tag + rawKey := v.Type().Field(i).Tag.Get("db") + key := strings.Split(rawKey, ",")[0] + if key == "id" { + continue + } + if addPrefix { + key = ":" + key + } + query = append(query, key) + } + return strings.Join(query, ", ") +} + +func updateSet(i interface{}, partial bool) string { + var query []string + v := reflect.ValueOf(i) + for i := 0; i < v.NumField(); i++ { + //get key for struct tag + rawKey := v.Type().Field(i).Tag.Get("db") + key := strings.Split(rawKey, ",")[0] + if key == "id" { + continue + } + + add := true + if partial { + reflectValue := reflect.ValueOf(v.Field(i).Interface()) + add = !reflectValue.IsNil() + } + + if add { + query = append(query, fmt.Sprintf("%s=:%s", key, key)) + } + } + return strings.Join(query, ", ") +} + +func updateSetMap(m map[string]interface{}) string { + var query []string + for k := range m { + query = append(query, fmt.Sprintf("%s=:%s", k, k)) + } + return strings.Join(query, ", ") +} diff --git a/pkg/models/querybuilder_scene.go b/pkg/sqlite/scene.go similarity index 56% rename from pkg/models/querybuilder_scene.go rename to pkg/sqlite/scene.go index e9524da02..3a8bf99bd 100644 --- a/pkg/models/querybuilder_scene.go +++ b/pkg/sqlite/scene.go @@ -1,4 +1,4 @@ -package models +package sqlite import ( "database/sql" @@ -6,10 +6,14 @@ import ( "strings" "github.com/jmoiron/sqlx" - "github.com/stashapp/stash/pkg/database" + "github.com/stashapp/stash/pkg/models" ) const sceneTable = "scenes" +const sceneIDColumn = "scene_id" +const performersScenesTable = "performers_scenes" +const scenesTagsTable = "scenes_tags" +const moviesScenesTable = "movies_scenes" var scenesForPerformerQuery = selectAll(sceneTable) + ` LEFT JOIN performers_scenes as performers_join on performers_join.scene_id = scenes.id @@ -50,77 +54,55 @@ SELECT id FROM scenes WHERE scenes.oshash is null ` -type SceneQueryBuilder struct{} - -func NewSceneQueryBuilder() SceneQueryBuilder { - return SceneQueryBuilder{} +type sceneQueryBuilder struct { + repository } -func (qb *SceneQueryBuilder) Create(newScene Scene, tx *sqlx.Tx) (*Scene, error) { - ensureTx(tx) - result, err := tx.NamedExec( - `INSERT INTO scenes (oshash, checksum, path, title, details, url, date, rating, organized, o_counter, size, duration, video_codec, - audio_codec, format, width, height, framerate, bitrate, studio_id, file_mod_time, created_at, updated_at) - VALUES (:oshash, :checksum, :path, :title, :details, :url, :date, :rating, :organized, :o_counter, :size, :duration, :video_codec, - :audio_codec, :format, :width, :height, :framerate, :bitrate, :studio_id, :file_mod_time, :created_at, :updated_at) - `, - newScene, - ) - if err != nil { - return nil, err +func NewSceneReaderWriter(tx dbi) *sceneQueryBuilder { + return &sceneQueryBuilder{ + repository{ + tx: tx, + tableName: sceneTable, + idColumn: idColumn, + }, } - sceneID, err := result.LastInsertId() - if err != nil { - return nil, err - } - if err := tx.Get(&newScene, `SELECT * FROM scenes WHERE id = ? LIMIT 1`, sceneID); err != nil { - return nil, err - } - return &newScene, nil } -func (qb *SceneQueryBuilder) Update(updatedScene ScenePartial, tx *sqlx.Tx) (*Scene, error) { - ensureTx(tx) - _, err := tx.NamedExec( - `UPDATE scenes SET `+SQLGenKeysPartial(updatedScene)+` WHERE scenes.id = :id`, - updatedScene, - ) - if err != nil { +func (qb *sceneQueryBuilder) Create(newObject models.Scene) (*models.Scene, error) { + var ret models.Scene + if err := qb.insertObject(newObject, &ret); err != nil { return nil, err } - return qb.find(updatedScene.ID, tx) + return &ret, nil } -func (qb *SceneQueryBuilder) UpdateFull(updatedScene Scene, tx *sqlx.Tx) (*Scene, error) { - ensureTx(tx) - _, err := tx.NamedExec( - `UPDATE scenes SET `+SQLGenKeys(updatedScene)+` WHERE scenes.id = :id`, - updatedScene, - ) - if err != nil { +func (qb *sceneQueryBuilder) Update(updatedObject models.ScenePartial) (*models.Scene, error) { + const partial = true + if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil { return nil, err } - return qb.find(updatedScene.ID, tx) + return qb.find(updatedObject.ID) } -func (qb *SceneQueryBuilder) UpdateFileModTime(id int, modTime NullSQLiteTimestamp, tx *sqlx.Tx) error { - ensureTx(tx) - _, err := tx.Exec( - `UPDATE scenes SET file_mod_time = ? WHERE scenes.id = ? `, - modTime, id, - ) - if err != nil { - return err +func (qb *sceneQueryBuilder) UpdateFull(updatedObject models.Scene) (*models.Scene, error) { + const partial = false + if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil { + return nil, err } - return nil + return qb.find(updatedObject.ID) } -func (qb *SceneQueryBuilder) IncrementOCounter(id int, tx *sqlx.Tx) (int, error) { - ensureTx(tx) - _, err := tx.Exec( +func (qb *sceneQueryBuilder) UpdateFileModTime(id int, modTime models.NullSQLiteTimestamp) error { + return qb.updateMap(id, map[string]interface{}{ + "file_mod_time": modTime, + }) +} + +func (qb *sceneQueryBuilder) IncrementOCounter(id int) (int, error) { + _, err := qb.tx.Exec( `UPDATE scenes SET o_counter = o_counter + 1 WHERE scenes.id = ?`, id, ) @@ -128,7 +110,7 @@ func (qb *SceneQueryBuilder) IncrementOCounter(id int, tx *sqlx.Tx) (int, error) return 0, err } - scene, err := qb.find(id, tx) + scene, err := qb.find(id) if err != nil { return 0, err } @@ -136,9 +118,8 @@ func (qb *SceneQueryBuilder) IncrementOCounter(id int, tx *sqlx.Tx) (int, error) return scene.OCounter, nil } -func (qb *SceneQueryBuilder) DecrementOCounter(id int, tx *sqlx.Tx) (int, error) { - ensureTx(tx) - _, err := tx.Exec( +func (qb *sceneQueryBuilder) DecrementOCounter(id int) (int, error) { + _, err := qb.tx.Exec( `UPDATE scenes SET o_counter = o_counter - 1 WHERE scenes.id = ? and scenes.o_counter > 0`, id, ) @@ -146,7 +127,7 @@ func (qb *SceneQueryBuilder) DecrementOCounter(id int, tx *sqlx.Tx) (int, error) return 0, err } - scene, err := qb.find(id, tx) + scene, err := qb.find(id) if err != nil { return 0, err } @@ -154,9 +135,8 @@ func (qb *SceneQueryBuilder) DecrementOCounter(id int, tx *sqlx.Tx) (int, error) return scene.OCounter, nil } -func (qb *SceneQueryBuilder) ResetOCounter(id int, tx *sqlx.Tx) (int, error) { - ensureTx(tx) - _, err := tx.Exec( +func (qb *sceneQueryBuilder) ResetOCounter(id int) (int, error) { + _, err := qb.tx.Exec( `UPDATE scenes SET o_counter = 0 WHERE scenes.id = ?`, id, ) @@ -164,7 +144,7 @@ func (qb *SceneQueryBuilder) ResetOCounter(id int, tx *sqlx.Tx) (int, error) { return 0, err } - scene, err := qb.find(id, tx) + scene, err := qb.find(id) if err != nil { return 0, err } @@ -172,19 +152,25 @@ func (qb *SceneQueryBuilder) ResetOCounter(id int, tx *sqlx.Tx) (int, error) { return scene.OCounter, nil } -func (qb *SceneQueryBuilder) Destroy(id string, tx *sqlx.Tx) error { - _, err := tx.Exec("DELETE FROM movies_scenes WHERE scene_id = ?", id) - if err != nil { +func (qb *sceneQueryBuilder) Destroy(id int) error { + // delete all related table rows + // TODO - this should be handled by a delete cascade + if err := qb.performersRepository().destroy([]int{id}); err != nil { return err } - return executeDeleteQuery("scenes", id, tx) -} -func (qb *SceneQueryBuilder) Find(id int) (*Scene, error) { - return qb.find(id, nil) + + // scene markers should be handled prior to calling destroy + // galleries should be handled prior to calling destroy + + return qb.destroyExisting([]int{id}) } -func (qb *SceneQueryBuilder) FindMany(ids []int) ([]*Scene, error) { - var scenes []*Scene +func (qb *sceneQueryBuilder) Find(id int) (*models.Scene, error) { + return qb.find(id) +} + +func (qb *sceneQueryBuilder) FindMany(ids []int) ([]*models.Scene, error) { + var scenes []*models.Scene for _, id := range ids { scene, err := qb.Find(id) if err != nil { @@ -201,107 +187,105 @@ func (qb *SceneQueryBuilder) FindMany(ids []int) ([]*Scene, error) { return scenes, nil } -func (qb *SceneQueryBuilder) find(id int, tx *sqlx.Tx) (*Scene, error) { - query := selectAll(sceneTable) + "WHERE id = ? LIMIT 1" - args := []interface{}{id} - return qb.queryScene(query, args, tx) +func (qb *sceneQueryBuilder) find(id int) (*models.Scene, error) { + var ret models.Scene + if err := qb.get(id, &ret); err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, err + } + return &ret, nil } -func (qb *SceneQueryBuilder) FindByChecksum(checksum string) (*Scene, error) { +func (qb *sceneQueryBuilder) FindByChecksum(checksum string) (*models.Scene, error) { query := "SELECT * FROM scenes WHERE checksum = ? LIMIT 1" args := []interface{}{checksum} - return qb.queryScene(query, args, nil) + return qb.queryScene(query, args) } -func (qb *SceneQueryBuilder) FindByOSHash(oshash string) (*Scene, error) { +func (qb *sceneQueryBuilder) FindByOSHash(oshash string) (*models.Scene, error) { query := "SELECT * FROM scenes WHERE oshash = ? LIMIT 1" args := []interface{}{oshash} - return qb.queryScene(query, args, nil) + return qb.queryScene(query, args) } -func (qb *SceneQueryBuilder) FindByPath(path string) (*Scene, error) { +func (qb *sceneQueryBuilder) FindByPath(path string) (*models.Scene, error) { query := selectAll(sceneTable) + "WHERE path = ? LIMIT 1" args := []interface{}{path} - return qb.queryScene(query, args, nil) + return qb.queryScene(query, args) } -func (qb *SceneQueryBuilder) FindByPerformerID(performerID int) ([]*Scene, error) { +func (qb *sceneQueryBuilder) FindByPerformerID(performerID int) ([]*models.Scene, error) { args := []interface{}{performerID} - return qb.queryScenes(scenesForPerformerQuery, args, nil) + return qb.queryScenes(scenesForPerformerQuery, args) } -func (qb *SceneQueryBuilder) CountByPerformerID(performerID int) (int, error) { +func (qb *sceneQueryBuilder) CountByPerformerID(performerID int) (int, error) { args := []interface{}{performerID} - return runCountQuery(buildCountQuery(countScenesForPerformerQuery), args) + return qb.runCountQuery(qb.buildCountQuery(countScenesForPerformerQuery), args) } -func (qb *SceneQueryBuilder) FindByStudioID(studioID int) ([]*Scene, error) { - args := []interface{}{studioID} - return qb.queryScenes(scenesForStudioQuery, args, nil) -} - -func (qb *SceneQueryBuilder) FindByMovieID(movieID int) ([]*Scene, error) { +func (qb *sceneQueryBuilder) FindByMovieID(movieID int) ([]*models.Scene, error) { args := []interface{}{movieID} - return qb.queryScenes(scenesForMovieQuery, args, nil) + return qb.queryScenes(scenesForMovieQuery, args) } -func (qb *SceneQueryBuilder) CountByMovieID(movieID int) (int, error) { +func (qb *sceneQueryBuilder) CountByMovieID(movieID int) (int, error) { args := []interface{}{movieID} - return runCountQuery(buildCountQuery(scenesForMovieQuery), args) + return qb.runCountQuery(qb.buildCountQuery(scenesForMovieQuery), args) } -func (qb *SceneQueryBuilder) Count() (int, error) { - return runCountQuery(buildCountQuery("SELECT scenes.id FROM scenes"), nil) +func (qb *sceneQueryBuilder) Count() (int, error) { + return qb.runCountQuery(qb.buildCountQuery("SELECT scenes.id FROM scenes"), nil) } -func (qb *SceneQueryBuilder) Size() (float64, error) { - return runSumQuery("SELECT SUM(cast(size as double)) as sum FROM scenes", nil) +func (qb *sceneQueryBuilder) Size() (float64, error) { + return qb.runSumQuery("SELECT SUM(cast(size as double)) as sum FROM scenes", nil) } -func (qb *SceneQueryBuilder) CountByStudioID(studioID int) (int, error) { +func (qb *sceneQueryBuilder) CountByStudioID(studioID int) (int, error) { args := []interface{}{studioID} - return runCountQuery(buildCountQuery(scenesForStudioQuery), args) + return qb.runCountQuery(qb.buildCountQuery(scenesForStudioQuery), args) } -func (qb *SceneQueryBuilder) CountByTagID(tagID int) (int, error) { +func (qb *sceneQueryBuilder) CountByTagID(tagID int) (int, error) { args := []interface{}{tagID} - return runCountQuery(buildCountQuery(countScenesForTagQuery), args) + return qb.runCountQuery(qb.buildCountQuery(countScenesForTagQuery), args) } // CountMissingChecksum returns the number of scenes missing a checksum value. -func (qb *SceneQueryBuilder) CountMissingChecksum() (int, error) { - return runCountQuery(buildCountQuery(countScenesForMissingChecksumQuery), []interface{}{}) +func (qb *sceneQueryBuilder) CountMissingChecksum() (int, error) { + return qb.runCountQuery(qb.buildCountQuery(countScenesForMissingChecksumQuery), []interface{}{}) } // CountMissingOSHash returns the number of scenes missing an oshash value. -func (qb *SceneQueryBuilder) CountMissingOSHash() (int, error) { - return runCountQuery(buildCountQuery(countScenesForMissingOSHashQuery), []interface{}{}) +func (qb *sceneQueryBuilder) CountMissingOSHash() (int, error) { + return qb.runCountQuery(qb.buildCountQuery(countScenesForMissingOSHashQuery), []interface{}{}) } -func (qb *SceneQueryBuilder) Wall(q *string) ([]*Scene, error) { +func (qb *sceneQueryBuilder) Wall(q *string) ([]*models.Scene, error) { s := "" if q != nil { s = *q } query := selectAll(sceneTable) + "WHERE scenes.details LIKE '%" + s + "%' ORDER BY RANDOM() LIMIT 80" - return qb.queryScenes(query, nil, nil) + return qb.queryScenes(query, nil) } -func (qb *SceneQueryBuilder) All() ([]*Scene, error) { - return qb.queryScenes(selectAll(sceneTable)+qb.getSceneSort(nil), nil, nil) +func (qb *sceneQueryBuilder) All() ([]*models.Scene, error) { + return qb.queryScenes(selectAll(sceneTable)+qb.getSceneSort(nil), nil) } -func (qb *SceneQueryBuilder) Query(sceneFilter *SceneFilterType, findFilter *FindFilterType) ([]*Scene, int) { +func (qb *sceneQueryBuilder) Query(sceneFilter *models.SceneFilterType, findFilter *models.FindFilterType) ([]*models.Scene, int, error) { if sceneFilter == nil { - sceneFilter = &SceneFilterType{} + sceneFilter = &models.SceneFilterType{} } if findFilter == nil { - findFilter = &FindFilterType{} + findFilter = &models.FindFilterType{} } - query := queryBuilder{ - tableName: sceneTable, - } + query := qb.newQuery() query.body = selectDistinctIDs(sceneTable) query.body += ` @@ -452,15 +436,21 @@ func (qb *SceneQueryBuilder) Query(sceneFilter *SceneFilterType, findFilter *Fin } query.sortAndPagination = qb.getSceneSort(findFilter) + getPagination(findFilter) - idsResult, countResult := query.executeFind() + idsResult, countResult, err := query.executeFind() + if err != nil { + return nil, 0, err + } - var scenes []*Scene + var scenes []*models.Scene for _, id := range idsResult { - scene, _ := qb.Find(id) + scene, err := qb.Find(id) + if err != nil { + return nil, 0, err + } scenes = append(scenes, scene) } - return scenes, countResult + return scenes, countResult, nil } func appendClause(clauses []string, clause string) []string { @@ -471,7 +461,7 @@ func appendClause(clauses []string, clause string) []string { return clauses } -func getDurationWhereClause(durationFilter IntCriterionInput) (string, []interface{}) { +func getDurationWhereClause(durationFilter models.IntCriterionInput) (string, []interface{}) { // special case for duration. We accept duration as seconds as int but the // field is floating point. Change the equals filter to return a range // between x and x + 1 @@ -480,11 +470,11 @@ func getDurationWhereClause(durationFilter IntCriterionInput) (string, []interfa args := []interface{}{} value := durationFilter.Value - if durationFilter.Modifier == CriterionModifierEquals { + if durationFilter.Modifier == models.CriterionModifierEquals { clause = "scenes.duration >= ? AND scenes.duration < ?" args = append(args, value) args = append(args, value+1) - } else if durationFilter.Modifier == CriterionModifierNotEquals { + } else if durationFilter.Modifier == models.CriterionModifierNotEquals { clause = "(scenes.duration < ? OR scenes.duration >= ?)" args = append(args, value) args = append(args, value+1) @@ -499,7 +489,7 @@ func getDurationWhereClause(durationFilter IntCriterionInput) (string, []interfa return clause, args } -func (qb *SceneQueryBuilder) QueryAllByPathRegex(regex string, ignoreOrganized bool) ([]*Scene, error) { +func (qb *sceneQueryBuilder) QueryAllByPathRegex(regex string, ignoreOrganized bool) ([]*models.Scene, error) { var args []interface{} body := selectDistinctIDs("scenes") + " WHERE scenes.path regexp ?" @@ -509,13 +499,13 @@ func (qb *SceneQueryBuilder) QueryAllByPathRegex(regex string, ignoreOrganized b args = append(args, "(?i)"+regex) - idsResult, err := runIdsQuery(body, args) + idsResult, err := qb.runIdsQuery(body, args) if err != nil { return nil, err } - var scenes []*Scene + var scenes []*models.Scene for _, id := range idsResult { scene, err := qb.Find(id) @@ -529,9 +519,9 @@ func (qb *SceneQueryBuilder) QueryAllByPathRegex(regex string, ignoreOrganized b return scenes, nil } -func (qb *SceneQueryBuilder) QueryByPathRegex(findFilter *FindFilterType) ([]*Scene, int) { +func (qb *sceneQueryBuilder) QueryByPathRegex(findFilter *models.FindFilterType) ([]*models.Scene, int, error) { if findFilter == nil { - findFilter = &FindFilterType{} + findFilter = &models.FindFilterType{} } var whereClauses []string @@ -545,18 +535,25 @@ func (qb *SceneQueryBuilder) QueryByPathRegex(findFilter *FindFilterType) ([]*Sc } sortAndPagination := qb.getSceneSort(findFilter) + getPagination(findFilter) - idsResult, countResult := executeFindQuery("scenes", body, args, sortAndPagination, whereClauses, havingClauses) + idsResult, countResult, err := qb.executeFindQuery(body, args, sortAndPagination, whereClauses, havingClauses) + if err != nil { + return nil, 0, err + } - var scenes []*Scene + var scenes []*models.Scene for _, id := range idsResult { - scene, _ := qb.Find(id) + scene, err := qb.Find(id) + if err != nil { + return nil, 0, err + } + scenes = append(scenes, scene) } - return scenes, countResult + return scenes, countResult, nil } -func (qb *SceneQueryBuilder) getSceneSort(findFilter *FindFilterType) string { +func (qb *sceneQueryBuilder) getSceneSort(findFilter *models.FindFilterType) string { if findFilter == nil { return " ORDER BY scenes.path, scenes.date ASC " } @@ -565,112 +562,141 @@ func (qb *SceneQueryBuilder) getSceneSort(findFilter *FindFilterType) string { return getSort(sort, direction, "scenes") } -func (qb *SceneQueryBuilder) queryScene(query string, args []interface{}, tx *sqlx.Tx) (*Scene, error) { - results, err := qb.queryScenes(query, args, tx) +func (qb *sceneQueryBuilder) queryScene(query string, args []interface{}) (*models.Scene, error) { + results, err := qb.queryScenes(query, args) if err != nil || len(results) < 1 { return nil, err } return results[0], nil } -func (qb *SceneQueryBuilder) queryScenes(query string, args []interface{}, tx *sqlx.Tx) ([]*Scene, error) { - var rows *sqlx.Rows - var err error - if tx != nil { - rows, err = tx.Queryx(query, args...) - } else { - rows, err = database.DB.Queryx(query, args...) - } - - if err != nil && err != sql.ErrNoRows { +func (qb *sceneQueryBuilder) queryScenes(query string, args []interface{}) ([]*models.Scene, error) { + var ret models.Scenes + if err := qb.query(query, args, &ret); err != nil { return nil, err } - defer rows.Close() - scenes := make([]*Scene, 0) - for rows.Next() { - scene := Scene{} - if err := rows.StructScan(&scene); err != nil { - return nil, err + return []*models.Scene(ret), nil +} + +func (qb *sceneQueryBuilder) imageRepository() *imageRepository { + return &imageRepository{ + repository: repository{ + tx: qb.tx, + tableName: "scenes_cover", + idColumn: sceneIDColumn, + }, + imageColumn: "cover", + } +} + +func (qb *sceneQueryBuilder) GetCover(sceneID int) ([]byte, error) { + return qb.imageRepository().get(sceneID) +} + +func (qb *sceneQueryBuilder) UpdateCover(sceneID int, image []byte) error { + return qb.imageRepository().replace(sceneID, image) +} + +func (qb *sceneQueryBuilder) DestroyCover(sceneID int) error { + return qb.imageRepository().destroy([]int{sceneID}) +} + +func (qb *sceneQueryBuilder) moviesRepository() *repository { + return &repository{ + tx: qb.tx, + tableName: moviesScenesTable, + idColumn: sceneIDColumn, + } +} + +func (qb *sceneQueryBuilder) GetMovies(id int) (ret []models.MoviesScenes, err error) { + if err := qb.moviesRepository().getAll(id, func(rows *sqlx.Rows) error { + var ms models.MoviesScenes + if err := rows.StructScan(&ms); err != nil { + return err } - scenes = append(scenes, &scene) - } - if err := rows.Err(); err != nil { + ret = append(ret, ms) + return nil + }); err != nil { return nil, err } - return scenes, nil + return ret, nil } -func (qb *SceneQueryBuilder) UpdateFormat(id int, format string, tx *sqlx.Tx) error { - ensureTx(tx) - _, err := tx.Exec( - `UPDATE scenes SET format = ? WHERE scenes.id = ? `, - format, id, - ) - if err != nil { +func (qb *sceneQueryBuilder) UpdateMovies(sceneID int, movies []models.MoviesScenes) error { + // destroy existing joins + r := qb.moviesRepository() + if err := r.destroy([]int{sceneID}); err != nil { return err } + for _, m := range movies { + m.SceneID = sceneID + if _, err := r.insert(m); err != nil { + return err + } + } + return nil } -func (qb *SceneQueryBuilder) UpdateOSHash(id int, oshash string, tx *sqlx.Tx) error { - ensureTx(tx) - _, err := tx.Exec( - `UPDATE scenes SET oshash = ? WHERE scenes.id = ? `, - oshash, id, - ) - if err != nil { - return err +func (qb *sceneQueryBuilder) performersRepository() *joinRepository { + return &joinRepository{ + repository: repository{ + tx: qb.tx, + tableName: performersScenesTable, + idColumn: sceneIDColumn, + }, + fkColumn: performerIDColumn, } - - return nil } -func (qb *SceneQueryBuilder) UpdateChecksum(id int, checksum string, tx *sqlx.Tx) error { - ensureTx(tx) - _, err := tx.Exec( - `UPDATE scenes SET checksum = ? WHERE scenes.id = ? `, - checksum, id, - ) - if err != nil { - return err +func (qb *sceneQueryBuilder) GetPerformerIDs(id int) ([]int, error) { + return qb.performersRepository().getIDs(id) +} + +func (qb *sceneQueryBuilder) UpdatePerformers(id int, performerIDs []int) error { + // Delete the existing joins and then create new ones + return qb.performersRepository().replace(id, performerIDs) +} + +func (qb *sceneQueryBuilder) tagsRepository() *joinRepository { + return &joinRepository{ + repository: repository{ + tx: qb.tx, + tableName: scenesTagsTable, + idColumn: sceneIDColumn, + }, + fkColumn: tagIDColumn, } - - return nil } -func (qb *SceneQueryBuilder) UpdateSceneCover(sceneID int, cover []byte, tx *sqlx.Tx) error { - ensureTx(tx) +func (qb *sceneQueryBuilder) GetTagIDs(id int) ([]int, error) { + return qb.tagsRepository().getIDs(id) +} - // Delete the existing cover and then create new - if err := qb.DestroySceneCover(sceneID, tx); err != nil { - return err +func (qb *sceneQueryBuilder) UpdateTags(id int, tagIDs []int) error { + // Delete the existing joins and then create new ones + return qb.tagsRepository().replace(id, tagIDs) +} + +func (qb *sceneQueryBuilder) stashIDRepository() *stashIDRepository { + return &stashIDRepository{ + repository{ + tx: qb.tx, + tableName: "scene_stash_ids", + idColumn: sceneIDColumn, + }, } - - _, err := tx.Exec( - `INSERT INTO scenes_cover (scene_id, cover) VALUES (?, ?)`, - sceneID, - cover, - ) - - return err } -func (qb *SceneQueryBuilder) DestroySceneCover(sceneID int, tx *sqlx.Tx) error { - ensureTx(tx) - - // Delete the existing joins - _, err := tx.Exec("DELETE FROM scenes_cover WHERE scene_id = ?", sceneID) - if err != nil { - return err - } - return err +func (qb *sceneQueryBuilder) GetStashIDs(sceneID int) ([]*models.StashID, error) { + return qb.stashIDRepository().get(sceneID) } -func (qb *SceneQueryBuilder) GetSceneCover(sceneID int, tx *sqlx.Tx) ([]byte, error) { - query := `SELECT cover from scenes_cover WHERE scene_id = ?` - return getImage(tx, query, sceneID) +func (qb *sceneQueryBuilder) UpdateStashIDs(sceneID int, stashIDs []models.StashID) error { + return qb.stashIDRepository().replace(sceneID, stashIDs) } diff --git a/pkg/models/querybuilder_scene_marker.go b/pkg/sqlite/scene_marker.go similarity index 61% rename from pkg/models/querybuilder_scene_marker.go rename to pkg/sqlite/scene_marker.go index 5edba444c..9c8e6e9bc 100644 --- a/pkg/models/querybuilder_scene_marker.go +++ b/pkg/sqlite/scene_marker.go @@ -1,14 +1,16 @@ -package models +package sqlite import ( "database/sql" "fmt" "strconv" - "github.com/jmoiron/sqlx" "github.com/stashapp/stash/pkg/database" + "github.com/stashapp/stash/pkg/models" ) +const sceneMarkerTable = "scene_markers" + const countSceneMarkersForTagQuery = ` SELECT scene_markers.id FROM scene_markers LEFT JOIN scene_markers_tags as tags_join on tags_join.scene_marker_id = scene_markers.id @@ -16,66 +18,59 @@ WHERE tags_join.tag_id = ? OR scene_markers.primary_tag_id = ? GROUP BY scene_markers.id ` -type SceneMarkerQueryBuilder struct{} - -func NewSceneMarkerQueryBuilder() SceneMarkerQueryBuilder { - return SceneMarkerQueryBuilder{} +type sceneMarkerQueryBuilder struct { + repository } -func (qb *SceneMarkerQueryBuilder) Create(newSceneMarker SceneMarker, tx *sqlx.Tx) (*SceneMarker, error) { - ensureTx(tx) - result, err := tx.NamedExec( - `INSERT INTO scene_markers (title, seconds, primary_tag_id, scene_id, created_at, updated_at) - VALUES (:title, :seconds, :primary_tag_id, :scene_id, :created_at, :updated_at) - `, - newSceneMarker, - ) - if err != nil { - return nil, err +func NewSceneMarkerReaderWriter(tx dbi) *sceneMarkerQueryBuilder { + return &sceneMarkerQueryBuilder{ + repository{ + tx: tx, + tableName: sceneMarkerTable, + idColumn: idColumn, + }, } - sceneMarkerID, err := result.LastInsertId() - if err != nil { - return nil, err - } - - if err := tx.Get(&newSceneMarker, `SELECT * FROM scene_markers WHERE id = ? LIMIT 1`, sceneMarkerID); err != nil { - return nil, err - } - return &newSceneMarker, nil } -func (qb *SceneMarkerQueryBuilder) Update(updatedSceneMarker SceneMarker, tx *sqlx.Tx) (*SceneMarker, error) { - ensureTx(tx) - _, err := tx.NamedExec( - `UPDATE scene_markers SET `+SQLGenKeys(updatedSceneMarker)+` WHERE scene_markers.id = :id`, - updatedSceneMarker, - ) - if err != nil { +func (qb *sceneMarkerQueryBuilder) Create(newObject models.SceneMarker) (*models.SceneMarker, error) { + var ret models.SceneMarker + if err := qb.insertObject(newObject, &ret); err != nil { return nil, err } - if err := tx.Get(&updatedSceneMarker, `SELECT * FROM scene_markers WHERE id = ? LIMIT 1`, updatedSceneMarker.ID); err != nil { + return &ret, nil +} + +func (qb *sceneMarkerQueryBuilder) Update(updatedObject models.SceneMarker) (*models.SceneMarker, error) { + const partial = false + if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil { return nil, err } - return &updatedSceneMarker, nil + + var ret models.SceneMarker + if err := qb.get(updatedObject.ID, &ret); err != nil { + return nil, err + } + + return &ret, nil } -func (qb *SceneMarkerQueryBuilder) Destroy(id string, tx *sqlx.Tx) error { - return executeDeleteQuery("scene_markers", id, tx) +func (qb *sceneMarkerQueryBuilder) Destroy(id int) error { + return qb.destroyExisting([]int{id}) } -func (qb *SceneMarkerQueryBuilder) Find(id int) (*SceneMarker, error) { +func (qb *sceneMarkerQueryBuilder) Find(id int) (*models.SceneMarker, error) { query := "SELECT * FROM scene_markers WHERE id = ? LIMIT 1" args := []interface{}{id} - results, err := qb.querySceneMarkers(query, args, nil) + results, err := qb.querySceneMarkers(query, args) if err != nil || len(results) < 1 { return nil, err } return results[0], nil } -func (qb *SceneMarkerQueryBuilder) FindMany(ids []int) ([]*SceneMarker, error) { - var markers []*SceneMarker +func (qb *sceneMarkerQueryBuilder) FindMany(ids []int) ([]*models.SceneMarker, error) { + var markers []*models.SceneMarker for _, id := range ids { marker, err := qb.Find(id) if err != nil { @@ -92,7 +87,7 @@ func (qb *SceneMarkerQueryBuilder) FindMany(ids []int) ([]*SceneMarker, error) { return markers, nil } -func (qb *SceneMarkerQueryBuilder) FindBySceneID(sceneID int, tx *sqlx.Tx) ([]*SceneMarker, error) { +func (qb *sceneMarkerQueryBuilder) FindBySceneID(sceneID int) ([]*models.SceneMarker, error) { query := ` SELECT scene_markers.* FROM scene_markers WHERE scene_markers.scene_id = ? @@ -100,15 +95,15 @@ func (qb *SceneMarkerQueryBuilder) FindBySceneID(sceneID int, tx *sqlx.Tx) ([]*S ORDER BY scene_markers.seconds ASC ` args := []interface{}{sceneID} - return qb.querySceneMarkers(query, args, tx) + return qb.querySceneMarkers(query, args) } -func (qb *SceneMarkerQueryBuilder) CountByTagID(tagID int) (int, error) { +func (qb *sceneMarkerQueryBuilder) CountByTagID(tagID int) (int, error) { args := []interface{}{tagID, tagID} - return runCountQuery(buildCountQuery(countSceneMarkersForTagQuery), args) + return qb.runCountQuery(qb.buildCountQuery(countSceneMarkersForTagQuery), args) } -func (qb *SceneMarkerQueryBuilder) GetMarkerStrings(q *string, sort *string) ([]*MarkerStringsResultType, error) { +func (qb *sceneMarkerQueryBuilder) GetMarkerStrings(q *string, sort *string) ([]*models.MarkerStringsResultType, error) { query := "SELECT count(*) as `count`, scene_markers.id as id, scene_markers.title as title FROM scene_markers" if q != nil { query = query + " WHERE title LIKE '%" + *q + "%'" @@ -123,21 +118,21 @@ func (qb *SceneMarkerQueryBuilder) GetMarkerStrings(q *string, sort *string) ([] return qb.queryMarkerStringsResultType(query, args) } -func (qb *SceneMarkerQueryBuilder) Wall(q *string) ([]*SceneMarker, error) { +func (qb *sceneMarkerQueryBuilder) Wall(q *string) ([]*models.SceneMarker, error) { s := "" if q != nil { s = *q } query := "SELECT scene_markers.* FROM scene_markers WHERE scene_markers.title LIKE '%" + s + "%' ORDER BY RANDOM() LIMIT 80" - return qb.querySceneMarkers(query, nil, nil) + return qb.querySceneMarkers(query, nil) } -func (qb *SceneMarkerQueryBuilder) Query(sceneMarkerFilter *SceneMarkerFilterType, findFilter *FindFilterType) ([]*SceneMarker, int) { +func (qb *sceneMarkerQueryBuilder) Query(sceneMarkerFilter *models.SceneMarkerFilterType, findFilter *models.FindFilterType) ([]*models.SceneMarker, int, error) { if sceneMarkerFilter == nil { - sceneMarkerFilter = &SceneMarkerFilterType{} + sceneMarkerFilter = &models.SceneMarkerFilterType{} } if findFilter == nil { - findFilter = &FindFilterType{} + findFilter = &models.FindFilterType{} } var whereClauses []string @@ -164,7 +159,7 @@ func (qb *SceneMarkerQueryBuilder) Query(sceneMarkerFilter *SceneMarkerFilterTyp length := len(tagsFilter.Value) - if tagsFilter.Modifier == CriterionModifierIncludes || tagsFilter.Modifier == CriterionModifierIncludesAll { + if tagsFilter.Modifier == models.CriterionModifierIncludes || tagsFilter.Modifier == models.CriterionModifierIncludesAll { body += " LEFT JOIN tags AS ptj ON ptj.id = scene_markers.primary_tag_id AND ptj.id IN " + getInBinding(length) body += " LEFT JOIN scene_markers_tags AS tj ON tj.scene_marker_id = scene_markers.id AND tj.tag_id IN " + getInBinding(length) @@ -172,12 +167,12 @@ func (qb *SceneMarkerQueryBuilder) Query(sceneMarkerFilter *SceneMarkerFilterTyp requiredCount := 1 // all required for include all - if tagsFilter.Modifier == CriterionModifierIncludesAll { + if tagsFilter.Modifier == models.CriterionModifierIncludesAll { requiredCount = length } havingClauses = append(havingClauses, "((COUNT(DISTINCT ptj.id) + COUNT(DISTINCT tj.tag_id)) >= "+strconv.Itoa(requiredCount)+")") - } else if tagsFilter.Modifier == CriterionModifierExcludes { + } else if tagsFilter.Modifier == models.CriterionModifierExcludes { // excludes all of the provided ids whereClauses = append(whereClauses, "scene_markers.primary_tag_id not in "+getInBinding(length)) whereClauses = append(whereClauses, "not exists (select smt.scene_marker_id from scene_markers_tags as smt where smt.scene_marker_id = scene_markers.id and smt.tag_id in "+getInBinding(length)+")") @@ -194,19 +189,19 @@ func (qb *SceneMarkerQueryBuilder) Query(sceneMarkerFilter *SceneMarkerFilterTyp if sceneTagsFilter := sceneMarkerFilter.SceneTags; sceneTagsFilter != nil && len(sceneTagsFilter.Value) > 0 { length := len(sceneTagsFilter.Value) - if sceneTagsFilter.Modifier == CriterionModifierIncludes || sceneTagsFilter.Modifier == CriterionModifierIncludesAll { + if sceneTagsFilter.Modifier == models.CriterionModifierIncludes || sceneTagsFilter.Modifier == models.CriterionModifierIncludesAll { body += " LEFT JOIN scenes_tags AS scene_tags_join ON scene_tags_join.scene_id = scene.id AND scene_tags_join.tag_id IN " + getInBinding(length) // only one required for include any requiredCount := 1 // all required for include all - if sceneTagsFilter.Modifier == CriterionModifierIncludesAll { + if sceneTagsFilter.Modifier == models.CriterionModifierIncludesAll { requiredCount = length } havingClauses = append(havingClauses, "COUNT(DISTINCT scene_tags_join.tag_id) >= "+strconv.Itoa(requiredCount)) - } else if sceneTagsFilter.Modifier == CriterionModifierExcludes { + } else if sceneTagsFilter.Modifier == models.CriterionModifierExcludes { // excludes all of the provided ids whereClauses = append(whereClauses, "not exists (select st.scene_id from scenes_tags as st where st.scene_id = scene.id AND st.tag_id IN "+getInBinding(length)+")") } @@ -219,7 +214,7 @@ func (qb *SceneMarkerQueryBuilder) Query(sceneMarkerFilter *SceneMarkerFilterTyp if performersFilter := sceneMarkerFilter.Performers; performersFilter != nil && len(performersFilter.Value) > 0 { length := len(performersFilter.Value) - if performersFilter.Modifier == CriterionModifierIncludes || performersFilter.Modifier == CriterionModifierIncludesAll { + if performersFilter.Modifier == models.CriterionModifierIncludes || performersFilter.Modifier == models.CriterionModifierIncludesAll { body += " LEFT JOIN performers_scenes as scene_performers ON scene.id = scene_performers.scene_id" whereClauses = append(whereClauses, "scene_performers.performer_id IN "+getInBinding(length)) @@ -227,12 +222,12 @@ func (qb *SceneMarkerQueryBuilder) Query(sceneMarkerFilter *SceneMarkerFilterTyp requiredCount := 1 // all required for include all - if performersFilter.Modifier == CriterionModifierIncludesAll { + if performersFilter.Modifier == models.CriterionModifierIncludesAll { requiredCount = length } havingClauses = append(havingClauses, "COUNT(DISTINCT scene_performers.performer_id) >= "+strconv.Itoa(requiredCount)) - } else if performersFilter.Modifier == CriterionModifierExcludes { + } else if performersFilter.Modifier == models.CriterionModifierExcludes { // excludes all of the provided ids whereClauses = append(whereClauses, "not exists (select sp.scene_id from performers_scenes as sp where sp.scene_id = scene.id AND sp.performer_id IN "+getInBinding(length)+")") } @@ -254,18 +249,25 @@ func (qb *SceneMarkerQueryBuilder) Query(sceneMarkerFilter *SceneMarkerFilterTyp } sortAndPagination := qb.getSceneMarkerSort(findFilter) + getPagination(findFilter) - idsResult, countResult := executeFindQuery("scene_markers", body, args, sortAndPagination, whereClauses, havingClauses) + idsResult, countResult, err := qb.executeFindQuery(body, args, sortAndPagination, whereClauses, havingClauses) + if err != nil { + return nil, 0, err + } - var sceneMarkers []*SceneMarker + var sceneMarkers []*models.SceneMarker for _, id := range idsResult { - sceneMarker, _ := qb.Find(id) + sceneMarker, err := qb.Find(id) + if err != nil { + return nil, 0, err + } + sceneMarkers = append(sceneMarkers, sceneMarker) } - return sceneMarkers, countResult + return sceneMarkers, countResult, nil } -func (qb *SceneMarkerQueryBuilder) getSceneMarkerSort(findFilter *FindFilterType) string { +func (qb *sceneMarkerQueryBuilder) getSceneMarkerSort(findFilter *models.FindFilterType) string { sort := findFilter.GetSort("title") direction := findFilter.GetDirection() tableName := "scene_markers" @@ -276,46 +278,25 @@ func (qb *SceneMarkerQueryBuilder) getSceneMarkerSort(findFilter *FindFilterType return getSort(sort, direction, tableName) } -func (qb *SceneMarkerQueryBuilder) querySceneMarkers(query string, args []interface{}, tx *sqlx.Tx) ([]*SceneMarker, error) { - var rows *sqlx.Rows - var err error - if tx != nil { - rows, err = tx.Queryx(query, args...) - } else { - rows, err = database.DB.Queryx(query, args...) - } - - if err != nil && err != sql.ErrNoRows { - return nil, err - } - defer rows.Close() - - sceneMarkers := make([]*SceneMarker, 0) - for rows.Next() { - sceneMarker := SceneMarker{} - if err := rows.StructScan(&sceneMarker); err != nil { - return nil, err - } - sceneMarkers = append(sceneMarkers, &sceneMarker) - } - - if err := rows.Err(); err != nil { +func (qb *sceneMarkerQueryBuilder) querySceneMarkers(query string, args []interface{}) ([]*models.SceneMarker, error) { + var ret models.SceneMarkers + if err := qb.query(query, args, &ret); err != nil { return nil, err } - return sceneMarkers, nil + return []*models.SceneMarker(ret), nil } -func (qb *SceneMarkerQueryBuilder) queryMarkerStringsResultType(query string, args []interface{}) ([]*MarkerStringsResultType, error) { +func (qb *sceneMarkerQueryBuilder) queryMarkerStringsResultType(query string, args []interface{}) ([]*models.MarkerStringsResultType, error) { rows, err := database.DB.Queryx(query, args...) if err != nil && err != sql.ErrNoRows { return nil, err } defer rows.Close() - markerStrings := make([]*MarkerStringsResultType, 0) + markerStrings := make([]*models.MarkerStringsResultType, 0) for rows.Next() { - markerString := MarkerStringsResultType{} + markerString := models.MarkerStringsResultType{} if err := rows.StructScan(&markerString); err != nil { return nil, err } @@ -328,3 +309,23 @@ func (qb *SceneMarkerQueryBuilder) queryMarkerStringsResultType(query string, ar return markerStrings, nil } + +func (qb *sceneMarkerQueryBuilder) tagsRepository() *joinRepository { + return &joinRepository{ + repository: repository{ + tx: qb.tx, + tableName: "scene_markers_tags", + idColumn: "scene_marker_id", + }, + fkColumn: tagIDColumn, + } +} + +func (qb *sceneMarkerQueryBuilder) GetTagIDs(id int) ([]int, error) { + return qb.tagsRepository().getIDs(id) +} + +func (qb *sceneMarkerQueryBuilder) UpdateTags(id int, tagIDs []int) error { + // Delete the existing joins and then create new ones + return qb.tagsRepository().replace(id, tagIDs) +} diff --git a/pkg/sqlite/scene_marker_test.go b/pkg/sqlite/scene_marker_test.go new file mode 100644 index 000000000..bc3d95daa --- /dev/null +++ b/pkg/sqlite/scene_marker_test.go @@ -0,0 +1,75 @@ +// +build integration + +package sqlite_test + +import ( + "testing" + + "github.com/stashapp/stash/pkg/models" + "github.com/stretchr/testify/assert" +) + +func TestMarkerFindBySceneID(t *testing.T) { + withTxn(func(r models.Repository) error { + mqb := r.SceneMarker() + + sceneID := sceneIDs[sceneIdxWithMarker] + markers, err := mqb.FindBySceneID(sceneID) + + if err != nil { + t.Errorf("Error finding markers: %s", err.Error()) + } + + assert.Len(t, markers, 1) + assert.Equal(t, markerIDs[markerIdxWithScene], markers[0].ID) + + markers, err = mqb.FindBySceneID(0) + + if err != nil { + t.Errorf("Error finding marker: %s", err.Error()) + } + + assert.Len(t, markers, 0) + + return nil + }) +} + +func TestMarkerCountByTagID(t *testing.T) { + withTxn(func(r models.Repository) error { + mqb := r.SceneMarker() + + markerCount, err := mqb.CountByTagID(tagIDs[tagIdxWithPrimaryMarker]) + + if err != nil { + t.Errorf("error calling CountByTagID: %s", err.Error()) + } + + assert.Equal(t, 1, markerCount) + + markerCount, err = mqb.CountByTagID(tagIDs[tagIdxWithMarker]) + + if err != nil { + t.Errorf("error calling CountByTagID: %s", err.Error()) + } + + assert.Equal(t, 1, markerCount) + + markerCount, err = mqb.CountByTagID(0) + + if err != nil { + t.Errorf("error calling CountByTagID: %s", err.Error()) + } + + assert.Equal(t, 0, markerCount) + + return nil + }) +} + +// TODO Update +// TODO Destroy +// TODO Find +// TODO GetMarkerStrings +// TODO Wall +// TODO Query diff --git a/pkg/sqlite/scene_test.go b/pkg/sqlite/scene_test.go new file mode 100644 index 000000000..f6066b252 --- /dev/null +++ b/pkg/sqlite/scene_test.go @@ -0,0 +1,1169 @@ +// +build integration + +package sqlite_test + +import ( + "database/sql" + "fmt" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/utils" +) + +func TestSceneFind(t *testing.T) { + withTxn(func(r models.Repository) error { + // assume that the first scene is sceneWithGalleryPath + sqb := r.Scene() + + const sceneIdx = 0 + sceneID := sceneIDs[sceneIdx] + scene, err := sqb.Find(sceneID) + + if err != nil { + t.Errorf("Error finding scene: %s", err.Error()) + } + + assert.Equal(t, getSceneStringValue(sceneIdx, "Path"), scene.Path) + + sceneID = 0 + scene, err = sqb.Find(sceneID) + + if err != nil { + t.Errorf("Error finding scene: %s", err.Error()) + } + + assert.Nil(t, scene) + + return nil + }) +} + +func TestSceneFindByPath(t *testing.T) { + withTxn(func(r models.Repository) error { + sqb := r.Scene() + + const sceneIdx = 1 + scenePath := getSceneStringValue(sceneIdx, "Path") + scene, err := sqb.FindByPath(scenePath) + + if err != nil { + t.Errorf("Error finding scene: %s", err.Error()) + } + + assert.Equal(t, sceneIDs[sceneIdx], scene.ID) + assert.Equal(t, scenePath, scene.Path) + + scenePath = "not exist" + scene, err = sqb.FindByPath(scenePath) + + if err != nil { + t.Errorf("Error finding scene: %s", err.Error()) + } + + assert.Nil(t, scene) + + return nil + }) +} + +func TestSceneCountByPerformerID(t *testing.T) { + withTxn(func(r models.Repository) error { + sqb := r.Scene() + count, err := sqb.CountByPerformerID(performerIDs[performerIdxWithScene]) + + if err != nil { + t.Errorf("Error counting scenes: %s", err.Error()) + } + + assert.Equal(t, 1, count) + + count, err = sqb.CountByPerformerID(0) + + if err != nil { + t.Errorf("Error counting scenes: %s", err.Error()) + } + + assert.Equal(t, 0, count) + + return nil + }) +} + +func TestSceneWall(t *testing.T) { + withTxn(func(r models.Repository) error { + sqb := r.Scene() + + const sceneIdx = 2 + wallQuery := getSceneStringValue(sceneIdx, "Details") + scenes, err := sqb.Wall(&wallQuery) + + if err != nil { + t.Errorf("Error finding scenes: %s", err.Error()) + } + + assert.Len(t, scenes, 1) + scene := scenes[0] + assert.Equal(t, sceneIDs[sceneIdx], scene.ID) + assert.Equal(t, getSceneStringValue(sceneIdx, "Path"), scene.Path) + + wallQuery = "not exist" + scenes, err = sqb.Wall(&wallQuery) + + if err != nil { + t.Errorf("Error finding scene: %s", err.Error()) + } + + assert.Len(t, scenes, 0) + + return nil + }) +} + +func TestSceneQueryQ(t *testing.T) { + const sceneIdx = 2 + + q := getSceneStringValue(sceneIdx, titleField) + + withTxn(func(r models.Repository) error { + sqb := r.Scene() + + sceneQueryQ(t, sqb, q, sceneIdx) + + return nil + }) +} + +func queryScene(t *testing.T, sqb models.SceneReader, sceneFilter *models.SceneFilterType, findFilter *models.FindFilterType) []*models.Scene { + scenes, _, err := sqb.Query(sceneFilter, findFilter) + if err != nil { + t.Errorf("Error querying scene: %s", err.Error()) + } + + return scenes +} + +func sceneQueryQ(t *testing.T, sqb models.SceneReader, q string, expectedSceneIdx int) { + filter := models.FindFilterType{ + Q: &q, + } + scenes := queryScene(t, sqb, nil, &filter) + + assert.Len(t, scenes, 1) + scene := scenes[0] + assert.Equal(t, sceneIDs[expectedSceneIdx], scene.ID) + + // no Q should return all results + filter.Q = nil + scenes = queryScene(t, sqb, nil, &filter) + + assert.Len(t, scenes, totalScenes) +} + +func TestSceneQueryPath(t *testing.T) { + const sceneIdx = 1 + scenePath := getSceneStringValue(sceneIdx, "Path") + + pathCriterion := models.StringCriterionInput{ + Value: scenePath, + Modifier: models.CriterionModifierEquals, + } + + verifyScenesPath(t, pathCriterion) + + pathCriterion.Modifier = models.CriterionModifierNotEquals + verifyScenesPath(t, pathCriterion) +} + +func verifyScenesPath(t *testing.T, pathCriterion models.StringCriterionInput) { + withTxn(func(r models.Repository) error { + sqb := r.Scene() + sceneFilter := models.SceneFilterType{ + Path: &pathCriterion, + } + + scenes := queryScene(t, sqb, &sceneFilter, nil) + + for _, scene := range scenes { + verifyString(t, scene.Path, pathCriterion) + } + + return nil + }) +} + +func verifyNullString(t *testing.T, value sql.NullString, criterion models.StringCriterionInput) { + t.Helper() + assert := assert.New(t) + if criterion.Modifier == models.CriterionModifierIsNull { + assert.False(value.Valid, "expect is null values to be null") + } + if criterion.Modifier == models.CriterionModifierNotNull { + assert.True(value.Valid, "expect is null values to be null") + } + if criterion.Modifier == models.CriterionModifierEquals { + assert.Equal(criterion.Value, value.String) + } + if criterion.Modifier == models.CriterionModifierNotEquals { + assert.NotEqual(criterion.Value, value.String) + } +} + +func verifyString(t *testing.T, value string, criterion models.StringCriterionInput) { + t.Helper() + assert := assert.New(t) + if criterion.Modifier == models.CriterionModifierEquals { + assert.Equal(criterion.Value, value) + } + if criterion.Modifier == models.CriterionModifierNotEquals { + assert.NotEqual(criterion.Value, value) + } +} + +func TestSceneQueryRating(t *testing.T) { + const rating = 3 + ratingCriterion := models.IntCriterionInput{ + Value: rating, + Modifier: models.CriterionModifierEquals, + } + + verifyScenesRating(t, ratingCriterion) + + ratingCriterion.Modifier = models.CriterionModifierNotEquals + verifyScenesRating(t, ratingCriterion) + + ratingCriterion.Modifier = models.CriterionModifierGreaterThan + verifyScenesRating(t, ratingCriterion) + + ratingCriterion.Modifier = models.CriterionModifierLessThan + verifyScenesRating(t, ratingCriterion) + + ratingCriterion.Modifier = models.CriterionModifierIsNull + verifyScenesRating(t, ratingCriterion) + + ratingCriterion.Modifier = models.CriterionModifierNotNull + verifyScenesRating(t, ratingCriterion) +} + +func verifyScenesRating(t *testing.T, ratingCriterion models.IntCriterionInput) { + withTxn(func(r models.Repository) error { + sqb := r.Scene() + sceneFilter := models.SceneFilterType{ + Rating: &ratingCriterion, + } + + scenes := queryScene(t, sqb, &sceneFilter, nil) + + for _, scene := range scenes { + verifyInt64(t, scene.Rating, ratingCriterion) + } + + return nil + }) +} + +func verifyInt64(t *testing.T, value sql.NullInt64, criterion models.IntCriterionInput) { + t.Helper() + assert := assert.New(t) + if criterion.Modifier == models.CriterionModifierIsNull { + assert.False(value.Valid, "expect is null values to be null") + } + if criterion.Modifier == models.CriterionModifierNotNull { + assert.True(value.Valid, "expect is null values to be null") + } + if criterion.Modifier == models.CriterionModifierEquals { + assert.Equal(int64(criterion.Value), value.Int64) + } + if criterion.Modifier == models.CriterionModifierNotEquals { + assert.NotEqual(int64(criterion.Value), value.Int64) + } + if criterion.Modifier == models.CriterionModifierGreaterThan { + assert.True(value.Int64 > int64(criterion.Value)) + } + if criterion.Modifier == models.CriterionModifierLessThan { + assert.True(value.Int64 < int64(criterion.Value)) + } +} + +func TestSceneQueryOCounter(t *testing.T) { + const oCounter = 1 + oCounterCriterion := models.IntCriterionInput{ + Value: oCounter, + Modifier: models.CriterionModifierEquals, + } + + verifyScenesOCounter(t, oCounterCriterion) + + oCounterCriterion.Modifier = models.CriterionModifierNotEquals + verifyScenesOCounter(t, oCounterCriterion) + + oCounterCriterion.Modifier = models.CriterionModifierGreaterThan + verifyScenesOCounter(t, oCounterCriterion) + + oCounterCriterion.Modifier = models.CriterionModifierLessThan + verifyScenesOCounter(t, oCounterCriterion) +} + +func verifyScenesOCounter(t *testing.T, oCounterCriterion models.IntCriterionInput) { + withTxn(func(r models.Repository) error { + sqb := r.Scene() + sceneFilter := models.SceneFilterType{ + OCounter: &oCounterCriterion, + } + + scenes := queryScene(t, sqb, &sceneFilter, nil) + + for _, scene := range scenes { + verifyInt(t, scene.OCounter, oCounterCriterion) + } + + return nil + }) +} + +func verifyInt(t *testing.T, value int, criterion models.IntCriterionInput) { + t.Helper() + assert := assert.New(t) + if criterion.Modifier == models.CriterionModifierEquals { + assert.Equal(criterion.Value, value) + } + if criterion.Modifier == models.CriterionModifierNotEquals { + assert.NotEqual(criterion.Value, value) + } + if criterion.Modifier == models.CriterionModifierGreaterThan { + assert.True(value > criterion.Value) + } + if criterion.Modifier == models.CriterionModifierLessThan { + assert.True(value < criterion.Value) + } +} + +func TestSceneQueryDuration(t *testing.T) { + duration := 200.432 + + durationCriterion := models.IntCriterionInput{ + Value: int(duration), + Modifier: models.CriterionModifierEquals, + } + verifyScenesDuration(t, durationCriterion) + + durationCriterion.Modifier = models.CriterionModifierNotEquals + verifyScenesDuration(t, durationCriterion) + + durationCriterion.Modifier = models.CriterionModifierGreaterThan + verifyScenesDuration(t, durationCriterion) + + durationCriterion.Modifier = models.CriterionModifierLessThan + verifyScenesDuration(t, durationCriterion) + + durationCriterion.Modifier = models.CriterionModifierIsNull + verifyScenesDuration(t, durationCriterion) + + durationCriterion.Modifier = models.CriterionModifierNotNull + verifyScenesDuration(t, durationCriterion) +} + +func verifyScenesDuration(t *testing.T, durationCriterion models.IntCriterionInput) { + withTxn(func(r models.Repository) error { + sqb := r.Scene() + sceneFilter := models.SceneFilterType{ + Duration: &durationCriterion, + } + + scenes := queryScene(t, sqb, &sceneFilter, nil) + + for _, scene := range scenes { + if durationCriterion.Modifier == models.CriterionModifierEquals { + assert.True(t, scene.Duration.Float64 >= float64(durationCriterion.Value) && scene.Duration.Float64 < float64(durationCriterion.Value+1)) + } else if durationCriterion.Modifier == models.CriterionModifierNotEquals { + assert.True(t, scene.Duration.Float64 < float64(durationCriterion.Value) || scene.Duration.Float64 >= float64(durationCriterion.Value+1)) + } else { + verifyFloat64(t, scene.Duration, durationCriterion) + } + } + + return nil + }) +} + +func verifyFloat64(t *testing.T, value sql.NullFloat64, criterion models.IntCriterionInput) { + assert := assert.New(t) + if criterion.Modifier == models.CriterionModifierIsNull { + assert.False(value.Valid, "expect is null values to be null") + } + if criterion.Modifier == models.CriterionModifierNotNull { + assert.True(value.Valid, "expect is null values to be null") + } + if criterion.Modifier == models.CriterionModifierEquals { + assert.Equal(float64(criterion.Value), value.Float64) + } + if criterion.Modifier == models.CriterionModifierNotEquals { + assert.NotEqual(float64(criterion.Value), value.Float64) + } + if criterion.Modifier == models.CriterionModifierGreaterThan { + assert.True(value.Float64 > float64(criterion.Value)) + } + if criterion.Modifier == models.CriterionModifierLessThan { + assert.True(value.Float64 < float64(criterion.Value)) + } +} + +func TestSceneQueryResolution(t *testing.T) { + verifyScenesResolution(t, models.ResolutionEnumLow) + verifyScenesResolution(t, models.ResolutionEnumStandard) + verifyScenesResolution(t, models.ResolutionEnumStandardHd) + verifyScenesResolution(t, models.ResolutionEnumFullHd) + verifyScenesResolution(t, models.ResolutionEnumFourK) + verifyScenesResolution(t, models.ResolutionEnum("unknown")) +} + +func verifyScenesResolution(t *testing.T, resolution models.ResolutionEnum) { + withTxn(func(r models.Repository) error { + sqb := r.Scene() + sceneFilter := models.SceneFilterType{ + Resolution: &resolution, + } + + scenes := queryScene(t, sqb, &sceneFilter, nil) + + for _, scene := range scenes { + verifySceneResolution(t, scene.Height, resolution) + } + + return nil + }) +} + +func verifySceneResolution(t *testing.T, height sql.NullInt64, resolution models.ResolutionEnum) { + assert := assert.New(t) + h := height.Int64 + + switch resolution { + case models.ResolutionEnumLow: + assert.True(h < 480) + case models.ResolutionEnumStandard: + assert.True(h >= 480 && h < 720) + case models.ResolutionEnumStandardHd: + assert.True(h >= 720 && h < 1080) + case models.ResolutionEnumFullHd: + assert.True(h >= 1080 && h < 2160) + case models.ResolutionEnumFourK: + assert.True(h >= 2160) + } +} + +func TestSceneQueryHasMarkers(t *testing.T) { + withTxn(func(r models.Repository) error { + sqb := r.Scene() + hasMarkers := "true" + sceneFilter := models.SceneFilterType{ + HasMarkers: &hasMarkers, + } + + q := getSceneStringValue(sceneIdxWithMarker, titleField) + findFilter := models.FindFilterType{ + Q: &q, + } + + scenes := queryScene(t, sqb, &sceneFilter, &findFilter) + + assert.Len(t, scenes, 1) + assert.Equal(t, sceneIDs[sceneIdxWithMarker], scenes[0].ID) + + hasMarkers = "false" + scenes = queryScene(t, sqb, &sceneFilter, &findFilter) + assert.Len(t, scenes, 0) + + findFilter.Q = nil + scenes = queryScene(t, sqb, &sceneFilter, &findFilter) + + assert.NotEqual(t, 0, len(scenes)) + + // ensure non of the ids equal the one with gallery + for _, scene := range scenes { + assert.NotEqual(t, sceneIDs[sceneIdxWithMarker], scene.ID) + } + + return nil + }) +} + +func TestSceneQueryIsMissingGallery(t *testing.T) { + withTxn(func(r models.Repository) error { + sqb := r.Scene() + isMissing := "gallery" + sceneFilter := models.SceneFilterType{ + IsMissing: &isMissing, + } + + q := getSceneStringValue(sceneIdxWithGallery, titleField) + findFilter := models.FindFilterType{ + Q: &q, + } + + scenes := queryScene(t, sqb, &sceneFilter, &findFilter) + + assert.Len(t, scenes, 0) + + findFilter.Q = nil + scenes = queryScene(t, sqb, &sceneFilter, &findFilter) + + // ensure non of the ids equal the one with gallery + for _, scene := range scenes { + assert.NotEqual(t, sceneIDs[sceneIdxWithGallery], scene.ID) + } + + return nil + }) +} + +func TestSceneQueryIsMissingStudio(t *testing.T) { + withTxn(func(r models.Repository) error { + sqb := r.Scene() + isMissing := "studio" + sceneFilter := models.SceneFilterType{ + IsMissing: &isMissing, + } + + q := getSceneStringValue(sceneIdxWithStudio, titleField) + findFilter := models.FindFilterType{ + Q: &q, + } + + scenes := queryScene(t, sqb, &sceneFilter, &findFilter) + + assert.Len(t, scenes, 0) + + findFilter.Q = nil + scenes = queryScene(t, sqb, &sceneFilter, &findFilter) + + // ensure non of the ids equal the one with studio + for _, scene := range scenes { + assert.NotEqual(t, sceneIDs[sceneIdxWithStudio], scene.ID) + } + + return nil + }) +} + +func TestSceneQueryIsMissingMovies(t *testing.T) { + withTxn(func(r models.Repository) error { + sqb := r.Scene() + isMissing := "movie" + sceneFilter := models.SceneFilterType{ + IsMissing: &isMissing, + } + + q := getSceneStringValue(sceneIdxWithMovie, titleField) + findFilter := models.FindFilterType{ + Q: &q, + } + + scenes := queryScene(t, sqb, &sceneFilter, &findFilter) + + assert.Len(t, scenes, 0) + + findFilter.Q = nil + scenes = queryScene(t, sqb, &sceneFilter, &findFilter) + + // ensure non of the ids equal the one with movies + for _, scene := range scenes { + assert.NotEqual(t, sceneIDs[sceneIdxWithMovie], scene.ID) + } + + return nil + }) +} + +func TestSceneQueryIsMissingPerformers(t *testing.T) { + withTxn(func(r models.Repository) error { + sqb := r.Scene() + isMissing := "performers" + sceneFilter := models.SceneFilterType{ + IsMissing: &isMissing, + } + + q := getSceneStringValue(sceneIdxWithPerformer, titleField) + findFilter := models.FindFilterType{ + Q: &q, + } + + scenes := queryScene(t, sqb, &sceneFilter, &findFilter) + + assert.Len(t, scenes, 0) + + findFilter.Q = nil + scenes = queryScene(t, sqb, &sceneFilter, &findFilter) + + assert.True(t, len(scenes) > 0) + + // ensure non of the ids equal the one with movies + for _, scene := range scenes { + assert.NotEqual(t, sceneIDs[sceneIdxWithPerformer], scene.ID) + } + + return nil + }) +} + +func TestSceneQueryIsMissingDate(t *testing.T) { + withTxn(func(r models.Repository) error { + sqb := r.Scene() + isMissing := "date" + sceneFilter := models.SceneFilterType{ + IsMissing: &isMissing, + } + + scenes := queryScene(t, sqb, &sceneFilter, nil) + + assert.True(t, len(scenes) > 0) + + // ensure date is null, empty or "0001-01-01" + for _, scene := range scenes { + assert.True(t, !scene.Date.Valid || scene.Date.String == "" || scene.Date.String == "0001-01-01") + } + + return nil + }) +} + +func TestSceneQueryIsMissingTags(t *testing.T) { + withTxn(func(r models.Repository) error { + sqb := r.Scene() + isMissing := "tags" + sceneFilter := models.SceneFilterType{ + IsMissing: &isMissing, + } + + q := getSceneStringValue(sceneIdxWithTwoTags, titleField) + findFilter := models.FindFilterType{ + Q: &q, + } + + scenes := queryScene(t, sqb, &sceneFilter, &findFilter) + + assert.Len(t, scenes, 0) + + findFilter.Q = nil + scenes = queryScene(t, sqb, &sceneFilter, &findFilter) + + assert.True(t, len(scenes) > 0) + + return nil + }) +} + +func TestSceneQueryIsMissingRating(t *testing.T) { + withTxn(func(r models.Repository) error { + sqb := r.Scene() + isMissing := "rating" + sceneFilter := models.SceneFilterType{ + IsMissing: &isMissing, + } + + scenes := queryScene(t, sqb, &sceneFilter, nil) + + assert.True(t, len(scenes) > 0) + + // ensure date is null, empty or "0001-01-01" + for _, scene := range scenes { + assert.True(t, !scene.Rating.Valid) + } + + return nil + }) +} + +func TestSceneQueryPerformers(t *testing.T) { + withTxn(func(r models.Repository) error { + sqb := r.Scene() + performerCriterion := models.MultiCriterionInput{ + Value: []string{ + strconv.Itoa(performerIDs[performerIdxWithScene]), + strconv.Itoa(performerIDs[performerIdx1WithScene]), + }, + Modifier: models.CriterionModifierIncludes, + } + + sceneFilter := models.SceneFilterType{ + Performers: &performerCriterion, + } + + scenes := queryScene(t, sqb, &sceneFilter, nil) + + assert.Len(t, scenes, 2) + + // ensure ids are correct + for _, scene := range scenes { + assert.True(t, scene.ID == sceneIDs[sceneIdxWithPerformer] || scene.ID == sceneIDs[sceneIdxWithTwoPerformers]) + } + + performerCriterion = models.MultiCriterionInput{ + Value: []string{ + strconv.Itoa(performerIDs[performerIdx1WithScene]), + strconv.Itoa(performerIDs[performerIdx2WithScene]), + }, + Modifier: models.CriterionModifierIncludesAll, + } + + scenes = queryScene(t, sqb, &sceneFilter, nil) + + assert.Len(t, scenes, 1) + assert.Equal(t, sceneIDs[sceneIdxWithTwoPerformers], scenes[0].ID) + + performerCriterion = models.MultiCriterionInput{ + Value: []string{ + strconv.Itoa(performerIDs[performerIdx1WithScene]), + }, + Modifier: models.CriterionModifierExcludes, + } + + q := getSceneStringValue(sceneIdxWithTwoPerformers, titleField) + findFilter := models.FindFilterType{ + Q: &q, + } + + scenes = queryScene(t, sqb, &sceneFilter, &findFilter) + assert.Len(t, scenes, 0) + + return nil + }) +} + +func TestSceneQueryTags(t *testing.T) { + withTxn(func(r models.Repository) error { + sqb := r.Scene() + tagCriterion := models.MultiCriterionInput{ + Value: []string{ + strconv.Itoa(tagIDs[tagIdxWithScene]), + strconv.Itoa(tagIDs[tagIdx1WithScene]), + }, + Modifier: models.CriterionModifierIncludes, + } + + sceneFilter := models.SceneFilterType{ + Tags: &tagCriterion, + } + + scenes := queryScene(t, sqb, &sceneFilter, nil) + assert.Len(t, scenes, 2) + + // ensure ids are correct + for _, scene := range scenes { + assert.True(t, scene.ID == sceneIDs[sceneIdxWithTag] || scene.ID == sceneIDs[sceneIdxWithTwoTags]) + } + + tagCriterion = models.MultiCriterionInput{ + Value: []string{ + strconv.Itoa(tagIDs[tagIdx1WithScene]), + strconv.Itoa(tagIDs[tagIdx2WithScene]), + }, + Modifier: models.CriterionModifierIncludesAll, + } + + scenes = queryScene(t, sqb, &sceneFilter, nil) + + assert.Len(t, scenes, 1) + assert.Equal(t, sceneIDs[sceneIdxWithTwoTags], scenes[0].ID) + + tagCriterion = models.MultiCriterionInput{ + Value: []string{ + strconv.Itoa(tagIDs[tagIdx1WithScene]), + }, + Modifier: models.CriterionModifierExcludes, + } + + q := getSceneStringValue(sceneIdxWithTwoTags, titleField) + findFilter := models.FindFilterType{ + Q: &q, + } + + scenes = queryScene(t, sqb, &sceneFilter, &findFilter) + assert.Len(t, scenes, 0) + + return nil + }) +} + +func TestSceneQueryStudio(t *testing.T) { + withTxn(func(r models.Repository) error { + sqb := r.Scene() + studioCriterion := models.MultiCriterionInput{ + Value: []string{ + strconv.Itoa(studioIDs[studioIdxWithScene]), + }, + Modifier: models.CriterionModifierIncludes, + } + + sceneFilter := models.SceneFilterType{ + Studios: &studioCriterion, + } + + scenes := queryScene(t, sqb, &sceneFilter, nil) + + assert.Len(t, scenes, 1) + + // ensure id is correct + assert.Equal(t, sceneIDs[sceneIdxWithStudio], scenes[0].ID) + + studioCriterion = models.MultiCriterionInput{ + Value: []string{ + strconv.Itoa(studioIDs[studioIdxWithScene]), + }, + Modifier: models.CriterionModifierExcludes, + } + + q := getSceneStringValue(sceneIdxWithStudio, titleField) + findFilter := models.FindFilterType{ + Q: &q, + } + + scenes = queryScene(t, sqb, &sceneFilter, &findFilter) + assert.Len(t, scenes, 0) + + return nil + }) +} + +func TestSceneQueryMovies(t *testing.T) { + withTxn(func(r models.Repository) error { + sqb := r.Scene() + movieCriterion := models.MultiCriterionInput{ + Value: []string{ + strconv.Itoa(movieIDs[movieIdxWithScene]), + }, + Modifier: models.CriterionModifierIncludes, + } + + sceneFilter := models.SceneFilterType{ + Movies: &movieCriterion, + } + + scenes := queryScene(t, sqb, &sceneFilter, nil) + + assert.Len(t, scenes, 1) + + // ensure id is correct + assert.Equal(t, sceneIDs[sceneIdxWithMovie], scenes[0].ID) + + movieCriterion = models.MultiCriterionInput{ + Value: []string{ + strconv.Itoa(movieIDs[movieIdxWithScene]), + }, + Modifier: models.CriterionModifierExcludes, + } + + q := getSceneStringValue(sceneIdxWithMovie, titleField) + findFilter := models.FindFilterType{ + Q: &q, + } + + scenes = queryScene(t, sqb, &sceneFilter, &findFilter) + assert.Len(t, scenes, 0) + + return nil + }) +} + +func TestSceneQuerySorting(t *testing.T) { + sort := titleField + direction := models.SortDirectionEnumAsc + findFilter := models.FindFilterType{ + Sort: &sort, + Direction: &direction, + } + + withTxn(func(r models.Repository) error { + sqb := r.Scene() + scenes := queryScene(t, sqb, nil, &findFilter) + + // scenes should be in same order as indexes + firstScene := scenes[0] + lastScene := scenes[len(scenes)-1] + + assert.Equal(t, sceneIDs[0], firstScene.ID) + assert.Equal(t, sceneIDs[len(sceneIDs)-1], lastScene.ID) + + // sort in descending order + direction = models.SortDirectionEnumDesc + + scenes = queryScene(t, sqb, nil, &findFilter) + firstScene = scenes[0] + lastScene = scenes[len(scenes)-1] + + assert.Equal(t, sceneIDs[len(sceneIDs)-1], firstScene.ID) + assert.Equal(t, sceneIDs[0], lastScene.ID) + + return nil + }) +} + +func TestSceneQueryPagination(t *testing.T) { + perPage := 1 + findFilter := models.FindFilterType{ + PerPage: &perPage, + } + + withTxn(func(r models.Repository) error { + sqb := r.Scene() + scenes := queryScene(t, sqb, nil, &findFilter) + + assert.Len(t, scenes, 1) + + firstID := scenes[0].ID + + page := 2 + findFilter.Page = &page + scenes = queryScene(t, sqb, nil, &findFilter) + + assert.Len(t, scenes, 1) + secondID := scenes[0].ID + assert.NotEqual(t, firstID, secondID) + + perPage = 2 + page = 1 + + scenes = queryScene(t, sqb, nil, &findFilter) + assert.Len(t, scenes, 2) + assert.Equal(t, firstID, scenes[0].ID) + assert.Equal(t, secondID, scenes[1].ID) + + return nil + }) +} + +func TestSceneCountByTagID(t *testing.T) { + withTxn(func(r models.Repository) error { + sqb := r.Scene() + + sceneCount, err := sqb.CountByTagID(tagIDs[tagIdxWithScene]) + + if err != nil { + t.Errorf("error calling CountByTagID: %s", err.Error()) + } + + assert.Equal(t, 1, sceneCount) + + sceneCount, err = sqb.CountByTagID(0) + + if err != nil { + t.Errorf("error calling CountByTagID: %s", err.Error()) + } + + assert.Equal(t, 0, sceneCount) + + return nil + }) +} + +func TestSceneCountByMovieID(t *testing.T) { + withTxn(func(r models.Repository) error { + sqb := r.Scene() + + sceneCount, err := sqb.CountByMovieID(movieIDs[movieIdxWithScene]) + + if err != nil { + t.Errorf("error calling CountByMovieID: %s", err.Error()) + } + + assert.Equal(t, 1, sceneCount) + + sceneCount, err = sqb.CountByMovieID(0) + + if err != nil { + t.Errorf("error calling CountByMovieID: %s", err.Error()) + } + + assert.Equal(t, 0, sceneCount) + + return nil + }) +} + +func TestSceneCountByStudioID(t *testing.T) { + withTxn(func(r models.Repository) error { + sqb := r.Scene() + + sceneCount, err := sqb.CountByStudioID(studioIDs[studioIdxWithScene]) + + if err != nil { + t.Errorf("error calling CountByStudioID: %s", err.Error()) + } + + assert.Equal(t, 1, sceneCount) + + sceneCount, err = sqb.CountByStudioID(0) + + if err != nil { + t.Errorf("error calling CountByStudioID: %s", err.Error()) + } + + assert.Equal(t, 0, sceneCount) + + return nil + }) +} + +func TestFindByMovieID(t *testing.T) { + withTxn(func(r models.Repository) error { + sqb := r.Scene() + + scenes, err := sqb.FindByMovieID(movieIDs[movieIdxWithScene]) + + if err != nil { + t.Errorf("error calling FindByMovieID: %s", err.Error()) + } + + assert.Len(t, scenes, 1) + assert.Equal(t, sceneIDs[sceneIdxWithMovie], scenes[0].ID) + + scenes, err = sqb.FindByMovieID(0) + + if err != nil { + t.Errorf("error calling FindByMovieID: %s", err.Error()) + } + + assert.Len(t, scenes, 0) + + return nil + }) +} + +func TestFindByPerformerID(t *testing.T) { + withTxn(func(r models.Repository) error { + sqb := r.Scene() + + scenes, err := sqb.FindByPerformerID(performerIDs[performerIdxWithScene]) + + if err != nil { + t.Errorf("error calling FindByPerformerID: %s", err.Error()) + } + + assert.Len(t, scenes, 1) + assert.Equal(t, sceneIDs[sceneIdxWithPerformer], scenes[0].ID) + + scenes, err = sqb.FindByPerformerID(0) + + if err != nil { + t.Errorf("error calling FindByPerformerID: %s", err.Error()) + } + + assert.Len(t, scenes, 0) + + return nil + }) +} + +func TestSceneUpdateSceneCover(t *testing.T) { + if err := withTxn(func(r models.Repository) error { + qb := r.Scene() + + // create performer to test against + const name = "TestSceneUpdateSceneCover" + scene := models.Scene{ + Path: name, + Checksum: sql.NullString{String: utils.MD5FromString(name), Valid: true}, + } + created, err := qb.Create(scene) + if err != nil { + return fmt.Errorf("Error creating scene: %s", err.Error()) + } + + image := []byte("image") + err = qb.UpdateCover(created.ID, image) + if err != nil { + return fmt.Errorf("Error updating scene cover: %s", err.Error()) + } + + // ensure image set + storedImage, err := qb.GetCover(created.ID) + if err != nil { + return fmt.Errorf("Error getting image: %s", err.Error()) + } + assert.Equal(t, storedImage, image) + + // set nil image + err = qb.UpdateCover(created.ID, nil) + if err == nil { + return fmt.Errorf("Expected error setting nil image") + } + + return nil + }); err != nil { + t.Error(err.Error()) + } +} + +func TestSceneDestroySceneCover(t *testing.T) { + if err := withTxn(func(r models.Repository) error { + qb := r.Scene() + + // create performer to test against + const name = "TestSceneDestroySceneCover" + scene := models.Scene{ + Path: name, + Checksum: sql.NullString{String: utils.MD5FromString(name), Valid: true}, + } + created, err := qb.Create(scene) + if err != nil { + return fmt.Errorf("Error creating scene: %s", err.Error()) + } + + image := []byte("image") + err = qb.UpdateCover(created.ID, image) + if err != nil { + return fmt.Errorf("Error updating scene image: %s", err.Error()) + } + + err = qb.DestroyCover(created.ID) + if err != nil { + return fmt.Errorf("Error destroying scene cover: %s", err.Error()) + } + + // image should be nil + storedImage, err := qb.GetCover(created.ID) + if err != nil { + return fmt.Errorf("Error getting image: %s", err.Error()) + } + assert.Nil(t, storedImage) + + return nil + }); err != nil { + t.Error(err.Error()) + } +} + +func TestSceneStashIDs(t *testing.T) { + if err := withTxn(func(r models.Repository) error { + qb := r.Scene() + + // create scene to test against + const name = "TestSceneStashIDs" + scene := models.Scene{ + Path: name, + Checksum: sql.NullString{String: utils.MD5FromString(name), Valid: true}, + } + created, err := qb.Create(scene) + if err != nil { + return fmt.Errorf("Error creating scene: %s", err.Error()) + } + + testStashIDReaderWriter(t, qb, created.ID) + return nil + }); err != nil { + t.Error(err.Error()) + } +} + +// TODO Update +// TODO IncrementOCounter +// TODO DecrementOCounter +// TODO ResetOCounter +// TODO Destroy +// TODO FindByChecksum +// TODO Count +// TODO SizeCount +// TODO All diff --git a/pkg/sqlite/scraped_item.go b/pkg/sqlite/scraped_item.go new file mode 100644 index 000000000..39e040f89 --- /dev/null +++ b/pkg/sqlite/scraped_item.go @@ -0,0 +1,90 @@ +package sqlite + +import ( + "database/sql" + + "github.com/stashapp/stash/pkg/models" +) + +const scrapedItemTable = "scraped_items" + +type scrapedItemQueryBuilder struct { + repository +} + +func NewScrapedItemReaderWriter(tx dbi) *scrapedItemQueryBuilder { + return &scrapedItemQueryBuilder{ + repository{ + tx: tx, + tableName: scrapedItemTable, + idColumn: idColumn, + }, + } +} + +func (qb *scrapedItemQueryBuilder) Create(newObject models.ScrapedItem) (*models.ScrapedItem, error) { + var ret models.ScrapedItem + if err := qb.insertObject(newObject, &ret); err != nil { + return nil, err + } + + return &ret, nil +} + +func (qb *scrapedItemQueryBuilder) Update(updatedObject models.ScrapedItem) (*models.ScrapedItem, error) { + const partial = false + if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil { + return nil, err + } + + return qb.find(updatedObject.ID) +} + +func (qb *scrapedItemQueryBuilder) Find(id int) (*models.ScrapedItem, error) { + return qb.find(id) +} + +func (qb *scrapedItemQueryBuilder) find(id int) (*models.ScrapedItem, error) { + var ret models.ScrapedItem + if err := qb.get(id, &ret); err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, err + } + return &ret, nil +} + +func (qb *scrapedItemQueryBuilder) All() ([]*models.ScrapedItem, error) { + return qb.queryScrapedItems(selectAll("scraped_items")+qb.getScrapedItemsSort(nil), nil) +} + +func (qb *scrapedItemQueryBuilder) getScrapedItemsSort(findFilter *models.FindFilterType) string { + var sort string + var direction string + if findFilter == nil { + sort = "id" // TODO studio_id and title + direction = "ASC" + } else { + sort = findFilter.GetSort("id") + direction = findFilter.GetDirection() + } + return getSort(sort, direction, "scraped_items") +} + +func (qb *scrapedItemQueryBuilder) queryScrapedItem(query string, args []interface{}) (*models.ScrapedItem, error) { + results, err := qb.queryScrapedItems(query, args) + if err != nil || len(results) < 1 { + return nil, err + } + return results[0], nil +} + +func (qb *scrapedItemQueryBuilder) queryScrapedItems(query string, args []interface{}) ([]*models.ScrapedItem, error) { + var ret models.ScrapedItems + if err := qb.query(query, args, &ret); err != nil { + return nil, err + } + + return []*models.ScrapedItem(ret), nil +} diff --git a/pkg/models/setup_test.go b/pkg/sqlite/setup_test.go similarity index 62% rename from pkg/models/setup_test.go rename to pkg/sqlite/setup_test.go index 489cf54fa..aa663f159 100644 --- a/pkg/models/setup_test.go +++ b/pkg/sqlite/setup_test.go @@ -1,6 +1,6 @@ // +build integration -package models_test +package sqlite_test import ( "context" @@ -13,11 +13,11 @@ import ( "testing" "time" - "github.com/jmoiron/sqlx" - "github.com/stashapp/stash/pkg/database" + "github.com/stashapp/stash/pkg/gallery" "github.com/stashapp/stash/pkg/models" - "github.com/stashapp/stash/pkg/models/modelstest" + "github.com/stashapp/stash/pkg/scene" + "github.com/stashapp/stash/pkg/sqlite" "github.com/stashapp/stash/pkg/utils" ) @@ -117,6 +117,11 @@ func TestMain(m *testing.M) { os.Exit(ret) } +func withTxn(f func(r models.Repository) error) error { + t := sqlite.NewTransactionManager() + return t.WithTxn(context.TODO(), f) +} + func testTeardown(databaseFile string) { err := database.DB.Close() @@ -154,121 +159,92 @@ func runTests(m *testing.M) int { } func populateDB() error { - ctx := context.TODO() - tx := database.DB.MustBeginTx(ctx, nil) + if err := withTxn(func(r models.Repository) error { + if err := createScenes(r.Scene(), totalScenes); err != nil { + return fmt.Errorf("error creating scenes: %s", err.Error()) + } - if err := createScenes(tx, totalScenes); err != nil { - tx.Rollback() + if err := createImages(r.Image(), totalImages); err != nil { + return fmt.Errorf("error creating images: %s", err.Error()) + } + + if err := createGalleries(r.Gallery(), totalGalleries); err != nil { + return fmt.Errorf("error creating galleries: %s", err.Error()) + } + + if err := createMovies(r.Movie(), moviesNameCase, moviesNameNoCase); err != nil { + return fmt.Errorf("error creating movies: %s", err.Error()) + } + + if err := createPerformers(r.Performer(), performersNameCase, performersNameNoCase); err != nil { + return fmt.Errorf("error creating performers: %s", err.Error()) + } + + if err := createTags(r.Tag(), tagsNameCase, tagsNameNoCase); err != nil { + return fmt.Errorf("error creating tags: %s", err.Error()) + } + + if err := addTagImage(r.Tag(), tagIdxWithCoverImage); err != nil { + return fmt.Errorf("error adding tag image: %s", err.Error()) + } + + if err := createStudios(r.Studio(), studiosNameCase, studiosNameNoCase); err != nil { + return fmt.Errorf("error creating studios: %s", err.Error()) + } + + if err := linkSceneGallery(r.Gallery(), sceneIdxWithGallery, galleryIdxWithScene); err != nil { + return fmt.Errorf("error linking scene to gallery: %s", err.Error()) + } + + if err := linkSceneMovie(r.Scene(), sceneIdxWithMovie, movieIdxWithScene); err != nil { + return fmt.Errorf("error scene to movie: %s", err.Error()) + } + + if err := linkScenePerformers(r.Scene()); err != nil { + return fmt.Errorf("error linking scene performers: %s", err.Error()) + } + + if err := linkSceneTags(r.Scene()); err != nil { + return fmt.Errorf("error linking scene tags: %s", err.Error()) + } + + if err := linkSceneStudio(r.Scene(), sceneIdxWithStudio, studioIdxWithScene); err != nil { + return fmt.Errorf("error linking scene studio: %s", err.Error()) + } + + if err := linkImageGallery(r.Gallery(), imageIdxWithGallery, galleryIdxWithImage); err != nil { + return fmt.Errorf("error linking gallery images: %s", err.Error()) + } + + if err := linkImagePerformers(r.Image()); err != nil { + return fmt.Errorf("error linking image performers: %s", err.Error()) + } + + if err := linkImageTags(r.Image()); err != nil { + return fmt.Errorf("error linking image tags: %s", err.Error()) + } + + if err := linkImageStudio(r.Image(), imageIdxWithStudio, studioIdxWithImage); err != nil { + return fmt.Errorf("error linking image studio: %s", err.Error()) + } + + if err := linkMovieStudio(r.Movie(), movieIdxWithStudio, studioIdxWithMovie); err != nil { + return fmt.Errorf("error linking movie studio: %s", err.Error()) + } + + if err := linkStudioParent(r.Studio(), studioIdxWithChildStudio, studioIdxWithParentStudio); err != nil { + return fmt.Errorf("error linking studio parent: %s", err.Error()) + } + + if err := createMarker(r.SceneMarker(), sceneIdxWithMarker, tagIdxWithPrimaryMarker, []int{tagIdxWithMarker}); err != nil { + return fmt.Errorf("error creating scene marker: %s", err.Error()) + } + + return nil + }); err != nil { return err } - if err := createImages(tx, totalImages); err != nil { - tx.Rollback() - return err - } - - if err := createGalleries(tx, totalGalleries); err != nil { - tx.Rollback() - return err - } - - if err := createMovies(tx, moviesNameCase, moviesNameNoCase); err != nil { - tx.Rollback() - return err - } - - if err := createPerformers(tx, performersNameCase, performersNameNoCase); err != nil { - tx.Rollback() - return err - } - - if err := createTags(tx, tagsNameCase, tagsNameNoCase); err != nil { - tx.Rollback() - return err - } - - if err := addTagImage(tx, tagIdxWithCoverImage); err != nil { - tx.Rollback() - return err - } - - if err := createStudios(tx, studiosNameCase, studiosNameNoCase); err != nil { - tx.Rollback() - return err - } - - // TODO - the link methods use Find which don't accept a transaction, so - // to commit the transaction and start a new one - if err := tx.Commit(); err != nil { - return fmt.Errorf("Error committing: %s", err.Error()) - } - - tx = database.DB.MustBeginTx(ctx, nil) - - if err := linkSceneGallery(tx, sceneIdxWithGallery, galleryIdxWithScene); err != nil { - tx.Rollback() - return err - } - - if err := linkSceneMovie(tx, sceneIdxWithMovie, movieIdxWithScene); err != nil { - tx.Rollback() - return err - } - - if err := linkScenePerformers(tx); err != nil { - tx.Rollback() - return err - } - - if err := linkSceneTags(tx); err != nil { - tx.Rollback() - return err - } - - if err := linkSceneStudio(tx, sceneIdxWithStudio, studioIdxWithScene); err != nil { - tx.Rollback() - return err - } - - if err := linkImageGallery(tx, imageIdxWithGallery, galleryIdxWithImage); err != nil { - tx.Rollback() - return err - } - - if err := linkImagePerformers(tx); err != nil { - tx.Rollback() - return err - } - - if err := linkImageTags(tx); err != nil { - tx.Rollback() - return err - } - - if err := linkImageStudio(tx, imageIdxWithStudio, studioIdxWithImage); err != nil { - tx.Rollback() - return err - } - - if err := linkMovieStudio(tx, movieIdxWithStudio, studioIdxWithMovie); err != nil { - tx.Rollback() - return err - } - - if err := linkStudioParent(tx, studioIdxWithChildStudio, studioIdxWithParentStudio); err != nil { - tx.Rollback() - return err - } - - if err := createMarker(tx, sceneIdxWithMarker, tagIdxWithPrimaryMarker, []int{tagIdxWithMarker}); err != nil { - tx.Rollback() - return err - } - - if err := tx.Commit(); err != nil { - return fmt.Errorf("Error committing: %s", err.Error()) - } - return nil } @@ -313,9 +289,7 @@ func getSceneDate(index int) models.SQLiteDate { } } -func createScenes(tx *sqlx.Tx, n int) error { - sqb := models.NewSceneQueryBuilder() - +func createScenes(sqb models.SceneReaderWriter, n int) error { for i := 0; i < n; i++ { scene := models.Scene{ Path: getSceneStringValue(i, pathField), @@ -329,7 +303,7 @@ func createScenes(tx *sqlx.Tx, n int) error { Date: getSceneDate(i), } - created, err := sqb.Create(scene, tx) + created, err := sqb.Create(scene) if err != nil { return fmt.Errorf("Error creating scene %v+: %s", scene, err.Error()) @@ -345,9 +319,7 @@ func getImageStringValue(index int, field string) string { return fmt.Sprintf("image_%04d_%s", index, field) } -func createImages(tx *sqlx.Tx, n int) error { - qb := models.NewImageQueryBuilder() - +func createImages(qb models.ImageReaderWriter, n int) error { for i := 0; i < n; i++ { image := models.Image{ Path: getImageStringValue(i, pathField), @@ -358,7 +330,7 @@ func createImages(tx *sqlx.Tx, n int) error { Height: getHeight(i), } - created, err := qb.Create(image, tx) + created, err := qb.Create(image) if err != nil { return fmt.Errorf("Error creating image %v+: %s", image, err.Error()) @@ -374,16 +346,14 @@ func getGalleryStringValue(index int, field string) string { return "gallery_" + strconv.FormatInt(int64(index), 10) + "_" + field } -func createGalleries(tx *sqlx.Tx, n int) error { - gqb := models.NewGalleryQueryBuilder() - +func createGalleries(gqb models.GalleryReaderWriter, n int) error { for i := 0; i < n; i++ { gallery := models.Gallery{ - Path: modelstest.NullString(getGalleryStringValue(i, pathField)), + Path: models.NullString(getGalleryStringValue(i, pathField)), Checksum: getGalleryStringValue(i, checksumField), } - created, err := gqb.Create(gallery, tx) + created, err := gqb.Create(gallery) if err != nil { return fmt.Errorf("Error creating gallery %v+: %s", gallery, err.Error()) @@ -400,8 +370,7 @@ func getMovieStringValue(index int, field string) string { } //createMoviees creates n movies with plain Name and o movies with camel cased NaMe included -func createMovies(tx *sqlx.Tx, n int, o int) error { - mqb := models.NewMovieQueryBuilder() +func createMovies(mqb models.MovieReaderWriter, n int, o int) error { const namePlain = "Name" const nameNoCase = "NaMe" @@ -421,7 +390,7 @@ func createMovies(tx *sqlx.Tx, n int, o int) error { Checksum: utils.MD5FromString(name), } - created, err := mqb.Create(movie, tx) + created, err := mqb.Create(movie) if err != nil { return fmt.Errorf("Error creating movie [%d] %v+: %s", i, movie, err.Error()) @@ -451,8 +420,7 @@ func getPerformerBirthdate(index int) string { } //createPerformers creates n performers with plain Name and o performers with camel cased NaMe included -func createPerformers(tx *sqlx.Tx, n int, o int) error { - pqb := models.NewPerformerQueryBuilder() +func createPerformers(pqb models.PerformerReaderWriter, n int, o int) error { const namePlain = "Name" const nameNoCase = "NaMe" @@ -477,7 +445,7 @@ func createPerformers(tx *sqlx.Tx, n int, o int) error { }, } - created, err := pqb.Create(performer, tx) + created, err := pqb.Create(performer) if err != nil { return fmt.Errorf("Error creating performer %v+: %s", performer, err.Error()) @@ -511,8 +479,7 @@ func getTagMarkerCount(id int) int { } //createTags creates n tags with plain Name and o tags with camel cased NaMe included -func createTags(tx *sqlx.Tx, n int, o int) error { - tqb := models.NewTagQueryBuilder() +func createTags(tqb models.TagReaderWriter, n int, o int) error { const namePlain = "Name" const nameNoCase = "NaMe" @@ -531,7 +498,7 @@ func createTags(tx *sqlx.Tx, n int, o int) error { Name: getTagStringValue(index, name), } - created, err := tqb.Create(tag, tx) + created, err := tqb.Create(tag) if err != nil { return fmt.Errorf("Error creating tag %v+: %s", tag, err.Error()) @@ -548,8 +515,7 @@ func getStudioStringValue(index int, field string) string { return "studio_" + strconv.FormatInt(int64(index), 10) + "_" + field } -func createStudio(tx *sqlx.Tx, name string, parentID *int64) (*models.Studio, error) { - sqb := models.NewStudioQueryBuilder() +func createStudio(sqb models.StudioReaderWriter, name string, parentID *int64) (*models.Studio, error) { studio := models.Studio{ Name: sql.NullString{String: name, Valid: true}, Checksum: utils.MD5FromString(name), @@ -559,7 +525,7 @@ func createStudio(tx *sqlx.Tx, name string, parentID *int64) (*models.Studio, er studio.ParentID = sql.NullInt64{Int64: *parentID, Valid: true} } - created, err := sqb.Create(studio, tx) + created, err := sqb.Create(studio) if err != nil { return nil, fmt.Errorf("Error creating studio %v+: %s", studio, err.Error()) @@ -569,7 +535,7 @@ func createStudio(tx *sqlx.Tx, name string, parentID *int64) (*models.Studio, er } //createStudios creates n studios with plain Name and o studios with camel cased NaMe included -func createStudios(tx *sqlx.Tx, n int, o int) error { +func createStudios(sqb models.StudioReaderWriter, n int, o int) error { const namePlain = "Name" const nameNoCase = "NaMe" @@ -584,7 +550,7 @@ func createStudios(tx *sqlx.Tx, n int, o int) error { // studios [ i ] and [ n + o - i - 1 ] should have similar names with only the Name!=NaMe part different name = getStudioStringValue(index, name) - created, err := createStudio(tx, name, nil) + created, err := createStudio(sqb, name, nil) if err != nil { return err @@ -597,15 +563,13 @@ func createStudios(tx *sqlx.Tx, n int, o int) error { return nil } -func createMarker(tx *sqlx.Tx, sceneIdx, primaryTagIdx int, tagIdxs []int) error { - mqb := models.NewSceneMarkerQueryBuilder() - +func createMarker(mqb models.SceneMarkerReaderWriter, sceneIdx, primaryTagIdx int, tagIdxs []int) error { marker := models.SceneMarker{ SceneID: sql.NullInt64{Int64: int64(sceneIDs[sceneIdx]), Valid: true}, PrimaryTagID: tagIDs[primaryTagIdx], } - created, err := mqb.Create(marker, tx) + created, err := mqb.Create(marker) if err != nil { return fmt.Errorf("Error creating marker %v+: %s", marker, err.Error()) @@ -613,57 +577,54 @@ func createMarker(tx *sqlx.Tx, sceneIdx, primaryTagIdx int, tagIdxs []int) error markerIDs = append(markerIDs, created.ID) - jqb := models.NewJoinsQueryBuilder() - - joins := []models.SceneMarkersTags{} + newTagIDs := []int{} for _, tagIdx := range tagIdxs { - join := models.SceneMarkersTags{ - SceneMarkerID: created.ID, - TagID: tagIDs[tagIdx], - } - joins = append(joins, join) + newTagIDs = append(newTagIDs, tagIDs[tagIdx]) } - if err := jqb.CreateSceneMarkersTags(joins, tx); err != nil { + if err := mqb.UpdateTags(created.ID, newTagIDs); err != nil { return fmt.Errorf("Error creating marker/tag join: %s", err.Error()) } return nil } -func linkSceneMovie(tx *sqlx.Tx, sceneIndex, movieIndex int) error { - jqb := models.NewJoinsQueryBuilder() +func linkSceneMovie(qb models.SceneReaderWriter, sceneIndex, movieIndex int) error { + sceneID := sceneIDs[sceneIndex] + movies, err := qb.GetMovies(sceneID) + if err != nil { + return err + } - _, err := jqb.AddMoviesScene(sceneIDs[sceneIndex], movieIDs[movieIndex], nil, tx) - return err + movies = append(movies, models.MoviesScenes{ + MovieID: movieIDs[movieIndex], + SceneID: sceneID, + }) + return qb.UpdateMovies(sceneID, movies) } -func linkScenePerformers(tx *sqlx.Tx) error { - if err := linkScenePerformer(tx, sceneIdxWithPerformer, performerIdxWithScene); err != nil { +func linkScenePerformers(qb models.SceneReaderWriter) error { + if err := linkScenePerformer(qb, sceneIdxWithPerformer, performerIdxWithScene); err != nil { return err } - if err := linkScenePerformer(tx, sceneIdxWithTwoPerformers, performerIdx1WithScene); err != nil { + if err := linkScenePerformer(qb, sceneIdxWithTwoPerformers, performerIdx1WithScene); err != nil { return err } - if err := linkScenePerformer(tx, sceneIdxWithTwoPerformers, performerIdx2WithScene); err != nil { + if err := linkScenePerformer(qb, sceneIdxWithTwoPerformers, performerIdx2WithScene); err != nil { return err } return nil } -func linkScenePerformer(tx *sqlx.Tx, sceneIndex, performerIndex int) error { - jqb := models.NewJoinsQueryBuilder() - - _, err := jqb.AddPerformerScene(sceneIDs[sceneIndex], performerIDs[performerIndex], tx) +func linkScenePerformer(qb models.SceneReaderWriter, sceneIndex, performerIndex int) error { + _, err := scene.AddPerformer(qb, sceneIDs[sceneIndex], performerIDs[performerIndex]) return err } -func linkSceneGallery(tx *sqlx.Tx, sceneIndex, galleryIndex int) error { - gqb := models.NewGalleryQueryBuilder() - - gallery, err := gqb.Find(galleryIDs[galleryIndex], nil) +func linkSceneGallery(gqb models.GalleryReaderWriter, sceneIndex, galleryIndex int) error { + gallery, err := gqb.Find(galleryIDs[galleryIndex]) if err != nil { return fmt.Errorf("error finding gallery: %s", err.Error()) @@ -674,132 +635,126 @@ func linkSceneGallery(tx *sqlx.Tx, sceneIndex, galleryIndex int) error { } gallery.SceneID = sql.NullInt64{Int64: int64(sceneIDs[sceneIndex]), Valid: true} - _, err = gqb.Update(*gallery, tx) + _, err = gqb.Update(*gallery) return err } -func linkSceneTags(tx *sqlx.Tx) error { - if err := linkSceneTag(tx, sceneIdxWithTag, tagIdxWithScene); err != nil { +func linkSceneTags(qb models.SceneReaderWriter) error { + if err := linkSceneTag(qb, sceneIdxWithTag, tagIdxWithScene); err != nil { return err } - if err := linkSceneTag(tx, sceneIdxWithTwoTags, tagIdx1WithScene); err != nil { + if err := linkSceneTag(qb, sceneIdxWithTwoTags, tagIdx1WithScene); err != nil { return err } - if err := linkSceneTag(tx, sceneIdxWithTwoTags, tagIdx2WithScene); err != nil { + if err := linkSceneTag(qb, sceneIdxWithTwoTags, tagIdx2WithScene); err != nil { return err } return nil } -func linkSceneTag(tx *sqlx.Tx, sceneIndex, tagIndex int) error { - jqb := models.NewJoinsQueryBuilder() - - _, err := jqb.AddSceneTag(sceneIDs[sceneIndex], tagIDs[tagIndex], tx) +func linkSceneTag(qb models.SceneReaderWriter, sceneIndex, tagIndex int) error { + _, err := scene.AddTag(qb, sceneIDs[sceneIndex], tagIDs[tagIndex]) return err } -func linkSceneStudio(tx *sqlx.Tx, sceneIndex, studioIndex int) error { - sqb := models.NewSceneQueryBuilder() - +func linkSceneStudio(sqb models.SceneWriter, sceneIndex, studioIndex int) error { scene := models.ScenePartial{ ID: sceneIDs[sceneIndex], StudioID: &sql.NullInt64{Int64: int64(studioIDs[studioIndex]), Valid: true}, } - _, err := sqb.Update(scene, tx) + _, err := sqb.Update(scene) return err } -func linkImageGallery(tx *sqlx.Tx, imageIndex, galleryIndex int) error { - jqb := models.NewJoinsQueryBuilder() - - _, err := jqb.AddImageGallery(imageIDs[imageIndex], galleryIDs[galleryIndex], tx) - - return err +func linkImageGallery(gqb models.GalleryReaderWriter, imageIndex, galleryIndex int) error { + return gallery.AddImage(gqb, galleryIDs[galleryIndex], imageIDs[imageIndex]) } -func linkImageTags(tx *sqlx.Tx) error { - if err := linkImageTag(tx, imageIdxWithTag, tagIdxWithImage); err != nil { +func linkImageTags(iqb models.ImageReaderWriter) error { + if err := linkImageTag(iqb, imageIdxWithTag, tagIdxWithImage); err != nil { return err } - if err := linkImageTag(tx, imageIdxWithTwoTags, tagIdx1WithImage); err != nil { + if err := linkImageTag(iqb, imageIdxWithTwoTags, tagIdx1WithImage); err != nil { return err } - if err := linkImageTag(tx, imageIdxWithTwoTags, tagIdx2WithImage); err != nil { + if err := linkImageTag(iqb, imageIdxWithTwoTags, tagIdx2WithImage); err != nil { return err } return nil } -func linkImageTag(tx *sqlx.Tx, imageIndex, tagIndex int) error { - jqb := models.NewJoinsQueryBuilder() +func linkImageTag(iqb models.ImageReaderWriter, imageIndex, tagIndex int) error { + imageID := imageIDs[imageIndex] + tags, err := iqb.GetTagIDs(imageID) + if err != nil { + return err + } - _, err := jqb.AddImageTag(imageIDs[imageIndex], tagIDs[tagIndex], tx) - return err + tags = append(tags, tagIDs[tagIndex]) + + return iqb.UpdateTags(imageID, tags) } -func linkImageStudio(tx *sqlx.Tx, imageIndex, studioIndex int) error { - sqb := models.NewImageQueryBuilder() - +func linkImageStudio(qb models.ImageWriter, imageIndex, studioIndex int) error { image := models.ImagePartial{ ID: imageIDs[imageIndex], StudioID: &sql.NullInt64{Int64: int64(studioIDs[studioIndex]), Valid: true}, } - _, err := sqb.Update(image, tx) + _, err := qb.Update(image) return err } -func linkImagePerformers(tx *sqlx.Tx) error { - if err := linkImagePerformer(tx, imageIdxWithPerformer, performerIdxWithImage); err != nil { +func linkImagePerformers(qb models.ImageReaderWriter) error { + if err := linkImagePerformer(qb, imageIdxWithPerformer, performerIdxWithImage); err != nil { return err } - if err := linkImagePerformer(tx, imageIdxWithTwoPerformers, performerIdx1WithImage); err != nil { + if err := linkImagePerformer(qb, imageIdxWithTwoPerformers, performerIdx1WithImage); err != nil { return err } - if err := linkImagePerformer(tx, imageIdxWithTwoPerformers, performerIdx2WithImage); err != nil { + if err := linkImagePerformer(qb, imageIdxWithTwoPerformers, performerIdx2WithImage); err != nil { return err } return nil } -func linkImagePerformer(tx *sqlx.Tx, imageIndex, performerIndex int) error { - jqb := models.NewJoinsQueryBuilder() +func linkImagePerformer(iqb models.ImageReaderWriter, imageIndex, performerIndex int) error { + imageID := imageIDs[imageIndex] + performers, err := iqb.GetPerformerIDs(imageID) + if err != nil { + return err + } - _, err := jqb.AddPerformerImage(imageIDs[imageIndex], performerIDs[performerIndex], tx) - return err + performers = append(performers, performerIDs[performerIndex]) + + return iqb.UpdatePerformers(imageID, performers) } -func linkMovieStudio(tx *sqlx.Tx, movieIndex, studioIndex int) error { - mqb := models.NewMovieQueryBuilder() - +func linkMovieStudio(mqb models.MovieWriter, movieIndex, studioIndex int) error { movie := models.MoviePartial{ ID: movieIDs[movieIndex], StudioID: &sql.NullInt64{Int64: int64(studioIDs[studioIndex]), Valid: true}, } - _, err := mqb.Update(movie, tx) + _, err := mqb.Update(movie) return err } -func linkStudioParent(tx *sqlx.Tx, parentIndex, childIndex int) error { - sqb := models.NewStudioQueryBuilder() - +func linkStudioParent(qb models.StudioWriter, parentIndex, childIndex int) error { studio := models.StudioPartial{ ID: studioIDs[childIndex], ParentID: &sql.NullInt64{Int64: int64(studioIDs[parentIndex]), Valid: true}, } - _, err := sqb.Update(studio, tx) + _, err := qb.Update(studio) return err } -func addTagImage(tx *sqlx.Tx, tagIndex int) error { - qb := models.NewTagQueryBuilder() - - return qb.UpdateTagImage(tagIDs[tagIndex], models.DefaultTagImage, tx) +func addTagImage(qb models.TagWriter, tagIndex int) error { + return qb.UpdateImage(tagIDs[tagIndex], models.DefaultTagImage) } diff --git a/pkg/sqlite/sql.go b/pkg/sqlite/sql.go new file mode 100644 index 000000000..27b4db64d --- /dev/null +++ b/pkg/sqlite/sql.go @@ -0,0 +1,254 @@ +package sqlite + +import ( + "database/sql" + "math/rand" + "strconv" + "strings" + + "github.com/jmoiron/sqlx" + "github.com/stashapp/stash/pkg/logger" + "github.com/stashapp/stash/pkg/models" +) + +var randomSortFloat = rand.Float64() + +func selectAll(tableName string) string { + idColumn := getColumn(tableName, "*") + return "SELECT " + idColumn + " FROM " + tableName + " " +} + +func selectDistinctIDs(tableName string) string { + idColumn := getColumn(tableName, "id") + return "SELECT DISTINCT " + idColumn + " FROM " + tableName + " " +} + +func getColumn(tableName string, columnName string) string { + return tableName + "." + columnName +} + +func getPagination(findFilter *models.FindFilterType) string { + if findFilter == nil { + panic("nil find filter for pagination") + } + + var page int + if findFilter.Page == nil || *findFilter.Page < 1 { + page = 1 + } else { + page = *findFilter.Page + } + + var perPage int + if findFilter.PerPage == nil { + perPage = 25 + } else { + perPage = *findFilter.PerPage + } + + if perPage > 1000 { + perPage = 1000 + } else if perPage < 1 { + perPage = 1 + } + + page = (page - 1) * perPage + return " LIMIT " + strconv.Itoa(perPage) + " OFFSET " + strconv.Itoa(page) + " " +} + +func getSort(sort string, direction string, tableName string) string { + if direction != "ASC" && direction != "DESC" { + direction = "ASC" + } + + const randomSeedPrefix = "random_" + + if strings.HasSuffix(sort, "_count") { + var relationTableName = strings.TrimSuffix(sort, "_count") // TODO: pluralize? + colName := getColumn(relationTableName, "id") + return " ORDER BY COUNT(distinct " + colName + ") " + direction + } else if strings.Compare(sort, "filesize") == 0 { + colName := getColumn(tableName, "size") + return " ORDER BY cast(" + colName + " as integer) " + direction + } else if strings.HasPrefix(sort, randomSeedPrefix) { + // seed as a parameter from the UI + // turn the provided seed into a float + seedStr := "0." + sort[len(randomSeedPrefix):] + seed, err := strconv.ParseFloat(seedStr, 32) + if err != nil { + // fallback to default seed + seed = randomSortFloat + } + return getRandomSort(tableName, direction, seed) + } else if strings.Compare(sort, "random") == 0 { + return getRandomSort(tableName, direction, randomSortFloat) + } else { + colName := getColumn(tableName, sort) + var additional string + if tableName == "scenes" { + additional = ", bitrate DESC, framerate DESC, scenes.rating DESC, scenes.duration DESC" + } else if tableName == "scene_markers" { + additional = ", scene_markers.scene_id ASC, scene_markers.seconds ASC" + } + if strings.Compare(sort, "name") == 0 { + return " ORDER BY " + colName + " COLLATE NOCASE " + direction + additional + } + if strings.Compare(sort, "title") == 0 { + return " ORDER BY " + colName + " COLLATE NATURAL_CS " + direction + additional + } + + return " ORDER BY " + colName + " " + direction + additional + } +} + +func getRandomSort(tableName string, direction string, seed float64) string { + // https://stackoverflow.com/a/24511461 + colName := getColumn(tableName, "id") + randomSortString := strconv.FormatFloat(seed, 'f', 16, 32) + return " ORDER BY " + "(substr(" + colName + " * " + randomSortString + ", length(" + colName + ") + 2))" + " " + direction +} + +func getSearchBinding(columns []string, q string, not bool) (string, []interface{}) { + var likeClauses []string + var args []interface{} + + notStr := "" + binaryType := " OR " + if not { + notStr = " NOT " + binaryType = " AND " + } + + queryWords := strings.Split(q, " ") + trimmedQuery := strings.Trim(q, "\"") + if trimmedQuery == q { + // Search for any word + for _, word := range queryWords { + for _, column := range columns { + likeClauses = append(likeClauses, column+notStr+" LIKE ?") + args = append(args, "%"+word+"%") + } + } + } else { + // Search the exact query + for _, column := range columns { + likeClauses = append(likeClauses, column+notStr+" LIKE ?") + args = append(args, "%"+trimmedQuery+"%") + } + } + likes := strings.Join(likeClauses, binaryType) + + return "(" + likes + ")", args +} + +func getInBinding(length int) string { + bindings := strings.Repeat("?, ", length) + bindings = strings.TrimRight(bindings, ", ") + return "(" + bindings + ")" +} + +func getCriterionModifierBinding(criterionModifier models.CriterionModifier, value interface{}) (string, int) { + var length int + switch x := value.(type) { + case []string: + length = len(x) + case []int: + length = len(x) + default: + length = 1 + } + if modifier := criterionModifier.String(); criterionModifier.IsValid() { + switch modifier { + case "EQUALS", "NOT_EQUALS", "GREATER_THAN", "LESS_THAN", "IS_NULL", "NOT_NULL": + return getSimpleCriterionClause(criterionModifier, "?") + case "INCLUDES": + return "IN " + getInBinding(length), length // TODO? + case "EXCLUDES": + return "NOT IN " + getInBinding(length), length // TODO? + default: + logger.Errorf("todo") + return "= ?", 1 // TODO + } + } + return "= ?", 1 // TODO +} + +func getSimpleCriterionClause(criterionModifier models.CriterionModifier, rhs string) (string, int) { + if modifier := criterionModifier.String(); criterionModifier.IsValid() { + switch modifier { + case "EQUALS": + return "= " + rhs, 1 + case "NOT_EQUALS": + return "!= " + rhs, 1 + case "GREATER_THAN": + return "> " + rhs, 1 + case "LESS_THAN": + return "< " + rhs, 1 + case "IS_NULL": + return "IS NULL", 0 + case "NOT_NULL": + return "IS NOT NULL", 0 + default: + logger.Errorf("todo") + return "= ?", 1 // TODO + } + } + + return "= ?", 1 // TODO +} + +func getIntCriterionWhereClause(column string, input models.IntCriterionInput) (string, int) { + binding, count := getCriterionModifierBinding(input.Modifier, input.Value) + return column + " " + binding, count +} + +// returns where clause and having clause +func getMultiCriterionClause(primaryTable, foreignTable, joinTable, primaryFK, foreignFK string, criterion *models.MultiCriterionInput) (string, string) { + whereClause := "" + havingClause := "" + if criterion.Modifier == models.CriterionModifierIncludes { + // includes any of the provided ids + whereClause = foreignTable + ".id IN " + getInBinding(len(criterion.Value)) + } else if criterion.Modifier == models.CriterionModifierIncludesAll { + // includes all of the provided ids + whereClause = foreignTable + ".id IN " + getInBinding(len(criterion.Value)) + havingClause = "count(distinct " + foreignTable + ".id) IS " + strconv.Itoa(len(criterion.Value)) + } else if criterion.Modifier == models.CriterionModifierExcludes { + // excludes all of the provided ids + if joinTable != "" { + whereClause = "not exists (select " + joinTable + "." + primaryFK + " from " + joinTable + " where " + joinTable + "." + primaryFK + " = " + primaryTable + ".id and " + joinTable + "." + foreignFK + " in " + getInBinding(len(criterion.Value)) + ")" + } else { + whereClause = "not exists (select s.id from " + primaryTable + " as s where s.id = " + primaryTable + ".id and s." + foreignFK + " in " + getInBinding(len(criterion.Value)) + ")" + } + } + + return whereClause, havingClause +} + +func ensureTx(tx *sqlx.Tx) { + if tx == nil { + panic("must use a transaction") + } +} + +func getImage(tx dbi, query string, args ...interface{}) ([]byte, error) { + rows, err := tx.Queryx(query, args...) + + if err != nil && err != sql.ErrNoRows { + return nil, err + } + defer rows.Close() + + var ret []byte + if rows.Next() { + if err := rows.Scan(&ret); err != nil { + return nil, err + } + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return ret, nil +} diff --git a/pkg/sqlite/stash_id_test.go b/pkg/sqlite/stash_id_test.go new file mode 100644 index 000000000..943abc9c0 --- /dev/null +++ b/pkg/sqlite/stash_id_test.go @@ -0,0 +1,70 @@ +package sqlite_test + +import ( + "testing" + + "github.com/stashapp/stash/pkg/models" + "github.com/stretchr/testify/assert" +) + +type stashIDReaderWriter interface { + GetStashIDs(performerID int) ([]*models.StashID, error) + UpdateStashIDs(performerID int, stashIDs []models.StashID) error +} + +func testStashIDReaderWriter(t *testing.T, r stashIDReaderWriter, id int) { + // ensure no stash IDs to begin with + testNoStashIDs(t, r, id) + + // ensure GetStashIDs with non-existing also returns none + testNoStashIDs(t, r, -1) + + // add stash ids + const stashIDStr = "stashID" + const endpoint = "endpoint" + stashID := models.StashID{ + StashID: stashIDStr, + Endpoint: endpoint, + } + + // update stash ids and ensure was updated + if err := r.UpdateStashIDs(id, []models.StashID{stashID}); err != nil { + t.Error(err.Error()) + } + + testStashIDs(t, r, id, []*models.StashID{&stashID}) + + // update non-existing id - should return error + if err := r.UpdateStashIDs(-1, []models.StashID{stashID}); err == nil { + t.Error("expected error when updating non-existing id") + } + + // remove stash ids and ensure was updated + if err := r.UpdateStashIDs(id, []models.StashID{}); err != nil { + t.Error(err.Error()) + } + + testNoStashIDs(t, r, id) +} + +func testNoStashIDs(t *testing.T, r stashIDReaderWriter, id int) { + t.Helper() + stashIDs, err := r.GetStashIDs(id) + if err != nil { + t.Error(err.Error()) + return + } + + assert.Len(t, stashIDs, 0) +} + +func testStashIDs(t *testing.T, r stashIDReaderWriter, id int, expected []*models.StashID) { + t.Helper() + stashIDs, err := r.GetStashIDs(id) + if err != nil { + t.Error(err.Error()) + return + } + + assert.Equal(t, stashIDs, expected) +} diff --git a/pkg/sqlite/studio.go b/pkg/sqlite/studio.go new file mode 100644 index 000000000..84c04eed5 --- /dev/null +++ b/pkg/sqlite/studio.go @@ -0,0 +1,277 @@ +package sqlite + +import ( + "database/sql" + "fmt" + + "github.com/stashapp/stash/pkg/models" +) + +const studioTable = "studios" +const studioIDColumn = "studio_id" + +type studioQueryBuilder struct { + repository +} + +func NewStudioReaderWriter(tx dbi) *studioQueryBuilder { + return &studioQueryBuilder{ + repository{ + tx: tx, + tableName: studioTable, + idColumn: idColumn, + }, + } +} + +func (qb *studioQueryBuilder) Create(newObject models.Studio) (*models.Studio, error) { + var ret models.Studio + if err := qb.insertObject(newObject, &ret); err != nil { + return nil, err + } + + return &ret, nil +} + +func (qb *studioQueryBuilder) Update(updatedObject models.StudioPartial) (*models.Studio, error) { + const partial = true + if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil { + return nil, err + } + + return qb.Find(updatedObject.ID) +} + +func (qb *studioQueryBuilder) UpdateFull(updatedObject models.Studio) (*models.Studio, error) { + const partial = false + if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil { + return nil, err + } + + return qb.Find(updatedObject.ID) +} + +func (qb *studioQueryBuilder) Destroy(id int) error { + // TODO - set null on foreign key in scraped items + // remove studio from scraped items + _, err := qb.tx.Exec("UPDATE scraped_items SET studio_id = null WHERE studio_id = ?", id) + if err != nil { + return err + } + + return qb.destroyExisting([]int{id}) +} + +func (qb *studioQueryBuilder) Find(id int) (*models.Studio, error) { + var ret models.Studio + if err := qb.get(id, &ret); err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, err + } + return &ret, nil +} + +func (qb *studioQueryBuilder) FindMany(ids []int) ([]*models.Studio, error) { + var studios []*models.Studio + for _, id := range ids { + studio, err := qb.Find(id) + if err != nil { + return nil, err + } + + if studio == nil { + return nil, fmt.Errorf("studio with id %d not found", id) + } + + studios = append(studios, studio) + } + + return studios, nil +} + +func (qb *studioQueryBuilder) FindChildren(id int) ([]*models.Studio, error) { + query := "SELECT studios.* FROM studios WHERE studios.parent_id = ?" + args := []interface{}{id} + return qb.queryStudios(query, args) +} + +func (qb *studioQueryBuilder) FindBySceneID(sceneID int) (*models.Studio, error) { + query := "SELECT studios.* FROM studios JOIN scenes ON studios.id = scenes.studio_id WHERE scenes.id = ? LIMIT 1" + args := []interface{}{sceneID} + return qb.queryStudio(query, args) +} + +func (qb *studioQueryBuilder) FindByName(name string, nocase bool) (*models.Studio, error) { + query := "SELECT * FROM studios WHERE name = ?" + if nocase { + query += " COLLATE NOCASE" + } + query += " LIMIT 1" + args := []interface{}{name} + return qb.queryStudio(query, args) +} + +func (qb *studioQueryBuilder) Count() (int, error) { + return qb.runCountQuery(qb.buildCountQuery("SELECT studios.id FROM studios"), nil) +} + +func (qb *studioQueryBuilder) All() ([]*models.Studio, error) { + return qb.queryStudios(selectAll("studios")+qb.getStudioSort(nil), nil) +} + +func (qb *studioQueryBuilder) AllSlim() ([]*models.Studio, error) { + return qb.queryStudios("SELECT studios.id, studios.name, studios.parent_id FROM studios "+qb.getStudioSort(nil), nil) +} + +func (qb *studioQueryBuilder) Query(studioFilter *models.StudioFilterType, findFilter *models.FindFilterType) ([]*models.Studio, int, error) { + if studioFilter == nil { + studioFilter = &models.StudioFilterType{} + } + if findFilter == nil { + findFilter = &models.FindFilterType{} + } + + var whereClauses []string + var havingClauses []string + var args []interface{} + body := selectDistinctIDs("studios") + body += ` + left join scenes on studios.id = scenes.studio_id + left join studio_stash_ids on studio_stash_ids.studio_id = studios.id + ` + + if q := findFilter.Q; q != nil && *q != "" { + searchColumns := []string{"studios.name"} + + clause, thisArgs := getSearchBinding(searchColumns, *q, false) + whereClauses = append(whereClauses, clause) + args = append(args, thisArgs...) + } + + if parentsFilter := studioFilter.Parents; parentsFilter != nil && len(parentsFilter.Value) > 0 { + body += ` + left join studios as parent_studio on parent_studio.id = studios.parent_id + ` + + for _, studioID := range parentsFilter.Value { + args = append(args, studioID) + } + + whereClause, havingClause := getMultiCriterionClause("studios", "parent_studio", "", "", "parent_id", parentsFilter) + whereClauses = appendClause(whereClauses, whereClause) + havingClauses = appendClause(havingClauses, havingClause) + } + + if stashIDFilter := studioFilter.StashID; stashIDFilter != nil { + whereClauses = append(whereClauses, "studio_stash_ids.stash_id = ?") + args = append(args, stashIDFilter) + } + + if isMissingFilter := studioFilter.IsMissing; isMissingFilter != nil && *isMissingFilter != "" { + switch *isMissingFilter { + case "image": + body += `left join studios_image on studios_image.studio_id = studios.id + ` + whereClauses = appendClause(whereClauses, "studios_image.studio_id IS NULL") + case "stash_id": + whereClauses = appendClause(whereClauses, "studio_stash_ids.studio_id IS NULL") + default: + whereClauses = appendClause(whereClauses, "studios."+*isMissingFilter+" IS NULL") + } + } + + sortAndPagination := qb.getStudioSort(findFilter) + getPagination(findFilter) + idsResult, countResult, err := qb.executeFindQuery(body, args, sortAndPagination, whereClauses, havingClauses) + if err != nil { + return nil, 0, err + } + + var studios []*models.Studio + for _, id := range idsResult { + studio, err := qb.Find(id) + if err != nil { + return nil, 0, err + } + + studios = append(studios, studio) + } + + return studios, countResult, nil +} + +func (qb *studioQueryBuilder) getStudioSort(findFilter *models.FindFilterType) string { + var sort string + var direction string + if findFilter == nil { + sort = "name" + direction = "ASC" + } else { + sort = findFilter.GetSort("name") + direction = findFilter.GetDirection() + } + return getSort(sort, direction, "studios") +} + +func (qb *studioQueryBuilder) queryStudio(query string, args []interface{}) (*models.Studio, error) { + results, err := qb.queryStudios(query, args) + if err != nil || len(results) < 1 { + return nil, err + } + return results[0], nil +} + +func (qb *studioQueryBuilder) queryStudios(query string, args []interface{}) ([]*models.Studio, error) { + var ret models.Studios + if err := qb.query(query, args, &ret); err != nil { + return nil, err + } + + return []*models.Studio(ret), nil +} + +func (qb *studioQueryBuilder) imageRepository() *imageRepository { + return &imageRepository{ + repository: repository{ + tx: qb.tx, + tableName: "studios_image", + idColumn: studioIDColumn, + }, + imageColumn: "image", + } +} + +func (qb *studioQueryBuilder) GetImage(studioID int) ([]byte, error) { + return qb.imageRepository().get(studioID) +} + +func (qb *studioQueryBuilder) HasImage(studioID int) (bool, error) { + return qb.imageRepository().exists(studioID) +} + +func (qb *studioQueryBuilder) UpdateImage(studioID int, image []byte) error { + return qb.imageRepository().replace(studioID, image) +} + +func (qb *studioQueryBuilder) DestroyImage(studioID int) error { + return qb.imageRepository().destroy([]int{studioID}) +} + +func (qb *studioQueryBuilder) stashIDRepository() *stashIDRepository { + return &stashIDRepository{ + repository{ + tx: qb.tx, + tableName: "studio_stash_ids", + idColumn: studioIDColumn, + }, + } +} + +func (qb *studioQueryBuilder) GetStashIDs(studioID int) ([]*models.StashID, error) { + return qb.stashIDRepository().get(studioID) +} + +func (qb *studioQueryBuilder) UpdateStashIDs(studioID int, stashIDs []models.StashID) error { + return qb.stashIDRepository().replace(studioID, stashIDs) +} diff --git a/pkg/sqlite/studio_test.go b/pkg/sqlite/studio_test.go new file mode 100644 index 000000000..0dd3e4585 --- /dev/null +++ b/pkg/sqlite/studio_test.go @@ -0,0 +1,294 @@ +// +build integration + +package sqlite_test + +import ( + "database/sql" + "errors" + "fmt" + "strconv" + "strings" + "testing" + + "github.com/stashapp/stash/pkg/models" + "github.com/stretchr/testify/assert" +) + +func TestStudioFindByName(t *testing.T) { + withTxn(func(r models.Repository) error { + sqb := r.Studio() + + name := studioNames[studioIdxWithScene] // find a studio by name + + studio, err := sqb.FindByName(name, false) + + if err != nil { + t.Errorf("Error finding studios: %s", err.Error()) + } + + assert.Equal(t, studioNames[studioIdxWithScene], studio.Name.String) + + name = studioNames[studioIdxWithDupName] // find a studio by name nocase + + studio, err = sqb.FindByName(name, true) + + if err != nil { + t.Errorf("Error finding studios: %s", err.Error()) + } + // studioIdxWithDupName and studioIdxWithScene should have similar names ( only diff should be Name vs NaMe) + //studio.Name should match with studioIdxWithScene since its ID is before studioIdxWithDupName + assert.Equal(t, studioNames[studioIdxWithScene], studio.Name.String) + //studio.Name should match with studioIdxWithDupName if the check is not case sensitive + assert.Equal(t, strings.ToLower(studioNames[studioIdxWithDupName]), strings.ToLower(studio.Name.String)) + + return nil + }) +} + +func TestStudioQueryParent(t *testing.T) { + withTxn(func(r models.Repository) error { + sqb := r.Studio() + studioCriterion := models.MultiCriterionInput{ + Value: []string{ + strconv.Itoa(studioIDs[studioIdxWithChildStudio]), + }, + Modifier: models.CriterionModifierIncludes, + } + + studioFilter := models.StudioFilterType{ + Parents: &studioCriterion, + } + + studios, _, err := sqb.Query(&studioFilter, nil) + if err != nil { + t.Errorf("Error querying studio: %s", err.Error()) + } + + assert.Len(t, studios, 1) + + // ensure id is correct + assert.Equal(t, sceneIDs[studioIdxWithParentStudio], studios[0].ID) + + studioCriterion = models.MultiCriterionInput{ + Value: []string{ + strconv.Itoa(studioIDs[studioIdxWithChildStudio]), + }, + Modifier: models.CriterionModifierExcludes, + } + + q := getStudioStringValue(studioIdxWithParentStudio, titleField) + findFilter := models.FindFilterType{ + Q: &q, + } + + studios, _, err = sqb.Query(&studioFilter, &findFilter) + if err != nil { + t.Errorf("Error querying studio: %s", err.Error()) + } + assert.Len(t, studios, 0) + + return nil + }) +} + +func TestStudioDestroyParent(t *testing.T) { + const parentName = "parent" + const childName = "child" + + // create parent and child studios + if err := withTxn(func(r models.Repository) error { + createdParent, err := createStudio(r.Studio(), parentName, nil) + if err != nil { + return fmt.Errorf("Error creating parent studio: %s", err.Error()) + } + + parentID := int64(createdParent.ID) + createdChild, err := createStudio(r.Studio(), childName, &parentID) + if err != nil { + return fmt.Errorf("Error creating child studio: %s", err.Error()) + } + + sqb := r.Studio() + + // destroy the parent + err = sqb.Destroy(createdParent.ID) + if err != nil { + return fmt.Errorf("Error destroying parent studio: %s", err.Error()) + } + + // destroy the child + err = sqb.Destroy(createdChild.ID) + if err != nil { + return fmt.Errorf("Error destroying child studio: %s", err.Error()) + } + + return nil + }); err != nil { + t.Error(err.Error()) + } +} + +func TestStudioFindChildren(t *testing.T) { + withTxn(func(r models.Repository) error { + sqb := r.Studio() + + studios, err := sqb.FindChildren(studioIDs[studioIdxWithChildStudio]) + + if err != nil { + t.Errorf("error calling FindChildren: %s", err.Error()) + } + + assert.Len(t, studios, 1) + assert.Equal(t, studioIDs[studioIdxWithParentStudio], studios[0].ID) + + studios, err = sqb.FindChildren(0) + + if err != nil { + t.Errorf("error calling FindChildren: %s", err.Error()) + } + + assert.Len(t, studios, 0) + + return nil + }) +} + +func TestStudioUpdateClearParent(t *testing.T) { + const parentName = "clearParent_parent" + const childName = "clearParent_child" + + // create parent and child studios + if err := withTxn(func(r models.Repository) error { + createdParent, err := createStudio(r.Studio(), parentName, nil) + if err != nil { + return fmt.Errorf("Error creating parent studio: %s", err.Error()) + } + + parentID := int64(createdParent.ID) + createdChild, err := createStudio(r.Studio(), childName, &parentID) + if err != nil { + return fmt.Errorf("Error creating child studio: %s", err.Error()) + } + + sqb := r.Studio() + + // clear the parent id from the child + updatePartial := models.StudioPartial{ + ID: createdChild.ID, + ParentID: &sql.NullInt64{Valid: false}, + } + + updatedStudio, err := sqb.Update(updatePartial) + + if err != nil { + return fmt.Errorf("Error updated studio: %s", err.Error()) + } + + if updatedStudio.ParentID.Valid { + return errors.New("updated studio has parent ID set") + } + + return nil + }); err != nil { + t.Error(err.Error()) + } +} + +func TestStudioUpdateStudioImage(t *testing.T) { + if err := withTxn(func(r models.Repository) error { + qb := r.Studio() + + // create performer to test against + const name = "TestStudioUpdateStudioImage" + created, err := createStudio(r.Studio(), name, nil) + if err != nil { + return fmt.Errorf("Error creating studio: %s", err.Error()) + } + + image := []byte("image") + err = qb.UpdateImage(created.ID, image) + if err != nil { + return fmt.Errorf("Error updating studio image: %s", err.Error()) + } + + // ensure image set + storedImage, err := qb.GetImage(created.ID) + if err != nil { + return fmt.Errorf("Error getting image: %s", err.Error()) + } + assert.Equal(t, storedImage, image) + + // set nil image + err = qb.UpdateImage(created.ID, nil) + if err == nil { + return fmt.Errorf("Expected error setting nil image") + } + + return nil + }); err != nil { + t.Error(err.Error()) + } +} + +func TestStudioDestroyStudioImage(t *testing.T) { + if err := withTxn(func(r models.Repository) error { + qb := r.Studio() + + // create performer to test against + const name = "TestStudioDestroyStudioImage" + created, err := createStudio(r.Studio(), name, nil) + if err != nil { + return fmt.Errorf("Error creating studio: %s", err.Error()) + } + + image := []byte("image") + err = qb.UpdateImage(created.ID, image) + if err != nil { + return fmt.Errorf("Error updating studio image: %s", err.Error()) + } + + err = qb.DestroyImage(created.ID) + if err != nil { + return fmt.Errorf("Error destroying studio image: %s", err.Error()) + } + + // image should be nil + storedImage, err := qb.GetImage(created.ID) + if err != nil { + return fmt.Errorf("Error getting image: %s", err.Error()) + } + assert.Nil(t, storedImage) + + return nil + }); err != nil { + t.Error(err.Error()) + } +} + +func TestStudioStashIDs(t *testing.T) { + if err := withTxn(func(r models.Repository) error { + qb := r.Studio() + + // create studio to test against + const name = "TestStudioStashIDs" + created, err := createStudio(r.Studio(), name, nil) + if err != nil { + return fmt.Errorf("Error creating studio: %s", err.Error()) + } + + testStashIDReaderWriter(t, qb, created.ID) + return nil + }); err != nil { + t.Error(err.Error()) + } +} + +// TODO Create +// TODO Update +// TODO Destroy +// TODO Find +// TODO FindBySceneID +// TODO Count +// TODO All +// TODO AllSlim +// TODO Query diff --git a/pkg/models/querybuilder_tag.go b/pkg/sqlite/tag.go similarity index 50% rename from pkg/models/querybuilder_tag.go rename to pkg/sqlite/tag.go index de6d78b2f..688b6e4b4 100644 --- a/pkg/models/querybuilder_tag.go +++ b/pkg/sqlite/tag.go @@ -1,70 +1,58 @@ -package models +package sqlite import ( "database/sql" "errors" "fmt" - "github.com/jmoiron/sqlx" - "github.com/stashapp/stash/pkg/database" + "github.com/stashapp/stash/pkg/models" ) const tagTable = "tags" +const tagIDColumn = "tag_id" -type TagQueryBuilder struct{} - -func NewTagQueryBuilder() TagQueryBuilder { - return TagQueryBuilder{} +type tagQueryBuilder struct { + repository } -func (qb *TagQueryBuilder) Create(newTag Tag, tx *sqlx.Tx) (*Tag, error) { - ensureTx(tx) - result, err := tx.NamedExec( - `INSERT INTO tags (name, created_at, updated_at) - VALUES (:name, :created_at, :updated_at) - `, - newTag, - ) - if err != nil { - return nil, err +func NewTagReaderWriter(tx dbi) *tagQueryBuilder { + return &tagQueryBuilder{ + repository{ + tx: tx, + tableName: tagTable, + idColumn: idColumn, + }, } - studioID, err := result.LastInsertId() - if err != nil { - return nil, err - } - - if err := tx.Get(&newTag, `SELECT * FROM tags WHERE id = ? LIMIT 1`, studioID); err != nil { - return nil, err - } - - return &newTag, nil } -func (qb *TagQueryBuilder) Update(updatedTag Tag, tx *sqlx.Tx) (*Tag, error) { - ensureTx(tx) - query := `UPDATE tags SET ` + SQLGenKeys(updatedTag) + ` WHERE tags.id = :id` - _, err := tx.NamedExec( - query, - updatedTag, - ) - if err != nil { +func (qb *tagQueryBuilder) Create(newObject models.Tag) (*models.Tag, error) { + var ret models.Tag + if err := qb.insertObject(newObject, &ret); err != nil { return nil, err } - if err := tx.Get(&updatedTag, `SELECT * FROM tags WHERE id = ? LIMIT 1`, updatedTag.ID); err != nil { - return nil, err - } - return &updatedTag, nil + return &ret, nil } -func (qb *TagQueryBuilder) Destroy(id string, tx *sqlx.Tx) error { +func (qb *tagQueryBuilder) Update(updatedObject models.Tag) (*models.Tag, error) { + const partial = false + if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil { + return nil, err + } + + return qb.Find(updatedObject.ID) +} + +func (qb *tagQueryBuilder) Destroy(id int) error { + // TODO - add delete cascade to foreign key // delete tag from scenes and markers first - _, err := tx.Exec("DELETE FROM scenes_tags WHERE tag_id = ?", id) + _, err := qb.tx.Exec("DELETE FROM scenes_tags WHERE tag_id = ?", id) if err != nil { return err } - _, err = tx.Exec("DELETE FROM scene_markers_tags WHERE tag_id = ?", id) + // TODO - add delete cascade to foreign key + _, err = qb.tx.Exec("DELETE FROM scene_markers_tags WHERE tag_id = ?", id) if err != nil { return err } @@ -72,7 +60,7 @@ func (qb *TagQueryBuilder) Destroy(id string, tx *sqlx.Tx) error { // cannot unset primary_tag_id in scene_markers because it is not nullable countQuery := "SELECT COUNT(*) as count FROM scene_markers where primary_tag_id = ?" args := []interface{}{id} - primaryMarkers, err := runCountQuery(countQuery, args) + primaryMarkers, err := qb.runCountQuery(countQuery, args) if err != nil { return err } @@ -81,19 +69,24 @@ func (qb *TagQueryBuilder) Destroy(id string, tx *sqlx.Tx) error { return errors.New("Cannot delete tag used as a primary tag in scene markers") } - return executeDeleteQuery("tags", id, tx) + return qb.destroyExisting([]int{id}) } -func (qb *TagQueryBuilder) Find(id int, tx *sqlx.Tx) (*Tag, error) { - query := "SELECT * FROM tags WHERE id = ? LIMIT 1" - args := []interface{}{id} - return qb.queryTag(query, args, tx) +func (qb *tagQueryBuilder) Find(id int) (*models.Tag, error) { + var ret models.Tag + if err := qb.get(id, &ret); err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, err + } + return &ret, nil } -func (qb *TagQueryBuilder) FindMany(ids []int) ([]*Tag, error) { - var tags []*Tag +func (qb *tagQueryBuilder) FindMany(ids []int) ([]*models.Tag, error) { + var tags []*models.Tag for _, id := range ids { - tag, err := qb.Find(id, nil) + tag, err := qb.Find(id) if err != nil { return nil, err } @@ -108,7 +101,7 @@ func (qb *TagQueryBuilder) FindMany(ids []int) ([]*Tag, error) { return tags, nil } -func (qb *TagQueryBuilder) FindBySceneID(sceneID int, tx *sqlx.Tx) ([]*Tag, error) { +func (qb *tagQueryBuilder) FindBySceneID(sceneID int) ([]*models.Tag, error) { query := ` SELECT tags.* FROM tags LEFT JOIN scenes_tags as scenes_join on scenes_join.tag_id = tags.id @@ -117,10 +110,10 @@ func (qb *TagQueryBuilder) FindBySceneID(sceneID int, tx *sqlx.Tx) ([]*Tag, erro ` query += qb.getTagSort(nil) args := []interface{}{sceneID} - return qb.queryTags(query, args, tx) + return qb.queryTags(query, args) } -func (qb *TagQueryBuilder) FindByImageID(imageID int, tx *sqlx.Tx) ([]*Tag, error) { +func (qb *tagQueryBuilder) FindByImageID(imageID int) ([]*models.Tag, error) { query := ` SELECT tags.* FROM tags LEFT JOIN images_tags as images_join on images_join.tag_id = tags.id @@ -129,10 +122,10 @@ func (qb *TagQueryBuilder) FindByImageID(imageID int, tx *sqlx.Tx) ([]*Tag, erro ` query += qb.getTagSort(nil) args := []interface{}{imageID} - return qb.queryTags(query, args, tx) + return qb.queryTags(query, args) } -func (qb *TagQueryBuilder) FindByGalleryID(galleryID int, tx *sqlx.Tx) ([]*Tag, error) { +func (qb *tagQueryBuilder) FindByGalleryID(galleryID int) ([]*models.Tag, error) { query := ` SELECT tags.* FROM tags LEFT JOIN galleries_tags as galleries_join on galleries_join.tag_id = tags.id @@ -141,10 +134,10 @@ func (qb *TagQueryBuilder) FindByGalleryID(galleryID int, tx *sqlx.Tx) ([]*Tag, ` query += qb.getTagSort(nil) args := []interface{}{galleryID} - return qb.queryTags(query, args, tx) + return qb.queryTags(query, args) } -func (qb *TagQueryBuilder) FindBySceneMarkerID(sceneMarkerID int, tx *sqlx.Tx) ([]*Tag, error) { +func (qb *tagQueryBuilder) FindBySceneMarkerID(sceneMarkerID int) ([]*models.Tag, error) { query := ` SELECT tags.* FROM tags LEFT JOIN scene_markers_tags as scene_markers_join on scene_markers_join.tag_id = tags.id @@ -153,20 +146,20 @@ func (qb *TagQueryBuilder) FindBySceneMarkerID(sceneMarkerID int, tx *sqlx.Tx) ( ` query += qb.getTagSort(nil) args := []interface{}{sceneMarkerID} - return qb.queryTags(query, args, tx) + return qb.queryTags(query, args) } -func (qb *TagQueryBuilder) FindByName(name string, tx *sqlx.Tx, nocase bool) (*Tag, error) { +func (qb *tagQueryBuilder) FindByName(name string, nocase bool) (*models.Tag, error) { query := "SELECT * FROM tags WHERE name = ?" if nocase { query += " COLLATE NOCASE" } query += " LIMIT 1" args := []interface{}{name} - return qb.queryTag(query, args, tx) + return qb.queryTag(query, args) } -func (qb *TagQueryBuilder) FindByNames(names []string, tx *sqlx.Tx, nocase bool) ([]*Tag, error) { +func (qb *tagQueryBuilder) FindByNames(names []string, nocase bool) ([]*models.Tag, error) { query := "SELECT * FROM tags WHERE name" if nocase { query += " COLLATE NOCASE" @@ -176,32 +169,30 @@ func (qb *TagQueryBuilder) FindByNames(names []string, tx *sqlx.Tx, nocase bool) for _, name := range names { args = append(args, name) } - return qb.queryTags(query, args, tx) + return qb.queryTags(query, args) } -func (qb *TagQueryBuilder) Count() (int, error) { - return runCountQuery(buildCountQuery("SELECT tags.id FROM tags"), nil) +func (qb *tagQueryBuilder) Count() (int, error) { + return qb.runCountQuery(qb.buildCountQuery("SELECT tags.id FROM tags"), nil) } -func (qb *TagQueryBuilder) All() ([]*Tag, error) { - return qb.queryTags(selectAll("tags")+qb.getTagSort(nil), nil, nil) +func (qb *tagQueryBuilder) All() ([]*models.Tag, error) { + return qb.queryTags(selectAll("tags")+qb.getTagSort(nil), nil) } -func (qb *TagQueryBuilder) AllSlim() ([]*Tag, error) { - return qb.queryTags("SELECT tags.id, tags.name FROM tags "+qb.getTagSort(nil), nil, nil) +func (qb *tagQueryBuilder) AllSlim() ([]*models.Tag, error) { + return qb.queryTags("SELECT tags.id, tags.name FROM tags "+qb.getTagSort(nil), nil) } -func (qb *TagQueryBuilder) Query(tagFilter *TagFilterType, findFilter *FindFilterType) ([]*Tag, int) { +func (qb *tagQueryBuilder) Query(tagFilter *models.TagFilterType, findFilter *models.FindFilterType) ([]*models.Tag, int, error) { if tagFilter == nil { - tagFilter = &TagFilterType{} + tagFilter = &models.TagFilterType{} } if findFilter == nil { - findFilter = &FindFilterType{} + findFilter = &models.FindFilterType{} } - query := queryBuilder{ - tableName: tagTable, - } + query := qb.newQuery() query.body = selectDistinctIDs(tagTable) @@ -256,18 +247,24 @@ func (qb *TagQueryBuilder) Query(tagFilter *TagFilterType, findFilter *FindFilte // } query.sortAndPagination = qb.getTagSort(findFilter) + getPagination(findFilter) - idsResult, countResult := query.executeFind() + idsResult, countResult, err := query.executeFind() + if err != nil { + return nil, 0, err + } - var tags []*Tag + var tags []*models.Tag for _, id := range idsResult { - tag, _ := qb.Find(id, nil) + tag, err := qb.Find(id) + if err != nil { + return nil, 0, err + } tags = append(tags, tag) } - return tags, countResult + return tags, countResult, nil } -func (qb *TagQueryBuilder) getTagSort(findFilter *FindFilterType) string { +func (qb *tagQueryBuilder) getTagSort(findFilter *models.FindFilterType) string { var sort string var direction string if findFilter == nil { @@ -280,73 +277,46 @@ func (qb *TagQueryBuilder) getTagSort(findFilter *FindFilterType) string { return getSort(sort, direction, "tags") } -func (qb *TagQueryBuilder) queryTag(query string, args []interface{}, tx *sqlx.Tx) (*Tag, error) { - results, err := qb.queryTags(query, args, tx) +func (qb *tagQueryBuilder) queryTag(query string, args []interface{}) (*models.Tag, error) { + results, err := qb.queryTags(query, args) if err != nil || len(results) < 1 { return nil, err } return results[0], nil } -func (qb *TagQueryBuilder) queryTags(query string, args []interface{}, tx *sqlx.Tx) ([]*Tag, error) { - var rows *sqlx.Rows - var err error - if tx != nil { - rows, err = tx.Queryx(query, args...) - } else { - rows, err = database.DB.Queryx(query, args...) - } - - if err != nil && err != sql.ErrNoRows { - return nil, err - } - defer rows.Close() - - tags := make([]*Tag, 0) - for rows.Next() { - tag := Tag{} - if err := rows.StructScan(&tag); err != nil { - return nil, err - } - tags = append(tags, &tag) - } - - if err := rows.Err(); err != nil { +func (qb *tagQueryBuilder) queryTags(query string, args []interface{}) ([]*models.Tag, error) { + var ret models.Tags + if err := qb.query(query, args, &ret); err != nil { return nil, err } - return tags, nil + return []*models.Tag(ret), nil } -func (qb *TagQueryBuilder) UpdateTagImage(tagID int, image []byte, tx *sqlx.Tx) error { - ensureTx(tx) - - // Delete the existing cover and then create new - if err := qb.DestroyTagImage(tagID, tx); err != nil { - return err +func (qb *tagQueryBuilder) imageRepository() *imageRepository { + return &imageRepository{ + repository: repository{ + tx: qb.tx, + tableName: "tags_image", + idColumn: tagIDColumn, + }, + imageColumn: "image", } - - _, err := tx.Exec( - `INSERT INTO tags_image (tag_id, image) VALUES (?, ?)`, - tagID, - image, - ) - - return err } -func (qb *TagQueryBuilder) DestroyTagImage(tagID int, tx *sqlx.Tx) error { - ensureTx(tx) - - // Delete the existing joins - _, err := tx.Exec("DELETE FROM tags_image WHERE tag_id = ?", tagID) - if err != nil { - return err - } - return err +func (qb *tagQueryBuilder) GetImage(tagID int) ([]byte, error) { + return qb.imageRepository().get(tagID) } -func (qb *TagQueryBuilder) GetTagImage(tagID int, tx *sqlx.Tx) ([]byte, error) { - query := `SELECT image from tags_image WHERE tag_id = ?` - return getImage(tx, query, tagID) +func (qb *tagQueryBuilder) HasImage(tagID int) (bool, error) { + return qb.imageRepository().exists(tagID) +} + +func (qb *tagQueryBuilder) UpdateImage(tagID int, image []byte) error { + return qb.imageRepository().replace(tagID, image) +} + +func (qb *tagQueryBuilder) DestroyImage(tagID int) error { + return qb.imageRepository().destroy([]int{tagID}) } diff --git a/pkg/sqlite/tag_test.go b/pkg/sqlite/tag_test.go new file mode 100644 index 000000000..532a23201 --- /dev/null +++ b/pkg/sqlite/tag_test.go @@ -0,0 +1,327 @@ +// +build integration + +package sqlite_test + +import ( + "database/sql" + "fmt" + "strings" + "testing" + + "github.com/stashapp/stash/pkg/models" + "github.com/stretchr/testify/assert" +) + +func TestMarkerFindBySceneMarkerID(t *testing.T) { + withTxn(func(r models.Repository) error { + tqb := r.Tag() + + markerID := markerIDs[markerIdxWithScene] + + tags, err := tqb.FindBySceneMarkerID(markerID) + + if err != nil { + t.Errorf("Error finding tags: %s", err.Error()) + } + + assert.Len(t, tags, 1) + assert.Equal(t, tagIDs[tagIdxWithMarker], tags[0].ID) + + tags, err = tqb.FindBySceneMarkerID(0) + + if err != nil { + t.Errorf("Error finding tags: %s", err.Error()) + } + + assert.Len(t, tags, 0) + + return nil + }) +} + +func TestTagFindByName(t *testing.T) { + withTxn(func(r models.Repository) error { + tqb := r.Tag() + + name := tagNames[tagIdxWithScene] // find a tag by name + + tag, err := tqb.FindByName(name, false) + + if err != nil { + t.Errorf("Error finding tags: %s", err.Error()) + } + + assert.Equal(t, tagNames[tagIdxWithScene], tag.Name) + + name = tagNames[tagIdxWithDupName] // find a tag by name nocase + + tag, err = tqb.FindByName(name, true) + + if err != nil { + t.Errorf("Error finding tags: %s", err.Error()) + } + // tagIdxWithDupName and tagIdxWithScene should have similar names ( only diff should be Name vs NaMe) + //tag.Name should match with tagIdxWithScene since its ID is before tagIdxWithDupName + assert.Equal(t, tagNames[tagIdxWithScene], tag.Name) + //tag.Name should match with tagIdxWithDupName if the check is not case sensitive + assert.Equal(t, strings.ToLower(tagNames[tagIdxWithDupName]), strings.ToLower(tag.Name)) + + return nil + }) +} + +func TestTagFindByNames(t *testing.T) { + var names []string + + withTxn(func(r models.Repository) error { + tqb := r.Tag() + + names = append(names, tagNames[tagIdxWithScene]) // find tags by names + + tags, err := tqb.FindByNames(names, false) + if err != nil { + t.Errorf("Error finding tags: %s", err.Error()) + } + assert.Len(t, tags, 1) + assert.Equal(t, tagNames[tagIdxWithScene], tags[0].Name) + + tags, err = tqb.FindByNames(names, true) // find tags by names nocase + if err != nil { + t.Errorf("Error finding tags: %s", err.Error()) + } + assert.Len(t, tags, 2) // tagIdxWithScene and tagIdxWithDupName + assert.Equal(t, strings.ToLower(tagNames[tagIdxWithScene]), strings.ToLower(tags[0].Name)) + assert.Equal(t, strings.ToLower(tagNames[tagIdxWithScene]), strings.ToLower(tags[1].Name)) + + names = append(names, tagNames[tagIdx1WithScene]) // find tags by names ( 2 names ) + + tags, err = tqb.FindByNames(names, false) + if err != nil { + t.Errorf("Error finding tags: %s", err.Error()) + } + assert.Len(t, tags, 2) // tagIdxWithScene and tagIdx1WithScene + assert.Equal(t, tagNames[tagIdxWithScene], tags[0].Name) + assert.Equal(t, tagNames[tagIdx1WithScene], tags[1].Name) + + tags, err = tqb.FindByNames(names, true) // find tags by names ( 2 names nocase) + if err != nil { + t.Errorf("Error finding tags: %s", err.Error()) + } + assert.Len(t, tags, 4) // tagIdxWithScene and tagIdxWithDupName , tagIdx1WithScene and tagIdx1WithDupName + assert.Equal(t, tagNames[tagIdxWithScene], tags[0].Name) + assert.Equal(t, tagNames[tagIdx1WithScene], tags[1].Name) + assert.Equal(t, tagNames[tagIdx1WithDupName], tags[2].Name) + assert.Equal(t, tagNames[tagIdxWithDupName], tags[3].Name) + + return nil + }) +} + +func TestTagQueryIsMissingImage(t *testing.T) { + withTxn(func(r models.Repository) error { + qb := r.Tag() + isMissing := "image" + tagFilter := models.TagFilterType{ + IsMissing: &isMissing, + } + + q := getTagStringValue(tagIdxWithCoverImage, "name") + findFilter := models.FindFilterType{ + Q: &q, + } + + tags, _, err := qb.Query(&tagFilter, &findFilter) + if err != nil { + t.Errorf("Error querying tag: %s", err.Error()) + } + + assert.Len(t, tags, 0) + + findFilter.Q = nil + tags, _, err = qb.Query(&tagFilter, &findFilter) + if err != nil { + t.Errorf("Error querying tag: %s", err.Error()) + } + + // ensure non of the ids equal the one with image + for _, tag := range tags { + assert.NotEqual(t, tagIDs[tagIdxWithCoverImage], tag.ID) + } + + return nil + }) +} + +func TestTagQuerySceneCount(t *testing.T) { + countCriterion := models.IntCriterionInput{ + Value: 1, + Modifier: models.CriterionModifierEquals, + } + + verifyTagSceneCount(t, countCriterion) + + countCriterion.Modifier = models.CriterionModifierNotEquals + verifyTagSceneCount(t, countCriterion) + + countCriterion.Modifier = models.CriterionModifierLessThan + verifyTagSceneCount(t, countCriterion) + + countCriterion.Value = 0 + countCriterion.Modifier = models.CriterionModifierGreaterThan + verifyTagSceneCount(t, countCriterion) +} + +func verifyTagSceneCount(t *testing.T, sceneCountCriterion models.IntCriterionInput) { + withTxn(func(r models.Repository) error { + qb := r.Tag() + tagFilter := models.TagFilterType{ + SceneCount: &sceneCountCriterion, + } + + tags, _, err := qb.Query(&tagFilter, nil) + if err != nil { + t.Errorf("Error querying tag: %s", err.Error()) + } + + for _, tag := range tags { + verifyInt64(t, sql.NullInt64{ + Int64: int64(getTagSceneCount(tag.ID)), + Valid: true, + }, sceneCountCriterion) + } + + return nil + }) +} + +// disabled due to performance issues + +// func TestTagQueryMarkerCount(t *testing.T) { +// countCriterion := models.IntCriterionInput{ +// Value: 1, +// Modifier: models.CriterionModifierEquals, +// } + +// verifyTagMarkerCount(t, countCriterion) + +// countCriterion.Modifier = models.CriterionModifierNotEquals +// verifyTagMarkerCount(t, countCriterion) + +// countCriterion.Modifier = models.CriterionModifierLessThan +// verifyTagMarkerCount(t, countCriterion) + +// countCriterion.Value = 0 +// countCriterion.Modifier = models.CriterionModifierGreaterThan +// verifyTagMarkerCount(t, countCriterion) +// } + +func verifyTagMarkerCount(t *testing.T, markerCountCriterion models.IntCriterionInput) { + withTxn(func(r models.Repository) error { + qb := r.Tag() + tagFilter := models.TagFilterType{ + MarkerCount: &markerCountCriterion, + } + + tags, _, err := qb.Query(&tagFilter, nil) + if err != nil { + t.Errorf("Error querying tag: %s", err.Error()) + } + + for _, tag := range tags { + verifyInt64(t, sql.NullInt64{ + Int64: int64(getTagMarkerCount(tag.ID)), + Valid: true, + }, markerCountCriterion) + } + + return nil + }) +} + +func TestTagUpdateTagImage(t *testing.T) { + if err := withTxn(func(r models.Repository) error { + qb := r.Tag() + + // create tag to test against + const name = "TestTagUpdateTagImage" + tag := models.Tag{ + Name: name, + } + created, err := qb.Create(tag) + if err != nil { + return fmt.Errorf("Error creating tag: %s", err.Error()) + } + + image := []byte("image") + err = qb.UpdateImage(created.ID, image) + if err != nil { + return fmt.Errorf("Error updating studio image: %s", err.Error()) + } + + // ensure image set + storedImage, err := qb.GetImage(created.ID) + if err != nil { + return fmt.Errorf("Error getting image: %s", err.Error()) + } + assert.Equal(t, storedImage, image) + + // set nil image + err = qb.UpdateImage(created.ID, nil) + if err == nil { + return fmt.Errorf("Expected error setting nil image") + } + + return nil + }); err != nil { + t.Error(err.Error()) + } +} + +func TestTagDestroyTagImage(t *testing.T) { + if err := withTxn(func(r models.Repository) error { + qb := r.Tag() + + // create performer to test against + const name = "TestTagDestroyTagImage" + tag := models.Tag{ + Name: name, + } + created, err := qb.Create(tag) + if err != nil { + return fmt.Errorf("Error creating tag: %s", err.Error()) + } + + image := []byte("image") + err = qb.UpdateImage(created.ID, image) + if err != nil { + return fmt.Errorf("Error updating studio image: %s", err.Error()) + } + + err = qb.DestroyImage(created.ID) + if err != nil { + return fmt.Errorf("Error destroying studio image: %s", err.Error()) + } + + // image should be nil + storedImage, err := qb.GetImage(created.ID) + if err != nil { + return fmt.Errorf("Error getting image: %s", err.Error()) + } + assert.Nil(t, storedImage) + + return nil + }); err != nil { + t.Error(err.Error()) + } +} + +// TODO Create +// TODO Update +// TODO Destroy +// TODO Find +// TODO FindBySceneID +// TODO FindBySceneMarkerID +// TODO Count +// TODO All +// TODO AllSlim +// TODO Query diff --git a/pkg/sqlite/transaction.go b/pkg/sqlite/transaction.go new file mode 100644 index 000000000..8ec0fd5d2 --- /dev/null +++ b/pkg/sqlite/transaction.go @@ -0,0 +1,195 @@ +package sqlite + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/jmoiron/sqlx" + "github.com/stashapp/stash/pkg/database" + "github.com/stashapp/stash/pkg/models" +) + +type dbi interface { + Get(dest interface{}, query string, args ...interface{}) error + Select(dest interface{}, query string, args ...interface{}) error + Queryx(query string, args ...interface{}) (*sqlx.Rows, error) + NamedExec(query string, arg interface{}) (sql.Result, error) + Exec(query string, args ...interface{}) (sql.Result, error) +} + +type transaction struct { + Ctx context.Context + tx *sqlx.Tx +} + +func (t *transaction) Begin() error { + if t.tx != nil { + return errors.New("transaction already begun") + } + + var err error + t.tx, err = database.DB.BeginTxx(t.Ctx, nil) + if err != nil { + return fmt.Errorf("error starting transaction: %s", err.Error()) + } + + return nil +} + +func (t *transaction) Rollback() error { + if t.tx == nil { + return errors.New("not in transaction") + } + + err := t.tx.Rollback() + if err != nil { + return fmt.Errorf("error rolling back transaction: %s", err.Error()) + } + t.tx = nil + + return nil +} + +func (t *transaction) Commit() error { + if t.tx == nil { + return errors.New("not in transaction") + } + + err := t.tx.Commit() + if err != nil { + return fmt.Errorf("error committing transaction: %s", err.Error()) + } + t.tx = nil + + return nil +} + +func (t *transaction) Repository() models.Repository { + return t +} + +func (t *transaction) ensureTx() { + if t.tx == nil { + panic("tx is nil") + } +} + +func (t *transaction) Gallery() models.GalleryReaderWriter { + t.ensureTx() + return NewGalleryReaderWriter(t.tx) +} + +func (t *transaction) Image() models.ImageReaderWriter { + t.ensureTx() + return NewImageReaderWriter(t.tx) +} + +func (t *transaction) Movie() models.MovieReaderWriter { + t.ensureTx() + return NewMovieReaderWriter(t.tx) +} + +func (t *transaction) Performer() models.PerformerReaderWriter { + t.ensureTx() + return NewPerformerReaderWriter(t.tx) +} + +func (t *transaction) SceneMarker() models.SceneMarkerReaderWriter { + t.ensureTx() + return NewSceneMarkerReaderWriter(t.tx) +} + +func (t *transaction) Scene() models.SceneReaderWriter { + t.ensureTx() + return NewSceneReaderWriter(t.tx) +} + +func (t *transaction) ScrapedItem() models.ScrapedItemReaderWriter { + t.ensureTx() + return NewScrapedItemReaderWriter(t.tx) +} + +func (t *transaction) Studio() models.StudioReaderWriter { + t.ensureTx() + return NewStudioReaderWriter(t.tx) +} + +func (t *transaction) Tag() models.TagReaderWriter { + t.ensureTx() + return NewTagReaderWriter(t.tx) +} + +type ReadTransaction struct{} + +func (t *ReadTransaction) Begin() error { + return nil +} + +func (t *ReadTransaction) Rollback() error { + return nil +} + +func (t *ReadTransaction) Commit() error { + return nil +} + +func (t *ReadTransaction) Repository() models.ReaderRepository { + return t +} + +func (t *ReadTransaction) Gallery() models.GalleryReader { + return NewGalleryReaderWriter(database.DB) +} + +func (t *ReadTransaction) Image() models.ImageReader { + return NewImageReaderWriter(database.DB) +} + +func (t *ReadTransaction) Movie() models.MovieReader { + return NewMovieReaderWriter(database.DB) +} + +func (t *ReadTransaction) Performer() models.PerformerReader { + return NewPerformerReaderWriter(database.DB) +} + +func (t *ReadTransaction) SceneMarker() models.SceneMarkerReader { + return NewSceneMarkerReaderWriter(database.DB) +} + +func (t *ReadTransaction) Scene() models.SceneReader { + return NewSceneReaderWriter(database.DB) +} + +func (t *ReadTransaction) ScrapedItem() models.ScrapedItemReader { + return NewScrapedItemReaderWriter(database.DB) +} + +func (t *ReadTransaction) Studio() models.StudioReader { + return NewStudioReaderWriter(database.DB) +} + +func (t *ReadTransaction) Tag() models.TagReader { + return NewTagReaderWriter(database.DB) +} + +type TransactionManager struct { + // only allow one write transaction at a time + c chan struct{} +} + +func NewTransactionManager() *TransactionManager { + return &TransactionManager{ + c: make(chan struct{}, 1), + } +} + +func (t *TransactionManager) WithTxn(ctx context.Context, fn func(r models.Repository) error) error { + return models.WithTxn(&transaction{Ctx: ctx}, fn) +} + +func (t *TransactionManager) WithReadTxn(ctx context.Context, fn func(r models.ReaderRepository) error) error { + return models.WithROTxn(&ReadTransaction{}, fn) +} diff --git a/pkg/studio/export.go b/pkg/studio/export.go index 7779a0be0..f7e72d53d 100644 --- a/pkg/studio/export.go +++ b/pkg/studio/export.go @@ -34,7 +34,7 @@ func ToJSON(reader models.StudioReader, studio *models.Studio) (*jsonschema.Stud } } - image, err := reader.GetStudioImage(studio.ID) + image, err := reader.GetImage(studio.ID) if err != nil { return nil, fmt.Errorf("error getting studio image: %s", err.Error()) } diff --git a/pkg/studio/export_test.go b/pkg/studio/export_test.go index b8802bcbc..72807df48 100644 --- a/pkg/studio/export_test.go +++ b/pkg/studio/export_test.go @@ -6,7 +6,6 @@ import ( "github.com/stashapp/stash/pkg/manager/jsonschema" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models/mocks" - "github.com/stashapp/stash/pkg/models/modelstest" "github.com/stretchr/testify/assert" "testing" @@ -31,7 +30,7 @@ const url = "url" const parentStudioName = "parentStudio" var parentStudio models.Studio = models.Studio{ - Name: modelstest.NullString(parentStudioName), + Name: models.NullString(parentStudioName), } var imageBytes = []byte("imageBytes") @@ -44,9 +43,9 @@ var updateTime time.Time = time.Date(2002, 01, 01, 0, 0, 0, 0, time.UTC) func createFullStudio(id int, parentID int) models.Studio { return models.Studio{ ID: id, - Name: modelstest.NullString(studioName), - URL: modelstest.NullString(url), - ParentID: modelstest.NullInt64(int64(parentID)), + Name: models.NullString(studioName), + URL: models.NullString(url), + ParentID: models.NullInt64(int64(parentID)), CreatedAt: models.SQLiteTimestamp{ Timestamp: createTime, }, @@ -139,11 +138,11 @@ func TestToJSON(t *testing.T) { imageErr := errors.New("error getting image") - mockStudioReader.On("GetStudioImage", studioID).Return(imageBytes, nil).Once() - mockStudioReader.On("GetStudioImage", noImageID).Return(nil, nil).Once() - mockStudioReader.On("GetStudioImage", errImageID).Return(nil, imageErr).Once() - mockStudioReader.On("GetStudioImage", missingParentStudioID).Return(imageBytes, nil).Maybe() - mockStudioReader.On("GetStudioImage", errStudioID).Return(imageBytes, nil).Maybe() + mockStudioReader.On("GetImage", studioID).Return(imageBytes, nil).Once() + mockStudioReader.On("GetImage", noImageID).Return(nil, nil).Once() + mockStudioReader.On("GetImage", errImageID).Return(nil, imageErr).Once() + mockStudioReader.On("GetImage", missingParentStudioID).Return(imageBytes, nil).Maybe() + mockStudioReader.On("GetImage", errStudioID).Return(imageBytes, nil).Maybe() parentStudioErr := errors.New("error getting parent studio") diff --git a/pkg/studio/import.go b/pkg/studio/import.go index 64924f475..1e8afcc67 100644 --- a/pkg/studio/import.go +++ b/pkg/studio/import.go @@ -94,7 +94,7 @@ func (i *Importer) createParentStudio(name string) (int, error) { func (i *Importer) PostImport(id int) error { if len(i.imageData) > 0 { - if err := i.ReaderWriter.UpdateStudioImage(id, i.imageData); err != nil { + if err := i.ReaderWriter.UpdateImage(id, i.imageData); err != nil { return fmt.Errorf("error setting studio image: %s", err.Error()) } } diff --git a/pkg/studio/import_test.go b/pkg/studio/import_test.go index bc71c8b10..339e8b9b6 100644 --- a/pkg/studio/import_test.go +++ b/pkg/studio/import_test.go @@ -7,7 +7,6 @@ import ( "github.com/stashapp/stash/pkg/manager/jsonschema" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models/mocks" - "github.com/stashapp/stash/pkg/models/modelstest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) @@ -143,10 +142,10 @@ func TestImporterPostImport(t *testing.T) { imageData: imageBytes, } - updateStudioImageErr := errors.New("UpdateStudioImage error") + updateStudioImageErr := errors.New("UpdateImage error") - readerWriter.On("UpdateStudioImage", studioID, imageBytes).Return(nil).Once() - readerWriter.On("UpdateStudioImage", errImageID, imageBytes).Return(updateStudioImageErr).Once() + readerWriter.On("UpdateImage", studioID, imageBytes).Return(nil).Once() + readerWriter.On("UpdateImage", errImageID, imageBytes).Return(updateStudioImageErr).Once() err := i.PostImport(studioID) assert.Nil(t, err) @@ -195,11 +194,11 @@ func TestCreate(t *testing.T) { readerWriter := &mocks.StudioReaderWriter{} studio := models.Studio{ - Name: modelstest.NullString(studioName), + Name: models.NullString(studioName), } studioErr := models.Studio{ - Name: modelstest.NullString(studioNameErr), + Name: models.NullString(studioNameErr), } i := Importer{ @@ -229,11 +228,11 @@ func TestUpdate(t *testing.T) { readerWriter := &mocks.StudioReaderWriter{} studio := models.Studio{ - Name: modelstest.NullString(studioName), + Name: models.NullString(studioName), } studioErr := models.Studio{ - Name: modelstest.NullString(studioNameErr), + Name: models.NullString(studioNameErr), } i := Importer{ diff --git a/pkg/tag/export.go b/pkg/tag/export.go index 9671ee869..456bf68d8 100644 --- a/pkg/tag/export.go +++ b/pkg/tag/export.go @@ -16,7 +16,7 @@ func ToJSON(reader models.TagReader, tag *models.Tag) (*jsonschema.Tag, error) { UpdatedAt: models.JSONTime{Time: tag.UpdatedAt.Timestamp}, } - image, err := reader.GetTagImage(tag.ID) + image, err := reader.GetImage(tag.ID) if err != nil { return nil, fmt.Errorf("error getting tag image: %s", err.Error()) } diff --git a/pkg/tag/export_test.go b/pkg/tag/export_test.go index a494aebf4..30c8448bc 100644 --- a/pkg/tag/export_test.go +++ b/pkg/tag/export_test.go @@ -84,9 +84,9 @@ func TestToJSON(t *testing.T) { imageErr := errors.New("error getting image") - mockTagReader.On("GetTagImage", tagID).Return(models.DefaultTagImage, nil).Once() - mockTagReader.On("GetTagImage", noImageID).Return(nil, nil).Once() - mockTagReader.On("GetTagImage", errImageID).Return(nil, imageErr).Once() + mockTagReader.On("GetImage", tagID).Return(models.DefaultTagImage, nil).Once() + mockTagReader.On("GetImage", noImageID).Return(nil, nil).Once() + mockTagReader.On("GetImage", errImageID).Return(nil, imageErr).Once() for i, s := range scenarios { tag := s.tag diff --git a/pkg/tag/import.go b/pkg/tag/import.go index 5985253e6..f050a4d5e 100644 --- a/pkg/tag/import.go +++ b/pkg/tag/import.go @@ -36,7 +36,7 @@ func (i *Importer) PreImport() error { func (i *Importer) PostImport(id int) error { if len(i.imageData) > 0 { - if err := i.ReaderWriter.UpdateTagImage(id, i.imageData); err != nil { + if err := i.ReaderWriter.UpdateImage(id, i.imageData); err != nil { return fmt.Errorf("error setting tag image: %s", err.Error()) } } diff --git a/pkg/tag/import_test.go b/pkg/tag/import_test.go index b99dd012c..3bcd829aa 100644 --- a/pkg/tag/import_test.go +++ b/pkg/tag/import_test.go @@ -59,10 +59,10 @@ func TestImporterPostImport(t *testing.T) { imageData: imageBytes, } - updateTagImageErr := errors.New("UpdateTagImage error") + updateTagImageErr := errors.New("UpdateImage error") - readerWriter.On("UpdateTagImage", tagID, imageBytes).Return(nil).Once() - readerWriter.On("UpdateTagImage", errImageID, imageBytes).Return(updateTagImageErr).Once() + readerWriter.On("UpdateImage", tagID, imageBytes).Return(nil).Once() + readerWriter.On("UpdateImage", errImageID, imageBytes).Return(updateTagImageErr).Once() err := i.PostImport(tagID) assert.Nil(t, err) diff --git a/pkg/utils/int_collections.go b/pkg/utils/int_collections.go index 5b59a81f7..41daa5185 100644 --- a/pkg/utils/int_collections.go +++ b/pkg/utils/int_collections.go @@ -37,3 +37,16 @@ func IntAppendUniques(vs []int, toAdd []int) []int { return vs } + +// IntExclude removes all instances of any value in toExclude from the vs int +// slice. It returns the new or unchanged int slice. +func IntExclude(vs []int, toExclude []int) []int { + var ret []int + for _, v := range vs { + if !IntInclude(toExclude, v) { + ret = append(ret, v) + } + } + + return ret +} diff --git a/pkg/utils/string_collections.go b/pkg/utils/string_collections.go index aff2d136a..af135eb71 100644 --- a/pkg/utils/string_collections.go +++ b/pkg/utils/string_collections.go @@ -35,14 +35,17 @@ func StrMap(vs []string, f func(string) string) []string { return vsm } -// StringSliceToIntSlice converts a slice of strings to a slice of ints. If any -// values cannot be parsed, then they are inserted into the returned slice as -// 0. -func StringSliceToIntSlice(ss []string) []int { +// StringSliceToIntSlice converts a slice of strings to a slice of ints. +// Returns an error if any values cannot be parsed. +func StringSliceToIntSlice(ss []string) ([]int, error) { ret := make([]int, len(ss)) for i, v := range ss { - ret[i], _ = strconv.Atoi(v) + var err error + ret[i], err = strconv.Atoi(v) + if err != nil { + return nil, err + } } - return ret + return ret, nil } diff --git a/ui/v2.5/src/core/StashService.ts b/ui/v2.5/src/core/StashService.ts index 581108095..4992fea16 100644 --- a/ui/v2.5/src/core/StashService.ts +++ b/ui/v2.5/src/core/StashService.ts @@ -178,8 +178,10 @@ export const queryFindPerformers = (filter: ListFilterModel) => }, }); -export const useFindGallery = (id: string) => - GQL.useFindGalleryQuery({ variables: { id } }); +export const useFindGallery = (id: string) => { + const skip = id === "new"; + return GQL.useFindGalleryQuery({ variables: { id }, skip }); +}; export const useFindScene = (id: string) => GQL.useFindSceneQuery({ variables: { id } }); export const useSceneStreams = (id: string) =>