Data layer restructuring (#997)

* Move query builders to sqlite package
* Add transaction system
* Wrap model resolvers in transaction
* Add error return value for StringSliceToIntSlice
* Update/refactor mutation resolvers
* Convert query builders
* Remove unused join types
* Add stash id unit tests
* Use WAL journal mode
This commit is contained in:
WithoutPants
2021-01-18 12:23:20 +11:00
committed by GitHub
parent 7bae990c67
commit 1e04deb3d4
168 changed files with 12683 additions and 10863 deletions

View File

@@ -5,12 +5,13 @@ import (
"sort" "sort"
"strconv" "strconv"
"github.com/99designs/gqlgen/graphql"
"github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
) )
type Resolver struct{} type Resolver struct {
txnManager models.TransactionManager
}
func (r *Resolver) Gallery() models.GalleryResolver { func (r *Resolver) Gallery() models.GalleryResolver {
return &galleryResolver{r} return &galleryResolver{r}
@@ -79,35 +80,75 @@ type scrapedSceneMovieResolver struct{ *Resolver }
type scrapedScenePerformerResolver struct{ *Resolver } type scrapedScenePerformerResolver struct{ *Resolver }
type scrapedSceneStudioResolver struct{ *Resolver } type scrapedSceneStudioResolver struct{ *Resolver }
func (r *queryResolver) MarkerWall(ctx context.Context, q *string) ([]*models.SceneMarker, error) { func (r *Resolver) withTxn(ctx context.Context, fn func(r models.Repository) error) error {
qb := models.NewSceneMarkerQueryBuilder() return r.txnManager.WithTxn(ctx, fn)
return qb.Wall(q)
} }
func (r *queryResolver) SceneWall(ctx context.Context, q *string) ([]*models.Scene, error) { func (r *Resolver) withReadTxn(ctx context.Context, fn func(r models.ReaderRepository) error) error {
qb := models.NewSceneQueryBuilder() return r.txnManager.WithReadTxn(ctx, fn)
return qb.Wall(q)
} }
func (r *queryResolver) MarkerStrings(ctx context.Context, q *string, sort *string) ([]*models.MarkerStringsResultType, error) { func (r *queryResolver) MarkerWall(ctx context.Context, q *string) (ret []*models.SceneMarker, err error) {
qb := models.NewSceneMarkerQueryBuilder() if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
return qb.GetMarkerStrings(q, sort) ret, err = repo.SceneMarker().Wall(q)
return err
}); err != nil {
return nil, err
}
return ret, nil
}
func (r *queryResolver) SceneWall(ctx context.Context, q *string) (ret []*models.Scene, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Scene().Wall(q)
return err
}); err != nil {
return nil, err
}
return ret, nil
}
func (r *queryResolver) MarkerStrings(ctx context.Context, q *string, sort *string) (ret []*models.MarkerStringsResultType, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.SceneMarker().GetMarkerStrings(q, sort)
return err
}); err != nil {
return nil, err
}
return ret, nil
} }
func (r *queryResolver) ValidGalleriesForScene(ctx context.Context, scene_id *string) ([]*models.Gallery, error) { func (r *queryResolver) ValidGalleriesForScene(ctx context.Context, scene_id *string) ([]*models.Gallery, error) {
if scene_id == nil { if scene_id == nil {
panic("nil scene id") // TODO make scene_id mandatory panic("nil scene id") // TODO make scene_id mandatory
} }
sceneID, _ := strconv.Atoi(*scene_id) sceneID, err := strconv.Atoi(*scene_id)
sqb := models.NewSceneQueryBuilder()
scene, err := sqb.Find(sceneID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
qb := models.NewGalleryQueryBuilder() var validGalleries []*models.Gallery
validGalleries, err := qb.ValidGalleriesForScenePath(scene.Path) var sceneGallery *models.Gallery
sceneGallery, _ := qb.FindBySceneID(sceneID, nil) if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
sqb := repo.Scene()
scene, err := sqb.Find(sceneID)
if err != nil {
return err
}
qb := repo.Gallery()
validGalleries, err = qb.ValidGalleriesForScenePath(scene.Path)
if err != nil {
return err
}
sceneGallery, err = qb.FindBySceneID(sceneID)
return err
}); err != nil {
return nil, err
}
if sceneGallery != nil { if sceneGallery != nil {
validGalleries = append(validGalleries, sceneGallery) validGalleries = append(validGalleries, sceneGallery)
} }
@@ -115,33 +156,43 @@ func (r *queryResolver) ValidGalleriesForScene(ctx context.Context, scene_id *st
} }
func (r *queryResolver) Stats(ctx context.Context) (*models.StatsResultType, error) { func (r *queryResolver) Stats(ctx context.Context) (*models.StatsResultType, error) {
scenesQB := models.NewSceneQueryBuilder() var ret models.StatsResultType
scenesCount, _ := scenesQB.Count() if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
scenesSize, _ := scenesQB.Size() scenesQB := repo.Scene()
imageQB := models.NewImageQueryBuilder() imageQB := repo.Image()
imageCount, _ := imageQB.Count() galleryQB := repo.Gallery()
imageSize, _ := imageQB.Size() studiosQB := repo.Studio()
galleryQB := models.NewGalleryQueryBuilder() performersQB := repo.Performer()
galleryCount, _ := galleryQB.Count() moviesQB := repo.Movie()
performersQB := models.NewPerformerQueryBuilder() tagsQB := repo.Tag()
performersCount, _ := performersQB.Count() scenesCount, _ := scenesQB.Count()
studiosQB := models.NewStudioQueryBuilder() scenesSize, _ := scenesQB.Size()
studiosCount, _ := studiosQB.Count() imageCount, _ := imageQB.Count()
moviesQB := models.NewMovieQueryBuilder() imageSize, _ := imageQB.Size()
moviesCount, _ := moviesQB.Count() galleryCount, _ := galleryQB.Count()
tagsQB := models.NewTagQueryBuilder() performersCount, _ := performersQB.Count()
tagsCount, _ := tagsQB.Count() studiosCount, _ := studiosQB.Count()
return &models.StatsResultType{ moviesCount, _ := moviesQB.Count()
SceneCount: scenesCount, tagsCount, _ := tagsQB.Count()
ScenesSize: scenesSize,
ImageCount: imageCount, ret = models.StatsResultType{
ImagesSize: imageSize, SceneCount: scenesCount,
GalleryCount: galleryCount, ScenesSize: scenesSize,
PerformerCount: performersCount, ImageCount: imageCount,
StudioCount: studiosCount, ImagesSize: imageSize,
MovieCount: moviesCount, GalleryCount: galleryCount,
TagCount: tagsCount, PerformerCount: performersCount,
}, nil StudioCount: studiosCount,
MovieCount: moviesCount,
TagCount: tagsCount,
}
return nil
}); err != nil {
return nil, err
}
return &ret, nil
} }
func (r *queryResolver) Version(ctx context.Context) (*models.Version, error) { func (r *queryResolver) Version(ctx context.Context) (*models.Version, error) {
@@ -171,31 +222,41 @@ func (r *queryResolver) Latestversion(ctx context.Context) (*models.ShortVersion
// Get scene marker tags which show up under the video. // Get scene marker tags which show up under the video.
func (r *queryResolver) SceneMarkerTags(ctx context.Context, scene_id string) ([]*models.SceneMarkerTag, error) { func (r *queryResolver) SceneMarkerTags(ctx context.Context, scene_id string) ([]*models.SceneMarkerTag, error) {
sceneID, _ := strconv.Atoi(scene_id) sceneID, err := strconv.Atoi(scene_id)
sqb := models.NewSceneMarkerQueryBuilder()
sceneMarkers, err := sqb.FindBySceneID(sceneID, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
tags := make(map[int]*models.SceneMarkerTag)
var keys []int var keys []int
tqb := models.NewTagQueryBuilder() tags := make(map[int]*models.SceneMarkerTag)
for _, sceneMarker := range sceneMarkers {
markerPrimaryTag, err := tqb.Find(sceneMarker.PrimaryTagID, nil) if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
sceneMarkers, err := repo.SceneMarker().FindBySceneID(sceneID)
if err != nil { if err != nil {
return nil, err return err
} }
_, hasKey := tags[markerPrimaryTag.ID]
var sceneMarkerTag *models.SceneMarkerTag tqb := repo.Tag()
if !hasKey { for _, sceneMarker := range sceneMarkers {
sceneMarkerTag = &models.SceneMarkerTag{Tag: markerPrimaryTag} markerPrimaryTag, err := tqb.Find(sceneMarker.PrimaryTagID)
tags[markerPrimaryTag.ID] = sceneMarkerTag if err != nil {
keys = append(keys, markerPrimaryTag.ID) return err
} else { }
sceneMarkerTag = tags[markerPrimaryTag.ID] _, hasKey := tags[markerPrimaryTag.ID]
var sceneMarkerTag *models.SceneMarkerTag
if !hasKey {
sceneMarkerTag = &models.SceneMarkerTag{Tag: markerPrimaryTag}
tags[markerPrimaryTag.ID] = sceneMarkerTag
keys = append(keys, markerPrimaryTag.ID)
} else {
sceneMarkerTag = tags[markerPrimaryTag.ID]
}
tags[markerPrimaryTag.ID].SceneMarkers = append(tags[markerPrimaryTag.ID].SceneMarkers, sceneMarker)
} }
tags[markerPrimaryTag.ID].SceneMarkers = append(tags[markerPrimaryTag.ID].SceneMarkers, sceneMarker)
return nil
}); err != nil {
return nil, err
} }
// Sort so that primary tags that show up earlier in the video are first. // Sort so that primary tags that show up earlier in the video are first.
@@ -212,13 +273,3 @@ func (r *queryResolver) SceneMarkerTags(ctx context.Context, scene_id string) ([
return result, nil return result, nil
} }
// wasFieldIncluded returns true if the given field was included in the request.
// Slices are unmarshalled to empty slices even if the field was omitted. This
// method determines if it was omitted altogether.
func wasFieldIncluded(ctx context.Context, field string) bool {
rctx := graphql.GetRequestContext(ctx)
_, ret := rctx.Variables[field]
return ret
}

View File

@@ -22,30 +22,39 @@ func (r *galleryResolver) Title(ctx context.Context, obj *models.Gallery) (*stri
return nil, nil return nil, nil
} }
func (r *galleryResolver) Images(ctx context.Context, obj *models.Gallery) ([]*models.Image, error) { func (r *galleryResolver) Images(ctx context.Context, obj *models.Gallery) (ret []*models.Image, err error) {
qb := models.NewImageQueryBuilder() if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
var err error
return qb.FindByGalleryID(obj.ID) ret, err = repo.Image().FindByGalleryID(obj.ID)
} return err
}); err != nil {
func (r *galleryResolver) Cover(ctx context.Context, obj *models.Gallery) (*models.Image, error) {
qb := models.NewImageQueryBuilder()
imgs, err := qb.FindByGalleryID(obj.ID)
if err != nil {
return nil, err return nil, err
} }
var ret *models.Image return ret, nil
if len(imgs) > 0 { }
ret = imgs[0]
}
for _, img := range imgs { func (r *galleryResolver) Cover(ctx context.Context, obj *models.Gallery) (ret *models.Image, err error) {
if image.IsCover(img) { if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret = img imgs, err := repo.Image().FindByGalleryID(obj.ID)
break if err != nil {
return err
} }
if len(imgs) > 0 {
ret = imgs[0]
}
for _, img := range imgs {
if image.IsCover(img) {
ret = img
break
}
}
return nil
}); err != nil {
return nil, err
} }
return ret, nil return ret, nil
@@ -81,35 +90,71 @@ func (r *galleryResolver) Rating(ctx context.Context, obj *models.Gallery) (*int
return nil, nil return nil, nil
} }
func (r *galleryResolver) Scene(ctx context.Context, obj *models.Gallery) (*models.Scene, error) { func (r *galleryResolver) Scene(ctx context.Context, obj *models.Gallery) (ret *models.Scene, err error) {
if !obj.SceneID.Valid { if !obj.SceneID.Valid {
return nil, nil return nil, nil
} }
qb := models.NewSceneQueryBuilder() if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
return qb.Find(int(obj.SceneID.Int64)) var err error
ret, err = repo.Scene().Find(int(obj.SceneID.Int64))
return err
}); err != nil {
return nil, err
}
return ret, nil
} }
func (r *galleryResolver) Studio(ctx context.Context, obj *models.Gallery) (*models.Studio, error) { func (r *galleryResolver) Studio(ctx context.Context, obj *models.Gallery) (ret *models.Studio, err error) {
if !obj.StudioID.Valid { if !obj.StudioID.Valid {
return nil, nil return nil, nil
} }
qb := models.NewStudioQueryBuilder() if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
return qb.Find(int(obj.StudioID.Int64), nil) var err error
ret, err = repo.Studio().Find(int(obj.StudioID.Int64))
return err
}); err != nil {
return nil, err
}
return ret, nil
} }
func (r *galleryResolver) Tags(ctx context.Context, obj *models.Gallery) ([]*models.Tag, error) { func (r *galleryResolver) Tags(ctx context.Context, obj *models.Gallery) (ret []*models.Tag, err error) {
qb := models.NewTagQueryBuilder() if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
return qb.FindByGalleryID(obj.ID, nil) var err error
ret, err = repo.Tag().FindByGalleryID(obj.ID)
return err
}); err != nil {
return nil, err
}
return ret, nil
} }
func (r *galleryResolver) Performers(ctx context.Context, obj *models.Gallery) ([]*models.Performer, error) { func (r *galleryResolver) Performers(ctx context.Context, obj *models.Gallery) (ret []*models.Performer, err error) {
qb := models.NewPerformerQueryBuilder() if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
return qb.FindByGalleryID(obj.ID, nil) var err error
ret, err = repo.Performer().FindByGalleryID(obj.ID)
return err
}); err != nil {
return nil, err
}
return ret, nil
} }
func (r *galleryResolver) ImageCount(ctx context.Context, obj *models.Gallery) (int, error) { func (r *galleryResolver) ImageCount(ctx context.Context, obj *models.Gallery) (ret int, err error) {
qb := models.NewImageQueryBuilder() if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
return qb.CountByGalleryID(obj.ID) var err error
ret, err = repo.Image().CountByGalleryID(obj.ID)
return err
}); err != nil {
return 0, err
}
return ret, nil
} }

View File

@@ -43,26 +43,51 @@ func (r *imageResolver) Paths(ctx context.Context, obj *models.Image) (*models.I
}, nil }, nil
} }
func (r *imageResolver) Galleries(ctx context.Context, obj *models.Image) ([]*models.Gallery, error) { func (r *imageResolver) Galleries(ctx context.Context, obj *models.Image) (ret []*models.Gallery, err error) {
qb := models.NewGalleryQueryBuilder() if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
return qb.FindByImageID(obj.ID, nil) var err error
ret, err = repo.Gallery().FindByImageID(obj.ID)
return err
}); err != nil {
return nil, err
}
return ret, nil
} }
func (r *imageResolver) Studio(ctx context.Context, obj *models.Image) (*models.Studio, error) { func (r *imageResolver) Studio(ctx context.Context, obj *models.Image) (ret *models.Studio, err error) {
if !obj.StudioID.Valid { if !obj.StudioID.Valid {
return nil, nil return nil, nil
} }
qb := models.NewStudioQueryBuilder() if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
return qb.Find(int(obj.StudioID.Int64), nil) ret, err = repo.Studio().Find(int(obj.StudioID.Int64))
return err
}); err != nil {
return nil, err
}
return ret, nil
} }
func (r *imageResolver) Tags(ctx context.Context, obj *models.Image) ([]*models.Tag, error) { func (r *imageResolver) Tags(ctx context.Context, obj *models.Image) (ret []*models.Tag, err error) {
qb := models.NewTagQueryBuilder() if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
return qb.FindByImageID(obj.ID, nil) ret, err = repo.Tag().FindByImageID(obj.ID)
return err
}); err != nil {
return nil, err
}
return ret, nil
} }
func (r *imageResolver) Performers(ctx context.Context, obj *models.Image) ([]*models.Performer, error) { func (r *imageResolver) Performers(ctx context.Context, obj *models.Image) (ret []*models.Performer, err error) {
qb := models.NewPerformerQueryBuilder() if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
return qb.FindByImageID(obj.ID, nil) ret, err = repo.Performer().FindByImageID(obj.ID)
return err
}); err != nil {
return nil, err
}
return ret, nil
} }

View File

@@ -53,10 +53,16 @@ func (r *movieResolver) Rating(ctx context.Context, obj *models.Movie) (*int, er
return nil, nil return nil, nil
} }
func (r *movieResolver) Studio(ctx context.Context, obj *models.Movie) (*models.Studio, error) { func (r *movieResolver) Studio(ctx context.Context, obj *models.Movie) (ret *models.Studio, err error) {
qb := models.NewStudioQueryBuilder()
if obj.StudioID.Valid { if obj.StudioID.Valid {
return qb.Find(int(obj.StudioID.Int64), nil) if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Studio().Find(int(obj.StudioID.Int64))
return err
}); err != nil {
return nil, err
}
return ret, nil
} }
return nil, nil return nil, nil
@@ -88,8 +94,14 @@ func (r *movieResolver) BackImagePath(ctx context.Context, obj *models.Movie) (*
return &backimagePath, nil return &backimagePath, nil
} }
func (r *movieResolver) SceneCount(ctx context.Context, obj *models.Movie) (*int, error) { func (r *movieResolver) SceneCount(ctx context.Context, obj *models.Movie) (ret *int, err error) {
qb := models.NewSceneQueryBuilder() var res int
res, err := qb.CountByMovieID(obj.ID) if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
res, err = repo.Scene().CountByMovieID(obj.ID)
return err
}); err != nil {
return nil, err
}
return &res, err return &res, err
} }

View File

@@ -138,18 +138,36 @@ func (r *performerResolver) ImagePath(ctx context.Context, obj *models.Performer
return &imagePath, nil return &imagePath, nil
} }
func (r *performerResolver) SceneCount(ctx context.Context, obj *models.Performer) (*int, error) { func (r *performerResolver) SceneCount(ctx context.Context, obj *models.Performer) (ret *int, err error) {
qb := models.NewSceneQueryBuilder() var res int
res, err := qb.CountByPerformerID(obj.ID) if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
return &res, err res, err = repo.Scene().CountByPerformerID(obj.ID)
return err
}); err != nil {
return nil, err
}
return &res, nil
} }
func (r *performerResolver) Scenes(ctx context.Context, obj *models.Performer) ([]*models.Scene, error) { func (r *performerResolver) Scenes(ctx context.Context, obj *models.Performer) (ret []*models.Scene, err error) {
qb := models.NewSceneQueryBuilder() if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
return qb.FindByPerformerID(obj.ID) ret, err = repo.Scene().FindByPerformerID(obj.ID)
return err
}); err != nil {
return nil, err
}
return ret, nil
} }
func (r *performerResolver) StashIds(ctx context.Context, obj *models.Performer) ([]*models.StashID, error) { func (r *performerResolver) StashIds(ctx context.Context, obj *models.Performer) (ret []*models.StashID, err error) {
qb := models.NewJoinsQueryBuilder() if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
return qb.GetPerformerStashIDs(obj.ID) ret, err = repo.Performer().GetStashIDs(obj.ID)
return err
}); err != nil {
return nil, err
}
return ret, nil
} }

View File

@@ -94,64 +94,109 @@ func (r *sceneResolver) Paths(ctx context.Context, obj *models.Scene) (*models.S
}, nil }, nil
} }
func (r *sceneResolver) SceneMarkers(ctx context.Context, obj *models.Scene) ([]*models.SceneMarker, error) { func (r *sceneResolver) SceneMarkers(ctx context.Context, obj *models.Scene) (ret []*models.SceneMarker, err error) {
qb := models.NewSceneMarkerQueryBuilder() if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
return qb.FindBySceneID(obj.ID, nil) ret, err = repo.SceneMarker().FindBySceneID(obj.ID)
} return err
}); err != nil {
func (r *sceneResolver) Gallery(ctx context.Context, obj *models.Scene) (*models.Gallery, error) {
qb := models.NewGalleryQueryBuilder()
return qb.FindBySceneID(obj.ID, nil)
}
func (r *sceneResolver) Studio(ctx context.Context, obj *models.Scene) (*models.Studio, error) {
qb := models.NewStudioQueryBuilder()
return qb.FindBySceneID(obj.ID)
}
func (r *sceneResolver) Movies(ctx context.Context, obj *models.Scene) ([]*models.SceneMovie, error) {
joinQB := models.NewJoinsQueryBuilder()
qb := models.NewMovieQueryBuilder()
sceneMovies, err := joinQB.GetSceneMovies(obj.ID, nil)
if err != nil {
return nil, err return nil, err
} }
var ret []*models.SceneMovie return ret, nil
for _, sm := range sceneMovies { }
movie, err := qb.Find(sm.MovieID, nil)
func (r *sceneResolver) Gallery(ctx context.Context, obj *models.Scene) (ret *models.Gallery, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Gallery().FindBySceneID(obj.ID)
return err
}); err != nil {
return nil, err
}
return ret, nil
}
func (r *sceneResolver) Studio(ctx context.Context, obj *models.Scene) (ret *models.Studio, err error) {
if !obj.StudioID.Valid {
return nil, nil
}
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Studio().Find(int(obj.StudioID.Int64))
return err
}); err != nil {
return nil, err
}
return ret, nil
}
func (r *sceneResolver) Movies(ctx context.Context, obj *models.Scene) (ret []*models.SceneMovie, err error) {
if err := r.withTxn(ctx, func(repo models.Repository) error {
qb := repo.Scene()
mqb := repo.Movie()
sceneMovies, err := qb.GetMovies(obj.ID)
if err != nil { if err != nil {
return nil, err return err
} }
sceneIdx := sm.SceneIndex for _, sm := range sceneMovies {
sceneMovie := &models.SceneMovie{ movie, err := mqb.Find(sm.MovieID)
Movie: movie, if err != nil {
return err
}
sceneIdx := sm.SceneIndex
sceneMovie := &models.SceneMovie{
Movie: movie,
}
if sceneIdx.Valid {
var idx int
idx = int(sceneIdx.Int64)
sceneMovie.SceneIndex = &idx
}
ret = append(ret, sceneMovie)
} }
if sceneIdx.Valid { return nil
var idx int }); err != nil {
idx = int(sceneIdx.Int64) return nil, err
sceneMovie.SceneIndex = &idx
}
ret = append(ret, sceneMovie)
} }
return ret, nil return ret, nil
} }
func (r *sceneResolver) Tags(ctx context.Context, obj *models.Scene) ([]*models.Tag, error) { func (r *sceneResolver) Tags(ctx context.Context, obj *models.Scene) (ret []*models.Tag, err error) {
qb := models.NewTagQueryBuilder() if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
return qb.FindBySceneID(obj.ID, nil) ret, err = repo.Tag().FindBySceneID(obj.ID)
return err
}); err != nil {
return nil, err
}
return ret, nil
} }
func (r *sceneResolver) Performers(ctx context.Context, obj *models.Scene) ([]*models.Performer, error) { func (r *sceneResolver) Performers(ctx context.Context, obj *models.Scene) (ret []*models.Performer, err error) {
qb := models.NewPerformerQueryBuilder() if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
return qb.FindBySceneID(obj.ID, nil) ret, err = repo.Performer().FindBySceneID(obj.ID)
return err
}); err != nil {
return nil, err
}
return ret, nil
} }
func (r *sceneResolver) StashIds(ctx context.Context, obj *models.Scene) ([]*models.StashID, error) { func (r *sceneResolver) StashIds(ctx context.Context, obj *models.Scene) (ret []*models.StashID, err error) {
qb := models.NewJoinsQueryBuilder() if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
return qb.GetSceneStashIDs(obj.ID) ret, err = repo.Scene().GetStashIDs(obj.ID)
return err
}); err != nil {
return nil, err
}
return ret, nil
} }

View File

@@ -2,29 +2,47 @@ package api
import ( import (
"context" "context"
"github.com/stashapp/stash/pkg/api/urlbuilders" "github.com/stashapp/stash/pkg/api/urlbuilders"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
) )
func (r *sceneMarkerResolver) Scene(ctx context.Context, obj *models.SceneMarker) (*models.Scene, error) { func (r *sceneMarkerResolver) Scene(ctx context.Context, obj *models.SceneMarker) (ret *models.Scene, err error) {
if !obj.SceneID.Valid { if !obj.SceneID.Valid {
panic("Invalid scene id") panic("Invalid scene id")
} }
qb := models.NewSceneQueryBuilder()
sceneID := int(obj.SceneID.Int64) if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
scene, err := qb.Find(sceneID) sceneID := int(obj.SceneID.Int64)
return scene, err ret, err = repo.Scene().Find(sceneID)
return err
}); err != nil {
return nil, err
}
return ret, nil
} }
func (r *sceneMarkerResolver) PrimaryTag(ctx context.Context, obj *models.SceneMarker) (*models.Tag, error) { func (r *sceneMarkerResolver) PrimaryTag(ctx context.Context, obj *models.SceneMarker) (ret *models.Tag, err error) {
qb := models.NewTagQueryBuilder() if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
tag, err := qb.Find(obj.PrimaryTagID, nil) ret, err = repo.Tag().Find(obj.PrimaryTagID)
return tag, err return err
}); err != nil {
return nil, err
}
return ret, err
} }
func (r *sceneMarkerResolver) Tags(ctx context.Context, obj *models.SceneMarker) ([]*models.Tag, error) { func (r *sceneMarkerResolver) Tags(ctx context.Context, obj *models.SceneMarker) (ret []*models.Tag, err error) {
qb := models.NewTagQueryBuilder() if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
return qb.FindBySceneMarkerID(obj.ID, nil) ret, err = repo.Tag().FindBySceneMarkerID(obj.ID)
return err
}); err != nil {
return nil, err
}
return ret, err
} }
func (r *sceneMarkerResolver) Stream(ctx context.Context, obj *models.SceneMarker) (string, error) { func (r *sceneMarkerResolver) Stream(ctx context.Context, obj *models.SceneMarker) (string, error) {

View File

@@ -25,10 +25,12 @@ func (r *studioResolver) ImagePath(ctx context.Context, obj *models.Studio) (*st
baseURL, _ := ctx.Value(BaseURLCtxKey).(string) baseURL, _ := ctx.Value(BaseURLCtxKey).(string)
imagePath := urlbuilders.NewStudioURLBuilder(baseURL, obj.ID).GetStudioImageURL() imagePath := urlbuilders.NewStudioURLBuilder(baseURL, obj.ID).GetStudioImageURL()
qb := models.NewStudioQueryBuilder() var hasImage bool
hasImage, err := qb.HasStudioImage(obj.ID) if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
var err error
if err != nil { hasImage, err = repo.Studio().HasImage(obj.ID)
return err
}); err != nil {
return nil, err return nil, err
} }
@@ -40,27 +42,51 @@ func (r *studioResolver) ImagePath(ctx context.Context, obj *models.Studio) (*st
return &imagePath, nil return &imagePath, nil
} }
func (r *studioResolver) SceneCount(ctx context.Context, obj *models.Studio) (*int, error) { func (r *studioResolver) SceneCount(ctx context.Context, obj *models.Studio) (ret *int, err error) {
qb := models.NewSceneQueryBuilder() var res int
res, err := qb.CountByStudioID(obj.ID) if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
res, err = repo.Scene().CountByStudioID(obj.ID)
return err
}); err != nil {
return nil, err
}
return &res, err return &res, err
} }
func (r *studioResolver) ParentStudio(ctx context.Context, obj *models.Studio) (*models.Studio, error) { func (r *studioResolver) ParentStudio(ctx context.Context, obj *models.Studio) (ret *models.Studio, err error) {
if !obj.ParentID.Valid { if !obj.ParentID.Valid {
return nil, nil return nil, nil
} }
qb := models.NewStudioQueryBuilder() if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
return qb.Find(int(obj.ParentID.Int64), nil) ret, err = repo.Studio().Find(int(obj.ParentID.Int64))
return err
}); err != nil {
return nil, err
}
return ret, nil
} }
func (r *studioResolver) ChildStudios(ctx context.Context, obj *models.Studio) ([]*models.Studio, error) { func (r *studioResolver) ChildStudios(ctx context.Context, obj *models.Studio) (ret []*models.Studio, err error) {
qb := models.NewStudioQueryBuilder() if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
return qb.FindChildren(obj.ID, nil) ret, err = repo.Studio().FindChildren(obj.ID)
return err
}); err != nil {
return nil, err
}
return ret, nil
} }
func (r *studioResolver) StashIds(ctx context.Context, obj *models.Studio) ([]*models.StashID, error) { func (r *studioResolver) StashIds(ctx context.Context, obj *models.Studio) (ret []*models.StashID, err error) {
qb := models.NewJoinsQueryBuilder() if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
return qb.GetStudioStashIDs(obj.ID) ret, err = repo.Studio().GetStashIDs(obj.ID)
return err
}); err != nil {
return nil, err
}
return ret, nil
} }

View File

@@ -7,21 +7,27 @@ import (
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
) )
func (r *tagResolver) SceneCount(ctx context.Context, obj *models.Tag) (*int, error) { func (r *tagResolver) SceneCount(ctx context.Context, obj *models.Tag) (ret *int, err error) {
qb := models.NewSceneQueryBuilder() var count int
if obj == nil { if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
return nil, nil count, err = repo.Scene().CountByTagID(obj.ID)
return err
}); err != nil {
return nil, err
} }
count, err := qb.CountByTagID(obj.ID)
return &count, err return &count, err
} }
func (r *tagResolver) SceneMarkerCount(ctx context.Context, obj *models.Tag) (*int, error) { func (r *tagResolver) SceneMarkerCount(ctx context.Context, obj *models.Tag) (ret *int, err error) {
qb := models.NewSceneMarkerQueryBuilder() var count int
if obj == nil { if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
return nil, nil count, err = repo.SceneMarker().CountByTagID(obj.ID)
return err
}); err != nil {
return nil, err
} }
count, err := qb.CountByTagID(obj.ID)
return &count, err return &count, err
} }

View File

@@ -52,7 +52,7 @@ func (r *mutationResolver) ConfigureGeneral(ctx context.Context, input models.Co
if input.VideoFileNamingAlgorithm != config.GetVideoFileNamingAlgorithm() { if input.VideoFileNamingAlgorithm != config.GetVideoFileNamingAlgorithm() {
// validate changing VideoFileNamingAlgorithm // validate changing VideoFileNamingAlgorithm
if err := manager.ValidateVideoFileNamingAlgorithm(input.VideoFileNamingAlgorithm); err != nil { if err := manager.ValidateVideoFileNamingAlgorithm(r.txnManager, input.VideoFileNamingAlgorithm); err != nil {
return makeConfigGeneralResult(), err return makeConfigGeneralResult(), err
} }

View File

@@ -4,11 +4,10 @@ import (
"context" "context"
"database/sql" "database/sql"
"errors" "errors"
"fmt"
"strconv" "strconv"
"time" "time"
"github.com/jmoiron/sqlx"
"github.com/stashapp/stash/pkg/database"
"github.com/stashapp/stash/pkg/manager" "github.com/stashapp/stash/pkg/manager"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
@@ -69,108 +68,102 @@ func (r *mutationResolver) GalleryCreate(ctx context.Context, input models.Galle
newGallery.SceneID = sql.NullInt64{Valid: false} newGallery.SceneID = sql.NullInt64{Valid: false}
} }
// Start the transaction and save the performer // Start the transaction and save the gallery
tx := database.DB.MustBeginTx(ctx, nil) var gallery *models.Gallery
qb := models.NewGalleryQueryBuilder() if err := r.withTxn(ctx, func(repo models.Repository) error {
jqb := models.NewJoinsQueryBuilder() qb := repo.Gallery()
gallery, err := qb.Create(newGallery, tx) var err error
if err != nil { gallery, err = qb.Create(newGallery)
_ = tx.Rollback() if err != nil {
return nil, err return err
}
// Save the performers
var performerJoins []models.PerformersGalleries
for _, pid := range input.PerformerIds {
performerID, _ := strconv.Atoi(pid)
performerJoin := models.PerformersGalleries{
PerformerID: performerID,
GalleryID: gallery.ID,
} }
performerJoins = append(performerJoins, performerJoin)
}
if err := jqb.UpdatePerformersGalleries(gallery.ID, performerJoins, tx); err != nil {
return nil, err
}
// Save the tags // Save the performers
var tagJoins []models.GalleriesTags if err := r.updateGalleryPerformers(qb, gallery.ID, input.PerformerIds); err != nil {
for _, tid := range input.TagIds { return err
tagID, _ := strconv.Atoi(tid)
tagJoin := models.GalleriesTags{
GalleryID: gallery.ID,
TagID: tagID,
} }
tagJoins = append(tagJoins, tagJoin)
}
if err := jqb.UpdateGalleriesTags(gallery.ID, tagJoins, tx); err != nil {
return nil, err
}
// Commit // Save the tags
if err := tx.Commit(); err != nil { if err := r.updateGalleryTags(qb, gallery.ID, input.TagIds); err != nil {
return err
}
return nil
}); err != nil {
return nil, err return nil, err
} }
return gallery, nil return gallery, nil
} }
func (r *mutationResolver) GalleryUpdate(ctx context.Context, input models.GalleryUpdateInput) (*models.Gallery, error) { func (r *mutationResolver) updateGalleryPerformers(qb models.GalleryReaderWriter, galleryID int, performerIDs []string) error {
// Start the transaction and save the gallery ids, err := utils.StringSliceToIntSlice(performerIDs)
tx := database.DB.MustBeginTx(ctx, nil) if err != nil {
return err
}
return qb.UpdatePerformers(galleryID, ids)
}
func (r *mutationResolver) updateGalleryTags(qb models.GalleryReaderWriter, galleryID int, tagIDs []string) error {
ids, err := utils.StringSliceToIntSlice(tagIDs)
if err != nil {
return err
}
return qb.UpdateTags(galleryID, ids)
}
func (r *mutationResolver) GalleryUpdate(ctx context.Context, input models.GalleryUpdateInput) (ret *models.Gallery, err error) {
translator := changesetTranslator{ translator := changesetTranslator{
inputMap: getUpdateInputMap(ctx), inputMap: getUpdateInputMap(ctx),
} }
ret, err := r.galleryUpdate(input, translator, tx)
if err != nil { // Start the transaction and save the gallery
_ = tx.Rollback() if err := r.withTxn(ctx, func(repo models.Repository) error {
return nil, err ret, err = r.galleryUpdate(input, translator, repo)
} return err
}); err != nil {
// Commit
if err := tx.Commit(); err != nil {
return nil, err return nil, err
} }
return ret, nil return ret, nil
} }
func (r *mutationResolver) GalleriesUpdate(ctx context.Context, input []*models.GalleryUpdateInput) ([]*models.Gallery, error) { func (r *mutationResolver) GalleriesUpdate(ctx context.Context, input []*models.GalleryUpdateInput) (ret []*models.Gallery, err error) {
// Start the transaction and save the gallery
tx := database.DB.MustBeginTx(ctx, nil)
inputMaps := getUpdateInputMaps(ctx) inputMaps := getUpdateInputMaps(ctx)
var ret []*models.Gallery // Start the transaction and save the gallery
if err := r.withTxn(ctx, func(repo models.Repository) error {
for i, gallery := range input {
translator := changesetTranslator{
inputMap: inputMaps[i],
}
for i, gallery := range input { thisGallery, err := r.galleryUpdate(*gallery, translator, repo)
translator := changesetTranslator{ if err != nil {
inputMap: inputMaps[i], return err
}
ret = append(ret, thisGallery)
} }
thisGallery, err := r.galleryUpdate(*gallery, translator, tx) return nil
ret = append(ret, thisGallery) }); err != nil {
if err != nil {
_ = tx.Rollback()
return nil, err
}
}
// Commit
if err := tx.Commit(); err != nil {
return nil, err return nil, err
} }
return ret, nil return ret, nil
} }
func (r *mutationResolver) galleryUpdate(input models.GalleryUpdateInput, translator changesetTranslator, tx *sqlx.Tx) (*models.Gallery, error) { func (r *mutationResolver) galleryUpdate(input models.GalleryUpdateInput, translator changesetTranslator, repo models.Repository) (*models.Gallery, error) {
qb := models.NewGalleryQueryBuilder() qb := repo.Gallery()
// Populate gallery from the input // Populate gallery from the input
galleryID, _ := strconv.Atoi(input.ID) galleryID, err := strconv.Atoi(input.ID)
originalGallery, err := qb.Find(galleryID, nil) if err != nil {
return nil, err
}
originalGallery, err := qb.Find(galleryID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -209,40 +202,21 @@ func (r *mutationResolver) galleryUpdate(input models.GalleryUpdateInput, transl
// gallery scene is set from the scene only // gallery scene is set from the scene only
jqb := models.NewJoinsQueryBuilder() gallery, err := qb.UpdatePartial(updatedGallery)
gallery, err := qb.UpdatePartial(updatedGallery, tx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Save the performers // Save the performers
if translator.hasField("performer_ids") { if translator.hasField("performer_ids") {
var performerJoins []models.PerformersGalleries if err := r.updateGalleryPerformers(qb, galleryID, input.PerformerIds); err != nil {
for _, pid := range input.PerformerIds {
performerID, _ := strconv.Atoi(pid)
performerJoin := models.PerformersGalleries{
PerformerID: performerID,
GalleryID: galleryID,
}
performerJoins = append(performerJoins, performerJoin)
}
if err := jqb.UpdatePerformersGalleries(galleryID, performerJoins, tx); err != nil {
return nil, err return nil, err
} }
} }
// Save the tags // Save the tags
if translator.hasField("tag_ids") { if translator.hasField("tag_ids") {
var tagJoins []models.GalleriesTags if err := r.updateGalleryTags(qb, galleryID, input.TagIds); err != nil {
for _, tid := range input.TagIds {
tagID, _ := strconv.Atoi(tid)
tagJoin := models.GalleriesTags{
GalleryID: galleryID,
TagID: tagID,
}
tagJoins = append(tagJoins, tagJoin)
}
if err := jqb.UpdateGalleriesTags(galleryID, tagJoins, tx); err != nil {
return nil, err return nil, err
} }
} }
@@ -254,11 +228,6 @@ func (r *mutationResolver) BulkGalleryUpdate(ctx context.Context, input models.B
// Populate gallery from the input // Populate gallery from the input
updatedTime := time.Now() updatedTime := time.Now()
// Start the transaction and save the gallery marker
tx := database.DB.MustBeginTx(ctx, nil)
qb := models.NewGalleryQueryBuilder()
jqb := models.NewJoinsQueryBuilder()
translator := changesetTranslator{ translator := changesetTranslator{
inputMap: getUpdateInputMap(ctx), inputMap: getUpdateInputMap(ctx),
} }
@@ -277,181 +246,143 @@ func (r *mutationResolver) BulkGalleryUpdate(ctx context.Context, input models.B
ret := []*models.Gallery{} ret := []*models.Gallery{}
for _, galleryIDStr := range input.Ids { // Start the transaction and save the galleries
galleryID, _ := strconv.Atoi(galleryIDStr) if err := r.withTxn(ctx, func(repo models.Repository) error {
updatedGallery.ID = galleryID qb := repo.Gallery()
gallery, err := qb.UpdatePartial(updatedGallery, tx) for _, galleryIDStr := range input.Ids {
if err != nil { galleryID, _ := strconv.Atoi(galleryIDStr)
_ = tx.Rollback() updatedGallery.ID = galleryID
return nil, err
}
ret = append(ret, gallery) gallery, err := qb.UpdatePartial(updatedGallery)
// Save the performers
if translator.hasField("performer_ids") {
performerIDs, err := adjustGalleryPerformerIDs(tx, galleryID, *input.PerformerIds)
if err != nil { if err != nil {
_ = tx.Rollback() return err
return nil, err
} }
var performerJoins []models.PerformersGalleries ret = append(ret, gallery)
for _, performerID := range performerIDs {
performerJoin := models.PerformersGalleries{ // Save the performers
PerformerID: performerID, if translator.hasField("performer_ids") {
GalleryID: galleryID, performerIDs, err := adjustGalleryPerformerIDs(qb, galleryID, *input.PerformerIds)
if err != nil {
return err
}
if err := qb.UpdatePerformers(galleryID, performerIDs); err != nil {
return err
} }
performerJoins = append(performerJoins, performerJoin)
} }
if err := jqb.UpdatePerformersGalleries(galleryID, performerJoins, tx); err != nil {
_ = tx.Rollback() // Save the tags
return nil, err if translator.hasField("tag_ids") {
tagIDs, err := adjustGalleryTagIDs(qb, galleryID, *input.TagIds)
if err != nil {
return err
}
if err := qb.UpdateTags(galleryID, tagIDs); err != nil {
return err
}
} }
} }
// Save the tags return nil
if translator.hasField("tag_ids") { }); err != nil {
tagIDs, err := adjustGalleryTagIDs(tx, galleryID, *input.TagIds)
if err != nil {
_ = tx.Rollback()
return nil, err
}
var tagJoins []models.GalleriesTags
for _, tagID := range tagIDs {
tagJoin := models.GalleriesTags{
GalleryID: galleryID,
TagID: tagID,
}
tagJoins = append(tagJoins, tagJoin)
}
if err := jqb.UpdateGalleriesTags(galleryID, tagJoins, tx); err != nil {
_ = tx.Rollback()
return nil, err
}
}
}
// Commit
if err := tx.Commit(); err != nil {
return nil, err return nil, err
} }
return ret, nil return ret, nil
} }
func adjustGalleryPerformerIDs(tx *sqlx.Tx, galleryID int, ids models.BulkUpdateIds) ([]int, error) { func adjustGalleryPerformerIDs(qb models.GalleryReader, galleryID int, ids models.BulkUpdateIds) (ret []int, err error) {
var ret []int ret, err = qb.GetPerformerIDs(galleryID)
if err != nil {
jqb := models.NewJoinsQueryBuilder() return nil, err
if ids.Mode == models.BulkUpdateIDModeAdd || ids.Mode == models.BulkUpdateIDModeRemove {
// adding to the joins
performerJoins, err := jqb.GetGalleryPerformers(galleryID, tx)
if err != nil {
return nil, err
}
for _, join := range performerJoins {
ret = append(ret, join.PerformerID)
}
} }
return adjustIDs(ret, ids), nil return adjustIDs(ret, ids), nil
} }
func adjustGalleryTagIDs(tx *sqlx.Tx, galleryID int, ids models.BulkUpdateIds) ([]int, error) { func adjustGalleryTagIDs(qb models.GalleryReader, galleryID int, ids models.BulkUpdateIds) (ret []int, err error) {
var ret []int ret, err = qb.GetTagIDs(galleryID)
if err != nil {
jqb := models.NewJoinsQueryBuilder() return nil, err
if ids.Mode == models.BulkUpdateIDModeAdd || ids.Mode == models.BulkUpdateIDModeRemove {
// adding to the joins
tagJoins, err := jqb.GetGalleryTags(galleryID, tx)
if err != nil {
return nil, err
}
for _, join := range tagJoins {
ret = append(ret, join.TagID)
}
} }
return adjustIDs(ret, ids), nil return adjustIDs(ret, ids), nil
} }
func (r *mutationResolver) GalleryDestroy(ctx context.Context, input models.GalleryDestroyInput) (bool, error) { func (r *mutationResolver) GalleryDestroy(ctx context.Context, input models.GalleryDestroyInput) (bool, error) {
qb := models.NewGalleryQueryBuilder() galleryIDs, err := utils.StringSliceToIntSlice(input.Ids)
iqb := models.NewImageQueryBuilder() if err != nil {
tx := database.DB.MustBeginTx(ctx, nil) return false, err
}
var galleries []*models.Gallery var galleries []*models.Gallery
var imgsToPostProcess []*models.Image var imgsToPostProcess []*models.Image
var imgsToDelete []*models.Image var imgsToDelete []*models.Image
for _, id := range input.Ids { if err := r.withTxn(ctx, func(repo models.Repository) error {
galleryID, _ := strconv.Atoi(id) qb := repo.Gallery()
iqb := repo.Image()
for _, id := range galleryIDs {
gallery, err := qb.Find(id)
if err != nil {
return err
}
if gallery == nil {
return fmt.Errorf("gallery with id %d not found", id)
}
gallery, err := qb.Find(galleryID, tx)
if gallery != nil {
galleries = append(galleries, gallery) galleries = append(galleries, gallery)
}
err = qb.Destroy(galleryID, tx)
if err != nil { // if this is a zip-based gallery, delete the images as well first
tx.Rollback() if gallery.Zip {
return false, err imgs, err := iqb.FindByGalleryID(id)
}
// if this is a zip-based gallery, delete the images as well
if gallery.Zip {
imgs, err := iqb.FindByGalleryID(galleryID)
if err != nil {
tx.Rollback()
return false, err
}
for _, img := range imgs {
err = iqb.Destroy(img.ID, tx)
if err != nil { if err != nil {
tx.Rollback() return err
return false, err
} }
imgsToPostProcess = append(imgsToPostProcess, img) for _, img := range imgs {
} if err := iqb.Destroy(img.ID); err != nil {
} else if input.DeleteFile != nil && *input.DeleteFile { return err
// Delete image if it is only attached to this gallery
imgs, err := iqb.FindByGalleryID(galleryID)
if err != nil {
tx.Rollback()
return false, err
}
for _, img := range imgs {
imgGalleries, err := qb.FindByImageID(img.ID, tx)
if err != nil {
tx.Rollback()
return false, err
}
if len(imgGalleries) == 0 {
err = iqb.Destroy(img.ID, tx)
if err != nil {
tx.Rollback()
return false, err
} }
imgsToDelete = append(imgsToDelete, img)
imgsToPostProcess = append(imgsToPostProcess, img) imgsToPostProcess = append(imgsToPostProcess, img)
} }
} else if input.DeleteFile != nil && *input.DeleteFile {
// Delete image if it is only attached to this gallery
imgs, err := iqb.FindByGalleryID(id)
if err != nil {
return err
}
for _, img := range imgs {
imgGalleries, err := qb.FindByImageID(img.ID)
if err != nil {
return err
}
if len(imgGalleries) == 0 {
if err := iqb.Destroy(img.ID); err != nil {
return err
}
imgsToDelete = append(imgsToDelete, img)
imgsToPostProcess = append(imgsToPostProcess, img)
}
}
}
if err := qb.Destroy(id); err != nil {
return err
} }
} }
}
if err := tx.Commit(); err != nil { return nil
}); err != nil {
return false, err return false, err
} }
@@ -479,34 +410,39 @@ func (r *mutationResolver) GalleryDestroy(ctx context.Context, input models.Gall
} }
func (r *mutationResolver) AddGalleryImages(ctx context.Context, input models.GalleryAddInput) (bool, error) { func (r *mutationResolver) AddGalleryImages(ctx context.Context, input models.GalleryAddInput) (bool, error) {
galleryID, _ := strconv.Atoi(input.GalleryID) galleryID, err := strconv.Atoi(input.GalleryID)
qb := models.NewGalleryQueryBuilder()
gallery, err := qb.Find(galleryID, nil)
if err != nil { if err != nil {
return false, err return false, err
} }
if gallery == nil { imageIDs, err := utils.StringSliceToIntSlice(input.ImageIds)
return false, errors.New("gallery not found") if err != nil {
return false, err
} }
if gallery.Zip { if err := r.withTxn(ctx, func(repo models.Repository) error {
return false, errors.New("cannot modify zip gallery images") qb := repo.Gallery()
} gallery, err := qb.Find(galleryID)
jqb := models.NewJoinsQueryBuilder()
tx := database.DB.MustBeginTx(ctx, nil)
for _, id := range input.ImageIds {
imageID, _ := strconv.Atoi(id)
_, err := jqb.AddImageGallery(imageID, galleryID, tx)
if err != nil { if err != nil {
tx.Rollback() return err
return false, err
} }
}
if err := tx.Commit(); err != nil { if gallery == nil {
return errors.New("gallery not found")
}
if gallery.Zip {
return errors.New("cannot modify zip gallery images")
}
newIDs, err := qb.GetImageIDs(galleryID)
if err != nil {
return err
}
newIDs = utils.IntAppendUniques(newIDs, imageIDs)
return qb.UpdateImages(galleryID, newIDs)
}); err != nil {
return false, err return false, err
} }
@@ -514,34 +450,39 @@ func (r *mutationResolver) AddGalleryImages(ctx context.Context, input models.Ga
} }
func (r *mutationResolver) RemoveGalleryImages(ctx context.Context, input models.GalleryRemoveInput) (bool, error) { func (r *mutationResolver) RemoveGalleryImages(ctx context.Context, input models.GalleryRemoveInput) (bool, error) {
galleryID, _ := strconv.Atoi(input.GalleryID) galleryID, err := strconv.Atoi(input.GalleryID)
qb := models.NewGalleryQueryBuilder()
gallery, err := qb.Find(galleryID, nil)
if err != nil { if err != nil {
return false, err return false, err
} }
if gallery == nil { imageIDs, err := utils.StringSliceToIntSlice(input.ImageIds)
return false, errors.New("gallery not found") if err != nil {
return false, err
} }
if gallery.Zip { if err := r.withTxn(ctx, func(repo models.Repository) error {
return false, errors.New("cannot modify zip gallery images") qb := repo.Gallery()
} gallery, err := qb.Find(galleryID)
jqb := models.NewJoinsQueryBuilder()
tx := database.DB.MustBeginTx(ctx, nil)
for _, id := range input.ImageIds {
imageID, _ := strconv.Atoi(id)
_, err := jqb.RemoveImageGallery(imageID, galleryID, tx)
if err != nil { if err != nil {
tx.Rollback() return err
return false, err
} }
}
if err := tx.Commit(); err != nil { if gallery == nil {
return errors.New("gallery not found")
}
if gallery.Zip {
return errors.New("cannot modify zip gallery images")
}
newIDs, err := qb.GetImageIDs(galleryID)
if err != nil {
return err
}
newIDs = utils.IntExclude(newIDs, imageIDs)
return qb.UpdateImages(galleryID, newIDs)
}); err != nil {
return false, err return false, err
} }

View File

@@ -2,71 +2,63 @@ package api
import ( import (
"context" "context"
"fmt"
"strconv" "strconv"
"time" "time"
"github.com/jmoiron/sqlx"
"github.com/stashapp/stash/pkg/database"
"github.com/stashapp/stash/pkg/manager" "github.com/stashapp/stash/pkg/manager"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/utils"
) )
func (r *mutationResolver) ImageUpdate(ctx context.Context, input models.ImageUpdateInput) (*models.Image, error) { func (r *mutationResolver) ImageUpdate(ctx context.Context, input models.ImageUpdateInput) (ret *models.Image, err error) {
// Start the transaction and save the image
tx := database.DB.MustBeginTx(ctx, nil)
translator := changesetTranslator{ translator := changesetTranslator{
inputMap: getUpdateInputMap(ctx), inputMap: getUpdateInputMap(ctx),
} }
ret, err := r.imageUpdate(input, translator, tx) // Start the transaction and save the image
if err := r.withTxn(ctx, func(repo models.Repository) error {
if err != nil { ret, err = r.imageUpdate(input, translator, repo)
_ = tx.Rollback() return err
return nil, err }); err != nil {
}
// Commit
if err := tx.Commit(); err != nil {
return nil, err return nil, err
} }
return ret, nil return ret, nil
} }
func (r *mutationResolver) ImagesUpdate(ctx context.Context, input []*models.ImageUpdateInput) ([]*models.Image, error) { func (r *mutationResolver) ImagesUpdate(ctx context.Context, input []*models.ImageUpdateInput) (ret []*models.Image, err error) {
// Start the transaction and save the image
tx := database.DB.MustBeginTx(ctx, nil)
inputMaps := getUpdateInputMaps(ctx) inputMaps := getUpdateInputMaps(ctx)
var ret []*models.Image // Start the transaction and save the image
if err := r.withTxn(ctx, func(repo models.Repository) error {
for i, image := range input {
translator := changesetTranslator{
inputMap: inputMaps[i],
}
for i, image := range input { thisImage, err := r.imageUpdate(*image, translator, repo)
translator := changesetTranslator{ if err != nil {
inputMap: inputMaps[i], return err
}
ret = append(ret, thisImage)
} }
thisImage, err := r.imageUpdate(*image, translator, tx) return nil
ret = append(ret, thisImage) }); err != nil {
if err != nil {
_ = tx.Rollback()
return nil, err
}
}
// Commit
if err := tx.Commit(); err != nil {
return nil, err return nil, err
} }
return ret, nil return ret, nil
} }
func (r *mutationResolver) imageUpdate(input models.ImageUpdateInput, translator changesetTranslator, tx *sqlx.Tx) (*models.Image, error) { func (r *mutationResolver) imageUpdate(input models.ImageUpdateInput, translator changesetTranslator, repo models.Repository) (*models.Image, error) {
// Populate image from the input // Populate image from the input
imageID, _ := strconv.Atoi(input.ID) imageID, err := strconv.Atoi(input.ID)
if err != nil {
return nil, err
}
updatedTime := time.Now() updatedTime := time.Now()
updatedImage := models.ImagePartial{ updatedImage := models.ImagePartial{
@@ -79,43 +71,28 @@ func (r *mutationResolver) imageUpdate(input models.ImageUpdateInput, translator
updatedImage.StudioID = translator.nullInt64FromString(input.StudioID, "studio_id") updatedImage.StudioID = translator.nullInt64FromString(input.StudioID, "studio_id")
updatedImage.Organized = input.Organized updatedImage.Organized = input.Organized
qb := models.NewImageQueryBuilder() qb := repo.Image()
jqb := models.NewJoinsQueryBuilder() image, err := qb.Update(updatedImage)
image, err := qb.Update(updatedImage, tx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// don't set the galleries directly. Use add/remove gallery images interface instead if translator.hasField("gallery_ids") {
if err := r.updateImageGalleries(qb, imageID, input.GalleryIds); err != nil {
return nil, err
}
}
// Save the performers // Save the performers
if translator.hasField("performer_ids") { if translator.hasField("performer_ids") {
var performerJoins []models.PerformersImages if err := r.updateImagePerformers(qb, imageID, input.PerformerIds); err != nil {
for _, pid := range input.PerformerIds {
performerID, _ := strconv.Atoi(pid)
performerJoin := models.PerformersImages{
PerformerID: performerID,
ImageID: imageID,
}
performerJoins = append(performerJoins, performerJoin)
}
if err := jqb.UpdatePerformersImages(imageID, performerJoins, tx); err != nil {
return nil, err return nil, err
} }
} }
// Save the tags // Save the tags
if translator.hasField("tag_ids") { if translator.hasField("tag_ids") {
var tagJoins []models.ImagesTags if err := r.updateImageTags(qb, imageID, input.TagIds); err != nil {
for _, tid := range input.TagIds {
tagID, _ := strconv.Atoi(tid)
tagJoin := models.ImagesTags{
ImageID: imageID,
TagID: tagID,
}
tagJoins = append(tagJoins, tagJoin)
}
if err := jqb.UpdateImagesTags(imageID, tagJoins, tx); err != nil {
return nil, err return nil, err
} }
} }
@@ -123,15 +100,39 @@ func (r *mutationResolver) imageUpdate(input models.ImageUpdateInput, translator
return image, nil return image, nil
} }
func (r *mutationResolver) BulkImageUpdate(ctx context.Context, input models.BulkImageUpdateInput) ([]*models.Image, error) { func (r *mutationResolver) updateImageGalleries(qb models.ImageReaderWriter, imageID int, galleryIDs []string) error {
ids, err := utils.StringSliceToIntSlice(galleryIDs)
if err != nil {
return err
}
return qb.UpdateGalleries(imageID, ids)
}
func (r *mutationResolver) updateImagePerformers(qb models.ImageReaderWriter, imageID int, performerIDs []string) error {
ids, err := utils.StringSliceToIntSlice(performerIDs)
if err != nil {
return err
}
return qb.UpdatePerformers(imageID, ids)
}
func (r *mutationResolver) updateImageTags(qb models.ImageReaderWriter, imageID int, tagsIDs []string) error {
ids, err := utils.StringSliceToIntSlice(tagsIDs)
if err != nil {
return err
}
return qb.UpdateTags(imageID, ids)
}
func (r *mutationResolver) BulkImageUpdate(ctx context.Context, input models.BulkImageUpdateInput) (ret []*models.Image, err error) {
imageIDs, err := utils.StringSliceToIntSlice(input.Ids)
if err != nil {
return nil, err
}
// Populate image from the input // Populate image from the input
updatedTime := time.Now() updatedTime := time.Now()
// Start the transaction and save the image marker
tx := database.DB.MustBeginTx(ctx, nil)
qb := models.NewImageQueryBuilder()
jqb := models.NewJoinsQueryBuilder()
updatedImage := models.ImagePartial{ updatedImage := models.ImagePartial{
UpdatedAt: &models.SQLiteTimestamp{Timestamp: updatedTime}, UpdatedAt: &models.SQLiteTimestamp{Timestamp: updatedTime},
} }
@@ -145,168 +146,113 @@ func (r *mutationResolver) BulkImageUpdate(ctx context.Context, input models.Bul
updatedImage.StudioID = translator.nullInt64FromString(input.StudioID, "studio_id") updatedImage.StudioID = translator.nullInt64FromString(input.StudioID, "studio_id")
updatedImage.Organized = input.Organized updatedImage.Organized = input.Organized
ret := []*models.Image{} // Start the transaction and save the image marker
if err := r.withTxn(ctx, func(repo models.Repository) error {
qb := repo.Image()
for _, imageIDStr := range input.Ids { for _, imageID := range imageIDs {
imageID, _ := strconv.Atoi(imageIDStr) updatedImage.ID = imageID
updatedImage.ID = imageID
image, err := qb.Update(updatedImage, tx) image, err := qb.Update(updatedImage)
if err != nil {
_ = tx.Rollback()
return nil, err
}
ret = append(ret, image)
// Save the galleries
if translator.hasField("gallery_ids") {
galleryIDs, err := adjustImageGalleryIDs(tx, imageID, *input.GalleryIds)
if err != nil { if err != nil {
_ = tx.Rollback() return err
return nil, err
} }
var galleryJoins []models.GalleriesImages ret = append(ret, image)
for _, gid := range galleryIDs {
galleryJoin := models.GalleriesImages{ // Save the galleries
GalleryID: gid, if translator.hasField("gallery_ids") {
ImageID: imageID, galleryIDs, err := adjustImageGalleryIDs(qb, imageID, *input.GalleryIds)
if err != nil {
return err
}
if err := qb.UpdateGalleries(imageID, galleryIDs); err != nil {
return err
} }
galleryJoins = append(galleryJoins, galleryJoin)
} }
if err := jqb.UpdateGalleriesImages(imageID, galleryJoins, tx); err != nil {
return nil, err // Save the performers
if translator.hasField("performer_ids") {
performerIDs, err := adjustImagePerformerIDs(qb, imageID, *input.PerformerIds)
if err != nil {
return err
}
if err := qb.UpdatePerformers(imageID, performerIDs); err != nil {
return err
}
}
// Save the tags
if translator.hasField("tag_ids") {
tagIDs, err := adjustImageTagIDs(qb, imageID, *input.TagIds)
if err != nil {
return err
}
if err := qb.UpdateTags(imageID, tagIDs); err != nil {
return err
}
} }
} }
// Save the performers return nil
if translator.hasField("performer_ids") { }); err != nil {
performerIDs, err := adjustImagePerformerIDs(tx, imageID, *input.PerformerIds)
if err != nil {
_ = tx.Rollback()
return nil, err
}
var performerJoins []models.PerformersImages
for _, performerID := range performerIDs {
performerJoin := models.PerformersImages{
PerformerID: performerID,
ImageID: imageID,
}
performerJoins = append(performerJoins, performerJoin)
}
if err := jqb.UpdatePerformersImages(imageID, performerJoins, tx); err != nil {
_ = tx.Rollback()
return nil, err
}
}
// Save the tags
if translator.hasField("tag_ids") {
tagIDs, err := adjustImageTagIDs(tx, imageID, *input.TagIds)
if err != nil {
_ = tx.Rollback()
return nil, err
}
var tagJoins []models.ImagesTags
for _, tagID := range tagIDs {
tagJoin := models.ImagesTags{
ImageID: imageID,
TagID: tagID,
}
tagJoins = append(tagJoins, tagJoin)
}
if err := jqb.UpdateImagesTags(imageID, tagJoins, tx); err != nil {
_ = tx.Rollback()
return nil, err
}
}
}
// Commit
if err := tx.Commit(); err != nil {
return nil, err return nil, err
} }
return ret, nil return ret, nil
} }
func adjustImageGalleryIDs(tx *sqlx.Tx, imageID int, ids models.BulkUpdateIds) ([]int, error) { func adjustImageGalleryIDs(qb models.ImageReader, imageID int, ids models.BulkUpdateIds) (ret []int, err error) {
var ret []int ret, err = qb.GetGalleryIDs(imageID)
if err != nil {
jqb := models.NewJoinsQueryBuilder() return nil, err
if ids.Mode == models.BulkUpdateIDModeAdd || ids.Mode == models.BulkUpdateIDModeRemove { }
// adding to the joins
galleryJoins, err := jqb.GetImageGalleries(imageID, tx) return adjustIDs(ret, ids), nil
}
if err != nil {
return nil, err func adjustImagePerformerIDs(qb models.ImageReader, imageID int, ids models.BulkUpdateIds) (ret []int, err error) {
} ret, err = qb.GetPerformerIDs(imageID)
if err != nil {
for _, join := range galleryJoins { return nil, err
ret = append(ret, join.GalleryID) }
}
} return adjustIDs(ret, ids), nil
}
return adjustIDs(ret, ids), nil
} func adjustImageTagIDs(qb models.ImageReader, imageID int, ids models.BulkUpdateIds) (ret []int, err error) {
ret, err = qb.GetTagIDs(imageID)
func adjustImagePerformerIDs(tx *sqlx.Tx, imageID int, ids models.BulkUpdateIds) ([]int, error) { if err != nil {
var ret []int return nil, err
}
jqb := models.NewJoinsQueryBuilder()
if ids.Mode == models.BulkUpdateIDModeAdd || ids.Mode == models.BulkUpdateIDModeRemove { return adjustIDs(ret, ids), nil
// adding to the joins }
performerJoins, err := jqb.GetImagePerformers(imageID, tx)
func (r *mutationResolver) ImageDestroy(ctx context.Context, input models.ImageDestroyInput) (ret bool, err error) {
if err != nil { imageID, err := strconv.Atoi(input.ID)
return nil, err
}
for _, join := range performerJoins {
ret = append(ret, join.PerformerID)
}
}
return adjustIDs(ret, ids), nil
}
func adjustImageTagIDs(tx *sqlx.Tx, imageID int, ids models.BulkUpdateIds) ([]int, error) {
var ret []int
jqb := models.NewJoinsQueryBuilder()
if ids.Mode == models.BulkUpdateIDModeAdd || ids.Mode == models.BulkUpdateIDModeRemove {
// adding to the joins
tagJoins, err := jqb.GetImageTags(imageID, tx)
if err != nil {
return nil, err
}
for _, join := range tagJoins {
ret = append(ret, join.TagID)
}
}
return adjustIDs(ret, ids), nil
}
func (r *mutationResolver) ImageDestroy(ctx context.Context, input models.ImageDestroyInput) (bool, error) {
qb := models.NewImageQueryBuilder()
tx := database.DB.MustBeginTx(ctx, nil)
imageID, _ := strconv.Atoi(input.ID)
image, err := qb.Find(imageID)
err = qb.Destroy(imageID, tx)
if err != nil { if err != nil {
tx.Rollback()
return false, err return false, err
} }
if err := tx.Commit(); err != nil { var image *models.Image
if err := r.withTxn(ctx, func(repo models.Repository) error {
qb := repo.Image()
image, err = qb.Find(imageID)
if err != nil {
return err
}
if image == nil {
return fmt.Errorf("image with id %d not found", imageID)
}
return qb.Destroy(imageID)
}); err != nil {
return false, err return false, err
} }
@@ -325,27 +271,35 @@ func (r *mutationResolver) ImageDestroy(ctx context.Context, input models.ImageD
return true, nil return true, nil
} }
func (r *mutationResolver) ImagesDestroy(ctx context.Context, input models.ImagesDestroyInput) (bool, error) { func (r *mutationResolver) ImagesDestroy(ctx context.Context, input models.ImagesDestroyInput) (ret bool, err error) {
qb := models.NewImageQueryBuilder() imageIDs, err := utils.StringSliceToIntSlice(input.Ids)
tx := database.DB.MustBeginTx(ctx, nil) if err != nil {
return false, err
var images []*models.Image
for _, id := range input.Ids {
imageID, _ := strconv.Atoi(id)
image, err := qb.Find(imageID)
if image != nil {
images = append(images, image)
}
err = qb.Destroy(imageID, tx)
if err != nil {
tx.Rollback()
return false, err
}
} }
if err := tx.Commit(); err != nil { var images []*models.Image
if err := r.withTxn(ctx, func(repo models.Repository) error {
qb := repo.Image()
for _, imageID := range imageIDs {
image, err := qb.Find(imageID)
if err != nil {
return err
}
if image == nil {
return fmt.Errorf("image with id %d not found", imageID)
}
images = append(images, image)
if err := qb.Destroy(imageID); err != nil {
return err
}
}
return nil
}); err != nil {
return false, err return false, err
} }
@@ -366,62 +320,56 @@ func (r *mutationResolver) ImagesDestroy(ctx context.Context, input models.Image
return true, nil return true, nil
} }
func (r *mutationResolver) ImageIncrementO(ctx context.Context, id string) (int, error) { func (r *mutationResolver) ImageIncrementO(ctx context.Context, id string) (ret int, err error) {
imageID, _ := strconv.Atoi(id) imageID, err := strconv.Atoi(id)
tx := database.DB.MustBeginTx(ctx, nil)
qb := models.NewImageQueryBuilder()
newVal, err := qb.IncrementOCounter(imageID, tx)
if err != nil { if err != nil {
_ = tx.Rollback()
return 0, err return 0, err
} }
// Commit if err := r.withTxn(ctx, func(repo models.Repository) error {
if err := tx.Commit(); err != nil { qb := repo.Image()
ret, err = qb.IncrementOCounter(imageID)
return err
}); err != nil {
return 0, err return 0, err
} }
return newVal, nil return ret, nil
} }
func (r *mutationResolver) ImageDecrementO(ctx context.Context, id string) (int, error) { func (r *mutationResolver) ImageDecrementO(ctx context.Context, id string) (ret int, err error) {
imageID, _ := strconv.Atoi(id) imageID, err := strconv.Atoi(id)
tx := database.DB.MustBeginTx(ctx, nil)
qb := models.NewImageQueryBuilder()
newVal, err := qb.DecrementOCounter(imageID, tx)
if err != nil { if err != nil {
_ = tx.Rollback()
return 0, err return 0, err
} }
// Commit if err := r.withTxn(ctx, func(repo models.Repository) error {
if err := tx.Commit(); err != nil { qb := repo.Image()
ret, err = qb.DecrementOCounter(imageID)
return err
}); err != nil {
return 0, err return 0, err
} }
return newVal, nil return ret, nil
} }
func (r *mutationResolver) ImageResetO(ctx context.Context, id string) (int, error) { func (r *mutationResolver) ImageResetO(ctx context.Context, id string) (ret int, err error) {
imageID, _ := strconv.Atoi(id) imageID, err := strconv.Atoi(id)
tx := database.DB.MustBeginTx(ctx, nil)
qb := models.NewImageQueryBuilder()
newVal, err := qb.ResetOCounter(imageID, tx)
if err != nil { if err != nil {
_ = tx.Rollback()
return 0, err return 0, err
} }
// Commit if err := r.withTxn(ctx, func(repo models.Repository) error {
if err := tx.Commit(); err != nil { qb := repo.Image()
ret, err = qb.ResetOCounter(imageID)
return err
}); err != nil {
return 0, err return 0, err
} }
return newVal, nil return ret, nil
} }

View File

@@ -6,7 +6,6 @@ import (
"strconv" "strconv"
"time" "time"
"github.com/stashapp/stash/pkg/database"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
) )
@@ -85,24 +84,23 @@ func (r *mutationResolver) MovieCreate(ctx context.Context, input models.MovieCr
} }
// Start the transaction and save the movie // Start the transaction and save the movie
tx := database.DB.MustBeginTx(ctx, nil) var movie *models.Movie
qb := models.NewMovieQueryBuilder() if err := r.withTxn(ctx, func(repo models.Repository) error {
movie, err := qb.Create(newMovie, tx) qb := repo.Movie()
if err != nil { movie, err = qb.Create(newMovie)
_ = tx.Rollback() if err != nil {
return nil, err return err
}
// update image table
if len(frontimageData) > 0 {
if err := qb.UpdateMovieImages(movie.ID, frontimageData, backimageData, tx); err != nil {
_ = tx.Rollback()
return nil, err
} }
}
// Commit // update image table
if err := tx.Commit(); err != nil { if len(frontimageData) > 0 {
if err := qb.UpdateImages(movie.ID, frontimageData, backimageData); err != nil {
return err
}
}
return nil
}); err != nil {
return nil, err return nil, err
} }
@@ -111,7 +109,10 @@ func (r *mutationResolver) MovieCreate(ctx context.Context, input models.MovieCr
func (r *mutationResolver) MovieUpdate(ctx context.Context, input models.MovieUpdateInput) (*models.Movie, error) { func (r *mutationResolver) MovieUpdate(ctx context.Context, input models.MovieUpdateInput) (*models.Movie, error) {
// Populate movie from the input // Populate movie from the input
movieID, _ := strconv.Atoi(input.ID) movieID, err := strconv.Atoi(input.ID)
if err != nil {
return nil, err
}
updatedMovie := models.MoviePartial{ updatedMovie := models.MoviePartial{
ID: movieID, ID: movieID,
@@ -123,7 +124,6 @@ func (r *mutationResolver) MovieUpdate(ctx context.Context, input models.MovieUp
} }
var frontimageData []byte var frontimageData []byte
var err error
frontImageIncluded := translator.hasField("front_image") frontImageIncluded := translator.hasField("front_image")
if input.FrontImage != nil { if input.FrontImage != nil {
_, frontimageData, err = utils.ProcessBase64Image(*input.FrontImage) _, frontimageData, err = utils.ProcessBase64Image(*input.FrontImage)
@@ -157,53 +157,49 @@ func (r *mutationResolver) MovieUpdate(ctx context.Context, input models.MovieUp
updatedMovie.URL = translator.nullString(input.URL, "url") updatedMovie.URL = translator.nullString(input.URL, "url")
// Start the transaction and save the movie // Start the transaction and save the movie
tx := database.DB.MustBeginTx(ctx, nil) var movie *models.Movie
qb := models.NewMovieQueryBuilder() if err := r.withTxn(ctx, func(repo models.Repository) error {
movie, err := qb.Update(updatedMovie, tx) qb := repo.Movie()
if err != nil { movie, err = qb.Update(updatedMovie)
_ = tx.Rollback() if err != nil {
return nil, err return err
}
// update image table
if frontImageIncluded || backImageIncluded {
if !frontImageIncluded {
frontimageData, err = qb.GetFrontImage(updatedMovie.ID, tx)
if err != nil {
tx.Rollback()
return nil, err
}
} }
if !backImageIncluded {
backimageData, err = qb.GetBackImage(updatedMovie.ID, tx) // update image table
if err != nil { if frontImageIncluded || backImageIncluded {
tx.Rollback() if !frontImageIncluded {
return nil, err frontimageData, err = qb.GetFrontImage(updatedMovie.ID)
if err != nil {
return err
}
}
if !backImageIncluded {
backimageData, err = qb.GetBackImage(updatedMovie.ID)
if err != nil {
return err
}
}
if len(frontimageData) == 0 && len(backimageData) == 0 {
// both images are being nulled. Destroy them.
if err := qb.DestroyImages(movie.ID); err != nil {
return err
}
} else {
// HACK - if front image is null and back image is not null, then set the front image
// to the default image since we can't have a null front image and a non-null back image
if frontimageData == nil && backimageData != nil {
_, frontimageData, _ = utils.ProcessBase64Image(models.DefaultMovieImage)
}
if err := qb.UpdateImages(movie.ID, frontimageData, backimageData); err != nil {
return err
}
} }
} }
if len(frontimageData) == 0 && len(backimageData) == 0 { return nil
// both images are being nulled. Destroy them. }); err != nil {
if err := qb.DestroyMovieImages(movie.ID, tx); err != nil {
tx.Rollback()
return nil, err
}
} else {
// HACK - if front image is null and back image is not null, then set the front image
// to the default image since we can't have a null front image and a non-null back image
if frontimageData == nil && backimageData != nil {
_, frontimageData, _ = utils.ProcessBase64Image(models.DefaultMovieImage)
}
if err := qb.UpdateMovieImages(movie.ID, frontimageData, backimageData, tx); err != nil {
_ = tx.Rollback()
return nil, err
}
}
}
// Commit
if err := tx.Commit(); err != nil {
return nil, err return nil, err
} }
@@ -211,28 +207,35 @@ func (r *mutationResolver) MovieUpdate(ctx context.Context, input models.MovieUp
} }
func (r *mutationResolver) MovieDestroy(ctx context.Context, input models.MovieDestroyInput) (bool, error) { func (r *mutationResolver) MovieDestroy(ctx context.Context, input models.MovieDestroyInput) (bool, error) {
qb := models.NewMovieQueryBuilder() id, err := strconv.Atoi(input.ID)
tx := database.DB.MustBeginTx(ctx, nil) if err != nil {
if err := qb.Destroy(input.ID, tx); err != nil {
_ = tx.Rollback()
return false, err return false, err
} }
if err := tx.Commit(); err != nil {
if err := r.withTxn(ctx, func(repo models.Repository) error {
return repo.Movie().Destroy(id)
}); err != nil {
return false, err return false, err
} }
return true, nil return true, nil
} }
func (r *mutationResolver) MoviesDestroy(ctx context.Context, ids []string) (bool, error) { func (r *mutationResolver) MoviesDestroy(ctx context.Context, movieIDs []string) (bool, error) {
qb := models.NewMovieQueryBuilder() ids, err := utils.StringSliceToIntSlice(movieIDs)
tx := database.DB.MustBeginTx(ctx, nil) if err != nil {
for _, id := range ids { return false, err
if err := qb.Destroy(id, tx); err != nil {
_ = tx.Rollback()
return false, err
}
} }
if err := tx.Commit(); err != nil {
if err := r.withTxn(ctx, func(repo models.Repository) error {
qb := repo.Movie()
for _, id := range ids {
if err := qb.Destroy(id); err != nil {
return err
}
}
return nil
}); err != nil {
return false, err return false, err
} }
return true, nil return true, nil

View File

@@ -6,7 +6,6 @@ import (
"strconv" "strconv"
"time" "time"
"github.com/stashapp/stash/pkg/database"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
) )
@@ -86,41 +85,32 @@ func (r *mutationResolver) PerformerCreate(ctx context.Context, input models.Per
} }
// Start the transaction and save the performer // Start the transaction and save the performer
tx := database.DB.MustBeginTx(ctx, nil) var performer *models.Performer
qb := models.NewPerformerQueryBuilder() if err := r.withTxn(ctx, func(repo models.Repository) error {
jqb := models.NewJoinsQueryBuilder() qb := repo.Performer()
performer, err := qb.Create(newPerformer, tx) performer, err = qb.Create(newPerformer)
if err != nil { if err != nil {
_ = tx.Rollback() return err
return nil, err
}
// update image table
if len(imageData) > 0 {
if err := qb.UpdatePerformerImage(performer.ID, imageData, tx); err != nil {
_ = tx.Rollback()
return nil, err
} }
}
// Save the stash_ids // update image table
if input.StashIds != nil { if len(imageData) > 0 {
var stashIDJoins []models.StashID if err := qb.UpdateImage(performer.ID, imageData); err != nil {
for _, stashID := range input.StashIds { return err
newJoin := models.StashID{
StashID: stashID.StashID,
Endpoint: stashID.Endpoint,
} }
stashIDJoins = append(stashIDJoins, newJoin)
} }
if err := jqb.UpdatePerformerStashIDs(performer.ID, stashIDJoins, tx); err != nil {
return nil, err
}
}
// Commit // Save the stash_ids
if err := tx.Commit(); err != nil { if input.StashIds != nil {
stashIDJoins := models.StashIDsFromInput(input.StashIds)
if err := qb.UpdateStashIDs(performer.ID, stashIDJoins); err != nil {
return err
}
}
return nil
}); err != nil {
return nil, err return nil, err
} }
@@ -183,47 +173,38 @@ func (r *mutationResolver) PerformerUpdate(ctx context.Context, input models.Per
updatedPerformer.Favorite = translator.nullBool(input.Favorite, "favorite") updatedPerformer.Favorite = translator.nullBool(input.Favorite, "favorite")
// Start the transaction and save the performer // Start the transaction and save the performer
tx := database.DB.MustBeginTx(ctx, nil) var performer *models.Performer
qb := models.NewPerformerQueryBuilder() if err := r.withTxn(ctx, func(repo models.Repository) error {
jqb := models.NewJoinsQueryBuilder() qb := repo.Performer()
performer, err := qb.Update(updatedPerformer, tx) var err error
if err != nil { performer, err = qb.Update(updatedPerformer)
tx.Rollback() if err != nil {
return nil, err return err
}
// update image table
if len(imageData) > 0 {
if err := qb.UpdatePerformerImage(performer.ID, imageData, tx); err != nil {
tx.Rollback()
return nil, err
} }
} else if imageIncluded {
// must be unsetting
if err := qb.DestroyPerformerImage(performer.ID, tx); err != nil {
tx.Rollback()
return nil, err
}
}
// Save the stash_ids // update image table
if translator.hasField("stash_ids") { if len(imageData) > 0 {
var stashIDJoins []models.StashID if err := qb.UpdateImage(performer.ID, imageData); err != nil {
for _, stashID := range input.StashIds { return err
newJoin := models.StashID{ }
StashID: stashID.StashID, } else if imageIncluded {
Endpoint: stashID.Endpoint, // must be unsetting
if err := qb.DestroyImage(performer.ID); err != nil {
return err
} }
stashIDJoins = append(stashIDJoins, newJoin)
} }
if err := jqb.UpdatePerformerStashIDs(performerID, stashIDJoins, tx); err != nil {
return nil, err
}
}
// Commit // Save the stash_ids
if err := tx.Commit(); err != nil { if translator.hasField("stash_ids") {
stashIDJoins := models.StashIDsFromInput(input.StashIds)
if err := qb.UpdateStashIDs(performerID, stashIDJoins); err != nil {
return err
}
}
return nil
}); err != nil {
return nil, err return nil, err
} }
@@ -231,28 +212,35 @@ func (r *mutationResolver) PerformerUpdate(ctx context.Context, input models.Per
} }
func (r *mutationResolver) PerformerDestroy(ctx context.Context, input models.PerformerDestroyInput) (bool, error) { func (r *mutationResolver) PerformerDestroy(ctx context.Context, input models.PerformerDestroyInput) (bool, error) {
qb := models.NewPerformerQueryBuilder() id, err := strconv.Atoi(input.ID)
tx := database.DB.MustBeginTx(ctx, nil) if err != nil {
if err := qb.Destroy(input.ID, tx); err != nil {
_ = tx.Rollback()
return false, err return false, err
} }
if err := tx.Commit(); err != nil {
if err := r.withTxn(ctx, func(repo models.Repository) error {
return repo.Performer().Destroy(id)
}); err != nil {
return false, err return false, err
} }
return true, nil return true, nil
} }
func (r *mutationResolver) PerformersDestroy(ctx context.Context, ids []string) (bool, error) { func (r *mutationResolver) PerformersDestroy(ctx context.Context, performerIDs []string) (bool, error) {
qb := models.NewPerformerQueryBuilder() ids, err := utils.StringSliceToIntSlice(performerIDs)
tx := database.DB.MustBeginTx(ctx, nil) if err != nil {
for _, id := range ids { return false, err
if err := qb.Destroy(id, tx); err != nil {
_ = tx.Rollback()
return false, err
}
} }
if err := tx.Commit(); err != nil {
if err := r.withTxn(ctx, func(repo models.Repository) error {
qb := repo.Performer()
for _, id := range ids {
if err := qb.Destroy(id); err != nil {
return err
}
}
return nil
}); err != nil {
return false, err return false, err
} }
return true, nil return true, nil

View File

@@ -3,73 +3,64 @@ package api
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"strconv" "strconv"
"time" "time"
"github.com/jmoiron/sqlx"
"github.com/stashapp/stash/pkg/database"
"github.com/stashapp/stash/pkg/manager" "github.com/stashapp/stash/pkg/manager"
"github.com/stashapp/stash/pkg/manager/config" "github.com/stashapp/stash/pkg/manager/config"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
) )
func (r *mutationResolver) SceneUpdate(ctx context.Context, input models.SceneUpdateInput) (*models.Scene, error) { func (r *mutationResolver) SceneUpdate(ctx context.Context, input models.SceneUpdateInput) (ret *models.Scene, err error) {
// Start the transaction and save the scene
tx := database.DB.MustBeginTx(ctx, nil)
translator := changesetTranslator{ translator := changesetTranslator{
inputMap: getUpdateInputMap(ctx), inputMap: getUpdateInputMap(ctx),
} }
ret, err := r.sceneUpdate(input, translator, tx)
if err != nil { // Start the transaction and save the scene
_ = tx.Rollback() if err := r.withTxn(ctx, func(repo models.Repository) error {
return nil, err ret, err = r.sceneUpdate(input, translator, repo)
} return err
}); err != nil {
// Commit
if err := tx.Commit(); err != nil {
return nil, err return nil, err
} }
return ret, nil return ret, nil
} }
func (r *mutationResolver) ScenesUpdate(ctx context.Context, input []*models.SceneUpdateInput) ([]*models.Scene, error) { func (r *mutationResolver) ScenesUpdate(ctx context.Context, input []*models.SceneUpdateInput) (ret []*models.Scene, err error) {
// Start the transaction and save the scene
tx := database.DB.MustBeginTx(ctx, nil)
var ret []*models.Scene
inputMaps := getUpdateInputMaps(ctx) inputMaps := getUpdateInputMaps(ctx)
for i, scene := range input { // Start the transaction and save the scene
translator := changesetTranslator{ if err := r.withTxn(ctx, func(repo models.Repository) error {
inputMap: inputMaps[i], for i, scene := range input {
translator := changesetTranslator{
inputMap: inputMaps[i],
}
thisScene, err := r.sceneUpdate(*scene, translator, repo)
ret = append(ret, thisScene)
if err != nil {
return err
}
} }
thisScene, err := r.sceneUpdate(*scene, translator, tx) return nil
ret = append(ret, thisScene) }); err != nil {
if err != nil {
_ = tx.Rollback()
return nil, err
}
}
// Commit
if err := tx.Commit(); err != nil {
return nil, err return nil, err
} }
return ret, nil return ret, nil
} }
func (r *mutationResolver) sceneUpdate(input models.SceneUpdateInput, translator changesetTranslator, tx *sqlx.Tx) (*models.Scene, error) { func (r *mutationResolver) sceneUpdate(input models.SceneUpdateInput, translator changesetTranslator, repo models.Repository) (*models.Scene, error) {
// Populate scene from the input // Populate scene from the input
sceneID, _ := strconv.Atoi(input.ID) sceneID, err := strconv.Atoi(input.ID)
if err != nil {
return nil, err
}
var coverImageData []byte var coverImageData []byte
@@ -97,24 +88,23 @@ func (r *mutationResolver) sceneUpdate(input models.SceneUpdateInput, translator
// update the cover after updating the scene // update the cover after updating the scene
} }
qb := models.NewSceneQueryBuilder() qb := repo.Scene()
jqb := models.NewJoinsQueryBuilder() scene, err := qb.Update(updatedScene)
scene, err := qb.Update(updatedScene, tx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// update cover table // update cover table
if len(coverImageData) > 0 { if len(coverImageData) > 0 {
if err := qb.UpdateSceneCover(sceneID, coverImageData, tx); err != nil { if err := qb.UpdateCover(sceneID, coverImageData); err != nil {
return nil, err return nil, err
} }
} }
// Clear the existing gallery value // Clear the existing gallery value
if translator.hasField("gallery_id") { if translator.hasField("gallery_id") {
gqb := models.NewGalleryQueryBuilder() gqb := repo.Gallery()
err = gqb.ClearGalleryId(sceneID, tx) err = gqb.ClearGalleryId(sceneID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -122,13 +112,12 @@ func (r *mutationResolver) sceneUpdate(input models.SceneUpdateInput, translator
if input.GalleryID != nil { if input.GalleryID != nil {
// Save the gallery // Save the gallery
galleryID, _ := strconv.Atoi(*input.GalleryID) galleryID, _ := strconv.Atoi(*input.GalleryID)
updatedGallery := models.Gallery{ updatedGallery := models.GalleryPartial{
ID: galleryID, ID: galleryID,
SceneID: sql.NullInt64{Int64: int64(sceneID), Valid: true}, SceneID: &sql.NullInt64{Int64: int64(sceneID), Valid: true},
UpdatedAt: models.SQLiteTimestamp{Timestamp: updatedTime}, UpdatedAt: &models.SQLiteTimestamp{Timestamp: updatedTime},
} }
gqb := models.NewGalleryQueryBuilder() _, err := gqb.UpdatePartial(updatedGallery)
_, err := gqb.Update(updatedGallery, tx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -137,59 +126,29 @@ func (r *mutationResolver) sceneUpdate(input models.SceneUpdateInput, translator
// Save the performers // Save the performers
if translator.hasField("performer_ids") { if translator.hasField("performer_ids") {
var performerJoins []models.PerformersScenes if err := r.updateScenePerformers(qb, sceneID, input.PerformerIds); err != nil {
for _, pid := range input.PerformerIds {
performerID, _ := strconv.Atoi(pid)
performerJoin := models.PerformersScenes{
PerformerID: performerID,
SceneID: sceneID,
}
performerJoins = append(performerJoins, performerJoin)
}
if err := jqb.UpdatePerformersScenes(sceneID, performerJoins, tx); err != nil {
return nil, err return nil, err
} }
} }
// Save the movies // Save the movies
if translator.hasField("movies") { if translator.hasField("movies") {
var movieJoins []models.MoviesScenes if err := r.updateSceneMovies(qb, sceneID, input.Movies); err != nil {
for _, movie := range input.Movies {
movieID, _ := strconv.Atoi(movie.MovieID)
movieJoin := models.MoviesScenes{
MovieID: movieID,
SceneID: sceneID,
}
if movie.SceneIndex != nil {
movieJoin.SceneIndex = sql.NullInt64{
Int64: int64(*movie.SceneIndex),
Valid: true,
}
}
movieJoins = append(movieJoins, movieJoin)
}
if err := jqb.UpdateMoviesScenes(sceneID, movieJoins, tx); err != nil {
return nil, err return nil, err
} }
} }
// Save the tags // Save the tags
if translator.hasField("tag_ids") { if translator.hasField("tag_ids") {
var tagJoins []models.ScenesTags if err := r.updateSceneTags(qb, sceneID, input.TagIds); err != nil {
for _, tid := range input.TagIds { return nil, err
tagID, _ := strconv.Atoi(tid)
tagJoin := models.ScenesTags{
SceneID: sceneID,
TagID: tagID,
}
tagJoins = append(tagJoins, tagJoin)
} }
if err := jqb.UpdateScenesTags(sceneID, tagJoins, tx); err != nil { }
// Save the stash_ids
if translator.hasField("stash_ids") {
stashIDJoins := models.StashIDsFromInput(input.StashIds)
if err := qb.UpdateStashIDs(sceneID, stashIDJoins); err != nil {
return nil, err return nil, err
} }
} }
@@ -202,25 +161,57 @@ func (r *mutationResolver) sceneUpdate(input models.SceneUpdateInput, translator
} }
} }
// Save the stash_ids
if translator.hasField("stash_ids") {
var stashIDJoins []models.StashID
for _, stashID := range input.StashIds {
newJoin := models.StashID{
StashID: stashID.StashID,
Endpoint: stashID.Endpoint,
}
stashIDJoins = append(stashIDJoins, newJoin)
}
if err := jqb.UpdateSceneStashIDs(sceneID, stashIDJoins, tx); err != nil {
return nil, err
}
}
return scene, nil return scene, nil
} }
func (r *mutationResolver) updateScenePerformers(qb models.SceneReaderWriter, sceneID int, performerIDs []string) error {
ids, err := utils.StringSliceToIntSlice(performerIDs)
if err != nil {
return err
}
return qb.UpdatePerformers(sceneID, ids)
}
func (r *mutationResolver) updateSceneMovies(qb models.SceneReaderWriter, sceneID int, movies []*models.SceneMovieInput) error {
var movieJoins []models.MoviesScenes
for _, movie := range movies {
movieID, err := strconv.Atoi(movie.MovieID)
if err != nil {
return err
}
movieJoin := models.MoviesScenes{
MovieID: movieID,
}
if movie.SceneIndex != nil {
movieJoin.SceneIndex = sql.NullInt64{
Int64: int64(*movie.SceneIndex),
Valid: true,
}
}
movieJoins = append(movieJoins, movieJoin)
}
return qb.UpdateMovies(sceneID, movieJoins)
}
func (r *mutationResolver) updateSceneTags(qb models.SceneReaderWriter, sceneID int, tagsIDs []string) error {
ids, err := utils.StringSliceToIntSlice(tagsIDs)
if err != nil {
return err
}
return qb.UpdateTags(sceneID, ids)
}
func (r *mutationResolver) BulkSceneUpdate(ctx context.Context, input models.BulkSceneUpdateInput) ([]*models.Scene, error) { func (r *mutationResolver) BulkSceneUpdate(ctx context.Context, input models.BulkSceneUpdateInput) ([]*models.Scene, error) {
sceneIDs, err := utils.StringSliceToIntSlice(input.Ids)
if err != nil {
return nil, err
}
// Populate scene from the input // Populate scene from the input
updatedTime := time.Now() updatedTime := time.Now()
@@ -228,11 +219,6 @@ func (r *mutationResolver) BulkSceneUpdate(ctx context.Context, input models.Bul
inputMap: getUpdateInputMap(ctx), inputMap: getUpdateInputMap(ctx),
} }
// Start the transaction and save the scene marker
tx := database.DB.MustBeginTx(ctx, nil)
qb := models.NewSceneQueryBuilder()
jqb := models.NewJoinsQueryBuilder()
updatedScene := models.ScenePartial{ updatedScene := models.ScenePartial{
UpdatedAt: &models.SQLiteTimestamp{Timestamp: updatedTime}, UpdatedAt: &models.SQLiteTimestamp{Timestamp: updatedTime},
} }
@@ -247,84 +233,62 @@ func (r *mutationResolver) BulkSceneUpdate(ctx context.Context, input models.Bul
ret := []*models.Scene{} ret := []*models.Scene{}
for _, sceneIDStr := range input.Ids { // Start the transaction and save the scene marker
sceneID, _ := strconv.Atoi(sceneIDStr) if err := r.withTxn(ctx, func(repo models.Repository) error {
updatedScene.ID = sceneID qb := repo.Scene()
gqb := repo.Gallery()
scene, err := qb.Update(updatedScene, tx) for _, sceneID := range sceneIDs {
if err != nil { updatedScene.ID = sceneID
_ = tx.Rollback()
return nil, err
}
ret = append(ret, scene) scene, err := qb.Update(updatedScene)
if translator.hasField("gallery_id") {
// Save the gallery
var galleryID int
if input.GalleryID != nil {
galleryID, _ = strconv.Atoi(*input.GalleryID)
}
updatedGallery := models.Gallery{
ID: galleryID,
SceneID: sql.NullInt64{Int64: int64(sceneID), Valid: true},
UpdatedAt: models.SQLiteTimestamp{Timestamp: updatedTime},
}
gqb := models.NewGalleryQueryBuilder()
_, err := gqb.Update(updatedGallery, tx)
if err != nil { if err != nil {
_ = tx.Rollback() return err
return nil, err
}
}
// Save the performers
if translator.hasField("performer_ids") {
performerIDs, err := adjustScenePerformerIDs(tx, sceneID, *input.PerformerIds)
if err != nil {
_ = tx.Rollback()
return nil, err
} }
var performerJoins []models.PerformersScenes ret = append(ret, scene)
for _, performerID := range performerIDs {
performerJoin := models.PerformersScenes{ if translator.hasField("gallery_id") {
PerformerID: performerID, // Save the gallery
SceneID: sceneID, galleryID, _ := strconv.Atoi(*input.GalleryID)
updatedGallery := models.GalleryPartial{
ID: galleryID,
SceneID: &sql.NullInt64{Int64: int64(sceneID), Valid: true},
UpdatedAt: &models.SQLiteTimestamp{Timestamp: updatedTime},
} }
performerJoins = append(performerJoins, performerJoin)
}
if err := jqb.UpdatePerformersScenes(sceneID, performerJoins, tx); err != nil {
_ = tx.Rollback()
return nil, err
}
}
// Save the tags if _, err := gqb.UpdatePartial(updatedGallery); err != nil {
if translator.hasField("tag_ids") { return err
tagIDs, err := adjustSceneTagIDs(tx, sceneID, *input.TagIds)
if err != nil {
_ = tx.Rollback()
return nil, err
}
var tagJoins []models.ScenesTags
for _, tagID := range tagIDs {
tagJoin := models.ScenesTags{
SceneID: sceneID,
TagID: tagID,
} }
tagJoins = append(tagJoins, tagJoin)
} }
if err := jqb.UpdateScenesTags(sceneID, tagJoins, tx); err != nil {
_ = tx.Rollback() // Save the performers
return nil, err if translator.hasField("performer_ids") {
performerIDs, err := adjustScenePerformerIDs(qb, sceneID, *input.PerformerIds)
if err != nil {
return err
}
if err := qb.UpdatePerformers(sceneID, performerIDs); err != nil {
return err
}
}
// Save the tags
if translator.hasField("tag_ids") {
tagIDs, err := adjustSceneTagIDs(qb, sceneID, *input.TagIds)
if err != nil {
return err
}
if err := qb.UpdateTags(sceneID, tagIDs); err != nil {
return err
}
} }
} }
}
// Commit return nil
if err := tx.Commit(); err != nil { }); err != nil {
return nil, err return nil, err
} }
@@ -332,6 +296,17 @@ func (r *mutationResolver) BulkSceneUpdate(ctx context.Context, input models.Bul
} }
func adjustIDs(existingIDs []int, updateIDs models.BulkUpdateIds) []int { func adjustIDs(existingIDs []int, updateIDs models.BulkUpdateIds) []int {
// if we are setting the ids, just return the ids
if updateIDs.Mode == models.BulkUpdateIDModeSet {
existingIDs = []int{}
for _, idStr := range updateIDs.Ids {
id, _ := strconv.Atoi(idStr)
existingIDs = append(existingIDs, id)
}
return existingIDs
}
for _, idStr := range updateIDs.Ids { for _, idStr := range updateIDs.Ids {
id, _ := strconv.Atoi(idStr) id, _ := strconv.Atoi(idStr)
@@ -357,63 +332,53 @@ func adjustIDs(existingIDs []int, updateIDs models.BulkUpdateIds) []int {
return existingIDs return existingIDs
} }
func adjustScenePerformerIDs(tx *sqlx.Tx, sceneID int, ids models.BulkUpdateIds) ([]int, error) { func adjustScenePerformerIDs(qb models.SceneReader, sceneID int, ids models.BulkUpdateIds) (ret []int, err error) {
var ret []int ret, err = qb.GetPerformerIDs(sceneID)
if err != nil {
jqb := models.NewJoinsQueryBuilder() return nil, err
if ids.Mode == models.BulkUpdateIDModeAdd || ids.Mode == models.BulkUpdateIDModeRemove {
// adding to the joins
performerJoins, err := jqb.GetScenePerformers(sceneID, tx)
if err != nil {
return nil, err
}
for _, join := range performerJoins {
ret = append(ret, join.PerformerID)
}
} }
return adjustIDs(ret, ids), nil return adjustIDs(ret, ids), nil
} }
func adjustSceneTagIDs(tx *sqlx.Tx, sceneID int, ids models.BulkUpdateIds) ([]int, error) { func adjustSceneTagIDs(qb models.SceneReader, sceneID int, ids models.BulkUpdateIds) (ret []int, err error) {
var ret []int ret, err = qb.GetTagIDs(sceneID)
if err != nil {
jqb := models.NewJoinsQueryBuilder() return nil, err
if ids.Mode == models.BulkUpdateIDModeAdd || ids.Mode == models.BulkUpdateIDModeRemove {
// adding to the joins
tagJoins, err := jqb.GetSceneTags(sceneID, tx)
if err != nil {
return nil, err
}
for _, join := range tagJoins {
ret = append(ret, join.TagID)
}
} }
return adjustIDs(ret, ids), nil return adjustIDs(ret, ids), nil
} }
func (r *mutationResolver) SceneDestroy(ctx context.Context, input models.SceneDestroyInput) (bool, error) { func (r *mutationResolver) SceneDestroy(ctx context.Context, input models.SceneDestroyInput) (bool, error) {
qb := models.NewSceneQueryBuilder() sceneID, err := strconv.Atoi(input.ID)
tx := database.DB.MustBeginTx(ctx, nil)
sceneID, _ := strconv.Atoi(input.ID)
scene, err := qb.Find(sceneID)
err = manager.DestroyScene(sceneID, tx)
if err != nil { if err != nil {
tx.Rollback()
return false, err return false, err
} }
if err := tx.Commit(); err != nil { var scene *models.Scene
var postCommitFunc func()
if err := r.withTxn(ctx, func(repo models.Repository) error {
qb := repo.Scene()
var err error
scene, err = qb.Find(sceneID)
if err != nil {
return err
}
if scene == nil {
return fmt.Errorf("scene with id %d not found", sceneID)
}
postCommitFunc, err = manager.DestroyScene(scene, repo)
return err
}); err != nil {
return false, err return false, err
} }
// perform the post-commit actions
postCommitFunc()
// if delete generated is true, then delete the generated files // if delete generated is true, then delete the generated files
// for the scene // for the scene
if input.DeleteGenerated != nil && *input.DeleteGenerated { if input.DeleteGenerated != nil && *input.DeleteGenerated {
@@ -430,27 +395,33 @@ func (r *mutationResolver) SceneDestroy(ctx context.Context, input models.SceneD
} }
func (r *mutationResolver) ScenesDestroy(ctx context.Context, input models.ScenesDestroyInput) (bool, error) { func (r *mutationResolver) ScenesDestroy(ctx context.Context, input models.ScenesDestroyInput) (bool, error) {
qb := models.NewSceneQueryBuilder()
tx := database.DB.MustBeginTx(ctx, nil)
var scenes []*models.Scene var scenes []*models.Scene
for _, id := range input.Ids { var postCommitFuncs []func()
sceneID, _ := strconv.Atoi(id) if err := r.withTxn(ctx, func(repo models.Repository) error {
qb := repo.Scene()
scene, err := qb.Find(sceneID) for _, id := range input.Ids {
if scene != nil { sceneID, _ := strconv.Atoi(id)
scenes = append(scenes, scene)
}
err = manager.DestroyScene(sceneID, tx)
if err != nil { scene, err := qb.Find(sceneID)
tx.Rollback() if scene != nil {
return false, err scenes = append(scenes, scene)
}
f, err := manager.DestroyScene(scene, repo)
if err != nil {
return err
}
postCommitFuncs = append(postCommitFuncs, f)
} }
return nil
}); err != nil {
return false, err
} }
if err := tx.Commit(); err != nil { for _, f := range postCommitFuncs {
return false, err f()
} }
fileNamingAlgo := config.GetVideoFileNamingAlgorithm() fileNamingAlgo := config.GetVideoFileNamingAlgorithm()
@@ -472,8 +443,16 @@ func (r *mutationResolver) ScenesDestroy(ctx context.Context, input models.Scene
} }
func (r *mutationResolver) SceneMarkerCreate(ctx context.Context, input models.SceneMarkerCreateInput) (*models.SceneMarker, error) { func (r *mutationResolver) SceneMarkerCreate(ctx context.Context, input models.SceneMarkerCreateInput) (*models.SceneMarker, error) {
primaryTagID, _ := strconv.Atoi(input.PrimaryTagID) primaryTagID, err := strconv.Atoi(input.PrimaryTagID)
sceneID, _ := strconv.Atoi(input.SceneID) if err != nil {
return nil, err
}
sceneID, err := strconv.Atoi(input.SceneID)
if err != nil {
return nil, err
}
currentTime := time.Now() currentTime := time.Now()
newSceneMarker := models.SceneMarker{ newSceneMarker := models.SceneMarker{
Title: input.Title, Title: input.Title,
@@ -484,14 +463,31 @@ func (r *mutationResolver) SceneMarkerCreate(ctx context.Context, input models.S
UpdatedAt: models.SQLiteTimestamp{Timestamp: currentTime}, UpdatedAt: models.SQLiteTimestamp{Timestamp: currentTime},
} }
return changeMarker(ctx, create, newSceneMarker, input.TagIds) tagIDs, err := utils.StringSliceToIntSlice(input.TagIds)
if err != nil {
return nil, err
}
return r.changeMarker(ctx, create, newSceneMarker, tagIDs)
} }
func (r *mutationResolver) SceneMarkerUpdate(ctx context.Context, input models.SceneMarkerUpdateInput) (*models.SceneMarker, error) { func (r *mutationResolver) SceneMarkerUpdate(ctx context.Context, input models.SceneMarkerUpdateInput) (*models.SceneMarker, error) {
// Populate scene marker from the input // Populate scene marker from the input
sceneMarkerID, _ := strconv.Atoi(input.ID) sceneMarkerID, err := strconv.Atoi(input.ID)
sceneID, _ := strconv.Atoi(input.SceneID) if err != nil {
primaryTagID, _ := strconv.Atoi(input.PrimaryTagID) return nil, err
}
primaryTagID, err := strconv.Atoi(input.PrimaryTagID)
if err != nil {
return nil, err
}
sceneID, err := strconv.Atoi(input.SceneID)
if err != nil {
return nil, err
}
updatedSceneMarker := models.SceneMarker{ updatedSceneMarker := models.SceneMarker{
ID: sceneMarkerID, ID: sceneMarkerID,
Title: input.Title, Title: input.Title,
@@ -501,168 +497,151 @@ func (r *mutationResolver) SceneMarkerUpdate(ctx context.Context, input models.S
UpdatedAt: models.SQLiteTimestamp{Timestamp: time.Now()}, UpdatedAt: models.SQLiteTimestamp{Timestamp: time.Now()},
} }
return changeMarker(ctx, update, updatedSceneMarker, input.TagIds) tagIDs, err := utils.StringSliceToIntSlice(input.TagIds)
if err != nil {
return nil, err
}
return r.changeMarker(ctx, update, updatedSceneMarker, tagIDs)
} }
func (r *mutationResolver) SceneMarkerDestroy(ctx context.Context, id string) (bool, error) { func (r *mutationResolver) SceneMarkerDestroy(ctx context.Context, id string) (bool, error) {
qb := models.NewSceneMarkerQueryBuilder() markerID, err := strconv.Atoi(id)
tx := database.DB.MustBeginTx(ctx, nil)
markerID, _ := strconv.Atoi(id)
marker, err := qb.Find(markerID)
if err != nil { if err != nil {
return false, err return false, err
} }
if err := qb.Destroy(id, tx); err != nil { var postCommitFunc func()
_ = tx.Rollback() if err := r.withTxn(ctx, func(repo models.Repository) error {
return false, err qb := repo.SceneMarker()
} sqb := repo.Scene()
if err := tx.Commit(); err != nil {
marker, err := qb.Find(markerID)
if err != nil {
return err
}
if marker == nil {
return fmt.Errorf("scene marker with id %d not found", markerID)
}
scene, err := sqb.Find(int(marker.SceneID.Int64))
if err != nil {
return err
}
postCommitFunc, err = manager.DestroySceneMarker(scene, marker, qb)
return err
}); err != nil {
return false, err return false, err
} }
// delete the preview for the marker postCommitFunc()
sqb := models.NewSceneQueryBuilder()
scene, _ := sqb.Find(int(marker.SceneID.Int64))
if scene != nil {
seconds := int(marker.Seconds)
manager.DeleteSceneMarkerFiles(scene, seconds, config.GetVideoFileNamingAlgorithm())
}
return true, nil return true, nil
} }
func changeMarker(ctx context.Context, changeType int, changedMarker models.SceneMarker, tagIds []string) (*models.SceneMarker, error) { func (r *mutationResolver) changeMarker(ctx context.Context, changeType int, changedMarker models.SceneMarker, tagIDs []int) (*models.SceneMarker, error) {
// Start the transaction and save the scene marker
tx := database.DB.MustBeginTx(ctx, nil)
qb := models.NewSceneMarkerQueryBuilder()
jqb := models.NewJoinsQueryBuilder()
var existingMarker *models.SceneMarker var existingMarker *models.SceneMarker
var sceneMarker *models.SceneMarker var sceneMarker *models.SceneMarker
var err error var scene *models.Scene
switch changeType {
case create:
sceneMarker, err = qb.Create(changedMarker, tx)
case update:
// check to see if timestamp was changed
existingMarker, err = qb.Find(changedMarker.ID)
if err == nil {
sceneMarker, err = qb.Update(changedMarker, tx)
}
}
if err != nil {
_ = tx.Rollback()
return nil, err
}
// Save the marker tags // Start the transaction and save the scene marker
var markerTagJoins []models.SceneMarkersTags if err := r.withTxn(ctx, func(repo models.Repository) error {
for _, tid := range tagIds { qb := repo.SceneMarker()
tagID, _ := strconv.Atoi(tid) sqb := repo.Scene()
if tagID == changedMarker.PrimaryTagID {
continue // If this tag is the primary tag, then let's not add it.
}
markerTag := models.SceneMarkersTags{
SceneMarkerID: sceneMarker.ID,
TagID: tagID,
}
markerTagJoins = append(markerTagJoins, markerTag)
}
switch changeType {
case create:
if err := jqb.CreateSceneMarkersTags(markerTagJoins, tx); err != nil {
_ = tx.Rollback()
return nil, err
}
case update:
if err := jqb.UpdateSceneMarkersTags(changedMarker.ID, markerTagJoins, tx); err != nil {
_ = tx.Rollback()
return nil, err
}
}
// Commit var err error
if err := tx.Commit(); err != nil { switch changeType {
case create:
sceneMarker, err = qb.Create(changedMarker)
case update:
// check to see if timestamp was changed
existingMarker, err = qb.Find(changedMarker.ID)
if err != nil {
return err
}
sceneMarker, err = qb.Update(changedMarker)
if err != nil {
return err
}
scene, err = sqb.Find(int(existingMarker.SceneID.Int64))
}
if err != nil {
return err
}
// Save the marker tags
// If this tag is the primary tag, then let's not add it.
tagIDs = utils.IntExclude(tagIDs, []int{changedMarker.PrimaryTagID})
return qb.UpdateTags(sceneMarker.ID, tagIDs)
}); err != nil {
return nil, err return nil, err
} }
// remove the marker preview if the timestamp was changed // remove the marker preview if the timestamp was changed
if existingMarker != nil && existingMarker.Seconds != changedMarker.Seconds { if scene != nil && existingMarker != nil && existingMarker.Seconds != changedMarker.Seconds {
sqb := models.NewSceneQueryBuilder() seconds := int(existingMarker.Seconds)
manager.DeleteSceneMarkerFiles(scene, seconds, config.GetVideoFileNamingAlgorithm())
scene, _ := sqb.Find(int(existingMarker.SceneID.Int64))
if scene != nil {
seconds := int(existingMarker.Seconds)
manager.DeleteSceneMarkerFiles(scene, seconds, config.GetVideoFileNamingAlgorithm())
}
} }
return sceneMarker, nil return sceneMarker, nil
} }
func (r *mutationResolver) SceneIncrementO(ctx context.Context, id string) (int, error) { func (r *mutationResolver) SceneIncrementO(ctx context.Context, id string) (ret int, err error) {
sceneID, _ := strconv.Atoi(id) sceneID, err := strconv.Atoi(id)
tx := database.DB.MustBeginTx(ctx, nil)
qb := models.NewSceneQueryBuilder()
newVal, err := qb.IncrementOCounter(sceneID, tx)
if err != nil { if err != nil {
_ = tx.Rollback()
return 0, err return 0, err
} }
// Commit if err := r.withTxn(ctx, func(repo models.Repository) error {
if err := tx.Commit(); err != nil { qb := repo.Scene()
ret, err = qb.IncrementOCounter(sceneID)
return err
}); err != nil {
return 0, err return 0, err
} }
return newVal, nil return ret, nil
} }
func (r *mutationResolver) SceneDecrementO(ctx context.Context, id string) (int, error) { func (r *mutationResolver) SceneDecrementO(ctx context.Context, id string) (ret int, err error) {
sceneID, _ := strconv.Atoi(id) sceneID, err := strconv.Atoi(id)
tx := database.DB.MustBeginTx(ctx, nil)
qb := models.NewSceneQueryBuilder()
newVal, err := qb.DecrementOCounter(sceneID, tx)
if err != nil { if err != nil {
_ = tx.Rollback()
return 0, err return 0, err
} }
// Commit if err := r.withTxn(ctx, func(repo models.Repository) error {
if err := tx.Commit(); err != nil { qb := repo.Scene()
ret, err = qb.DecrementOCounter(sceneID)
return err
}); err != nil {
return 0, err return 0, err
} }
return newVal, nil return ret, nil
} }
func (r *mutationResolver) SceneResetO(ctx context.Context, id string) (int, error) { func (r *mutationResolver) SceneResetO(ctx context.Context, id string) (ret int, err error) {
sceneID, _ := strconv.Atoi(id) sceneID, err := strconv.Atoi(id)
tx := database.DB.MustBeginTx(ctx, nil)
qb := models.NewSceneQueryBuilder()
newVal, err := qb.ResetOCounter(sceneID, tx)
if err != nil { if err != nil {
_ = tx.Rollback()
return 0, err return 0, err
} }
// Commit if err := r.withTxn(ctx, func(repo models.Repository) error {
if err := tx.Commit(); err != nil { qb := repo.Scene()
ret, err = qb.ResetOCounter(sceneID)
return err
}); err != nil {
return 0, err return 0, err
} }
return newVal, nil return ret, nil
} }
func (r *mutationResolver) SceneGenerateScreenshot(ctx context.Context, id string, at *float64) (string, error) { func (r *mutationResolver) SceneGenerateScreenshot(ctx context.Context, id string, at *float64) (string, error) {

View File

@@ -16,7 +16,7 @@ func (r *mutationResolver) SubmitStashBoxFingerprints(ctx context.Context, input
return false, fmt.Errorf("invalid stash_box_index %d", input.StashBoxIndex) return false, fmt.Errorf("invalid stash_box_index %d", input.StashBoxIndex)
} }
client := stashbox.NewClient(*boxes[input.StashBoxIndex]) client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.txnManager)
return client.SubmitStashBoxFingerprints(input.SceneIds, boxes[input.StashBoxIndex].Endpoint) return client.SubmitStashBoxFingerprints(input.SceneIds, boxes[input.StashBoxIndex].Endpoint)
} }

View File

@@ -6,7 +6,6 @@ import (
"strconv" "strconv"
"time" "time"
"github.com/stashapp/stash/pkg/database"
"github.com/stashapp/stash/pkg/manager" "github.com/stashapp/stash/pkg/manager"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
@@ -44,41 +43,33 @@ func (r *mutationResolver) StudioCreate(ctx context.Context, input models.Studio
} }
// Start the transaction and save the studio // Start the transaction and save the studio
tx := database.DB.MustBeginTx(ctx, nil) var studio *models.Studio
qb := models.NewStudioQueryBuilder() if err := r.withTxn(ctx, func(repo models.Repository) error {
jqb := models.NewJoinsQueryBuilder() qb := repo.Studio()
studio, err := qb.Create(newStudio, tx) var err error
if err != nil { studio, err = qb.Create(newStudio)
_ = tx.Rollback() if err != nil {
return nil, err return err
}
// update image table
if len(imageData) > 0 {
if err := qb.UpdateStudioImage(studio.ID, imageData, tx); err != nil {
_ = tx.Rollback()
return nil, err
} }
}
// Save the stash_ids // update image table
if input.StashIds != nil { if len(imageData) > 0 {
var stashIDJoins []models.StashID if err := qb.UpdateImage(studio.ID, imageData); err != nil {
for _, stashID := range input.StashIds { return err
newJoin := models.StashID{
StashID: stashID.StashID,
Endpoint: stashID.Endpoint,
} }
stashIDJoins = append(stashIDJoins, newJoin)
} }
if err := jqb.UpdateStudioStashIDs(studio.ID, stashIDJoins, tx); err != nil {
return nil, err
}
}
// Commit // Save the stash_ids
if err := tx.Commit(); err != nil { if input.StashIds != nil {
stashIDJoins := models.StashIDsFromInput(input.StashIds)
if err := qb.UpdateStashIDs(studio.ID, stashIDJoins); err != nil {
return err
}
}
return nil
}); err != nil {
return nil, err return nil, err
} }
@@ -87,7 +78,10 @@ func (r *mutationResolver) StudioCreate(ctx context.Context, input models.Studio
func (r *mutationResolver) StudioUpdate(ctx context.Context, input models.StudioUpdateInput) (*models.Studio, error) { func (r *mutationResolver) StudioUpdate(ctx context.Context, input models.StudioUpdateInput) (*models.Studio, error) {
// Populate studio from the input // Populate studio from the input
studioID, _ := strconv.Atoi(input.ID) studioID, err := strconv.Atoi(input.ID)
if err != nil {
return nil, err
}
translator := changesetTranslator{ translator := changesetTranslator{
inputMap: getUpdateInputMap(ctx), inputMap: getUpdateInputMap(ctx),
@@ -118,52 +112,42 @@ func (r *mutationResolver) StudioUpdate(ctx context.Context, input models.Studio
updatedStudio.ParentID = translator.nullInt64FromString(input.ParentID, "parent_id") updatedStudio.ParentID = translator.nullInt64FromString(input.ParentID, "parent_id")
// Start the transaction and save the studio // Start the transaction and save the studio
tx := database.DB.MustBeginTx(ctx, nil) var studio *models.Studio
qb := models.NewStudioQueryBuilder() if err := r.withTxn(ctx, func(repo models.Repository) error {
jqb := models.NewJoinsQueryBuilder() qb := repo.Studio()
if err := manager.ValidateModifyStudio(updatedStudio, tx); err != nil { if err := manager.ValidateModifyStudio(updatedStudio, qb); err != nil {
tx.Rollback() return err
return nil, err
}
studio, err := qb.Update(updatedStudio, tx)
if err != nil {
tx.Rollback()
return nil, err
}
// update image table
if len(imageData) > 0 {
if err := qb.UpdateStudioImage(studio.ID, imageData, tx); err != nil {
tx.Rollback()
return nil, err
} }
} else if imageIncluded {
// must be unsetting
if err := qb.DestroyStudioImage(studio.ID, tx); err != nil {
tx.Rollback()
return nil, err
}
}
// Save the stash_ids var err error
if translator.hasField("stash_ids") { studio, err = qb.Update(updatedStudio)
var stashIDJoins []models.StashID if err != nil {
for _, stashID := range input.StashIds { return err
newJoin := models.StashID{ }
StashID: stashID.StashID,
Endpoint: stashID.Endpoint, // update image table
if len(imageData) > 0 {
if err := qb.UpdateImage(studio.ID, imageData); err != nil {
return err
}
} else if imageIncluded {
// must be unsetting
if err := qb.DestroyImage(studio.ID); err != nil {
return err
} }
stashIDJoins = append(stashIDJoins, newJoin)
} }
if err := jqb.UpdateStudioStashIDs(studioID, stashIDJoins, tx); err != nil {
return nil, err
}
}
// Commit // Save the stash_ids
if err := tx.Commit(); err != nil { if translator.hasField("stash_ids") {
stashIDJoins := models.StashIDsFromInput(input.StashIds)
if err := qb.UpdateStashIDs(studioID, stashIDJoins); err != nil {
return err
}
}
return nil
}); err != nil {
return nil, err return nil, err
} }
@@ -171,28 +155,35 @@ func (r *mutationResolver) StudioUpdate(ctx context.Context, input models.Studio
} }
func (r *mutationResolver) StudioDestroy(ctx context.Context, input models.StudioDestroyInput) (bool, error) { func (r *mutationResolver) StudioDestroy(ctx context.Context, input models.StudioDestroyInput) (bool, error) {
qb := models.NewStudioQueryBuilder() id, err := strconv.Atoi(input.ID)
tx := database.DB.MustBeginTx(ctx, nil) if err != nil {
if err := qb.Destroy(input.ID, tx); err != nil {
_ = tx.Rollback()
return false, err return false, err
} }
if err := tx.Commit(); err != nil {
if err := r.withTxn(ctx, func(repo models.Repository) error {
return repo.Studio().Destroy(id)
}); err != nil {
return false, err return false, err
} }
return true, nil return true, nil
} }
func (r *mutationResolver) StudiosDestroy(ctx context.Context, ids []string) (bool, error) { func (r *mutationResolver) StudiosDestroy(ctx context.Context, studioIDs []string) (bool, error) {
qb := models.NewStudioQueryBuilder() ids, err := utils.StringSliceToIntSlice(studioIDs)
tx := database.DB.MustBeginTx(ctx, nil) if err != nil {
for _, id := range ids { return false, err
if err := qb.Destroy(id, tx); err != nil {
_ = tx.Rollback()
return false, err
}
} }
if err := tx.Commit(); err != nil {
if err := r.withTxn(ctx, func(repo models.Repository) error {
qb := repo.Studio()
for _, id := range ids {
if err := qb.Destroy(id); err != nil {
return err
}
}
return nil
}); err != nil {
return false, err return false, err
} }
return true, nil return true, nil

View File

@@ -6,7 +6,6 @@ import (
"strconv" "strconv"
"time" "time"
"github.com/stashapp/stash/pkg/database"
"github.com/stashapp/stash/pkg/manager" "github.com/stashapp/stash/pkg/manager"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
@@ -33,31 +32,29 @@ func (r *mutationResolver) TagCreate(ctx context.Context, input models.TagCreate
} }
// Start the transaction and save the tag // Start the transaction and save the tag
tx := database.DB.MustBeginTx(ctx, nil) var tag *models.Tag
qb := models.NewTagQueryBuilder() if err := r.withTxn(ctx, func(repo models.Repository) error {
qb := repo.Tag()
// ensure name is unique // ensure name is unique
if err := manager.EnsureTagNameUnique(newTag, tx); err != nil { if err := manager.EnsureTagNameUnique(newTag, qb); err != nil {
tx.Rollback() return err
return nil, err
}
tag, err := qb.Create(newTag, tx)
if err != nil {
tx.Rollback()
return nil, err
}
// update image table
if len(imageData) > 0 {
if err := qb.UpdateTagImage(tag.ID, imageData, tx); err != nil {
_ = tx.Rollback()
return nil, err
} }
}
// Commit tag, err = qb.Create(newTag)
if err := tx.Commit(); err != nil { if err != nil {
return err
}
// update image table
if len(imageData) > 0 {
if err := qb.UpdateImage(tag.ID, imageData); err != nil {
return err
}
}
return nil
}); err != nil {
return nil, err return nil, err
} }
@@ -66,7 +63,11 @@ func (r *mutationResolver) TagCreate(ctx context.Context, input models.TagCreate
func (r *mutationResolver) TagUpdate(ctx context.Context, input models.TagUpdateInput) (*models.Tag, error) { func (r *mutationResolver) TagUpdate(ctx context.Context, input models.TagUpdateInput) (*models.Tag, error) {
// Populate tag from the input // Populate tag from the input
tagID, _ := strconv.Atoi(input.ID) tagID, err := strconv.Atoi(input.ID)
if err != nil {
return nil, err
}
updatedTag := models.Tag{ updatedTag := models.Tag{
ID: tagID, ID: tagID,
Name: input.Name, Name: input.Name,
@@ -74,7 +75,6 @@ func (r *mutationResolver) TagUpdate(ctx context.Context, input models.TagUpdate
} }
var imageData []byte var imageData []byte
var err error
translator := changesetTranslator{ translator := changesetTranslator{
inputMap: getUpdateInputMap(ctx), inputMap: getUpdateInputMap(ctx),
@@ -90,50 +90,45 @@ func (r *mutationResolver) TagUpdate(ctx context.Context, input models.TagUpdate
} }
// Start the transaction and save the tag // Start the transaction and save the tag
tx := database.DB.MustBeginTx(ctx, nil) var tag *models.Tag
qb := models.NewTagQueryBuilder() if err := r.withTxn(ctx, func(repo models.Repository) error {
qb := repo.Tag()
// ensure name is unique // ensure name is unique
existing, err := qb.Find(tagID, tx) existing, err := qb.Find(tagID)
if err != nil { if err != nil {
tx.Rollback() return err
return nil, err
}
if existing == nil {
tx.Rollback()
return nil, fmt.Errorf("Tag with ID %d not found", tagID)
}
if existing.Name != updatedTag.Name {
if err := manager.EnsureTagNameUnique(updatedTag, tx); err != nil {
tx.Rollback()
return nil, err
} }
}
tag, err := qb.Update(updatedTag, tx) if existing == nil {
if err != nil { return fmt.Errorf("Tag with ID %d not found", tagID)
_ = tx.Rollback()
return nil, err
}
// update image table
if len(imageData) > 0 {
if err := qb.UpdateTagImage(tag.ID, imageData, tx); err != nil {
tx.Rollback()
return nil, err
} }
} else if imageIncluded {
// must be unsetting
if err := qb.DestroyTagImage(tag.ID, tx); err != nil {
tx.Rollback()
return nil, err
}
}
// Commit if existing.Name != updatedTag.Name {
if err := tx.Commit(); err != nil { if err := manager.EnsureTagNameUnique(updatedTag, qb); err != nil {
return err
}
}
tag, err = qb.Update(updatedTag)
if err != nil {
return err
}
// update image table
if len(imageData) > 0 {
if err := qb.UpdateImage(tag.ID, imageData); err != nil {
return err
}
} else if imageIncluded {
// must be unsetting
if err := qb.DestroyImage(tag.ID); err != nil {
return err
}
}
return nil
}); err != nil {
return nil, err return nil, err
} }
@@ -141,29 +136,35 @@ func (r *mutationResolver) TagUpdate(ctx context.Context, input models.TagUpdate
} }
func (r *mutationResolver) TagDestroy(ctx context.Context, input models.TagDestroyInput) (bool, error) { func (r *mutationResolver) TagDestroy(ctx context.Context, input models.TagDestroyInput) (bool, error) {
qb := models.NewTagQueryBuilder() tagID, err := strconv.Atoi(input.ID)
tx := database.DB.MustBeginTx(ctx, nil) if err != nil {
if err := qb.Destroy(input.ID, tx); err != nil {
_ = tx.Rollback()
return false, err return false, err
} }
if err := tx.Commit(); err != nil {
if err := r.withTxn(ctx, func(repo models.Repository) error {
return repo.Tag().Destroy(tagID)
}); err != nil {
return false, err return false, err
} }
return true, nil return true, nil
} }
func (r *mutationResolver) TagsDestroy(ctx context.Context, ids []string) (bool, error) { func (r *mutationResolver) TagsDestroy(ctx context.Context, tagIDs []string) (bool, error) {
qb := models.NewTagQueryBuilder() ids, err := utils.StringSliceToIntSlice(tagIDs)
tx := database.DB.MustBeginTx(ctx, nil) if err != nil {
return false, err
}
for _, id := range ids { if err := r.withTxn(ctx, func(repo models.Repository) error {
if err := qb.Destroy(id, tx); err != nil { qb := repo.Tag()
_ = tx.Rollback() for _, id := range ids {
return false, err if err := qb.Destroy(id); err != nil {
return err
}
} }
}
if err := tx.Commit(); err != nil { return nil
}); err != nil {
return false, err return false, err
} }
return true, nil return true, nil

View File

@@ -0,0 +1,70 @@
package api
import (
"context"
"errors"
"testing"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/models/mocks"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
// TODO - move this into a common area
func newResolver() *Resolver {
return &Resolver{
txnManager: mocks.NewTransactionManager(),
}
}
const tagName = "tagName"
const errTagName = "errTagName"
const existingTagID = 1
const existingTagName = "existingTagName"
const newTagID = 2
func TestTagCreate(t *testing.T) {
r := newResolver()
tagRW := r.txnManager.(*mocks.TransactionManager).Tag().(*mocks.TagReaderWriter)
tagRW.On("FindByName", existingTagName, true).Return(&models.Tag{
ID: existingTagID,
Name: existingTagName,
}, nil).Once()
tagRW.On("FindByName", errTagName, true).Return(nil, nil).Once()
expectedErr := errors.New("TagCreate error")
tagRW.On("Create", mock.AnythingOfType("models.Tag")).Return(nil, expectedErr)
_, err := r.Mutation().TagCreate(context.TODO(), models.TagCreateInput{
Name: existingTagName,
})
assert.NotNil(t, err)
_, err = r.Mutation().TagCreate(context.TODO(), models.TagCreateInput{
Name: errTagName,
})
assert.Equal(t, expectedErr, err)
tagRW.AssertExpectations(t)
r = newResolver()
tagRW = r.txnManager.(*mocks.TransactionManager).Tag().(*mocks.TagReaderWriter)
tagRW.On("FindByName", tagName, true).Return(nil, nil).Once()
tagRW.On("Create", mock.AnythingOfType("models.Tag")).Return(&models.Tag{
ID: newTagID,
Name: tagName,
}, nil)
tag, err := r.Mutation().TagCreate(context.TODO(), models.TagCreateInput{
Name: tagName,
})
assert.Nil(t, err)
assert.NotNil(t, tag)
}

View File

@@ -7,17 +7,37 @@ import (
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
) )
func (r *queryResolver) FindGallery(ctx context.Context, id string) (*models.Gallery, error) { func (r *queryResolver) FindGallery(ctx context.Context, id string) (ret *models.Gallery, err error) {
qb := models.NewGalleryQueryBuilder() idInt, err := strconv.Atoi(id)
idInt, _ := strconv.Atoi(id) if err != nil {
return qb.Find(idInt, nil) return nil, err
}
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Gallery().Find(idInt)
return err
}); err != nil {
return nil, err
}
return ret, nil
} }
func (r *queryResolver) FindGalleries(ctx context.Context, galleryFilter *models.GalleryFilterType, filter *models.FindFilterType) (*models.FindGalleriesResultType, error) { func (r *queryResolver) FindGalleries(ctx context.Context, galleryFilter *models.GalleryFilterType, filter *models.FindFilterType) (ret *models.FindGalleriesResultType, err error) {
qb := models.NewGalleryQueryBuilder() if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
galleries, total := qb.Query(galleryFilter, filter) galleries, total, err := repo.Gallery().Query(galleryFilter, filter)
return &models.FindGalleriesResultType{ if err != nil {
Count: total, return err
Galleries: galleries, }
}, nil
ret = &models.FindGalleriesResultType{
Count: total,
Galleries: galleries,
}
return nil
}); err != nil {
return nil, err
}
return ret, nil
} }

View File

@@ -8,23 +8,48 @@ import (
) )
func (r *queryResolver) FindImage(ctx context.Context, id *string, checksum *string) (*models.Image, error) { func (r *queryResolver) FindImage(ctx context.Context, id *string, checksum *string) (*models.Image, error) {
qb := models.NewImageQueryBuilder()
var image *models.Image var image *models.Image
var err error
if id != nil { if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
idInt, _ := strconv.Atoi(*id) qb := repo.Image()
image, err = qb.Find(idInt) var err error
} else if checksum != nil {
image, err = qb.FindByChecksum(*checksum) if id != nil {
idInt, err := strconv.Atoi(*id)
if err != nil {
return err
}
image, err = qb.Find(idInt)
} else if checksum != nil {
image, err = qb.FindByChecksum(*checksum)
}
return err
}); err != nil {
return nil, err
} }
return image, err
return image, nil
} }
func (r *queryResolver) FindImages(ctx context.Context, imageFilter *models.ImageFilterType, imageIds []int, filter *models.FindFilterType) (*models.FindImagesResultType, error) { func (r *queryResolver) FindImages(ctx context.Context, imageFilter *models.ImageFilterType, imageIds []int, filter *models.FindFilterType) (ret *models.FindImagesResultType, err error) {
qb := models.NewImageQueryBuilder() if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
images, total := qb.Query(imageFilter, filter) qb := repo.Image()
return &models.FindImagesResultType{ images, total, err := qb.Query(imageFilter, filter)
Count: total, if err != nil {
Images: images, return err
}, nil }
ret = &models.FindImagesResultType{
Count: total,
Images: images,
}
return nil
}); err != nil {
return nil, err
}
return ret, nil
} }

View File

@@ -7,27 +7,60 @@ import (
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
) )
func (r *queryResolver) FindMovie(ctx context.Context, id string) (*models.Movie, error) { func (r *queryResolver) FindMovie(ctx context.Context, id string) (ret *models.Movie, err error) {
qb := models.NewMovieQueryBuilder() idInt, err := strconv.Atoi(id)
idInt, _ := strconv.Atoi(id) if err != nil {
return qb.Find(idInt, nil) return nil, err
}
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Movie().Find(idInt)
return err
}); err != nil {
return nil, err
}
return ret, nil
} }
func (r *queryResolver) FindMovies(ctx context.Context, movieFilter *models.MovieFilterType, filter *models.FindFilterType) (*models.FindMoviesResultType, error) { func (r *queryResolver) FindMovies(ctx context.Context, movieFilter *models.MovieFilterType, filter *models.FindFilterType) (ret *models.FindMoviesResultType, err error) {
qb := models.NewMovieQueryBuilder() if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
movies, total := qb.Query(movieFilter, filter) movies, total, err := repo.Movie().Query(movieFilter, filter)
return &models.FindMoviesResultType{ if err != nil {
Count: total, return err
Movies: movies, }
}, nil
ret = &models.FindMoviesResultType{
Count: total,
Movies: movies,
}
return nil
}); err != nil {
return nil, err
}
return ret, nil
} }
func (r *queryResolver) AllMovies(ctx context.Context) ([]*models.Movie, error) { func (r *queryResolver) AllMovies(ctx context.Context) (ret []*models.Movie, err error) {
qb := models.NewMovieQueryBuilder() if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
return qb.All() ret, err = repo.Movie().All()
return err
}); err != nil {
return nil, err
}
return ret, nil
} }
func (r *queryResolver) AllMoviesSlim(ctx context.Context) ([]*models.Movie, error) { func (r *queryResolver) AllMoviesSlim(ctx context.Context) (ret []*models.Movie, err error) {
qb := models.NewMovieQueryBuilder() if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
return qb.AllSlim() ret, err = repo.Movie().AllSlim()
return err
}); err != nil {
return nil, err
}
return ret, nil
} }

View File

@@ -2,31 +2,64 @@ package api
import ( import (
"context" "context"
"github.com/stashapp/stash/pkg/models"
"strconv" "strconv"
"github.com/stashapp/stash/pkg/models"
) )
func (r *queryResolver) FindPerformer(ctx context.Context, id string) (*models.Performer, error) { func (r *queryResolver) FindPerformer(ctx context.Context, id string) (ret *models.Performer, err error) {
qb := models.NewPerformerQueryBuilder() idInt, err := strconv.Atoi(id)
idInt, _ := strconv.Atoi(id) if err != nil {
return qb.Find(idInt) return nil, err
}
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Performer().Find(idInt)
return err
}); err != nil {
return nil, err
}
return ret, nil
} }
func (r *queryResolver) FindPerformers(ctx context.Context, performerFilter *models.PerformerFilterType, filter *models.FindFilterType) (*models.FindPerformersResultType, error) { func (r *queryResolver) FindPerformers(ctx context.Context, performerFilter *models.PerformerFilterType, filter *models.FindFilterType) (ret *models.FindPerformersResultType, err error) {
qb := models.NewPerformerQueryBuilder() if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
performers, total := qb.Query(performerFilter, filter) performers, total, err := repo.Performer().Query(performerFilter, filter)
return &models.FindPerformersResultType{ if err != nil {
Count: total, return err
Performers: performers, }
}, nil
ret = &models.FindPerformersResultType{
Count: total,
Performers: performers,
}
return nil
}); err != nil {
return nil, err
}
return ret, nil
} }
func (r *queryResolver) AllPerformers(ctx context.Context) ([]*models.Performer, error) { func (r *queryResolver) AllPerformers(ctx context.Context) (ret []*models.Performer, err error) {
qb := models.NewPerformerQueryBuilder() if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
return qb.All() ret, err = repo.Performer().All()
return err
}); err != nil {
return nil, err
}
return ret, nil
} }
func (r *queryResolver) AllPerformersSlim(ctx context.Context) ([]*models.Performer, error) { func (r *queryResolver) AllPerformersSlim(ctx context.Context) (ret []*models.Performer, err error) {
qb := models.NewPerformerQueryBuilder() if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
return qb.AllSlim() ret, err = repo.Performer().AllSlim()
return err
}); err != nil {
return nil, err
}
return ret, nil
} }

View File

@@ -9,70 +9,114 @@ import (
) )
func (r *queryResolver) FindScene(ctx context.Context, id *string, checksum *string) (*models.Scene, error) { func (r *queryResolver) FindScene(ctx context.Context, id *string, checksum *string) (*models.Scene, error) {
qb := models.NewSceneQueryBuilder()
var scene *models.Scene var scene *models.Scene
var err error if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
if id != nil { qb := repo.Scene()
idInt, _ := strconv.Atoi(*id) var err error
scene, err = qb.Find(idInt) if id != nil {
} else if checksum != nil { idInt, err := strconv.Atoi(*id)
scene, err = qb.FindByChecksum(*checksum) if err != nil {
} return err
return scene, err }
} scene, err = qb.Find(idInt)
} else if checksum != nil {
func (r *queryResolver) FindSceneByHash(ctx context.Context, input models.SceneHashInput) (*models.Scene, error) { scene, err = qb.FindByChecksum(*checksum)
qb := models.NewSceneQueryBuilder()
var scene *models.Scene
var err error
if input.Checksum != nil {
scene, err = qb.FindByChecksum(*input.Checksum)
if err != nil {
return nil, err
} }
}
if scene == nil && input.Oshash != nil { return err
scene, err = qb.FindByOSHash(*input.Oshash) }); err != nil {
if err != nil {
return nil, err
}
}
return scene, err
}
func (r *queryResolver) FindScenes(ctx context.Context, sceneFilter *models.SceneFilterType, sceneIds []int, filter *models.FindFilterType) (*models.FindScenesResultType, error) {
qb := models.NewSceneQueryBuilder()
scenes, total := qb.Query(sceneFilter, filter)
return &models.FindScenesResultType{
Count: total,
Scenes: scenes,
}, nil
}
func (r *queryResolver) FindScenesByPathRegex(ctx context.Context, filter *models.FindFilterType) (*models.FindScenesResultType, error) {
qb := models.NewSceneQueryBuilder()
scenes, total := qb.QueryByPathRegex(filter)
return &models.FindScenesResultType{
Count: total,
Scenes: scenes,
}, nil
}
func (r *queryResolver) ParseSceneFilenames(ctx context.Context, filter *models.FindFilterType, config models.SceneParserInput) (*models.SceneParserResultType, error) {
parser := manager.NewSceneFilenameParser(filter, config)
result, count, err := parser.Parse()
if err != nil {
return nil, err return nil, err
} }
return &models.SceneParserResultType{ return scene, nil
Count: count, }
Results: result,
}, nil func (r *queryResolver) FindSceneByHash(ctx context.Context, input models.SceneHashInput) (*models.Scene, error) {
var scene *models.Scene
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
qb := repo.Scene()
var err error
if input.Checksum != nil {
scene, err = qb.FindByChecksum(*input.Checksum)
if err != nil {
return err
}
}
if scene == nil && input.Oshash != nil {
scene, err = qb.FindByOSHash(*input.Oshash)
if err != nil {
return err
}
}
return nil
}); err != nil {
return nil, err
}
return scene, nil
}
func (r *queryResolver) FindScenes(ctx context.Context, sceneFilter *models.SceneFilterType, sceneIds []int, filter *models.FindFilterType) (ret *models.FindScenesResultType, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
scenes, total, err := repo.Scene().Query(sceneFilter, filter)
if err != nil {
return err
}
ret = &models.FindScenesResultType{
Count: total,
Scenes: scenes,
}
return nil
}); err != nil {
return nil, err
}
return ret, nil
}
func (r *queryResolver) FindScenesByPathRegex(ctx context.Context, filter *models.FindFilterType) (ret *models.FindScenesResultType, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
scenes, total, err := repo.Scene().QueryByPathRegex(filter)
if err != nil {
return err
}
ret = &models.FindScenesResultType{
Count: total,
Scenes: scenes,
}
return nil
}); err != nil {
return nil, err
}
return ret, nil
}
func (r *queryResolver) ParseSceneFilenames(ctx context.Context, filter *models.FindFilterType, config models.SceneParserInput) (ret *models.SceneParserResultType, err error) {
parser := manager.NewSceneFilenameParser(filter, config)
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
result, count, err := parser.Parse(repo)
if err != nil {
return err
}
ret = &models.SceneParserResultType{
Count: count,
Results: result,
}
return nil
}); err != nil {
return nil, err
}
return ret, nil
} }

View File

@@ -2,14 +2,25 @@ package api
import ( import (
"context" "context"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
) )
func (r *queryResolver) FindSceneMarkers(ctx context.Context, sceneMarkerFilter *models.SceneMarkerFilterType, filter *models.FindFilterType) (*models.FindSceneMarkersResultType, error) { func (r *queryResolver) FindSceneMarkers(ctx context.Context, sceneMarkerFilter *models.SceneMarkerFilterType, filter *models.FindFilterType) (ret *models.FindSceneMarkersResultType, err error) {
qb := models.NewSceneMarkerQueryBuilder() if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
sceneMarkers, total := qb.Query(sceneMarkerFilter, filter) sceneMarkers, total, err := repo.SceneMarker().Query(sceneMarkerFilter, filter)
return &models.FindSceneMarkersResultType{ if err != nil {
Count: total, return err
SceneMarkers: sceneMarkers, }
}, nil ret = &models.FindSceneMarkersResultType{
Count: total,
SceneMarkers: sceneMarkers,
}
return nil
}); err != nil {
return nil, err
}
return ret, nil
} }

View File

@@ -2,31 +2,66 @@ package api
import ( import (
"context" "context"
"github.com/stashapp/stash/pkg/models"
"strconv" "strconv"
"github.com/stashapp/stash/pkg/models"
) )
func (r *queryResolver) FindStudio(ctx context.Context, id string) (*models.Studio, error) { func (r *queryResolver) FindStudio(ctx context.Context, id string) (ret *models.Studio, err error) {
qb := models.NewStudioQueryBuilder() idInt, err := strconv.Atoi(id)
idInt, _ := strconv.Atoi(id) if err != nil {
return qb.Find(idInt, nil) return nil, err
}
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
var err error
ret, err = repo.Studio().Find(idInt)
return err
}); err != nil {
return nil, err
}
return ret, nil
} }
func (r *queryResolver) FindStudios(ctx context.Context, studioFilter *models.StudioFilterType, filter *models.FindFilterType) (*models.FindStudiosResultType, error) { func (r *queryResolver) FindStudios(ctx context.Context, studioFilter *models.StudioFilterType, filter *models.FindFilterType) (ret *models.FindStudiosResultType, err error) {
qb := models.NewStudioQueryBuilder() if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
studios, total := qb.Query(studioFilter, filter) studios, total, err := repo.Studio().Query(studioFilter, filter)
return &models.FindStudiosResultType{ if err != nil {
Count: total, return err
Studios: studios, }
}, nil
ret = &models.FindStudiosResultType{
Count: total,
Studios: studios,
}
return nil
}); err != nil {
return nil, err
}
return ret, nil
} }
func (r *queryResolver) AllStudios(ctx context.Context) ([]*models.Studio, error) { func (r *queryResolver) AllStudios(ctx context.Context) (ret []*models.Studio, err error) {
qb := models.NewStudioQueryBuilder() if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
return qb.All() ret, err = repo.Studio().All()
return err
}); err != nil {
return nil, err
}
return ret, nil
} }
func (r *queryResolver) AllStudiosSlim(ctx context.Context) ([]*models.Studio, error) { func (r *queryResolver) AllStudiosSlim(ctx context.Context) (ret []*models.Studio, err error) {
qb := models.NewStudioQueryBuilder() if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
return qb.AllSlim() ret, err = repo.Studio().AllSlim()
return err
}); err != nil {
return nil, err
}
return ret, nil
} }

View File

@@ -7,27 +7,60 @@ import (
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
) )
func (r *queryResolver) FindTag(ctx context.Context, id string) (*models.Tag, error) { func (r *queryResolver) FindTag(ctx context.Context, id string) (ret *models.Tag, err error) {
qb := models.NewTagQueryBuilder() idInt, err := strconv.Atoi(id)
idInt, _ := strconv.Atoi(id) if err != nil {
return qb.Find(idInt, nil) return nil, err
}
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Tag().Find(idInt)
return err
}); err != nil {
return nil, err
}
return ret, nil
} }
func (r *queryResolver) FindTags(ctx context.Context, tagFilter *models.TagFilterType, filter *models.FindFilterType) (*models.FindTagsResultType, error) { func (r *queryResolver) FindTags(ctx context.Context, tagFilter *models.TagFilterType, filter *models.FindFilterType) (ret *models.FindTagsResultType, err error) {
qb := models.NewTagQueryBuilder() if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
tags, total := qb.Query(tagFilter, filter) tags, total, err := repo.Tag().Query(tagFilter, filter)
return &models.FindTagsResultType{ if err != nil {
Count: total, return err
Tags: tags, }
}, nil
ret = &models.FindTagsResultType{
Count: total,
Tags: tags,
}
return nil
}); err != nil {
return nil, err
}
return ret, nil
} }
func (r *queryResolver) AllTags(ctx context.Context) ([]*models.Tag, error) { func (r *queryResolver) AllTags(ctx context.Context) (ret []*models.Tag, err error) {
qb := models.NewTagQueryBuilder() if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
return qb.All() ret, err = repo.Tag().All()
return err
}); err != nil {
return nil, err
}
return ret, nil
} }
func (r *queryResolver) AllTagsSlim(ctx context.Context) ([]*models.Tag, error) { func (r *queryResolver) AllTagsSlim(ctx context.Context) (ret []*models.Tag, err error) {
qb := models.NewTagQueryBuilder() if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
return qb.AllSlim() ret, err = repo.Tag().AllSlim()
return err
}); err != nil {
return nil, err
}
return ret, nil
} }

View File

@@ -12,11 +12,13 @@ import (
func (r *queryResolver) SceneStreams(ctx context.Context, id *string) ([]*models.SceneStreamEndpoint, error) { func (r *queryResolver) SceneStreams(ctx context.Context, id *string) ([]*models.SceneStreamEndpoint, error) {
// find the scene // find the scene
qb := models.NewSceneQueryBuilder() var scene *models.Scene
idInt, _ := strconv.Atoi(*id) if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
scene, err := qb.Find(idInt) idInt, _ := strconv.Atoi(*id)
var err error
if err != nil { scene, err = repo.Scene().Find(idInt)
return err
}); err != nil {
return nil, err return nil, err
} }

View File

@@ -95,7 +95,7 @@ func (r *queryResolver) QueryStashBoxScene(ctx context.Context, input models.Sta
return nil, fmt.Errorf("invalid stash_box_index %d", input.StashBoxIndex) return nil, fmt.Errorf("invalid stash_box_index %d", input.StashBoxIndex)
} }
client := stashbox.NewClient(*boxes[input.StashBoxIndex]) client := stashbox.NewClient(*boxes[input.StashBoxIndex], r.txnManager)
if len(input.SceneIds) > 0 { if len(input.SceneIds) > 0 {
return client.FindStashBoxScenesByFingerprints(input.SceneIds) return client.FindStashBoxScenesByFingerprints(input.SceneIds)

View File

@@ -12,7 +12,9 @@ import (
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
) )
type imageRoutes struct{} type imageRoutes struct {
txnManager models.TransactionManager
}
func (rs imageRoutes) Routes() chi.Router { func (rs imageRoutes) Routes() chi.Router {
r := chi.NewRouter() r := chi.NewRouter()
@@ -57,12 +59,16 @@ func ImageCtx(next http.Handler) http.Handler {
imageID, _ := strconv.Atoi(imageIdentifierQueryParam) imageID, _ := strconv.Atoi(imageIdentifierQueryParam)
var image *models.Image var image *models.Image
qb := models.NewImageQueryBuilder() manager.GetInstance().TxnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
if imageID == 0 { qb := repo.Image()
image, _ = qb.FindByChecksum(imageIdentifierQueryParam) if imageID == 0 {
} else { image, _ = qb.FindByChecksum(imageIdentifierQueryParam)
image, _ = qb.Find(imageID) } else {
} image, _ = qb.Find(imageID)
}
return nil
})
if image == nil { if image == nil {
http.Error(w, http.StatusText(404), 404) http.Error(w, http.StatusText(404), 404)

View File

@@ -6,11 +6,14 @@ import (
"strconv" "strconv"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/stashapp/stash/pkg/manager"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
) )
type movieRoutes struct{} type movieRoutes struct {
txnManager models.TransactionManager
}
func (rs movieRoutes) Routes() chi.Router { func (rs movieRoutes) Routes() chi.Router {
r := chi.NewRouter() r := chi.NewRouter()
@@ -26,11 +29,16 @@ func (rs movieRoutes) Routes() chi.Router {
func (rs movieRoutes) FrontImage(w http.ResponseWriter, r *http.Request) { func (rs movieRoutes) FrontImage(w http.ResponseWriter, r *http.Request) {
movie := r.Context().Value(movieKey).(*models.Movie) movie := r.Context().Value(movieKey).(*models.Movie)
qb := models.NewMovieQueryBuilder()
image, _ := qb.GetFrontImage(movie.ID, nil)
defaultParam := r.URL.Query().Get("default") defaultParam := r.URL.Query().Get("default")
if len(image) == 0 || defaultParam == "true" { var image []byte
if defaultParam != "true" {
rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
image, _ = repo.Movie().GetFrontImage(movie.ID)
return nil
})
}
if len(image) == 0 {
_, image, _ = utils.ProcessBase64Image(models.DefaultMovieImage) _, image, _ = utils.ProcessBase64Image(models.DefaultMovieImage)
} }
@@ -39,11 +47,16 @@ func (rs movieRoutes) FrontImage(w http.ResponseWriter, r *http.Request) {
func (rs movieRoutes) BackImage(w http.ResponseWriter, r *http.Request) { func (rs movieRoutes) BackImage(w http.ResponseWriter, r *http.Request) {
movie := r.Context().Value(movieKey).(*models.Movie) movie := r.Context().Value(movieKey).(*models.Movie)
qb := models.NewMovieQueryBuilder()
image, _ := qb.GetBackImage(movie.ID, nil)
defaultParam := r.URL.Query().Get("default") defaultParam := r.URL.Query().Get("default")
if len(image) == 0 || defaultParam == "true" { var image []byte
if defaultParam != "true" {
rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
image, _ = repo.Movie().GetBackImage(movie.ID)
return nil
})
}
if len(image) == 0 {
_, image, _ = utils.ProcessBase64Image(models.DefaultMovieImage) _, image, _ = utils.ProcessBase64Image(models.DefaultMovieImage)
} }
@@ -58,9 +71,12 @@ func MovieCtx(next http.Handler) http.Handler {
return return
} }
qb := models.NewMovieQueryBuilder() var movie *models.Movie
movie, err := qb.Find(movieID, nil) if err := manager.GetInstance().TxnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
if err != nil { var err error
movie, err = repo.Movie().Find(movieID)
return err
}); err != nil {
http.Error(w, http.StatusText(404), 404) http.Error(w, http.StatusText(404), 404)
return return
} }

View File

@@ -6,11 +6,14 @@ import (
"strconv" "strconv"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/stashapp/stash/pkg/manager"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
) )
type performerRoutes struct{} type performerRoutes struct {
txnManager models.TransactionManager
}
func (rs performerRoutes) Routes() chi.Router { func (rs performerRoutes) Routes() chi.Router {
r := chi.NewRouter() r := chi.NewRouter()
@@ -25,10 +28,16 @@ func (rs performerRoutes) Routes() chi.Router {
func (rs performerRoutes) Image(w http.ResponseWriter, r *http.Request) { func (rs performerRoutes) Image(w http.ResponseWriter, r *http.Request) {
performer := r.Context().Value(performerKey).(*models.Performer) performer := r.Context().Value(performerKey).(*models.Performer)
qb := models.NewPerformerQueryBuilder()
image, _ := qb.GetPerformerImage(performer.ID, nil)
defaultParam := r.URL.Query().Get("default") defaultParam := r.URL.Query().Get("default")
var image []byte
if defaultParam != "true" {
rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
image, _ = repo.Performer().GetImage(performer.ID)
return nil
})
}
if len(image) == 0 || defaultParam == "true" { if len(image) == 0 || defaultParam == "true" {
image, _ = getRandomPerformerImageUsingName(performer.Name.String, performer.Gender.String) image, _ = getRandomPerformerImageUsingName(performer.Name.String, performer.Gender.String)
} }
@@ -44,9 +53,12 @@ func PerformerCtx(next http.Handler) http.Handler {
return return
} }
qb := models.NewPerformerQueryBuilder() var performer *models.Performer
performer, err := qb.Find(performerID) if err := manager.GetInstance().TxnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
if err != nil { var err error
performer, err = repo.Performer().Find(performerID)
return err
}); err != nil {
http.Error(w, http.StatusText(404), 404) http.Error(w, http.StatusText(404), 404)
return return
} }

View File

@@ -15,7 +15,9 @@ import (
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
) )
type sceneRoutes struct{} type sceneRoutes struct {
txnManager models.TransactionManager
}
func (rs sceneRoutes) Routes() chi.Router { func (rs sceneRoutes) Routes() chi.Router {
r := chi.NewRouter() r := chi.NewRouter()
@@ -183,8 +185,11 @@ func (rs sceneRoutes) Screenshot(w http.ResponseWriter, r *http.Request) {
if screenshotExists { if screenshotExists {
http.ServeFile(w, r, filepath) http.ServeFile(w, r, filepath)
} else { } else {
qb := models.NewSceneQueryBuilder() var cover []byte
cover, _ := qb.GetSceneCover(scene.ID, nil) rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
cover, _ = repo.Scene().GetCover(scene.ID)
return nil
})
utils.ServeImage(cover, w, r) utils.ServeImage(cover, w, r)
} }
} }
@@ -201,39 +206,48 @@ func (rs sceneRoutes) Webp(w http.ResponseWriter, r *http.Request) {
http.ServeFile(w, r, filepath) http.ServeFile(w, r, filepath)
} }
func getChapterVttTitle(marker *models.SceneMarker) string { func (rs sceneRoutes) getChapterVttTitle(ctx context.Context, marker *models.SceneMarker) string {
if marker.Title != "" { if marker.Title != "" {
return marker.Title return marker.Title
} }
qb := models.NewTagQueryBuilder() var ret string
primaryTag, err := qb.Find(marker.PrimaryTagID, nil) if err := rs.txnManager.WithReadTxn(ctx, func(repo models.ReaderRepository) error {
if err != nil { qb := repo.Tag()
// should not happen primaryTag, err := qb.Find(marker.PrimaryTagID)
if err != nil {
return err
}
ret = primaryTag.Name
tags, err := qb.FindBySceneMarkerID(marker.ID)
if err != nil {
return err
}
for _, t := range tags {
ret += ", " + t.Name
}
return nil
}); err != nil {
panic(err) panic(err)
} }
ret := primaryTag.Name
tags, err := qb.FindBySceneMarkerID(marker.ID, nil)
if err != nil {
// should not happen
panic(err)
}
for _, t := range tags {
ret += ", " + t.Name
}
return ret return ret
} }
func (rs sceneRoutes) ChapterVtt(w http.ResponseWriter, r *http.Request) { func (rs sceneRoutes) ChapterVtt(w http.ResponseWriter, r *http.Request) {
scene := r.Context().Value(sceneKey).(*models.Scene) scene := r.Context().Value(sceneKey).(*models.Scene)
qb := models.NewSceneMarkerQueryBuilder() var sceneMarkers []*models.SceneMarker
sceneMarkers, err := qb.FindBySceneID(scene.ID, nil) if err := rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
if err != nil { var err error
panic("invalid scene markers for chapter vtt") sceneMarkers, err = repo.SceneMarker().FindBySceneID(scene.ID)
return err
}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
} }
vttLines := []string{"WEBVTT", ""} vttLines := []string{"WEBVTT", ""}
@@ -241,7 +255,7 @@ func (rs sceneRoutes) ChapterVtt(w http.ResponseWriter, r *http.Request) {
vttLines = append(vttLines, strconv.Itoa(i+1)) vttLines = append(vttLines, strconv.Itoa(i+1))
time := utils.GetVTTTime(marker.Seconds) time := utils.GetVTTTime(marker.Seconds)
vttLines = append(vttLines, time+" --> "+time) vttLines = append(vttLines, time+" --> "+time)
vttLines = append(vttLines, getChapterVttTitle(marker)) vttLines = append(vttLines, rs.getChapterVttTitle(r.Context(), marker))
vttLines = append(vttLines, "") vttLines = append(vttLines, "")
} }
vtt := strings.Join(vttLines, "\n") vtt := strings.Join(vttLines, "\n")
@@ -267,11 +281,14 @@ func (rs sceneRoutes) VttSprite(w http.ResponseWriter, r *http.Request) {
func (rs sceneRoutes) SceneMarkerStream(w http.ResponseWriter, r *http.Request) { func (rs sceneRoutes) SceneMarkerStream(w http.ResponseWriter, r *http.Request) {
scene := r.Context().Value(sceneKey).(*models.Scene) scene := r.Context().Value(sceneKey).(*models.Scene)
sceneMarkerID, _ := strconv.Atoi(chi.URLParam(r, "sceneMarkerId")) sceneMarkerID, _ := strconv.Atoi(chi.URLParam(r, "sceneMarkerId"))
qb := models.NewSceneMarkerQueryBuilder() var sceneMarker *models.SceneMarker
sceneMarker, err := qb.Find(sceneMarkerID) if err := rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
if err != nil { var err error
logger.Warn("Error when getting scene marker for stream") sceneMarker, err = repo.SceneMarker().Find(sceneMarkerID)
http.Error(w, http.StatusText(404), 404) return err
}); err != nil {
logger.Warnf("Error when getting scene marker for stream: %s", err.Error())
http.Error(w, http.StatusText(500), 500)
return return
} }
filepath := manager.GetInstance().Paths.SceneMarkers.GetStreamPath(scene.GetHash(config.GetVideoFileNamingAlgorithm()), int(sceneMarker.Seconds)) filepath := manager.GetInstance().Paths.SceneMarkers.GetStreamPath(scene.GetHash(config.GetVideoFileNamingAlgorithm()), int(sceneMarker.Seconds))
@@ -281,11 +298,14 @@ func (rs sceneRoutes) SceneMarkerStream(w http.ResponseWriter, r *http.Request)
func (rs sceneRoutes) SceneMarkerPreview(w http.ResponseWriter, r *http.Request) { func (rs sceneRoutes) SceneMarkerPreview(w http.ResponseWriter, r *http.Request) {
scene := r.Context().Value(sceneKey).(*models.Scene) scene := r.Context().Value(sceneKey).(*models.Scene)
sceneMarkerID, _ := strconv.Atoi(chi.URLParam(r, "sceneMarkerId")) sceneMarkerID, _ := strconv.Atoi(chi.URLParam(r, "sceneMarkerId"))
qb := models.NewSceneMarkerQueryBuilder() var sceneMarker *models.SceneMarker
sceneMarker, err := qb.Find(sceneMarkerID) if err := rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
if err != nil { var err error
logger.Warn("Error when getting scene marker for stream") sceneMarker, err = repo.SceneMarker().Find(sceneMarkerID)
http.Error(w, http.StatusText(404), 404) return err
}); err != nil {
logger.Warnf("Error when getting scene marker for stream: %s", err.Error())
http.Error(w, http.StatusText(500), 500)
return return
} }
filepath := manager.GetInstance().Paths.SceneMarkers.GetStreamPreviewImagePath(scene.GetHash(config.GetVideoFileNamingAlgorithm()), int(sceneMarker.Seconds)) filepath := manager.GetInstance().Paths.SceneMarkers.GetStreamPreviewImagePath(scene.GetHash(config.GetVideoFileNamingAlgorithm()), int(sceneMarker.Seconds))
@@ -310,17 +330,21 @@ func SceneCtx(next http.Handler) http.Handler {
sceneID, _ := strconv.Atoi(sceneIdentifierQueryParam) sceneID, _ := strconv.Atoi(sceneIdentifierQueryParam)
var scene *models.Scene var scene *models.Scene
qb := models.NewSceneQueryBuilder() manager.GetInstance().TxnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
if sceneID == 0 { qb := repo.Scene()
// determine checksum/os by the length of the query param if sceneID == 0 {
if len(sceneIdentifierQueryParam) == 32 { // determine checksum/os by the length of the query param
scene, _ = qb.FindByChecksum(sceneIdentifierQueryParam) if len(sceneIdentifierQueryParam) == 32 {
scene, _ = qb.FindByChecksum(sceneIdentifierQueryParam)
} else {
scene, _ = qb.FindByOSHash(sceneIdentifierQueryParam)
}
} else { } else {
scene, _ = qb.FindByOSHash(sceneIdentifierQueryParam) scene, _ = qb.Find(sceneID)
} }
} else {
scene, _ = qb.Find(sceneID) return nil
} })
if scene == nil { if scene == nil {
http.Error(w, http.StatusText(404), 404) http.Error(w, http.StatusText(404), 404)

View File

@@ -6,11 +6,14 @@ import (
"strconv" "strconv"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/stashapp/stash/pkg/manager"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
) )
type studioRoutes struct{} type studioRoutes struct {
txnManager models.TransactionManager
}
func (rs studioRoutes) Routes() chi.Router { func (rs studioRoutes) Routes() chi.Router {
r := chi.NewRouter() r := chi.NewRouter()
@@ -25,12 +28,14 @@ func (rs studioRoutes) Routes() chi.Router {
func (rs studioRoutes) Image(w http.ResponseWriter, r *http.Request) { func (rs studioRoutes) Image(w http.ResponseWriter, r *http.Request) {
studio := r.Context().Value(studioKey).(*models.Studio) studio := r.Context().Value(studioKey).(*models.Studio)
qb := models.NewStudioQueryBuilder()
var image []byte
defaultParam := r.URL.Query().Get("default") defaultParam := r.URL.Query().Get("default")
var image []byte
if defaultParam != "true" { if defaultParam != "true" {
image, _ = qb.GetStudioImage(studio.ID, nil) rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
image, _ = repo.Studio().GetImage(studio.ID)
return nil
})
} }
if len(image) == 0 { if len(image) == 0 {
@@ -48,9 +53,12 @@ func StudioCtx(next http.Handler) http.Handler {
return return
} }
qb := models.NewStudioQueryBuilder() var studio *models.Studio
studio, err := qb.Find(studioID, nil) if err := manager.GetInstance().TxnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
if err != nil { var err error
studio, err = repo.Studio().Find(studioID)
return err
}); err != nil {
http.Error(w, http.StatusText(404), 404) http.Error(w, http.StatusText(404), 404)
return return
} }

View File

@@ -6,11 +6,14 @@ import (
"strconv" "strconv"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/stashapp/stash/pkg/manager"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
) )
type tagRoutes struct{} type tagRoutes struct {
txnManager models.TransactionManager
}
func (rs tagRoutes) Routes() chi.Router { func (rs tagRoutes) Routes() chi.Router {
r := chi.NewRouter() r := chi.NewRouter()
@@ -25,12 +28,17 @@ func (rs tagRoutes) Routes() chi.Router {
func (rs tagRoutes) Image(w http.ResponseWriter, r *http.Request) { func (rs tagRoutes) Image(w http.ResponseWriter, r *http.Request) {
tag := r.Context().Value(tagKey).(*models.Tag) tag := r.Context().Value(tagKey).(*models.Tag)
qb := models.NewTagQueryBuilder()
image, _ := qb.GetTagImage(tag.ID, nil)
// use default image if not present
defaultParam := r.URL.Query().Get("default") defaultParam := r.URL.Query().Get("default")
if len(image) == 0 || defaultParam == "true" {
var image []byte
if defaultParam != "true" {
rs.txnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
image, _ = repo.Tag().GetImage(tag.ID)
return nil
})
}
if len(image) == 0 {
image = models.DefaultTagImage image = models.DefaultTagImage
} }
@@ -45,9 +53,12 @@ func TagCtx(next http.Handler) http.Handler {
return return
} }
qb := models.NewTagQueryBuilder() var tag *models.Tag
tag, err := qb.Find(tagID, nil) if err := manager.GetInstance().TxnManager.WithReadTxn(r.Context(), func(repo models.ReaderRepository) error {
if err != nil { var err error
tag, err = repo.Tag().Find(tagID)
return err
}); err != nil {
http.Error(w, http.StatusText(404), 404) http.Error(w, http.StatusText(404), 404)
return return
} }

View File

@@ -134,7 +134,12 @@ func Start() {
return true return true
}, },
}) })
gqlHandler := handler.GraphQL(models.NewExecutableSchema(models.Config{Resolvers: &Resolver{}}), recoverFunc, websocketUpgrader)
txnManager := manager.GetInstance().TxnManager
resolver := &Resolver{
txnManager: txnManager,
}
gqlHandler := handler.GraphQL(models.NewExecutableSchema(models.Config{Resolvers: resolver}), recoverFunc, websocketUpgrader)
r.Handle("/graphql", gqlHandler) r.Handle("/graphql", gqlHandler)
r.Handle("/playground", handler.Playground("GraphQL playground", "/graphql")) r.Handle("/playground", handler.Playground("GraphQL playground", "/graphql"))
@@ -145,12 +150,24 @@ func Start() {
r.Get(loginEndPoint, getLoginHandler) r.Get(loginEndPoint, getLoginHandler)
r.Mount("/performer", performerRoutes{}.Routes()) r.Mount("/performer", performerRoutes{
r.Mount("/scene", sceneRoutes{}.Routes()) txnManager: txnManager,
r.Mount("/image", imageRoutes{}.Routes()) }.Routes())
r.Mount("/studio", studioRoutes{}.Routes()) r.Mount("/scene", sceneRoutes{
r.Mount("/movie", movieRoutes{}.Routes()) txnManager: txnManager,
r.Mount("/tag", tagRoutes{}.Routes()) }.Routes())
r.Mount("/image", imageRoutes{
txnManager: txnManager,
}.Routes())
r.Mount("/studio", studioRoutes{
txnManager: txnManager,
}.Routes())
r.Mount("/movie", movieRoutes{
txnManager: txnManager,
}.Routes())
r.Mount("/tag", tagRoutes{
txnManager: txnManager,
}.Routes())
r.Mount("/downloads", downloadsRoutes{}.Routes()) r.Mount("/downloads", downloadsRoutes{}.Routes())
r.HandleFunc("/css", func(w http.ResponseWriter, r *http.Request) { r.HandleFunc("/css", func(w http.ResponseWriter, r *http.Request) {

View File

@@ -68,9 +68,9 @@ func Initialize(databasePath string) bool {
func open(databasePath string, disableForeignKeys bool) *sqlx.DB { func open(databasePath string, disableForeignKeys bool) *sqlx.DB {
// https://github.com/mattn/go-sqlite3 // https://github.com/mattn/go-sqlite3
url := "file:" + databasePath url := "file:" + databasePath + "?_journal=WAL"
if !disableForeignKeys { if !disableForeignKeys {
url += "?_fk=true" url += "&_fk=true"
} }
conn, err := sqlx.Open(sqlite3Driver, url) conn, err := sqlx.Open(sqlite3Driver, url)

View File

@@ -6,7 +6,6 @@ import (
"github.com/stashapp/stash/pkg/manager/jsonschema" "github.com/stashapp/stash/pkg/manager/jsonschema"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/models/mocks" "github.com/stashapp/stash/pkg/models/mocks"
"github.com/stashapp/stash/pkg/models/modelstest"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"testing" "testing"
@@ -51,18 +50,18 @@ var updateTime time.Time = time.Date(2002, 01, 01, 0, 0, 0, 0, time.UTC)
func createFullGallery(id int) models.Gallery { func createFullGallery(id int) models.Gallery {
return models.Gallery{ return models.Gallery{
ID: id, ID: id,
Path: modelstest.NullString(path), Path: models.NullString(path),
Zip: zip, Zip: zip,
Title: modelstest.NullString(title), Title: models.NullString(title),
Checksum: checksum, Checksum: checksum,
Date: models.SQLiteDate{ Date: models.SQLiteDate{
String: date, String: date,
Valid: true, Valid: true,
}, },
Details: modelstest.NullString(details), Details: models.NullString(details),
Rating: modelstest.NullInt64(rating), Rating: models.NullInt64(rating),
Organized: organized, Organized: organized,
URL: modelstest.NullString(url), URL: models.NullString(url),
CreatedAt: models.SQLiteTimestamp{ CreatedAt: models.SQLiteTimestamp{
Timestamp: createTime, Timestamp: createTime,
}, },
@@ -146,7 +145,7 @@ func TestToJSON(t *testing.T) {
func createStudioGallery(studioID int) models.Gallery { func createStudioGallery(studioID int) models.Gallery {
return models.Gallery{ return models.Gallery{
StudioID: modelstest.NullInt64(int64(studioID)), StudioID: models.NullInt64(int64(studioID)),
} }
} }
@@ -180,7 +179,7 @@ func TestGetStudioName(t *testing.T) {
studioErr := errors.New("error getting image") studioErr := errors.New("error getting image")
mockStudioReader.On("Find", studioID).Return(&models.Studio{ mockStudioReader.On("Find", studioID).Return(&models.Studio{
Name: modelstest.NullString(studioName), Name: models.NullString(studioName),
}, nil).Once() }, nil).Once()
mockStudioReader.On("Find", missingStudioID).Return(nil, nil).Once() mockStudioReader.On("Find", missingStudioID).Return(nil, nil).Once()
mockStudioReader.On("Find", errStudioID).Return(nil, studioErr).Once() mockStudioReader.On("Find", errStudioID).Return(nil, studioErr).Once()

View File

@@ -1,30 +0,0 @@
package gallery
import (
"github.com/stashapp/stash/pkg/api/urlbuilders"
"github.com/stashapp/stash/pkg/models"
)
func GetFiles(g *models.Gallery, baseURL string) []*models.GalleryFilesType {
var galleryFiles []*models.GalleryFilesType
qb := models.NewImageQueryBuilder()
images, err := qb.FindByGalleryID(g.ID)
if err != nil {
return nil
}
for i, img := range images {
builder := urlbuilders.NewImageURLBuilder(baseURL, img.ID)
imageURL := builder.GetImageURL()
galleryFile := models.GalleryFilesType{
Index: i,
Name: &img.Title.String,
Path: &imageURL,
}
galleryFiles = append(galleryFiles, &galleryFile)
}
return galleryFiles
}

View File

@@ -15,7 +15,6 @@ type Importer struct {
StudioWriter models.StudioReaderWriter StudioWriter models.StudioReaderWriter
PerformerWriter models.PerformerReaderWriter PerformerWriter models.PerformerReaderWriter
TagWriter models.TagReaderWriter TagWriter models.TagReaderWriter
JoinWriter models.JoinReaderWriter
Input jsonschema.Gallery Input jsonschema.Gallery
MissingRefBehaviour models.ImportMissingRefEnum MissingRefBehaviour models.ImportMissingRefEnum
@@ -237,29 +236,22 @@ func (i *Importer) createTags(names []string) ([]*models.Tag, error) {
func (i *Importer) PostImport(id int) error { func (i *Importer) PostImport(id int) error {
if len(i.performers) > 0 { if len(i.performers) > 0 {
var performerJoins []models.PerformersGalleries var performerIDs []int
for _, performer := range i.performers { for _, performer := range i.performers {
join := models.PerformersGalleries{ performerIDs = append(performerIDs, performer.ID)
PerformerID: performer.ID,
GalleryID: id,
}
performerJoins = append(performerJoins, join)
} }
if err := i.JoinWriter.UpdatePerformersGalleries(id, performerJoins); err != nil {
if err := i.ReaderWriter.UpdatePerformers(id, performerIDs); err != nil {
return fmt.Errorf("failed to associate performers: %s", err.Error()) return fmt.Errorf("failed to associate performers: %s", err.Error())
} }
} }
if len(i.tags) > 0 { if len(i.tags) > 0 {
var tagJoins []models.GalleriesTags var tagIDs []int
for _, tag := range i.tags { for _, t := range i.tags {
join := models.GalleriesTags{ tagIDs = append(tagIDs, t.ID)
GalleryID: id,
TagID: tag.ID,
}
tagJoins = append(tagJoins, join)
} }
if err := i.JoinWriter.UpdateGalleriesTags(id, tagJoins); err != nil { if err := i.ReaderWriter.UpdateTags(id, tagIDs); err != nil {
return fmt.Errorf("failed to associate tags: %s", err.Error()) return fmt.Errorf("failed to associate tags: %s", err.Error())
} }
} }

View File

@@ -8,7 +8,6 @@ import (
"github.com/stashapp/stash/pkg/manager/jsonschema" "github.com/stashapp/stash/pkg/manager/jsonschema"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/models/mocks" "github.com/stashapp/stash/pkg/models/mocks"
"github.com/stashapp/stash/pkg/models/modelstest"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
) )
@@ -77,17 +76,17 @@ func TestImporterPreImport(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
expectedGallery := models.Gallery{ expectedGallery := models.Gallery{
Path: modelstest.NullString(path), Path: models.NullString(path),
Checksum: checksum, Checksum: checksum,
Title: modelstest.NullString(title), Title: models.NullString(title),
Date: models.SQLiteDate{ Date: models.SQLiteDate{
String: date, String: date,
Valid: true, Valid: true,
}, },
Details: modelstest.NullString(details), Details: models.NullString(details),
Rating: modelstest.NullInt64(rating), Rating: models.NullInt64(rating),
Organized: organized, Organized: organized,
URL: modelstest.NullString(url), URL: models.NullString(url),
CreatedAt: models.SQLiteTimestamp{ CreatedAt: models.SQLiteTimestamp{
Timestamp: createdAt, Timestamp: createdAt,
}, },
@@ -194,7 +193,7 @@ func TestImporterPreImportWithPerformer(t *testing.T) {
performerReaderWriter.On("FindByNames", []string{existingPerformerName}, false).Return([]*models.Performer{ performerReaderWriter.On("FindByNames", []string{existingPerformerName}, false).Return([]*models.Performer{
{ {
ID: existingPerformerID, ID: existingPerformerID,
Name: modelstest.NullString(existingPerformerName), Name: models.NullString(existingPerformerName),
}, },
}, nil).Once() }, nil).Once()
performerReaderWriter.On("FindByNames", []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once() performerReaderWriter.On("FindByNames", []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once()
@@ -354,10 +353,10 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) {
} }
func TestImporterPostImportUpdatePerformers(t *testing.T) { func TestImporterPostImportUpdatePerformers(t *testing.T) {
joinReaderWriter := &mocks.JoinReaderWriter{} galleryReaderWriter := &mocks.GalleryReaderWriter{}
i := Importer{ i := Importer{
JoinWriter: joinReaderWriter, ReaderWriter: galleryReaderWriter,
performers: []*models.Performer{ performers: []*models.Performer{
{ {
ID: existingPerformerID, ID: existingPerformerID,
@@ -365,15 +364,10 @@ func TestImporterPostImportUpdatePerformers(t *testing.T) {
}, },
} }
updateErr := errors.New("UpdatePerformersGalleries error") updateErr := errors.New("UpdatePerformers error")
joinReaderWriter.On("UpdatePerformersGalleries", galleryID, []models.PerformersGalleries{ galleryReaderWriter.On("UpdatePerformers", galleryID, []int{existingPerformerID}).Return(nil).Once()
{ galleryReaderWriter.On("UpdatePerformers", errPerformersID, mock.AnythingOfType("[]int")).Return(updateErr).Once()
PerformerID: existingPerformerID,
GalleryID: galleryID,
},
}).Return(nil).Once()
joinReaderWriter.On("UpdatePerformersGalleries", errPerformersID, mock.AnythingOfType("[]models.PerformersGalleries")).Return(updateErr).Once()
err := i.PostImport(galleryID) err := i.PostImport(galleryID)
assert.Nil(t, err) assert.Nil(t, err)
@@ -381,14 +375,14 @@ func TestImporterPostImportUpdatePerformers(t *testing.T) {
err = i.PostImport(errPerformersID) err = i.PostImport(errPerformersID)
assert.NotNil(t, err) assert.NotNil(t, err)
joinReaderWriter.AssertExpectations(t) galleryReaderWriter.AssertExpectations(t)
} }
func TestImporterPostImportUpdateTags(t *testing.T) { func TestImporterPostImportUpdateTags(t *testing.T) {
joinReaderWriter := &mocks.JoinReaderWriter{} galleryReaderWriter := &mocks.GalleryReaderWriter{}
i := Importer{ i := Importer{
JoinWriter: joinReaderWriter, ReaderWriter: galleryReaderWriter,
tags: []*models.Tag{ tags: []*models.Tag{
{ {
ID: existingTagID, ID: existingTagID,
@@ -396,15 +390,10 @@ func TestImporterPostImportUpdateTags(t *testing.T) {
}, },
} }
updateErr := errors.New("UpdateGalleriesTags error") updateErr := errors.New("UpdateTags error")
joinReaderWriter.On("UpdateGalleriesTags", galleryID, []models.GalleriesTags{ galleryReaderWriter.On("UpdateTags", galleryID, []int{existingTagID}).Return(nil).Once()
{ galleryReaderWriter.On("UpdateTags", errTagsID, mock.AnythingOfType("[]int")).Return(updateErr).Once()
TagID: existingTagID,
GalleryID: galleryID,
},
}).Return(nil).Once()
joinReaderWriter.On("UpdateGalleriesTags", errTagsID, mock.AnythingOfType("[]models.GalleriesTags")).Return(updateErr).Once()
err := i.PostImport(galleryID) err := i.PostImport(galleryID)
assert.Nil(t, err) assert.Nil(t, err)
@@ -412,7 +401,7 @@ func TestImporterPostImportUpdateTags(t *testing.T) {
err = i.PostImport(errTagsID) err = i.PostImport(errTagsID)
assert.NotNil(t, err) assert.NotNil(t, err)
joinReaderWriter.AssertExpectations(t) galleryReaderWriter.AssertExpectations(t)
} }
func TestImporterFindExistingID(t *testing.T) { func TestImporterFindExistingID(t *testing.T) {
@@ -454,11 +443,11 @@ func TestCreate(t *testing.T) {
readerWriter := &mocks.GalleryReaderWriter{} readerWriter := &mocks.GalleryReaderWriter{}
gallery := models.Gallery{ gallery := models.Gallery{
Title: modelstest.NullString(title), Title: models.NullString(title),
} }
galleryErr := models.Gallery{ galleryErr := models.Gallery{
Title: modelstest.NullString(galleryNameErr), Title: models.NullString(galleryNameErr),
} }
i := Importer{ i := Importer{
@@ -488,7 +477,7 @@ func TestUpdate(t *testing.T) {
readerWriter := &mocks.GalleryReaderWriter{} readerWriter := &mocks.GalleryReaderWriter{}
gallery := models.Gallery{ gallery := models.Gallery{
Title: modelstest.NullString(title), Title: models.NullString(title),
} }
i := Importer{ i := Importer{

23
pkg/gallery/update.go Normal file
View File

@@ -0,0 +1,23 @@
package gallery
import (
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/utils"
)
func UpdateFileModTime(qb models.GalleryWriter, id int, modTime models.NullSQLiteTimestamp) (*models.Gallery, error) {
return qb.UpdatePartial(models.GalleryPartial{
ID: id,
FileModTime: &modTime,
})
}
func AddImage(qb models.GalleryReaderWriter, galleryID int, imageID int) error {
imageIDs, err := qb.GetImageIDs(galleryID)
if err != nil {
return err
}
imageIDs = utils.IntAppendUnique(imageIDs, imageID)
return qb.UpdateImages(galleryID, imageIDs)
}

View File

@@ -6,7 +6,6 @@ import (
"github.com/stashapp/stash/pkg/manager/jsonschema" "github.com/stashapp/stash/pkg/manager/jsonschema"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/models/mocks" "github.com/stashapp/stash/pkg/models/mocks"
"github.com/stashapp/stash/pkg/models/modelstest"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"testing" "testing"
@@ -65,14 +64,14 @@ var updateTime time.Time = time.Date(2002, 01, 01, 0, 0, 0, 0, time.UTC)
func createFullImage(id int) models.Image { func createFullImage(id int) models.Image {
return models.Image{ return models.Image{
ID: id, ID: id,
Title: modelstest.NullString(title), Title: models.NullString(title),
Checksum: checksum, Checksum: checksum,
Height: modelstest.NullInt64(height), Height: models.NullInt64(height),
OCounter: ocounter, OCounter: ocounter,
Rating: modelstest.NullInt64(rating), Rating: models.NullInt64(rating),
Size: models.NullInt64(int64(size)),
Organized: organized, Organized: organized,
Size: modelstest.NullInt64(int64(size)), Width: models.NullInt64(width),
Width: modelstest.NullInt64(width),
CreatedAt: models.SQLiteTimestamp{ CreatedAt: models.SQLiteTimestamp{
Timestamp: createTime, Timestamp: createTime,
}, },
@@ -150,7 +149,7 @@ func TestToJSON(t *testing.T) {
func createStudioImage(studioID int) models.Image { func createStudioImage(studioID int) models.Image {
return models.Image{ return models.Image{
StudioID: modelstest.NullInt64(int64(studioID)), StudioID: models.NullInt64(int64(studioID)),
} }
} }
@@ -184,7 +183,7 @@ func TestGetStudioName(t *testing.T) {
studioErr := errors.New("error getting image") studioErr := errors.New("error getting image")
mockStudioReader.On("Find", studioID).Return(&models.Studio{ mockStudioReader.On("Find", studioID).Return(&models.Studio{
Name: modelstest.NullString(studioName), Name: models.NullString(studioName),
}, nil).Once() }, nil).Once()
mockStudioReader.On("Find", missingStudioID).Return(nil, nil).Once() mockStudioReader.On("Find", missingStudioID).Return(nil, nil).Once()
mockStudioReader.On("Find", errStudioID).Return(nil, studioErr).Once() mockStudioReader.On("Find", errStudioID).Return(nil, studioErr).Once()

View File

@@ -16,7 +16,6 @@ type Importer struct {
GalleryWriter models.GalleryReaderWriter GalleryWriter models.GalleryReaderWriter
PerformerWriter models.PerformerReaderWriter PerformerWriter models.PerformerReaderWriter
TagWriter models.TagReaderWriter TagWriter models.TagReaderWriter
JoinWriter models.JoinReaderWriter
Input jsonschema.Image Input jsonschema.Image
Path string Path string
MissingRefBehaviour models.ImportMissingRefEnum MissingRefBehaviour models.ImportMissingRefEnum
@@ -227,43 +226,33 @@ func (i *Importer) populateTags() error {
func (i *Importer) PostImport(id int) error { func (i *Importer) PostImport(id int) error {
if len(i.galleries) > 0 { if len(i.galleries) > 0 {
var galleryJoins []models.GalleriesImages var galleryIDs []int
for _, gallery := range i.galleries { for _, g := range i.galleries {
join := models.GalleriesImages{ galleryIDs = append(galleryIDs, g.ID)
GalleryID: gallery.ID,
ImageID: id,
}
galleryJoins = append(galleryJoins, join)
} }
if err := i.JoinWriter.UpdateGalleriesImages(id, galleryJoins); err != nil {
if err := i.ReaderWriter.UpdateGalleries(id, galleryIDs); err != nil {
return fmt.Errorf("failed to associate galleries: %s", err.Error()) return fmt.Errorf("failed to associate galleries: %s", err.Error())
} }
} }
if len(i.performers) > 0 { if len(i.performers) > 0 {
var performerJoins []models.PerformersImages var performerIDs []int
for _, performer := range i.performers { for _, performer := range i.performers {
join := models.PerformersImages{ performerIDs = append(performerIDs, performer.ID)
PerformerID: performer.ID,
ImageID: id,
}
performerJoins = append(performerJoins, join)
} }
if err := i.JoinWriter.UpdatePerformersImages(id, performerJoins); err != nil {
if err := i.ReaderWriter.UpdatePerformers(id, performerIDs); err != nil {
return fmt.Errorf("failed to associate performers: %s", err.Error()) return fmt.Errorf("failed to associate performers: %s", err.Error())
} }
} }
if len(i.tags) > 0 { if len(i.tags) > 0 {
var tagJoins []models.ImagesTags var tagIDs []int
for _, tag := range i.tags { for _, t := range i.tags {
join := models.ImagesTags{ tagIDs = append(tagIDs, t.ID)
ImageID: id,
TagID: tag.ID,
}
tagJoins = append(tagJoins, join)
} }
if err := i.JoinWriter.UpdateImagesTags(id, tagJoins); err != nil { if err := i.ReaderWriter.UpdateTags(id, tagIDs); err != nil {
return fmt.Errorf("failed to associate tags: %s", err.Error()) return fmt.Errorf("failed to associate tags: %s", err.Error())
} }
} }

View File

@@ -7,7 +7,6 @@ import (
"github.com/stashapp/stash/pkg/manager/jsonschema" "github.com/stashapp/stash/pkg/manager/jsonschema"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/models/mocks" "github.com/stashapp/stash/pkg/models/mocks"
"github.com/stashapp/stash/pkg/models/modelstest"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
) )
@@ -227,7 +226,7 @@ func TestImporterPreImportWithPerformer(t *testing.T) {
performerReaderWriter.On("FindByNames", []string{existingPerformerName}, false).Return([]*models.Performer{ performerReaderWriter.On("FindByNames", []string{existingPerformerName}, false).Return([]*models.Performer{
{ {
ID: existingPerformerID, ID: existingPerformerID,
Name: modelstest.NullString(existingPerformerName), Name: models.NullString(existingPerformerName),
}, },
}, nil).Once() }, nil).Once()
performerReaderWriter.On("FindByNames", []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once() performerReaderWriter.On("FindByNames", []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once()
@@ -387,10 +386,10 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) {
} }
func TestImporterPostImportUpdateGallery(t *testing.T) { func TestImporterPostImportUpdateGallery(t *testing.T) {
joinReaderWriter := &mocks.JoinReaderWriter{} readerWriter := &mocks.ImageReaderWriter{}
i := Importer{ i := Importer{
JoinWriter: joinReaderWriter, ReaderWriter: readerWriter,
galleries: []*models.Gallery{ galleries: []*models.Gallery{
{ {
ID: existingGalleryID, ID: existingGalleryID,
@@ -398,15 +397,10 @@ func TestImporterPostImportUpdateGallery(t *testing.T) {
}, },
} }
updateErr := errors.New("UpdateGalleriesImages error") updateErr := errors.New("UpdateGalleries error")
joinReaderWriter.On("UpdateGalleriesImages", imageID, []models.GalleriesImages{ readerWriter.On("UpdateGalleries", imageID, []int{existingGalleryID}).Return(nil).Once()
{ readerWriter.On("UpdateGalleries", errGalleriesID, mock.AnythingOfType("[]int")).Return(updateErr).Once()
GalleryID: existingGalleryID,
ImageID: imageID,
},
}).Return(nil).Once()
joinReaderWriter.On("UpdateGalleriesImages", errGalleriesID, mock.AnythingOfType("[]models.GalleriesImages")).Return(updateErr).Once()
err := i.PostImport(imageID) err := i.PostImport(imageID)
assert.Nil(t, err) assert.Nil(t, err)
@@ -414,14 +408,14 @@ func TestImporterPostImportUpdateGallery(t *testing.T) {
err = i.PostImport(errGalleriesID) err = i.PostImport(errGalleriesID)
assert.NotNil(t, err) assert.NotNil(t, err)
joinReaderWriter.AssertExpectations(t) readerWriter.AssertExpectations(t)
} }
func TestImporterPostImportUpdatePerformers(t *testing.T) { func TestImporterPostImportUpdatePerformers(t *testing.T) {
joinReaderWriter := &mocks.JoinReaderWriter{} readerWriter := &mocks.ImageReaderWriter{}
i := Importer{ i := Importer{
JoinWriter: joinReaderWriter, ReaderWriter: readerWriter,
performers: []*models.Performer{ performers: []*models.Performer{
{ {
ID: existingPerformerID, ID: existingPerformerID,
@@ -429,15 +423,10 @@ func TestImporterPostImportUpdatePerformers(t *testing.T) {
}, },
} }
updateErr := errors.New("UpdatePerformersImages error") updateErr := errors.New("UpdatePerformers error")
joinReaderWriter.On("UpdatePerformersImages", imageID, []models.PerformersImages{ readerWriter.On("UpdatePerformers", imageID, []int{existingPerformerID}).Return(nil).Once()
{ readerWriter.On("UpdatePerformers", errPerformersID, mock.AnythingOfType("[]int")).Return(updateErr).Once()
PerformerID: existingPerformerID,
ImageID: imageID,
},
}).Return(nil).Once()
joinReaderWriter.On("UpdatePerformersImages", errPerformersID, mock.AnythingOfType("[]models.PerformersImages")).Return(updateErr).Once()
err := i.PostImport(imageID) err := i.PostImport(imageID)
assert.Nil(t, err) assert.Nil(t, err)
@@ -445,14 +434,14 @@ func TestImporterPostImportUpdatePerformers(t *testing.T) {
err = i.PostImport(errPerformersID) err = i.PostImport(errPerformersID)
assert.NotNil(t, err) assert.NotNil(t, err)
joinReaderWriter.AssertExpectations(t) readerWriter.AssertExpectations(t)
} }
func TestImporterPostImportUpdateTags(t *testing.T) { func TestImporterPostImportUpdateTags(t *testing.T) {
joinReaderWriter := &mocks.JoinReaderWriter{} readerWriter := &mocks.ImageReaderWriter{}
i := Importer{ i := Importer{
JoinWriter: joinReaderWriter, ReaderWriter: readerWriter,
tags: []*models.Tag{ tags: []*models.Tag{
{ {
ID: existingTagID, ID: existingTagID,
@@ -460,15 +449,10 @@ func TestImporterPostImportUpdateTags(t *testing.T) {
}, },
} }
updateErr := errors.New("UpdateImagesTags error") updateErr := errors.New("UpdateTags error")
joinReaderWriter.On("UpdateImagesTags", imageID, []models.ImagesTags{ readerWriter.On("UpdateTags", imageID, []int{existingTagID}).Return(nil).Once()
{ readerWriter.On("UpdateTags", errTagsID, mock.AnythingOfType("[]int")).Return(updateErr).Once()
TagID: existingTagID,
ImageID: imageID,
},
}).Return(nil).Once()
joinReaderWriter.On("UpdateImagesTags", errTagsID, mock.AnythingOfType("[]models.ImagesTags")).Return(updateErr).Once()
err := i.PostImport(imageID) err := i.PostImport(imageID)
assert.Nil(t, err) assert.Nil(t, err)
@@ -476,7 +460,7 @@ func TestImporterPostImportUpdateTags(t *testing.T) {
err = i.PostImport(errTagsID) err = i.PostImport(errTagsID)
assert.NotNil(t, err) assert.NotNil(t, err)
joinReaderWriter.AssertExpectations(t) readerWriter.AssertExpectations(t)
} }
func TestImporterFindExistingID(t *testing.T) { func TestImporterFindExistingID(t *testing.T) {
@@ -518,11 +502,11 @@ func TestCreate(t *testing.T) {
readerWriter := &mocks.ImageReaderWriter{} readerWriter := &mocks.ImageReaderWriter{}
image := models.Image{ image := models.Image{
Title: modelstest.NullString(title), Title: models.NullString(title),
} }
imageErr := models.Image{ imageErr := models.Image{
Title: modelstest.NullString(imageNameErr), Title: models.NullString(imageNameErr),
} }
i := Importer{ i := Importer{
@@ -553,11 +537,11 @@ func TestUpdate(t *testing.T) {
readerWriter := &mocks.ImageReaderWriter{} readerWriter := &mocks.ImageReaderWriter{}
image := models.Image{ image := models.Image{
Title: modelstest.NullString(title), Title: models.NullString(title),
} }
imageErr := models.Image{ imageErr := models.Image{
Title: modelstest.NullString(imageNameErr), Title: models.NullString(imageNameErr),
} }
i := Importer{ i := Importer{

10
pkg/image/update.go Normal file
View File

@@ -0,0 +1,10 @@
package image
import "github.com/stashapp/stash/pkg/models"
func UpdateFileModTime(qb models.ImageWriter, id int, modTime models.NullSQLiteTimestamp) (*models.Image, error) {
return qb.Update(models.ImagePartial{
ID: id,
FileModTime: &modTime,
})
}

View File

@@ -1,6 +1,7 @@
package manager package manager
import ( import (
"context"
"errors" "errors"
"github.com/spf13/viper" "github.com/spf13/viper"
@@ -9,13 +10,16 @@ import (
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
) )
func setInitialMD5Config() { func setInitialMD5Config(txnManager models.TransactionManager) {
// if there are no scene files in the database, then default the // if there are no scene files in the database, then default the
// VideoFileNamingAlgorithm config setting to oshash and calculateMD5 to // VideoFileNamingAlgorithm config setting to oshash and calculateMD5 to
// false, otherwise set them to true for backwards compatibility purposes // false, otherwise set them to true for backwards compatibility purposes
sqb := models.NewSceneQueryBuilder() var count int
count, err := sqb.Count() if err := txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error {
if err != nil { var err error
count, err = r.Scene().Count()
return err
}); err != nil {
logger.Errorf("Error while counting scenes: %s", err.Error()) logger.Errorf("Error while counting scenes: %s", err.Error())
return return
} }
@@ -43,28 +47,30 @@ func setInitialMD5Config() {
// //
// Likewise, if VideoFileNamingAlgorithm is set to oshash, then this function // Likewise, if VideoFileNamingAlgorithm is set to oshash, then this function
// will ensure that all oshash values are set on all scenes. // will ensure that all oshash values are set on all scenes.
func ValidateVideoFileNamingAlgorithm(newValue models.HashAlgorithm) error { func ValidateVideoFileNamingAlgorithm(txnManager models.TransactionManager, newValue models.HashAlgorithm) error {
// if algorithm is being set to MD5, then all checksums must be present // if algorithm is being set to MD5, then all checksums must be present
qb := models.NewSceneQueryBuilder() return txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error {
if newValue == models.HashAlgorithmMd5 { qb := r.Scene()
missingMD5, err := qb.CountMissingChecksum() if newValue == models.HashAlgorithmMd5 {
if err != nil { missingMD5, err := qb.CountMissingChecksum()
return err if err != nil {
return err
}
if missingMD5 > 0 {
return errors.New("some checksums are missing on scenes. Run Scan with calculateMD5 set to true")
}
} else if newValue == models.HashAlgorithmOshash {
missingOSHash, err := qb.CountMissingOSHash()
if err != nil {
return err
}
if missingOSHash > 0 {
return errors.New("some oshash values are missing on scenes. Run Scan to populate")
}
} }
if missingMD5 > 0 { return nil
return errors.New("some checksums are missing on scenes. Run Scan with calculateMD5 set to true") })
}
} else if newValue == models.HashAlgorithmOshash {
missingOSHash, err := qb.CountMissingOSHash()
if err != nil {
return err
}
if missingOSHash > 0 {
return errors.New("some oshash values are missing on scenes. Run Scan to populate")
}
}
return nil
} }

View File

@@ -433,12 +433,6 @@ type SceneFilenameParser struct {
studioCache map[string]*models.Studio studioCache map[string]*models.Studio
movieCache map[string]*models.Movie movieCache map[string]*models.Movie
tagCache map[string]*models.Tag tagCache map[string]*models.Tag
performerQuery performerQueryer
sceneQuery sceneQueryer
tagQuery tagQueryer
studioQuery studioQueryer
movieQuery movieQueryer
} }
func NewSceneFilenameParser(filter *models.FindFilterType, config models.SceneParserInput) *SceneFilenameParser { func NewSceneFilenameParser(filter *models.FindFilterType, config models.SceneParserInput) *SceneFilenameParser {
@@ -455,21 +449,6 @@ func NewSceneFilenameParser(filter *models.FindFilterType, config models.ScenePa
p.initWhiteSpaceRegex() p.initWhiteSpaceRegex()
performerQuery := models.NewPerformerQueryBuilder()
p.performerQuery = &performerQuery
sceneQuery := models.NewSceneQueryBuilder()
p.sceneQuery = &sceneQuery
tagQuery := models.NewTagQueryBuilder()
p.tagQuery = &tagQuery
studioQuery := models.NewStudioQueryBuilder()
p.studioQuery = &studioQuery
movieQuery := models.NewMovieQueryBuilder()
p.movieQuery = &movieQuery
return p return p
} }
@@ -489,7 +468,7 @@ func (p *SceneFilenameParser) initWhiteSpaceRegex() {
} }
} }
func (p *SceneFilenameParser) Parse() ([]*models.SceneParserResult, int, error) { func (p *SceneFilenameParser) Parse(repo models.ReaderRepository) ([]*models.SceneParserResult, int, error) {
// perform the query to find the scenes // perform the query to find the scenes
mapper, err := newParseMapper(p.Pattern, p.ParserInput.IgnoreWords) mapper, err := newParseMapper(p.Pattern, p.ParserInput.IgnoreWords)
@@ -499,14 +478,17 @@ func (p *SceneFilenameParser) Parse() ([]*models.SceneParserResult, int, error)
p.Filter.Q = &mapper.regexString p.Filter.Q = &mapper.regexString
scenes, total := p.sceneQuery.QueryByPathRegex(p.Filter) scenes, total, err := repo.Scene().QueryByPathRegex(p.Filter)
if err != nil {
return nil, 0, err
}
ret := p.parseScenes(scenes, mapper) ret := p.parseScenes(repo, scenes, mapper)
return ret, total, nil return ret, total, nil
} }
func (p *SceneFilenameParser) parseScenes(scenes []*models.Scene, mapper *parseMapper) []*models.SceneParserResult { func (p *SceneFilenameParser) parseScenes(repo models.ReaderRepository, scenes []*models.Scene, mapper *parseMapper) []*models.SceneParserResult {
var ret []*models.SceneParserResult var ret []*models.SceneParserResult
for _, scene := range scenes { for _, scene := range scenes {
sceneHolder := mapper.parse(scene) sceneHolder := mapper.parse(scene)
@@ -515,7 +497,7 @@ func (p *SceneFilenameParser) parseScenes(scenes []*models.Scene, mapper *parseM
r := &models.SceneParserResult{ r := &models.SceneParserResult{
Scene: scene, Scene: scene,
} }
p.setParserResult(*sceneHolder, r) p.setParserResult(repo, *sceneHolder, r)
if r != nil { if r != nil {
ret = append(ret, r) ret = append(ret, r)
@@ -536,7 +518,7 @@ func (p SceneFilenameParser) replaceWhitespaceCharacters(value string) string {
return value return value
} }
func (p *SceneFilenameParser) queryPerformer(performerName string) *models.Performer { func (p *SceneFilenameParser) queryPerformer(qb models.PerformerReader, performerName string) *models.Performer {
// massage the performer name // massage the performer name
performerName = delimiterRE.ReplaceAllString(performerName, " ") performerName = delimiterRE.ReplaceAllString(performerName, " ")
@@ -546,7 +528,7 @@ func (p *SceneFilenameParser) queryPerformer(performerName string) *models.Perfo
} }
// perform an exact match and grab the first // perform an exact match and grab the first
performers, _ := p.performerQuery.FindByNames([]string{performerName}, nil, true) performers, _ := qb.FindByNames([]string{performerName}, true)
var ret *models.Performer var ret *models.Performer
if len(performers) > 0 { if len(performers) > 0 {
@@ -559,7 +541,7 @@ func (p *SceneFilenameParser) queryPerformer(performerName string) *models.Perfo
return ret return ret
} }
func (p *SceneFilenameParser) queryStudio(studioName string) *models.Studio { func (p *SceneFilenameParser) queryStudio(qb models.StudioReader, studioName string) *models.Studio {
// massage the performer name // massage the performer name
studioName = delimiterRE.ReplaceAllString(studioName, " ") studioName = delimiterRE.ReplaceAllString(studioName, " ")
@@ -568,7 +550,7 @@ func (p *SceneFilenameParser) queryStudio(studioName string) *models.Studio {
return ret return ret
} }
ret, _ := p.studioQuery.FindByName(studioName, nil, true) ret, _ := qb.FindByName(studioName, true)
// add result to cache // add result to cache
p.studioCache[studioName] = ret p.studioCache[studioName] = ret
@@ -576,7 +558,7 @@ func (p *SceneFilenameParser) queryStudio(studioName string) *models.Studio {
return ret return ret
} }
func (p *SceneFilenameParser) queryMovie(movieName string) *models.Movie { func (p *SceneFilenameParser) queryMovie(qb models.MovieReader, movieName string) *models.Movie {
// massage the movie name // massage the movie name
movieName = delimiterRE.ReplaceAllString(movieName, " ") movieName = delimiterRE.ReplaceAllString(movieName, " ")
@@ -585,7 +567,7 @@ func (p *SceneFilenameParser) queryMovie(movieName string) *models.Movie {
return ret return ret
} }
ret, _ := p.movieQuery.FindByName(movieName, nil, true) ret, _ := qb.FindByName(movieName, true)
// add result to cache // add result to cache
p.movieCache[movieName] = ret p.movieCache[movieName] = ret
@@ -593,7 +575,7 @@ func (p *SceneFilenameParser) queryMovie(movieName string) *models.Movie {
return ret return ret
} }
func (p *SceneFilenameParser) queryTag(tagName string) *models.Tag { func (p *SceneFilenameParser) queryTag(qb models.TagReader, tagName string) *models.Tag {
// massage the performer name // massage the performer name
tagName = delimiterRE.ReplaceAllString(tagName, " ") tagName = delimiterRE.ReplaceAllString(tagName, " ")
@@ -603,7 +585,7 @@ func (p *SceneFilenameParser) queryTag(tagName string) *models.Tag {
} }
// match tag name exactly // match tag name exactly
ret, _ := p.tagQuery.FindByName(tagName, nil, true) ret, _ := qb.FindByName(tagName, true)
// add result to cache // add result to cache
p.tagCache[tagName] = ret p.tagCache[tagName] = ret
@@ -611,12 +593,12 @@ func (p *SceneFilenameParser) queryTag(tagName string) *models.Tag {
return ret return ret
} }
func (p *SceneFilenameParser) setPerformers(h sceneHolder, result *models.SceneParserResult) { func (p *SceneFilenameParser) setPerformers(qb models.PerformerReader, h sceneHolder, result *models.SceneParserResult) {
// query for each performer // query for each performer
performersSet := make(map[int]bool) performersSet := make(map[int]bool)
for _, performerName := range h.performers { for _, performerName := range h.performers {
if performerName != "" { if performerName != "" {
performer := p.queryPerformer(performerName) performer := p.queryPerformer(qb, performerName)
if performer != nil { if performer != nil {
if _, found := performersSet[performer.ID]; !found { if _, found := performersSet[performer.ID]; !found {
result.PerformerIds = append(result.PerformerIds, strconv.Itoa(performer.ID)) result.PerformerIds = append(result.PerformerIds, strconv.Itoa(performer.ID))
@@ -627,12 +609,12 @@ func (p *SceneFilenameParser) setPerformers(h sceneHolder, result *models.SceneP
} }
} }
func (p *SceneFilenameParser) setTags(h sceneHolder, result *models.SceneParserResult) { func (p *SceneFilenameParser) setTags(qb models.TagReader, h sceneHolder, result *models.SceneParserResult) {
// query for each performer // query for each performer
tagsSet := make(map[int]bool) tagsSet := make(map[int]bool)
for _, tagName := range h.tags { for _, tagName := range h.tags {
if tagName != "" { if tagName != "" {
tag := p.queryTag(tagName) tag := p.queryTag(qb, tagName)
if tag != nil { if tag != nil {
if _, found := tagsSet[tag.ID]; !found { if _, found := tagsSet[tag.ID]; !found {
result.TagIds = append(result.TagIds, strconv.Itoa(tag.ID)) result.TagIds = append(result.TagIds, strconv.Itoa(tag.ID))
@@ -643,23 +625,23 @@ func (p *SceneFilenameParser) setTags(h sceneHolder, result *models.SceneParserR
} }
} }
func (p *SceneFilenameParser) setStudio(h sceneHolder, result *models.SceneParserResult) { func (p *SceneFilenameParser) setStudio(qb models.StudioReader, h sceneHolder, result *models.SceneParserResult) {
// query for each performer // query for each performer
if h.studio != "" { if h.studio != "" {
studio := p.queryStudio(h.studio) studio := p.queryStudio(qb, h.studio)
if studio != nil { if studio != nil {
studioId := strconv.Itoa(studio.ID) studioID := strconv.Itoa(studio.ID)
result.StudioID = &studioId result.StudioID = &studioID
} }
} }
} }
func (p *SceneFilenameParser) setMovies(h sceneHolder, result *models.SceneParserResult) { func (p *SceneFilenameParser) setMovies(qb models.MovieReader, h sceneHolder, result *models.SceneParserResult) {
// query for each movie // query for each movie
moviesSet := make(map[int]bool) moviesSet := make(map[int]bool)
for _, movieName := range h.movies { for _, movieName := range h.movies {
if movieName != "" { if movieName != "" {
movie := p.queryMovie(movieName) movie := p.queryMovie(qb, movieName)
if movie != nil { if movie != nil {
if _, found := moviesSet[movie.ID]; !found { if _, found := moviesSet[movie.ID]; !found {
result.Movies = append(result.Movies, &models.SceneMovieID{ result.Movies = append(result.Movies, &models.SceneMovieID{
@@ -672,7 +654,7 @@ func (p *SceneFilenameParser) setMovies(h sceneHolder, result *models.SceneParse
} }
} }
func (p *SceneFilenameParser) setParserResult(h sceneHolder, result *models.SceneParserResult) { func (p *SceneFilenameParser) setParserResult(repo models.ReaderRepository, h sceneHolder, result *models.SceneParserResult) {
if h.result.Title.Valid { if h.result.Title.Valid {
title := h.result.Title.String title := h.result.Title.String
title = p.replaceWhitespaceCharacters(title) title = p.replaceWhitespaceCharacters(title)
@@ -694,15 +676,15 @@ func (p *SceneFilenameParser) setParserResult(h sceneHolder, result *models.Scen
} }
if len(h.performers) > 0 { if len(h.performers) > 0 {
p.setPerformers(h, result) p.setPerformers(repo.Performer(), h, result)
} }
if len(h.tags) > 0 { if len(h.tags) > 0 {
p.setTags(h, result) p.setTags(repo.Tag(), h, result)
} }
p.setStudio(h, result) p.setStudio(repo.Studio(), h, result)
if len(h.movies) > 0 { if len(h.movies) > 0 {
p.setMovies(h, result) p.setMovies(repo.Movie(), h, result)
} }
} }

View File

@@ -5,43 +5,11 @@ import (
"os" "os"
"strings" "strings"
"github.com/jmoiron/sqlx"
"github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
) )
// DestroyImage deletes an image and its associated relationships from the
// database.
func DestroyImage(imageID int, tx *sqlx.Tx) error {
qb := models.NewImageQueryBuilder()
jqb := models.NewJoinsQueryBuilder()
_, err := qb.Find(imageID)
if err != nil {
return err
}
if err := jqb.DestroyImagesTags(imageID, tx); err != nil {
return err
}
if err := jqb.DestroyPerformersImages(imageID, tx); err != nil {
return err
}
if err := jqb.DestroyImageGalleries(imageID, tx); err != nil {
return err
}
if err := qb.Destroy(imageID, tx); err != nil {
return err
}
return nil
}
// DeleteGeneratedImageFiles deletes generated files for the provided image. // DeleteGeneratedImageFiles deletes generated files for the provided image.
func DeleteGeneratedImageFiles(image *models.Image) { func DeleteGeneratedImageFiles(image *models.Image) {
thumbPath := GetInstance().Paths.Generated.GetThumbnailPath(image.Checksum, models.DefaultGthumbWidth) thumbPath := GetInstance().Paths.Generated.GetThumbnailPath(image.Checksum, models.DefaultGthumbWidth)

View File

@@ -10,8 +10,10 @@ import (
"github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/manager/config" "github.com/stashapp/stash/pkg/manager/config"
"github.com/stashapp/stash/pkg/manager/paths" "github.com/stashapp/stash/pkg/manager/paths"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/plugin" "github.com/stashapp/stash/pkg/plugin"
"github.com/stashapp/stash/pkg/scraper" "github.com/stashapp/stash/pkg/scraper"
"github.com/stashapp/stash/pkg/sqlite"
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
) )
@@ -26,6 +28,8 @@ type singleton struct {
ScraperCache *scraper.Cache ScraperCache *scraper.Cache
DownloadStore *DownloadStore DownloadStore *DownloadStore
TxnManager models.TransactionManager
} }
var instance *singleton var instance *singleton
@@ -53,11 +57,12 @@ func Initialize() *singleton {
Status: TaskStatus{Status: Idle, Progress: -1}, Status: TaskStatus{Status: Idle, Progress: -1},
Paths: paths.NewPaths(), Paths: paths.NewPaths(),
PluginCache: initPluginCache(), PluginCache: initPluginCache(),
ScraperCache: initScraperCache(),
DownloadStore: NewDownloadStore(), DownloadStore: NewDownloadStore(),
TxnManager: sqlite.NewTransactionManager(),
} }
instance.ScraperCache = instance.initScraperCache()
instance.RefreshConfig() instance.RefreshConfig()
@@ -180,13 +185,13 @@ func initPluginCache() *plugin.Cache {
} }
// initScraperCache initializes a new scraper cache and returns it. // initScraperCache initializes a new scraper cache and returns it.
func initScraperCache() *scraper.Cache { func (s *singleton) initScraperCache() *scraper.Cache {
scraperConfig := scraper.GlobalConfig{ scraperConfig := scraper.GlobalConfig{
Path: config.GetScrapersPath(), Path: config.GetScrapersPath(),
UserAgent: config.GetScraperUserAgent(), UserAgent: config.GetScraperUserAgent(),
CDPPath: config.GetScraperCDPPath(), CDPPath: config.GetScraperCDPPath(),
} }
ret, err := scraper.NewCache(scraperConfig) ret, err := scraper.NewCache(scraperConfig, s.TxnManager)
if err != nil { if err != nil {
logger.Errorf("Error reading scraper configs: %s", err.Error()) logger.Errorf("Error reading scraper configs: %s", err.Error())
@@ -210,5 +215,5 @@ func (s *singleton) RefreshConfig() {
// RefreshScraperCache refreshes the scraper cache. Call this when scraper // RefreshScraperCache refreshes the scraper cache. Call this when scraper
// configuration changes. // configuration changes.
func (s *singleton) RefreshScraperCache() { func (s *singleton) RefreshScraperCache() {
s.ScraperCache = initScraperCache() s.ScraperCache = s.initScraperCache()
} }

View File

@@ -1,6 +1,7 @@
package manager package manager
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"os" "os"
@@ -119,7 +120,7 @@ func (s *singleton) neededScan(paths []*models.StashConfig) (total *int, newFile
for _, sp := range paths { for _, sp := range paths {
err := walkFilesToScan(sp, func(path string, info os.FileInfo, err error) error { err := walkFilesToScan(sp, func(path string, info os.FileInfo, err error) error {
t++ t++
task := ScanTask{FilePath: path} task := ScanTask{FilePath: path, TxnManager: s.TxnManager}
if !task.doesPathExist() { if !task.doesPathExist() {
n++ n++
} }
@@ -211,7 +212,17 @@ func (s *singleton) Scan(input models.ScanMetadataInput) {
instance.Paths.Generated.EnsureTmpDir() instance.Paths.Generated.EnsureTmpDir()
wg.Add() wg.Add()
task := ScanTask{FilePath: path, UseFileMetadata: input.UseFileMetadata, StripFileExtension: input.StripFileExtension, fileNamingAlgorithm: fileNamingAlgo, calculateMD5: calculateMD5, GeneratePreview: input.ScanGeneratePreviews, GenerateImagePreview: input.ScanGenerateImagePreviews, GenerateSprite: input.ScanGenerateSprites} task := ScanTask{
TxnManager: s.TxnManager,
FilePath: path,
UseFileMetadata: input.UseFileMetadata,
StripFileExtension: input.StripFileExtension,
fileNamingAlgorithm: fileNamingAlgo,
calculateMD5: calculateMD5,
GeneratePreview: input.ScanGeneratePreviews,
GenerateImagePreview: input.ScanGenerateImagePreviews,
GenerateSprite: input.ScanGenerateSprites,
}
go task.Start(&wg) go task.Start(&wg)
return nil return nil
@@ -240,7 +251,11 @@ func (s *singleton) Scan(input models.ScanMetadataInput) {
for _, path := range galleries { for _, path := range galleries {
wg.Add() wg.Add()
task := ScanTask{FilePath: path, UseFileMetadata: false} task := ScanTask{
TxnManager: s.TxnManager,
FilePath: path,
UseFileMetadata: false,
}
go task.associateGallery(&wg) go task.associateGallery(&wg)
wg.Wait() wg.Wait()
} }
@@ -261,6 +276,7 @@ func (s *singleton) Import() {
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
task := ImportTask{ task := ImportTask{
txnManager: s.TxnManager,
BaseDir: config.GetMetadataPath(), BaseDir: config.GetMetadataPath(),
Reset: true, Reset: true,
DuplicateBehaviour: models.ImportDuplicateEnumFail, DuplicateBehaviour: models.ImportDuplicateEnumFail,
@@ -284,7 +300,11 @@ func (s *singleton) Export() {
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
task := ExportTask{full: true, fileNamingAlgorithm: config.GetVideoFileNamingAlgorithm()} task := ExportTask{
txnManager: s.TxnManager,
full: true,
fileNamingAlgorithm: config.GetVideoFileNamingAlgorithm(),
}
go task.Start(&wg) go task.Start(&wg)
wg.Wait() wg.Wait()
}() }()
@@ -344,29 +364,47 @@ func (s *singleton) Generate(input models.GenerateMetadataInput) {
s.Status.SetStatus(Generate) s.Status.SetStatus(Generate)
s.Status.indefiniteProgress() s.Status.indefiniteProgress()
qb := models.NewSceneQueryBuilder()
mqb := models.NewSceneMarkerQueryBuilder()
//this.job.total = await ObjectionUtils.getCount(Scene); //this.job.total = await ObjectionUtils.getCount(Scene);
instance.Paths.Generated.EnsureTmpDir() instance.Paths.Generated.EnsureTmpDir()
sceneIDs := utils.StringSliceToIntSlice(input.SceneIDs) sceneIDs, err := utils.StringSliceToIntSlice(input.SceneIDs)
markerIDs := utils.StringSliceToIntSlice(input.MarkerIDs) if err != nil {
logger.Error(err.Error())
}
markerIDs, err := utils.StringSliceToIntSlice(input.MarkerIDs)
if err != nil {
logger.Error(err.Error())
}
go func() { go func() {
defer s.returnToIdleState() defer s.returnToIdleState()
var scenes []*models.Scene var scenes []*models.Scene
var err error var err error
var markers []*models.SceneMarker
if len(sceneIDs) > 0 { if err := s.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error {
scenes, err = qb.FindMany(sceneIDs) qb := r.Scene()
} else { if len(sceneIDs) > 0 {
scenes, err = qb.All() scenes, err = qb.FindMany(sceneIDs)
} } else {
scenes, err = qb.All()
}
if err != nil { if err != nil {
logger.Errorf("failed to get scenes for generate") return err
}
if len(markerIDs) > 0 {
markers, err = r.SceneMarker().FindMany(markerIDs)
if err != nil {
return err
}
}
return nil
}); err != nil {
logger.Error(err.Error())
return return
} }
@@ -377,14 +415,7 @@ func (s *singleton) Generate(input models.GenerateMetadataInput) {
s.Status.Progress = 0 s.Status.Progress = 0
lenScenes := len(scenes) lenScenes := len(scenes)
total := lenScenes total := lenScenes + len(markers)
var markers []*models.SceneMarker
if len(markerIDs) > 0 {
markers, err = mqb.FindMany(markerIDs)
total += len(markers)
}
if s.Status.stopping { if s.Status.stopping {
logger.Info("Stopping due to user request") logger.Info("Stopping due to user request")
@@ -429,7 +460,11 @@ func (s *singleton) Generate(input models.GenerateMetadataInput) {
} }
if input.Sprites { if input.Sprites {
task := GenerateSpriteTask{Scene: *scene, Overwrite: overwrite, fileNamingAlgorithm: fileNamingAlgo} task := GenerateSpriteTask{
Scene: *scene,
Overwrite: overwrite,
fileNamingAlgorithm: fileNamingAlgo,
}
wg.Add() wg.Add()
go task.Start(&wg) go task.Start(&wg)
} }
@@ -448,13 +483,22 @@ func (s *singleton) Generate(input models.GenerateMetadataInput) {
if input.Markers { if input.Markers {
wg.Add() wg.Add()
task := GenerateMarkersTask{Scene: scene, Overwrite: overwrite, fileNamingAlgorithm: fileNamingAlgo} task := GenerateMarkersTask{
TxnManager: s.TxnManager,
Scene: scene,
Overwrite: overwrite,
fileNamingAlgorithm: fileNamingAlgo,
}
go task.Start(&wg) go task.Start(&wg)
} }
if input.Transcodes { if input.Transcodes {
wg.Add() wg.Add()
task := GenerateTranscodeTask{Scene: *scene, Overwrite: overwrite, fileNamingAlgorithm: fileNamingAlgo} task := GenerateTranscodeTask{
Scene: *scene,
Overwrite: overwrite,
fileNamingAlgorithm: fileNamingAlgo,
}
go task.Start(&wg) go task.Start(&wg)
} }
} }
@@ -474,7 +518,12 @@ func (s *singleton) Generate(input models.GenerateMetadataInput) {
} }
wg.Add() wg.Add()
task := GenerateMarkersTask{Marker: marker, Overwrite: overwrite, fileNamingAlgorithm: fileNamingAlgo} task := GenerateMarkersTask{
TxnManager: s.TxnManager,
Marker: marker,
Overwrite: overwrite,
fileNamingAlgorithm: fileNamingAlgo,
}
go task.Start(&wg) go task.Start(&wg)
} }
@@ -502,7 +551,6 @@ func (s *singleton) generateScreenshot(sceneId string, at *float64) {
s.Status.SetStatus(Generate) s.Status.SetStatus(Generate)
s.Status.indefiniteProgress() s.Status.indefiniteProgress()
qb := models.NewSceneQueryBuilder()
instance.Paths.Generated.EnsureTmpDir() instance.Paths.Generated.EnsureTmpDir()
go func() { go func() {
@@ -514,13 +562,18 @@ func (s *singleton) generateScreenshot(sceneId string, at *float64) {
return return
} }
scene, err := qb.Find(sceneIdInt) var scene *models.Scene
if err != nil || scene == nil { if err := s.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error {
logger.Errorf("failed to get scene for generate") var err error
scene, err = r.Scene().Find(sceneIdInt)
return err
}); err != nil || scene == nil {
logger.Errorf("failed to get scene for generate: %s", err.Error())
return return
} }
task := GenerateScreenshotTask{ task := GenerateScreenshotTask{
txnManager: s.TxnManager,
Scene: *scene, Scene: *scene,
ScreenshotAt: at, ScreenshotAt: at,
fileNamingAlgorithm: config.GetVideoFileNamingAlgorithm(), fileNamingAlgorithm: config.GetVideoFileNamingAlgorithm(),
@@ -551,29 +604,36 @@ func (s *singleton) AutoTag(performerIds []string, studioIds []string, tagIds []
studioCount := len(studioIds) studioCount := len(studioIds)
tagCount := len(tagIds) tagCount := len(tagIds)
performerQuery := models.NewPerformerQueryBuilder() if err := s.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error {
studioQuery := models.NewTagQueryBuilder() performerQuery := r.Performer()
tagQuery := models.NewTagQueryBuilder() studioQuery := r.Studio()
tagQuery := r.Tag()
const wildcard = "*" const wildcard = "*"
var err error var err error
if performerCount == 1 && performerIds[0] == wildcard { if performerCount == 1 && performerIds[0] == wildcard {
performerCount, err = performerQuery.Count() performerCount, err = performerQuery.Count()
if err != nil { if err != nil {
logger.Errorf("Error getting performer count: %s", err.Error()) return fmt.Errorf("Error getting performer count: %s", err.Error())
}
} }
} if studioCount == 1 && studioIds[0] == wildcard {
if studioCount == 1 && studioIds[0] == wildcard { studioCount, err = studioQuery.Count()
studioCount, err = studioQuery.Count() if err != nil {
if err != nil { return fmt.Errorf("Error getting studio count: %s", err.Error())
logger.Errorf("Error getting studio count: %s", err.Error()) }
} }
} if tagCount == 1 && tagIds[0] == wildcard {
if tagCount == 1 && tagIds[0] == wildcard { tagCount, err = tagQuery.Count()
tagCount, err = tagQuery.Count() if err != nil {
if err != nil { return fmt.Errorf("Error getting tag count: %s", err.Error())
logger.Errorf("Error getting tag count: %s", err.Error()) }
} }
return nil
}); err != nil {
logger.Error(err.Error())
return
} }
total := performerCount + studioCount + tagCount total := performerCount + studioCount + tagCount
@@ -586,36 +646,44 @@ func (s *singleton) AutoTag(performerIds []string, studioIds []string, tagIds []
} }
func (s *singleton) autoTagPerformers(performerIds []string) { func (s *singleton) autoTagPerformers(performerIds []string) {
performerQuery := models.NewPerformerQueryBuilder()
var wg sync.WaitGroup var wg sync.WaitGroup
for _, performerId := range performerIds { for _, performerId := range performerIds {
var performers []*models.Performer var performers []*models.Performer
if performerId == "*" {
var err error if err := s.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error {
performers, err = performerQuery.All() performerQuery := r.Performer()
if err != nil {
logger.Errorf("Error querying performers: %s", err.Error()) if performerId == "*" {
continue var err error
} performers, err = performerQuery.All()
} else { if err != nil {
performerIdInt, err := strconv.Atoi(performerId) return fmt.Errorf("Error querying performers: %s", err.Error())
if err != nil { }
logger.Errorf("Error parsing performer id %s: %s", performerId, err.Error()) } else {
continue performerIdInt, err := strconv.Atoi(performerId)
if err != nil {
return fmt.Errorf("Error parsing performer id %s: %s", performerId, err.Error())
}
performer, err := performerQuery.Find(performerIdInt)
if err != nil {
return fmt.Errorf("Error finding performer id %s: %s", performerId, err.Error())
}
performers = append(performers, performer)
} }
performer, err := performerQuery.Find(performerIdInt) return nil
if err != nil { }); err != nil {
logger.Errorf("Error finding performer id %s: %s", performerId, err.Error()) logger.Error(err.Error())
continue continue
}
performers = append(performers, performer)
} }
for _, performer := range performers { for _, performer := range performers {
wg.Add(1) wg.Add(1)
task := AutoTagPerformerTask{performer: performer} task := AutoTagPerformerTask{
txnManager: s.TxnManager,
performer: performer,
}
go task.Start(&wg) go task.Start(&wg)
wg.Wait() wg.Wait()
@@ -625,36 +693,43 @@ func (s *singleton) autoTagPerformers(performerIds []string) {
} }
func (s *singleton) autoTagStudios(studioIds []string) { func (s *singleton) autoTagStudios(studioIds []string) {
studioQuery := models.NewStudioQueryBuilder()
var wg sync.WaitGroup var wg sync.WaitGroup
for _, studioId := range studioIds { for _, studioId := range studioIds {
var studios []*models.Studio var studios []*models.Studio
if studioId == "*" {
var err error if err := s.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error {
studios, err = studioQuery.All() studioQuery := r.Studio()
if err != nil { if studioId == "*" {
logger.Errorf("Error querying studios: %s", err.Error()) var err error
continue studios, err = studioQuery.All()
} if err != nil {
} else { return fmt.Errorf("Error querying studios: %s", err.Error())
studioIdInt, err := strconv.Atoi(studioId) }
if err != nil { } else {
logger.Errorf("Error parsing studio id %s: %s", studioId, err.Error()) studioIdInt, err := strconv.Atoi(studioId)
continue if err != nil {
return fmt.Errorf("Error parsing studio id %s: %s", studioId, err.Error())
}
studio, err := studioQuery.Find(studioIdInt)
if err != nil {
return fmt.Errorf("Error finding studio id %s: %s", studioId, err.Error())
}
studios = append(studios, studio)
} }
studio, err := studioQuery.Find(studioIdInt, nil) return nil
if err != nil { }); err != nil {
logger.Errorf("Error finding studio id %s: %s", studioId, err.Error()) logger.Error(err.Error())
continue continue
}
studios = append(studios, studio)
} }
for _, studio := range studios { for _, studio := range studios {
wg.Add(1) wg.Add(1)
task := AutoTagStudioTask{studio: studio} task := AutoTagStudioTask{
studio: studio,
txnManager: s.TxnManager,
}
go task.Start(&wg) go task.Start(&wg)
wg.Wait() wg.Wait()
@@ -664,36 +739,42 @@ func (s *singleton) autoTagStudios(studioIds []string) {
} }
func (s *singleton) autoTagTags(tagIds []string) { func (s *singleton) autoTagTags(tagIds []string) {
tagQuery := models.NewTagQueryBuilder()
var wg sync.WaitGroup var wg sync.WaitGroup
for _, tagId := range tagIds { for _, tagId := range tagIds {
var tags []*models.Tag var tags []*models.Tag
if tagId == "*" { if err := s.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error {
var err error tagQuery := r.Tag()
tags, err = tagQuery.All() if tagId == "*" {
if err != nil { var err error
logger.Errorf("Error querying tags: %s", err.Error()) tags, err = tagQuery.All()
continue if err != nil {
} return fmt.Errorf("Error querying tags: %s", err.Error())
} else { }
tagIdInt, err := strconv.Atoi(tagId) } else {
if err != nil { tagIdInt, err := strconv.Atoi(tagId)
logger.Errorf("Error parsing tag id %s: %s", tagId, err.Error()) if err != nil {
continue return fmt.Errorf("Error parsing tag id %s: %s", tagId, err.Error())
}
tag, err := tagQuery.Find(tagIdInt)
if err != nil {
return fmt.Errorf("Error finding tag id %s: %s", tagId, err.Error())
}
tags = append(tags, tag)
} }
tag, err := tagQuery.Find(tagIdInt, nil) return nil
if err != nil { }); err != nil {
logger.Errorf("Error finding tag id %s: %s", tagId, err.Error()) logger.Error(err.Error())
continue continue
}
tags = append(tags, tag)
} }
for _, tag := range tags { for _, tag := range tags {
wg.Add(1) wg.Add(1)
task := AutoTagTagTask{tag: tag} task := AutoTagTagTask{
txnManager: s.TxnManager,
tag: tag,
}
go task.Start(&wg) go task.Start(&wg)
wg.Wait() wg.Wait()
@@ -709,28 +790,38 @@ func (s *singleton) Clean() {
s.Status.SetStatus(Clean) s.Status.SetStatus(Clean)
s.Status.indefiniteProgress() s.Status.indefiniteProgress()
qb := models.NewSceneQueryBuilder()
iqb := models.NewImageQueryBuilder()
gqb := models.NewGalleryQueryBuilder()
go func() { go func() {
defer s.returnToIdleState() defer s.returnToIdleState()
logger.Infof("Starting cleaning of tracked files") var scenes []*models.Scene
scenes, err := qb.All() var images []*models.Image
if err != nil { var galleries []*models.Gallery
logger.Errorf("failed to fetch list of scenes for cleaning")
return
}
images, err := iqb.All() if err := s.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error {
if err != nil { qb := r.Scene()
logger.Errorf("failed to fetch list of images for cleaning") iqb := r.Image()
return gqb := r.Gallery()
}
galleries, err := gqb.All() logger.Infof("Starting cleaning of tracked files")
if err != nil { var err error
logger.Errorf("failed to fetch list of galleries for cleaning") scenes, err = qb.All()
if err != nil {
return errors.New("failed to fetch list of scenes for cleaning")
}
images, err = iqb.All()
if err != nil {
return errors.New("failed to fetch list of images for cleaning")
}
galleries, err = gqb.All()
if err != nil {
return errors.New("failed to fetch list of galleries for cleaning")
}
return nil
}); err != nil {
logger.Error(err.Error())
return return
} }
@@ -757,7 +848,11 @@ func (s *singleton) Clean() {
wg.Add(1) wg.Add(1)
task := CleanTask{Scene: scene, fileNamingAlgorithm: fileNamingAlgo} task := CleanTask{
TxnManager: s.TxnManager,
Scene: scene,
fileNamingAlgorithm: fileNamingAlgo,
}
go task.Start(&wg) go task.Start(&wg)
wg.Wait() wg.Wait()
} }
@@ -776,7 +871,10 @@ func (s *singleton) Clean() {
wg.Add(1) wg.Add(1)
task := CleanTask{Image: img} task := CleanTask{
TxnManager: s.TxnManager,
Image: img,
}
go task.Start(&wg) go task.Start(&wg)
wg.Wait() wg.Wait()
} }
@@ -795,7 +893,10 @@ func (s *singleton) Clean() {
wg.Add(1) wg.Add(1)
task := CleanTask{Gallery: gallery} task := CleanTask{
TxnManager: s.TxnManager,
Gallery: gallery,
}
go task.Start(&wg) go task.Start(&wg)
wg.Wait() wg.Wait()
} }
@@ -811,17 +912,19 @@ func (s *singleton) MigrateHash() {
s.Status.SetStatus(Migrate) s.Status.SetStatus(Migrate)
s.Status.indefiniteProgress() s.Status.indefiniteProgress()
qb := models.NewSceneQueryBuilder()
go func() { go func() {
defer s.returnToIdleState() defer s.returnToIdleState()
fileNamingAlgo := config.GetVideoFileNamingAlgorithm() fileNamingAlgo := config.GetVideoFileNamingAlgorithm()
logger.Infof("Migrating generated files for %s naming hash", fileNamingAlgo.String()) logger.Infof("Migrating generated files for %s naming hash", fileNamingAlgo.String())
scenes, err := qb.All() var scenes []*models.Scene
if err != nil { if err := s.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error {
logger.Errorf("failed to fetch list of scenes for migration") var err error
scenes, err = r.Scene().All()
return err
}); err != nil {
logger.Errorf("failed to fetch list of scenes for migration: %s", err.Error())
return return
} }
@@ -926,6 +1029,7 @@ func (s *singleton) neededGenerate(scenes []*models.Scene, input models.Generate
if input.Markers { if input.Markers {
task := GenerateMarkersTask{ task := GenerateMarkersTask{
TxnManager: s.TxnManager,
Scene: scene, Scene: scene,
Overwrite: overwrite, Overwrite: overwrite,
fileNamingAlgorithm: fileNamingAlgo, fileNamingAlgorithm: fileNamingAlgo,

View File

@@ -2,5 +2,5 @@ package manager
// PostMigrate is executed after migrations have been executed. // PostMigrate is executed after migrations have been executed.
func (s *singleton) PostMigrate() { func (s *singleton) PostMigrate() {
setInitialMD5Config() setInitialMD5Config(s.TxnManager)
} }

View File

@@ -4,9 +4,6 @@ import (
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
"strconv"
"github.com/jmoiron/sqlx"
"github.com/stashapp/stash/pkg/ffmpeg" "github.com/stashapp/stash/pkg/ffmpeg"
"github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/logger"
@@ -16,37 +13,54 @@ import (
) )
// DestroyScene deletes a scene and its associated relationships from the // DestroyScene deletes a scene and its associated relationships from the
// database. // database. Returns a function to perform any post-commit actions.
func DestroyScene(sceneID int, tx *sqlx.Tx) error { func DestroyScene(scene *models.Scene, repo models.Repository) (func(), error) {
qb := models.NewSceneQueryBuilder() qb := repo.Scene()
jqb := models.NewJoinsQueryBuilder() mqb := repo.SceneMarker()
gqb := repo.Gallery()
_, err := qb.Find(sceneID) if err := gqb.ClearGalleryId(scene.ID); err != nil {
return nil, err
}
markers, err := mqb.FindBySceneID(scene.ID)
if err != nil { if err != nil {
return err return nil, err
} }
if err := jqb.DestroyScenesTags(sceneID, tx); err != nil { var funcs []func()
return err for _, m := range markers {
f, err := DestroySceneMarker(scene, m, mqb)
if err != nil {
return nil, err
}
funcs = append(funcs, f)
} }
if err := jqb.DestroyPerformersScenes(sceneID, tx); err != nil { if err := qb.Destroy(scene.ID); err != nil {
return err return nil, err
} }
if err := jqb.DestroyScenesMarkers(sceneID, tx); err != nil { return func() {
return err for _, f := range funcs {
f()
}
}, nil
}
// DestroySceneMarker deletes the scene marker from the database and returns a
// function that removes the generated files, to be executed after the
// transaction is successfully committed.
func DestroySceneMarker(scene *models.Scene, sceneMarker *models.SceneMarker, qb models.SceneMarkerWriter) (func(), error) {
if err := qb.Destroy(sceneMarker.ID); err != nil {
return nil, err
} }
if err := jqb.DestroyScenesGalleries(sceneID, tx); err != nil { // delete the preview for the marker
return err return func() {
} seconds := int(sceneMarker.Seconds)
DeleteSceneMarkerFiles(scene, seconds, config.GetVideoFileNamingAlgorithm())
if err := qb.Destroy(strconv.Itoa(sceneID), tx); err != nil { }, nil
return err
}
return nil
} }
// DeleteGeneratedSceneFiles deletes generated files for the provided scene. // DeleteGeneratedSceneFiles deletes generated files for the provided scene.

View File

@@ -4,18 +4,16 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/jmoiron/sqlx"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
) )
func ValidateModifyStudio(studio models.StudioPartial, tx *sqlx.Tx) error { func ValidateModifyStudio(studio models.StudioPartial, qb models.StudioReader) error {
if studio.ParentID == nil || !studio.ParentID.Valid { if studio.ParentID == nil || !studio.ParentID.Valid {
return nil return nil
} }
// ensure there is no cyclic dependency // ensure there is no cyclic dependency
thisID := studio.ID thisID := studio.ID
qb := models.NewStudioQueryBuilder()
currentParentID := *studio.ParentID currentParentID := *studio.ParentID
@@ -24,7 +22,7 @@ func ValidateModifyStudio(studio models.StudioPartial, tx *sqlx.Tx) error {
return errors.New("studio cannot be an ancestor of itself") return errors.New("studio cannot be an ancestor of itself")
} }
currentStudio, err := qb.Find(int(currentParentID.Int64), tx) currentStudio, err := qb.Find(int(currentParentID.Int64))
if err != nil { if err != nil {
return fmt.Errorf("error finding parent studio: %s", err.Error()) return fmt.Errorf("error finding parent studio: %s", err.Error())
} }

View File

@@ -3,17 +3,13 @@ package manager
import ( import (
"fmt" "fmt"
"github.com/jmoiron/sqlx"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
) )
func EnsureTagNameUnique(tag models.Tag, tx *sqlx.Tx) error { func EnsureTagNameUnique(tag models.Tag, qb models.TagReader) error {
qb := models.NewTagQueryBuilder()
// ensure name is unique // ensure name is unique
sameNameTag, err := qb.FindByName(tag.Name, tx, true) sameNameTag, err := qb.FindByName(tag.Name, true)
if err != nil { if err != nil {
_ = tx.Rollback()
return err return err
} }

View File

@@ -3,16 +3,18 @@ package manager
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"strings" "strings"
"sync" "sync"
"github.com/stashapp/stash/pkg/database"
"github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/scene"
) )
type AutoTagPerformerTask struct { type AutoTagPerformerTask struct {
performer *models.Performer performer *models.Performer
txnManager models.TransactionManager
} }
func (t *AutoTagPerformerTask) Start(wg *sync.WaitGroup) { func (t *AutoTagPerformerTask) Start(wg *sync.WaitGroup) {
@@ -32,44 +34,38 @@ func getQueryRegex(name string) string {
} }
func (t *AutoTagPerformerTask) autoTagPerformer() { func (t *AutoTagPerformerTask) autoTagPerformer() {
qb := models.NewSceneQueryBuilder()
jqb := models.NewJoinsQueryBuilder()
regex := getQueryRegex(t.performer.Name.String) regex := getQueryRegex(t.performer.Name.String)
const ignoreOrganized = true if err := t.txnManager.WithTxn(context.TODO(), func(r models.Repository) error {
scenes, err := qb.QueryAllByPathRegex(regex, ignoreOrganized) qb := r.Scene()
const ignoreOrganized = true
if err != nil { scenes, err := qb.QueryAllByPathRegex(regex, ignoreOrganized)
logger.Infof("Error querying scenes with regex '%s': %s", regex, err.Error())
return
}
ctx := context.TODO()
tx := database.DB.MustBeginTx(ctx, nil)
for _, scene := range scenes {
added, err := jqb.AddPerformerScene(scene.ID, t.performer.ID, tx)
if err != nil { if err != nil {
logger.Infof("Error adding performer '%s' to scene '%s': %s", t.performer.Name.String, scene.GetTitle(), err.Error()) return fmt.Errorf("Error querying scenes with regex '%s': %s", regex, err.Error())
tx.Rollback()
return
} }
if added { for _, s := range scenes {
logger.Infof("Added performer '%s' to scene '%s'", t.performer.Name.String, scene.GetTitle()) added, err := scene.AddPerformer(qb, s.ID, t.performer.ID)
}
}
if err := tx.Commit(); err != nil { if err != nil {
logger.Infof("Error adding performer to scene: %s", err.Error()) return fmt.Errorf("Error adding performer '%s' to scene '%s': %s", t.performer.Name.String, s.GetTitle(), err.Error())
return }
if added {
logger.Infof("Added performer '%s' to scene '%s'", t.performer.Name.String, s.GetTitle())
}
}
return nil
}); err != nil {
logger.Error(err.Error())
} }
} }
type AutoTagStudioTask struct { type AutoTagStudioTask struct {
studio *models.Studio studio *models.Studio
txnManager models.TransactionManager
} }
func (t *AutoTagStudioTask) Start(wg *sync.WaitGroup) { func (t *AutoTagStudioTask) Start(wg *sync.WaitGroup) {
@@ -79,54 +75,47 @@ func (t *AutoTagStudioTask) Start(wg *sync.WaitGroup) {
} }
func (t *AutoTagStudioTask) autoTagStudio() { func (t *AutoTagStudioTask) autoTagStudio() {
qb := models.NewSceneQueryBuilder()
regex := getQueryRegex(t.studio.Name.String) regex := getQueryRegex(t.studio.Name.String)
const ignoreOrganized = true if err := t.txnManager.WithTxn(context.TODO(), func(r models.Repository) error {
scenes, err := qb.QueryAllByPathRegex(regex, ignoreOrganized) qb := r.Scene()
const ignoreOrganized = true
if err != nil { scenes, err := qb.QueryAllByPathRegex(regex, ignoreOrganized)
logger.Infof("Error querying scenes with regex '%s': %s", regex, err.Error())
return
}
ctx := context.TODO()
tx := database.DB.MustBeginTx(ctx, nil)
for _, scene := range scenes {
// #306 - don't overwrite studio if already present
if scene.StudioID.Valid {
// don't modify
continue
}
logger.Infof("Adding studio '%s' to scene '%s'", t.studio.Name.String, scene.GetTitle())
// set the studio id
studioID := sql.NullInt64{Int64: int64(t.studio.ID), Valid: true}
scenePartial := models.ScenePartial{
ID: scene.ID,
StudioID: &studioID,
}
_, err := qb.Update(scenePartial, tx)
if err != nil { if err != nil {
logger.Infof("Error adding studio to scene: %s", err.Error()) return fmt.Errorf("Error querying scenes with regex '%s': %s", regex, err.Error())
tx.Rollback()
return
} }
}
if err := tx.Commit(); err != nil { for _, scene := range scenes {
logger.Infof("Error adding studio to scene: %s", err.Error()) // #306 - don't overwrite studio if already present
return if scene.StudioID.Valid {
// don't modify
continue
}
logger.Infof("Adding studio '%s' to scene '%s'", t.studio.Name.String, scene.GetTitle())
// set the studio id
studioID := sql.NullInt64{Int64: int64(t.studio.ID), Valid: true}
scenePartial := models.ScenePartial{
ID: scene.ID,
StudioID: &studioID,
}
if _, err := qb.Update(scenePartial); err != nil {
return fmt.Errorf("Error adding studio to scene: %s", err.Error())
}
}
return nil
}); err != nil {
logger.Error(err.Error())
} }
} }
type AutoTagTagTask struct { type AutoTagTagTask struct {
tag *models.Tag tag *models.Tag
txnManager models.TransactionManager
} }
func (t *AutoTagTagTask) Start(wg *sync.WaitGroup) { func (t *AutoTagTagTask) Start(wg *sync.WaitGroup) {
@@ -136,38 +125,31 @@ func (t *AutoTagTagTask) Start(wg *sync.WaitGroup) {
} }
func (t *AutoTagTagTask) autoTagTag() { func (t *AutoTagTagTask) autoTagTag() {
qb := models.NewSceneQueryBuilder()
jqb := models.NewJoinsQueryBuilder()
regex := getQueryRegex(t.tag.Name) regex := getQueryRegex(t.tag.Name)
const ignoreOrganized = true if err := t.txnManager.WithTxn(context.TODO(), func(r models.Repository) error {
scenes, err := qb.QueryAllByPathRegex(regex, ignoreOrganized) qb := r.Scene()
const ignoreOrganized = true
if err != nil { scenes, err := qb.QueryAllByPathRegex(regex, ignoreOrganized)
logger.Infof("Error querying scenes with regex '%s': %s", regex, err.Error())
return
}
ctx := context.TODO()
tx := database.DB.MustBeginTx(ctx, nil)
for _, scene := range scenes {
added, err := jqb.AddSceneTag(scene.ID, t.tag.ID, tx)
if err != nil { if err != nil {
logger.Infof("Error adding tag '%s' to scene '%s': %s", t.tag.Name, scene.GetTitle(), err.Error()) return fmt.Errorf("Error querying scenes with regex '%s': %s", regex, err.Error())
tx.Rollback()
return
} }
if added { for _, s := range scenes {
logger.Infof("Added tag '%s' to scene '%s'", t.tag.Name, scene.GetTitle()) added, err := scene.AddTag(qb, s.ID, t.tag.ID)
}
}
if err := tx.Commit(); err != nil { if err != nil {
logger.Infof("Error adding tag to scene: %s", err.Error()) return fmt.Errorf("Error adding tag '%s' to scene '%s': %s", t.tag.Name, s.GetTitle(), err.Error())
return }
if added {
logger.Infof("Added tag '%s' to scene '%s'", t.tag.Name, s.GetTitle())
}
}
return nil
}); err != nil {
logger.Error(err.Error())
} }
} }

View File

@@ -14,11 +14,11 @@ import (
"github.com/stashapp/stash/pkg/database" "github.com/stashapp/stash/pkg/database"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/sqlite"
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
_ "github.com/golang-migrate/migrate/v4/database/sqlite3" _ "github.com/golang-migrate/migrate/v4/database/sqlite3"
_ "github.com/golang-migrate/migrate/v4/source/file" _ "github.com/golang-migrate/migrate/v4/source/file"
"github.com/jmoiron/sqlx"
) )
const testName = "Foo's Bar" const testName = "Foo's Bar"
@@ -106,17 +106,15 @@ func TestMain(m *testing.M) {
os.Exit(ret) os.Exit(ret)
} }
func createPerformer(tx *sqlx.Tx) error { func createPerformer(pqb models.PerformerWriter) error {
// create the performer // create the performer
pqb := models.NewPerformerQueryBuilder()
performer := models.Performer{ performer := models.Performer{
Checksum: testName, Checksum: testName,
Name: sql.NullString{Valid: true, String: testName}, Name: sql.NullString{Valid: true, String: testName},
Favorite: sql.NullBool{Valid: true, Bool: false}, Favorite: sql.NullBool{Valid: true, Bool: false},
} }
_, err := pqb.Create(performer, tx) _, err := pqb.Create(performer)
if err != nil { if err != nil {
return err return err
} }
@@ -124,27 +122,23 @@ func createPerformer(tx *sqlx.Tx) error {
return nil return nil
} }
func createStudio(tx *sqlx.Tx, name string) (*models.Studio, error) { func createStudio(qb models.StudioWriter, name string) (*models.Studio, error) {
// create the studio // create the studio
qb := models.NewStudioQueryBuilder()
studio := models.Studio{ studio := models.Studio{
Checksum: name, Checksum: name,
Name: sql.NullString{Valid: true, String: testName}, Name: sql.NullString{Valid: true, String: testName},
} }
return qb.Create(studio, tx) return qb.Create(studio)
} }
func createTag(tx *sqlx.Tx) error { func createTag(qb models.TagWriter) error {
// create the studio // create the studio
qb := models.NewTagQueryBuilder()
tag := models.Tag{ tag := models.Tag{
Name: testName, Name: testName,
} }
_, err := qb.Create(tag, tx) _, err := qb.Create(tag)
if err != nil { if err != nil {
return err return err
} }
@@ -152,9 +146,7 @@ func createTag(tx *sqlx.Tx) error {
return nil return nil
} }
func createScenes(tx *sqlx.Tx) error { func createScenes(sqb models.SceneReaderWriter) error {
sqb := models.NewSceneQueryBuilder()
// create the scenes // create the scenes
var scenePatterns []string var scenePatterns []string
var falseScenePatterns []string var falseScenePatterns []string
@@ -175,13 +167,13 @@ func createScenes(tx *sqlx.Tx) error {
} }
for _, fn := range scenePatterns { for _, fn := range scenePatterns {
err := createScene(sqb, tx, makeScene(fn, true)) err := createScene(sqb, makeScene(fn, true))
if err != nil { if err != nil {
return err return err
} }
} }
for _, fn := range falseScenePatterns { for _, fn := range falseScenePatterns {
err := createScene(sqb, tx, makeScene(fn, false)) err := createScene(sqb, makeScene(fn, false))
if err != nil { if err != nil {
return err return err
} }
@@ -191,7 +183,7 @@ func createScenes(tx *sqlx.Tx) error {
for _, fn := range scenePatterns { for _, fn := range scenePatterns {
s := makeScene("organized"+fn, false) s := makeScene("organized"+fn, false)
s.Organized = true s.Organized = true
err := createScene(sqb, tx, s) err := createScene(sqb, s)
if err != nil { if err != nil {
return err return err
} }
@@ -200,7 +192,7 @@ func createScenes(tx *sqlx.Tx) error {
// create scene with existing studio io // create scene with existing studio io
studioScene := makeScene(existingStudioSceneName, true) studioScene := makeScene(existingStudioSceneName, true)
studioScene.StudioID = sql.NullInt64{Valid: true, Int64: int64(existingStudioID)} studioScene.StudioID = sql.NullInt64{Valid: true, Int64: int64(existingStudioID)}
err := createScene(sqb, tx, studioScene) err := createScene(sqb, studioScene)
if err != nil { if err != nil {
return err return err
} }
@@ -222,8 +214,8 @@ func makeScene(name string, expectedResult bool) *models.Scene {
return scene return scene
} }
func createScene(sqb models.SceneQueryBuilder, tx *sqlx.Tx, scene *models.Scene) error { func createScene(sqb models.SceneWriter, scene *models.Scene) error {
_, err := sqb.Create(*scene, tx) _, err := sqb.Create(*scene)
if err != nil { if err != nil {
return fmt.Errorf("Failed to create scene with name '%s': %s", scene.Path, err.Error()) return fmt.Errorf("Failed to create scene with name '%s': %s", scene.Path, err.Error())
@@ -232,39 +224,43 @@ func createScene(sqb models.SceneQueryBuilder, tx *sqlx.Tx, scene *models.Scene)
return nil return nil
} }
func withTxn(f func(r models.Repository) error) error {
t := sqlite.NewTransactionManager()
return t.WithTxn(context.TODO(), f)
}
func populateDB() error { func populateDB() error {
ctx := context.TODO() if err := withTxn(func(r models.Repository) error {
tx := database.DB.MustBeginTx(ctx, nil) err := createPerformer(r.Performer())
if err != nil {
return err
}
err := createPerformer(tx) _, err = createStudio(r.Studio(), testName)
if err != nil { if err != nil {
return err return err
} }
_, err = createStudio(tx, testName) // create existing studio
if err != nil { existingStudio, err := createStudio(r.Studio(), existingStudioName)
return err if err != nil {
} return err
}
// create existing studio existingStudioID = existingStudio.ID
existingStudio, err := createStudio(tx, existingStudioName)
if err != nil {
return err
}
existingStudioID = existingStudio.ID err = createTag(r.Tag())
if err != nil {
return err
}
err = createTag(tx) err = createScenes(r.Scene())
if err != nil { if err != nil {
return err return err
} }
err = createScenes(tx) return nil
if err != nil { }); err != nil {
return err
}
if err := tx.Commit(); err != nil {
return err return err
} }
@@ -272,16 +268,19 @@ func populateDB() error {
} }
func TestParsePerformers(t *testing.T) { func TestParsePerformers(t *testing.T) {
pqb := models.NewPerformerQueryBuilder() var performers []*models.Performer
performers, err := pqb.All() if err := withTxn(func(r models.Repository) error {
var err error
if err != nil { performers, err = r.Performer().All()
return err
}); err != nil {
t.Errorf("Error getting performer: %s", err) t.Errorf("Error getting performer: %s", err)
return return
} }
task := AutoTagPerformerTask{ task := AutoTagPerformerTask{
performer: performers[0], performer: performers[0],
txnManager: sqlite.NewTransactionManager(),
} }
var wg sync.WaitGroup var wg sync.WaitGroup
@@ -289,38 +288,47 @@ func TestParsePerformers(t *testing.T) {
task.Start(&wg) task.Start(&wg)
// verify that scenes were tagged correctly // verify that scenes were tagged correctly
sqb := models.NewSceneQueryBuilder() withTxn(func(r models.Repository) error {
pqb := r.Performer()
scenes, err := sqb.All()
for _, scene := range scenes {
performers, err := pqb.FindBySceneID(scene.ID, nil)
scenes, err := r.Scene().All()
if err != nil { if err != nil {
t.Errorf("Error getting scene performers: %s", err.Error()) t.Error(err.Error())
return
} }
// title is only set on scenes where we expect performer to be set for _, scene := range scenes {
if scene.Title.String == scene.Path && len(performers) == 0 { performers, err := pqb.FindBySceneID(scene.ID)
t.Errorf("Did not set performer '%s' for path '%s'", testName, scene.Path)
} else if scene.Title.String != scene.Path && len(performers) > 0 { if err != nil {
t.Errorf("Incorrectly set performer '%s' for path '%s'", testName, scene.Path) t.Errorf("Error getting scene performers: %s", err.Error())
}
// title is only set on scenes where we expect performer to be set
if scene.Title.String == scene.Path && len(performers) == 0 {
t.Errorf("Did not set performer '%s' for path '%s'", testName, scene.Path)
} else if scene.Title.String != scene.Path && len(performers) > 0 {
t.Errorf("Incorrectly set performer '%s' for path '%s'", testName, scene.Path)
}
} }
}
return nil
})
} }
func TestParseStudios(t *testing.T) { func TestParseStudios(t *testing.T) {
studioQuery := models.NewStudioQueryBuilder() var studios []*models.Studio
studios, err := studioQuery.All() if err := withTxn(func(r models.Repository) error {
var err error
if err != nil { studios, err = r.Studio().All()
return err
}); err != nil {
t.Errorf("Error getting studio: %s", err) t.Errorf("Error getting studio: %s", err)
return return
} }
task := AutoTagStudioTask{ task := AutoTagStudioTask{
studio: studios[0], studio: studios[0],
txnManager: sqlite.NewTransactionManager(),
} }
var wg sync.WaitGroup var wg sync.WaitGroup
@@ -328,38 +336,46 @@ func TestParseStudios(t *testing.T) {
task.Start(&wg) task.Start(&wg)
// verify that scenes were tagged correctly // verify that scenes were tagged correctly
sqb := models.NewSceneQueryBuilder() withTxn(func(r models.Repository) error {
scenes, err := r.Scene().All()
if err != nil {
t.Error(err.Error())
}
scenes, err := sqb.All() for _, scene := range scenes {
// check for existing studio id scene first
for _, scene := range scenes { if scene.Path == existingStudioSceneName {
// check for existing studio id scene first if scene.StudioID.Int64 != int64(existingStudioID) {
if scene.Path == existingStudioSceneName { t.Error("Incorrectly overwrote studio ID for scene with existing studio ID")
if scene.StudioID.Int64 != int64(existingStudioID) { }
t.Error("Incorrectly overwrote studio ID for scene with existing studio ID") } else {
} // title is only set on scenes where we expect studio to be set
} else { if scene.Title.String == scene.Path && scene.StudioID.Int64 != int64(studios[0].ID) {
// title is only set on scenes where we expect studio to be set t.Errorf("Did not set studio '%s' for path '%s'", testName, scene.Path)
if scene.Title.String == scene.Path && scene.StudioID.Int64 != int64(studios[0].ID) { } else if scene.Title.String != scene.Path && scene.StudioID.Int64 == int64(studios[0].ID) {
t.Errorf("Did not set studio '%s' for path '%s'", testName, scene.Path) t.Errorf("Incorrectly set studio '%s' for path '%s'", testName, scene.Path)
} else if scene.Title.String != scene.Path && scene.StudioID.Int64 == int64(studios[0].ID) { }
t.Errorf("Incorrectly set studio '%s' for path '%s'", testName, scene.Path)
} }
} }
}
return nil
})
} }
func TestParseTags(t *testing.T) { func TestParseTags(t *testing.T) {
tagQuery := models.NewTagQueryBuilder() var tags []*models.Tag
tags, err := tagQuery.All() if err := withTxn(func(r models.Repository) error {
var err error
if err != nil { tags, err = r.Tag().All()
return err
}); err != nil {
t.Errorf("Error getting performer: %s", err) t.Errorf("Error getting performer: %s", err)
return return
} }
task := AutoTagTagTask{ task := AutoTagTagTask{
tag: tags[0], tag: tags[0],
txnManager: sqlite.NewTransactionManager(),
} }
var wg sync.WaitGroup var wg sync.WaitGroup
@@ -367,23 +383,29 @@ func TestParseTags(t *testing.T) {
task.Start(&wg) task.Start(&wg)
// verify that scenes were tagged correctly // verify that scenes were tagged correctly
sqb := models.NewSceneQueryBuilder() withTxn(func(r models.Repository) error {
scenes, err := r.Scene().All()
scenes, err := sqb.All()
for _, scene := range scenes {
tags, err := tagQuery.FindBySceneID(scene.ID, nil)
if err != nil { if err != nil {
t.Errorf("Error getting scene tags: %s", err.Error()) t.Error(err.Error())
return
} }
// title is only set on scenes where we expect performer to be set tqb := r.Tag()
if scene.Title.String == scene.Path && len(tags) == 0 {
t.Errorf("Did not set tag '%s' for path '%s'", testName, scene.Path) for _, scene := range scenes {
} else if scene.Title.String != scene.Path && len(tags) > 0 { tags, err := tqb.FindBySceneID(scene.ID)
t.Errorf("Incorrectly set tag '%s' for path '%s'", testName, scene.Path)
if err != nil {
t.Errorf("Error getting scene tags: %s", err.Error())
}
// title is only set on scenes where we expect performer to be set
if scene.Title.String == scene.Path && len(tags) == 0 {
t.Errorf("Did not set tag '%s' for path '%s'", testName, scene.Path)
} else if scene.Title.String != scene.Path && len(tags) > 0 {
t.Errorf("Incorrectly set tag '%s' for path '%s'", testName, scene.Path)
}
} }
}
return nil
})
} }

View File

@@ -7,7 +7,6 @@ import (
"strings" "strings"
"sync" "sync"
"github.com/stashapp/stash/pkg/database"
"github.com/stashapp/stash/pkg/image" "github.com/stashapp/stash/pkg/image"
"github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/manager/config" "github.com/stashapp/stash/pkg/manager/config"
@@ -15,6 +14,7 @@ import (
) )
type CleanTask struct { type CleanTask struct {
TxnManager models.TransactionManager
Scene *models.Scene Scene *models.Scene
Gallery *models.Gallery Gallery *models.Gallery
Image *models.Image Image *models.Image
@@ -133,60 +133,45 @@ func (t *CleanTask) shouldCleanImage(s *models.Image) bool {
} }
func (t *CleanTask) deleteScene(sceneID int) { func (t *CleanTask) deleteScene(sceneID int) {
ctx := context.TODO() var postCommitFunc func()
qb := models.NewSceneQueryBuilder() var scene *models.Scene
tx := database.DB.MustBeginTx(ctx, nil) if err := t.TxnManager.WithTxn(context.TODO(), func(repo models.Repository) error {
qb := repo.Scene()
scene, err := qb.Find(sceneID) var err error
err = DestroyScene(sceneID, tx) scene, err = qb.Find(sceneID)
if err != nil {
if err != nil { return err
logger.Errorf("Error deleting scene from database: %s", err.Error()) }
tx.Rollback() postCommitFunc, err = DestroyScene(scene, repo)
return return err
} }); err != nil {
if err := tx.Commit(); err != nil {
logger.Errorf("Error deleting scene from database: %s", err.Error()) logger.Errorf("Error deleting scene from database: %s", err.Error())
return return
} }
postCommitFunc()
DeleteGeneratedSceneFiles(scene, t.fileNamingAlgorithm) DeleteGeneratedSceneFiles(scene, t.fileNamingAlgorithm)
} }
func (t *CleanTask) deleteGallery(galleryID int) { func (t *CleanTask) deleteGallery(galleryID int) {
ctx := context.TODO() if err := t.TxnManager.WithTxn(context.TODO(), func(repo models.Repository) error {
qb := models.NewGalleryQueryBuilder() qb := repo.Gallery()
tx := database.DB.MustBeginTx(ctx, nil) return qb.Destroy(galleryID)
}); err != nil {
err := qb.Destroy(galleryID, tx)
if err != nil {
logger.Errorf("Error deleting gallery from database: %s", err.Error())
tx.Rollback()
return
}
if err := tx.Commit(); err != nil {
logger.Errorf("Error deleting gallery from database: %s", err.Error()) logger.Errorf("Error deleting gallery from database: %s", err.Error())
return return
} }
} }
func (t *CleanTask) deleteImage(imageID int) { func (t *CleanTask) deleteImage(imageID int) {
ctx := context.TODO()
qb := models.NewImageQueryBuilder()
tx := database.DB.MustBeginTx(ctx, nil)
err := qb.Destroy(imageID, tx) if err := t.TxnManager.WithTxn(context.TODO(), func(repo models.Repository) error {
qb := repo.Image()
if err != nil { return qb.Destroy(imageID)
logger.Errorf("Error deleting image from database: %s", err.Error()) }); err != nil {
tx.Rollback()
return
}
if err := tx.Commit(); err != nil {
logger.Errorf("Error deleting image from database: %s", err.Error()) logger.Errorf("Error deleting image from database: %s", err.Error())
return return
} }

View File

@@ -2,6 +2,7 @@ package manager
import ( import (
"archive/zip" "archive/zip"
"context"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
@@ -27,7 +28,8 @@ import (
) )
type ExportTask struct { type ExportTask struct {
full bool txnManager models.TransactionManager
full bool
baseDir string baseDir string
json jsonUtils json jsonUtils
@@ -58,8 +60,10 @@ func newExportSpec(input *models.ExportObjectTypeInput) *exportSpec {
return &exportSpec{} return &exportSpec{}
} }
ids, _ := utils.StringSliceToIntSlice(input.Ids)
ret := &exportSpec{ ret := &exportSpec{
IDs: utils.StringSliceToIntSlice(input.Ids), IDs: ids,
} }
if input.All != nil { if input.All != nil {
@@ -76,6 +80,7 @@ func CreateExportTask(a models.HashAlgorithm, input models.ExportObjectsInput) *
} }
return &ExportTask{ return &ExportTask{
txnManager: GetInstance().TxnManager,
fileNamingAlgorithm: a, fileNamingAlgorithm: a,
scenes: newExportSpec(input.Scenes), scenes: newExportSpec(input.Scenes),
images: newExportSpec(input.Images), images: newExportSpec(input.Images),
@@ -125,34 +130,40 @@ func (t *ExportTask) Start(wg *sync.WaitGroup) {
paths.EnsureJSONDirs(t.baseDir) paths.EnsureJSONDirs(t.baseDir)
// include movie scenes and gallery images t.txnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error {
if !t.full { // include movie scenes and gallery images
// only include movie scenes if includeDependencies is also set if !t.full {
if !t.scenes.all && t.includeDependencies { // only include movie scenes if includeDependencies is also set
t.populateMovieScenes() if !t.scenes.all && t.includeDependencies {
t.populateMovieScenes(r)
}
// always export gallery images
if !t.images.all {
t.populateGalleryImages(r)
}
} }
// always export gallery images t.ExportScenes(workerCount, r)
if !t.images.all { t.ExportImages(workerCount, r)
t.populateGalleryImages() t.ExportGalleries(workerCount, r)
} t.ExportMovies(workerCount, r)
} t.ExportPerformers(workerCount, r)
t.ExportStudios(workerCount, r)
t.ExportTags(workerCount, r)
t.ExportScenes(workerCount) if t.full {
t.ExportImages(workerCount) t.ExportScrapedItems(r)
t.ExportGalleries(workerCount) }
t.ExportMovies(workerCount)
t.ExportPerformers(workerCount) return nil
t.ExportStudios(workerCount) })
t.ExportTags(workerCount)
if err := t.json.saveMappings(t.Mappings); err != nil { if err := t.json.saveMappings(t.Mappings); err != nil {
logger.Errorf("[mappings] failed to save json: %s", err.Error()) logger.Errorf("[mappings] failed to save json: %s", err.Error())
} }
if t.full { if !t.full {
t.ExportScrapedItems()
} else {
err := t.generateDownload() err := t.generateDownload()
if err != nil { if err != nil {
logger.Errorf("error generating download link: %s", err.Error()) logger.Errorf("error generating download link: %s", err.Error())
@@ -242,9 +253,9 @@ func (t *ExportTask) zipFile(fn, outDir string, z *zip.Writer) error {
return nil return nil
} }
func (t *ExportTask) populateMovieScenes() { func (t *ExportTask) populateMovieScenes(repo models.ReaderRepository) {
reader := models.NewMovieReaderWriter(nil) reader := repo.Movie()
sceneReader := models.NewSceneReaderWriter(nil) sceneReader := repo.Scene()
var movies []*models.Movie var movies []*models.Movie
var err error var err error
@@ -272,9 +283,9 @@ func (t *ExportTask) populateMovieScenes() {
} }
} }
func (t *ExportTask) populateGalleryImages() { func (t *ExportTask) populateGalleryImages(repo models.ReaderRepository) {
reader := models.NewGalleryReaderWriter(nil) reader := repo.Gallery()
imageReader := models.NewImageReaderWriter(nil) imageReader := repo.Image()
var galleries []*models.Gallery var galleries []*models.Gallery
var err error var err error
@@ -302,10 +313,10 @@ func (t *ExportTask) populateGalleryImages() {
} }
} }
func (t *ExportTask) ExportScenes(workers int) { func (t *ExportTask) ExportScenes(workers int, repo models.ReaderRepository) {
var scenesWg sync.WaitGroup var scenesWg sync.WaitGroup
sceneReader := models.NewSceneReaderWriter(nil) sceneReader := repo.Scene()
var scenes []*models.Scene var scenes []*models.Scene
var err error var err error
@@ -327,7 +338,7 @@ func (t *ExportTask) ExportScenes(workers int) {
for w := 0; w < workers; w++ { // create export Scene workers for w := 0; w < workers; w++ { // create export Scene workers
scenesWg.Add(1) scenesWg.Add(1)
go exportScene(&scenesWg, jobCh, t) go exportScene(&scenesWg, jobCh, repo, t)
} }
for i, scene := range scenes { for i, scene := range scenes {
@@ -346,16 +357,15 @@ func (t *ExportTask) ExportScenes(workers int) {
logger.Infof("[scenes] export complete in %s. %d workers used.", time.Since(startTime), workers) logger.Infof("[scenes] export complete in %s. %d workers used.", time.Since(startTime), workers)
} }
func exportScene(wg *sync.WaitGroup, jobChan <-chan *models.Scene, t *ExportTask) { func exportScene(wg *sync.WaitGroup, jobChan <-chan *models.Scene, repo models.ReaderRepository, t *ExportTask) {
defer wg.Done() defer wg.Done()
sceneReader := models.NewSceneReaderWriter(nil) sceneReader := repo.Scene()
studioReader := models.NewStudioReaderWriter(nil) studioReader := repo.Studio()
movieReader := models.NewMovieReaderWriter(nil) movieReader := repo.Movie()
galleryReader := models.NewGalleryReaderWriter(nil) galleryReader := repo.Gallery()
performerReader := models.NewPerformerReaderWriter(nil) performerReader := repo.Performer()
tagReader := models.NewTagReaderWriter(nil) tagReader := repo.Tag()
sceneMarkerReader := models.NewSceneMarkerReaderWriter(nil) sceneMarkerReader := repo.SceneMarker()
joinReader := models.NewJoinReaderWriter(nil)
for s := range jobChan { for s := range jobChan {
sceneHash := s.GetHash(t.fileNamingAlgorithm) sceneHash := s.GetHash(t.fileNamingAlgorithm)
@@ -402,7 +412,7 @@ func exportScene(wg *sync.WaitGroup, jobChan <-chan *models.Scene, t *ExportTask
continue continue
} }
newSceneJSON.Movies, err = scene.GetSceneMoviesJSON(movieReader, joinReader, s) newSceneJSON.Movies, err = scene.GetSceneMoviesJSON(movieReader, sceneReader, s)
if err != nil { if err != nil {
logger.Errorf("[scenes] <%s> error getting scene movies JSON: %s", sceneHash, err.Error()) logger.Errorf("[scenes] <%s> error getting scene movies JSON: %s", sceneHash, err.Error())
continue continue
@@ -417,14 +427,14 @@ func exportScene(wg *sync.WaitGroup, jobChan <-chan *models.Scene, t *ExportTask
t.galleries.IDs = utils.IntAppendUnique(t.galleries.IDs, sceneGallery.ID) t.galleries.IDs = utils.IntAppendUnique(t.galleries.IDs, sceneGallery.ID)
} }
tagIDs, err := scene.GetDependentTagIDs(tagReader, joinReader, sceneMarkerReader, s) tagIDs, err := scene.GetDependentTagIDs(tagReader, sceneMarkerReader, s)
if err != nil { if err != nil {
logger.Errorf("[scenes] <%s> error getting scene tags: %s", sceneHash, err.Error()) logger.Errorf("[scenes] <%s> error getting scene tags: %s", sceneHash, err.Error())
continue continue
} }
t.tags.IDs = utils.IntAppendUniques(t.tags.IDs, tagIDs) t.tags.IDs = utils.IntAppendUniques(t.tags.IDs, tagIDs)
movieIDs, err := scene.GetDependentMovieIDs(joinReader, s) movieIDs, err := scene.GetDependentMovieIDs(sceneReader, s)
if err != nil { if err != nil {
logger.Errorf("[scenes] <%s> error getting scene movies: %s", sceneHash, err.Error()) logger.Errorf("[scenes] <%s> error getting scene movies: %s", sceneHash, err.Error())
continue continue
@@ -445,10 +455,10 @@ func exportScene(wg *sync.WaitGroup, jobChan <-chan *models.Scene, t *ExportTask
} }
} }
func (t *ExportTask) ExportImages(workers int) { func (t *ExportTask) ExportImages(workers int, repo models.ReaderRepository) {
var imagesWg sync.WaitGroup var imagesWg sync.WaitGroup
imageReader := models.NewImageReaderWriter(nil) imageReader := repo.Image()
var images []*models.Image var images []*models.Image
var err error var err error
@@ -470,7 +480,7 @@ func (t *ExportTask) ExportImages(workers int) {
for w := 0; w < workers; w++ { // create export Image workers for w := 0; w < workers; w++ { // create export Image workers
imagesWg.Add(1) imagesWg.Add(1)
go exportImage(&imagesWg, jobCh, t) go exportImage(&imagesWg, jobCh, repo, t)
} }
for i, image := range images { for i, image := range images {
@@ -489,12 +499,12 @@ func (t *ExportTask) ExportImages(workers int) {
logger.Infof("[images] export complete in %s. %d workers used.", time.Since(startTime), workers) logger.Infof("[images] export complete in %s. %d workers used.", time.Since(startTime), workers)
} }
func exportImage(wg *sync.WaitGroup, jobChan <-chan *models.Image, t *ExportTask) { func exportImage(wg *sync.WaitGroup, jobChan <-chan *models.Image, repo models.ReaderRepository, t *ExportTask) {
defer wg.Done() defer wg.Done()
studioReader := models.NewStudioReaderWriter(nil) studioReader := repo.Studio()
galleryReader := models.NewGalleryReaderWriter(nil) galleryReader := repo.Gallery()
performerReader := models.NewPerformerReaderWriter(nil) performerReader := repo.Performer()
tagReader := models.NewTagReaderWriter(nil) tagReader := repo.Tag()
for s := range jobChan { for s := range jobChan {
imageHash := s.Checksum imageHash := s.Checksum
@@ -560,10 +570,10 @@ func (t *ExportTask) getGalleryChecksums(galleries []*models.Gallery) (ret []str
return return
} }
func (t *ExportTask) ExportGalleries(workers int) { func (t *ExportTask) ExportGalleries(workers int, repo models.ReaderRepository) {
var galleriesWg sync.WaitGroup var galleriesWg sync.WaitGroup
reader := models.NewGalleryReaderWriter(nil) reader := repo.Gallery()
var galleries []*models.Gallery var galleries []*models.Gallery
var err error var err error
@@ -585,7 +595,7 @@ func (t *ExportTask) ExportGalleries(workers int) {
for w := 0; w < workers; w++ { // create export Scene workers for w := 0; w < workers; w++ { // create export Scene workers
galleriesWg.Add(1) galleriesWg.Add(1)
go exportGallery(&galleriesWg, jobCh, t) go exportGallery(&galleriesWg, jobCh, repo, t)
} }
for i, gallery := range galleries { for i, gallery := range galleries {
@@ -609,11 +619,11 @@ func (t *ExportTask) ExportGalleries(workers int) {
logger.Infof("[galleries] export complete in %s. %d workers used.", time.Since(startTime), workers) logger.Infof("[galleries] export complete in %s. %d workers used.", time.Since(startTime), workers)
} }
func exportGallery(wg *sync.WaitGroup, jobChan <-chan *models.Gallery, t *ExportTask) { func exportGallery(wg *sync.WaitGroup, jobChan <-chan *models.Gallery, repo models.ReaderRepository, t *ExportTask) {
defer wg.Done() defer wg.Done()
studioReader := models.NewStudioReaderWriter(nil) studioReader := repo.Studio()
performerReader := models.NewPerformerReaderWriter(nil) performerReader := repo.Performer()
tagReader := models.NewTagReaderWriter(nil) tagReader := repo.Tag()
for g := range jobChan { for g := range jobChan {
galleryHash := g.Checksum galleryHash := g.Checksum
@@ -666,10 +676,10 @@ func exportGallery(wg *sync.WaitGroup, jobChan <-chan *models.Gallery, t *Export
} }
} }
func (t *ExportTask) ExportPerformers(workers int) { func (t *ExportTask) ExportPerformers(workers int, repo models.ReaderRepository) {
var performersWg sync.WaitGroup var performersWg sync.WaitGroup
reader := models.NewPerformerReaderWriter(nil) reader := repo.Performer()
var performers []*models.Performer var performers []*models.Performer
var err error var err error
all := t.full || (t.performers != nil && t.performers.all) all := t.full || (t.performers != nil && t.performers.all)
@@ -689,7 +699,7 @@ func (t *ExportTask) ExportPerformers(workers int) {
for w := 0; w < workers; w++ { // create export Performer workers for w := 0; w < workers; w++ { // create export Performer workers
performersWg.Add(1) performersWg.Add(1)
go t.exportPerformer(&performersWg, jobCh) go t.exportPerformer(&performersWg, jobCh, repo)
} }
for i, performer := range performers { for i, performer := range performers {
@@ -706,10 +716,10 @@ func (t *ExportTask) ExportPerformers(workers int) {
logger.Infof("[performers] export complete in %s. %d workers used.", time.Since(startTime), workers) logger.Infof("[performers] export complete in %s. %d workers used.", time.Since(startTime), workers)
} }
func (t *ExportTask) exportPerformer(wg *sync.WaitGroup, jobChan <-chan *models.Performer) { func (t *ExportTask) exportPerformer(wg *sync.WaitGroup, jobChan <-chan *models.Performer, repo models.ReaderRepository) {
defer wg.Done() defer wg.Done()
performerReader := models.NewPerformerReaderWriter(nil) performerReader := repo.Performer()
for p := range jobChan { for p := range jobChan {
newPerformerJSON, err := performer.ToJSON(performerReader, p) newPerformerJSON, err := performer.ToJSON(performerReader, p)
@@ -732,10 +742,10 @@ func (t *ExportTask) exportPerformer(wg *sync.WaitGroup, jobChan <-chan *models.
} }
} }
func (t *ExportTask) ExportStudios(workers int) { func (t *ExportTask) ExportStudios(workers int, repo models.ReaderRepository) {
var studiosWg sync.WaitGroup var studiosWg sync.WaitGroup
reader := models.NewStudioReaderWriter(nil) reader := repo.Studio()
var studios []*models.Studio var studios []*models.Studio
var err error var err error
all := t.full || (t.studios != nil && t.studios.all) all := t.full || (t.studios != nil && t.studios.all)
@@ -756,7 +766,7 @@ func (t *ExportTask) ExportStudios(workers int) {
for w := 0; w < workers; w++ { // create export Studio workers for w := 0; w < workers; w++ { // create export Studio workers
studiosWg.Add(1) studiosWg.Add(1)
go t.exportStudio(&studiosWg, jobCh) go t.exportStudio(&studiosWg, jobCh, repo)
} }
for i, studio := range studios { for i, studio := range studios {
@@ -773,10 +783,10 @@ func (t *ExportTask) ExportStudios(workers int) {
logger.Infof("[studios] export complete in %s. %d workers used.", time.Since(startTime), workers) logger.Infof("[studios] export complete in %s. %d workers used.", time.Since(startTime), workers)
} }
func (t *ExportTask) exportStudio(wg *sync.WaitGroup, jobChan <-chan *models.Studio) { func (t *ExportTask) exportStudio(wg *sync.WaitGroup, jobChan <-chan *models.Studio, repo models.ReaderRepository) {
defer wg.Done() defer wg.Done()
studioReader := models.NewStudioReaderWriter(nil) studioReader := repo.Studio()
for s := range jobChan { for s := range jobChan {
newStudioJSON, err := studio.ToJSON(studioReader, s) newStudioJSON, err := studio.ToJSON(studioReader, s)
@@ -797,10 +807,10 @@ func (t *ExportTask) exportStudio(wg *sync.WaitGroup, jobChan <-chan *models.Stu
} }
} }
func (t *ExportTask) ExportTags(workers int) { func (t *ExportTask) ExportTags(workers int, repo models.ReaderRepository) {
var tagsWg sync.WaitGroup var tagsWg sync.WaitGroup
reader := models.NewTagReaderWriter(nil) reader := repo.Tag()
var tags []*models.Tag var tags []*models.Tag
var err error var err error
all := t.full || (t.tags != nil && t.tags.all) all := t.full || (t.tags != nil && t.tags.all)
@@ -821,7 +831,7 @@ func (t *ExportTask) ExportTags(workers int) {
for w := 0; w < workers; w++ { // create export Tag workers for w := 0; w < workers; w++ { // create export Tag workers
tagsWg.Add(1) tagsWg.Add(1)
go t.exportTag(&tagsWg, jobCh) go t.exportTag(&tagsWg, jobCh, repo)
} }
for i, tag := range tags { for i, tag := range tags {
@@ -841,10 +851,10 @@ func (t *ExportTask) ExportTags(workers int) {
logger.Infof("[tags] export complete in %s. %d workers used.", time.Since(startTime), workers) logger.Infof("[tags] export complete in %s. %d workers used.", time.Since(startTime), workers)
} }
func (t *ExportTask) exportTag(wg *sync.WaitGroup, jobChan <-chan *models.Tag) { func (t *ExportTask) exportTag(wg *sync.WaitGroup, jobChan <-chan *models.Tag, repo models.ReaderRepository) {
defer wg.Done() defer wg.Done()
tagReader := models.NewTagReaderWriter(nil) tagReader := repo.Tag()
for thisTag := range jobChan { for thisTag := range jobChan {
newTagJSON, err := tag.ToJSON(tagReader, thisTag) newTagJSON, err := tag.ToJSON(tagReader, thisTag)
@@ -868,10 +878,10 @@ func (t *ExportTask) exportTag(wg *sync.WaitGroup, jobChan <-chan *models.Tag) {
} }
} }
func (t *ExportTask) ExportMovies(workers int) { func (t *ExportTask) ExportMovies(workers int, repo models.ReaderRepository) {
var moviesWg sync.WaitGroup var moviesWg sync.WaitGroup
reader := models.NewMovieReaderWriter(nil) reader := repo.Movie()
var movies []*models.Movie var movies []*models.Movie
var err error var err error
all := t.full || (t.movies != nil && t.movies.all) all := t.full || (t.movies != nil && t.movies.all)
@@ -892,7 +902,7 @@ func (t *ExportTask) ExportMovies(workers int) {
for w := 0; w < workers; w++ { // create export Studio workers for w := 0; w < workers; w++ { // create export Studio workers
moviesWg.Add(1) moviesWg.Add(1)
go t.exportMovie(&moviesWg, jobCh) go t.exportMovie(&moviesWg, jobCh, repo)
} }
for i, movie := range movies { for i, movie := range movies {
@@ -909,11 +919,11 @@ func (t *ExportTask) ExportMovies(workers int) {
logger.Infof("[movies] export complete in %s. %d workers used.", time.Since(startTime), workers) logger.Infof("[movies] export complete in %s. %d workers used.", time.Since(startTime), workers)
} }
func (t *ExportTask) exportMovie(wg *sync.WaitGroup, jobChan <-chan *models.Movie) { func (t *ExportTask) exportMovie(wg *sync.WaitGroup, jobChan <-chan *models.Movie, repo models.ReaderRepository) {
defer wg.Done() defer wg.Done()
movieReader := models.NewMovieReaderWriter(nil) movieReader := repo.Movie()
studioReader := models.NewStudioReaderWriter(nil) studioReader := repo.Studio()
for m := range jobChan { for m := range jobChan {
newMovieJSON, err := movie.ToJSON(movieReader, studioReader, m) newMovieJSON, err := movie.ToJSON(movieReader, studioReader, m)
@@ -942,9 +952,9 @@ func (t *ExportTask) exportMovie(wg *sync.WaitGroup, jobChan <-chan *models.Movi
} }
} }
func (t *ExportTask) ExportScrapedItems() { func (t *ExportTask) ExportScrapedItems(repo models.ReaderRepository) {
qb := models.NewScrapedItemQueryBuilder() qb := repo.ScrapedItem()
sqb := models.NewStudioQueryBuilder() sqb := repo.Studio()
scrapedItems, err := qb.All() scrapedItems, err := qb.All()
if err != nil { if err != nil {
logger.Errorf("[scraped sites] failed to fetch all items: %s", err.Error()) logger.Errorf("[scraped sites] failed to fetch all items: %s", err.Error())
@@ -960,7 +970,7 @@ func (t *ExportTask) ExportScrapedItems() {
var studioName string var studioName string
if scrapedItem.StudioID.Valid { if scrapedItem.StudioID.Valid {
studio, _ := sqb.Find(int(scrapedItem.StudioID.Int64), nil) studio, _ := sqb.Find(int(scrapedItem.StudioID.Int64))
if studio != nil { if studio != nil {
studioName = studio.Name.String studioName = studio.Name.String
} }

View File

@@ -1,6 +1,7 @@
package manager package manager
import ( import (
"context"
"path/filepath" "path/filepath"
"strconv" "strconv"
@@ -13,6 +14,7 @@ import (
) )
type GenerateMarkersTask struct { type GenerateMarkersTask struct {
TxnManager models.TransactionManager
Scene *models.Scene Scene *models.Scene
Marker *models.SceneMarker Marker *models.SceneMarker
Overwrite bool Overwrite bool
@@ -27,13 +29,21 @@ func (t *GenerateMarkersTask) Start(wg *sizedwaitgroup.SizedWaitGroup) {
} }
if t.Marker != nil { if t.Marker != nil {
qb := models.NewSceneQueryBuilder() var scene *models.Scene
scene, err := qb.Find(int(t.Marker.SceneID.Int64)) if err := t.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error {
if err != nil { var err error
scene, err = r.Scene().Find(int(t.Marker.SceneID.Int64))
return err
}); err != nil {
logger.Errorf("error finding scene for marker: %s", err.Error()) logger.Errorf("error finding scene for marker: %s", err.Error())
return return
} }
if scene == nil {
logger.Errorf("scene not found for id %d", t.Marker.SceneID.Int64)
return
}
videoFile, err := ffmpeg.NewVideoFile(instance.FFProbePath, t.Scene.Path, false) videoFile, err := ffmpeg.NewVideoFile(instance.FFProbePath, t.Scene.Path, false)
if err != nil { if err != nil {
logger.Errorf("error reading video file: %s", err.Error()) logger.Errorf("error reading video file: %s", err.Error())
@@ -45,8 +55,16 @@ func (t *GenerateMarkersTask) Start(wg *sizedwaitgroup.SizedWaitGroup) {
} }
func (t *GenerateMarkersTask) generateSceneMarkers() { func (t *GenerateMarkersTask) generateSceneMarkers() {
qb := models.NewSceneMarkerQueryBuilder() var sceneMarkers []*models.SceneMarker
sceneMarkers, _ := qb.FindBySceneID(t.Scene.ID, nil) if err := t.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error {
var err error
sceneMarkers, err = r.SceneMarker().FindBySceneID(t.Scene.ID)
return err
}); err != nil {
logger.Errorf("error getting scene markers: %s", err.Error())
return
}
if len(sceneMarkers) == 0 { if len(sceneMarkers) == 0 {
return return
} }
@@ -117,8 +135,16 @@ func (t *GenerateMarkersTask) generateMarker(videoFile *ffmpeg.VideoFile, scene
func (t *GenerateMarkersTask) isMarkerNeeded() int { func (t *GenerateMarkersTask) isMarkerNeeded() int {
markers := 0 markers := 0
qb := models.NewSceneMarkerQueryBuilder() var sceneMarkers []*models.SceneMarker
sceneMarkers, _ := qb.FindBySceneID(t.Scene.ID, nil) if err := t.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error {
var err error
sceneMarkers, err = r.SceneMarker().FindBySceneID(t.Scene.ID)
return err
}); err != nil {
logger.Errorf("errror finding scene markers: %s", err.Error())
return 0
}
if len(sceneMarkers) == 0 { if len(sceneMarkers) == 0 {
return 0 return 0
} }

View File

@@ -2,12 +2,12 @@ package manager
import ( import (
"context" "context"
"fmt"
"io/ioutil" "io/ioutil"
"os" "os"
"sync" "sync"
"time" "time"
"github.com/stashapp/stash/pkg/database"
"github.com/stashapp/stash/pkg/ffmpeg" "github.com/stashapp/stash/pkg/ffmpeg"
"github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
@@ -17,6 +17,7 @@ type GenerateScreenshotTask struct {
Scene models.Scene Scene models.Scene
ScreenshotAt *float64 ScreenshotAt *float64
fileNamingAlgorithm models.HashAlgorithm fileNamingAlgorithm models.HashAlgorithm
txnManager models.TransactionManager
} }
func (t *GenerateScreenshotTask) Start(wg *sync.WaitGroup) { func (t *GenerateScreenshotTask) Start(wg *sync.WaitGroup) {
@@ -60,39 +61,31 @@ func (t *GenerateScreenshotTask) Start(wg *sync.WaitGroup) {
return return
} }
ctx := context.TODO() if err := t.txnManager.WithTxn(context.TODO(), func(r models.Repository) error {
tx := database.DB.MustBeginTx(ctx, nil) qb := r.Scene()
updatedTime := time.Now()
updatedScene := models.ScenePartial{
ID: t.Scene.ID,
UpdatedAt: &models.SQLiteTimestamp{Timestamp: updatedTime},
}
qb := models.NewSceneQueryBuilder() if err := SetSceneScreenshot(checksum, coverImageData); err != nil {
updatedTime := time.Now() return fmt.Errorf("Error writing screenshot: %s", err.Error())
updatedScene := models.ScenePartial{ }
ID: t.Scene.ID,
UpdatedAt: &models.SQLiteTimestamp{Timestamp: updatedTime},
}
if err := SetSceneScreenshot(checksum, coverImageData); err != nil { // update the scene cover table
logger.Errorf("Error writing screenshot: %s", err.Error()) if err := qb.UpdateCover(t.Scene.ID, coverImageData); err != nil {
tx.Rollback() return fmt.Errorf("Error setting screenshot: %s", err.Error())
return }
}
// update the scene cover table // update the scene with the update date
if err := qb.UpdateSceneCover(t.Scene.ID, coverImageData, tx); err != nil { _, err = qb.Update(updatedScene)
logger.Errorf("Error setting screenshot: %s", err.Error()) if err != nil {
tx.Rollback() return fmt.Errorf("Error updating scene: %s", err.Error())
return }
}
// update the scene with the update date return nil
_, err = qb.Update(updatedScene, tx) }); err != nil {
if err != nil { logger.Error(err.Error())
logger.Errorf("Error updating scene: %s", err.Error())
tx.Rollback()
return
}
if err := tx.Commit(); err != nil {
logger.Errorf("Error setting screenshot: %s", err.Error())
return
} }
} }

View File

@@ -11,7 +11,6 @@ import (
"sync" "sync"
"time" "time"
"github.com/jmoiron/sqlx"
"github.com/stashapp/stash/pkg/database" "github.com/stashapp/stash/pkg/database"
"github.com/stashapp/stash/pkg/gallery" "github.com/stashapp/stash/pkg/gallery"
"github.com/stashapp/stash/pkg/image" "github.com/stashapp/stash/pkg/image"
@@ -29,7 +28,8 @@ import (
) )
type ImportTask struct { type ImportTask struct {
json jsonUtils txnManager models.TransactionManager
json jsonUtils
BaseDir string BaseDir string
ZipFile io.Reader ZipFile io.Reader
@@ -44,6 +44,7 @@ type ImportTask struct {
func CreateImportTask(a models.HashAlgorithm, input models.ImportObjectsInput) *ImportTask { func CreateImportTask(a models.HashAlgorithm, input models.ImportObjectsInput) *ImportTask {
return &ImportTask{ return &ImportTask{
txnManager: GetInstance().TxnManager,
ZipFile: input.File.File, ZipFile: input.File.File,
Reset: false, Reset: false,
DuplicateBehaviour: input.DuplicateBehaviour, DuplicateBehaviour: input.DuplicateBehaviour,
@@ -200,22 +201,16 @@ func (t *ImportTask) ImportPerformers(ctx context.Context) {
logger.Progressf("[performers] %d of %d", index, len(t.mappings.Performers)) logger.Progressf("[performers] %d of %d", index, len(t.mappings.Performers))
tx := database.DB.MustBeginTx(ctx, nil) if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error {
readerWriter := models.NewPerformerReaderWriter(tx) readerWriter := r.Performer()
importer := &performer.Importer{ importer := &performer.Importer{
ReaderWriter: readerWriter, ReaderWriter: readerWriter,
Input: *performerJSON, Input: *performerJSON,
} }
if err := performImport(importer, t.DuplicateBehaviour); err != nil { return performImport(importer, t.DuplicateBehaviour)
tx.Rollback() }); err != nil {
logger.Errorf("[performers] <%s> failed to import: %s", mappingJSON.Checksum, err.Error()) logger.Errorf("[performers] <%s> import failed: %s", mappingJSON.Checksum, err.Error())
continue
}
if err := tx.Commit(); err != nil {
tx.Rollback()
logger.Errorf("[performers] <%s> import failed to commit: %s", mappingJSON.Checksum, err.Error())
} }
} }
@@ -237,12 +232,9 @@ func (t *ImportTask) ImportStudios(ctx context.Context) {
logger.Progressf("[studios] %d of %d", index, len(t.mappings.Studios)) logger.Progressf("[studios] %d of %d", index, len(t.mappings.Studios))
tx := database.DB.MustBeginTx(ctx, nil) if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error {
return t.ImportStudio(studioJSON, pendingParent, r.Studio())
// fail on missing parent studio to begin with }); err != nil {
if err := t.ImportStudio(studioJSON, pendingParent, tx); err != nil {
tx.Rollback()
if err == studio.ErrParentStudioNotExist { if err == studio.ErrParentStudioNotExist {
// add to the pending parent list so that it is created after the parent // add to the pending parent list so that it is created after the parent
s := pendingParent[studioJSON.ParentStudio] s := pendingParent[studioJSON.ParentStudio]
@@ -254,11 +246,6 @@ func (t *ImportTask) ImportStudios(ctx context.Context) {
logger.Errorf("[studios] <%s> failed to create: %s", mappingJSON.Checksum, err.Error()) logger.Errorf("[studios] <%s> failed to create: %s", mappingJSON.Checksum, err.Error())
continue continue
} }
if err := tx.Commit(); err != nil {
logger.Errorf("[studios] import failed to commit: %s", err.Error())
continue
}
} }
// create the leftover studios, warning for missing parents // create the leftover studios, warning for missing parents
@@ -267,18 +254,12 @@ func (t *ImportTask) ImportStudios(ctx context.Context) {
for _, s := range pendingParent { for _, s := range pendingParent {
for _, orphanStudioJSON := range s { for _, orphanStudioJSON := range s {
tx := database.DB.MustBeginTx(ctx, nil) if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error {
return t.ImportStudio(orphanStudioJSON, nil, r.Studio())
if err := t.ImportStudio(orphanStudioJSON, nil, tx); err != nil { }); err != nil {
tx.Rollback()
logger.Errorf("[studios] <%s> failed to create: %s", orphanStudioJSON.Name, err.Error()) logger.Errorf("[studios] <%s> failed to create: %s", orphanStudioJSON.Name, err.Error())
continue continue
} }
if err := tx.Commit(); err != nil {
logger.Errorf("[studios] import failed to commit: %s", err.Error())
continue
}
} }
} }
} }
@@ -286,8 +267,7 @@ func (t *ImportTask) ImportStudios(ctx context.Context) {
logger.Info("[studios] import complete") logger.Info("[studios] import complete")
} }
func (t *ImportTask) ImportStudio(studioJSON *jsonschema.Studio, pendingParent map[string][]*jsonschema.Studio, tx *sqlx.Tx) error { func (t *ImportTask) ImportStudio(studioJSON *jsonschema.Studio, pendingParent map[string][]*jsonschema.Studio, readerWriter models.StudioReaderWriter) error {
readerWriter := models.NewStudioReaderWriter(tx)
importer := &studio.Importer{ importer := &studio.Importer{
ReaderWriter: readerWriter, ReaderWriter: readerWriter,
Input: *studioJSON, Input: *studioJSON,
@@ -307,7 +287,7 @@ func (t *ImportTask) ImportStudio(studioJSON *jsonschema.Studio, pendingParent m
s := pendingParent[studioJSON.Name] s := pendingParent[studioJSON.Name]
for _, childStudioJSON := range s { for _, childStudioJSON := range s {
// map is nil since we're not checking parent studios at this point // map is nil since we're not checking parent studios at this point
if err := t.ImportStudio(childStudioJSON, nil, tx); err != nil { if err := t.ImportStudio(childStudioJSON, nil, readerWriter); err != nil {
return fmt.Errorf("failed to create child studio <%s>: %s", childStudioJSON.Name, err.Error()) return fmt.Errorf("failed to create child studio <%s>: %s", childStudioJSON.Name, err.Error())
} }
} }
@@ -331,26 +311,20 @@ func (t *ImportTask) ImportMovies(ctx context.Context) {
logger.Progressf("[movies] %d of %d", index, len(t.mappings.Movies)) logger.Progressf("[movies] %d of %d", index, len(t.mappings.Movies))
tx := database.DB.MustBeginTx(ctx, nil) if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error {
readerWriter := models.NewMovieReaderWriter(tx) readerWriter := r.Movie()
studioReaderWriter := models.NewStudioReaderWriter(tx) studioReaderWriter := r.Studio()
movieImporter := &movie.Importer{ movieImporter := &movie.Importer{
ReaderWriter: readerWriter, ReaderWriter: readerWriter,
StudioWriter: studioReaderWriter, StudioWriter: studioReaderWriter,
Input: *movieJSON, Input: *movieJSON,
MissingRefBehaviour: t.MissingRefBehaviour, MissingRefBehaviour: t.MissingRefBehaviour,
} }
if err := performImport(movieImporter, t.DuplicateBehaviour); err != nil { return performImport(movieImporter, t.DuplicateBehaviour)
tx.Rollback() }); err != nil {
logger.Errorf("[movies] <%s> failed to import: %s", mappingJSON.Checksum, err.Error()) logger.Errorf("[movies] <%s> import failed: %s", mappingJSON.Checksum, err.Error())
continue
}
if err := tx.Commit(); err != nil {
tx.Rollback()
logger.Errorf("[movies] <%s> import failed to commit: %s", mappingJSON.Checksum, err.Error())
continue continue
} }
} }
@@ -371,31 +345,23 @@ func (t *ImportTask) ImportGalleries(ctx context.Context) {
logger.Progressf("[galleries] %d of %d", index, len(t.mappings.Galleries)) logger.Progressf("[galleries] %d of %d", index, len(t.mappings.Galleries))
tx := database.DB.MustBeginTx(ctx, nil) if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error {
readerWriter := models.NewGalleryReaderWriter(tx) readerWriter := r.Gallery()
tagWriter := models.NewTagReaderWriter(tx) tagWriter := r.Tag()
joinWriter := models.NewJoinReaderWriter(tx) performerWriter := r.Performer()
performerWriter := models.NewPerformerReaderWriter(tx) studioWriter := r.Studio()
studioWriter := models.NewStudioReaderWriter(tx)
galleryImporter := &gallery.Importer{ galleryImporter := &gallery.Importer{
ReaderWriter: readerWriter, ReaderWriter: readerWriter,
PerformerWriter: performerWriter, PerformerWriter: performerWriter,
StudioWriter: studioWriter, StudioWriter: studioWriter,
TagWriter: tagWriter, TagWriter: tagWriter,
JoinWriter: joinWriter, Input: *galleryJSON,
Input: *galleryJSON, MissingRefBehaviour: t.MissingRefBehaviour,
MissingRefBehaviour: t.MissingRefBehaviour, }
}
if err := performImport(galleryImporter, t.DuplicateBehaviour); err != nil { return performImport(galleryImporter, t.DuplicateBehaviour)
tx.Rollback() }); err != nil {
logger.Errorf("[galleries] <%s> failed to import: %s", mappingJSON.Checksum, err.Error())
continue
}
if err := tx.Commit(); err != nil {
tx.Rollback()
logger.Errorf("[galleries] <%s> import failed to commit: %s", mappingJSON.Checksum, err.Error()) logger.Errorf("[galleries] <%s> import failed to commit: %s", mappingJSON.Checksum, err.Error())
continue continue
} }
@@ -417,74 +383,71 @@ func (t *ImportTask) ImportTags(ctx context.Context) {
logger.Progressf("[tags] %d of %d", index, len(t.mappings.Tags)) logger.Progressf("[tags] %d of %d", index, len(t.mappings.Tags))
tx := database.DB.MustBeginTx(ctx, nil) if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error {
readerWriter := models.NewTagReaderWriter(tx) readerWriter := r.Tag()
tagImporter := &tag.Importer{ tagImporter := &tag.Importer{
ReaderWriter: readerWriter, ReaderWriter: readerWriter,
Input: *tagJSON, Input: *tagJSON,
} }
if err := performImport(tagImporter, t.DuplicateBehaviour); err != nil { return performImport(tagImporter, t.DuplicateBehaviour)
tx.Rollback() }); err != nil {
logger.Errorf("[tags] <%s> failed to import: %s", mappingJSON.Checksum, err.Error()) logger.Errorf("[tags] <%s> failed to import: %s", mappingJSON.Checksum, err.Error())
continue continue
} }
if err := tx.Commit(); err != nil {
tx.Rollback()
logger.Errorf("[tags] <%s> import failed to commit: %s", mappingJSON.Checksum, err.Error())
}
} }
logger.Info("[tags] import complete") logger.Info("[tags] import complete")
} }
func (t *ImportTask) ImportScrapedItems(ctx context.Context) { func (t *ImportTask) ImportScrapedItems(ctx context.Context) {
tx := database.DB.MustBeginTx(ctx, nil) if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error {
qb := models.NewScrapedItemQueryBuilder() logger.Info("[scraped sites] importing")
sqb := models.NewStudioQueryBuilder() qb := r.ScrapedItem()
currentTime := time.Now() sqb := r.Studio()
currentTime := time.Now()
for i, mappingJSON := range t.scraped { for i, mappingJSON := range t.scraped {
index := i + 1 index := i + 1
logger.Progressf("[scraped sites] %d of %d", index, len(t.mappings.Scenes)) logger.Progressf("[scraped sites] %d of %d", index, len(t.mappings.Scenes))
newScrapedItem := models.ScrapedItem{ newScrapedItem := models.ScrapedItem{
Title: sql.NullString{String: mappingJSON.Title, Valid: true}, Title: sql.NullString{String: mappingJSON.Title, Valid: true},
Description: sql.NullString{String: mappingJSON.Description, Valid: true}, Description: sql.NullString{String: mappingJSON.Description, Valid: true},
URL: sql.NullString{String: mappingJSON.URL, Valid: true}, URL: sql.NullString{String: mappingJSON.URL, Valid: true},
Date: models.SQLiteDate{String: mappingJSON.Date, Valid: true}, Date: models.SQLiteDate{String: mappingJSON.Date, Valid: true},
Rating: sql.NullString{String: mappingJSON.Rating, Valid: true}, Rating: sql.NullString{String: mappingJSON.Rating, Valid: true},
Tags: sql.NullString{String: mappingJSON.Tags, Valid: true}, Tags: sql.NullString{String: mappingJSON.Tags, Valid: true},
Models: sql.NullString{String: mappingJSON.Models, Valid: true}, Models: sql.NullString{String: mappingJSON.Models, Valid: true},
Episode: sql.NullInt64{Int64: int64(mappingJSON.Episode), Valid: true}, Episode: sql.NullInt64{Int64: int64(mappingJSON.Episode), Valid: true},
GalleryFilename: sql.NullString{String: mappingJSON.GalleryFilename, Valid: true}, GalleryFilename: sql.NullString{String: mappingJSON.GalleryFilename, Valid: true},
GalleryURL: sql.NullString{String: mappingJSON.GalleryURL, Valid: true}, GalleryURL: sql.NullString{String: mappingJSON.GalleryURL, Valid: true},
VideoFilename: sql.NullString{String: mappingJSON.VideoFilename, Valid: true}, VideoFilename: sql.NullString{String: mappingJSON.VideoFilename, Valid: true},
VideoURL: sql.NullString{String: mappingJSON.VideoURL, Valid: true}, VideoURL: sql.NullString{String: mappingJSON.VideoURL, Valid: true},
CreatedAt: models.SQLiteTimestamp{Timestamp: currentTime}, CreatedAt: models.SQLiteTimestamp{Timestamp: currentTime},
UpdatedAt: models.SQLiteTimestamp{Timestamp: t.getTimeFromJSONTime(mappingJSON.UpdatedAt)}, UpdatedAt: models.SQLiteTimestamp{Timestamp: t.getTimeFromJSONTime(mappingJSON.UpdatedAt)},
}
studio, err := sqb.FindByName(mappingJSON.Studio, false)
if err != nil {
logger.Errorf("[scraped sites] failed to fetch studio: %s", err.Error())
}
if studio != nil {
newScrapedItem.StudioID = sql.NullInt64{Int64: int64(studio.ID), Valid: true}
}
_, err = qb.Create(newScrapedItem)
if err != nil {
logger.Errorf("[scraped sites] <%s> failed to create: %s", newScrapedItem.Title.String, err.Error())
}
} }
studio, err := sqb.FindByName(mappingJSON.Studio, tx, false) return nil
if err != nil { }); err != nil {
logger.Errorf("[scraped sites] failed to fetch studio: %s", err.Error())
}
if studio != nil {
newScrapedItem.StudioID = sql.NullInt64{Int64: int64(studio.ID), Valid: true}
}
_, err = qb.Create(newScrapedItem, tx)
if err != nil {
logger.Errorf("[scraped sites] <%s> failed to create: %s", newScrapedItem.Title.String, err.Error())
}
}
logger.Info("[scraped sites] importing")
if err := tx.Commit(); err != nil {
logger.Errorf("[scraped sites] import failed to commit: %s", err.Error()) logger.Errorf("[scraped sites] import failed to commit: %s", err.Error())
} }
logger.Info("[scraped sites] import complete") logger.Info("[scraped sites] import complete")
} }
@@ -504,65 +467,52 @@ func (t *ImportTask) ImportScenes(ctx context.Context) {
sceneHash := mappingJSON.Checksum sceneHash := mappingJSON.Checksum
tx := database.DB.MustBeginTx(ctx, nil) if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error {
readerWriter := models.NewSceneReaderWriter(tx) readerWriter := r.Scene()
tagWriter := models.NewTagReaderWriter(tx) tagWriter := r.Tag()
galleryWriter := models.NewGalleryReaderWriter(tx) galleryWriter := r.Gallery()
joinWriter := models.NewJoinReaderWriter(tx) movieWriter := r.Movie()
movieWriter := models.NewMovieReaderWriter(tx) performerWriter := r.Performer()
performerWriter := models.NewPerformerReaderWriter(tx) studioWriter := r.Studio()
studioWriter := models.NewStudioReaderWriter(tx) markerWriter := r.SceneMarker()
markerWriter := models.NewSceneMarkerReaderWriter(tx)
sceneImporter := &scene.Importer{ sceneImporter := &scene.Importer{
ReaderWriter: readerWriter, ReaderWriter: readerWriter,
Input: *sceneJSON, Input: *sceneJSON,
Path: mappingJSON.Path, Path: mappingJSON.Path,
FileNamingAlgorithm: t.fileNamingAlgorithm, FileNamingAlgorithm: t.fileNamingAlgorithm,
MissingRefBehaviour: t.MissingRefBehaviour,
GalleryWriter: galleryWriter,
JoinWriter: joinWriter,
MovieWriter: movieWriter,
PerformerWriter: performerWriter,
StudioWriter: studioWriter,
TagWriter: tagWriter,
}
if err := performImport(sceneImporter, t.DuplicateBehaviour); err != nil {
tx.Rollback()
logger.Errorf("[scenes] <%s> failed to import: %s", sceneHash, err.Error())
continue
}
// import the scene markers
failedMarkers := false
for _, m := range sceneJSON.Markers {
markerImporter := &scene.MarkerImporter{
SceneID: sceneImporter.ID,
Input: m,
MissingRefBehaviour: t.MissingRefBehaviour, MissingRefBehaviour: t.MissingRefBehaviour,
ReaderWriter: markerWriter,
JoinWriter: joinWriter, GalleryWriter: galleryWriter,
TagWriter: tagWriter, MovieWriter: movieWriter,
PerformerWriter: performerWriter,
StudioWriter: studioWriter,
TagWriter: tagWriter,
} }
if err := performImport(markerImporter, t.DuplicateBehaviour); err != nil { if err := performImport(sceneImporter, t.DuplicateBehaviour); err != nil {
failedMarkers = true return err
logger.Errorf("[scenes] <%s> failed to import markers: %s", sceneHash, err.Error())
break
} }
}
if failedMarkers { // import the scene markers
tx.Rollback() for _, m := range sceneJSON.Markers {
continue markerImporter := &scene.MarkerImporter{
} SceneID: sceneImporter.ID,
Input: m,
MissingRefBehaviour: t.MissingRefBehaviour,
ReaderWriter: markerWriter,
TagWriter: tagWriter,
}
if err := tx.Commit(); err != nil { if err := performImport(markerImporter, t.DuplicateBehaviour); err != nil {
tx.Rollback() return err
logger.Errorf("[scenes] <%s> import failed to commit: %s", sceneHash, err.Error()) }
}
return nil
}); err != nil {
logger.Errorf("[scenes] <%s> import failed: %s", sceneHash, err.Error())
} }
} }
@@ -585,46 +535,37 @@ func (t *ImportTask) ImportImages(ctx context.Context) {
imageHash := mappingJSON.Checksum imageHash := mappingJSON.Checksum
tx := database.DB.MustBeginTx(ctx, nil) if err := t.txnManager.WithTxn(ctx, func(r models.Repository) error {
readerWriter := models.NewImageReaderWriter(tx) readerWriter := r.Image()
tagWriter := models.NewTagReaderWriter(tx) tagWriter := r.Tag()
galleryWriter := models.NewGalleryReaderWriter(tx) galleryWriter := r.Gallery()
joinWriter := models.NewJoinReaderWriter(tx) performerWriter := r.Performer()
performerWriter := models.NewPerformerReaderWriter(tx) studioWriter := r.Studio()
studioWriter := models.NewStudioReaderWriter(tx)
imageImporter := &image.Importer{ imageImporter := &image.Importer{
ReaderWriter: readerWriter, ReaderWriter: readerWriter,
Input: *imageJSON, Input: *imageJSON,
Path: mappingJSON.Path, Path: mappingJSON.Path,
MissingRefBehaviour: t.MissingRefBehaviour, MissingRefBehaviour: t.MissingRefBehaviour,
GalleryWriter: galleryWriter, GalleryWriter: galleryWriter,
JoinWriter: joinWriter, PerformerWriter: performerWriter,
PerformerWriter: performerWriter, StudioWriter: studioWriter,
StudioWriter: studioWriter, TagWriter: tagWriter,
TagWriter: tagWriter, }
}
if err := performImport(imageImporter, t.DuplicateBehaviour); err != nil { return performImport(imageImporter, t.DuplicateBehaviour)
tx.Rollback() }); err != nil {
logger.Errorf("[images] <%s> failed to import: %s", imageHash, err.Error()) logger.Errorf("[images] <%s> import failed: %s", imageHash, err.Error())
continue
}
if err := tx.Commit(); err != nil {
tx.Rollback()
logger.Errorf("[images] <%s> import failed to commit: %s", imageHash, err.Error())
} }
} }
logger.Info("[images] import complete") logger.Info("[images] import complete")
} }
func (t *ImportTask) getPerformers(names []string, tx *sqlx.Tx) ([]*models.Performer, error) { func (t *ImportTask) getPerformers(names []string, qb models.PerformerReader) ([]*models.Performer, error) {
pqb := models.NewPerformerQueryBuilder() performers, err := qb.FindByNames(names, false)
performers, err := pqb.FindByNames(names, tx, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -648,12 +589,10 @@ func (t *ImportTask) getPerformers(names []string, tx *sqlx.Tx) ([]*models.Perfo
return performers, nil return performers, nil
} }
func (t *ImportTask) getMoviesScenes(input []jsonschema.SceneMovie, sceneID int, tx *sqlx.Tx) ([]models.MoviesScenes, error) { func (t *ImportTask) getMoviesScenes(input []jsonschema.SceneMovie, sceneID int, mqb models.MovieReader) ([]models.MoviesScenes, error) {
mqb := models.NewMovieQueryBuilder()
var movies []models.MoviesScenes var movies []models.MoviesScenes
for _, inputMovie := range input { for _, inputMovie := range input {
movie, err := mqb.FindByName(inputMovie.MovieName, tx, false) movie, err := mqb.FindByName(inputMovie.MovieName, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -680,9 +619,8 @@ func (t *ImportTask) getMoviesScenes(input []jsonschema.SceneMovie, sceneID int,
return movies, nil return movies, nil
} }
func (t *ImportTask) getTags(sceneChecksum string, names []string, tx *sqlx.Tx) ([]*models.Tag, error) { func (t *ImportTask) getTags(sceneChecksum string, names []string, tqb models.TagReader) ([]*models.Tag, error) {
tqb := models.NewTagQueryBuilder() tags, err := tqb.FindByNames(names, false)
tags, err := tqb.FindByNames(names, tx, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }

File diff suppressed because it is too large Load Diff

View File

@@ -1,74 +1,34 @@
package models package models
import (
"github.com/jmoiron/sqlx"
)
type GalleryReader interface { type GalleryReader interface {
// Find(id int) (*Gallery, error) Find(id int) (*Gallery, error)
FindMany(ids []int) ([]*Gallery, error) FindMany(ids []int) ([]*Gallery, error)
FindByChecksum(checksum string) (*Gallery, error) FindByChecksum(checksum string) (*Gallery, error)
FindByPath(path string) (*Gallery, error) FindByPath(path string) (*Gallery, error)
FindBySceneID(sceneID int) (*Gallery, error) FindBySceneID(sceneID int) (*Gallery, error)
FindByImageID(imageID int) ([]*Gallery, error) FindByImageID(imageID int) ([]*Gallery, error)
// ValidGalleriesForScenePath(scenePath string) ([]*Gallery, error) ValidGalleriesForScenePath(scenePath string) ([]*Gallery, error)
// Count() (int, error) Count() (int, error)
All() ([]*Gallery, error) All() ([]*Gallery, error)
// Query(galleryFilter *GalleryFilterType, findFilter *FindFilterType) ([]*Gallery, int) Query(galleryFilter *GalleryFilterType, findFilter *FindFilterType) ([]*Gallery, int, error)
GetPerformerIDs(galleryID int) ([]int, error)
GetTagIDs(galleryID int) ([]int, error)
GetImageIDs(galleryID int) ([]int, error)
} }
type GalleryWriter interface { type GalleryWriter interface {
Create(newGallery Gallery) (*Gallery, error) Create(newGallery Gallery) (*Gallery, error)
Update(updatedGallery Gallery) (*Gallery, error) Update(updatedGallery Gallery) (*Gallery, error)
// Destroy(id int) error UpdatePartial(updatedGallery GalleryPartial) (*Gallery, error)
// ClearGalleryId(sceneID int) error UpdateFileModTime(id int, modTime NullSQLiteTimestamp) error
Destroy(id int) error
ClearGalleryId(sceneID int) error
UpdatePerformers(galleryID int, performerIDs []int) error
UpdateTags(galleryID int, tagIDs []int) error
UpdateImages(galleryID int, imageIDs []int) error
} }
type GalleryReaderWriter interface { type GalleryReaderWriter interface {
GalleryReader GalleryReader
GalleryWriter GalleryWriter
} }
func NewGalleryReaderWriter(tx *sqlx.Tx) GalleryReaderWriter {
return &galleryReaderWriter{
tx: tx,
qb: NewGalleryQueryBuilder(),
}
}
type galleryReaderWriter struct {
tx *sqlx.Tx
qb GalleryQueryBuilder
}
func (t *galleryReaderWriter) FindMany(ids []int) ([]*Gallery, error) {
return t.qb.FindMany(ids)
}
func (t *galleryReaderWriter) FindByChecksum(checksum string) (*Gallery, error) {
return t.qb.FindByChecksum(checksum, t.tx)
}
func (t *galleryReaderWriter) All() ([]*Gallery, error) {
return t.qb.All()
}
func (t *galleryReaderWriter) FindByPath(path string) (*Gallery, error) {
return t.qb.FindByPath(path)
}
func (t *galleryReaderWriter) FindBySceneID(sceneID int) (*Gallery, error) {
return t.qb.FindBySceneID(sceneID, t.tx)
}
func (t *galleryReaderWriter) FindByImageID(imageID int) ([]*Gallery, error) {
return t.qb.FindByImageID(imageID, t.tx)
}
func (t *galleryReaderWriter) Create(newGallery Gallery) (*Gallery, error) {
return t.qb.Create(newGallery, t.tx)
}
func (t *galleryReaderWriter) Update(updatedGallery Gallery) (*Gallery, error) {
return t.qb.Update(updatedGallery, t.tx)
}

View File

@@ -1,77 +1,41 @@
package models package models
import (
"github.com/jmoiron/sqlx"
)
type ImageReader interface { type ImageReader interface {
// Find(id int) (*Image, error) Find(id int) (*Image, error)
FindMany(ids []int) ([]*Image, error) FindMany(ids []int) ([]*Image, error)
FindByChecksum(checksum string) (*Image, error) FindByChecksum(checksum string) (*Image, error)
FindByGalleryID(galleryID int) ([]*Image, error) FindByGalleryID(galleryID int) ([]*Image, error)
// FindByPath(path string) (*Image, error) CountByGalleryID(galleryID int) (int, error)
FindByPath(path string) (*Image, error)
// FindByPerformerID(performerID int) ([]*Image, error) // FindByPerformerID(performerID int) ([]*Image, error)
// CountByPerformerID(performerID int) (int, error) // CountByPerformerID(performerID int) (int, error)
// FindByStudioID(studioID int) ([]*Image, error) // FindByStudioID(studioID int) ([]*Image, error)
// Count() (int, error) Count() (int, error)
Size() (float64, error)
// SizeCount() (string, error) // SizeCount() (string, error)
// CountByStudioID(studioID int) (int, error) // CountByStudioID(studioID int) (int, error)
// CountByTagID(tagID int) (int, error) // CountByTagID(tagID int) (int, error)
All() ([]*Image, error) All() ([]*Image, error)
// Query(imageFilter *ImageFilterType, findFilter *FindFilterType) ([]*Image, int) Query(imageFilter *ImageFilterType, findFilter *FindFilterType) ([]*Image, int, error)
GetGalleryIDs(imageID int) ([]int, error)
GetTagIDs(imageID int) ([]int, error)
GetPerformerIDs(imageID int) ([]int, error)
} }
type ImageWriter interface { type ImageWriter interface {
Create(newImage Image) (*Image, error) Create(newImage Image) (*Image, error)
Update(updatedImage ImagePartial) (*Image, error) Update(updatedImage ImagePartial) (*Image, error)
UpdateFull(updatedImage Image) (*Image, error) UpdateFull(updatedImage Image) (*Image, error)
// IncrementOCounter(id int) (int, error) IncrementOCounter(id int) (int, error)
// DecrementOCounter(id int) (int, error) DecrementOCounter(id int) (int, error)
// ResetOCounter(id int) (int, error) ResetOCounter(id int) (int, error)
// Destroy(id string) error Destroy(id int) error
UpdateGalleries(imageID int, galleryIDs []int) error
UpdatePerformers(imageID int, performerIDs []int) error
UpdateTags(imageID int, tagIDs []int) error
} }
type ImageReaderWriter interface { type ImageReaderWriter interface {
ImageReader ImageReader
ImageWriter ImageWriter
} }
func NewImageReaderWriter(tx *sqlx.Tx) ImageReaderWriter {
return &imageReaderWriter{
tx: tx,
qb: NewImageQueryBuilder(),
}
}
type imageReaderWriter struct {
tx *sqlx.Tx
qb ImageQueryBuilder
}
func (t *imageReaderWriter) FindMany(ids []int) ([]*Image, error) {
return t.qb.FindMany(ids)
}
func (t *imageReaderWriter) FindByChecksum(checksum string) (*Image, error) {
return t.qb.FindByChecksum(checksum)
}
func (t *imageReaderWriter) FindByGalleryID(galleryID int) ([]*Image, error) {
return t.qb.FindByGalleryID(galleryID)
}
func (t *imageReaderWriter) All() ([]*Image, error) {
return t.qb.All()
}
func (t *imageReaderWriter) Create(newImage Image) (*Image, error) {
return t.qb.Create(newImage, t.tx)
}
func (t *imageReaderWriter) Update(updatedImage ImagePartial) (*Image, error) {
return t.qb.Update(updatedImage, t.tx)
}
func (t *imageReaderWriter) UpdateFull(updatedImage Image) (*Image, error) {
return t.qb.UpdateFull(updatedImage, t.tx)
}

View File

@@ -1,101 +0,0 @@
package models
import (
"github.com/jmoiron/sqlx"
)
type JoinReader interface {
// GetScenePerformers(sceneID int) ([]PerformersScenes, error)
GetSceneMovies(sceneID int) ([]MoviesScenes, error)
// GetSceneTags(sceneID int) ([]ScenesTags, error)
}
type JoinWriter interface {
CreatePerformersScenes(newJoins []PerformersScenes) error
// AddPerformerScene(sceneID int, performerID int) (bool, error)
UpdatePerformersScenes(sceneID int, updatedJoins []PerformersScenes) error
// DestroyPerformersScenes(sceneID int) error
CreateMoviesScenes(newJoins []MoviesScenes) error
// AddMoviesScene(sceneID int, movieID int, sceneIdx *int) (bool, error)
UpdateMoviesScenes(sceneID int, updatedJoins []MoviesScenes) error
// DestroyMoviesScenes(sceneID int) error
// CreateScenesTags(newJoins []ScenesTags) error
UpdateScenesTags(sceneID int, updatedJoins []ScenesTags) error
// AddSceneTag(sceneID int, tagID int) (bool, error)
// DestroyScenesTags(sceneID int) error
// CreateSceneMarkersTags(newJoins []SceneMarkersTags) error
UpdateSceneMarkersTags(sceneMarkerID int, updatedJoins []SceneMarkersTags) error
// DestroySceneMarkersTags(sceneMarkerID int, updatedJoins []SceneMarkersTags) error
// DestroyScenesGalleries(sceneID int) error
// DestroyScenesMarkers(sceneID int) error
UpdatePerformersGalleries(galleryID int, updatedJoins []PerformersGalleries) error
UpdateGalleriesTags(galleryID int, updatedJoins []GalleriesTags) error
UpdateGalleriesImages(imageID int, updatedJoins []GalleriesImages) error
UpdatePerformersImages(imageID int, updatedJoins []PerformersImages) error
UpdateImagesTags(imageID int, updatedJoins []ImagesTags) error
}
type JoinReaderWriter interface {
JoinReader
JoinWriter
}
func NewJoinReaderWriter(tx *sqlx.Tx) JoinReaderWriter {
return &joinReaderWriter{
tx: tx,
qb: NewJoinsQueryBuilder(),
}
}
type joinReaderWriter struct {
tx *sqlx.Tx
qb JoinsQueryBuilder
}
func (t *joinReaderWriter) GetSceneMovies(sceneID int) ([]MoviesScenes, error) {
return t.qb.GetSceneMovies(sceneID, t.tx)
}
func (t *joinReaderWriter) CreatePerformersScenes(newJoins []PerformersScenes) error {
return t.qb.CreatePerformersScenes(newJoins, t.tx)
}
func (t *joinReaderWriter) UpdatePerformersScenes(sceneID int, updatedJoins []PerformersScenes) error {
return t.qb.UpdatePerformersScenes(sceneID, updatedJoins, t.tx)
}
func (t *joinReaderWriter) CreateMoviesScenes(newJoins []MoviesScenes) error {
return t.qb.CreateMoviesScenes(newJoins, t.tx)
}
func (t *joinReaderWriter) UpdateMoviesScenes(sceneID int, updatedJoins []MoviesScenes) error {
return t.qb.UpdateMoviesScenes(sceneID, updatedJoins, t.tx)
}
func (t *joinReaderWriter) UpdateScenesTags(sceneID int, updatedJoins []ScenesTags) error {
return t.qb.UpdateScenesTags(sceneID, updatedJoins, t.tx)
}
func (t *joinReaderWriter) UpdateSceneMarkersTags(sceneMarkerID int, updatedJoins []SceneMarkersTags) error {
return t.qb.UpdateSceneMarkersTags(sceneMarkerID, updatedJoins, t.tx)
}
func (t *joinReaderWriter) UpdatePerformersGalleries(galleryID int, updatedJoins []PerformersGalleries) error {
return t.qb.UpdatePerformersGalleries(galleryID, updatedJoins, t.tx)
}
func (t *joinReaderWriter) UpdateGalleriesTags(galleryID int, updatedJoins []GalleriesTags) error {
return t.qb.UpdateGalleriesTags(galleryID, updatedJoins, t.tx)
}
func (t *joinReaderWriter) UpdateGalleriesImages(imageID int, updatedJoins []GalleriesImages) error {
return t.qb.UpdateGalleriesImages(imageID, updatedJoins, t.tx)
}
func (t *joinReaderWriter) UpdatePerformersImages(imageID int, updatedJoins []PerformersImages) error {
return t.qb.UpdatePerformersImages(imageID, updatedJoins, t.tx)
}
func (t *joinReaderWriter) UpdateImagesTags(imageID int, updatedJoins []ImagesTags) error {
return t.qb.UpdateImagesTags(imageID, updatedJoins, t.tx)
}

View File

@@ -35,6 +35,41 @@ func (_m *GalleryReaderWriter) All() ([]*models.Gallery, error) {
return r0, r1 return r0, r1
} }
// ClearGalleryId provides a mock function with given fields: sceneID
func (_m *GalleryReaderWriter) ClearGalleryId(sceneID int) error {
ret := _m.Called(sceneID)
var r0 error
if rf, ok := ret.Get(0).(func(int) error); ok {
r0 = rf(sceneID)
} else {
r0 = ret.Error(0)
}
return r0
}
// Count provides a mock function with given fields:
func (_m *GalleryReaderWriter) Count() (int, error) {
ret := _m.Called()
var r0 int
if rf, ok := ret.Get(0).(func() int); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(int)
}
var r1 error
if rf, ok := ret.Get(1).(func() error); ok {
r1 = rf()
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Create provides a mock function with given fields: newGallery // Create provides a mock function with given fields: newGallery
func (_m *GalleryReaderWriter) Create(newGallery models.Gallery) (*models.Gallery, error) { func (_m *GalleryReaderWriter) Create(newGallery models.Gallery) (*models.Gallery, error) {
ret := _m.Called(newGallery) ret := _m.Called(newGallery)
@@ -58,6 +93,43 @@ func (_m *GalleryReaderWriter) Create(newGallery models.Gallery) (*models.Galler
return r0, r1 return r0, r1
} }
// Destroy provides a mock function with given fields: id
func (_m *GalleryReaderWriter) Destroy(id int) error {
ret := _m.Called(id)
var r0 error
if rf, ok := ret.Get(0).(func(int) error); ok {
r0 = rf(id)
} else {
r0 = ret.Error(0)
}
return r0
}
// Find provides a mock function with given fields: id
func (_m *GalleryReaderWriter) Find(id int) (*models.Gallery, error) {
ret := _m.Called(id)
var r0 *models.Gallery
if rf, ok := ret.Get(0).(func(int) *models.Gallery); ok {
r0 = rf(id)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.Gallery)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(int) error); ok {
r1 = rf(id)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// FindByChecksum provides a mock function with given fields: checksum // FindByChecksum provides a mock function with given fields: checksum
func (_m *GalleryReaderWriter) FindByChecksum(checksum string) (*models.Gallery, error) { func (_m *GalleryReaderWriter) FindByChecksum(checksum string) (*models.Gallery, error) {
ret := _m.Called(checksum) ret := _m.Called(checksum)
@@ -173,6 +245,105 @@ func (_m *GalleryReaderWriter) FindMany(ids []int) ([]*models.Gallery, error) {
return r0, r1 return r0, r1
} }
// GetImageIDs provides a mock function with given fields: galleryID
func (_m *GalleryReaderWriter) GetImageIDs(galleryID int) ([]int, error) {
ret := _m.Called(galleryID)
var r0 []int
if rf, ok := ret.Get(0).(func(int) []int); ok {
r0 = rf(galleryID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]int)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(int) error); ok {
r1 = rf(galleryID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// GetPerformerIDs provides a mock function with given fields: galleryID
func (_m *GalleryReaderWriter) GetPerformerIDs(galleryID int) ([]int, error) {
ret := _m.Called(galleryID)
var r0 []int
if rf, ok := ret.Get(0).(func(int) []int); ok {
r0 = rf(galleryID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]int)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(int) error); ok {
r1 = rf(galleryID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// GetTagIDs provides a mock function with given fields: galleryID
func (_m *GalleryReaderWriter) GetTagIDs(galleryID int) ([]int, error) {
ret := _m.Called(galleryID)
var r0 []int
if rf, ok := ret.Get(0).(func(int) []int); ok {
r0 = rf(galleryID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]int)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(int) error); ok {
r1 = rf(galleryID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Query provides a mock function with given fields: galleryFilter, findFilter
func (_m *GalleryReaderWriter) Query(galleryFilter *models.GalleryFilterType, findFilter *models.FindFilterType) ([]*models.Gallery, int, error) {
ret := _m.Called(galleryFilter, findFilter)
var r0 []*models.Gallery
if rf, ok := ret.Get(0).(func(*models.GalleryFilterType, *models.FindFilterType) []*models.Gallery); ok {
r0 = rf(galleryFilter, findFilter)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Gallery)
}
}
var r1 int
if rf, ok := ret.Get(1).(func(*models.GalleryFilterType, *models.FindFilterType) int); ok {
r1 = rf(galleryFilter, findFilter)
} else {
r1 = ret.Get(1).(int)
}
var r2 error
if rf, ok := ret.Get(2).(func(*models.GalleryFilterType, *models.FindFilterType) error); ok {
r2 = rf(galleryFilter, findFilter)
} else {
r2 = ret.Error(2)
}
return r0, r1, r2
}
// Update provides a mock function with given fields: updatedGallery // Update provides a mock function with given fields: updatedGallery
func (_m *GalleryReaderWriter) Update(updatedGallery models.Gallery) (*models.Gallery, error) { func (_m *GalleryReaderWriter) Update(updatedGallery models.Gallery) (*models.Gallery, error) {
ret := _m.Called(updatedGallery) ret := _m.Called(updatedGallery)
@@ -195,3 +366,105 @@ func (_m *GalleryReaderWriter) Update(updatedGallery models.Gallery) (*models.Ga
return r0, r1 return r0, r1
} }
// UpdateFileModTime provides a mock function with given fields: id, modTime
func (_m *GalleryReaderWriter) UpdateFileModTime(id int, modTime models.NullSQLiteTimestamp) error {
ret := _m.Called(id, modTime)
var r0 error
if rf, ok := ret.Get(0).(func(int, models.NullSQLiteTimestamp) error); ok {
r0 = rf(id, modTime)
} else {
r0 = ret.Error(0)
}
return r0
}
// UpdateImages provides a mock function with given fields: galleryID, imageIDs
func (_m *GalleryReaderWriter) UpdateImages(galleryID int, imageIDs []int) error {
ret := _m.Called(galleryID, imageIDs)
var r0 error
if rf, ok := ret.Get(0).(func(int, []int) error); ok {
r0 = rf(galleryID, imageIDs)
} else {
r0 = ret.Error(0)
}
return r0
}
// UpdatePartial provides a mock function with given fields: updatedGallery
func (_m *GalleryReaderWriter) UpdatePartial(updatedGallery models.GalleryPartial) (*models.Gallery, error) {
ret := _m.Called(updatedGallery)
var r0 *models.Gallery
if rf, ok := ret.Get(0).(func(models.GalleryPartial) *models.Gallery); ok {
r0 = rf(updatedGallery)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.Gallery)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(models.GalleryPartial) error); ok {
r1 = rf(updatedGallery)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// UpdatePerformers provides a mock function with given fields: galleryID, performerIDs
func (_m *GalleryReaderWriter) UpdatePerformers(galleryID int, performerIDs []int) error {
ret := _m.Called(galleryID, performerIDs)
var r0 error
if rf, ok := ret.Get(0).(func(int, []int) error); ok {
r0 = rf(galleryID, performerIDs)
} else {
r0 = ret.Error(0)
}
return r0
}
// UpdateTags provides a mock function with given fields: galleryID, tagIDs
func (_m *GalleryReaderWriter) UpdateTags(galleryID int, tagIDs []int) error {
ret := _m.Called(galleryID, tagIDs)
var r0 error
if rf, ok := ret.Get(0).(func(int, []int) error); ok {
r0 = rf(galleryID, tagIDs)
} else {
r0 = ret.Error(0)
}
return r0
}
// ValidGalleriesForScenePath provides a mock function with given fields: scenePath
func (_m *GalleryReaderWriter) ValidGalleriesForScenePath(scenePath string) ([]*models.Gallery, error) {
ret := _m.Called(scenePath)
var r0 []*models.Gallery
if rf, ok := ret.Get(0).(func(string) []*models.Gallery); ok {
r0 = rf(scenePath)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Gallery)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(string) error); ok {
r1 = rf(scenePath)
} else {
r1 = ret.Error(1)
}
return r0, r1
}

View File

@@ -35,6 +35,48 @@ func (_m *ImageReaderWriter) All() ([]*models.Image, error) {
return r0, r1 return r0, r1
} }
// Count provides a mock function with given fields:
func (_m *ImageReaderWriter) Count() (int, error) {
ret := _m.Called()
var r0 int
if rf, ok := ret.Get(0).(func() int); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(int)
}
var r1 error
if rf, ok := ret.Get(1).(func() error); ok {
r1 = rf()
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// CountByGalleryID provides a mock function with given fields: galleryID
func (_m *ImageReaderWriter) CountByGalleryID(galleryID int) (int, error) {
ret := _m.Called(galleryID)
var r0 int
if rf, ok := ret.Get(0).(func(int) int); ok {
r0 = rf(galleryID)
} else {
r0 = ret.Get(0).(int)
}
var r1 error
if rf, ok := ret.Get(1).(func(int) error); ok {
r1 = rf(galleryID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Create provides a mock function with given fields: newImage // Create provides a mock function with given fields: newImage
func (_m *ImageReaderWriter) Create(newImage models.Image) (*models.Image, error) { func (_m *ImageReaderWriter) Create(newImage models.Image) (*models.Image, error) {
ret := _m.Called(newImage) ret := _m.Called(newImage)
@@ -58,6 +100,64 @@ func (_m *ImageReaderWriter) Create(newImage models.Image) (*models.Image, error
return r0, r1 return r0, r1
} }
// DecrementOCounter provides a mock function with given fields: id
func (_m *ImageReaderWriter) DecrementOCounter(id int) (int, error) {
ret := _m.Called(id)
var r0 int
if rf, ok := ret.Get(0).(func(int) int); ok {
r0 = rf(id)
} else {
r0 = ret.Get(0).(int)
}
var r1 error
if rf, ok := ret.Get(1).(func(int) error); ok {
r1 = rf(id)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Destroy provides a mock function with given fields: id
func (_m *ImageReaderWriter) Destroy(id int) error {
ret := _m.Called(id)
var r0 error
if rf, ok := ret.Get(0).(func(int) error); ok {
r0 = rf(id)
} else {
r0 = ret.Error(0)
}
return r0
}
// Find provides a mock function with given fields: id
func (_m *ImageReaderWriter) Find(id int) (*models.Image, error) {
ret := _m.Called(id)
var r0 *models.Image
if rf, ok := ret.Get(0).(func(int) *models.Image); ok {
r0 = rf(id)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.Image)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(int) error); ok {
r1 = rf(id)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// FindByChecksum provides a mock function with given fields: checksum // FindByChecksum provides a mock function with given fields: checksum
func (_m *ImageReaderWriter) FindByChecksum(checksum string) (*models.Image, error) { func (_m *ImageReaderWriter) FindByChecksum(checksum string) (*models.Image, error) {
ret := _m.Called(checksum) ret := _m.Called(checksum)
@@ -104,6 +204,29 @@ func (_m *ImageReaderWriter) FindByGalleryID(galleryID int) ([]*models.Image, er
return r0, r1 return r0, r1
} }
// FindByPath provides a mock function with given fields: path
func (_m *ImageReaderWriter) FindByPath(path string) (*models.Image, error) {
ret := _m.Called(path)
var r0 *models.Image
if rf, ok := ret.Get(0).(func(string) *models.Image); ok {
r0 = rf(path)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.Image)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(string) error); ok {
r1 = rf(path)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// FindMany provides a mock function with given fields: ids // FindMany provides a mock function with given fields: ids
func (_m *ImageReaderWriter) FindMany(ids []int) ([]*models.Image, error) { func (_m *ImageReaderWriter) FindMany(ids []int) ([]*models.Image, error) {
ret := _m.Called(ids) ret := _m.Called(ids)
@@ -127,6 +250,168 @@ func (_m *ImageReaderWriter) FindMany(ids []int) ([]*models.Image, error) {
return r0, r1 return r0, r1
} }
// GetGalleryIDs provides a mock function with given fields: imageID
func (_m *ImageReaderWriter) GetGalleryIDs(imageID int) ([]int, error) {
ret := _m.Called(imageID)
var r0 []int
if rf, ok := ret.Get(0).(func(int) []int); ok {
r0 = rf(imageID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]int)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(int) error); ok {
r1 = rf(imageID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// GetPerformerIDs provides a mock function with given fields: imageID
func (_m *ImageReaderWriter) GetPerformerIDs(imageID int) ([]int, error) {
ret := _m.Called(imageID)
var r0 []int
if rf, ok := ret.Get(0).(func(int) []int); ok {
r0 = rf(imageID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]int)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(int) error); ok {
r1 = rf(imageID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// GetTagIDs provides a mock function with given fields: imageID
func (_m *ImageReaderWriter) GetTagIDs(imageID int) ([]int, error) {
ret := _m.Called(imageID)
var r0 []int
if rf, ok := ret.Get(0).(func(int) []int); ok {
r0 = rf(imageID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]int)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(int) error); ok {
r1 = rf(imageID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// IncrementOCounter provides a mock function with given fields: id
func (_m *ImageReaderWriter) IncrementOCounter(id int) (int, error) {
ret := _m.Called(id)
var r0 int
if rf, ok := ret.Get(0).(func(int) int); ok {
r0 = rf(id)
} else {
r0 = ret.Get(0).(int)
}
var r1 error
if rf, ok := ret.Get(1).(func(int) error); ok {
r1 = rf(id)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Query provides a mock function with given fields: imageFilter, findFilter
func (_m *ImageReaderWriter) Query(imageFilter *models.ImageFilterType, findFilter *models.FindFilterType) ([]*models.Image, int, error) {
ret := _m.Called(imageFilter, findFilter)
var r0 []*models.Image
if rf, ok := ret.Get(0).(func(*models.ImageFilterType, *models.FindFilterType) []*models.Image); ok {
r0 = rf(imageFilter, findFilter)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Image)
}
}
var r1 int
if rf, ok := ret.Get(1).(func(*models.ImageFilterType, *models.FindFilterType) int); ok {
r1 = rf(imageFilter, findFilter)
} else {
r1 = ret.Get(1).(int)
}
var r2 error
if rf, ok := ret.Get(2).(func(*models.ImageFilterType, *models.FindFilterType) error); ok {
r2 = rf(imageFilter, findFilter)
} else {
r2 = ret.Error(2)
}
return r0, r1, r2
}
// ResetOCounter provides a mock function with given fields: id
func (_m *ImageReaderWriter) ResetOCounter(id int) (int, error) {
ret := _m.Called(id)
var r0 int
if rf, ok := ret.Get(0).(func(int) int); ok {
r0 = rf(id)
} else {
r0 = ret.Get(0).(int)
}
var r1 error
if rf, ok := ret.Get(1).(func(int) error); ok {
r1 = rf(id)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Size provides a mock function with given fields:
func (_m *ImageReaderWriter) Size() (float64, error) {
ret := _m.Called()
var r0 float64
if rf, ok := ret.Get(0).(func() float64); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(float64)
}
var r1 error
if rf, ok := ret.Get(1).(func() error); ok {
r1 = rf()
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Update provides a mock function with given fields: updatedImage // Update provides a mock function with given fields: updatedImage
func (_m *ImageReaderWriter) Update(updatedImage models.ImagePartial) (*models.Image, error) { func (_m *ImageReaderWriter) Update(updatedImage models.ImagePartial) (*models.Image, error) {
ret := _m.Called(updatedImage) ret := _m.Called(updatedImage)
@@ -172,3 +457,45 @@ func (_m *ImageReaderWriter) UpdateFull(updatedImage models.Image) (*models.Imag
return r0, r1 return r0, r1
} }
// UpdateGalleries provides a mock function with given fields: imageID, galleryIDs
func (_m *ImageReaderWriter) UpdateGalleries(imageID int, galleryIDs []int) error {
ret := _m.Called(imageID, galleryIDs)
var r0 error
if rf, ok := ret.Get(0).(func(int, []int) error); ok {
r0 = rf(imageID, galleryIDs)
} else {
r0 = ret.Error(0)
}
return r0
}
// UpdatePerformers provides a mock function with given fields: imageID, performerIDs
func (_m *ImageReaderWriter) UpdatePerformers(imageID int, performerIDs []int) error {
ret := _m.Called(imageID, performerIDs)
var r0 error
if rf, ok := ret.Get(0).(func(int, []int) error); ok {
r0 = rf(imageID, performerIDs)
} else {
r0 = ret.Error(0)
}
return r0
}
// UpdateTags provides a mock function with given fields: imageID, tagIDs
func (_m *ImageReaderWriter) UpdateTags(imageID int, tagIDs []int) error {
ret := _m.Called(imageID, tagIDs)
var r0 error
if rf, ok := ret.Get(0).(func(int, []int) error); ok {
r0 = rf(imageID, tagIDs)
} else {
r0 = ret.Error(0)
}
return r0
}

View File

@@ -1,190 +0,0 @@
// Code generated by mockery v0.0.0-dev. DO NOT EDIT.
package mocks
import (
models "github.com/stashapp/stash/pkg/models"
mock "github.com/stretchr/testify/mock"
)
// JoinReaderWriter is an autogenerated mock type for the JoinReaderWriter type
type JoinReaderWriter struct {
mock.Mock
}
// CreateMoviesScenes provides a mock function with given fields: newJoins
func (_m *JoinReaderWriter) CreateMoviesScenes(newJoins []models.MoviesScenes) error {
ret := _m.Called(newJoins)
var r0 error
if rf, ok := ret.Get(0).(func([]models.MoviesScenes) error); ok {
r0 = rf(newJoins)
} else {
r0 = ret.Error(0)
}
return r0
}
// CreatePerformersScenes provides a mock function with given fields: newJoins
func (_m *JoinReaderWriter) CreatePerformersScenes(newJoins []models.PerformersScenes) error {
ret := _m.Called(newJoins)
var r0 error
if rf, ok := ret.Get(0).(func([]models.PerformersScenes) error); ok {
r0 = rf(newJoins)
} else {
r0 = ret.Error(0)
}
return r0
}
// GetSceneMovies provides a mock function with given fields: sceneID
func (_m *JoinReaderWriter) GetSceneMovies(sceneID int) ([]models.MoviesScenes, error) {
ret := _m.Called(sceneID)
var r0 []models.MoviesScenes
if rf, ok := ret.Get(0).(func(int) []models.MoviesScenes); ok {
r0 = rf(sceneID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]models.MoviesScenes)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(int) error); ok {
r1 = rf(sceneID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// UpdateGalleriesImages provides a mock function with given fields: imageID, updatedJoins
func (_m *JoinReaderWriter) UpdateGalleriesImages(imageID int, updatedJoins []models.GalleriesImages) error {
ret := _m.Called(imageID, updatedJoins)
var r0 error
if rf, ok := ret.Get(0).(func(int, []models.GalleriesImages) error); ok {
r0 = rf(imageID, updatedJoins)
} else {
r0 = ret.Error(0)
}
return r0
}
// UpdateGalleriesTags provides a mock function with given fields: galleryID, updatedJoins
func (_m *JoinReaderWriter) UpdateGalleriesTags(galleryID int, updatedJoins []models.GalleriesTags) error {
ret := _m.Called(galleryID, updatedJoins)
var r0 error
if rf, ok := ret.Get(0).(func(int, []models.GalleriesTags) error); ok {
r0 = rf(galleryID, updatedJoins)
} else {
r0 = ret.Error(0)
}
return r0
}
// UpdateImagesTags provides a mock function with given fields: imageID, updatedJoins
func (_m *JoinReaderWriter) UpdateImagesTags(imageID int, updatedJoins []models.ImagesTags) error {
ret := _m.Called(imageID, updatedJoins)
var r0 error
if rf, ok := ret.Get(0).(func(int, []models.ImagesTags) error); ok {
r0 = rf(imageID, updatedJoins)
} else {
r0 = ret.Error(0)
}
return r0
}
// UpdateMoviesScenes provides a mock function with given fields: sceneID, updatedJoins
func (_m *JoinReaderWriter) UpdateMoviesScenes(sceneID int, updatedJoins []models.MoviesScenes) error {
ret := _m.Called(sceneID, updatedJoins)
var r0 error
if rf, ok := ret.Get(0).(func(int, []models.MoviesScenes) error); ok {
r0 = rf(sceneID, updatedJoins)
} else {
r0 = ret.Error(0)
}
return r0
}
// UpdatePerformersGalleries provides a mock function with given fields: galleryID, updatedJoins
func (_m *JoinReaderWriter) UpdatePerformersGalleries(galleryID int, updatedJoins []models.PerformersGalleries) error {
ret := _m.Called(galleryID, updatedJoins)
var r0 error
if rf, ok := ret.Get(0).(func(int, []models.PerformersGalleries) error); ok {
r0 = rf(galleryID, updatedJoins)
} else {
r0 = ret.Error(0)
}
return r0
}
// UpdatePerformersImages provides a mock function with given fields: imageID, updatedJoins
func (_m *JoinReaderWriter) UpdatePerformersImages(imageID int, updatedJoins []models.PerformersImages) error {
ret := _m.Called(imageID, updatedJoins)
var r0 error
if rf, ok := ret.Get(0).(func(int, []models.PerformersImages) error); ok {
r0 = rf(imageID, updatedJoins)
} else {
r0 = ret.Error(0)
}
return r0
}
// UpdatePerformersScenes provides a mock function with given fields: sceneID, updatedJoins
func (_m *JoinReaderWriter) UpdatePerformersScenes(sceneID int, updatedJoins []models.PerformersScenes) error {
ret := _m.Called(sceneID, updatedJoins)
var r0 error
if rf, ok := ret.Get(0).(func(int, []models.PerformersScenes) error); ok {
r0 = rf(sceneID, updatedJoins)
} else {
r0 = ret.Error(0)
}
return r0
}
// UpdateSceneMarkersTags provides a mock function with given fields: sceneMarkerID, updatedJoins
func (_m *JoinReaderWriter) UpdateSceneMarkersTags(sceneMarkerID int, updatedJoins []models.SceneMarkersTags) error {
ret := _m.Called(sceneMarkerID, updatedJoins)
var r0 error
if rf, ok := ret.Get(0).(func(int, []models.SceneMarkersTags) error); ok {
r0 = rf(sceneMarkerID, updatedJoins)
} else {
r0 = ret.Error(0)
}
return r0
}
// UpdateScenesTags provides a mock function with given fields: sceneID, updatedJoins
func (_m *JoinReaderWriter) UpdateScenesTags(sceneID int, updatedJoins []models.ScenesTags) error {
ret := _m.Called(sceneID, updatedJoins)
var r0 error
if rf, ok := ret.Get(0).(func(int, []models.ScenesTags) error); ok {
r0 = rf(sceneID, updatedJoins)
} else {
r0 = ret.Error(0)
}
return r0
}

View File

@@ -35,6 +35,50 @@ func (_m *MovieReaderWriter) All() ([]*models.Movie, error) {
return r0, r1 return r0, r1
} }
// AllSlim provides a mock function with given fields:
func (_m *MovieReaderWriter) AllSlim() ([]*models.Movie, error) {
ret := _m.Called()
var r0 []*models.Movie
if rf, ok := ret.Get(0).(func() []*models.Movie); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Movie)
}
}
var r1 error
if rf, ok := ret.Get(1).(func() error); ok {
r1 = rf()
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Count provides a mock function with given fields:
func (_m *MovieReaderWriter) Count() (int, error) {
ret := _m.Called()
var r0 int
if rf, ok := ret.Get(0).(func() int); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(int)
}
var r1 error
if rf, ok := ret.Get(1).(func() error); ok {
r1 = rf()
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Create provides a mock function with given fields: newMovie // Create provides a mock function with given fields: newMovie
func (_m *MovieReaderWriter) Create(newMovie models.Movie) (*models.Movie, error) { func (_m *MovieReaderWriter) Create(newMovie models.Movie) (*models.Movie, error) {
ret := _m.Called(newMovie) ret := _m.Called(newMovie)
@@ -58,6 +102,34 @@ func (_m *MovieReaderWriter) Create(newMovie models.Movie) (*models.Movie, error
return r0, r1 return r0, r1
} }
// Destroy provides a mock function with given fields: id
func (_m *MovieReaderWriter) Destroy(id int) error {
ret := _m.Called(id)
var r0 error
if rf, ok := ret.Get(0).(func(int) error); ok {
r0 = rf(id)
} else {
r0 = ret.Error(0)
}
return r0
}
// DestroyImages provides a mock function with given fields: movieID
func (_m *MovieReaderWriter) DestroyImages(movieID int) error {
ret := _m.Called(movieID)
var r0 error
if rf, ok := ret.Get(0).(func(int) error); ok {
r0 = rf(movieID)
} else {
r0 = ret.Error(0)
}
return r0
}
// Find provides a mock function with given fields: id // Find provides a mock function with given fields: id
func (_m *MovieReaderWriter) Find(id int) (*models.Movie, error) { func (_m *MovieReaderWriter) Find(id int) (*models.Movie, error) {
ret := _m.Called(id) ret := _m.Called(id)
@@ -196,6 +268,36 @@ func (_m *MovieReaderWriter) GetFrontImage(movieID int) ([]byte, error) {
return r0, r1 return r0, r1
} }
// Query provides a mock function with given fields: movieFilter, findFilter
func (_m *MovieReaderWriter) Query(movieFilter *models.MovieFilterType, findFilter *models.FindFilterType) ([]*models.Movie, int, error) {
ret := _m.Called(movieFilter, findFilter)
var r0 []*models.Movie
if rf, ok := ret.Get(0).(func(*models.MovieFilterType, *models.FindFilterType) []*models.Movie); ok {
r0 = rf(movieFilter, findFilter)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Movie)
}
}
var r1 int
if rf, ok := ret.Get(1).(func(*models.MovieFilterType, *models.FindFilterType) int); ok {
r1 = rf(movieFilter, findFilter)
} else {
r1 = ret.Get(1).(int)
}
var r2 error
if rf, ok := ret.Get(2).(func(*models.MovieFilterType, *models.FindFilterType) error); ok {
r2 = rf(movieFilter, findFilter)
} else {
r2 = ret.Error(2)
}
return r0, r1, r2
}
// Update provides a mock function with given fields: updatedMovie // Update provides a mock function with given fields: updatedMovie
func (_m *MovieReaderWriter) Update(updatedMovie models.MoviePartial) (*models.Movie, error) { func (_m *MovieReaderWriter) Update(updatedMovie models.MoviePartial) (*models.Movie, error) {
ret := _m.Called(updatedMovie) ret := _m.Called(updatedMovie)
@@ -242,8 +344,8 @@ func (_m *MovieReaderWriter) UpdateFull(updatedMovie models.Movie) (*models.Movi
return r0, r1 return r0, r1
} }
// UpdateMovieImages provides a mock function with given fields: movieID, frontImage, backImage // UpdateImages provides a mock function with given fields: movieID, frontImage, backImage
func (_m *MovieReaderWriter) UpdateMovieImages(movieID int, frontImage []byte, backImage []byte) error { func (_m *MovieReaderWriter) UpdateImages(movieID int, frontImage []byte, backImage []byte) error {
ret := _m.Called(movieID, frontImage, backImage) ret := _m.Called(movieID, frontImage, backImage)
var r0 error var r0 error

View File

@@ -35,6 +35,50 @@ func (_m *PerformerReaderWriter) All() ([]*models.Performer, error) {
return r0, r1 return r0, r1
} }
// AllSlim provides a mock function with given fields:
func (_m *PerformerReaderWriter) AllSlim() ([]*models.Performer, error) {
ret := _m.Called()
var r0 []*models.Performer
if rf, ok := ret.Get(0).(func() []*models.Performer); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Performer)
}
}
var r1 error
if rf, ok := ret.Get(1).(func() error); ok {
r1 = rf()
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Count provides a mock function with given fields:
func (_m *PerformerReaderWriter) Count() (int, error) {
ret := _m.Called()
var r0 int
if rf, ok := ret.Get(0).(func() int); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(int)
}
var r1 error
if rf, ok := ret.Get(1).(func() error); ok {
r1 = rf()
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Create provides a mock function with given fields: newPerformer // Create provides a mock function with given fields: newPerformer
func (_m *PerformerReaderWriter) Create(newPerformer models.Performer) (*models.Performer, error) { func (_m *PerformerReaderWriter) Create(newPerformer models.Performer) (*models.Performer, error) {
ret := _m.Called(newPerformer) ret := _m.Called(newPerformer)
@@ -58,6 +102,57 @@ func (_m *PerformerReaderWriter) Create(newPerformer models.Performer) (*models.
return r0, r1 return r0, r1
} }
// Destroy provides a mock function with given fields: id
func (_m *PerformerReaderWriter) Destroy(id int) error {
ret := _m.Called(id)
var r0 error
if rf, ok := ret.Get(0).(func(int) error); ok {
r0 = rf(id)
} else {
r0 = ret.Error(0)
}
return r0
}
// DestroyImage provides a mock function with given fields: performerID
func (_m *PerformerReaderWriter) DestroyImage(performerID int) error {
ret := _m.Called(performerID)
var r0 error
if rf, ok := ret.Get(0).(func(int) error); ok {
r0 = rf(performerID)
} else {
r0 = ret.Error(0)
}
return r0
}
// Find provides a mock function with given fields: id
func (_m *PerformerReaderWriter) Find(id int) (*models.Performer, error) {
ret := _m.Called(id)
var r0 *models.Performer
if rf, ok := ret.Get(0).(func(int) *models.Performer); ok {
r0 = rf(id)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.Performer)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(int) error); ok {
r1 = rf(id)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// FindByGalleryID provides a mock function with given fields: galleryID // FindByGalleryID provides a mock function with given fields: galleryID
func (_m *PerformerReaderWriter) FindByGalleryID(galleryID int) ([]*models.Performer, error) { func (_m *PerformerReaderWriter) FindByGalleryID(galleryID int) ([]*models.Performer, error) {
ret := _m.Called(galleryID) ret := _m.Called(galleryID)
@@ -196,8 +291,8 @@ func (_m *PerformerReaderWriter) FindNamesBySceneID(sceneID int) ([]*models.Perf
return r0, r1 return r0, r1
} }
// GetPerformerImage provides a mock function with given fields: performerID // GetImage provides a mock function with given fields: performerID
func (_m *PerformerReaderWriter) GetPerformerImage(performerID int) ([]byte, error) { func (_m *PerformerReaderWriter) GetImage(performerID int) ([]byte, error) {
ret := _m.Called(performerID) ret := _m.Called(performerID)
var r0 []byte var r0 []byte
@@ -219,6 +314,59 @@ func (_m *PerformerReaderWriter) GetPerformerImage(performerID int) ([]byte, err
return r0, r1 return r0, r1
} }
// GetStashIDs provides a mock function with given fields: performerID
func (_m *PerformerReaderWriter) GetStashIDs(performerID int) ([]*models.StashID, error) {
ret := _m.Called(performerID)
var r0 []*models.StashID
if rf, ok := ret.Get(0).(func(int) []*models.StashID); ok {
r0 = rf(performerID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.StashID)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(int) error); ok {
r1 = rf(performerID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Query provides a mock function with given fields: performerFilter, findFilter
func (_m *PerformerReaderWriter) Query(performerFilter *models.PerformerFilterType, findFilter *models.FindFilterType) ([]*models.Performer, int, error) {
ret := _m.Called(performerFilter, findFilter)
var r0 []*models.Performer
if rf, ok := ret.Get(0).(func(*models.PerformerFilterType, *models.FindFilterType) []*models.Performer); ok {
r0 = rf(performerFilter, findFilter)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Performer)
}
}
var r1 int
if rf, ok := ret.Get(1).(func(*models.PerformerFilterType, *models.FindFilterType) int); ok {
r1 = rf(performerFilter, findFilter)
} else {
r1 = ret.Get(1).(int)
}
var r2 error
if rf, ok := ret.Get(2).(func(*models.PerformerFilterType, *models.FindFilterType) error); ok {
r2 = rf(performerFilter, findFilter)
} else {
r2 = ret.Error(2)
}
return r0, r1, r2
}
// Update provides a mock function with given fields: updatedPerformer // Update provides a mock function with given fields: updatedPerformer
func (_m *PerformerReaderWriter) Update(updatedPerformer models.PerformerPartial) (*models.Performer, error) { func (_m *PerformerReaderWriter) Update(updatedPerformer models.PerformerPartial) (*models.Performer, error) {
ret := _m.Called(updatedPerformer) ret := _m.Called(updatedPerformer)
@@ -265,8 +413,8 @@ func (_m *PerformerReaderWriter) UpdateFull(updatedPerformer models.Performer) (
return r0, r1 return r0, r1
} }
// UpdatePerformerImage provides a mock function with given fields: performerID, image // UpdateImage provides a mock function with given fields: performerID, image
func (_m *PerformerReaderWriter) UpdatePerformerImage(performerID int, image []byte) error { func (_m *PerformerReaderWriter) UpdateImage(performerID int, image []byte) error {
ret := _m.Called(performerID, image) ret := _m.Called(performerID, image)
var r0 error var r0 error
@@ -278,3 +426,17 @@ func (_m *PerformerReaderWriter) UpdatePerformerImage(performerID int, image []b
return r0 return r0
} }
// UpdateStashIDs provides a mock function with given fields: performerID, stashIDs
func (_m *PerformerReaderWriter) UpdateStashIDs(performerID int, stashIDs []models.StashID) error {
ret := _m.Called(performerID, stashIDs)
var r0 error
if rf, ok := ret.Get(0).(func(int, []models.StashID) error); ok {
r0 = rf(performerID, stashIDs)
} else {
r0 = ret.Error(0)
}
return r0
}

View File

@@ -12,6 +12,27 @@ type SceneMarkerReaderWriter struct {
mock.Mock mock.Mock
} }
// CountByTagID provides a mock function with given fields: tagID
func (_m *SceneMarkerReaderWriter) CountByTagID(tagID int) (int, error) {
ret := _m.Called(tagID)
var r0 int
if rf, ok := ret.Get(0).(func(int) int); ok {
r0 = rf(tagID)
} else {
r0 = ret.Get(0).(int)
}
var r1 error
if rf, ok := ret.Get(1).(func(int) error); ok {
r1 = rf(tagID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Create provides a mock function with given fields: newSceneMarker // Create provides a mock function with given fields: newSceneMarker
func (_m *SceneMarkerReaderWriter) Create(newSceneMarker models.SceneMarker) (*models.SceneMarker, error) { func (_m *SceneMarkerReaderWriter) Create(newSceneMarker models.SceneMarker) (*models.SceneMarker, error) {
ret := _m.Called(newSceneMarker) ret := _m.Called(newSceneMarker)
@@ -35,6 +56,43 @@ func (_m *SceneMarkerReaderWriter) Create(newSceneMarker models.SceneMarker) (*m
return r0, r1 return r0, r1
} }
// Destroy provides a mock function with given fields: id
func (_m *SceneMarkerReaderWriter) Destroy(id int) error {
ret := _m.Called(id)
var r0 error
if rf, ok := ret.Get(0).(func(int) error); ok {
r0 = rf(id)
} else {
r0 = ret.Error(0)
}
return r0
}
// Find provides a mock function with given fields: id
func (_m *SceneMarkerReaderWriter) Find(id int) (*models.SceneMarker, error) {
ret := _m.Called(id)
var r0 *models.SceneMarker
if rf, ok := ret.Get(0).(func(int) *models.SceneMarker); ok {
r0 = rf(id)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.SceneMarker)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(int) error); ok {
r1 = rf(id)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// FindBySceneID provides a mock function with given fields: sceneID // FindBySceneID provides a mock function with given fields: sceneID
func (_m *SceneMarkerReaderWriter) FindBySceneID(sceneID int) ([]*models.SceneMarker, error) { func (_m *SceneMarkerReaderWriter) FindBySceneID(sceneID int) ([]*models.SceneMarker, error) {
ret := _m.Called(sceneID) ret := _m.Called(sceneID)
@@ -58,6 +116,105 @@ func (_m *SceneMarkerReaderWriter) FindBySceneID(sceneID int) ([]*models.SceneMa
return r0, r1 return r0, r1
} }
// FindMany provides a mock function with given fields: ids
func (_m *SceneMarkerReaderWriter) FindMany(ids []int) ([]*models.SceneMarker, error) {
ret := _m.Called(ids)
var r0 []*models.SceneMarker
if rf, ok := ret.Get(0).(func([]int) []*models.SceneMarker); ok {
r0 = rf(ids)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.SceneMarker)
}
}
var r1 error
if rf, ok := ret.Get(1).(func([]int) error); ok {
r1 = rf(ids)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// GetMarkerStrings provides a mock function with given fields: q, sort
func (_m *SceneMarkerReaderWriter) GetMarkerStrings(q *string, sort *string) ([]*models.MarkerStringsResultType, error) {
ret := _m.Called(q, sort)
var r0 []*models.MarkerStringsResultType
if rf, ok := ret.Get(0).(func(*string, *string) []*models.MarkerStringsResultType); ok {
r0 = rf(q, sort)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.MarkerStringsResultType)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(*string, *string) error); ok {
r1 = rf(q, sort)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// GetTagIDs provides a mock function with given fields: imageID
func (_m *SceneMarkerReaderWriter) GetTagIDs(imageID int) ([]int, error) {
ret := _m.Called(imageID)
var r0 []int
if rf, ok := ret.Get(0).(func(int) []int); ok {
r0 = rf(imageID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]int)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(int) error); ok {
r1 = rf(imageID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Query provides a mock function with given fields: sceneMarkerFilter, findFilter
func (_m *SceneMarkerReaderWriter) Query(sceneMarkerFilter *models.SceneMarkerFilterType, findFilter *models.FindFilterType) ([]*models.SceneMarker, int, error) {
ret := _m.Called(sceneMarkerFilter, findFilter)
var r0 []*models.SceneMarker
if rf, ok := ret.Get(0).(func(*models.SceneMarkerFilterType, *models.FindFilterType) []*models.SceneMarker); ok {
r0 = rf(sceneMarkerFilter, findFilter)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.SceneMarker)
}
}
var r1 int
if rf, ok := ret.Get(1).(func(*models.SceneMarkerFilterType, *models.FindFilterType) int); ok {
r1 = rf(sceneMarkerFilter, findFilter)
} else {
r1 = ret.Get(1).(int)
}
var r2 error
if rf, ok := ret.Get(2).(func(*models.SceneMarkerFilterType, *models.FindFilterType) error); ok {
r2 = rf(sceneMarkerFilter, findFilter)
} else {
r2 = ret.Error(2)
}
return r0, r1, r2
}
// Update provides a mock function with given fields: updatedSceneMarker // Update provides a mock function with given fields: updatedSceneMarker
func (_m *SceneMarkerReaderWriter) Update(updatedSceneMarker models.SceneMarker) (*models.SceneMarker, error) { func (_m *SceneMarkerReaderWriter) Update(updatedSceneMarker models.SceneMarker) (*models.SceneMarker, error) {
ret := _m.Called(updatedSceneMarker) ret := _m.Called(updatedSceneMarker)
@@ -80,3 +237,40 @@ func (_m *SceneMarkerReaderWriter) Update(updatedSceneMarker models.SceneMarker)
return r0, r1 return r0, r1
} }
// UpdateTags provides a mock function with given fields: markerID, tagIDs
func (_m *SceneMarkerReaderWriter) UpdateTags(markerID int, tagIDs []int) error {
ret := _m.Called(markerID, tagIDs)
var r0 error
if rf, ok := ret.Get(0).(func(int, []int) error); ok {
r0 = rf(markerID, tagIDs)
} else {
r0 = ret.Error(0)
}
return r0
}
// Wall provides a mock function with given fields: q
func (_m *SceneMarkerReaderWriter) Wall(q *string) ([]*models.SceneMarker, error) {
ret := _m.Called(q)
var r0 []*models.SceneMarker
if rf, ok := ret.Get(0).(func(*string) []*models.SceneMarker); ok {
r0 = rf(q)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.SceneMarker)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(*string) error); ok {
r1 = rf(q)
} else {
r1 = ret.Error(1)
}
return r0, r1
}

View File

@@ -35,6 +35,153 @@ func (_m *SceneReaderWriter) All() ([]*models.Scene, error) {
return r0, r1 return r0, r1
} }
// Count provides a mock function with given fields:
func (_m *SceneReaderWriter) Count() (int, error) {
ret := _m.Called()
var r0 int
if rf, ok := ret.Get(0).(func() int); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(int)
}
var r1 error
if rf, ok := ret.Get(1).(func() error); ok {
r1 = rf()
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// CountByMovieID provides a mock function with given fields: movieID
func (_m *SceneReaderWriter) CountByMovieID(movieID int) (int, error) {
ret := _m.Called(movieID)
var r0 int
if rf, ok := ret.Get(0).(func(int) int); ok {
r0 = rf(movieID)
} else {
r0 = ret.Get(0).(int)
}
var r1 error
if rf, ok := ret.Get(1).(func(int) error); ok {
r1 = rf(movieID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// CountByPerformerID provides a mock function with given fields: performerID
func (_m *SceneReaderWriter) CountByPerformerID(performerID int) (int, error) {
ret := _m.Called(performerID)
var r0 int
if rf, ok := ret.Get(0).(func(int) int); ok {
r0 = rf(performerID)
} else {
r0 = ret.Get(0).(int)
}
var r1 error
if rf, ok := ret.Get(1).(func(int) error); ok {
r1 = rf(performerID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// CountByStudioID provides a mock function with given fields: studioID
func (_m *SceneReaderWriter) CountByStudioID(studioID int) (int, error) {
ret := _m.Called(studioID)
var r0 int
if rf, ok := ret.Get(0).(func(int) int); ok {
r0 = rf(studioID)
} else {
r0 = ret.Get(0).(int)
}
var r1 error
if rf, ok := ret.Get(1).(func(int) error); ok {
r1 = rf(studioID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// CountByTagID provides a mock function with given fields: tagID
func (_m *SceneReaderWriter) CountByTagID(tagID int) (int, error) {
ret := _m.Called(tagID)
var r0 int
if rf, ok := ret.Get(0).(func(int) int); ok {
r0 = rf(tagID)
} else {
r0 = ret.Get(0).(int)
}
var r1 error
if rf, ok := ret.Get(1).(func(int) error); ok {
r1 = rf(tagID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// CountMissingChecksum provides a mock function with given fields:
func (_m *SceneReaderWriter) CountMissingChecksum() (int, error) {
ret := _m.Called()
var r0 int
if rf, ok := ret.Get(0).(func() int); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(int)
}
var r1 error
if rf, ok := ret.Get(1).(func() error); ok {
r1 = rf()
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// CountMissingOSHash provides a mock function with given fields:
func (_m *SceneReaderWriter) CountMissingOSHash() (int, error) {
ret := _m.Called()
var r0 int
if rf, ok := ret.Get(0).(func() int); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(int)
}
var r1 error
if rf, ok := ret.Get(1).(func() error); ok {
r1 = rf()
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Create provides a mock function with given fields: newScene // Create provides a mock function with given fields: newScene
func (_m *SceneReaderWriter) Create(newScene models.Scene) (*models.Scene, error) { func (_m *SceneReaderWriter) Create(newScene models.Scene) (*models.Scene, error) {
ret := _m.Called(newScene) ret := _m.Called(newScene)
@@ -58,6 +205,78 @@ func (_m *SceneReaderWriter) Create(newScene models.Scene) (*models.Scene, error
return r0, r1 return r0, r1
} }
// DecrementOCounter provides a mock function with given fields: id
func (_m *SceneReaderWriter) DecrementOCounter(id int) (int, error) {
ret := _m.Called(id)
var r0 int
if rf, ok := ret.Get(0).(func(int) int); ok {
r0 = rf(id)
} else {
r0 = ret.Get(0).(int)
}
var r1 error
if rf, ok := ret.Get(1).(func(int) error); ok {
r1 = rf(id)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Destroy provides a mock function with given fields: id
func (_m *SceneReaderWriter) Destroy(id int) error {
ret := _m.Called(id)
var r0 error
if rf, ok := ret.Get(0).(func(int) error); ok {
r0 = rf(id)
} else {
r0 = ret.Error(0)
}
return r0
}
// DestroyCover provides a mock function with given fields: sceneID
func (_m *SceneReaderWriter) DestroyCover(sceneID int) error {
ret := _m.Called(sceneID)
var r0 error
if rf, ok := ret.Get(0).(func(int) error); ok {
r0 = rf(sceneID)
} else {
r0 = ret.Error(0)
}
return r0
}
// Find provides a mock function with given fields: id
func (_m *SceneReaderWriter) Find(id int) (*models.Scene, error) {
ret := _m.Called(id)
var r0 *models.Scene
if rf, ok := ret.Get(0).(func(int) *models.Scene); ok {
r0 = rf(id)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.Scene)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(int) error); ok {
r1 = rf(id)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// FindByChecksum provides a mock function with given fields: checksum // FindByChecksum provides a mock function with given fields: checksum
func (_m *SceneReaderWriter) FindByChecksum(checksum string) (*models.Scene, error) { func (_m *SceneReaderWriter) FindByChecksum(checksum string) (*models.Scene, error) {
ret := _m.Called(checksum) ret := _m.Called(checksum)
@@ -127,6 +346,52 @@ func (_m *SceneReaderWriter) FindByOSHash(oshash string) (*models.Scene, error)
return r0, r1 return r0, r1
} }
// FindByPath provides a mock function with given fields: path
func (_m *SceneReaderWriter) FindByPath(path string) (*models.Scene, error) {
ret := _m.Called(path)
var r0 *models.Scene
if rf, ok := ret.Get(0).(func(string) *models.Scene); ok {
r0 = rf(path)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.Scene)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(string) error); ok {
r1 = rf(path)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// FindByPerformerID provides a mock function with given fields: performerID
func (_m *SceneReaderWriter) FindByPerformerID(performerID int) ([]*models.Scene, error) {
ret := _m.Called(performerID)
var r0 []*models.Scene
if rf, ok := ret.Get(0).(func(int) []*models.Scene); ok {
r0 = rf(performerID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Scene)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(int) error); ok {
r1 = rf(performerID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// FindMany provides a mock function with given fields: ids // FindMany provides a mock function with given fields: ids
func (_m *SceneReaderWriter) FindMany(ids []int) ([]*models.Scene, error) { func (_m *SceneReaderWriter) FindMany(ids []int) ([]*models.Scene, error) {
ret := _m.Called(ids) ret := _m.Called(ids)
@@ -150,8 +415,8 @@ func (_m *SceneReaderWriter) FindMany(ids []int) ([]*models.Scene, error) {
return r0, r1 return r0, r1
} }
// GetSceneCover provides a mock function with given fields: sceneID // GetCover provides a mock function with given fields: sceneID
func (_m *SceneReaderWriter) GetSceneCover(sceneID int) ([]byte, error) { func (_m *SceneReaderWriter) GetCover(sceneID int) ([]byte, error) {
ret := _m.Called(sceneID) ret := _m.Called(sceneID)
var r0 []byte var r0 []byte
@@ -173,6 +438,244 @@ func (_m *SceneReaderWriter) GetSceneCover(sceneID int) ([]byte, error) {
return r0, r1 return r0, r1
} }
// GetMovies provides a mock function with given fields: sceneID
func (_m *SceneReaderWriter) GetMovies(sceneID int) ([]models.MoviesScenes, error) {
ret := _m.Called(sceneID)
var r0 []models.MoviesScenes
if rf, ok := ret.Get(0).(func(int) []models.MoviesScenes); ok {
r0 = rf(sceneID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]models.MoviesScenes)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(int) error); ok {
r1 = rf(sceneID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// GetPerformerIDs provides a mock function with given fields: imageID
func (_m *SceneReaderWriter) GetPerformerIDs(imageID int) ([]int, error) {
ret := _m.Called(imageID)
var r0 []int
if rf, ok := ret.Get(0).(func(int) []int); ok {
r0 = rf(imageID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]int)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(int) error); ok {
r1 = rf(imageID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// GetStashIDs provides a mock function with given fields: performerID
func (_m *SceneReaderWriter) GetStashIDs(performerID int) ([]*models.StashID, error) {
ret := _m.Called(performerID)
var r0 []*models.StashID
if rf, ok := ret.Get(0).(func(int) []*models.StashID); ok {
r0 = rf(performerID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.StashID)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(int) error); ok {
r1 = rf(performerID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// GetTagIDs provides a mock function with given fields: imageID
func (_m *SceneReaderWriter) GetTagIDs(imageID int) ([]int, error) {
ret := _m.Called(imageID)
var r0 []int
if rf, ok := ret.Get(0).(func(int) []int); ok {
r0 = rf(imageID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]int)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(int) error); ok {
r1 = rf(imageID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// IncrementOCounter provides a mock function with given fields: id
func (_m *SceneReaderWriter) IncrementOCounter(id int) (int, error) {
ret := _m.Called(id)
var r0 int
if rf, ok := ret.Get(0).(func(int) int); ok {
r0 = rf(id)
} else {
r0 = ret.Get(0).(int)
}
var r1 error
if rf, ok := ret.Get(1).(func(int) error); ok {
r1 = rf(id)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Query provides a mock function with given fields: sceneFilter, findFilter
func (_m *SceneReaderWriter) Query(sceneFilter *models.SceneFilterType, findFilter *models.FindFilterType) ([]*models.Scene, int, error) {
ret := _m.Called(sceneFilter, findFilter)
var r0 []*models.Scene
if rf, ok := ret.Get(0).(func(*models.SceneFilterType, *models.FindFilterType) []*models.Scene); ok {
r0 = rf(sceneFilter, findFilter)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Scene)
}
}
var r1 int
if rf, ok := ret.Get(1).(func(*models.SceneFilterType, *models.FindFilterType) int); ok {
r1 = rf(sceneFilter, findFilter)
} else {
r1 = ret.Get(1).(int)
}
var r2 error
if rf, ok := ret.Get(2).(func(*models.SceneFilterType, *models.FindFilterType) error); ok {
r2 = rf(sceneFilter, findFilter)
} else {
r2 = ret.Error(2)
}
return r0, r1, r2
}
// QueryAllByPathRegex provides a mock function with given fields: regex, ignoreOrganized
func (_m *SceneReaderWriter) QueryAllByPathRegex(regex string, ignoreOrganized bool) ([]*models.Scene, error) {
ret := _m.Called(regex, ignoreOrganized)
var r0 []*models.Scene
if rf, ok := ret.Get(0).(func(string, bool) []*models.Scene); ok {
r0 = rf(regex, ignoreOrganized)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Scene)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(string, bool) error); ok {
r1 = rf(regex, ignoreOrganized)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// QueryByPathRegex provides a mock function with given fields: findFilter
func (_m *SceneReaderWriter) QueryByPathRegex(findFilter *models.FindFilterType) ([]*models.Scene, int, error) {
ret := _m.Called(findFilter)
var r0 []*models.Scene
if rf, ok := ret.Get(0).(func(*models.FindFilterType) []*models.Scene); ok {
r0 = rf(findFilter)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Scene)
}
}
var r1 int
if rf, ok := ret.Get(1).(func(*models.FindFilterType) int); ok {
r1 = rf(findFilter)
} else {
r1 = ret.Get(1).(int)
}
var r2 error
if rf, ok := ret.Get(2).(func(*models.FindFilterType) error); ok {
r2 = rf(findFilter)
} else {
r2 = ret.Error(2)
}
return r0, r1, r2
}
// ResetOCounter provides a mock function with given fields: id
func (_m *SceneReaderWriter) ResetOCounter(id int) (int, error) {
ret := _m.Called(id)
var r0 int
if rf, ok := ret.Get(0).(func(int) int); ok {
r0 = rf(id)
} else {
r0 = ret.Get(0).(int)
}
var r1 error
if rf, ok := ret.Get(1).(func(int) error); ok {
r1 = rf(id)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Size provides a mock function with given fields:
func (_m *SceneReaderWriter) Size() (float64, error) {
ret := _m.Called()
var r0 float64
if rf, ok := ret.Get(0).(func() float64); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(float64)
}
var r1 error
if rf, ok := ret.Get(1).(func() error); ok {
r1 = rf()
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Update provides a mock function with given fields: updatedScene // Update provides a mock function with given fields: updatedScene
func (_m *SceneReaderWriter) Update(updatedScene models.ScenePartial) (*models.Scene, error) { func (_m *SceneReaderWriter) Update(updatedScene models.ScenePartial) (*models.Scene, error) {
ret := _m.Called(updatedScene) ret := _m.Called(updatedScene)
@@ -196,6 +699,34 @@ func (_m *SceneReaderWriter) Update(updatedScene models.ScenePartial) (*models.S
return r0, r1 return r0, r1
} }
// UpdateCover provides a mock function with given fields: sceneID, cover
func (_m *SceneReaderWriter) UpdateCover(sceneID int, cover []byte) error {
ret := _m.Called(sceneID, cover)
var r0 error
if rf, ok := ret.Get(0).(func(int, []byte) error); ok {
r0 = rf(sceneID, cover)
} else {
r0 = ret.Error(0)
}
return r0
}
// UpdateFileModTime provides a mock function with given fields: id, modTime
func (_m *SceneReaderWriter) UpdateFileModTime(id int, modTime models.NullSQLiteTimestamp) error {
ret := _m.Called(id, modTime)
var r0 error
if rf, ok := ret.Get(0).(func(int, models.NullSQLiteTimestamp) error); ok {
r0 = rf(id, modTime)
} else {
r0 = ret.Error(0)
}
return r0
}
// UpdateFull provides a mock function with given fields: updatedScene // UpdateFull provides a mock function with given fields: updatedScene
func (_m *SceneReaderWriter) UpdateFull(updatedScene models.Scene) (*models.Scene, error) { func (_m *SceneReaderWriter) UpdateFull(updatedScene models.Scene) (*models.Scene, error) {
ret := _m.Called(updatedScene) ret := _m.Called(updatedScene)
@@ -219,16 +750,81 @@ func (_m *SceneReaderWriter) UpdateFull(updatedScene models.Scene) (*models.Scen
return r0, r1 return r0, r1
} }
// UpdateSceneCover provides a mock function with given fields: sceneID, cover // UpdateMovies provides a mock function with given fields: sceneID, movies
func (_m *SceneReaderWriter) UpdateSceneCover(sceneID int, cover []byte) error { func (_m *SceneReaderWriter) UpdateMovies(sceneID int, movies []models.MoviesScenes) error {
ret := _m.Called(sceneID, cover) ret := _m.Called(sceneID, movies)
var r0 error var r0 error
if rf, ok := ret.Get(0).(func(int, []byte) error); ok { if rf, ok := ret.Get(0).(func(int, []models.MoviesScenes) error); ok {
r0 = rf(sceneID, cover) r0 = rf(sceneID, movies)
} else { } else {
r0 = ret.Error(0) r0 = ret.Error(0)
} }
return r0 return r0
} }
// UpdatePerformers provides a mock function with given fields: sceneID, performerIDs
func (_m *SceneReaderWriter) UpdatePerformers(sceneID int, performerIDs []int) error {
ret := _m.Called(sceneID, performerIDs)
var r0 error
if rf, ok := ret.Get(0).(func(int, []int) error); ok {
r0 = rf(sceneID, performerIDs)
} else {
r0 = ret.Error(0)
}
return r0
}
// UpdateStashIDs provides a mock function with given fields: sceneID, stashIDs
func (_m *SceneReaderWriter) UpdateStashIDs(sceneID int, stashIDs []models.StashID) error {
ret := _m.Called(sceneID, stashIDs)
var r0 error
if rf, ok := ret.Get(0).(func(int, []models.StashID) error); ok {
r0 = rf(sceneID, stashIDs)
} else {
r0 = ret.Error(0)
}
return r0
}
// UpdateTags provides a mock function with given fields: sceneID, tagIDs
func (_m *SceneReaderWriter) UpdateTags(sceneID int, tagIDs []int) error {
ret := _m.Called(sceneID, tagIDs)
var r0 error
if rf, ok := ret.Get(0).(func(int, []int) error); ok {
r0 = rf(sceneID, tagIDs)
} else {
r0 = ret.Error(0)
}
return r0
}
// Wall provides a mock function with given fields: q
func (_m *SceneReaderWriter) Wall(q *string) ([]*models.Scene, error) {
ret := _m.Called(q)
var r0 []*models.Scene
if rf, ok := ret.Get(0).(func(*string) []*models.Scene); ok {
r0 = rf(q)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Scene)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(*string) error); ok {
r1 = rf(q)
} else {
r1 = ret.Error(1)
}
return r0, r1
}

View File

@@ -0,0 +1,59 @@
// Code generated by mockery v0.0.0-dev. DO NOT EDIT.
package mocks
import (
models "github.com/stashapp/stash/pkg/models"
mock "github.com/stretchr/testify/mock"
)
// ScrapedItemReaderWriter is an autogenerated mock type for the ScrapedItemReaderWriter type
type ScrapedItemReaderWriter struct {
mock.Mock
}
// All provides a mock function with given fields:
func (_m *ScrapedItemReaderWriter) All() ([]*models.ScrapedItem, error) {
ret := _m.Called()
var r0 []*models.ScrapedItem
if rf, ok := ret.Get(0).(func() []*models.ScrapedItem); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.ScrapedItem)
}
}
var r1 error
if rf, ok := ret.Get(1).(func() error); ok {
r1 = rf()
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Create provides a mock function with given fields: newObject
func (_m *ScrapedItemReaderWriter) Create(newObject models.ScrapedItem) (*models.ScrapedItem, error) {
ret := _m.Called(newObject)
var r0 *models.ScrapedItem
if rf, ok := ret.Get(0).(func(models.ScrapedItem) *models.ScrapedItem); ok {
r0 = rf(newObject)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.ScrapedItem)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(models.ScrapedItem) error); ok {
r1 = rf(newObject)
} else {
r1 = ret.Error(1)
}
return r0, r1
}

View File

@@ -35,6 +35,50 @@ func (_m *StudioReaderWriter) All() ([]*models.Studio, error) {
return r0, r1 return r0, r1
} }
// AllSlim provides a mock function with given fields:
func (_m *StudioReaderWriter) AllSlim() ([]*models.Studio, error) {
ret := _m.Called()
var r0 []*models.Studio
if rf, ok := ret.Get(0).(func() []*models.Studio); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Studio)
}
}
var r1 error
if rf, ok := ret.Get(1).(func() error); ok {
r1 = rf()
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Count provides a mock function with given fields:
func (_m *StudioReaderWriter) Count() (int, error) {
ret := _m.Called()
var r0 int
if rf, ok := ret.Get(0).(func() int); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(int)
}
var r1 error
if rf, ok := ret.Get(1).(func() error); ok {
r1 = rf()
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Create provides a mock function with given fields: newStudio // Create provides a mock function with given fields: newStudio
func (_m *StudioReaderWriter) Create(newStudio models.Studio) (*models.Studio, error) { func (_m *StudioReaderWriter) Create(newStudio models.Studio) (*models.Studio, error) {
ret := _m.Called(newStudio) ret := _m.Called(newStudio)
@@ -58,6 +102,34 @@ func (_m *StudioReaderWriter) Create(newStudio models.Studio) (*models.Studio, e
return r0, r1 return r0, r1
} }
// Destroy provides a mock function with given fields: id
func (_m *StudioReaderWriter) Destroy(id int) error {
ret := _m.Called(id)
var r0 error
if rf, ok := ret.Get(0).(func(int) error); ok {
r0 = rf(id)
} else {
r0 = ret.Error(0)
}
return r0
}
// DestroyImage provides a mock function with given fields: studioID
func (_m *StudioReaderWriter) DestroyImage(studioID int) error {
ret := _m.Called(studioID)
var r0 error
if rf, ok := ret.Get(0).(func(int) error); ok {
r0 = rf(studioID)
} else {
r0 = ret.Error(0)
}
return r0
}
// Find provides a mock function with given fields: id // Find provides a mock function with given fields: id
func (_m *StudioReaderWriter) Find(id int) (*models.Studio, error) { func (_m *StudioReaderWriter) Find(id int) (*models.Studio, error) {
ret := _m.Called(id) ret := _m.Called(id)
@@ -104,6 +176,29 @@ func (_m *StudioReaderWriter) FindByName(name string, nocase bool) (*models.Stud
return r0, r1 return r0, r1
} }
// FindChildren provides a mock function with given fields: id
func (_m *StudioReaderWriter) FindChildren(id int) ([]*models.Studio, error) {
ret := _m.Called(id)
var r0 []*models.Studio
if rf, ok := ret.Get(0).(func(int) []*models.Studio); ok {
r0 = rf(id)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Studio)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(int) error); ok {
r1 = rf(id)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// FindMany provides a mock function with given fields: ids // FindMany provides a mock function with given fields: ids
func (_m *StudioReaderWriter) FindMany(ids []int) ([]*models.Studio, error) { func (_m *StudioReaderWriter) FindMany(ids []int) ([]*models.Studio, error) {
ret := _m.Called(ids) ret := _m.Called(ids)
@@ -127,8 +222,8 @@ func (_m *StudioReaderWriter) FindMany(ids []int) ([]*models.Studio, error) {
return r0, r1 return r0, r1
} }
// GetStudioImage provides a mock function with given fields: studioID // GetImage provides a mock function with given fields: studioID
func (_m *StudioReaderWriter) GetStudioImage(studioID int) ([]byte, error) { func (_m *StudioReaderWriter) GetImage(studioID int) ([]byte, error) {
ret := _m.Called(studioID) ret := _m.Called(studioID)
var r0 []byte var r0 []byte
@@ -150,6 +245,80 @@ func (_m *StudioReaderWriter) GetStudioImage(studioID int) ([]byte, error) {
return r0, r1 return r0, r1
} }
// GetStashIDs provides a mock function with given fields: studioID
func (_m *StudioReaderWriter) GetStashIDs(studioID int) ([]*models.StashID, error) {
ret := _m.Called(studioID)
var r0 []*models.StashID
if rf, ok := ret.Get(0).(func(int) []*models.StashID); ok {
r0 = rf(studioID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.StashID)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(int) error); ok {
r1 = rf(studioID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// HasImage provides a mock function with given fields: studioID
func (_m *StudioReaderWriter) HasImage(studioID int) (bool, error) {
ret := _m.Called(studioID)
var r0 bool
if rf, ok := ret.Get(0).(func(int) bool); ok {
r0 = rf(studioID)
} else {
r0 = ret.Get(0).(bool)
}
var r1 error
if rf, ok := ret.Get(1).(func(int) error); ok {
r1 = rf(studioID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Query provides a mock function with given fields: studioFilter, findFilter
func (_m *StudioReaderWriter) Query(studioFilter *models.StudioFilterType, findFilter *models.FindFilterType) ([]*models.Studio, int, error) {
ret := _m.Called(studioFilter, findFilter)
var r0 []*models.Studio
if rf, ok := ret.Get(0).(func(*models.StudioFilterType, *models.FindFilterType) []*models.Studio); ok {
r0 = rf(studioFilter, findFilter)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Studio)
}
}
var r1 int
if rf, ok := ret.Get(1).(func(*models.StudioFilterType, *models.FindFilterType) int); ok {
r1 = rf(studioFilter, findFilter)
} else {
r1 = ret.Get(1).(int)
}
var r2 error
if rf, ok := ret.Get(2).(func(*models.StudioFilterType, *models.FindFilterType) error); ok {
r2 = rf(studioFilter, findFilter)
} else {
r2 = ret.Error(2)
}
return r0, r1, r2
}
// Update provides a mock function with given fields: updatedStudio // Update provides a mock function with given fields: updatedStudio
func (_m *StudioReaderWriter) Update(updatedStudio models.StudioPartial) (*models.Studio, error) { func (_m *StudioReaderWriter) Update(updatedStudio models.StudioPartial) (*models.Studio, error) {
ret := _m.Called(updatedStudio) ret := _m.Called(updatedStudio)
@@ -196,8 +365,8 @@ func (_m *StudioReaderWriter) UpdateFull(updatedStudio models.Studio) (*models.S
return r0, r1 return r0, r1
} }
// UpdateStudioImage provides a mock function with given fields: studioID, image // UpdateImage provides a mock function with given fields: studioID, image
func (_m *StudioReaderWriter) UpdateStudioImage(studioID int, image []byte) error { func (_m *StudioReaderWriter) UpdateImage(studioID int, image []byte) error {
ret := _m.Called(studioID, image) ret := _m.Called(studioID, image)
var r0 error var r0 error
@@ -209,3 +378,17 @@ func (_m *StudioReaderWriter) UpdateStudioImage(studioID int, image []byte) erro
return r0 return r0
} }
// UpdateStashIDs provides a mock function with given fields: studioID, stashIDs
func (_m *StudioReaderWriter) UpdateStashIDs(studioID int, stashIDs []models.StashID) error {
ret := _m.Called(studioID, stashIDs)
var r0 error
if rf, ok := ret.Get(0).(func(int, []models.StashID) error); ok {
r0 = rf(studioID, stashIDs)
} else {
r0 = ret.Error(0)
}
return r0
}

View File

@@ -35,6 +35,50 @@ func (_m *TagReaderWriter) All() ([]*models.Tag, error) {
return r0, r1 return r0, r1
} }
// AllSlim provides a mock function with given fields:
func (_m *TagReaderWriter) AllSlim() ([]*models.Tag, error) {
ret := _m.Called()
var r0 []*models.Tag
if rf, ok := ret.Get(0).(func() []*models.Tag); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Tag)
}
}
var r1 error
if rf, ok := ret.Get(1).(func() error); ok {
r1 = rf()
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Count provides a mock function with given fields:
func (_m *TagReaderWriter) Count() (int, error) {
ret := _m.Called()
var r0 int
if rf, ok := ret.Get(0).(func() int); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(int)
}
var r1 error
if rf, ok := ret.Get(1).(func() error); ok {
r1 = rf()
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Create provides a mock function with given fields: newTag // Create provides a mock function with given fields: newTag
func (_m *TagReaderWriter) Create(newTag models.Tag) (*models.Tag, error) { func (_m *TagReaderWriter) Create(newTag models.Tag) (*models.Tag, error) {
ret := _m.Called(newTag) ret := _m.Called(newTag)
@@ -58,6 +102,34 @@ func (_m *TagReaderWriter) Create(newTag models.Tag) (*models.Tag, error) {
return r0, r1 return r0, r1
} }
// Destroy provides a mock function with given fields: id
func (_m *TagReaderWriter) Destroy(id int) error {
ret := _m.Called(id)
var r0 error
if rf, ok := ret.Get(0).(func(int) error); ok {
r0 = rf(id)
} else {
r0 = ret.Error(0)
}
return r0
}
// DestroyImage provides a mock function with given fields: tagID
func (_m *TagReaderWriter) DestroyImage(tagID int) error {
ret := _m.Called(tagID)
var r0 error
if rf, ok := ret.Get(0).(func(int) error); ok {
r0 = rf(tagID)
} else {
r0 = ret.Error(0)
}
return r0
}
// Find provides a mock function with given fields: id // Find provides a mock function with given fields: id
func (_m *TagReaderWriter) Find(id int) (*models.Tag, error) { func (_m *TagReaderWriter) Find(id int) (*models.Tag, error) {
ret := _m.Called(id) ret := _m.Called(id)
@@ -242,8 +314,8 @@ func (_m *TagReaderWriter) FindMany(ids []int) ([]*models.Tag, error) {
return r0, r1 return r0, r1
} }
// GetTagImage provides a mock function with given fields: tagID // GetImage provides a mock function with given fields: tagID
func (_m *TagReaderWriter) GetTagImage(tagID int) ([]byte, error) { func (_m *TagReaderWriter) GetImage(tagID int) ([]byte, error) {
ret := _m.Called(tagID) ret := _m.Called(tagID)
var r0 []byte var r0 []byte
@@ -265,6 +337,36 @@ func (_m *TagReaderWriter) GetTagImage(tagID int) ([]byte, error) {
return r0, r1 return r0, r1
} }
// Query provides a mock function with given fields: tagFilter, findFilter
func (_m *TagReaderWriter) Query(tagFilter *models.TagFilterType, findFilter *models.FindFilterType) ([]*models.Tag, int, error) {
ret := _m.Called(tagFilter, findFilter)
var r0 []*models.Tag
if rf, ok := ret.Get(0).(func(*models.TagFilterType, *models.FindFilterType) []*models.Tag); ok {
r0 = rf(tagFilter, findFilter)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Tag)
}
}
var r1 int
if rf, ok := ret.Get(1).(func(*models.TagFilterType, *models.FindFilterType) int); ok {
r1 = rf(tagFilter, findFilter)
} else {
r1 = ret.Get(1).(int)
}
var r2 error
if rf, ok := ret.Get(2).(func(*models.TagFilterType, *models.FindFilterType) error); ok {
r2 = rf(tagFilter, findFilter)
} else {
r2 = ret.Error(2)
}
return r0, r1, r2
}
// Update provides a mock function with given fields: updatedTag // Update provides a mock function with given fields: updatedTag
func (_m *TagReaderWriter) Update(updatedTag models.Tag) (*models.Tag, error) { func (_m *TagReaderWriter) Update(updatedTag models.Tag) (*models.Tag, error) {
ret := _m.Called(updatedTag) ret := _m.Called(updatedTag)
@@ -288,8 +390,8 @@ func (_m *TagReaderWriter) Update(updatedTag models.Tag) (*models.Tag, error) {
return r0, r1 return r0, r1
} }
// UpdateTagImage provides a mock function with given fields: tagID, image // UpdateImage provides a mock function with given fields: tagID, image
func (_m *TagReaderWriter) UpdateTagImage(tagID int, image []byte) error { func (_m *TagReaderWriter) UpdateImage(tagID int, image []byte) error {
ret := _m.Called(tagID, image) ret := _m.Called(tagID, image)
var r0 error var r0 error

View File

@@ -0,0 +1,117 @@
package mocks
import (
"context"
models "github.com/stashapp/stash/pkg/models"
)
type TransactionManager struct {
gallery models.GalleryReaderWriter
image models.ImageReaderWriter
movie models.MovieReaderWriter
performer models.PerformerReaderWriter
scene models.SceneReaderWriter
sceneMarker models.SceneMarkerReaderWriter
scrapedItem models.ScrapedItemReaderWriter
studio models.StudioReaderWriter
tag models.TagReaderWriter
}
func NewTransactionManager() *TransactionManager {
return &TransactionManager{
gallery: &GalleryReaderWriter{},
image: &ImageReaderWriter{},
movie: &MovieReaderWriter{},
performer: &PerformerReaderWriter{},
scene: &SceneReaderWriter{},
sceneMarker: &SceneMarkerReaderWriter{},
scrapedItem: &ScrapedItemReaderWriter{},
studio: &StudioReaderWriter{},
tag: &TagReaderWriter{},
}
}
func (t *TransactionManager) WithTxn(ctx context.Context, fn func(r models.Repository) error) error {
return fn(t)
}
func (t *TransactionManager) Gallery() models.GalleryReaderWriter {
return t.gallery
}
func (t *TransactionManager) Image() models.ImageReaderWriter {
return t.image
}
func (t *TransactionManager) Movie() models.MovieReaderWriter {
return t.movie
}
func (t *TransactionManager) Performer() models.PerformerReaderWriter {
return t.performer
}
func (t *TransactionManager) SceneMarker() models.SceneMarkerReaderWriter {
return t.sceneMarker
}
func (t *TransactionManager) Scene() models.SceneReaderWriter {
return t.scene
}
func (t *TransactionManager) ScrapedItem() models.ScrapedItemReaderWriter {
return t.scrapedItem
}
func (t *TransactionManager) Studio() models.StudioReaderWriter {
return t.studio
}
func (t *TransactionManager) Tag() models.TagReaderWriter {
return t.tag
}
type ReadTransaction struct {
t *TransactionManager
}
func (t *TransactionManager) WithReadTxn(ctx context.Context, fn func(r models.ReaderRepository) error) error {
return fn(&ReadTransaction{t: t})
}
func (r *ReadTransaction) Gallery() models.GalleryReader {
return r.t.gallery
}
func (r *ReadTransaction) Image() models.ImageReader {
return r.t.image
}
func (r *ReadTransaction) Movie() models.MovieReader {
return r.t.movie
}
func (r *ReadTransaction) Performer() models.PerformerReader {
return r.t.performer
}
func (r *ReadTransaction) SceneMarker() models.SceneMarkerReader {
return r.t.sceneMarker
}
func (r *ReadTransaction) Scene() models.SceneReader {
return r.t.scene
}
func (r *ReadTransaction) ScrapedItem() models.ScrapedItemReader {
return r.t.scrapedItem
}
func (r *ReadTransaction) Studio() models.StudioReader {
return r.t.studio
}
func (r *ReadTransaction) Tag() models.TagReader {
return r.t.tag
}

View File

@@ -42,3 +42,13 @@ type GalleryPartial struct {
} }
const DefaultGthumbWidth int = 640 const DefaultGthumbWidth int = 640
type Galleries []*Gallery
func (g *Galleries) Append(o interface{}) {
*g = append(*g, o.(*Gallery))
}
func (g *Galleries) New() interface{} {
return &Gallery{}
}

View File

@@ -46,3 +46,13 @@ type ImageFileType struct {
Width *int `graphql:"width" json:"width"` Width *int `graphql:"width" json:"width"`
Height *int `graphql:"height" json:"height"` Height *int `graphql:"height" json:"height"`
} }
type Images []*Image
func (i *Images) Append(o interface{}) {
*i = append(*i, o.(*Image))
}
func (i *Images) New() interface{} {
return &Image{}
}

View File

@@ -2,53 +2,13 @@ package models
import "database/sql" import "database/sql"
type PerformersScenes struct {
PerformerID int `db:"performer_id" json:"performer_id"`
SceneID int `db:"scene_id" json:"scene_id"`
}
type MoviesScenes struct { type MoviesScenes struct {
MovieID int `db:"movie_id" json:"movie_id"` MovieID int `db:"movie_id" json:"movie_id"`
SceneID int `db:"scene_id" json:"scene_id"` SceneID int `db:"scene_id" json:"scene_id"`
SceneIndex sql.NullInt64 `db:"scene_index" json:"scene_index"` SceneIndex sql.NullInt64 `db:"scene_index" json:"scene_index"`
} }
type ScenesTags struct {
SceneID int `db:"scene_id" json:"scene_id"`
TagID int `db:"tag_id" json:"tag_id"`
}
type SceneMarkersTags struct {
SceneMarkerID int `db:"scene_marker_id" json:"scene_marker_id"`
TagID int `db:"tag_id" json:"tag_id"`
}
type StashID struct { type StashID struct {
StashID string `db:"stash_id" json:"stash_id"` StashID string `db:"stash_id" json:"stash_id"`
Endpoint string `db:"endpoint" json:"endpoint"` Endpoint string `db:"endpoint" json:"endpoint"`
} }
type PerformersImages struct {
PerformerID int `db:"performer_id" json:"performer_id"`
ImageID int `db:"image_id" json:"image_id"`
}
type ImagesTags struct {
ImageID int `db:"image_id" json:"image_id"`
TagID int `db:"tag_id" json:"tag_id"`
}
type GalleriesImages struct {
GalleryID int `db:"gallery_id" json:"gallery_id"`
ImageID int `db:"image_id" json:"image_id"`
}
type PerformersGalleries struct {
PerformerID int `db:"performer_id" json:"performer_id"`
GalleryID int `db:"gallery_id" json:"gallery_id"`
}
type GalleriesTags struct {
TagID int `db:"tag_id" json:"tag_id"`
GalleryID int `db:"gallery_id" json:"gallery_id"`
}

View File

@@ -50,3 +50,13 @@ func NewMovie(name string) *Movie {
UpdatedAt: SQLiteTimestamp{Timestamp: currentTime}, UpdatedAt: SQLiteTimestamp{Timestamp: currentTime},
} }
} }
type Movies []*Movie
func (m *Movies) Append(o interface{}) {
*m = append(*m, o.(*Movie))
}
func (m *Movies) New() interface{} {
return &Movie{}
}

View File

@@ -65,3 +65,13 @@ func NewPerformer(name string) *Performer {
UpdatedAt: SQLiteTimestamp{Timestamp: currentTime}, UpdatedAt: SQLiteTimestamp{Timestamp: currentTime},
} }
} }
type Performers []*Performer
func (p *Performers) Append(o interface{}) {
*p = append(*p, o.(*Performer))
}
func (p *Performers) New() interface{} {
return &Performer{}
}

View File

@@ -95,3 +95,13 @@ type SceneFileType struct {
Framerate *float64 `graphql:"framerate" json:"framerate"` Framerate *float64 `graphql:"framerate" json:"framerate"`
Bitrate *int `graphql:"bitrate" json:"bitrate"` Bitrate *int `graphql:"bitrate" json:"bitrate"`
} }
type Scenes []*Scene
func (s *Scenes) Append(o interface{}) {
*s = append(*s, o.(*Scene))
}
func (m *Scenes) New() interface{} {
return &Scene{}
}

View File

@@ -13,3 +13,13 @@ type SceneMarker struct {
CreatedAt SQLiteTimestamp `db:"created_at" json:"created_at"` CreatedAt SQLiteTimestamp `db:"created_at" json:"created_at"`
UpdatedAt SQLiteTimestamp `db:"updated_at" json:"updated_at"` UpdatedAt SQLiteTimestamp `db:"updated_at" json:"updated_at"`
} }
type SceneMarkers []*SceneMarker
func (m *SceneMarkers) Append(o interface{}) {
*m = append(*m, o.(*SceneMarker))
}
func (m *SceneMarkers) New() interface{} {
return &SceneMarker{}
}

View File

@@ -174,3 +174,13 @@ type ScrapedMovieStudio struct {
Name string `graphql:"name" json:"name"` Name string `graphql:"name" json:"name"`
URL *string `graphql:"url" json:"url"` URL *string `graphql:"url" json:"url"`
} }
type ScrapedItems []*ScrapedItem
func (s *ScrapedItems) Append(o interface{}) {
*s = append(*s, o.(*ScrapedItem))
}
func (s *ScrapedItems) New() interface{} {
return &ScrapedItem{}
}

View File

@@ -38,3 +38,13 @@ func NewStudio(name string) *Studio {
UpdatedAt: SQLiteTimestamp{Timestamp: currentTime}, UpdatedAt: SQLiteTimestamp{Timestamp: currentTime},
} }
} }
type Studios []*Studio
func (s *Studios) Append(o interface{}) {
*s = append(*s, o.(*Studio))
}
func (s *Studios) New() interface{} {
return &Studio{}
}

View File

@@ -18,6 +18,16 @@ func NewTag(name string) *Tag {
} }
} }
type Tags []*Tag
func (t *Tags) Append(o interface{}) {
*t = append(*t, o.(*Tag))
}
func (t *Tags) New() interface{} {
return &Tag{}
}
// Original Tag image from: https://fontawesome.com/icons/tag?style=solid // Original Tag image from: https://fontawesome.com/icons/tag?style=solid
// Modified to change color and rotate // Modified to change color and rotate
// Licensed under CC Attribution 4.0: https://fontawesome.com/license // Licensed under CC Attribution 4.0: https://fontawesome.com/license

View File

@@ -1,9 +1,5 @@
package models package models
import (
"github.com/jmoiron/sqlx"
)
type MovieReader interface { type MovieReader interface {
Find(id int) (*Movie, error) Find(id int) (*Movie, error)
FindMany(ids []int) ([]*Movie, error) FindMany(ids []int) ([]*Movie, error)
@@ -11,8 +7,9 @@ type MovieReader interface {
FindByName(name string, nocase bool) (*Movie, error) FindByName(name string, nocase bool) (*Movie, error)
FindByNames(names []string, nocase bool) ([]*Movie, error) FindByNames(names []string, nocase bool) ([]*Movie, error)
All() ([]*Movie, error) All() ([]*Movie, error)
// AllSlim() ([]*Movie, error) Count() (int, error)
// Query(movieFilter *MovieFilterType, findFilter *FindFilterType) ([]*Movie, int) AllSlim() ([]*Movie, error)
Query(movieFilter *MovieFilterType, findFilter *FindFilterType) ([]*Movie, int, error)
GetFrontImage(movieID int) ([]byte, error) GetFrontImage(movieID int) ([]byte, error)
GetBackImage(movieID int) ([]byte, error) GetBackImage(movieID int) ([]byte, error)
} }
@@ -21,68 +18,12 @@ type MovieWriter interface {
Create(newMovie Movie) (*Movie, error) Create(newMovie Movie) (*Movie, error)
Update(updatedMovie MoviePartial) (*Movie, error) Update(updatedMovie MoviePartial) (*Movie, error)
UpdateFull(updatedMovie Movie) (*Movie, error) UpdateFull(updatedMovie Movie) (*Movie, error)
// Destroy(id string) error Destroy(id int) error
UpdateMovieImages(movieID int, frontImage []byte, backImage []byte) error UpdateImages(movieID int, frontImage []byte, backImage []byte) error
// DestroyMovieImages(movieID int) error DestroyImages(movieID int) error
} }
type MovieReaderWriter interface { type MovieReaderWriter interface {
MovieReader MovieReader
MovieWriter MovieWriter
} }
func NewMovieReaderWriter(tx *sqlx.Tx) MovieReaderWriter {
return &movieReaderWriter{
tx: tx,
qb: NewMovieQueryBuilder(),
}
}
type movieReaderWriter struct {
tx *sqlx.Tx
qb MovieQueryBuilder
}
func (t *movieReaderWriter) Find(id int) (*Movie, error) {
return t.qb.Find(id, t.tx)
}
func (t *movieReaderWriter) FindMany(ids []int) ([]*Movie, error) {
return t.qb.FindMany(ids)
}
func (t *movieReaderWriter) FindByName(name string, nocase bool) (*Movie, error) {
return t.qb.FindByName(name, t.tx, nocase)
}
func (t *movieReaderWriter) FindByNames(names []string, nocase bool) ([]*Movie, error) {
return t.qb.FindByNames(names, t.tx, nocase)
}
func (t *movieReaderWriter) All() ([]*Movie, error) {
return t.qb.All()
}
func (t *movieReaderWriter) GetFrontImage(movieID int) ([]byte, error) {
return t.qb.GetFrontImage(movieID, t.tx)
}
func (t *movieReaderWriter) GetBackImage(movieID int) ([]byte, error) {
return t.qb.GetBackImage(movieID, t.tx)
}
func (t *movieReaderWriter) Create(newMovie Movie) (*Movie, error) {
return t.qb.Create(newMovie, t.tx)
}
func (t *movieReaderWriter) Update(updatedMovie MoviePartial) (*Movie, error) {
return t.qb.Update(updatedMovie, t.tx)
}
func (t *movieReaderWriter) UpdateFull(updatedMovie Movie) (*Movie, error) {
return t.qb.UpdateFull(updatedMovie, t.tx)
}
func (t *movieReaderWriter) UpdateMovieImages(movieID int, frontImage []byte, backImage []byte) error {
return t.qb.UpdateMovieImages(movieID, frontImage, backImage, t.tx)
}

View File

@@ -1,94 +1,32 @@
package models package models
import (
"github.com/jmoiron/sqlx"
)
type PerformerReader interface { type PerformerReader interface {
// Find(id int) (*Performer, error) Find(id int) (*Performer, error)
FindMany(ids []int) ([]*Performer, error) FindMany(ids []int) ([]*Performer, error)
FindBySceneID(sceneID int) ([]*Performer, error) FindBySceneID(sceneID int) ([]*Performer, error)
FindNamesBySceneID(sceneID int) ([]*Performer, error) FindNamesBySceneID(sceneID int) ([]*Performer, error)
FindByImageID(imageID int) ([]*Performer, error) FindByImageID(imageID int) ([]*Performer, error)
FindByGalleryID(galleryID int) ([]*Performer, error) FindByGalleryID(galleryID int) ([]*Performer, error)
FindByNames(names []string, nocase bool) ([]*Performer, error) FindByNames(names []string, nocase bool) ([]*Performer, error)
// Count() (int, error) Count() (int, error)
All() ([]*Performer, error) All() ([]*Performer, error)
// AllSlim() ([]*Performer, error) AllSlim() ([]*Performer, error)
// Query(performerFilter *PerformerFilterType, findFilter *FindFilterType) ([]*Performer, int) Query(performerFilter *PerformerFilterType, findFilter *FindFilterType) ([]*Performer, int, error)
GetPerformerImage(performerID int) ([]byte, error) GetImage(performerID int) ([]byte, error)
GetStashIDs(performerID int) ([]*StashID, error)
} }
type PerformerWriter interface { type PerformerWriter interface {
Create(newPerformer Performer) (*Performer, error) Create(newPerformer Performer) (*Performer, error)
Update(updatedPerformer PerformerPartial) (*Performer, error) Update(updatedPerformer PerformerPartial) (*Performer, error)
UpdateFull(updatedPerformer Performer) (*Performer, error) UpdateFull(updatedPerformer Performer) (*Performer, error)
// Destroy(id string) error Destroy(id int) error
UpdatePerformerImage(performerID int, image []byte) error UpdateImage(performerID int, image []byte) error
// DestroyPerformerImage(performerID int) error DestroyImage(performerID int) error
UpdateStashIDs(performerID int, stashIDs []StashID) error
} }
type PerformerReaderWriter interface { type PerformerReaderWriter interface {
PerformerReader PerformerReader
PerformerWriter PerformerWriter
} }
func NewPerformerReaderWriter(tx *sqlx.Tx) PerformerReaderWriter {
return &performerReaderWriter{
tx: tx,
qb: NewPerformerQueryBuilder(),
}
}
type performerReaderWriter struct {
tx *sqlx.Tx
qb PerformerQueryBuilder
}
func (t *performerReaderWriter) FindMany(ids []int) ([]*Performer, error) {
return t.qb.FindMany(ids)
}
func (t *performerReaderWriter) FindByNames(names []string, nocase bool) ([]*Performer, error) {
return t.qb.FindByNames(names, t.tx, nocase)
}
func (t *performerReaderWriter) All() ([]*Performer, error) {
return t.qb.All()
}
func (t *performerReaderWriter) GetPerformerImage(performerID int) ([]byte, error) {
return t.qb.GetPerformerImage(performerID, t.tx)
}
func (t *performerReaderWriter) FindBySceneID(id int) ([]*Performer, error) {
return t.qb.FindBySceneID(id, t.tx)
}
func (t *performerReaderWriter) FindNamesBySceneID(sceneID int) ([]*Performer, error) {
return t.qb.FindNameBySceneID(sceneID, t.tx)
}
func (t *performerReaderWriter) FindByImageID(id int) ([]*Performer, error) {
return t.qb.FindByImageID(id, t.tx)
}
func (t *performerReaderWriter) FindByGalleryID(id int) ([]*Performer, error) {
return t.qb.FindByGalleryID(id, t.tx)
}
func (t *performerReaderWriter) Create(newPerformer Performer) (*Performer, error) {
return t.qb.Create(newPerformer, t.tx)
}
func (t *performerReaderWriter) Update(updatedPerformer PerformerPartial) (*Performer, error) {
return t.qb.Update(updatedPerformer, t.tx)
}
func (t *performerReaderWriter) UpdateFull(updatedPerformer Performer) (*Performer, error) {
return t.qb.UpdateFull(updatedPerformer, t.tx)
}
func (t *performerReaderWriter) UpdatePerformerImage(performerID int, image []byte) error {
return t.qb.UpdatePerformerImage(performerID, image, t.tx)
}

View File

@@ -1,225 +0,0 @@
// +build integration
package models_test
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stashapp/stash/pkg/models"
)
func TestGalleryFind(t *testing.T) {
gqb := models.NewGalleryQueryBuilder()
const galleryIdx = 0
gallery, err := gqb.Find(galleryIDs[galleryIdx], nil)
if err != nil {
t.Fatalf("Error finding gallery: %s", err.Error())
}
assert.Equal(t, getGalleryStringValue(galleryIdx, "Path"), gallery.Path.String)
gallery, err = gqb.Find(0, nil)
if err != nil {
t.Fatalf("Error finding gallery: %s", err.Error())
}
assert.Nil(t, gallery)
}
func TestGalleryFindByChecksum(t *testing.T) {
gqb := models.NewGalleryQueryBuilder()
const galleryIdx = 0
galleryChecksum := getGalleryStringValue(galleryIdx, "Checksum")
gallery, err := gqb.FindByChecksum(galleryChecksum, nil)
if err != nil {
t.Fatalf("Error finding gallery: %s", err.Error())
}
assert.Equal(t, getGalleryStringValue(galleryIdx, "Path"), gallery.Path.String)
galleryChecksum = "not exist"
gallery, err = gqb.FindByChecksum(galleryChecksum, nil)
if err != nil {
t.Fatalf("Error finding gallery: %s", err.Error())
}
assert.Nil(t, gallery)
}
func TestGalleryFindByPath(t *testing.T) {
gqb := models.NewGalleryQueryBuilder()
const galleryIdx = 0
galleryPath := getGalleryStringValue(galleryIdx, "Path")
gallery, err := gqb.FindByPath(galleryPath)
if err != nil {
t.Fatalf("Error finding gallery: %s", err.Error())
}
assert.Equal(t, galleryPath, gallery.Path.String)
galleryPath = "not exist"
gallery, err = gqb.FindByPath(galleryPath)
if err != nil {
t.Fatalf("Error finding gallery: %s", err.Error())
}
assert.Nil(t, gallery)
}
func TestGalleryFindBySceneID(t *testing.T) {
gqb := models.NewGalleryQueryBuilder()
sceneID := sceneIDs[sceneIdxWithGallery]
gallery, err := gqb.FindBySceneID(sceneID, nil)
if err != nil {
t.Fatalf("Error finding gallery: %s", err.Error())
}
assert.Equal(t, getGalleryStringValue(galleryIdxWithScene, "Path"), gallery.Path.String)
gallery, err = gqb.FindBySceneID(0, nil)
if err != nil {
t.Fatalf("Error finding gallery: %s", err.Error())
}
assert.Nil(t, gallery)
}
func TestGalleryQueryQ(t *testing.T) {
const galleryIdx = 0
q := getGalleryStringValue(galleryIdx, pathField)
sqb := models.NewGalleryQueryBuilder()
galleryQueryQ(t, sqb, q, galleryIdx)
}
func galleryQueryQ(t *testing.T, qb models.GalleryQueryBuilder, q string, expectedGalleryIdx int) {
filter := models.FindFilterType{
Q: &q,
}
galleries, _ := qb.Query(nil, &filter)
assert.Len(t, galleries, 1)
gallery := galleries[0]
assert.Equal(t, galleryIDs[expectedGalleryIdx], gallery.ID)
// no Q should return all results
filter.Q = nil
galleries, _ = qb.Query(nil, &filter)
assert.Len(t, galleries, totalGalleries)
}
func TestGalleryQueryPath(t *testing.T) {
const galleryIdx = 1
galleryPath := getGalleryStringValue(galleryIdx, "Path")
pathCriterion := models.StringCriterionInput{
Value: galleryPath,
Modifier: models.CriterionModifierEquals,
}
verifyGalleriesPath(t, pathCriterion)
pathCriterion.Modifier = models.CriterionModifierNotEquals
verifyGalleriesPath(t, pathCriterion)
}
func verifyGalleriesPath(t *testing.T, pathCriterion models.StringCriterionInput) {
sqb := models.NewGalleryQueryBuilder()
galleryFilter := models.GalleryFilterType{
Path: &pathCriterion,
}
galleries, _ := sqb.Query(&galleryFilter, nil)
for _, gallery := range galleries {
verifyNullString(t, gallery.Path, pathCriterion)
}
}
func TestGalleryQueryRating(t *testing.T) {
const rating = 3
ratingCriterion := models.IntCriterionInput{
Value: rating,
Modifier: models.CriterionModifierEquals,
}
verifyGalleriesRating(t, ratingCriterion)
ratingCriterion.Modifier = models.CriterionModifierNotEquals
verifyGalleriesRating(t, ratingCriterion)
ratingCriterion.Modifier = models.CriterionModifierGreaterThan
verifyGalleriesRating(t, ratingCriterion)
ratingCriterion.Modifier = models.CriterionModifierLessThan
verifyGalleriesRating(t, ratingCriterion)
ratingCriterion.Modifier = models.CriterionModifierIsNull
verifyGalleriesRating(t, ratingCriterion)
ratingCriterion.Modifier = models.CriterionModifierNotNull
verifyGalleriesRating(t, ratingCriterion)
}
func verifyGalleriesRating(t *testing.T, ratingCriterion models.IntCriterionInput) {
sqb := models.NewGalleryQueryBuilder()
galleryFilter := models.GalleryFilterType{
Rating: &ratingCriterion,
}
galleries, _ := sqb.Query(&galleryFilter, nil)
for _, gallery := range galleries {
verifyInt64(t, gallery.Rating, ratingCriterion)
}
}
func TestGalleryQueryIsMissingScene(t *testing.T) {
qb := models.NewGalleryQueryBuilder()
isMissing := "scene"
galleryFilter := models.GalleryFilterType{
IsMissing: &isMissing,
}
q := getGalleryStringValue(galleryIdxWithScene, titleField)
findFilter := models.FindFilterType{
Q: &q,
}
galleries, _ := qb.Query(&galleryFilter, &findFilter)
assert.Len(t, galleries, 0)
findFilter.Q = nil
galleries, _ = qb.Query(&galleryFilter, &findFilter)
// ensure non of the ids equal the one with gallery
for _, gallery := range galleries {
assert.NotEqual(t, galleryIDs[galleryIdxWithScene], gallery.ID)
}
}
// TODO ValidGalleriesForScenePath
// TODO Count
// TODO All
// TODO Query
// TODO Update
// TODO Destroy
// TODO ClearGalleryId

View File

@@ -1,652 +0,0 @@
// +build integration
package models_test
import (
"database/sql"
"strconv"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stashapp/stash/pkg/models"
)
func TestImageFind(t *testing.T) {
// assume that the first image is imageWithGalleryPath
sqb := models.NewImageQueryBuilder()
const imageIdx = 0
imageID := imageIDs[imageIdx]
image, err := sqb.Find(imageID)
if err != nil {
t.Fatalf("Error finding image: %s", err.Error())
}
assert.Equal(t, getImageStringValue(imageIdx, "Path"), image.Path)
imageID = 0
image, err = sqb.Find(imageID)
if err != nil {
t.Fatalf("Error finding image: %s", err.Error())
}
assert.Nil(t, image)
}
func TestImageFindByPath(t *testing.T) {
sqb := models.NewImageQueryBuilder()
const imageIdx = 1
imagePath := getImageStringValue(imageIdx, "Path")
image, err := sqb.FindByPath(imagePath)
if err != nil {
t.Fatalf("Error finding image: %s", err.Error())
}
assert.Equal(t, imageIDs[imageIdx], image.ID)
assert.Equal(t, imagePath, image.Path)
imagePath = "not exist"
image, err = sqb.FindByPath(imagePath)
if err != nil {
t.Fatalf("Error finding image: %s", err.Error())
}
assert.Nil(t, image)
}
func TestImageCountByPerformerID(t *testing.T) {
sqb := models.NewImageQueryBuilder()
count, err := sqb.CountByPerformerID(performerIDs[performerIdxWithImage])
if err != nil {
t.Fatalf("Error counting images: %s", err.Error())
}
assert.Equal(t, 1, count)
count, err = sqb.CountByPerformerID(0)
if err != nil {
t.Fatalf("Error counting images: %s", err.Error())
}
assert.Equal(t, 0, count)
}
func TestImageQueryQ(t *testing.T) {
const imageIdx = 2
q := getImageStringValue(imageIdx, titleField)
sqb := models.NewImageQueryBuilder()
imageQueryQ(t, sqb, q, imageIdx)
}
func imageQueryQ(t *testing.T, sqb models.ImageQueryBuilder, q string, expectedImageIdx int) {
filter := models.FindFilterType{
Q: &q,
}
images, _ := sqb.Query(nil, &filter)
assert.Len(t, images, 1)
image := images[0]
assert.Equal(t, imageIDs[expectedImageIdx], image.ID)
// no Q should return all results
filter.Q = nil
images, _ = sqb.Query(nil, &filter)
assert.Len(t, images, totalImages)
}
func TestImageQueryPath(t *testing.T) {
const imageIdx = 1
imagePath := getImageStringValue(imageIdx, "Path")
pathCriterion := models.StringCriterionInput{
Value: imagePath,
Modifier: models.CriterionModifierEquals,
}
verifyImagePath(t, pathCriterion)
pathCriterion.Modifier = models.CriterionModifierNotEquals
verifyImagePath(t, pathCriterion)
}
func verifyImagePath(t *testing.T, pathCriterion models.StringCriterionInput) {
sqb := models.NewImageQueryBuilder()
imageFilter := models.ImageFilterType{
Path: &pathCriterion,
}
images, _ := sqb.Query(&imageFilter, nil)
for _, image := range images {
verifyString(t, image.Path, pathCriterion)
}
}
func TestImageQueryRating(t *testing.T) {
const rating = 3
ratingCriterion := models.IntCriterionInput{
Value: rating,
Modifier: models.CriterionModifierEquals,
}
verifyImagesRating(t, ratingCriterion)
ratingCriterion.Modifier = models.CriterionModifierNotEquals
verifyImagesRating(t, ratingCriterion)
ratingCriterion.Modifier = models.CriterionModifierGreaterThan
verifyImagesRating(t, ratingCriterion)
ratingCriterion.Modifier = models.CriterionModifierLessThan
verifyImagesRating(t, ratingCriterion)
ratingCriterion.Modifier = models.CriterionModifierIsNull
verifyImagesRating(t, ratingCriterion)
ratingCriterion.Modifier = models.CriterionModifierNotNull
verifyImagesRating(t, ratingCriterion)
}
func verifyImagesRating(t *testing.T, ratingCriterion models.IntCriterionInput) {
sqb := models.NewImageQueryBuilder()
imageFilter := models.ImageFilterType{
Rating: &ratingCriterion,
}
images, _ := sqb.Query(&imageFilter, nil)
for _, image := range images {
verifyInt64(t, image.Rating, ratingCriterion)
}
}
func TestImageQueryOCounter(t *testing.T) {
const oCounter = 1
oCounterCriterion := models.IntCriterionInput{
Value: oCounter,
Modifier: models.CriterionModifierEquals,
}
verifyImagesOCounter(t, oCounterCriterion)
oCounterCriterion.Modifier = models.CriterionModifierNotEquals
verifyImagesOCounter(t, oCounterCriterion)
oCounterCriterion.Modifier = models.CriterionModifierGreaterThan
verifyImagesOCounter(t, oCounterCriterion)
oCounterCriterion.Modifier = models.CriterionModifierLessThan
verifyImagesOCounter(t, oCounterCriterion)
}
func verifyImagesOCounter(t *testing.T, oCounterCriterion models.IntCriterionInput) {
sqb := models.NewImageQueryBuilder()
imageFilter := models.ImageFilterType{
OCounter: &oCounterCriterion,
}
images, _ := sqb.Query(&imageFilter, nil)
for _, image := range images {
verifyInt(t, image.OCounter, oCounterCriterion)
}
}
func TestImageQueryResolution(t *testing.T) {
verifyImagesResolution(t, models.ResolutionEnumLow)
verifyImagesResolution(t, models.ResolutionEnumStandard)
verifyImagesResolution(t, models.ResolutionEnumStandardHd)
verifyImagesResolution(t, models.ResolutionEnumFullHd)
verifyImagesResolution(t, models.ResolutionEnumFourK)
verifyImagesResolution(t, models.ResolutionEnum("unknown"))
}
func verifyImagesResolution(t *testing.T, resolution models.ResolutionEnum) {
sqb := models.NewImageQueryBuilder()
imageFilter := models.ImageFilterType{
Resolution: &resolution,
}
images, _ := sqb.Query(&imageFilter, nil)
for _, image := range images {
verifyImageResolution(t, image.Height, resolution)
}
}
func verifyImageResolution(t *testing.T, height sql.NullInt64, resolution models.ResolutionEnum) {
assert := assert.New(t)
h := height.Int64
switch resolution {
case models.ResolutionEnumLow:
assert.True(h < 480)
case models.ResolutionEnumStandard:
assert.True(h >= 480 && h < 720)
case models.ResolutionEnumStandardHd:
assert.True(h >= 720 && h < 1080)
case models.ResolutionEnumFullHd:
assert.True(h >= 1080 && h < 2160)
case models.ResolutionEnumFourK:
assert.True(h >= 2160)
}
}
func TestImageQueryIsMissingGalleries(t *testing.T) {
sqb := models.NewImageQueryBuilder()
isMissing := "galleries"
imageFilter := models.ImageFilterType{
IsMissing: &isMissing,
}
q := getImageStringValue(imageIdxWithGallery, titleField)
findFilter := models.FindFilterType{
Q: &q,
}
images, _ := sqb.Query(&imageFilter, &findFilter)
assert.Len(t, images, 0)
findFilter.Q = nil
images, _ = sqb.Query(&imageFilter, &findFilter)
// ensure non of the ids equal the one with gallery
for _, image := range images {
assert.NotEqual(t, imageIDs[imageIdxWithGallery], image.ID)
}
}
func TestImageQueryIsMissingStudio(t *testing.T) {
sqb := models.NewImageQueryBuilder()
isMissing := "studio"
imageFilter := models.ImageFilterType{
IsMissing: &isMissing,
}
q := getImageStringValue(imageIdxWithStudio, titleField)
findFilter := models.FindFilterType{
Q: &q,
}
images, _ := sqb.Query(&imageFilter, &findFilter)
assert.Len(t, images, 0)
findFilter.Q = nil
images, _ = sqb.Query(&imageFilter, &findFilter)
// ensure non of the ids equal the one with studio
for _, image := range images {
assert.NotEqual(t, imageIDs[imageIdxWithStudio], image.ID)
}
}
func TestImageQueryIsMissingPerformers(t *testing.T) {
sqb := models.NewImageQueryBuilder()
isMissing := "performers"
imageFilter := models.ImageFilterType{
IsMissing: &isMissing,
}
q := getImageStringValue(imageIdxWithPerformer, titleField)
findFilter := models.FindFilterType{
Q: &q,
}
images, _ := sqb.Query(&imageFilter, &findFilter)
assert.Len(t, images, 0)
findFilter.Q = nil
images, _ = sqb.Query(&imageFilter, &findFilter)
assert.True(t, len(images) > 0)
// ensure non of the ids equal the one with movies
for _, image := range images {
assert.NotEqual(t, imageIDs[imageIdxWithPerformer], image.ID)
}
}
func TestImageQueryIsMissingTags(t *testing.T) {
sqb := models.NewImageQueryBuilder()
isMissing := "tags"
imageFilter := models.ImageFilterType{
IsMissing: &isMissing,
}
q := getImageStringValue(imageIdxWithTwoTags, titleField)
findFilter := models.FindFilterType{
Q: &q,
}
images, _ := sqb.Query(&imageFilter, &findFilter)
assert.Len(t, images, 0)
findFilter.Q = nil
images, _ = sqb.Query(&imageFilter, &findFilter)
assert.True(t, len(images) > 0)
}
func TestImageQueryIsMissingRating(t *testing.T) {
sqb := models.NewImageQueryBuilder()
isMissing := "rating"
imageFilter := models.ImageFilterType{
IsMissing: &isMissing,
}
images, _ := sqb.Query(&imageFilter, nil)
assert.True(t, len(images) > 0)
// ensure date is null, empty or "0001-01-01"
for _, image := range images {
assert.True(t, !image.Rating.Valid)
}
}
func TestImageQueryPerformers(t *testing.T) {
sqb := models.NewImageQueryBuilder()
performerCriterion := models.MultiCriterionInput{
Value: []string{
strconv.Itoa(performerIDs[performerIdxWithImage]),
strconv.Itoa(performerIDs[performerIdx1WithImage]),
},
Modifier: models.CriterionModifierIncludes,
}
imageFilter := models.ImageFilterType{
Performers: &performerCriterion,
}
images, _ := sqb.Query(&imageFilter, nil)
assert.Len(t, images, 2)
// ensure ids are correct
for _, image := range images {
assert.True(t, image.ID == imageIDs[imageIdxWithPerformer] || image.ID == imageIDs[imageIdxWithTwoPerformers])
}
performerCriterion = models.MultiCriterionInput{
Value: []string{
strconv.Itoa(performerIDs[performerIdx1WithImage]),
strconv.Itoa(performerIDs[performerIdx2WithImage]),
},
Modifier: models.CriterionModifierIncludesAll,
}
images, _ = sqb.Query(&imageFilter, nil)
assert.Len(t, images, 1)
assert.Equal(t, imageIDs[imageIdxWithTwoPerformers], images[0].ID)
performerCriterion = models.MultiCriterionInput{
Value: []string{
strconv.Itoa(performerIDs[performerIdx1WithImage]),
},
Modifier: models.CriterionModifierExcludes,
}
q := getImageStringValue(imageIdxWithTwoPerformers, titleField)
findFilter := models.FindFilterType{
Q: &q,
}
images, _ = sqb.Query(&imageFilter, &findFilter)
assert.Len(t, images, 0)
}
func TestImageQueryTags(t *testing.T) {
sqb := models.NewImageQueryBuilder()
tagCriterion := models.MultiCriterionInput{
Value: []string{
strconv.Itoa(tagIDs[tagIdxWithImage]),
strconv.Itoa(tagIDs[tagIdx1WithImage]),
},
Modifier: models.CriterionModifierIncludes,
}
imageFilter := models.ImageFilterType{
Tags: &tagCriterion,
}
images, _ := sqb.Query(&imageFilter, nil)
assert.Len(t, images, 2)
// ensure ids are correct
for _, image := range images {
assert.True(t, image.ID == imageIDs[imageIdxWithTag] || image.ID == imageIDs[imageIdxWithTwoTags])
}
tagCriterion = models.MultiCriterionInput{
Value: []string{
strconv.Itoa(tagIDs[tagIdx1WithImage]),
strconv.Itoa(tagIDs[tagIdx2WithImage]),
},
Modifier: models.CriterionModifierIncludesAll,
}
images, _ = sqb.Query(&imageFilter, nil)
assert.Len(t, images, 1)
assert.Equal(t, imageIDs[imageIdxWithTwoTags], images[0].ID)
tagCriterion = models.MultiCriterionInput{
Value: []string{
strconv.Itoa(tagIDs[tagIdx1WithImage]),
},
Modifier: models.CriterionModifierExcludes,
}
q := getImageStringValue(imageIdxWithTwoTags, titleField)
findFilter := models.FindFilterType{
Q: &q,
}
images, _ = sqb.Query(&imageFilter, &findFilter)
assert.Len(t, images, 0)
}
func TestImageQueryStudio(t *testing.T) {
sqb := models.NewImageQueryBuilder()
studioCriterion := models.MultiCriterionInput{
Value: []string{
strconv.Itoa(studioIDs[studioIdxWithImage]),
},
Modifier: models.CriterionModifierIncludes,
}
imageFilter := models.ImageFilterType{
Studios: &studioCriterion,
}
images, _ := sqb.Query(&imageFilter, nil)
assert.Len(t, images, 1)
// ensure id is correct
assert.Equal(t, imageIDs[imageIdxWithStudio], images[0].ID)
studioCriterion = models.MultiCriterionInput{
Value: []string{
strconv.Itoa(studioIDs[studioIdxWithImage]),
},
Modifier: models.CriterionModifierExcludes,
}
q := getImageStringValue(imageIdxWithStudio, titleField)
findFilter := models.FindFilterType{
Q: &q,
}
images, _ = sqb.Query(&imageFilter, &findFilter)
assert.Len(t, images, 0)
}
func TestImageQuerySorting(t *testing.T) {
sort := titleField
direction := models.SortDirectionEnumAsc
findFilter := models.FindFilterType{
Sort: &sort,
Direction: &direction,
}
sqb := models.NewImageQueryBuilder()
images, _ := sqb.Query(nil, &findFilter)
// images should be in same order as indexes
firstImage := images[0]
lastImage := images[len(images)-1]
assert.Equal(t, imageIDs[0], firstImage.ID)
assert.Equal(t, imageIDs[len(imageIDs)-1], lastImage.ID)
// sort in descending order
direction = models.SortDirectionEnumDesc
images, _ = sqb.Query(nil, &findFilter)
firstImage = images[0]
lastImage = images[len(images)-1]
assert.Equal(t, imageIDs[len(imageIDs)-1], firstImage.ID)
assert.Equal(t, imageIDs[0], lastImage.ID)
}
func TestImageQueryPagination(t *testing.T) {
perPage := 1
findFilter := models.FindFilterType{
PerPage: &perPage,
}
sqb := models.NewImageQueryBuilder()
images, _ := sqb.Query(nil, &findFilter)
assert.Len(t, images, 1)
firstID := images[0].ID
page := 2
findFilter.Page = &page
images, _ = sqb.Query(nil, &findFilter)
assert.Len(t, images, 1)
secondID := images[0].ID
assert.NotEqual(t, firstID, secondID)
perPage = 2
page = 1
images, _ = sqb.Query(nil, &findFilter)
assert.Len(t, images, 2)
assert.Equal(t, firstID, images[0].ID)
assert.Equal(t, secondID, images[1].ID)
}
func TestImageCountByTagID(t *testing.T) {
sqb := models.NewImageQueryBuilder()
imageCount, err := sqb.CountByTagID(tagIDs[tagIdxWithImage])
if err != nil {
t.Fatalf("error calling CountByTagID: %s", err.Error())
}
assert.Equal(t, 1, imageCount)
imageCount, err = sqb.CountByTagID(0)
if err != nil {
t.Fatalf("error calling CountByTagID: %s", err.Error())
}
assert.Equal(t, 0, imageCount)
}
func TestImageCountByStudioID(t *testing.T) {
sqb := models.NewImageQueryBuilder()
imageCount, err := sqb.CountByStudioID(studioIDs[studioIdxWithImage])
if err != nil {
t.Fatalf("error calling CountByStudioID: %s", err.Error())
}
assert.Equal(t, 1, imageCount)
imageCount, err = sqb.CountByStudioID(0)
if err != nil {
t.Fatalf("error calling CountByStudioID: %s", err.Error())
}
assert.Equal(t, 0, imageCount)
}
func TestImageFindByPerformerID(t *testing.T) {
sqb := models.NewImageQueryBuilder()
images, err := sqb.FindByPerformerID(performerIDs[performerIdxWithImage])
if err != nil {
t.Fatalf("error calling FindByPerformerID: %s", err.Error())
}
assert.Len(t, images, 1)
assert.Equal(t, imageIDs[imageIdxWithPerformer], images[0].ID)
images, err = sqb.FindByPerformerID(0)
if err != nil {
t.Fatalf("error calling FindByPerformerID: %s", err.Error())
}
assert.Len(t, images, 0)
}
func TestImageFindByStudioID(t *testing.T) {
sqb := models.NewImageQueryBuilder()
images, err := sqb.FindByStudioID(performerIDs[studioIdxWithImage])
if err != nil {
t.Fatalf("error calling FindByStudioID: %s", err.Error())
}
assert.Len(t, images, 1)
assert.Equal(t, imageIDs[imageIdxWithStudio], images[0].ID)
images, err = sqb.FindByStudioID(0)
if err != nil {
t.Fatalf("error calling FindByStudioID: %s", err.Error())
}
assert.Len(t, images, 0)
}
// TODO Update
// TODO IncrementOCounter
// TODO DecrementOCounter
// TODO ResetOCounter
// TODO Destroy
// TODO FindByChecksum
// TODO Count
// TODO SizeCount
// TODO All

File diff suppressed because it is too large Load Diff

View File

@@ -1,312 +0,0 @@
package models
import (
"database/sql"
"fmt"
"github.com/jmoiron/sqlx"
"github.com/stashapp/stash/pkg/database"
)
type MovieQueryBuilder struct{}
func NewMovieQueryBuilder() MovieQueryBuilder {
return MovieQueryBuilder{}
}
func (qb *MovieQueryBuilder) Create(newMovie Movie, tx *sqlx.Tx) (*Movie, error) {
ensureTx(tx)
result, err := tx.NamedExec(
`INSERT INTO movies (checksum, name, aliases, duration, date, rating, studio_id, director, synopsis, url, created_at, updated_at)
VALUES (:checksum, :name, :aliases, :duration, :date, :rating, :studio_id, :director, :synopsis, :url, :created_at, :updated_at)
`,
newMovie,
)
if err != nil {
return nil, err
}
movieID, err := result.LastInsertId()
if err != nil {
return nil, err
}
if err := tx.Get(&newMovie, `SELECT * FROM movies WHERE id = ? LIMIT 1`, movieID); err != nil {
return nil, err
}
return &newMovie, nil
}
func (qb *MovieQueryBuilder) Update(updatedMovie MoviePartial, tx *sqlx.Tx) (*Movie, error) {
ensureTx(tx)
_, err := tx.NamedExec(
`UPDATE movies SET `+SQLGenKeysPartial(updatedMovie)+` WHERE movies.id = :id`,
updatedMovie,
)
if err != nil {
return nil, err
}
return qb.Find(updatedMovie.ID, tx)
}
func (qb *MovieQueryBuilder) UpdateFull(updatedMovie Movie, tx *sqlx.Tx) (*Movie, error) {
ensureTx(tx)
_, err := tx.NamedExec(
`UPDATE movies SET `+SQLGenKeys(updatedMovie)+` WHERE movies.id = :id`,
updatedMovie,
)
if err != nil {
return nil, err
}
return qb.Find(updatedMovie.ID, tx)
}
func (qb *MovieQueryBuilder) Destroy(id string, tx *sqlx.Tx) error {
// delete movie from movies_scenes
_, err := tx.Exec("DELETE FROM movies_scenes WHERE movie_id = ?", id)
if err != nil {
return err
}
// // remove movie from scraped items
// _, err = tx.Exec("UPDATE scraped_items SET movie_id = null WHERE movie_id = ?", id)
// if err != nil {
// return err
// }
return executeDeleteQuery("movies", id, tx)
}
func (qb *MovieQueryBuilder) Find(id int, tx *sqlx.Tx) (*Movie, error) {
query := "SELECT * FROM movies WHERE id = ? LIMIT 1"
args := []interface{}{id}
return qb.queryMovie(query, args, tx)
}
func (qb *MovieQueryBuilder) FindMany(ids []int) ([]*Movie, error) {
var movies []*Movie
for _, id := range ids {
movie, err := qb.Find(id, nil)
if err != nil {
return nil, err
}
if movie == nil {
return nil, fmt.Errorf("movie with id %d not found", id)
}
movies = append(movies, movie)
}
return movies, nil
}
func (qb *MovieQueryBuilder) FindBySceneID(sceneID int, tx *sqlx.Tx) ([]*Movie, error) {
query := `
SELECT movies.* FROM movies
LEFT JOIN movies_scenes as scenes_join on scenes_join.movie_id = movies.id
WHERE scenes_join.scene_id = ?
GROUP BY movies.id
`
args := []interface{}{sceneID}
return qb.queryMovies(query, args, tx)
}
func (qb *MovieQueryBuilder) FindByName(name string, tx *sqlx.Tx, nocase bool) (*Movie, error) {
query := "SELECT * FROM movies WHERE name = ?"
if nocase {
query += " COLLATE NOCASE"
}
query += " LIMIT 1"
args := []interface{}{name}
return qb.queryMovie(query, args, tx)
}
func (qb *MovieQueryBuilder) FindByNames(names []string, tx *sqlx.Tx, nocase bool) ([]*Movie, error) {
query := "SELECT * FROM movies WHERE name"
if nocase {
query += " COLLATE NOCASE"
}
query += " IN " + getInBinding(len(names))
var args []interface{}
for _, name := range names {
args = append(args, name)
}
return qb.queryMovies(query, args, tx)
}
func (qb *MovieQueryBuilder) Count() (int, error) {
return runCountQuery(buildCountQuery("SELECT movies.id FROM movies"), nil)
}
func (qb *MovieQueryBuilder) All() ([]*Movie, error) {
return qb.queryMovies(selectAll("movies")+qb.getMovieSort(nil), nil, nil)
}
func (qb *MovieQueryBuilder) AllSlim() ([]*Movie, error) {
return qb.queryMovies("SELECT movies.id, movies.name FROM movies "+qb.getMovieSort(nil), nil, nil)
}
func (qb *MovieQueryBuilder) Query(movieFilter *MovieFilterType, findFilter *FindFilterType) ([]*Movie, int) {
if findFilter == nil {
findFilter = &FindFilterType{}
}
if movieFilter == nil {
movieFilter = &MovieFilterType{}
}
var whereClauses []string
var havingClauses []string
var args []interface{}
body := selectDistinctIDs("movies")
body += `
left join movies_scenes as scenes_join on scenes_join.movie_id = movies.id
left join scenes on scenes_join.scene_id = scenes.id
left join studios as studio on studio.id = movies.studio_id
`
if q := findFilter.Q; q != nil && *q != "" {
searchColumns := []string{"movies.name"}
clause, thisArgs := getSearchBinding(searchColumns, *q, false)
whereClauses = append(whereClauses, clause)
args = append(args, thisArgs...)
}
if studiosFilter := movieFilter.Studios; studiosFilter != nil && len(studiosFilter.Value) > 0 {
for _, studioID := range studiosFilter.Value {
args = append(args, studioID)
}
whereClause, havingClause := getMultiCriterionClause("movies", "studio", "", "", "studio_id", studiosFilter)
whereClauses = appendClause(whereClauses, whereClause)
havingClauses = appendClause(havingClauses, havingClause)
}
if isMissingFilter := movieFilter.IsMissing; isMissingFilter != nil && *isMissingFilter != "" {
switch *isMissingFilter {
case "front_image":
body += `left join movies_images on movies_images.movie_id = movies.id
`
whereClauses = appendClause(whereClauses, "movies_images.front_image IS NULL")
case "back_image":
body += `left join movies_images on movies_images.movie_id = movies.id
`
whereClauses = appendClause(whereClauses, "movies_images.back_image IS NULL")
case "scenes":
body += `left join movies_scenes on movies_scenes.movie_id = movies.id
`
whereClauses = appendClause(whereClauses, "movies_scenes.scene_id IS NULL")
default:
whereClauses = appendClause(whereClauses, "movies."+*isMissingFilter+" IS NULL")
}
}
sortAndPagination := qb.getMovieSort(findFilter) + getPagination(findFilter)
idsResult, countResult := executeFindQuery("movies", body, args, sortAndPagination, whereClauses, havingClauses)
var movies []*Movie
for _, id := range idsResult {
movie, _ := qb.Find(id, nil)
movies = append(movies, movie)
}
return movies, countResult
}
func (qb *MovieQueryBuilder) getMovieSort(findFilter *FindFilterType) string {
var sort string
var direction string
if findFilter == nil {
sort = "name"
direction = "ASC"
} else {
sort = findFilter.GetSort("name")
direction = findFilter.GetDirection()
}
// #943 - override name sorting to use natural sort
if sort == "name" {
return " ORDER BY " + getColumn("movies", sort) + " COLLATE NATURAL_CS " + direction
}
return getSort(sort, direction, "movies")
}
func (qb *MovieQueryBuilder) queryMovie(query string, args []interface{}, tx *sqlx.Tx) (*Movie, error) {
results, err := qb.queryMovies(query, args, tx)
if err != nil || len(results) < 1 {
return nil, err
}
return results[0], nil
}
func (qb *MovieQueryBuilder) queryMovies(query string, args []interface{}, tx *sqlx.Tx) ([]*Movie, error) {
var rows *sqlx.Rows
var err error
if tx != nil {
rows, err = tx.Queryx(query, args...)
} else {
rows, err = database.DB.Queryx(query, args...)
}
if err != nil && err != sql.ErrNoRows {
return nil, err
}
defer rows.Close()
movies := make([]*Movie, 0)
for rows.Next() {
movie := Movie{}
if err := rows.StructScan(&movie); err != nil {
return nil, err
}
movies = append(movies, &movie)
}
if err := rows.Err(); err != nil {
return nil, err
}
return movies, nil
}
func (qb *MovieQueryBuilder) UpdateMovieImages(movieID int, frontImage []byte, backImage []byte, tx *sqlx.Tx) error {
ensureTx(tx)
// Delete the existing cover and then create new
if err := qb.DestroyMovieImages(movieID, tx); err != nil {
return err
}
_, err := tx.Exec(
`INSERT INTO movies_images (movie_id, front_image, back_image) VALUES (?, ?, ?)`,
movieID,
frontImage,
backImage,
)
return err
}
func (qb *MovieQueryBuilder) DestroyMovieImages(movieID int, tx *sqlx.Tx) error {
ensureTx(tx)
// Delete the existing joins
_, err := tx.Exec("DELETE FROM movies_images WHERE movie_id = ?", movieID)
if err != nil {
return err
}
return err
}
func (qb *MovieQueryBuilder) GetFrontImage(movieID int, tx *sqlx.Tx) ([]byte, error) {
query := `SELECT front_image from movies_images WHERE movie_id = ?`
return getImage(tx, query, movieID)
}
func (qb *MovieQueryBuilder) GetBackImage(movieID int, tx *sqlx.Tx) ([]byte, error) {
query := `SELECT back_image from movies_images WHERE movie_id = ?`
return getImage(tx, query, movieID)
}

View File

@@ -1,275 +0,0 @@
// +build integration
package models_test
import (
"context"
"database/sql"
"strconv"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stashapp/stash/pkg/database"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/utils"
)
func TestMovieFindBySceneID(t *testing.T) {
mqb := models.NewMovieQueryBuilder()
sceneID := sceneIDs[sceneIdxWithMovie]
movies, err := mqb.FindBySceneID(sceneID, nil)
if err != nil {
t.Fatalf("Error finding movie: %s", err.Error())
}
assert.Equal(t, 1, len(movies), "expect 1 movie")
movie := movies[0]
assert.Equal(t, getMovieStringValue(movieIdxWithScene, "Name"), movie.Name.String)
movies, err = mqb.FindBySceneID(0, nil)
if err != nil {
t.Fatalf("Error finding movie: %s", err.Error())
}
assert.Equal(t, 0, len(movies))
}
func TestMovieFindByName(t *testing.T) {
mqb := models.NewMovieQueryBuilder()
name := movieNames[movieIdxWithScene] // find a movie by name
movie, err := mqb.FindByName(name, nil, false)
if err != nil {
t.Fatalf("Error finding movies: %s", err.Error())
}
assert.Equal(t, movieNames[movieIdxWithScene], movie.Name.String)
name = movieNames[movieIdxWithDupName] // find a movie by name nocase
movie, err = mqb.FindByName(name, nil, true)
if err != nil {
t.Fatalf("Error finding movies: %s", err.Error())
}
// movieIdxWithDupName and movieIdxWithScene should have similar names ( only diff should be Name vs NaMe)
//movie.Name should match with movieIdxWithScene since its ID is before moveIdxWithDupName
assert.Equal(t, movieNames[movieIdxWithScene], movie.Name.String)
//movie.Name should match with movieIdxWithDupName if the check is not case sensitive
assert.Equal(t, strings.ToLower(movieNames[movieIdxWithDupName]), strings.ToLower(movie.Name.String))
}
func TestMovieFindByNames(t *testing.T) {
var names []string
mqb := models.NewMovieQueryBuilder()
names = append(names, movieNames[movieIdxWithScene]) // find movies by names
movies, err := mqb.FindByNames(names, nil, false)
if err != nil {
t.Fatalf("Error finding movies: %s", err.Error())
}
assert.Len(t, movies, 1)
assert.Equal(t, movieNames[movieIdxWithScene], movies[0].Name.String)
movies, err = mqb.FindByNames(names, nil, true) // find movies by names nocase
if err != nil {
t.Fatalf("Error finding movies: %s", err.Error())
}
assert.Len(t, movies, 2) // movieIdxWithScene and movieIdxWithDupName
assert.Equal(t, strings.ToLower(movieNames[movieIdxWithScene]), strings.ToLower(movies[0].Name.String))
assert.Equal(t, strings.ToLower(movieNames[movieIdxWithScene]), strings.ToLower(movies[1].Name.String))
}
func TestMovieQueryStudio(t *testing.T) {
mqb := models.NewMovieQueryBuilder()
studioCriterion := models.MultiCriterionInput{
Value: []string{
strconv.Itoa(studioIDs[studioIdxWithMovie]),
},
Modifier: models.CriterionModifierIncludes,
}
movieFilter := models.MovieFilterType{
Studios: &studioCriterion,
}
movies, _ := mqb.Query(&movieFilter, nil)
assert.Len(t, movies, 1)
// ensure id is correct
assert.Equal(t, movieIDs[movieIdxWithStudio], movies[0].ID)
studioCriterion = models.MultiCriterionInput{
Value: []string{
strconv.Itoa(studioIDs[studioIdxWithMovie]),
},
Modifier: models.CriterionModifierExcludes,
}
q := getMovieStringValue(movieIdxWithStudio, titleField)
findFilter := models.FindFilterType{
Q: &q,
}
movies, _ = mqb.Query(&movieFilter, &findFilter)
assert.Len(t, movies, 0)
}
func TestMovieUpdateMovieImages(t *testing.T) {
mqb := models.NewMovieQueryBuilder()
// create movie to test against
ctx := context.TODO()
tx := database.DB.MustBeginTx(ctx, nil)
const name = "TestMovieUpdateMovieImages"
movie := models.Movie{
Name: sql.NullString{String: name, Valid: true},
Checksum: utils.MD5FromString(name),
}
created, err := mqb.Create(movie, tx)
if err != nil {
tx.Rollback()
t.Fatalf("Error creating movie: %s", err.Error())
}
frontImage := []byte("frontImage")
backImage := []byte("backImage")
err = mqb.UpdateMovieImages(created.ID, frontImage, backImage, tx)
if err != nil {
tx.Rollback()
t.Fatalf("Error updating movie images: %s", err.Error())
}
if err := tx.Commit(); err != nil {
tx.Rollback()
t.Fatalf("Error committing: %s", err.Error())
}
// ensure images are set
storedFront, err := mqb.GetFrontImage(created.ID, nil)
if err != nil {
t.Fatalf("Error getting front image: %s", err.Error())
}
assert.Equal(t, storedFront, frontImage)
storedBack, err := mqb.GetBackImage(created.ID, nil)
if err != nil {
t.Fatalf("Error getting back image: %s", err.Error())
}
assert.Equal(t, storedBack, backImage)
// set front image only
newImage := []byte("newImage")
tx = database.DB.MustBeginTx(ctx, nil)
err = mqb.UpdateMovieImages(created.ID, newImage, nil, tx)
if err != nil {
tx.Rollback()
t.Fatalf("Error updating movie images: %s", err.Error())
}
storedFront, err = mqb.GetFrontImage(created.ID, tx)
if err != nil {
tx.Rollback()
t.Fatalf("Error getting front image: %s", err.Error())
}
assert.Equal(t, storedFront, newImage)
// back image should be nil
storedBack, err = mqb.GetBackImage(created.ID, tx)
if err != nil {
tx.Rollback()
t.Fatalf("Error getting back image: %s", err.Error())
}
assert.Nil(t, nil)
// set back image only
err = mqb.UpdateMovieImages(created.ID, nil, newImage, tx)
if err == nil {
tx.Rollback()
t.Fatalf("Expected error setting nil front image")
}
if err := tx.Commit(); err != nil {
tx.Rollback()
t.Fatalf("Error committing: %s", err.Error())
}
}
func TestMovieDestroyMovieImages(t *testing.T) {
mqb := models.NewMovieQueryBuilder()
// create movie to test against
ctx := context.TODO()
tx := database.DB.MustBeginTx(ctx, nil)
const name = "TestMovieDestroyMovieImages"
movie := models.Movie{
Name: sql.NullString{String: name, Valid: true},
Checksum: utils.MD5FromString(name),
}
created, err := mqb.Create(movie, tx)
if err != nil {
tx.Rollback()
t.Fatalf("Error creating movie: %s", err.Error())
}
frontImage := []byte("frontImage")
backImage := []byte("backImage")
err = mqb.UpdateMovieImages(created.ID, frontImage, backImage, tx)
if err != nil {
tx.Rollback()
t.Fatalf("Error updating movie images: %s", err.Error())
}
if err := tx.Commit(); err != nil {
tx.Rollback()
t.Fatalf("Error committing: %s", err.Error())
}
tx = database.DB.MustBeginTx(ctx, nil)
err = mqb.DestroyMovieImages(created.ID, tx)
if err != nil {
tx.Rollback()
t.Fatalf("Error destroying movie images: %s", err.Error())
}
if err := tx.Commit(); err != nil {
tx.Rollback()
t.Fatalf("Error committing: %s", err.Error())
}
// front image should be nil
storedFront, err := mqb.GetFrontImage(created.ID, nil)
if err != nil {
t.Fatalf("Error getting front image: %s", err.Error())
}
assert.Nil(t, storedFront)
// back image should be nil
storedBack, err := mqb.GetBackImage(created.ID, nil)
if err != nil {
t.Fatalf("Error getting back image: %s", err.Error())
}
assert.Nil(t, storedBack)
}
// TODO Update
// TODO Destroy
// TODO Find
// TODO Count
// TODO All
// TODO Query

View File

@@ -1,257 +0,0 @@
// +build integration
package models_test
import (
"context"
"database/sql"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stashapp/stash/pkg/database"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/utils"
)
func TestPerformerFindBySceneID(t *testing.T) {
pqb := models.NewPerformerQueryBuilder()
sceneID := sceneIDs[sceneIdxWithPerformer]
performers, err := pqb.FindBySceneID(sceneID, nil)
if err != nil {
t.Fatalf("Error finding performer: %s", err.Error())
}
assert.Equal(t, 1, len(performers))
performer := performers[0]
assert.Equal(t, getPerformerStringValue(performerIdxWithScene, "Name"), performer.Name.String)
performers, err = pqb.FindBySceneID(0, nil)
if err != nil {
t.Fatalf("Error finding performer: %s", err.Error())
}
assert.Equal(t, 0, len(performers))
}
func TestPerformerFindNameBySceneID(t *testing.T) {
pqb := models.NewPerformerQueryBuilder()
sceneID := sceneIDs[sceneIdxWithPerformer]
performers, err := pqb.FindNameBySceneID(sceneID, nil)
if err != nil {
t.Fatalf("Error finding performer: %s", err.Error())
}
assert.Equal(t, 1, len(performers))
performer := performers[0]
assert.Equal(t, getPerformerStringValue(performerIdxWithScene, "Name"), performer.Name.String)
performers, err = pqb.FindBySceneID(0, nil)
if err != nil {
t.Fatalf("Error finding performer: %s", err.Error())
}
assert.Equal(t, 0, len(performers))
}
func TestPerformerFindByNames(t *testing.T) {
var names []string
pqb := models.NewPerformerQueryBuilder()
names = append(names, performerNames[performerIdxWithScene]) // find performers by names
performers, err := pqb.FindByNames(names, nil, false)
if err != nil {
t.Fatalf("Error finding performers: %s", err.Error())
}
assert.Len(t, performers, 1)
assert.Equal(t, performerNames[performerIdxWithScene], performers[0].Name.String)
performers, err = pqb.FindByNames(names, nil, true) // find performers by names nocase
if err != nil {
t.Fatalf("Error finding performers: %s", err.Error())
}
assert.Len(t, performers, 2) // performerIdxWithScene and performerIdxWithDupName
assert.Equal(t, strings.ToLower(performerNames[performerIdxWithScene]), strings.ToLower(performers[0].Name.String))
assert.Equal(t, strings.ToLower(performerNames[performerIdxWithScene]), strings.ToLower(performers[1].Name.String))
names = append(names, performerNames[performerIdx1WithScene]) // find performers by names ( 2 names )
performers, err = pqb.FindByNames(names, nil, false)
if err != nil {
t.Fatalf("Error finding performers: %s", err.Error())
}
assert.Len(t, performers, 2) // performerIdxWithScene and performerIdx1WithScene
assert.Equal(t, performerNames[performerIdxWithScene], performers[0].Name.String)
assert.Equal(t, performerNames[performerIdx1WithScene], performers[1].Name.String)
performers, err = pqb.FindByNames(names, nil, true) // find performers by names ( 2 names nocase)
if err != nil {
t.Fatalf("Error finding performers: %s", err.Error())
}
assert.Len(t, performers, 4) // performerIdxWithScene and performerIdxWithDupName , performerIdx1WithScene and performerIdx1WithDupName
assert.Equal(t, performerNames[performerIdxWithScene], performers[0].Name.String)
assert.Equal(t, performerNames[performerIdx1WithScene], performers[1].Name.String)
assert.Equal(t, performerNames[performerIdx1WithDupName], performers[2].Name.String)
assert.Equal(t, performerNames[performerIdxWithDupName], performers[3].Name.String)
}
func TestPerformerUpdatePerformerImage(t *testing.T) {
qb := models.NewPerformerQueryBuilder()
// create performer to test against
ctx := context.TODO()
tx := database.DB.MustBeginTx(ctx, nil)
const name = "TestPerformerUpdatePerformerImage"
performer := models.Performer{
Name: sql.NullString{String: name, Valid: true},
Checksum: utils.MD5FromString(name),
Favorite: sql.NullBool{Bool: false, Valid: true},
}
created, err := qb.Create(performer, tx)
if err != nil {
tx.Rollback()
t.Fatalf("Error creating performer: %s", err.Error())
}
image := []byte("image")
err = qb.UpdatePerformerImage(created.ID, image, tx)
if err != nil {
tx.Rollback()
t.Fatalf("Error updating performer image: %s", err.Error())
}
if err := tx.Commit(); err != nil {
tx.Rollback()
t.Fatalf("Error committing: %s", err.Error())
}
// ensure image set
storedImage, err := qb.GetPerformerImage(created.ID, nil)
if err != nil {
t.Fatalf("Error getting image: %s", err.Error())
}
assert.Equal(t, storedImage, image)
// set nil image
tx = database.DB.MustBeginTx(ctx, nil)
err = qb.UpdatePerformerImage(created.ID, nil, tx)
if err == nil {
t.Fatalf("Expected error setting nil image")
}
tx.Rollback()
}
func TestPerformerDestroyPerformerImage(t *testing.T) {
qb := models.NewPerformerQueryBuilder()
// create performer to test against
ctx := context.TODO()
tx := database.DB.MustBeginTx(ctx, nil)
const name = "TestPerformerDestroyPerformerImage"
performer := models.Performer{
Name: sql.NullString{String: name, Valid: true},
Checksum: utils.MD5FromString(name),
Favorite: sql.NullBool{Bool: false, Valid: true},
}
created, err := qb.Create(performer, tx)
if err != nil {
tx.Rollback()
t.Fatalf("Error creating performer: %s", err.Error())
}
image := []byte("image")
err = qb.UpdatePerformerImage(created.ID, image, tx)
if err != nil {
tx.Rollback()
t.Fatalf("Error updating performer image: %s", err.Error())
}
if err := tx.Commit(); err != nil {
tx.Rollback()
t.Fatalf("Error committing: %s", err.Error())
}
tx = database.DB.MustBeginTx(ctx, nil)
err = qb.DestroyPerformerImage(created.ID, tx)
if err != nil {
tx.Rollback()
t.Fatalf("Error destroying performer image: %s", err.Error())
}
if err := tx.Commit(); err != nil {
tx.Rollback()
t.Fatalf("Error committing: %s", err.Error())
}
// image should be nil
storedImage, err := qb.GetPerformerImage(created.ID, nil)
if err != nil {
t.Fatalf("Error getting image: %s", err.Error())
}
assert.Nil(t, storedImage)
}
func TestPerformerQueryAge(t *testing.T) {
const age = 19
ageCriterion := models.IntCriterionInput{
Value: age,
Modifier: models.CriterionModifierEquals,
}
verifyPerformerAge(t, ageCriterion)
ageCriterion.Modifier = models.CriterionModifierNotEquals
verifyPerformerAge(t, ageCriterion)
ageCriterion.Modifier = models.CriterionModifierGreaterThan
verifyPerformerAge(t, ageCriterion)
ageCriterion.Modifier = models.CriterionModifierLessThan
verifyPerformerAge(t, ageCriterion)
}
func verifyPerformerAge(t *testing.T, ageCriterion models.IntCriterionInput) {
qb := models.NewPerformerQueryBuilder()
performerFilter := models.PerformerFilterType{
Age: &ageCriterion,
}
performers, _ := qb.Query(&performerFilter, nil)
now := time.Now()
for _, performer := range performers {
bd := performer.Birthdate.String
d, _ := time.Parse("2006-01-02", bd)
age := now.Year() - d.Year()
if now.YearDay() < d.YearDay() {
age = age - 1
}
verifyInt(t, age, ageCriterion)
}
}
// TODO Update
// TODO Destroy
// TODO Find
// TODO Count
// TODO All
// TODO AllSlim
// TODO Query

View File

@@ -1,68 +0,0 @@
// +build integration
package models_test
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stashapp/stash/pkg/models"
)
func TestMarkerFindBySceneID(t *testing.T) {
mqb := models.NewSceneMarkerQueryBuilder()
sceneID := sceneIDs[sceneIdxWithMarker]
markers, err := mqb.FindBySceneID(sceneID, nil)
if err != nil {
t.Fatalf("Error finding markers: %s", err.Error())
}
assert.Len(t, markers, 1)
assert.Equal(t, markerIDs[markerIdxWithScene], markers[0].ID)
markers, err = mqb.FindBySceneID(0, nil)
if err != nil {
t.Fatalf("Error finding marker: %s", err.Error())
}
assert.Len(t, markers, 0)
}
func TestMarkerCountByTagID(t *testing.T) {
mqb := models.NewSceneMarkerQueryBuilder()
markerCount, err := mqb.CountByTagID(tagIDs[tagIdxWithPrimaryMarker])
if err != nil {
t.Fatalf("error calling CountByTagID: %s", err.Error())
}
assert.Equal(t, 1, markerCount)
markerCount, err = mqb.CountByTagID(tagIDs[tagIdxWithMarker])
if err != nil {
t.Fatalf("error calling CountByTagID: %s", err.Error())
}
assert.Equal(t, 1, markerCount)
markerCount, err = mqb.CountByTagID(0)
if err != nil {
t.Fatalf("error calling CountByTagID: %s", err.Error())
}
assert.Equal(t, 0, markerCount)
}
// TODO Update
// TODO Destroy
// TODO Find
// TODO GetMarkerStrings
// TODO Wall
// TODO Query

File diff suppressed because it is too large Load Diff

View File

@@ -1,113 +0,0 @@
package models
import (
"database/sql"
"github.com/jmoiron/sqlx"
"github.com/stashapp/stash/pkg/database"
)
type ScrapedItemQueryBuilder struct{}
func NewScrapedItemQueryBuilder() ScrapedItemQueryBuilder {
return ScrapedItemQueryBuilder{}
}
func (qb *ScrapedItemQueryBuilder) Create(newScrapedItem ScrapedItem, tx *sqlx.Tx) (*ScrapedItem, error) {
ensureTx(tx)
result, err := tx.NamedExec(
`INSERT INTO scraped_items (title, description, url, date, rating, tags, models, episode, gallery_filename,
gallery_url, video_filename, video_url, studio_id, created_at, updated_at)
VALUES (:title, :description, :url, :date, :rating, :tags, :models, :episode, :gallery_filename,
:gallery_url, :video_filename, :video_url, :studio_id, :created_at, :updated_at)
`,
newScrapedItem,
)
if err != nil {
return nil, err
}
scrapedItemID, err := result.LastInsertId()
if err != nil {
return nil, err
}
if err := tx.Get(&newScrapedItem, `SELECT * FROM scraped_items WHERE id = ? LIMIT 1`, scrapedItemID); err != nil {
return nil, err
}
return &newScrapedItem, nil
}
func (qb *ScrapedItemQueryBuilder) Update(updatedScrapedItem ScrapedItem, tx *sqlx.Tx) (*ScrapedItem, error) {
ensureTx(tx)
_, err := tx.NamedExec(
`UPDATE scraped_items SET `+SQLGenKeys(updatedScrapedItem)+` WHERE scraped_items.id = :id`,
updatedScrapedItem,
)
if err != nil {
return nil, err
}
if err := tx.Get(&updatedScrapedItem, `SELECT * FROM scraped_items WHERE id = ? LIMIT 1`, updatedScrapedItem.ID); err != nil {
return nil, err
}
return &updatedScrapedItem, nil
}
func (qb *ScrapedItemQueryBuilder) Find(id int) (*ScrapedItem, error) {
query := "SELECT * FROM scraped_items WHERE id = ? LIMIT 1"
args := []interface{}{id}
return qb.queryScrapedItem(query, args, nil)
}
func (qb *ScrapedItemQueryBuilder) All() ([]*ScrapedItem, error) {
return qb.queryScrapedItems(selectAll("scraped_items")+qb.getScrapedItemsSort(nil), nil, nil)
}
func (qb *ScrapedItemQueryBuilder) getScrapedItemsSort(findFilter *FindFilterType) string {
var sort string
var direction string
if findFilter == nil {
sort = "id" // TODO studio_id and title
direction = "ASC"
} else {
sort = findFilter.GetSort("id")
direction = findFilter.GetDirection()
}
return getSort(sort, direction, "scraped_items")
}
func (qb *ScrapedItemQueryBuilder) queryScrapedItem(query string, args []interface{}, tx *sqlx.Tx) (*ScrapedItem, error) {
results, err := qb.queryScrapedItems(query, args, tx)
if err != nil || len(results) < 1 {
return nil, err
}
return results[0], nil
}
func (qb *ScrapedItemQueryBuilder) queryScrapedItems(query string, args []interface{}, tx *sqlx.Tx) ([]*ScrapedItem, error) {
var rows *sqlx.Rows
var err error
if tx != nil {
rows, err = tx.Queryx(query, args...)
} else {
rows, err = database.DB.Queryx(query, args...)
}
if err != nil && err != sql.ErrNoRows {
return nil, err
}
defer rows.Close()
scrapedItems := make([]*ScrapedItem, 0)
for rows.Next() {
scrapedItem := ScrapedItem{}
if err := rows.StructScan(&scrapedItem); err != nil {
return nil, err
}
scrapedItems = append(scrapedItems, &scrapedItem)
}
if err := rows.Err(); err != nil {
return nil, err
}
return scrapedItems, nil
}

View File

@@ -1,484 +0,0 @@
package models
import (
"database/sql"
"fmt"
"math/rand"
"reflect"
"strconv"
"strings"
"github.com/jmoiron/sqlx"
"github.com/stashapp/stash/pkg/database"
"github.com/stashapp/stash/pkg/logger"
)
type queryBuilder struct {
tableName string
body string
whereClauses []string
havingClauses []string
args []interface{}
sortAndPagination string
}
func (qb queryBuilder) executeFind() ([]int, int) {
return executeFindQuery(qb.tableName, qb.body, qb.args, qb.sortAndPagination, qb.whereClauses, qb.havingClauses)
}
func (qb *queryBuilder) addWhere(clauses ...string) {
for _, clause := range clauses {
if len(clause) > 0 {
qb.whereClauses = append(qb.whereClauses, clause)
}
}
}
func (qb *queryBuilder) addHaving(clauses ...string) {
for _, clause := range clauses {
if len(clause) > 0 {
qb.havingClauses = append(qb.havingClauses, clause)
}
}
}
func (qb *queryBuilder) addArg(args ...interface{}) {
qb.args = append(qb.args, args...)
}
func (qb *queryBuilder) handleIntCriterionInput(c *IntCriterionInput, column string) {
if c != nil {
clause, count := getIntCriterionWhereClause(column, *c)
qb.addWhere(clause)
if count == 1 {
qb.addArg(c.Value)
}
}
}
func (qb *queryBuilder) handleStringCriterionInput(c *StringCriterionInput, column string) {
if c != nil {
if modifier := c.Modifier; c.Modifier.IsValid() {
switch modifier {
case CriterionModifierIncludes:
clause, thisArgs := getSearchBinding([]string{column}, c.Value, false)
qb.addWhere(clause)
qb.addArg(thisArgs...)
case CriterionModifierExcludes:
clause, thisArgs := getSearchBinding([]string{column}, c.Value, true)
qb.addWhere(clause)
qb.addArg(thisArgs...)
case CriterionModifierEquals:
qb.addWhere(column + " LIKE ?")
qb.addArg(c.Value)
case CriterionModifierNotEquals:
qb.addWhere(column + " NOT LIKE ?")
qb.addArg(c.Value)
default:
clause, count := getSimpleCriterionClause(modifier, "?")
qb.addWhere(column + " " + clause)
if count == 1 {
qb.addArg(c.Value)
}
}
}
}
}
var randomSortFloat = rand.Float64()
func selectAll(tableName string) string {
idColumn := getColumn(tableName, "*")
return "SELECT " + idColumn + " FROM " + tableName + " "
}
func selectDistinctIDs(tableName string) string {
idColumn := getColumn(tableName, "id")
return "SELECT DISTINCT " + idColumn + " FROM " + tableName + " "
}
func buildCountQuery(query string) string {
return "SELECT COUNT(*) as count FROM (" + query + ") as temp"
}
func getColumn(tableName string, columnName string) string {
return tableName + "." + columnName
}
func getPagination(findFilter *FindFilterType) string {
if findFilter == nil {
panic("nil find filter for pagination")
}
var page int
if findFilter.Page == nil || *findFilter.Page < 1 {
page = 1
} else {
page = *findFilter.Page
}
var perPage int
if findFilter.PerPage == nil {
perPage = 25
} else {
perPage = *findFilter.PerPage
}
if perPage > 1000 {
perPage = 1000
} else if perPage < 1 {
perPage = 1
}
page = (page - 1) * perPage
return " LIMIT " + strconv.Itoa(perPage) + " OFFSET " + strconv.Itoa(page) + " "
}
func getSort(sort string, direction string, tableName string) string {
if direction != "ASC" && direction != "DESC" {
direction = "ASC"
}
const randomSeedPrefix = "random_"
if strings.HasSuffix(sort, "_count") {
var relationTableName = strings.TrimSuffix(sort, "_count") // TODO: pluralize?
colName := getColumn(relationTableName, "id")
return " ORDER BY COUNT(distinct " + colName + ") " + direction
} else if strings.Compare(sort, "filesize") == 0 {
colName := getColumn(tableName, "size")
return " ORDER BY cast(" + colName + " as integer) " + direction
} else if strings.HasPrefix(sort, randomSeedPrefix) {
// seed as a parameter from the UI
// turn the provided seed into a float
seedStr := "0." + sort[len(randomSeedPrefix):]
seed, err := strconv.ParseFloat(seedStr, 32)
if err != nil {
// fallback to default seed
seed = randomSortFloat
}
return getRandomSort(tableName, direction, seed)
} else if strings.Compare(sort, "random") == 0 {
return getRandomSort(tableName, direction, randomSortFloat)
} else {
colName := getColumn(tableName, sort)
var additional string
if tableName == "scenes" {
additional = ", bitrate DESC, framerate DESC, scenes.rating DESC, scenes.duration DESC"
} else if tableName == "scene_markers" {
additional = ", scene_markers.scene_id ASC, scene_markers.seconds ASC"
}
if strings.Compare(sort, "name") == 0 {
return " ORDER BY " + colName + " COLLATE NOCASE " + direction + additional
}
if strings.Compare(sort, "title") == 0 {
return " ORDER BY " + colName + " COLLATE NATURAL_CS " + direction + additional
}
return " ORDER BY " + colName + " " + direction + additional
}
}
func getRandomSort(tableName string, direction string, seed float64) string {
// https://stackoverflow.com/a/24511461
colName := getColumn(tableName, "id")
randomSortString := strconv.FormatFloat(seed, 'f', 16, 32)
return " ORDER BY " + "(substr(" + colName + " * " + randomSortString + ", length(" + colName + ") + 2))" + " " + direction
}
func getSearchBinding(columns []string, q string, not bool) (string, []interface{}) {
var likeClauses []string
var args []interface{}
notStr := ""
binaryType := " OR "
if not {
notStr = " NOT "
binaryType = " AND "
}
queryWords := strings.Split(q, " ")
trimmedQuery := strings.Trim(q, "\"")
if trimmedQuery == q {
// Search for any word
for _, word := range queryWords {
for _, column := range columns {
likeClauses = append(likeClauses, column+notStr+" LIKE ?")
args = append(args, "%"+word+"%")
}
}
} else {
// Search the exact query
for _, column := range columns {
likeClauses = append(likeClauses, column+notStr+" LIKE ?")
args = append(args, "%"+trimmedQuery+"%")
}
}
likes := strings.Join(likeClauses, binaryType)
return "(" + likes + ")", args
}
func getInBinding(length int) string {
bindings := strings.Repeat("?, ", length)
bindings = strings.TrimRight(bindings, ", ")
return "(" + bindings + ")"
}
func getCriterionModifierBinding(criterionModifier CriterionModifier, value interface{}) (string, int) {
var length int
switch x := value.(type) {
case []string:
length = len(x)
case []int:
length = len(x)
default:
length = 1
}
if modifier := criterionModifier.String(); criterionModifier.IsValid() {
switch modifier {
case "EQUALS", "NOT_EQUALS", "GREATER_THAN", "LESS_THAN", "IS_NULL", "NOT_NULL":
return getSimpleCriterionClause(criterionModifier, "?")
case "INCLUDES":
return "IN " + getInBinding(length), length // TODO?
case "EXCLUDES":
return "NOT IN " + getInBinding(length), length // TODO?
default:
logger.Errorf("todo")
return "= ?", 1 // TODO
}
}
return "= ?", 1 // TODO
}
func getSimpleCriterionClause(criterionModifier CriterionModifier, rhs string) (string, int) {
if modifier := criterionModifier.String(); criterionModifier.IsValid() {
switch modifier {
case "EQUALS":
return "= " + rhs, 1
case "NOT_EQUALS":
return "!= " + rhs, 1
case "GREATER_THAN":
return "> " + rhs, 1
case "LESS_THAN":
return "< " + rhs, 1
case "IS_NULL":
return "IS NULL", 0
case "NOT_NULL":
return "IS NOT NULL", 0
default:
logger.Errorf("todo")
return "= ?", 1 // TODO
}
}
return "= ?", 1 // TODO
}
func getIntCriterionWhereClause(column string, input IntCriterionInput) (string, int) {
binding, count := getCriterionModifierBinding(input.Modifier, input.Value)
return column + " " + binding, count
}
// returns where clause and having clause
func getMultiCriterionClause(primaryTable, foreignTable, joinTable, primaryFK, foreignFK string, criterion *MultiCriterionInput) (string, string) {
whereClause := ""
havingClause := ""
if criterion.Modifier == CriterionModifierIncludes {
// includes any of the provided ids
whereClause = foreignTable + ".id IN " + getInBinding(len(criterion.Value))
} else if criterion.Modifier == CriterionModifierIncludesAll {
// includes all of the provided ids
whereClause = foreignTable + ".id IN " + getInBinding(len(criterion.Value))
havingClause = "count(distinct " + foreignTable + ".id) IS " + strconv.Itoa(len(criterion.Value))
} else if criterion.Modifier == CriterionModifierExcludes {
// excludes all of the provided ids
if joinTable != "" {
whereClause = "not exists (select " + joinTable + "." + primaryFK + " from " + joinTable + " where " + joinTable + "." + primaryFK + " = " + primaryTable + ".id and " + joinTable + "." + foreignFK + " in " + getInBinding(len(criterion.Value)) + ")"
} else {
whereClause = "not exists (select s.id from " + primaryTable + " as s where s.id = " + primaryTable + ".id and s." + foreignFK + " in " + getInBinding(len(criterion.Value)) + ")"
}
}
return whereClause, havingClause
}
func runIdsQuery(query string, args []interface{}) ([]int, error) {
var result []struct {
Int int `db:"id"`
}
if err := database.DB.Select(&result, query, args...); err != nil && err != sql.ErrNoRows {
return []int{}, err
}
vsm := make([]int, len(result))
for i, v := range result {
vsm[i] = v.Int
}
return vsm, nil
}
func runCountQuery(query string, args []interface{}) (int, error) {
// Perform query and fetch result
result := struct {
Int int `db:"count"`
}{0}
if err := database.DB.Get(&result, query, args...); err != nil && err != sql.ErrNoRows {
return 0, err
}
return result.Int, nil
}
func runSumQuery(query string, args []interface{}) (float64, error) {
// Perform query and fetch result
result := struct {
Float64 float64 `db:"sum"`
}{0}
if err := database.DB.Get(&result, query, args...); err != nil && err != sql.ErrNoRows {
return 0, err
}
return result.Float64, nil
}
func executeFindQuery(tableName string, body string, args []interface{}, sortAndPagination string, whereClauses []string, havingClauses []string) ([]int, int) {
if len(whereClauses) > 0 {
body = body + " WHERE " + strings.Join(whereClauses, " AND ") // TODO handle AND or OR
}
body = body + " GROUP BY " + tableName + ".id "
if len(havingClauses) > 0 {
body = body + " HAVING " + strings.Join(havingClauses, " AND ") // TODO handle AND or OR
}
countQuery := buildCountQuery(body)
idsQuery := body + sortAndPagination
// Perform query and fetch result
logger.Tracef("SQL: %s, args: %v", idsQuery, args)
countResult, countErr := runCountQuery(countQuery, args)
idsResult, idsErr := runIdsQuery(idsQuery, args)
if countErr != nil {
logger.Errorf("Error executing count query with SQL: %s, args: %v, error: %s", countQuery, args, countErr.Error())
panic(countErr)
}
if idsErr != nil {
logger.Errorf("Error executing find query with SQL: %s, args: %v, error: %s", idsQuery, args, idsErr.Error())
panic(idsErr)
}
return idsResult, countResult
}
func executeDeleteQuery(tableName string, id string, tx *sqlx.Tx) error {
if tx == nil {
panic("must use a transaction")
}
idColumnName := getColumn(tableName, "id")
_, err := tx.Exec(
`DELETE FROM `+tableName+` WHERE `+idColumnName+` = ?`,
id,
)
return err
}
func ensureTx(tx *sqlx.Tx) {
if tx == nil {
panic("must use a transaction")
}
}
// https://github.com/jmoiron/sqlx/issues/410
// sqlGenKeys is used for passing a struct and returning a string
// of keys for non empty key:values. These keys are formated
// keyname=:keyname with a comma seperating them
func SQLGenKeys(i interface{}) string {
return sqlGenKeys(i, false)
}
// support a partial interface. When a partial interface is provided,
// keys will always be included if the value is not null. The partial
// interface must therefore consist of pointers
func SQLGenKeysPartial(i interface{}) string {
return sqlGenKeys(i, true)
}
func sqlGenKeys(i interface{}, partial bool) string {
var query []string
v := reflect.ValueOf(i)
for i := 0; i < v.NumField(); i++ {
//get key for struct tag
rawKey := v.Type().Field(i).Tag.Get("db")
key := strings.Split(rawKey, ",")[0]
if key == "id" {
continue
}
var add bool
switch t := v.Field(i).Interface().(type) {
case string:
add = partial || t != ""
case int:
add = partial || t != 0
case float64:
add = partial || t != 0
case bool:
add = true
case SQLiteTimestamp:
add = partial || !t.Timestamp.IsZero()
case NullSQLiteTimestamp:
add = partial || t.Valid
case SQLiteDate:
add = partial || t.Valid
case sql.NullString:
add = partial || t.Valid
case sql.NullBool:
add = partial || t.Valid
case sql.NullInt64:
add = partial || t.Valid
case sql.NullFloat64:
add = partial || t.Valid
default:
reflectValue := reflect.ValueOf(t)
isNil := reflectValue.IsNil()
add = !isNil
}
if add {
query = append(query, fmt.Sprintf("%s=:%s", key, key))
}
}
return strings.Join(query, ", ")
}
func getImage(tx *sqlx.Tx, query string, args ...interface{}) ([]byte, error) {
var rows *sqlx.Rows
var err error
if tx != nil {
rows, err = tx.Queryx(query, args...)
} else {
rows, err = database.DB.Queryx(query, args...)
}
if err != nil && err != sql.ErrNoRows {
return nil, err
}
defer rows.Close()
var ret []byte
if rows.Next() {
if err := rows.Scan(&ret); err != nil {
return nil, err
}
}
if err := rows.Err(); err != nil {
return nil, err
}
return ret, nil
}

View File

@@ -1,307 +0,0 @@
package models
import (
"database/sql"
"fmt"
"github.com/jmoiron/sqlx"
"github.com/stashapp/stash/pkg/database"
)
type StudioQueryBuilder struct{}
func NewStudioQueryBuilder() StudioQueryBuilder {
return StudioQueryBuilder{}
}
func (qb *StudioQueryBuilder) Create(newStudio Studio, tx *sqlx.Tx) (*Studio, error) {
ensureTx(tx)
result, err := tx.NamedExec(
`INSERT INTO studios (checksum, name, url, parent_id, created_at, updated_at)
VALUES (:checksum, :name, :url, :parent_id, :created_at, :updated_at)
`,
newStudio,
)
if err != nil {
return nil, err
}
studioID, err := result.LastInsertId()
if err != nil {
return nil, err
}
if err := tx.Get(&newStudio, `SELECT * FROM studios WHERE id = ? LIMIT 1`, studioID); err != nil {
return nil, err
}
return &newStudio, nil
}
func (qb *StudioQueryBuilder) Update(updatedStudio StudioPartial, tx *sqlx.Tx) (*Studio, error) {
ensureTx(tx)
_, err := tx.NamedExec(
`UPDATE studios SET `+SQLGenKeysPartial(updatedStudio)+` WHERE studios.id = :id`,
updatedStudio,
)
if err != nil {
return nil, err
}
var ret Studio
if err := tx.Get(&ret, `SELECT * FROM studios WHERE id = ? LIMIT 1`, updatedStudio.ID); err != nil {
return nil, err
}
return &ret, nil
}
func (qb *StudioQueryBuilder) UpdateFull(updatedStudio Studio, tx *sqlx.Tx) (*Studio, error) {
ensureTx(tx)
_, err := tx.NamedExec(
`UPDATE studios SET `+SQLGenKeys(updatedStudio)+` WHERE studios.id = :id`,
updatedStudio,
)
if err != nil {
return nil, err
}
var ret Studio
if err := tx.Get(&ret, `SELECT * FROM studios WHERE id = ? LIMIT 1`, updatedStudio.ID); err != nil {
return nil, err
}
return &ret, nil
}
func (qb *StudioQueryBuilder) Destroy(id string, tx *sqlx.Tx) error {
// remove studio from scenes
_, err := tx.Exec("UPDATE scenes SET studio_id = null WHERE studio_id = ?", id)
if err != nil {
return err
}
// remove studio from scraped items
_, err = tx.Exec("UPDATE scraped_items SET studio_id = null WHERE studio_id = ?", id)
if err != nil {
return err
}
return executeDeleteQuery("studios", id, tx)
}
func (qb *StudioQueryBuilder) Find(id int, tx *sqlx.Tx) (*Studio, error) {
query := "SELECT * FROM studios WHERE id = ? LIMIT 1"
args := []interface{}{id}
return qb.queryStudio(query, args, tx)
}
func (qb *StudioQueryBuilder) FindMany(ids []int) ([]*Studio, error) {
var studios []*Studio
for _, id := range ids {
studio, err := qb.Find(id, nil)
if err != nil {
return nil, err
}
if studio == nil {
return nil, fmt.Errorf("studio with id %d not found", id)
}
studios = append(studios, studio)
}
return studios, nil
}
func (qb *StudioQueryBuilder) FindChildren(id int, tx *sqlx.Tx) ([]*Studio, error) {
query := "SELECT studios.* FROM studios WHERE studios.parent_id = ?"
args := []interface{}{id}
return qb.queryStudios(query, args, tx)
}
func (qb *StudioQueryBuilder) FindBySceneID(sceneID int) (*Studio, error) {
query := "SELECT studios.* FROM studios JOIN scenes ON studios.id = scenes.studio_id WHERE scenes.id = ? LIMIT 1"
args := []interface{}{sceneID}
return qb.queryStudio(query, args, nil)
}
func (qb *StudioQueryBuilder) FindByName(name string, tx *sqlx.Tx, nocase bool) (*Studio, error) {
query := "SELECT * FROM studios WHERE name = ?"
if nocase {
query += " COLLATE NOCASE"
}
query += " LIMIT 1"
args := []interface{}{name}
return qb.queryStudio(query, args, tx)
}
func (qb *StudioQueryBuilder) Count() (int, error) {
return runCountQuery(buildCountQuery("SELECT studios.id FROM studios"), nil)
}
func (qb *StudioQueryBuilder) All() ([]*Studio, error) {
return qb.queryStudios(selectAll("studios")+qb.getStudioSort(nil), nil, nil)
}
func (qb *StudioQueryBuilder) AllSlim() ([]*Studio, error) {
return qb.queryStudios("SELECT studios.id, studios.name, studios.parent_id FROM studios "+qb.getStudioSort(nil), nil, nil)
}
func (qb *StudioQueryBuilder) Query(studioFilter *StudioFilterType, findFilter *FindFilterType) ([]*Studio, int) {
if studioFilter == nil {
studioFilter = &StudioFilterType{}
}
if findFilter == nil {
findFilter = &FindFilterType{}
}
var whereClauses []string
var havingClauses []string
var args []interface{}
body := selectDistinctIDs("studios")
body += `
left join scenes on studios.id = scenes.studio_id
left join studio_stash_ids on studio_stash_ids.studio_id = studios.id
`
if q := findFilter.Q; q != nil && *q != "" {
searchColumns := []string{"studios.name"}
clause, thisArgs := getSearchBinding(searchColumns, *q, false)
whereClauses = append(whereClauses, clause)
args = append(args, thisArgs...)
}
if parentsFilter := studioFilter.Parents; parentsFilter != nil && len(parentsFilter.Value) > 0 {
body += `
left join studios as parent_studio on parent_studio.id = studios.parent_id
`
for _, studioID := range parentsFilter.Value {
args = append(args, studioID)
}
whereClause, havingClause := getMultiCriterionClause("studios", "parent_studio", "", "", "parent_id", parentsFilter)
whereClauses = appendClause(whereClauses, whereClause)
havingClauses = appendClause(havingClauses, havingClause)
}
if stashIDFilter := studioFilter.StashID; stashIDFilter != nil {
whereClauses = append(whereClauses, "studio_stash_ids.stash_id = ?")
args = append(args, stashIDFilter)
}
if isMissingFilter := studioFilter.IsMissing; isMissingFilter != nil && *isMissingFilter != "" {
switch *isMissingFilter {
case "image":
body += `left join studios_image on studios_image.studio_id = studios.id
`
whereClauses = appendClause(whereClauses, "studios_image.studio_id IS NULL")
case "stash_id":
whereClauses = appendClause(whereClauses, "studio_stash_ids.studio_id IS NULL")
default:
whereClauses = appendClause(whereClauses, "studios."+*isMissingFilter+" IS NULL")
}
}
sortAndPagination := qb.getStudioSort(findFilter) + getPagination(findFilter)
idsResult, countResult := executeFindQuery("studios", body, args, sortAndPagination, whereClauses, havingClauses)
var studios []*Studio
for _, id := range idsResult {
studio, _ := qb.Find(id, nil)
studios = append(studios, studio)
}
return studios, countResult
}
func (qb *StudioQueryBuilder) getStudioSort(findFilter *FindFilterType) string {
var sort string
var direction string
if findFilter == nil {
sort = "name"
direction = "ASC"
} else {
sort = findFilter.GetSort("name")
direction = findFilter.GetDirection()
}
return getSort(sort, direction, "studios")
}
func (qb *StudioQueryBuilder) queryStudio(query string, args []interface{}, tx *sqlx.Tx) (*Studio, error) {
results, err := qb.queryStudios(query, args, tx)
if err != nil || len(results) < 1 {
return nil, err
}
return results[0], nil
}
func (qb *StudioQueryBuilder) queryStudios(query string, args []interface{}, tx *sqlx.Tx) ([]*Studio, error) {
var rows *sqlx.Rows
var err error
if tx != nil {
rows, err = tx.Queryx(query, args...)
} else {
rows, err = database.DB.Queryx(query, args...)
}
if err != nil && err != sql.ErrNoRows {
return nil, err
}
defer rows.Close()
studios := make([]*Studio, 0)
for rows.Next() {
studio := Studio{}
if err := rows.StructScan(&studio); err != nil {
return nil, err
}
studios = append(studios, &studio)
}
if err := rows.Err(); err != nil {
return nil, err
}
return studios, nil
}
func (qb *StudioQueryBuilder) UpdateStudioImage(studioID int, image []byte, tx *sqlx.Tx) error {
ensureTx(tx)
// Delete the existing cover and then create new
if err := qb.DestroyStudioImage(studioID, tx); err != nil {
return err
}
_, err := tx.Exec(
`INSERT INTO studios_image (studio_id, image) VALUES (?, ?)`,
studioID,
image,
)
return err
}
func (qb *StudioQueryBuilder) DestroyStudioImage(studioID int, tx *sqlx.Tx) error {
ensureTx(tx)
// Delete the existing joins
_, err := tx.Exec("DELETE FROM studios_image WHERE studio_id = ?", studioID)
if err != nil {
return err
}
return err
}
func (qb *StudioQueryBuilder) GetStudioImage(studioID int, tx *sqlx.Tx) ([]byte, error) {
query := `SELECT image from studios_image WHERE studio_id = ?`
return getImage(tx, query, studioID)
}
func (qb *StudioQueryBuilder) HasStudioImage(studioID int) (bool, error) {
ret, err := runCountQuery(buildCountQuery("SELECT studio_id from studios_image WHERE studio_id = ?"), []interface{}{studioID})
if err != nil {
return false, err
}
return ret == 1, nil
}

Some files were not shown because too many files have changed in this diff Show More