diff --git a/internal/api/changeset_translator.go b/internal/api/changeset_translator.go index 3dfb4a6a1..741885680 100644 --- a/internal/api/changeset_translator.go +++ b/internal/api/changeset_translator.go @@ -185,21 +185,6 @@ func (t changesetTranslator) optionalIntFromString(value *string, field string) return models.NewOptionalInt(vv), nil } -func (t changesetTranslator) nullBool(value *bool, field string) *sql.NullBool { - if !t.hasField(field) { - return nil - } - - ret := &sql.NullBool{} - - if value != nil { - ret.Bool = *value - ret.Valid = true - } - - return ret -} - func (t changesetTranslator) optionalBool(value *bool, field string) models.OptionalBool { if !t.hasField(field) { return models.OptionalBool{} diff --git a/internal/api/images.go b/internal/api/images.go index 19156436f..ddcaee629 100644 --- a/internal/api/images.go +++ b/internal/api/images.go @@ -10,6 +10,7 @@ import ( "github.com/stashapp/stash/internal/static" "github.com/stashapp/stash/pkg/hash" "github.com/stashapp/stash/pkg/logger" + "github.com/stashapp/stash/pkg/models" ) type imageBox struct { @@ -86,7 +87,7 @@ func initialiseCustomImages() { } } -func getRandomPerformerImageUsingName(name, gender, customPath string) ([]byte, error) { +func getRandomPerformerImageUsingName(name string, gender models.GenderEnum, customPath string) ([]byte, error) { var box *imageBox // If we have a custom path, we should return a new box in the given path. @@ -95,10 +96,10 @@ func getRandomPerformerImageUsingName(name, gender, customPath string) ([]byte, } if box == nil { - switch strings.ToUpper(gender) { - case "FEMALE": + switch gender { + case models.GenderEnumFemale: box = performerBox - case "MALE": + case models.GenderEnumMale: box = performerBoxMale default: box = performerBox diff --git a/internal/api/resolver_model_performer.go b/internal/api/resolver_model_performer.go index 9e66fb38d..ba7226d10 100644 --- a/internal/api/resolver_model_performer.go +++ b/internal/api/resolver_model_performer.go @@ -2,7 +2,6 @@ package api import ( "context" - "time" "github.com/stashapp/stash/internal/api/urlbuilders" "github.com/stashapp/stash/pkg/gallery" @@ -10,131 +9,14 @@ import ( "github.com/stashapp/stash/pkg/models" ) -func (r *performerResolver) Name(ctx context.Context, obj *models.Performer) (*string, error) { - if obj.Name.Valid { - return &obj.Name.String, nil - } - return nil, nil -} - -func (r *performerResolver) URL(ctx context.Context, obj *models.Performer) (*string, error) { - if obj.URL.Valid { - return &obj.URL.String, nil - } - return nil, nil -} - -func (r *performerResolver) Gender(ctx context.Context, obj *models.Performer) (*models.GenderEnum, error) { - var ret models.GenderEnum - - if obj.Gender.Valid { - ret = models.GenderEnum(obj.Gender.String) - if ret.IsValid() { - return &ret, nil - } - } - - return nil, nil -} - -func (r *performerResolver) Twitter(ctx context.Context, obj *models.Performer) (*string, error) { - if obj.Twitter.Valid { - return &obj.Twitter.String, nil - } - return nil, nil -} - -func (r *performerResolver) Instagram(ctx context.Context, obj *models.Performer) (*string, error) { - if obj.Instagram.Valid { - return &obj.Instagram.String, nil - } - return nil, nil -} - func (r *performerResolver) Birthdate(ctx context.Context, obj *models.Performer) (*string, error) { - if obj.Birthdate.Valid { - return &obj.Birthdate.String, nil + if obj.Birthdate != nil { + ret := obj.Birthdate.String() + return &ret, nil } return nil, nil } -func (r *performerResolver) Ethnicity(ctx context.Context, obj *models.Performer) (*string, error) { - if obj.Ethnicity.Valid { - return &obj.Ethnicity.String, nil - } - return nil, nil -} - -func (r *performerResolver) Country(ctx context.Context, obj *models.Performer) (*string, error) { - if obj.Country.Valid { - return &obj.Country.String, nil - } - return nil, nil -} - -func (r *performerResolver) EyeColor(ctx context.Context, obj *models.Performer) (*string, error) { - if obj.EyeColor.Valid { - return &obj.EyeColor.String, nil - } - return nil, nil -} - -func (r *performerResolver) Height(ctx context.Context, obj *models.Performer) (*string, error) { - if obj.Height.Valid { - return &obj.Height.String, nil - } - return nil, nil -} - -func (r *performerResolver) Measurements(ctx context.Context, obj *models.Performer) (*string, error) { - if obj.Measurements.Valid { - return &obj.Measurements.String, nil - } - return nil, nil -} - -func (r *performerResolver) FakeTits(ctx context.Context, obj *models.Performer) (*string, error) { - if obj.FakeTits.Valid { - return &obj.FakeTits.String, nil - } - return nil, nil -} - -func (r *performerResolver) CareerLength(ctx context.Context, obj *models.Performer) (*string, error) { - if obj.CareerLength.Valid { - return &obj.CareerLength.String, nil - } - return nil, nil -} - -func (r *performerResolver) Tattoos(ctx context.Context, obj *models.Performer) (*string, error) { - if obj.Tattoos.Valid { - return &obj.Tattoos.String, nil - } - return nil, nil -} - -func (r *performerResolver) Piercings(ctx context.Context, obj *models.Performer) (*string, error) { - if obj.Piercings.Valid { - return &obj.Piercings.String, nil - } - return nil, nil -} - -func (r *performerResolver) Aliases(ctx context.Context, obj *models.Performer) (*string, error) { - if obj.Aliases.Valid { - return &obj.Aliases.String, nil - } - return nil, nil -} - -func (r *performerResolver) Favorite(ctx context.Context, obj *models.Performer) (bool, error) { - if obj.Favorite.Valid { - return obj.Favorite.Bool, nil - } - return false, nil -} - func (r *performerResolver) ImagePath(ctx context.Context, obj *models.Performer) (*string, error) { baseURL, _ := ctx.Value(BaseURLCtxKey).(string) imagePath := urlbuilders.NewPerformerURLBuilder(baseURL, obj).GetPerformerImageURL() @@ -212,51 +94,14 @@ func (r *performerResolver) StashIds(ctx context.Context, obj *models.Performer) return stashIDsSliceToPtrSlice(ret), nil } -func (r *performerResolver) Rating(ctx context.Context, obj *models.Performer) (*int, error) { - if obj.Rating.Valid { - rating := int(obj.Rating.Int64) - return &rating, nil - } - return nil, nil -} - -func (r *performerResolver) Details(ctx context.Context, obj *models.Performer) (*string, error) { - if obj.Details.Valid { - return &obj.Details.String, nil - } - return nil, nil -} - func (r *performerResolver) DeathDate(ctx context.Context, obj *models.Performer) (*string, error) { - if obj.DeathDate.Valid { - return &obj.DeathDate.String, nil + if obj.DeathDate != nil { + ret := obj.DeathDate.String() + return &ret, nil } return nil, nil } -func (r *performerResolver) HairColor(ctx context.Context, obj *models.Performer) (*string, error) { - if obj.HairColor.Valid { - return &obj.HairColor.String, nil - } - return nil, nil -} - -func (r *performerResolver) Weight(ctx context.Context, obj *models.Performer) (*int, error) { - if obj.Weight.Valid { - weight := int(obj.Weight.Int64) - return &weight, nil - } - return nil, nil -} - -func (r *performerResolver) CreatedAt(ctx context.Context, obj *models.Performer) (*time.Time, error) { - return &obj.CreatedAt.Timestamp, nil -} - -func (r *performerResolver) UpdatedAt(ctx context.Context, obj *models.Performer) (*time.Time, error) { - return &obj.UpdatedAt.Timestamp, nil -} - func (r *performerResolver) Movies(ctx context.Context, obj *models.Performer) (ret []*models.Movie, err error) { if err := r.withTxn(ctx, func(ctx context.Context) error { ret, err = r.repository.Movie.FindByPerformerID(ctx, obj.ID) diff --git a/internal/api/resolver_mutation_performer.go b/internal/api/resolver_mutation_performer.go index b0ab18852..7ce3f624f 100644 --- a/internal/api/resolver_mutation_performer.go +++ b/internal/api/resolver_mutation_performer.go @@ -2,7 +2,6 @@ package api import ( "context" - "database/sql" "fmt" "strconv" "time" @@ -54,78 +53,75 @@ func (r *mutationResolver) PerformerCreate(ctx context.Context, input PerformerC // Populate a new performer from the input currentTime := time.Now() newPerformer := models.Performer{ + Name: input.Name, Checksum: checksum, - CreatedAt: models.SQLiteTimestamp{Timestamp: currentTime}, - UpdatedAt: models.SQLiteTimestamp{Timestamp: currentTime}, + CreatedAt: currentTime, + UpdatedAt: currentTime, } - newPerformer.Name = sql.NullString{String: input.Name, Valid: true} if input.URL != nil { - newPerformer.URL = sql.NullString{String: *input.URL, Valid: true} + newPerformer.URL = *input.URL } if input.Gender != nil { - newPerformer.Gender = sql.NullString{String: input.Gender.String(), Valid: true} + newPerformer.Gender = *input.Gender } if input.Birthdate != nil { - newPerformer.Birthdate = models.SQLiteDate{String: *input.Birthdate, Valid: true} + d := models.NewDate(*input.Birthdate) + newPerformer.Birthdate = &d } if input.Ethnicity != nil { - newPerformer.Ethnicity = sql.NullString{String: *input.Ethnicity, Valid: true} + newPerformer.Ethnicity = *input.Ethnicity } if input.Country != nil { - newPerformer.Country = sql.NullString{String: *input.Country, Valid: true} + newPerformer.Country = *input.Country } if input.EyeColor != nil { - newPerformer.EyeColor = sql.NullString{String: *input.EyeColor, Valid: true} + newPerformer.EyeColor = *input.EyeColor } if input.Height != nil { - newPerformer.Height = sql.NullString{String: *input.Height, Valid: true} + newPerformer.Height = *input.Height } if input.Measurements != nil { - newPerformer.Measurements = sql.NullString{String: *input.Measurements, Valid: true} + newPerformer.Measurements = *input.Measurements } if input.FakeTits != nil { - newPerformer.FakeTits = sql.NullString{String: *input.FakeTits, Valid: true} + newPerformer.FakeTits = *input.FakeTits } if input.CareerLength != nil { - newPerformer.CareerLength = sql.NullString{String: *input.CareerLength, Valid: true} + newPerformer.CareerLength = *input.CareerLength } if input.Tattoos != nil { - newPerformer.Tattoos = sql.NullString{String: *input.Tattoos, Valid: true} + newPerformer.Tattoos = *input.Tattoos } if input.Piercings != nil { - newPerformer.Piercings = sql.NullString{String: *input.Piercings, Valid: true} + newPerformer.Piercings = *input.Piercings } if input.Aliases != nil { - newPerformer.Aliases = sql.NullString{String: *input.Aliases, Valid: true} + newPerformer.Aliases = *input.Aliases } if input.Twitter != nil { - newPerformer.Twitter = sql.NullString{String: *input.Twitter, Valid: true} + newPerformer.Twitter = *input.Twitter } if input.Instagram != nil { - newPerformer.Instagram = sql.NullString{String: *input.Instagram, Valid: true} + newPerformer.Instagram = *input.Instagram } if input.Favorite != nil { - newPerformer.Favorite = sql.NullBool{Bool: *input.Favorite, Valid: true} - } else { - newPerformer.Favorite = sql.NullBool{Bool: false, Valid: true} + newPerformer.Favorite = *input.Favorite } if input.Rating != nil { - newPerformer.Rating = sql.NullInt64{Int64: int64(*input.Rating), Valid: true} - } else { - newPerformer.Rating = sql.NullInt64{Valid: false} + newPerformer.Rating = input.Rating } if input.Details != nil { - newPerformer.Details = sql.NullString{String: *input.Details, Valid: true} + newPerformer.Details = *input.Details } if input.DeathDate != nil { - newPerformer.DeathDate = models.SQLiteDate{String: *input.DeathDate, Valid: true} + d := models.NewDate(*input.DeathDate) + newPerformer.DeathDate = &d } if input.HairColor != nil { - newPerformer.HairColor = sql.NullString{String: *input.HairColor, Valid: true} + newPerformer.HairColor = *input.HairColor } if input.Weight != nil { - weight := int64(*input.Weight) - newPerformer.Weight = sql.NullInt64{Int64: weight, Valid: true} + newPerformer.Weight = input.Weight } if input.IgnoreAutoTag != nil { newPerformer.IgnoreAutoTag = *input.IgnoreAutoTag @@ -138,24 +134,23 @@ func (r *mutationResolver) PerformerCreate(ctx context.Context, input PerformerC } // Start the transaction and save the performer - var performer *models.Performer if err := r.withTxn(ctx, func(ctx context.Context) error { qb := r.repository.Performer - performer, err = qb.Create(ctx, newPerformer) + err = qb.Create(ctx, &newPerformer) if err != nil { return err } if len(input.TagIds) > 0 { - if err := r.updatePerformerTags(ctx, performer.ID, input.TagIds); err != nil { + if err := r.updatePerformerTags(ctx, newPerformer.ID, input.TagIds); err != nil { return err } } // update image table if len(imageData) > 0 { - if err := qb.UpdateImage(ctx, performer.ID, imageData); err != nil { + if err := qb.UpdateImage(ctx, newPerformer.ID, imageData); err != nil { return err } } @@ -163,7 +158,7 @@ func (r *mutationResolver) PerformerCreate(ctx context.Context, input PerformerC // Save the stash_ids if input.StashIds != nil { stashIDJoins := stashIDPtrSliceToSlice(input.StashIds) - if err := qb.UpdateStashIDs(ctx, performer.ID, stashIDJoins); err != nil { + if err := qb.UpdateStashIDs(ctx, newPerformer.ID, stashIDJoins); err != nil { return err } } @@ -173,17 +168,14 @@ func (r *mutationResolver) PerformerCreate(ctx context.Context, input PerformerC return nil, err } - r.hookExecutor.ExecutePostHooks(ctx, performer.ID, plugin.PerformerCreatePost, input, nil) - return r.getPerformer(ctx, performer.ID) + r.hookExecutor.ExecutePostHooks(ctx, newPerformer.ID, plugin.PerformerCreatePost, input, nil) + return r.getPerformer(ctx, newPerformer.ID) } func (r *mutationResolver) PerformerUpdate(ctx context.Context, input PerformerUpdateInput) (*models.Performer, error) { // Populate performer from the input performerID, _ := strconv.Atoi(input.ID) - updatedPerformer := models.PerformerPartial{ - ID: performerID, - UpdatedAt: &models.SQLiteTimestamp{Timestamp: time.Now()}, - } + updatedPerformer := models.NewPerformerPartial() translator := changesetTranslator{ inputMap: getUpdateInputMap(ctx), @@ -203,54 +195,53 @@ func (r *mutationResolver) PerformerUpdate(ctx context.Context, input PerformerU // generate checksum from performer name rather than image checksum := md5.FromString(*input.Name) - updatedPerformer.Name = &sql.NullString{String: *input.Name, Valid: true} - updatedPerformer.Checksum = &checksum + updatedPerformer.Name = models.NewOptionalString(*input.Name) + updatedPerformer.Checksum = models.NewOptionalString(checksum) } - updatedPerformer.URL = translator.nullString(input.URL, "url") + updatedPerformer.URL = translator.optionalString(input.URL, "url") if translator.hasField("gender") { if input.Gender != nil { - updatedPerformer.Gender = &sql.NullString{String: input.Gender.String(), Valid: true} + updatedPerformer.Gender = models.NewOptionalString(input.Gender.String()) } else { - updatedPerformer.Gender = &sql.NullString{String: "", Valid: false} + updatedPerformer.Gender = models.NewOptionalStringPtr(nil) } } - updatedPerformer.Birthdate = translator.sqliteDate(input.Birthdate, "birthdate") - updatedPerformer.Country = translator.nullString(input.Country, "country") - updatedPerformer.EyeColor = translator.nullString(input.EyeColor, "eye_color") - updatedPerformer.Measurements = translator.nullString(input.Measurements, "measurements") - updatedPerformer.Height = translator.nullString(input.Height, "height") - updatedPerformer.Ethnicity = translator.nullString(input.Ethnicity, "ethnicity") - updatedPerformer.FakeTits = translator.nullString(input.FakeTits, "fake_tits") - updatedPerformer.CareerLength = translator.nullString(input.CareerLength, "career_length") - updatedPerformer.Tattoos = translator.nullString(input.Tattoos, "tattoos") - updatedPerformer.Piercings = translator.nullString(input.Piercings, "piercings") - updatedPerformer.Aliases = translator.nullString(input.Aliases, "aliases") - updatedPerformer.Twitter = translator.nullString(input.Twitter, "twitter") - updatedPerformer.Instagram = translator.nullString(input.Instagram, "instagram") - updatedPerformer.Favorite = translator.nullBool(input.Favorite, "favorite") - updatedPerformer.Rating = translator.nullInt64(input.Rating, "rating") - updatedPerformer.Details = translator.nullString(input.Details, "details") - updatedPerformer.DeathDate = translator.sqliteDate(input.DeathDate, "death_date") - updatedPerformer.HairColor = translator.nullString(input.HairColor, "hair_color") - updatedPerformer.Weight = translator.nullInt64(input.Weight, "weight") - updatedPerformer.IgnoreAutoTag = input.IgnoreAutoTag + updatedPerformer.Birthdate = translator.optionalDate(input.Birthdate, "birthdate") + updatedPerformer.Country = translator.optionalString(input.Country, "country") + updatedPerformer.EyeColor = translator.optionalString(input.EyeColor, "eye_color") + updatedPerformer.Measurements = translator.optionalString(input.Measurements, "measurements") + updatedPerformer.Height = translator.optionalString(input.Height, "height") + updatedPerformer.Ethnicity = translator.optionalString(input.Ethnicity, "ethnicity") + updatedPerformer.FakeTits = translator.optionalString(input.FakeTits, "fake_tits") + updatedPerformer.CareerLength = translator.optionalString(input.CareerLength, "career_length") + updatedPerformer.Tattoos = translator.optionalString(input.Tattoos, "tattoos") + updatedPerformer.Piercings = translator.optionalString(input.Piercings, "piercings") + updatedPerformer.Aliases = translator.optionalString(input.Aliases, "aliases") + updatedPerformer.Twitter = translator.optionalString(input.Twitter, "twitter") + updatedPerformer.Instagram = translator.optionalString(input.Instagram, "instagram") + updatedPerformer.Favorite = translator.optionalBool(input.Favorite, "favorite") + updatedPerformer.Rating = translator.optionalInt(input.Rating, "rating") + updatedPerformer.Details = translator.optionalString(input.Details, "details") + updatedPerformer.DeathDate = translator.optionalDate(input.DeathDate, "death_date") + updatedPerformer.HairColor = translator.optionalString(input.HairColor, "hair_color") + updatedPerformer.Weight = translator.optionalInt(input.Weight, "weight") + updatedPerformer.IgnoreAutoTag = translator.optionalBool(input.IgnoreAutoTag, "ignore_auto_tag") // Start the transaction and save the p - var p *models.Performer if err := r.withTxn(ctx, func(ctx context.Context) error { qb := r.repository.Performer // need to get existing performer - existing, err := qb.Find(ctx, updatedPerformer.ID) + existing, err := qb.Find(ctx, performerID) if err != nil { return err } if existing == nil { - return fmt.Errorf("performer with id %d not found", updatedPerformer.ID) + return fmt.Errorf("performer with id %d not found", performerID) } if err := performer.ValidateDeathDate(existing, input.Birthdate, input.DeathDate); err != nil { @@ -259,26 +250,26 @@ func (r *mutationResolver) PerformerUpdate(ctx context.Context, input PerformerU } } - p, err = qb.Update(ctx, updatedPerformer) + _, err = qb.UpdatePartial(ctx, performerID, updatedPerformer) if err != nil { return err } // Save the tags if translator.hasField("tag_ids") { - if err := r.updatePerformerTags(ctx, p.ID, input.TagIds); err != nil { + if err := r.updatePerformerTags(ctx, performerID, input.TagIds); err != nil { return err } } // update image table if len(imageData) > 0 { - if err := qb.UpdateImage(ctx, p.ID, imageData); err != nil { + if err := qb.UpdateImage(ctx, performerID, imageData); err != nil { return err } } else if imageIncluded { // must be unsetting - if err := qb.DestroyImage(ctx, p.ID); err != nil { + if err := qb.DestroyImage(ctx, performerID); err != nil { return err } } @@ -296,8 +287,8 @@ func (r *mutationResolver) PerformerUpdate(ctx context.Context, input PerformerU return nil, err } - r.hookExecutor.ExecutePostHooks(ctx, p.ID, plugin.PerformerUpdatePost, input, translator.getFields()) - return r.getPerformer(ctx, p.ID) + r.hookExecutor.ExecutePostHooks(ctx, performerID, plugin.PerformerUpdatePost, input, translator.getFields()) + return r.getPerformer(ctx, performerID) } func (r *mutationResolver) updatePerformerTags(ctx context.Context, performerID int, tagsIDs []string) error { @@ -315,43 +306,39 @@ func (r *mutationResolver) BulkPerformerUpdate(ctx context.Context, input BulkPe } // Populate performer from the input - updatedTime := time.Now() - translator := changesetTranslator{ inputMap: getUpdateInputMap(ctx), } - updatedPerformer := models.PerformerPartial{ - UpdatedAt: &models.SQLiteTimestamp{Timestamp: updatedTime}, - } + updatedPerformer := models.NewPerformerPartial() - updatedPerformer.URL = translator.nullString(input.URL, "url") - updatedPerformer.Birthdate = translator.sqliteDate(input.Birthdate, "birthdate") - updatedPerformer.Ethnicity = translator.nullString(input.Ethnicity, "ethnicity") - updatedPerformer.Country = translator.nullString(input.Country, "country") - updatedPerformer.EyeColor = translator.nullString(input.EyeColor, "eye_color") - updatedPerformer.Height = translator.nullString(input.Height, "height") - updatedPerformer.Measurements = translator.nullString(input.Measurements, "measurements") - updatedPerformer.FakeTits = translator.nullString(input.FakeTits, "fake_tits") - updatedPerformer.CareerLength = translator.nullString(input.CareerLength, "career_length") - updatedPerformer.Tattoos = translator.nullString(input.Tattoos, "tattoos") - updatedPerformer.Piercings = translator.nullString(input.Piercings, "piercings") - updatedPerformer.Aliases = translator.nullString(input.Aliases, "aliases") - updatedPerformer.Twitter = translator.nullString(input.Twitter, "twitter") - updatedPerformer.Instagram = translator.nullString(input.Instagram, "instagram") - updatedPerformer.Favorite = translator.nullBool(input.Favorite, "favorite") - updatedPerformer.Rating = translator.nullInt64(input.Rating, "rating") - updatedPerformer.Details = translator.nullString(input.Details, "details") - updatedPerformer.DeathDate = translator.sqliteDate(input.DeathDate, "death_date") - updatedPerformer.HairColor = translator.nullString(input.HairColor, "hair_color") - updatedPerformer.Weight = translator.nullInt64(input.Weight, "weight") - updatedPerformer.IgnoreAutoTag = input.IgnoreAutoTag + updatedPerformer.URL = translator.optionalString(input.URL, "url") + updatedPerformer.Birthdate = translator.optionalDate(input.Birthdate, "birthdate") + updatedPerformer.Ethnicity = translator.optionalString(input.Ethnicity, "ethnicity") + updatedPerformer.Country = translator.optionalString(input.Country, "country") + updatedPerformer.EyeColor = translator.optionalString(input.EyeColor, "eye_color") + updatedPerformer.Height = translator.optionalString(input.Height, "height") + updatedPerformer.Measurements = translator.optionalString(input.Measurements, "measurements") + updatedPerformer.FakeTits = translator.optionalString(input.FakeTits, "fake_tits") + updatedPerformer.CareerLength = translator.optionalString(input.CareerLength, "career_length") + updatedPerformer.Tattoos = translator.optionalString(input.Tattoos, "tattoos") + updatedPerformer.Piercings = translator.optionalString(input.Piercings, "piercings") + updatedPerformer.Aliases = translator.optionalString(input.Aliases, "aliases") + updatedPerformer.Twitter = translator.optionalString(input.Twitter, "twitter") + updatedPerformer.Instagram = translator.optionalString(input.Instagram, "instagram") + updatedPerformer.Favorite = translator.optionalBool(input.Favorite, "favorite") + updatedPerformer.Rating = translator.optionalInt(input.Rating, "rating") + updatedPerformer.Details = translator.optionalString(input.Details, "details") + updatedPerformer.DeathDate = translator.optionalDate(input.DeathDate, "death_date") + updatedPerformer.HairColor = translator.optionalString(input.HairColor, "hair_color") + updatedPerformer.Weight = translator.optionalInt(input.Weight, "weight") + updatedPerformer.IgnoreAutoTag = translator.optionalBool(input.IgnoreAutoTag, "ignore_auto_tag") if translator.hasField("gender") { if input.Gender != nil { - updatedPerformer.Gender = &sql.NullString{String: input.Gender.String(), Valid: true} + updatedPerformer.Gender = models.NewOptionalString(input.Gender.String()) } else { - updatedPerformer.Gender = &sql.NullString{String: "", Valid: false} + updatedPerformer.Gender = models.NewOptionalStringPtr(nil) } } @@ -378,7 +365,7 @@ func (r *mutationResolver) BulkPerformerUpdate(ctx context.Context, input BulkPe return err } - performer, err := qb.Update(ctx, updatedPerformer) + performer, err := qb.UpdatePartial(ctx, performerID, updatedPerformer) if err != nil { return err } diff --git a/internal/api/routes_performer.go b/internal/api/routes_performer.go index 40e41833c..dcc338103 100644 --- a/internal/api/routes_performer.go +++ b/internal/api/routes_performer.go @@ -54,7 +54,7 @@ func (rs performerRoutes) Image(w http.ResponseWriter, r *http.Request) { } if len(image) == 0 || defaultParam == "true" { - image, _ = getRandomPerformerImageUsingName(performer.Name.String, performer.Gender.String, config.GetInstance().GetCustomPerformerImageLocation()) + image, _ = getRandomPerformerImageUsingName(performer.Name, performer.Gender, config.GetInstance().GetCustomPerformerImageLocation()) } if err := utils.ServeImage(image, w, r); err != nil { diff --git a/internal/api/urlbuilders/performer.go b/internal/api/urlbuilders/performer.go index e7e0b2626..841078e2c 100644 --- a/internal/api/urlbuilders/performer.go +++ b/internal/api/urlbuilders/performer.go @@ -1,8 +1,9 @@ package urlbuilders import ( - "github.com/stashapp/stash/pkg/models" "strconv" + + "github.com/stashapp/stash/pkg/models" ) type PerformerURLBuilder struct { @@ -15,7 +16,7 @@ func NewPerformerURLBuilder(baseURL string, performer *models.Performer) Perform return PerformerURLBuilder{ BaseURL: baseURL, PerformerID: strconv.Itoa(performer.ID), - UpdatedAt: strconv.FormatInt(performer.UpdatedAt.Timestamp.Unix(), 10), + UpdatedAt: strconv.FormatInt(performer.UpdatedAt.Unix(), 10), } } diff --git a/internal/autotag/gallery_test.go b/internal/autotag/gallery_test.go index ac7da4e26..cae45e6c9 100644 --- a/internal/autotag/gallery_test.go +++ b/internal/autotag/gallery_test.go @@ -22,14 +22,14 @@ func TestGalleryPerformers(t *testing.T) { const performerID = 2 performer := models.Performer{ ID: performerID, - Name: models.NullString(performerName), + Name: performerName, } const reversedPerformerName = "name performer" const reversedPerformerID = 3 reversedPerformer := models.Performer{ ID: reversedPerformerID, - Name: models.NullString(reversedPerformerName), + Name: reversedPerformerName, } testTables := generateTestTable(performerName, galleryExt) diff --git a/internal/autotag/image_test.go b/internal/autotag/image_test.go index 653cb2c2d..b98842398 100644 --- a/internal/autotag/image_test.go +++ b/internal/autotag/image_test.go @@ -19,14 +19,14 @@ func TestImagePerformers(t *testing.T) { const performerID = 2 performer := models.Performer{ ID: performerID, - Name: models.NullString(performerName), + Name: performerName, } const reversedPerformerName = "name performer" const reversedPerformerID = 3 reversedPerformer := models.Performer{ ID: reversedPerformerID, - Name: models.NullString(reversedPerformerName), + Name: reversedPerformerName, } testTables := generateTestTable(performerName, imageExt) diff --git a/internal/autotag/integration_test.go b/internal/autotag/integration_test.go index 7c5952652..ddf9adc95 100644 --- a/internal/autotag/integration_test.go +++ b/internal/autotag/integration_test.go @@ -87,11 +87,10 @@ func createPerformer(ctx context.Context, pqb models.PerformerWriter) error { // create the performer performer := models.Performer{ Checksum: testName, - Name: sql.NullString{Valid: true, String: testName}, - Favorite: sql.NullBool{Valid: true, Bool: false}, + Name: testName, } - _, err := pqb.Create(ctx, performer) + err := pqb.Create(ctx, &performer) if err != nil { return err } diff --git a/internal/autotag/performer.go b/internal/autotag/performer.go index f240dc0c5..b879fe341 100644 --- a/internal/autotag/performer.go +++ b/internal/autotag/performer.go @@ -33,7 +33,7 @@ func getPerformerTagger(p *models.Performer, cache *match.Cache) tagger { return tagger{ ID: p.ID, Type: "performer", - Name: p.Name.String, + Name: p.Name, cache: cache, } } diff --git a/internal/autotag/performer_test.go b/internal/autotag/performer_test.go index 71161cbfe..422e40645 100644 --- a/internal/autotag/performer_test.go +++ b/internal/autotag/performer_test.go @@ -60,7 +60,7 @@ func testPerformerScenes(t *testing.T, performerName, expectedRegex string) { performer := models.Performer{ ID: performerID, - Name: models.NullString(performerName), + Name: performerName, } organized := false @@ -140,7 +140,7 @@ func testPerformerImages(t *testing.T, performerName, expectedRegex string) { performer := models.Performer{ ID: performerID, - Name: models.NullString(performerName), + Name: performerName, } organized := false @@ -221,7 +221,7 @@ func testPerformerGalleries(t *testing.T, performerName, expectedRegex string) { performer := models.Performer{ ID: performerID, - Name: models.NullString(performerName), + Name: performerName, } organized := false diff --git a/internal/autotag/scene_test.go b/internal/autotag/scene_test.go index 1e9766836..cb0ff32db 100644 --- a/internal/autotag/scene_test.go +++ b/internal/autotag/scene_test.go @@ -152,14 +152,14 @@ func TestScenePerformers(t *testing.T) { const performerID = 2 performer := models.Performer{ ID: performerID, - Name: models.NullString(performerName), + Name: performerName, } const reversedPerformerName = "name performer" const reversedPerformerID = 3 reversedPerformer := models.Performer{ ID: reversedPerformerID, - Name: models.NullString(reversedPerformerName), + Name: reversedPerformerName, } testTables := generateTestTable(performerName, sceneExt) diff --git a/internal/autotag/tagger.go b/internal/autotag/tagger.go index 0e53200ec..4e3437da7 100644 --- a/internal/autotag/tagger.go +++ b/internal/autotag/tagger.go @@ -58,11 +58,11 @@ func (t *tagger) tagPerformers(ctx context.Context, performerReader match.Perfor added, err := addFunc(t.ID, p.ID) if err != nil { - return t.addError("performer", p.Name.String, err) + return t.addError("performer", p.Name, err) } if added { - t.addLog("performer", p.Name.String) + t.addLog("performer", p.Name) } } diff --git a/internal/dlna/cds.go b/internal/dlna/cds.go index eedeef082..948a808d6 100644 --- a/internal/dlna/cds.go +++ b/internal/dlna/cds.go @@ -612,7 +612,7 @@ func (me *contentDirectoryService) getPerformers() []interface{} { } for _, s := range performers { - objs = append(objs, makeStorageFolder("performers/"+strconv.Itoa(s.ID), s.Name.String, "performers")) + objs = append(objs, makeStorageFolder("performers/"+strconv.Itoa(s.ID), s.Name, "performers")) } return nil diff --git a/internal/identify/performer.go b/internal/identify/performer.go index 435524cc4..97621502f 100644 --- a/internal/identify/performer.go +++ b/internal/identify/performer.go @@ -2,7 +2,6 @@ package identify import ( "context" - "database/sql" "fmt" "strconv" "time" @@ -12,7 +11,7 @@ import ( ) type PerformerCreator interface { - Create(ctx context.Context, newPerformer models.Performer) (*models.Performer, error) + Create(ctx context.Context, newPerformer *models.Performer) error UpdateStashIDs(ctx context.Context, performerID int, stashIDs []models.StashID) error } @@ -33,13 +32,14 @@ func getPerformerID(ctx context.Context, endpoint string, w PerformerCreator, p } func createMissingPerformer(ctx context.Context, endpoint string, w PerformerCreator, p *models.ScrapedPerformer) (*int, error) { - created, err := w.Create(ctx, scrapedToPerformerInput(p)) + performerInput := scrapedToPerformerInput(p) + err := w.Create(ctx, &performerInput) if err != nil { return nil, fmt.Errorf("error creating performer: %w", err) } if endpoint != "" && p.RemoteSiteID != nil { - if err := w.UpdateStashIDs(ctx, created.ID, []models.StashID{ + if err := w.UpdateStashIDs(ctx, performerInput.ID, []models.StashID{ { Endpoint: endpoint, StashID: *p.RemoteSiteID, @@ -49,65 +49,66 @@ func createMissingPerformer(ctx context.Context, endpoint string, w PerformerCre } } - return &created.ID, nil + return &performerInput.ID, nil } func scrapedToPerformerInput(performer *models.ScrapedPerformer) models.Performer { currentTime := time.Now() ret := models.Performer{ - Name: sql.NullString{String: *performer.Name, Valid: true}, + Name: *performer.Name, Checksum: md5.FromString(*performer.Name), - CreatedAt: models.SQLiteTimestamp{Timestamp: currentTime}, - UpdatedAt: models.SQLiteTimestamp{Timestamp: currentTime}, - Favorite: sql.NullBool{Bool: false, Valid: true}, + CreatedAt: currentTime, + UpdatedAt: currentTime, } if performer.Birthdate != nil { - ret.Birthdate = models.SQLiteDate{String: *performer.Birthdate, Valid: true} + d := models.NewDate(*performer.Birthdate) + ret.Birthdate = &d } if performer.DeathDate != nil { - ret.DeathDate = models.SQLiteDate{String: *performer.DeathDate, Valid: true} + d := models.NewDate(*performer.DeathDate) + ret.DeathDate = &d } if performer.Gender != nil { - ret.Gender = sql.NullString{String: *performer.Gender, Valid: true} + ret.Gender = models.GenderEnum(*performer.Gender) } if performer.Ethnicity != nil { - ret.Ethnicity = sql.NullString{String: *performer.Ethnicity, Valid: true} + ret.Ethnicity = *performer.Ethnicity } if performer.Country != nil { - ret.Country = sql.NullString{String: *performer.Country, Valid: true} + ret.Country = *performer.Country } if performer.EyeColor != nil { - ret.EyeColor = sql.NullString{String: *performer.EyeColor, Valid: true} + ret.EyeColor = *performer.EyeColor } if performer.HairColor != nil { - ret.HairColor = sql.NullString{String: *performer.HairColor, Valid: true} + ret.HairColor = *performer.HairColor } if performer.Height != nil { - ret.Height = sql.NullString{String: *performer.Height, Valid: true} + ret.Height = *performer.Height } if performer.Measurements != nil { - ret.Measurements = sql.NullString{String: *performer.Measurements, Valid: true} + ret.Measurements = *performer.Measurements } if performer.FakeTits != nil { - ret.FakeTits = sql.NullString{String: *performer.FakeTits, Valid: true} + ret.FakeTits = *performer.FakeTits } if performer.CareerLength != nil { - ret.CareerLength = sql.NullString{String: *performer.CareerLength, Valid: true} + ret.CareerLength = *performer.CareerLength } if performer.Tattoos != nil { - ret.Tattoos = sql.NullString{String: *performer.Tattoos, Valid: true} + ret.Tattoos = *performer.Tattoos } if performer.Piercings != nil { - ret.Piercings = sql.NullString{String: *performer.Piercings, Valid: true} + ret.Piercings = *performer.Piercings } if performer.Aliases != nil { - ret.Aliases = sql.NullString{String: *performer.Aliases, Valid: true} + ret.Aliases = *performer.Aliases } if performer.Twitter != nil { - ret.Twitter = sql.NullString{String: *performer.Twitter, Valid: true} + ret.Twitter = *performer.Twitter } if performer.Instagram != nil { - ret.Instagram = sql.NullString{String: *performer.Instagram, Valid: true} + ret.Instagram = *performer.Instagram } return ret diff --git a/internal/identify/performer_test.go b/internal/identify/performer_test.go index eeed8a1e7..df18b1d09 100644 --- a/internal/identify/performer_test.go +++ b/internal/identify/performer_test.go @@ -1,11 +1,11 @@ package identify import ( - "database/sql" "errors" "reflect" "strconv" "testing" + "time" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models/mocks" @@ -24,9 +24,10 @@ func Test_getPerformerID(t *testing.T) { name := "name" mockPerformerReaderWriter := mocks.PerformerReaderWriter{} - mockPerformerReaderWriter.On("Create", testCtx, mock.Anything).Return(&models.Performer{ - ID: validStoredID, - }, nil) + mockPerformerReaderWriter.On("Create", testCtx, mock.Anything).Run(func(args mock.Arguments) { + p := args.Get(1).(*models.Performer) + p.ID = validStoredID + }).Return(nil) type args struct { endpoint string @@ -132,14 +133,16 @@ func Test_createMissingPerformer(t *testing.T) { performerID := 1 mockPerformerReaderWriter := mocks.PerformerReaderWriter{} - mockPerformerReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p models.Performer) bool { - return p.Name.String == validName - })).Return(&models.Performer{ - ID: performerID, - }, nil) - mockPerformerReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p models.Performer) bool { - return p.Name.String == invalidName - })).Return(nil, errors.New("error creating performer")) + mockPerformerReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p *models.Performer) bool { + return p.Name == validName + })).Run(func(args mock.Arguments) { + p := args.Get(1).(*models.Performer) + p.ID = performerID + }).Return(nil) + + mockPerformerReaderWriter.On("Create", testCtx, mock.MatchedBy(func(p *models.Performer) bool { + return p.Name == invalidName + })).Return(errors.New("error creating performer")) mockPerformerReaderWriter.On("UpdateStashIDs", testCtx, performerID, []models.StashID{ { @@ -241,6 +244,10 @@ func Test_scrapedToPerformerInput(t *testing.T) { return &ret } + dateToDatePtr := func(d models.Date) *models.Date { + return &d + } + tests := []struct { name string performer *models.ScrapedPerformer @@ -268,34 +275,24 @@ func Test_scrapedToPerformerInput(t *testing.T) { Instagram: nextVal(), }, models.Performer{ - Name: models.NullString(name), - Checksum: md5, - Favorite: sql.NullBool{ - Bool: false, - Valid: true, - }, - Birthdate: models.SQLiteDate{ - String: *nextVal(), - Valid: true, - }, - DeathDate: models.SQLiteDate{ - String: *nextVal(), - Valid: true, - }, - Gender: models.NullString(*nextVal()), - Ethnicity: models.NullString(*nextVal()), - Country: models.NullString(*nextVal()), - EyeColor: models.NullString(*nextVal()), - HairColor: models.NullString(*nextVal()), - Height: models.NullString(*nextVal()), - Measurements: models.NullString(*nextVal()), - FakeTits: models.NullString(*nextVal()), - CareerLength: models.NullString(*nextVal()), - Tattoos: models.NullString(*nextVal()), - Piercings: models.NullString(*nextVal()), - Aliases: models.NullString(*nextVal()), - Twitter: models.NullString(*nextVal()), - Instagram: models.NullString(*nextVal()), + Name: name, + Checksum: md5, + Birthdate: dateToDatePtr(models.NewDate(*nextVal())), + DeathDate: dateToDatePtr(models.NewDate(*nextVal())), + Gender: models.GenderEnum(*nextVal()), + Ethnicity: *nextVal(), + Country: *nextVal(), + EyeColor: *nextVal(), + HairColor: *nextVal(), + Height: *nextVal(), + Measurements: *nextVal(), + FakeTits: *nextVal(), + CareerLength: *nextVal(), + Tattoos: *nextVal(), + Piercings: *nextVal(), + Aliases: *nextVal(), + Twitter: *nextVal(), + Instagram: *nextVal(), }, }, { @@ -304,12 +301,8 @@ func Test_scrapedToPerformerInput(t *testing.T) { Name: &name, }, models.Performer{ - Name: models.NullString(name), + Name: name, Checksum: md5, - Favorite: sql.NullBool{ - Bool: false, - Valid: true, - }, }, }, } @@ -318,7 +311,7 @@ func Test_scrapedToPerformerInput(t *testing.T) { got := scrapedToPerformerInput(tt.performer) // clear created/updated dates - got.CreatedAt = models.SQLiteTimestamp{} + got.CreatedAt = time.Time{} got.UpdatedAt = got.CreatedAt if !reflect.DeepEqual(got, tt.want) { diff --git a/internal/manager/task_autotag.go b/internal/manager/task_autotag.go index 09086117a..ee487b724 100644 --- a/internal/manager/task_autotag.go +++ b/internal/manager/task_autotag.go @@ -176,7 +176,7 @@ func (j *autoTagJob) autoTagPerformers(ctx context.Context, progress *job.Progre return nil }); err != nil { - return fmt.Errorf("error auto-tagging performer '%s': %s", performer.Name.String, err.Error()) + return fmt.Errorf("error auto-tagging performer '%s': %s", performer.Name, err.Error()) } progress.Increment() diff --git a/internal/manager/task_stash_box_tag.go b/internal/manager/task_stash_box_tag.go index cf7add510..36d3a3207 100644 --- a/internal/manager/task_stash_box_tag.go +++ b/internal/manager/task_stash_box_tag.go @@ -2,7 +2,6 @@ package manager import ( "context" - "database/sql" "fmt" "time" @@ -31,7 +30,7 @@ func (t *StashBoxPerformerTagTask) Description() string { if t.name != nil { name = *t.name } else if t.performer != nil { - name = t.performer.Name.String + name = t.performer.Name } return fmt.Sprintf("Tagging performer %s from stash-box", name) @@ -70,7 +69,7 @@ func (t *StashBoxPerformerTagTask) stashBoxPerformerTag(ctx context.Context) { if t.name != nil { name = *t.name } else { - name = t.performer.Name.String + name = t.performer.Name } performer, err = client.FindStashBoxPerformerByName(ctx, name) } @@ -86,84 +85,64 @@ func (t *StashBoxPerformerTagTask) stashBoxPerformerTag(ctx context.Context) { } if performer != nil { - updatedTime := time.Now() - if t.performer != nil { - partial := models.PerformerPartial{ - ID: t.performer.ID, - UpdatedAt: &models.SQLiteTimestamp{Timestamp: updatedTime}, - } + partial := models.NewPerformerPartial() if performer.Aliases != nil && !excluded["aliases"] { - value := getNullString(performer.Aliases) - partial.Aliases = &value + partial.Aliases = models.NewOptionalString(*performer.Aliases) } if performer.Birthdate != nil && *performer.Birthdate != "" && !excluded["birthdate"] { value := getDate(performer.Birthdate) - partial.Birthdate = &value + partial.Birthdate = models.NewOptionalDate(*value) } if performer.CareerLength != nil && !excluded["career_length"] { - value := getNullString(performer.CareerLength) - partial.CareerLength = &value + partial.CareerLength = models.NewOptionalString(*performer.CareerLength) } if performer.Country != nil && !excluded["country"] { - value := getNullString(performer.Country) - partial.Country = &value + partial.Country = models.NewOptionalString(*performer.Country) } if performer.Ethnicity != nil && !excluded["ethnicity"] { - value := getNullString(performer.Ethnicity) - partial.Ethnicity = &value + partial.Ethnicity = models.NewOptionalString(*performer.Ethnicity) } if performer.EyeColor != nil && !excluded["eye_color"] { - value := getNullString(performer.EyeColor) - partial.EyeColor = &value + partial.EyeColor = models.NewOptionalString(*performer.EyeColor) } if performer.FakeTits != nil && !excluded["fake_tits"] { - value := getNullString(performer.FakeTits) - partial.FakeTits = &value + partial.FakeTits = models.NewOptionalString(*performer.FakeTits) } if performer.Gender != nil && !excluded["gender"] { - value := getNullString(performer.Gender) - partial.Gender = &value + partial.Gender = models.NewOptionalString(*performer.Gender) } if performer.Height != nil && !excluded["height"] { - value := getNullString(performer.Height) - partial.Height = &value + partial.Height = models.NewOptionalString(*performer.Height) } if performer.Instagram != nil && !excluded["instagram"] { - value := getNullString(performer.Instagram) - partial.Instagram = &value + partial.Instagram = models.NewOptionalString(*performer.Instagram) } if performer.Measurements != nil && !excluded["measurements"] { - value := getNullString(performer.Measurements) - partial.Measurements = &value + partial.Measurements = models.NewOptionalString(*performer.Measurements) } if excluded["name"] && performer.Name != nil { - value := sql.NullString{String: *performer.Name, Valid: true} - partial.Name = &value + partial.Name = models.NewOptionalString(*performer.Name) checksum := md5.FromString(*performer.Name) - partial.Checksum = &checksum + partial.Checksum = models.NewOptionalString(checksum) } if performer.Piercings != nil && !excluded["piercings"] { - value := getNullString(performer.Piercings) - partial.Piercings = &value + partial.Piercings = models.NewOptionalString(*performer.Piercings) } if performer.Tattoos != nil && !excluded["tattoos"] { - value := getNullString(performer.Tattoos) - partial.Tattoos = &value + partial.Tattoos = models.NewOptionalString(*performer.Tattoos) } if performer.Twitter != nil && !excluded["twitter"] { - value := getNullString(performer.Twitter) - partial.Twitter = &value + partial.Twitter = models.NewOptionalString(*performer.Twitter) } if performer.URL != nil && !excluded["url"] { - value := getNullString(performer.URL) - partial.URL = &value + partial.URL = models.NewOptionalString(*performer.URL) } txnErr := txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error { r := instance.Repository - _, err := r.Performer.Update(ctx, partial) + _, err := r.Performer.UpdatePartial(ctx, t.performer.ID, partial) if !t.refresh { err = r.Performer.UpdateStashIDs(ctx, t.performer.ID, []models.StashID{ @@ -203,35 +182,34 @@ func (t *StashBoxPerformerTagTask) stashBoxPerformerTag(ctx context.Context) { } else if t.name != nil && performer.Name != nil { currentTime := time.Now() newPerformer := models.Performer{ - Aliases: getNullString(performer.Aliases), + Aliases: getString(performer.Aliases), Birthdate: getDate(performer.Birthdate), - CareerLength: getNullString(performer.CareerLength), + CareerLength: getString(performer.CareerLength), Checksum: md5.FromString(*performer.Name), - Country: getNullString(performer.Country), - CreatedAt: models.SQLiteTimestamp{Timestamp: currentTime}, - Ethnicity: getNullString(performer.Ethnicity), - EyeColor: getNullString(performer.EyeColor), - FakeTits: getNullString(performer.FakeTits), - Favorite: sql.NullBool{Bool: false, Valid: true}, - Gender: getNullString(performer.Gender), - Height: getNullString(performer.Height), - Instagram: getNullString(performer.Instagram), - Measurements: getNullString(performer.Measurements), - Name: sql.NullString{String: *performer.Name, Valid: true}, - Piercings: getNullString(performer.Piercings), - Tattoos: getNullString(performer.Tattoos), - Twitter: getNullString(performer.Twitter), - URL: getNullString(performer.URL), - UpdatedAt: models.SQLiteTimestamp{Timestamp: currentTime}, + Country: getString(performer.Country), + CreatedAt: currentTime, + Ethnicity: getString(performer.Ethnicity), + EyeColor: getString(performer.EyeColor), + FakeTits: getString(performer.FakeTits), + Gender: models.GenderEnum(getString(performer.Gender)), + Height: getString(performer.Height), + Instagram: getString(performer.Instagram), + Measurements: getString(performer.Measurements), + Name: *performer.Name, + Piercings: getString(performer.Piercings), + Tattoos: getString(performer.Tattoos), + Twitter: getString(performer.Twitter), + URL: getString(performer.URL), + UpdatedAt: currentTime, } err := txn.WithTxn(ctx, instance.Repository, func(ctx context.Context) error { r := instance.Repository - createdPerformer, err := r.Performer.Create(ctx, newPerformer) + err := r.Performer.Create(ctx, &newPerformer) if err != nil { return err } - err = r.Performer.UpdateStashIDs(ctx, createdPerformer.ID, []models.StashID{ + err = r.Performer.UpdateStashIDs(ctx, newPerformer.ID, []models.StashID{ { Endpoint: t.box.Endpoint, StashID: *performer.RemoteSiteID, @@ -246,7 +224,7 @@ func (t *StashBoxPerformerTagTask) stashBoxPerformerTag(ctx context.Context) { if imageErr != nil { return imageErr } - err = r.Performer.UpdateImage(ctx, createdPerformer.ID, image) + err = r.Performer.UpdateImage(ctx, newPerformer.ID, image) } return err }) @@ -261,24 +239,25 @@ func (t *StashBoxPerformerTagTask) stashBoxPerformerTag(ctx context.Context) { if t.name != nil { name = *t.name } else if t.performer != nil { - name = t.performer.Name.String + name = t.performer.Name } logger.Infof("No match found for %s", name) } } -func getDate(val *string) models.SQLiteDate { +func getDate(val *string) *models.Date { if val == nil { - return models.SQLiteDate{Valid: false} - } else { - return models.SQLiteDate{String: *val, Valid: true} + return nil } + + ret := models.NewDate(*val) + return &ret } -func getNullString(val *string) sql.NullString { +func getString(val *string) string { if val == nil { - return sql.NullString{Valid: false} + return "" } else { - return sql.NullString{String: *val, Valid: true} + return *val } } diff --git a/pkg/gallery/import.go b/pkg/gallery/import.go index c324d8d72..c3bd83527 100644 --- a/pkg/gallery/import.go +++ b/pkg/gallery/import.go @@ -136,10 +136,10 @@ func (i *Importer) populatePerformers(ctx context.Context) error { var pluckedNames []string for _, performer := range performers { - if !performer.Name.Valid { + if performer.Name == "" { continue } - pluckedNames = append(pluckedNames, performer.Name.String) + pluckedNames = append(pluckedNames, performer.Name) } missingPerformers := stringslice.StrFilter(names, func(name string) bool { @@ -176,12 +176,12 @@ func (i *Importer) createPerformers(ctx context.Context, names []string) ([]*mod for _, name := range names { newPerformer := *models.NewPerformer(name) - created, err := i.PerformerWriter.Create(ctx, newPerformer) + err := i.PerformerWriter.Create(ctx, &newPerformer) if err != nil { return nil, err } - ret = append(ret, created) + ret = append(ret, &newPerformer) } return ret, nil diff --git a/pkg/gallery/import_test.go b/pkg/gallery/import_test.go index 43634fd13..73f2aed7d 100644 --- a/pkg/gallery/import_test.go +++ b/pkg/gallery/import_test.go @@ -169,7 +169,7 @@ func TestImporterPreImportWithPerformer(t *testing.T) { performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerName}, false).Return([]*models.Performer{ { ID: existingPerformerID, - Name: models.NullString(existingPerformerName), + Name: existingPerformerName, }, }, nil).Once() performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once() @@ -199,9 +199,10 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) { } performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3) - performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Performer")).Return(&models.Performer{ - ID: existingPerformerID, - }, nil) + performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Run(func(args mock.Arguments) { + performer := args.Get(1).(*models.Performer) + performer.ID = existingPerformerID + }).Return(nil) err := i.PreImport(testCtx) assert.NotNil(t, err) @@ -232,7 +233,7 @@ func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) { } performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once() - performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Performer")).Return(nil, errors.New("Create error")) + performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Return(errors.New("Create error")) err := i.PreImport(testCtx) assert.NotNil(t, err) diff --git a/pkg/image/import.go b/pkg/image/import.go index 7c19a5629..018dbf9fb 100644 --- a/pkg/image/import.go +++ b/pkg/image/import.go @@ -216,10 +216,10 @@ func (i *Importer) populatePerformers(ctx context.Context) error { var pluckedNames []string for _, performer := range performers { - if !performer.Name.Valid { + if performer.Name == "" { continue } - pluckedNames = append(pluckedNames, performer.Name.String) + pluckedNames = append(pluckedNames, performer.Name) } missingPerformers := stringslice.StrFilter(names, func(name string) bool { @@ -256,12 +256,12 @@ func (i *Importer) createPerformers(ctx context.Context, names []string) ([]*mod for _, name := range names { newPerformer := *models.NewPerformer(name) - created, err := i.PerformerWriter.Create(ctx, newPerformer) + err := i.PerformerWriter.Create(ctx, &newPerformer) if err != nil { return nil, err } - ret = append(ret, created) + ret = append(ret, &newPerformer) } return ret, nil diff --git a/pkg/image/import_test.go b/pkg/image/import_test.go index 647815127..6724b87cb 100644 --- a/pkg/image/import_test.go +++ b/pkg/image/import_test.go @@ -130,7 +130,7 @@ func TestImporterPreImportWithPerformer(t *testing.T) { performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerName}, false).Return([]*models.Performer{ { ID: existingPerformerID, - Name: models.NullString(existingPerformerName), + Name: existingPerformerName, }, }, nil).Once() performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once() @@ -160,9 +160,10 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) { } performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3) - performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Performer")).Return(&models.Performer{ - ID: existingPerformerID, - }, nil) + performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Run(func(args mock.Arguments) { + performer := args.Get(1).(*models.Performer) + performer.ID = existingPerformerID + }).Return(nil) err := i.PreImport(testCtx) assert.NotNil(t, err) @@ -193,7 +194,7 @@ func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) { } performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once() - performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Performer")).Return(nil, errors.New("Create error")) + performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Return(errors.New("Create error")) err := i.PreImport(testCtx) assert.NotNil(t, err) diff --git a/pkg/match/path.go b/pkg/match/path.go index e77fc2e59..b150809a3 100644 --- a/pkg/match/path.go +++ b/pkg/match/path.go @@ -169,7 +169,7 @@ func PathToPerformers(ctx context.Context, path string, reader PerformerAutoTagQ var ret []*models.Performer for _, p := range performers { // TODO - commenting out alias handling until both sides work correctly - if nameMatchesPath(p.Name.String, path) != -1 { // || nameMatchesPath(p.Aliases.String, path) { + if nameMatchesPath(p.Name, path) != -1 { // || nameMatchesPath(p.Aliases.String, path) { ret = append(ret, p) } } diff --git a/pkg/models/mocks/PerformerReaderWriter.go b/pkg/models/mocks/PerformerReaderWriter.go index f3fece8e6..cf1d965fa 100644 --- a/pkg/models/mocks/PerformerReaderWriter.go +++ b/pkg/models/mocks/PerformerReaderWriter.go @@ -80,26 +80,17 @@ func (_m *PerformerReaderWriter) CountByTagID(ctx context.Context, tagID int) (i } // Create provides a mock function with given fields: ctx, newPerformer -func (_m *PerformerReaderWriter) Create(ctx context.Context, newPerformer models.Performer) (*models.Performer, error) { +func (_m *PerformerReaderWriter) Create(ctx context.Context, newPerformer *models.Performer) error { ret := _m.Called(ctx, newPerformer) - var r0 *models.Performer - if rf, ok := ret.Get(0).(func(context.Context, models.Performer) *models.Performer); ok { + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *models.Performer) error); ok { r0 = rf(ctx, newPerformer) } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*models.Performer) - } + r0 = ret.Error(0) } - var r1 error - if rf, ok := ret.Get(1).(func(context.Context, models.Performer) error); ok { - r1 = rf(ctx, newPerformer) - } else { - r1 = ret.Error(1) - } - - return r0, r1 + return r0 } // Destroy provides a mock function with given fields: ctx, id @@ -314,29 +305,6 @@ func (_m *PerformerReaderWriter) FindMany(ctx context.Context, ids []int) ([]*mo return r0, r1 } -// FindNamesBySceneID provides a mock function with given fields: ctx, sceneID -func (_m *PerformerReaderWriter) FindNamesBySceneID(ctx context.Context, sceneID int) ([]*models.Performer, error) { - ret := _m.Called(ctx, sceneID) - - var r0 []*models.Performer - if rf, ok := ret.Get(0).(func(context.Context, int) []*models.Performer); ok { - r0 = rf(ctx, sceneID) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]*models.Performer) - } - } - - var r1 error - if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { - r1 = rf(ctx, sceneID) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - // GetImage provides a mock function with given fields: ctx, performerID func (_m *PerformerReaderWriter) GetImage(ctx context.Context, performerID int) ([]byte, error) { ret := _m.Called(ctx, performerID) @@ -460,49 +428,17 @@ func (_m *PerformerReaderWriter) QueryForAutoTag(ctx context.Context, words []st } // Update provides a mock function with given fields: ctx, updatedPerformer -func (_m *PerformerReaderWriter) Update(ctx context.Context, updatedPerformer models.PerformerPartial) (*models.Performer, error) { +func (_m *PerformerReaderWriter) Update(ctx context.Context, updatedPerformer *models.Performer) error { ret := _m.Called(ctx, updatedPerformer) - var r0 *models.Performer - if rf, ok := ret.Get(0).(func(context.Context, models.PerformerPartial) *models.Performer); ok { + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *models.Performer) error); ok { r0 = rf(ctx, updatedPerformer) } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*models.Performer) - } + r0 = ret.Error(0) } - var r1 error - if rf, ok := ret.Get(1).(func(context.Context, models.PerformerPartial) error); ok { - r1 = rf(ctx, updatedPerformer) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// UpdateFull provides a mock function with given fields: ctx, updatedPerformer -func (_m *PerformerReaderWriter) UpdateFull(ctx context.Context, updatedPerformer models.Performer) (*models.Performer, error) { - ret := _m.Called(ctx, updatedPerformer) - - var r0 *models.Performer - if rf, ok := ret.Get(0).(func(context.Context, models.Performer) *models.Performer); ok { - r0 = rf(ctx, updatedPerformer) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*models.Performer) - } - } - - var r1 error - if rf, ok := ret.Get(1).(func(context.Context, models.Performer) error); ok { - r1 = rf(ctx, updatedPerformer) - } else { - r1 = ret.Error(1) - } - - return r0, r1 + return r0 } // UpdateImage provides a mock function with given fields: ctx, performerID, image @@ -519,6 +455,29 @@ func (_m *PerformerReaderWriter) UpdateImage(ctx context.Context, performerID in return r0 } +// UpdatePartial provides a mock function with given fields: ctx, id, updatedPerformer +func (_m *PerformerReaderWriter) UpdatePartial(ctx context.Context, id int, updatedPerformer models.PerformerPartial) (*models.Performer, error) { + ret := _m.Called(ctx, id, updatedPerformer) + + var r0 *models.Performer + if rf, ok := ret.Get(0).(func(context.Context, int, models.PerformerPartial) *models.Performer); ok { + r0 = rf(ctx, id, updatedPerformer) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*models.Performer) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, int, models.PerformerPartial) error); ok { + r1 = rf(ctx, id, updatedPerformer) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // UpdateStashIDs provides a mock function with given fields: ctx, performerID, stashIDs func (_m *PerformerReaderWriter) UpdateStashIDs(ctx context.Context, performerID int, stashIDs []models.StashID) error { ret := _m.Called(ctx, performerID, stashIDs) diff --git a/pkg/models/model_performer.go b/pkg/models/model_performer.go index 0f10344f7..f8b7d9597 100644 --- a/pkg/models/model_performer.go +++ b/pkg/models/model_performer.go @@ -1,80 +1,87 @@ package models import ( - "database/sql" "time" "github.com/stashapp/stash/pkg/hash/md5" ) type Performer struct { - ID int `db:"id" json:"id"` - Checksum string `db:"checksum" json:"checksum"` - Name sql.NullString `db:"name" json:"name"` - Gender sql.NullString `db:"gender" json:"gender"` - URL sql.NullString `db:"url" json:"url"` - Twitter sql.NullString `db:"twitter" json:"twitter"` - Instagram sql.NullString `db:"instagram" json:"instagram"` - Birthdate SQLiteDate `db:"birthdate" json:"birthdate"` - Ethnicity sql.NullString `db:"ethnicity" json:"ethnicity"` - Country sql.NullString `db:"country" json:"country"` - EyeColor sql.NullString `db:"eye_color" json:"eye_color"` - Height sql.NullString `db:"height" json:"height"` - Measurements sql.NullString `db:"measurements" json:"measurements"` - FakeTits sql.NullString `db:"fake_tits" json:"fake_tits"` - CareerLength sql.NullString `db:"career_length" json:"career_length"` - Tattoos sql.NullString `db:"tattoos" json:"tattoos"` - Piercings sql.NullString `db:"piercings" json:"piercings"` - Aliases sql.NullString `db:"aliases" json:"aliases"` - Favorite sql.NullBool `db:"favorite" json:"favorite"` - CreatedAt SQLiteTimestamp `db:"created_at" json:"created_at"` - UpdatedAt SQLiteTimestamp `db:"updated_at" json:"updated_at"` - Rating sql.NullInt64 `db:"rating" json:"rating"` - Details sql.NullString `db:"details" json:"details"` - DeathDate SQLiteDate `db:"death_date" json:"death_date"` - HairColor sql.NullString `db:"hair_color" json:"hair_color"` - Weight sql.NullInt64 `db:"weight" json:"weight"` - IgnoreAutoTag bool `db:"ignore_auto_tag" json:"ignore_auto_tag"` + ID int `json:"id"` + Checksum string `json:"checksum"` + Name string `json:"name"` + Gender GenderEnum `json:"gender"` + URL string `json:"url"` + Twitter string `json:"twitter"` + Instagram string `json:"instagram"` + Birthdate *Date `json:"birthdate"` + Ethnicity string `json:"ethnicity"` + Country string `json:"country"` + EyeColor string `json:"eye_color"` + Height string `json:"height"` + Measurements string `json:"measurements"` + FakeTits string `json:"fake_tits"` + CareerLength string `json:"career_length"` + Tattoos string `json:"tattoos"` + Piercings string `json:"piercings"` + Aliases string `json:"aliases"` + Favorite bool `json:"favorite"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + Rating *int `json:"rating"` + Details string `json:"details"` + DeathDate *Date `json:"death_date"` + HairColor string `json:"hair_color"` + Weight *int `json:"weight"` + IgnoreAutoTag bool `json:"ignore_auto_tag"` } +// PerformerPartial represents part of a Performer object. It is used to update +// the database entry. type PerformerPartial struct { - ID int `db:"id" json:"id"` - Checksum *string `db:"checksum" json:"checksum"` - Name *sql.NullString `db:"name" json:"name"` - Gender *sql.NullString `db:"gender" json:"gender"` - URL *sql.NullString `db:"url" json:"url"` - Twitter *sql.NullString `db:"twitter" json:"twitter"` - Instagram *sql.NullString `db:"instagram" json:"instagram"` - Birthdate *SQLiteDate `db:"birthdate" json:"birthdate"` - Ethnicity *sql.NullString `db:"ethnicity" json:"ethnicity"` - Country *sql.NullString `db:"country" json:"country"` - EyeColor *sql.NullString `db:"eye_color" json:"eye_color"` - Height *sql.NullString `db:"height" json:"height"` - Measurements *sql.NullString `db:"measurements" json:"measurements"` - FakeTits *sql.NullString `db:"fake_tits" json:"fake_tits"` - CareerLength *sql.NullString `db:"career_length" json:"career_length"` - Tattoos *sql.NullString `db:"tattoos" json:"tattoos"` - Piercings *sql.NullString `db:"piercings" json:"piercings"` - Aliases *sql.NullString `db:"aliases" json:"aliases"` - Favorite *sql.NullBool `db:"favorite" json:"favorite"` - CreatedAt *SQLiteTimestamp `db:"created_at" json:"created_at"` - UpdatedAt *SQLiteTimestamp `db:"updated_at" json:"updated_at"` - Rating *sql.NullInt64 `db:"rating" json:"rating"` - Details *sql.NullString `db:"details" json:"details"` - DeathDate *SQLiteDate `db:"death_date" json:"death_date"` - HairColor *sql.NullString `db:"hair_color" json:"hair_color"` - Weight *sql.NullInt64 `db:"weight" json:"weight"` - IgnoreAutoTag *bool `db:"ignore_auto_tag" json:"ignore_auto_tag"` + ID int + Checksum OptionalString + Name OptionalString + Gender OptionalString + URL OptionalString + Twitter OptionalString + Instagram OptionalString + Birthdate OptionalDate + Ethnicity OptionalString + Country OptionalString + EyeColor OptionalString + Height OptionalString + Measurements OptionalString + FakeTits OptionalString + CareerLength OptionalString + Tattoos OptionalString + Piercings OptionalString + Aliases OptionalString + Favorite OptionalBool + CreatedAt OptionalTime + UpdatedAt OptionalTime + Rating OptionalInt + Details OptionalString + DeathDate OptionalDate + HairColor OptionalString + Weight OptionalInt + IgnoreAutoTag OptionalBool } func NewPerformer(name string) *Performer { currentTime := time.Now() return &Performer{ Checksum: md5.FromString(name), - Name: sql.NullString{String: name, Valid: true}, - Favorite: sql.NullBool{Bool: false, Valid: true}, - CreatedAt: SQLiteTimestamp{Timestamp: currentTime}, - UpdatedAt: SQLiteTimestamp{Timestamp: currentTime}, + Name: name, + CreatedAt: currentTime, + UpdatedAt: currentTime, + } +} + +func NewPerformerPartial() PerformerPartial { + updatedTime := time.Now() + return PerformerPartial{ + UpdatedAt: NewOptionalTime(updatedTime), } } diff --git a/pkg/models/performer.go b/pkg/models/performer.go index a00eea7fc..924ffb205 100644 --- a/pkg/models/performer.go +++ b/pkg/models/performer.go @@ -133,7 +133,6 @@ type PerformerReader interface { Find(ctx context.Context, id int) (*Performer, error) PerformerFinder FindBySceneID(ctx context.Context, sceneID int) ([]*Performer, error) - FindNamesBySceneID(ctx context.Context, sceneID int) ([]*Performer, error) FindByImageID(ctx context.Context, imageID int) ([]*Performer, error) FindByGalleryID(ctx context.Context, galleryID int) ([]*Performer, error) FindByNames(ctx context.Context, names []string, nocase bool) ([]*Performer, error) @@ -152,9 +151,9 @@ type PerformerReader interface { } type PerformerWriter interface { - Create(ctx context.Context, newPerformer Performer) (*Performer, error) - Update(ctx context.Context, updatedPerformer PerformerPartial) (*Performer, error) - UpdateFull(ctx context.Context, updatedPerformer Performer) (*Performer, error) + Create(ctx context.Context, newPerformer *Performer) error + UpdatePartial(ctx context.Context, id int, updatedPerformer PerformerPartial) (*Performer, error) + Update(ctx context.Context, updatedPerformer *Performer) error Destroy(ctx context.Context, id int) error UpdateImage(ctx context.Context, performerID int, image []byte) error DestroyImage(ctx context.Context, performerID int) error diff --git a/pkg/performer/export.go b/pkg/performer/export.go index 9a1a9c701..e366a0393 100644 --- a/pkg/performer/export.go +++ b/pkg/performer/export.go @@ -18,76 +18,40 @@ type ImageStashIDGetter interface { // ToJSON converts a Performer object into its JSON equivalent. func ToJSON(ctx context.Context, reader ImageStashIDGetter, performer *models.Performer) (*jsonschema.Performer, error) { newPerformerJSON := jsonschema.Performer{ + Name: performer.Name, + Gender: performer.Gender.String(), + URL: performer.URL, + Ethnicity: performer.Ethnicity, + Country: performer.Country, + EyeColor: performer.EyeColor, + Height: performer.Height, + Measurements: performer.Measurements, + FakeTits: performer.FakeTits, + CareerLength: performer.CareerLength, + Tattoos: performer.Tattoos, + Piercings: performer.Piercings, + Aliases: performer.Aliases, + Twitter: performer.Twitter, + Instagram: performer.Instagram, + Favorite: performer.Favorite, + Details: performer.Details, + HairColor: performer.HairColor, IgnoreAutoTag: performer.IgnoreAutoTag, - CreatedAt: json.JSONTime{Time: performer.CreatedAt.Timestamp}, - UpdatedAt: json.JSONTime{Time: performer.UpdatedAt.Timestamp}, + CreatedAt: json.JSONTime{Time: performer.CreatedAt}, + UpdatedAt: json.JSONTime{Time: performer.UpdatedAt}, } - if performer.Name.Valid { - newPerformerJSON.Name = performer.Name.String + if performer.Birthdate != nil { + newPerformerJSON.Birthdate = performer.Birthdate.String() } - if performer.Gender.Valid { - newPerformerJSON.Gender = performer.Gender.String + if performer.Rating != nil { + newPerformerJSON.Rating = *performer.Rating } - if performer.URL.Valid { - newPerformerJSON.URL = performer.URL.String + if performer.DeathDate != nil { + newPerformerJSON.DeathDate = performer.DeathDate.String() } - if performer.Birthdate.Valid { - newPerformerJSON.Birthdate = utils.GetYMDFromDatabaseDate(performer.Birthdate.String) - } - if performer.Ethnicity.Valid { - newPerformerJSON.Ethnicity = performer.Ethnicity.String - } - if performer.Country.Valid { - newPerformerJSON.Country = performer.Country.String - } - if performer.EyeColor.Valid { - newPerformerJSON.EyeColor = performer.EyeColor.String - } - if performer.Height.Valid { - newPerformerJSON.Height = performer.Height.String - } - if performer.Measurements.Valid { - newPerformerJSON.Measurements = performer.Measurements.String - } - if performer.FakeTits.Valid { - newPerformerJSON.FakeTits = performer.FakeTits.String - } - if performer.CareerLength.Valid { - newPerformerJSON.CareerLength = performer.CareerLength.String - } - if performer.Tattoos.Valid { - newPerformerJSON.Tattoos = performer.Tattoos.String - } - if performer.Piercings.Valid { - newPerformerJSON.Piercings = performer.Piercings.String - } - if performer.Aliases.Valid { - newPerformerJSON.Aliases = performer.Aliases.String - } - if performer.Twitter.Valid { - newPerformerJSON.Twitter = performer.Twitter.String - } - if performer.Instagram.Valid { - newPerformerJSON.Instagram = performer.Instagram.String - } - if performer.Favorite.Valid { - newPerformerJSON.Favorite = performer.Favorite.Bool - } - if performer.Rating.Valid { - newPerformerJSON.Rating = int(performer.Rating.Int64) - } - if performer.Details.Valid { - newPerformerJSON.Details = performer.Details.String - } - if performer.DeathDate.Valid { - newPerformerJSON.DeathDate = utils.GetYMDFromDatabaseDate(performer.DeathDate.String) - } - if performer.HairColor.Valid { - newPerformerJSON.HairColor = performer.HairColor.String - } - if performer.Weight.Valid { - newPerformerJSON.Weight = int(performer.Weight.Int64) + if performer.Weight != nil { + newPerformerJSON.Weight = *performer.Weight } image, err := reader.GetImage(ctx, performer.ID) @@ -126,8 +90,8 @@ func GetIDs(performers []*models.Performer) []int { func GetNames(performers []*models.Performer) []string { var results []string for _, performer := range performers { - if performer.Name.Valid { - results = append(results, performer.Name.String) + if performer.Name != "" { + results = append(results, performer.Name) } } diff --git a/pkg/performer/export_test.go b/pkg/performer/export_test.go index f328c0d8c..2271b6cbc 100644 --- a/pkg/performer/export_test.go +++ b/pkg/performer/export_test.go @@ -1,7 +1,6 @@ package performer import ( - "database/sql" "errors" "github.com/stashapp/stash/pkg/hash/md5" @@ -22,28 +21,32 @@ const ( ) const ( - performerName = "testPerformer" - url = "url" - aliases = "aliases" - careerLength = "careerLength" - country = "country" - ethnicity = "ethnicity" - eyeColor = "eyeColor" - fakeTits = "fakeTits" - gender = "gender" - height = "height" - instagram = "instagram" - measurements = "measurements" - piercings = "piercings" - tattoos = "tattoos" - twitter = "twitter" - rating = 5 - details = "details" - hairColor = "hairColor" - weight = 60 + performerName = "testPerformer" + url = "url" + aliases = "aliases" + careerLength = "careerLength" + country = "country" + ethnicity = "ethnicity" + eyeColor = "eyeColor" + fakeTits = "fakeTits" + gender = "gender" + height = "height" + instagram = "instagram" + measurements = "measurements" + piercings = "piercings" + tattoos = "tattoos" + twitter = "twitter" + details = "details" + hairColor = "hairColor" + autoTagIgnored = true ) +var ( + rating = 5 + weight = 60 +) + var imageBytes = []byte("imageBytes") var stashID = models.StashID{ @@ -56,14 +59,8 @@ var stashIDs = []models.StashID{ const image = "aW1hZ2VCeXRlcw==" -var birthDate = models.SQLiteDate{ - String: "2001-01-01", - Valid: true, -} -var deathDate = models.SQLiteDate{ - String: "2021-02-02", - Valid: true, -} +var birthDate = models.NewDate("2001-01-01") +var deathDate = models.NewDate("2021-02-02") var ( createTime = time.Date(2001, 01, 01, 0, 0, 0, 0, time.Local) @@ -72,55 +69,41 @@ var ( func createFullPerformer(id int, name string) *models.Performer { return &models.Performer{ - ID: id, - Name: models.NullString(name), - Checksum: md5.FromString(name), - URL: models.NullString(url), - Aliases: models.NullString(aliases), - Birthdate: birthDate, - CareerLength: models.NullString(careerLength), - Country: models.NullString(country), - Ethnicity: models.NullString(ethnicity), - EyeColor: models.NullString(eyeColor), - FakeTits: models.NullString(fakeTits), - Favorite: sql.NullBool{ - Bool: true, - Valid: true, - }, - Gender: models.NullString(gender), - Height: models.NullString(height), - Instagram: models.NullString(instagram), - Measurements: models.NullString(measurements), - Piercings: models.NullString(piercings), - Tattoos: models.NullString(tattoos), - Twitter: models.NullString(twitter), - CreatedAt: models.SQLiteTimestamp{ - Timestamp: createTime, - }, - UpdatedAt: models.SQLiteTimestamp{ - Timestamp: updateTime, - }, - Rating: models.NullInt64(rating), - Details: models.NullString(details), - DeathDate: deathDate, - HairColor: models.NullString(hairColor), - Weight: sql.NullInt64{ - Int64: weight, - Valid: true, - }, + ID: id, + Name: name, + Checksum: md5.FromString(name), + URL: url, + Aliases: aliases, + Birthdate: &birthDate, + CareerLength: careerLength, + Country: country, + Ethnicity: ethnicity, + EyeColor: eyeColor, + FakeTits: fakeTits, + Favorite: true, + Gender: gender, + Height: height, + Instagram: instagram, + Measurements: measurements, + Piercings: piercings, + Tattoos: tattoos, + Twitter: twitter, + CreatedAt: createTime, + UpdatedAt: updateTime, + Rating: &rating, + Details: details, + DeathDate: &deathDate, + HairColor: hairColor, + Weight: &weight, IgnoreAutoTag: autoTagIgnored, } } func createEmptyPerformer(id int) models.Performer { return models.Performer{ - ID: id, - CreatedAt: models.SQLiteTimestamp{ - Timestamp: createTime, - }, - UpdatedAt: models.SQLiteTimestamp{ - Timestamp: updateTime, - }, + ID: id, + CreatedAt: createTime, + UpdatedAt: updateTime, } } @@ -129,7 +112,7 @@ func createFullJSONPerformer(name string, image string) *jsonschema.Performer { Name: name, URL: url, Aliases: aliases, - Birthdate: birthDate.String, + Birthdate: birthDate.String(), CareerLength: careerLength, Country: country, Ethnicity: ethnicity, @@ -152,7 +135,7 @@ func createFullJSONPerformer(name string, image string) *jsonschema.Performer { Rating: rating, Image: image, Details: details, - DeathDate: deathDate.String, + DeathDate: deathDate.String(), HairColor: hairColor, Weight: weight, StashIDs: []models.StashID{ diff --git a/pkg/performer/import.go b/pkg/performer/import.go index 7c673fb34..05b4c4b2c 100644 --- a/pkg/performer/import.go +++ b/pkg/performer/import.go @@ -2,7 +2,6 @@ package performer import ( "context" - "database/sql" "fmt" "strings" @@ -16,7 +15,7 @@ import ( type NameFinderCreatorUpdater interface { NameFinderCreator - UpdateFull(ctx context.Context, updatedPerformer models.Performer) (*models.Performer, error) + Update(ctx context.Context, updatedPerformer *models.Performer) error UpdateTags(ctx context.Context, performerID int, tagIDs []int) error UpdateImage(ctx context.Context, performerID int, image []byte) error UpdateStashIDs(ctx context.Context, performerID int, stashIDs []models.StashID) error @@ -164,19 +163,19 @@ func (i *Importer) FindExistingID(ctx context.Context) (*int, error) { } func (i *Importer) Create(ctx context.Context) (*int, error) { - created, err := i.ReaderWriter.Create(ctx, i.performer) + err := i.ReaderWriter.Create(ctx, &i.performer) if err != nil { return nil, fmt.Errorf("error creating performer: %v", err) } - id := created.ID + id := i.performer.ID return &id, nil } func (i *Importer) Update(ctx context.Context, id int) error { performer := i.performer performer.ID = id - _, err := i.ReaderWriter.UpdateFull(ctx, performer) + err := i.ReaderWriter.Update(ctx, &performer) if err != nil { return fmt.Errorf("error updating existing performer: %v", err) } @@ -188,75 +187,52 @@ func performerJSONToPerformer(performerJSON jsonschema.Performer) models.Perform checksum := md5.FromString(performerJSON.Name) newPerformer := models.Performer{ + Name: performerJSON.Name, Checksum: checksum, - Favorite: sql.NullBool{Bool: performerJSON.Favorite, Valid: true}, + Gender: models.GenderEnum(performerJSON.Gender), + URL: performerJSON.URL, + Ethnicity: performerJSON.Ethnicity, + Country: performerJSON.Country, + EyeColor: performerJSON.EyeColor, + Height: performerJSON.Height, + Measurements: performerJSON.Measurements, + FakeTits: performerJSON.FakeTits, + CareerLength: performerJSON.CareerLength, + Tattoos: performerJSON.Tattoos, + Piercings: performerJSON.Piercings, + Aliases: performerJSON.Aliases, + Twitter: performerJSON.Twitter, + Instagram: performerJSON.Instagram, + Details: performerJSON.Details, + HairColor: performerJSON.HairColor, + Favorite: performerJSON.Favorite, IgnoreAutoTag: performerJSON.IgnoreAutoTag, - CreatedAt: models.SQLiteTimestamp{Timestamp: performerJSON.CreatedAt.GetTime()}, - UpdatedAt: models.SQLiteTimestamp{Timestamp: performerJSON.UpdatedAt.GetTime()}, + CreatedAt: performerJSON.CreatedAt.GetTime(), + UpdatedAt: performerJSON.UpdatedAt.GetTime(), } - if performerJSON.Name != "" { - newPerformer.Name = sql.NullString{String: performerJSON.Name, Valid: true} - } - if performerJSON.Gender != "" { - newPerformer.Gender = sql.NullString{String: performerJSON.Gender, Valid: true} - } - if performerJSON.URL != "" { - newPerformer.URL = sql.NullString{String: performerJSON.URL, Valid: true} - } if performerJSON.Birthdate != "" { - newPerformer.Birthdate = models.SQLiteDate{String: performerJSON.Birthdate, Valid: true} - } - if performerJSON.Ethnicity != "" { - newPerformer.Ethnicity = sql.NullString{String: performerJSON.Ethnicity, Valid: true} - } - if performerJSON.Country != "" { - newPerformer.Country = sql.NullString{String: performerJSON.Country, Valid: true} - } - if performerJSON.EyeColor != "" { - newPerformer.EyeColor = sql.NullString{String: performerJSON.EyeColor, Valid: true} - } - if performerJSON.Height != "" { - newPerformer.Height = sql.NullString{String: performerJSON.Height, Valid: true} - } - if performerJSON.Measurements != "" { - newPerformer.Measurements = sql.NullString{String: performerJSON.Measurements, Valid: true} - } - if performerJSON.FakeTits != "" { - newPerformer.FakeTits = sql.NullString{String: performerJSON.FakeTits, Valid: true} - } - if performerJSON.CareerLength != "" { - newPerformer.CareerLength = sql.NullString{String: performerJSON.CareerLength, Valid: true} - } - if performerJSON.Tattoos != "" { - newPerformer.Tattoos = sql.NullString{String: performerJSON.Tattoos, Valid: true} - } - if performerJSON.Piercings != "" { - newPerformer.Piercings = sql.NullString{String: performerJSON.Piercings, Valid: true} - } - if performerJSON.Aliases != "" { - newPerformer.Aliases = sql.NullString{String: performerJSON.Aliases, Valid: true} - } - if performerJSON.Twitter != "" { - newPerformer.Twitter = sql.NullString{String: performerJSON.Twitter, Valid: true} - } - if performerJSON.Instagram != "" { - newPerformer.Instagram = sql.NullString{String: performerJSON.Instagram, Valid: true} + d, err := utils.ParseDateStringAsTime(performerJSON.Birthdate) + if err == nil { + newPerformer.Birthdate = &models.Date{ + Time: d, + } + } } if performerJSON.Rating != 0 { - newPerformer.Rating = sql.NullInt64{Int64: int64(performerJSON.Rating), Valid: true} - } - if performerJSON.Details != "" { - newPerformer.Details = sql.NullString{String: performerJSON.Details, Valid: true} + newPerformer.Rating = &performerJSON.Rating } if performerJSON.DeathDate != "" { - newPerformer.DeathDate = models.SQLiteDate{String: performerJSON.DeathDate, Valid: true} - } - if performerJSON.HairColor != "" { - newPerformer.HairColor = sql.NullString{String: performerJSON.HairColor, Valid: true} + d, err := utils.ParseDateStringAsTime(performerJSON.DeathDate) + if err == nil { + newPerformer.DeathDate = &models.Date{ + Time: d, + } + } } + if performerJSON.Weight != 0 { - newPerformer.Weight = sql.NullInt64{Int64: int64(performerJSON.Weight), Valid: true} + newPerformer.Weight = &performerJSON.Weight } return newPerformer diff --git a/pkg/performer/import_test.go b/pkg/performer/import_test.go index 4f80a67c0..08f2c5b0c 100644 --- a/pkg/performer/import_test.go +++ b/pkg/performer/import_test.go @@ -237,11 +237,11 @@ func TestCreate(t *testing.T) { readerWriter := &mocks.PerformerReaderWriter{} performer := models.Performer{ - Name: models.NullString(performerName), + Name: performerName, } performerErr := models.Performer{ - Name: models.NullString(performerNameErr), + Name: performerNameErr, } i := Importer{ @@ -250,10 +250,11 @@ func TestCreate(t *testing.T) { } errCreate := errors.New("Create error") - readerWriter.On("Create", testCtx, performer).Return(&models.Performer{ - ID: performerID, - }, nil).Once() - readerWriter.On("Create", testCtx, performerErr).Return(nil, errCreate).Once() + readerWriter.On("Create", testCtx, &performer).Run(func(args mock.Arguments) { + arg := args.Get(1).(*models.Performer) + arg.ID = performerID + }).Return(nil).Once() + readerWriter.On("Create", testCtx, &performerErr).Return(errCreate).Once() id, err := i.Create(testCtx) assert.Equal(t, performerID, *id) @@ -271,11 +272,11 @@ func TestUpdate(t *testing.T) { readerWriter := &mocks.PerformerReaderWriter{} performer := models.Performer{ - Name: models.NullString(performerName), + Name: performerName, } performerErr := models.Performer{ - Name: models.NullString(performerNameErr), + Name: performerNameErr, } i := Importer{ @@ -287,7 +288,7 @@ func TestUpdate(t *testing.T) { // id needs to be set for the mock input performer.ID = performerID - readerWriter.On("UpdateFull", testCtx, performer).Return(nil, nil).Once() + readerWriter.On("Update", testCtx, &performer).Return(nil).Once() err := i.Update(testCtx, performerID) assert.Nil(t, err) @@ -296,7 +297,7 @@ func TestUpdate(t *testing.T) { // need to set id separately performerErr.ID = errImageID - readerWriter.On("UpdateFull", testCtx, performerErr).Return(nil, errUpdate).Once() + readerWriter.On("Update", testCtx, &performerErr).Return(errUpdate).Once() err = i.Update(testCtx, errImageID) assert.NotNil(t, err) diff --git a/pkg/performer/update.go b/pkg/performer/update.go index 5974a5eab..ed10246fa 100644 --- a/pkg/performer/update.go +++ b/pkg/performer/update.go @@ -8,5 +8,5 @@ import ( type NameFinderCreator interface { FindByNames(ctx context.Context, names []string, nocase bool) ([]*models.Performer, error) - Create(ctx context.Context, newPerformer models.Performer) (*models.Performer, error) + Create(ctx context.Context, newPerformer *models.Performer) error } diff --git a/pkg/performer/validate.go b/pkg/performer/validate.go index 374262590..4c0b5a919 100644 --- a/pkg/performer/validate.go +++ b/pkg/performer/validate.go @@ -14,11 +14,13 @@ func ValidateDeathDate(performer *models.Performer, birthdate *string, deathDate } if performer != nil { - if birthdate == nil && performer.Birthdate.Valid { - birthdate = &performer.Birthdate.String + if birthdate == nil && performer.Birthdate != nil { + s := performer.Birthdate.String() + birthdate = &s } - if deathDate == nil && performer.DeathDate.Valid { - deathDate = &performer.DeathDate.String + if deathDate == nil && performer.DeathDate != nil { + s := performer.DeathDate.String() + deathDate = &s } } diff --git a/pkg/performer/validate_test.go b/pkg/performer/validate_test.go index 33616e184..dbbb7fc98 100644 --- a/pkg/performer/validate_test.go +++ b/pkg/performer/validate_test.go @@ -16,26 +16,17 @@ func TestValidateDeathDate(t *testing.T) { date4 := "2004-01-01" empty := "" + md2 := models.NewDate(date2) + md3 := models.NewDate(date3) + emptyPerformer := models.Performer{} invalidPerformer := models.Performer{ - Birthdate: models.SQLiteDate{ - String: date3, - Valid: true, - }, - DeathDate: models.SQLiteDate{ - String: date2, - Valid: true, - }, + Birthdate: &md3, + DeathDate: &md2, } validPerformer := models.Performer{ - Birthdate: models.SQLiteDate{ - String: date2, - Valid: true, - }, - DeathDate: models.SQLiteDate{ - String: date3, - Valid: true, - }, + Birthdate: &md2, + DeathDate: &md3, } // nil values should always return nil diff --git a/pkg/scene/import.go b/pkg/scene/import.go index 79d95aa04..1ca7d2b58 100644 --- a/pkg/scene/import.go +++ b/pkg/scene/import.go @@ -231,10 +231,10 @@ func (i *Importer) populatePerformers(ctx context.Context) error { var pluckedNames []string for _, performer := range performers { - if !performer.Name.Valid { + if performer.Name == "" { continue } - pluckedNames = append(pluckedNames, performer.Name.String) + pluckedNames = append(pluckedNames, performer.Name) } missingPerformers := stringslice.StrFilter(names, func(name string) bool { @@ -271,12 +271,12 @@ func (i *Importer) createPerformers(ctx context.Context, names []string) ([]*mod for _, name := range names { newPerformer := *models.NewPerformer(name) - created, err := i.PerformerWriter.Create(ctx, newPerformer) + err := i.PerformerWriter.Create(ctx, &newPerformer) if err != nil { return nil, err } - ret = append(ret, created) + ret = append(ret, &newPerformer) } return ret, nil diff --git a/pkg/scene/import_test.go b/pkg/scene/import_test.go index 5a5fd5026..2e4d65f05 100644 --- a/pkg/scene/import_test.go +++ b/pkg/scene/import_test.go @@ -147,7 +147,7 @@ func TestImporterPreImportWithPerformer(t *testing.T) { performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerName}, false).Return([]*models.Performer{ { ID: existingPerformerID, - Name: models.NullString(existingPerformerName), + Name: existingPerformerName, }, }, nil).Once() performerReaderWriter.On("FindByNames", testCtx, []string{existingPerformerErr}, false).Return(nil, errors.New("FindByNames error")).Once() @@ -177,9 +177,10 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) { } performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3) - performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Performer")).Return(&models.Performer{ - ID: existingPerformerID, - }, nil) + performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Run(func(args mock.Arguments) { + p := args.Get(1).(*models.Performer) + p.ID = existingPerformerID + }).Return(nil) err := i.PreImport(testCtx) assert.NotNil(t, err) @@ -210,7 +211,7 @@ func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) { } performerReaderWriter.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once() - performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("models.Performer")).Return(nil, errors.New("Create error")) + performerReaderWriter.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Return(errors.New("Create error")) err := i.PreImport(testCtx) assert.NotNil(t, err) diff --git a/pkg/scraper/autotag.go b/pkg/scraper/autotag.go index cbcd38cfa..5af7b8c70 100644 --- a/pkg/scraper/autotag.go +++ b/pkg/scraper/autotag.go @@ -38,11 +38,12 @@ func autotagMatchPerformers(ctx context.Context, path string, performerReader ma id := strconv.Itoa(pp.ID) sp := &models.ScrapedPerformer{ - Name: &pp.Name.String, + Name: &pp.Name, StoredID: &id, } - if pp.Gender.Valid { - sp.Gender = &pp.Gender.String + if pp.Gender.IsValid() { + v := pp.Gender.String() + sp.Gender = &v } ret = append(ret, sp) diff --git a/pkg/scraper/stashbox/stash_box.go b/pkg/scraper/stashbox/stash_box.go index a560c25d4..a0323dfce 100644 --- a/pkg/scraper/stashbox/stash_box.go +++ b/pkg/scraper/stashbox/stash_box.go @@ -379,7 +379,7 @@ func (c Client) FindStashBoxPerformersByNames(ctx context.Context, performerIDs return fmt.Errorf("performer with id %d not found", performerID) } - if performer.Name.Valid { + if performer.Name != "" { performers = append(performers, performer) } } @@ -413,7 +413,7 @@ func (c Client) FindStashBoxPerformersByPerformerNames(ctx context.Context, perf return fmt.Errorf("performer with id %d not found", performerID) } - if performer.Name.Valid { + if performer.Name != "" { performers = append(performers, performer) } } @@ -439,8 +439,8 @@ func (c Client) FindStashBoxPerformersByPerformerNames(ctx context.Context, perf func (c Client) findStashBoxPerformersByNames(ctx context.Context, performers []*models.Performer) ([]*StashBoxPerformerQueryResult, error) { var ret []*StashBoxPerformerQueryResult for _, performer := range performers { - if performer.Name.Valid { - performerResults, err := c.queryStashBoxPerformer(ctx, performer.Name.String) + if performer.Name != "" { + performerResults, err := c.queryStashBoxPerformer(ctx, performer.Name) if err != nil { return nil, err } @@ -867,7 +867,7 @@ func (c Client) SubmitSceneDraft(ctx context.Context, scene *models.Scene, endpo performers := []*graphql.DraftEntityInput{} for _, p := range scenePerformers { performerDraft := graphql.DraftEntityInput{ - Name: p.Name.String, + Name: p.Name, } stashIDs, err := pqb.GetStashIDs(ctx, p.ID) @@ -947,55 +947,57 @@ func (c Client) SubmitPerformerDraft(ctx context.Context, performer *models.Perf image = bytes.NewReader(img) } - if performer.Name.Valid { - draft.Name = performer.Name.String + if performer.Name != "" { + draft.Name = performer.Name } - if performer.Birthdate.Valid { - draft.Birthdate = &performer.Birthdate.String + if performer.Birthdate != nil { + d := performer.Birthdate.String() + draft.Birthdate = &d } - if performer.Country.Valid { - draft.Country = &performer.Country.String + if performer.Country != "" { + draft.Country = &performer.Country } - if performer.Ethnicity.Valid { - draft.Ethnicity = &performer.Ethnicity.String + if performer.Ethnicity != "" { + draft.Ethnicity = &performer.Ethnicity } - if performer.EyeColor.Valid { - draft.EyeColor = &performer.EyeColor.String + if performer.EyeColor != "" { + draft.EyeColor = &performer.EyeColor } - if performer.FakeTits.Valid { - draft.BreastType = &performer.FakeTits.String + if performer.FakeTits != "" { + draft.BreastType = &performer.FakeTits } - if performer.Gender.Valid { - draft.Gender = &performer.Gender.String + if performer.Gender.IsValid() { + v := performer.Gender.String() + draft.Gender = &v } - if performer.HairColor.Valid { - draft.HairColor = &performer.HairColor.String + if performer.HairColor != "" { + draft.HairColor = &performer.HairColor } - if performer.Height.Valid { - draft.Height = &performer.Height.String + if performer.Height != "" { + draft.Height = &performer.Height } - if performer.Measurements.Valid { - draft.Measurements = &performer.Measurements.String + if performer.Measurements != "" { + draft.Measurements = &performer.Measurements } - if performer.Piercings.Valid { - draft.Piercings = &performer.Piercings.String + if performer.Piercings != "" { + draft.Piercings = &performer.Piercings } - if performer.Tattoos.Valid { - draft.Tattoos = &performer.Tattoos.String + if performer.Tattoos != "" { + draft.Tattoos = &performer.Tattoos } - if performer.Aliases.Valid { - draft.Aliases = &performer.Aliases.String + if performer.Aliases != "" { + draft.Aliases = &performer.Aliases } var urls []string - if len(strings.TrimSpace(performer.Twitter.String)) > 0 { - urls = append(urls, "https://twitter.com/"+strings.TrimSpace(performer.Twitter.String)) + if len(strings.TrimSpace(performer.Twitter)) > 0 { + urls = append(urls, "https://twitter.com/"+strings.TrimSpace(performer.Twitter)) } - if len(strings.TrimSpace(performer.Instagram.String)) > 0 { - urls = append(urls, "https://instagram.com/"+strings.TrimSpace(performer.Instagram.String)) + if len(strings.TrimSpace(performer.Instagram)) > 0 { + urls = append(urls, "https://instagram.com/"+strings.TrimSpace(performer.Instagram)) } - if len(strings.TrimSpace(performer.URL.String)) > 0 { - urls = append(urls, strings.TrimSpace(performer.URL.String)) + if len(strings.TrimSpace(performer.URL)) > 0 { + urls = append(urls, strings.TrimSpace(performer.URL)) } if len(urls) > 0 { draft.Urls = urls diff --git a/pkg/sqlite/database.go b/pkg/sqlite/database.go index 33f604cda..9f15d0a3b 100644 --- a/pkg/sqlite/database.go +++ b/pkg/sqlite/database.go @@ -61,11 +61,12 @@ func init() { } type Database struct { - File *FileStore - Folder *FolderStore - Image *ImageStore - Gallery *GalleryStore - Scene *SceneStore + File *FileStore + Folder *FolderStore + Image *ImageStore + Gallery *GalleryStore + Scene *SceneStore + Performer *PerformerStore db *sqlx.DB dbPath string @@ -80,11 +81,12 @@ func NewDatabase() *Database { folderStore := NewFolderStore() ret := &Database{ - File: fileStore, - Folder: folderStore, - Scene: NewSceneStore(fileStore), - Image: NewImageStore(fileStore), - Gallery: NewGalleryStore(fileStore, folderStore), + File: fileStore, + Folder: folderStore, + Scene: NewSceneStore(fileStore), + Image: NewImageStore(fileStore), + Gallery: NewGalleryStore(fileStore, folderStore), + Performer: NewPerformerStore(), } return ret diff --git a/pkg/sqlite/performer.go b/pkg/sqlite/performer.go index 6bf42dc18..3cf6fc15b 100644 --- a/pkg/sqlite/performer.go +++ b/pkg/sqlite/performer.go @@ -3,15 +3,17 @@ package sqlite import ( "context" "database/sql" - "errors" "fmt" "strings" "github.com/doug-martin/goqu/v9" + "github.com/doug-martin/goqu/v9/exp" "github.com/jmoiron/sqlx" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/sliceutil/intslice" "github.com/stashapp/stash/pkg/utils" + "gopkg.in/guregu/null.v4" + "gopkg.in/guregu/null.v4/zero" ) const performerTable = "performers" @@ -19,82 +21,227 @@ const performerIDColumn = "performer_id" const performersTagsTable = "performers_tags" const performersImageTable = "performers_image" // performer cover image -var countPerformersForTagQuery = ` -SELECT tag_id AS id FROM performers_tags -WHERE performers_tags.tag_id = ? -GROUP BY performers_tags.performer_id -` +type performerRow struct { + ID int `db:"id" goqu:"skipinsert"` + Checksum string `db:"checksum"` + Name zero.String `db:"name"` + Gender zero.String `db:"gender"` + URL zero.String `db:"url"` + Twitter zero.String `db:"twitter"` + Instagram zero.String `db:"instagram"` + Birthdate models.SQLiteDate `db:"birthdate"` + Ethnicity zero.String `db:"ethnicity"` + Country zero.String `db:"country"` + EyeColor zero.String `db:"eye_color"` + Height zero.String `db:"height"` + Measurements zero.String `db:"measurements"` + FakeTits zero.String `db:"fake_tits"` + CareerLength zero.String `db:"career_length"` + Tattoos zero.String `db:"tattoos"` + Piercings zero.String `db:"piercings"` + Aliases zero.String `db:"aliases"` + Favorite sql.NullBool `db:"favorite"` + CreatedAt models.SQLiteTimestamp `db:"created_at"` + UpdatedAt models.SQLiteTimestamp `db:"updated_at"` + Rating null.Int `db:"rating"` + Details zero.String `db:"details"` + DeathDate models.SQLiteDate `db:"death_date"` + HairColor zero.String `db:"hair_color"` + Weight null.Int `db:"weight"` + IgnoreAutoTag bool `db:"ignore_auto_tag"` +} -type performerQueryBuilder struct { +func (r *performerRow) fromPerformer(o models.Performer) { + r.ID = o.ID + r.Checksum = o.Checksum + r.Name = zero.StringFrom(o.Name) + if o.Gender.IsValid() { + r.Gender = zero.StringFrom(o.Gender.String()) + } + r.URL = zero.StringFrom(o.URL) + r.Twitter = zero.StringFrom(o.Twitter) + r.Instagram = zero.StringFrom(o.Instagram) + if o.Birthdate != nil { + _ = r.Birthdate.Scan(o.Birthdate.Time) + } + r.Ethnicity = zero.StringFrom(o.Ethnicity) + r.Country = zero.StringFrom(o.Country) + r.EyeColor = zero.StringFrom(o.EyeColor) + r.Height = zero.StringFrom(o.Height) + r.Measurements = zero.StringFrom(o.Measurements) + r.FakeTits = zero.StringFrom(o.FakeTits) + r.CareerLength = zero.StringFrom(o.CareerLength) + r.Tattoos = zero.StringFrom(o.Tattoos) + r.Piercings = zero.StringFrom(o.Piercings) + r.Aliases = zero.StringFrom(o.Aliases) + r.Favorite = sql.NullBool{Bool: o.Favorite, Valid: true} + r.CreatedAt = models.SQLiteTimestamp{Timestamp: o.CreatedAt} + r.UpdatedAt = models.SQLiteTimestamp{Timestamp: o.UpdatedAt} + r.Rating = intFromPtr(o.Rating) + r.Details = zero.StringFrom(o.Details) + if o.DeathDate != nil { + _ = r.DeathDate.Scan(o.DeathDate.Time) + } + r.HairColor = zero.StringFrom(o.HairColor) + r.Weight = intFromPtr(o.Weight) + r.IgnoreAutoTag = o.IgnoreAutoTag +} + +func (r *performerRow) resolve() *models.Performer { + ret := &models.Performer{ + ID: r.ID, + Checksum: r.Checksum, + Name: r.Name.String, + Gender: models.GenderEnum(r.Gender.String), + URL: r.URL.String, + Twitter: r.Twitter.String, + Instagram: r.Instagram.String, + Birthdate: r.Birthdate.DatePtr(), + Ethnicity: r.Ethnicity.String, + Country: r.Country.String, + EyeColor: r.EyeColor.String, + Height: r.Height.String, + Measurements: r.Measurements.String, + FakeTits: r.FakeTits.String, + CareerLength: r.CareerLength.String, + Tattoos: r.Tattoos.String, + Piercings: r.Piercings.String, + Aliases: r.Aliases.String, + Favorite: r.Favorite.Bool, + CreatedAt: r.CreatedAt.Timestamp, + UpdatedAt: r.UpdatedAt.Timestamp, + Rating: nullIntPtr(r.Rating), + Details: r.Details.String, + DeathDate: r.DeathDate.DatePtr(), + HairColor: r.HairColor.String, + Weight: nullIntPtr(r.Weight), + IgnoreAutoTag: r.IgnoreAutoTag, + } + + return ret +} + +type performerRowRecord struct { + updateRecord +} + +func (r *performerRowRecord) fromPartial(o models.PerformerPartial) { + r.setNullString("checksum", o.Checksum) + r.setNullString("name", o.Name) + r.setNullString("gender", o.Gender) + r.setNullString("url", o.URL) + r.setNullString("twitter", o.Twitter) + r.setNullString("instagram", o.Instagram) + r.setSQLiteDate("birthdate", o.Birthdate) + r.setNullString("ethnicity", o.Ethnicity) + r.setNullString("country", o.Country) + r.setNullString("eye_color", o.EyeColor) + r.setNullString("height", o.Height) + r.setNullString("measurements", o.Measurements) + r.setNullString("fake_tits", o.FakeTits) + r.setNullString("career_length", o.CareerLength) + r.setNullString("tattoos", o.Tattoos) + r.setNullString("piercings", o.Piercings) + r.setNullString("aliases", o.Aliases) + r.setBool("favorite", o.Favorite) + r.setSQLiteTimestamp("created_at", o.CreatedAt) + r.setSQLiteTimestamp("updated_at", o.UpdatedAt) + r.setNullInt("rating", o.Rating) + r.setNullString("details", o.Details) + r.setSQLiteDate("death_date", o.DeathDate) + r.setNullString("hair_color", o.HairColor) + r.setNullInt("weight", o.Weight) + r.setBool("ignore_auto_tag", o.IgnoreAutoTag) +} + +type PerformerStore struct { repository + + tableMgr *table } -var PerformerReaderWriter = &performerQueryBuilder{ - repository{ - tableName: performerTable, - idColumn: idColumn, - }, +func NewPerformerStore() *PerformerStore { + return &PerformerStore{ + repository: repository{ + tableName: performerTable, + idColumn: idColumn, + }, + tableMgr: performerTableMgr, + } } -func (qb *performerQueryBuilder) Create(ctx context.Context, newObject models.Performer) (*models.Performer, error) { - var ret models.Performer - if err := qb.insertObject(ctx, newObject, &ret); err != nil { - return nil, err - } +func (qb *PerformerStore) Create(ctx context.Context, newObject *models.Performer) error { + var r performerRow + r.fromPerformer(*newObject) - return &ret, nil -} - -func (qb *performerQueryBuilder) Update(ctx context.Context, updatedObject models.PerformerPartial) (*models.Performer, error) { - const partial = true - if err := qb.update(ctx, updatedObject.ID, updatedObject, partial); err != nil { - return nil, err - } - - var ret models.Performer - if err := qb.getByID(ctx, updatedObject.ID, &ret); err != nil { - return nil, err - } - - return &ret, nil -} - -func (qb *performerQueryBuilder) UpdateFull(ctx context.Context, updatedObject models.Performer) (*models.Performer, error) { - const partial = false - if err := qb.update(ctx, updatedObject.ID, updatedObject, partial); err != nil { - return nil, err - } - - var ret models.Performer - if err := qb.getByID(ctx, updatedObject.ID, &ret); err != nil { - return nil, err - } - - return &ret, nil -} - -func (qb *performerQueryBuilder) Destroy(ctx context.Context, id int) error { - // TODO - add on delete cascade to performers_scenes - _, err := qb.tx.Exec(ctx, "DELETE FROM performers_scenes WHERE performer_id = ?", id) + id, err := qb.tableMgr.insertID(ctx, r) if err != nil { return err } + updated, err := qb.Find(ctx, id) + if err != nil { + return fmt.Errorf("finding after create: %w", err) + } + + *newObject = *updated + + return nil +} + +func (qb *PerformerStore) UpdatePartial(ctx context.Context, id int, updatedObject models.PerformerPartial) (*models.Performer, error) { + r := performerRowRecord{ + updateRecord{ + Record: make(exp.Record), + }, + } + + r.fromPartial(updatedObject) + + if len(r.Record) > 0 { + if err := qb.tableMgr.updateByID(ctx, id, r.Record); err != nil { + return nil, err + } + } + + return qb.Find(ctx, id) +} + +func (qb *PerformerStore) Update(ctx context.Context, updatedObject *models.Performer) error { + var r performerRow + r.fromPerformer(*updatedObject) + + if err := qb.tableMgr.updateByID(ctx, updatedObject.ID, r); err != nil { + return err + } + + return nil +} + +func (qb *PerformerStore) Destroy(ctx context.Context, id int) error { return qb.destroyExisting(ctx, []int{id}) } -func (qb *performerQueryBuilder) Find(ctx context.Context, id int) (*models.Performer, error) { - var ret models.Performer - if err := qb.getByID(ctx, id, &ret); err != nil { - if errors.Is(err, sql.ErrNoRows) { - return nil, nil - } - return nil, err - } - return &ret, nil +func (qb *PerformerStore) table() exp.IdentifierExpression { + return qb.tableMgr.table } -func (qb *performerQueryBuilder) FindMany(ctx context.Context, ids []int) ([]*models.Performer, error) { +func (qb *PerformerStore) selectDataset() *goqu.SelectDataset { + return dialect.From(qb.table()).Select(qb.table().All()) +} + +func (qb *PerformerStore) Find(ctx context.Context, id int) (*models.Performer, error) { + q := qb.selectDataset().Where(qb.tableMgr.byID(id)) + + ret, err := qb.get(ctx, q) + if err != nil { + return nil, fmt.Errorf("getting scene by id %d: %w", id, err) + } + + return ret, nil +} + +func (qb *PerformerStore) FindMany(ctx context.Context, ids []int) ([]*models.Performer, error) { tableMgr := performerTableMgr q := goqu.Select("*").From(tableMgr.table).Where(tableMgr.byIDInts(ids...)) unsorted, err := qb.getMany(ctx, q) @@ -118,16 +265,31 @@ func (qb *performerQueryBuilder) FindMany(ctx context.Context, ids []int) ([]*mo return ret, nil } -func (qb *performerQueryBuilder) getMany(ctx context.Context, q *goqu.SelectDataset) ([]*models.Performer, error) { +func (qb *PerformerStore) get(ctx context.Context, q *goqu.SelectDataset) (*models.Performer, error) { + ret, err := qb.getMany(ctx, q) + if err != nil { + return nil, err + } + + if len(ret) == 0 { + return nil, sql.ErrNoRows + } + + return ret[0], nil +} + +func (qb *PerformerStore) getMany(ctx context.Context, q *goqu.SelectDataset) ([]*models.Performer, error) { const single = false var ret []*models.Performer if err := queryFunc(ctx, q, single, func(r *sqlx.Rows) error { - var f models.Performer + var f performerRow if err := r.StructScan(&f); err != nil { return err } - ret = append(ret, &f) + s := f.resolve() + + ret = append(ret, s) return nil }); err != nil { return nil, err @@ -136,95 +298,126 @@ func (qb *performerQueryBuilder) getMany(ctx context.Context, q *goqu.SelectData return ret, nil } -func (qb *performerQueryBuilder) FindBySceneID(ctx context.Context, sceneID int) ([]*models.Performer, error) { - query := selectAll("performers") + ` - LEFT JOIN performers_scenes as scenes_join on scenes_join.performer_id = performers.id - WHERE scenes_join.scene_id = ? - ` - args := []interface{}{sceneID} - return qb.queryPerformers(ctx, query, args) +func (qb *PerformerStore) findBySubquery(ctx context.Context, sq *goqu.SelectDataset) ([]*models.Performer, error) { + table := qb.table() + + q := qb.selectDataset().Where( + table.Col(idColumn).Eq( + sq, + ), + ) + + return qb.getMany(ctx, q) } -func (qb *performerQueryBuilder) FindByImageID(ctx context.Context, imageID int) ([]*models.Performer, error) { - query := selectAll("performers") + ` - LEFT JOIN performers_images as images_join on images_join.performer_id = performers.id - WHERE images_join.image_id = ? - ` - args := []interface{}{imageID} - return qb.queryPerformers(ctx, query, args) -} +func (qb *PerformerStore) FindBySceneID(ctx context.Context, sceneID int) ([]*models.Performer, error) { + sq := dialect.From(scenesPerformersJoinTable).Select(scenesPerformersJoinTable.Col(performerIDColumn)).Where( + scenesPerformersJoinTable.Col(sceneIDColumn).Eq(sceneID), + ) + ret, err := qb.findBySubquery(ctx, sq) -func (qb *performerQueryBuilder) FindByGalleryID(ctx context.Context, galleryID int) ([]*models.Performer, error) { - query := selectAll("performers") + ` - LEFT JOIN performers_galleries as galleries_join on galleries_join.performer_id = performers.id - WHERE galleries_join.gallery_id = ? - ` - args := []interface{}{galleryID} - return qb.queryPerformers(ctx, query, args) -} - -func (qb *performerQueryBuilder) FindNamesBySceneID(ctx context.Context, sceneID int) ([]*models.Performer, error) { - query := ` - SELECT performers.name FROM performers - LEFT JOIN performers_scenes as scenes_join on scenes_join.performer_id = performers.id - WHERE scenes_join.scene_id = ? - ` - args := []interface{}{sceneID} - return qb.queryPerformers(ctx, query, args) -} - -func (qb *performerQueryBuilder) FindByNames(ctx context.Context, names []string, nocase bool) ([]*models.Performer, error) { - query := "SELECT * FROM performers WHERE name" - if nocase { - query += " COLLATE NOCASE" + if err != nil { + return nil, fmt.Errorf("getting performers for scene %d: %w", sceneID, err) } - query += " IN " + getInBinding(len(names)) + + return ret, nil +} + +func (qb *PerformerStore) FindByImageID(ctx context.Context, imageID int) ([]*models.Performer, error) { + sq := dialect.From(performersImagesJoinTable).Select(performersImagesJoinTable.Col(performerIDColumn)).Where( + performersImagesJoinTable.Col(imageIDColumn).Eq(imageID), + ) + ret, err := qb.findBySubquery(ctx, sq) + + if err != nil { + return nil, fmt.Errorf("getting performers for image %d: %w", imageID, err) + } + + return ret, nil +} + +func (qb *PerformerStore) FindByGalleryID(ctx context.Context, galleryID int) ([]*models.Performer, error) { + sq := dialect.From(performersGalleriesJoinTable).Select(performersGalleriesJoinTable.Col(performerIDColumn)).Where( + performersGalleriesJoinTable.Col(galleryIDColumn).Eq(galleryID), + ) + ret, err := qb.findBySubquery(ctx, sq) + + if err != nil { + return nil, fmt.Errorf("getting performers for gallery %d: %w", galleryID, err) + } + + return ret, nil +} + +func (qb *PerformerStore) FindByNames(ctx context.Context, names []string, nocase bool) ([]*models.Performer, error) { + clause := "name " + if nocase { + clause += "COLLATE NOCASE " + } + clause += "IN " + getInBinding(len(names)) var args []interface{} for _, name := range names { args = append(args, name) } - return qb.queryPerformers(ctx, query, args) -} -func (qb *performerQueryBuilder) CountByTagID(ctx context.Context, tagID int) (int, error) { - args := []interface{}{tagID} - return qb.runCountQuery(ctx, qb.buildCountQuery(countPerformersForTagQuery), args) -} + sq := qb.selectDataset().Prepared(true).Where( + goqu.L(clause, args...), + ) + ret, err := qb.getMany(ctx, sq) -func (qb *performerQueryBuilder) Count(ctx context.Context) (int, error) { - return qb.runCountQuery(ctx, qb.buildCountQuery("SELECT performers.id FROM performers"), nil) -} - -func (qb *performerQueryBuilder) All(ctx context.Context) ([]*models.Performer, error) { - return qb.queryPerformers(ctx, selectAll("performers")+qb.getPerformerSort(nil), nil) -} - -func (qb *performerQueryBuilder) QueryForAutoTag(ctx context.Context, words []string) ([]*models.Performer, error) { - // TODO - Query needs to be changed to support queries of this type, and - // this method should be removed - query := selectAll(performerTable) - - var whereClauses []string - var args []interface{} - - for _, w := range words { - whereClauses = append(whereClauses, "name like ?") - args = append(args, w+"%") - // TODO - commented out until alias matching works both ways - // whereClauses = append(whereClauses, "aliases like ?") - // args = append(args, w+"%") + if err != nil { + return nil, fmt.Errorf("getting performers by names: %w", err) } - whereOr := "(" + strings.Join(whereClauses, " OR ") + ")" - where := strings.Join([]string{ - "ignore_auto_tag = 0", - whereOr, - }, " AND ") - return qb.queryPerformers(ctx, query+" WHERE "+where, args) + return ret, nil } -func (qb *performerQueryBuilder) validateFilter(filter *models.PerformerFilterType) error { +func (qb *PerformerStore) CountByTagID(ctx context.Context, tagID int) (int, error) { + joinTable := performersTagsJoinTable + + q := dialect.Select(goqu.COUNT("*")).From(joinTable).Where(joinTable.Col(tagIDColumn).Eq(tagID)) + return count(ctx, q) +} + +func (qb *PerformerStore) Count(ctx context.Context) (int, error) { + q := dialect.Select(goqu.COUNT("*")).From(qb.table()) + return count(ctx, q) +} + +func (qb *PerformerStore) All(ctx context.Context) ([]*models.Performer, error) { + return qb.getMany(ctx, qb.selectDataset()) +} + +func (qb *PerformerStore) QueryForAutoTag(ctx context.Context, words []string) ([]*models.Performer, error) { + // TODO - Query needs to be changed to support queries of this type, and + // this method should be removed + table := qb.table() + sq := dialect.From(table).Select(table.Col(idColumn)).Where() + + var whereClauses []exp.Expression + + for _, w := range words { + whereClauses = append(whereClauses, table.Col("name").Like(w+"%")) + // TODO - commented out until alias matching works both ways + // whereClauses = append(whereClauses, table.Col("aliases").Like(w+"%") + } + + sq = sq.Where( + goqu.Or(whereClauses...), + table.Col("ignore_auto_tag").Eq(0), + ) + + ret, err := qb.findBySubquery(ctx, sq) + + if err != nil { + return nil, fmt.Errorf("getting performers for autotag: %w", err) + } + + return ret, nil +} + +func (qb *PerformerStore) validateFilter(filter *models.PerformerFilterType) error { const and = "AND" const or = "OR" const not = "NOT" @@ -255,7 +448,7 @@ func (qb *performerQueryBuilder) validateFilter(filter *models.PerformerFilterTy return nil } -func (qb *performerQueryBuilder) makeFilter(ctx context.Context, filter *models.PerformerFilterType) *filterBuilder { +func (qb *PerformerStore) makeFilter(ctx context.Context, filter *models.PerformerFilterType) *filterBuilder { query := &filterBuilder{} if filter.And != nil { @@ -322,7 +515,7 @@ func (qb *performerQueryBuilder) makeFilter(ctx context.Context, filter *models. return query } -func (qb *performerQueryBuilder) Query(ctx context.Context, performerFilter *models.PerformerFilterType, findFilter *models.FindFilterType) ([]*models.Performer, int, error) { +func (qb *PerformerStore) Query(ctx context.Context, performerFilter *models.PerformerFilterType, findFilter *models.FindFilterType) ([]*models.Performer, int, error) { if performerFilter == nil { performerFilter = &models.PerformerFilterType{} } @@ -359,7 +552,7 @@ func (qb *performerQueryBuilder) Query(ctx context.Context, performerFilter *mod return performers, countResult, nil } -func performerIsMissingCriterionHandler(qb *performerQueryBuilder, isMissing *string) criterionHandlerFunc { +func performerIsMissingCriterionHandler(qb *PerformerStore, isMissing *string) criterionHandlerFunc { return func(ctx context.Context, f *filterBuilder) { if isMissing != nil && *isMissing != "" { switch *isMissing { @@ -400,7 +593,7 @@ func performerAgeFilterCriterionHandler(age *models.IntCriterionInput) criterion } } -func performerTagsCriterionHandler(qb *performerQueryBuilder, tags *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { +func performerTagsCriterionHandler(qb *PerformerStore, tags *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { h := joinedHierarchicalMultiCriterionHandlerBuilder{ tx: qb.tx, @@ -417,7 +610,7 @@ func performerTagsCriterionHandler(qb *performerQueryBuilder, tags *models.Hiera return h.handler(tags) } -func performerTagCountCriterionHandler(qb *performerQueryBuilder, count *models.IntCriterionInput) criterionHandlerFunc { +func performerTagCountCriterionHandler(qb *PerformerStore, count *models.IntCriterionInput) criterionHandlerFunc { h := countCriterionHandlerBuilder{ primaryTable: performerTable, joinTable: performersTagsTable, @@ -427,7 +620,7 @@ func performerTagCountCriterionHandler(qb *performerQueryBuilder, count *models. return h.handler(count) } -func performerSceneCountCriterionHandler(qb *performerQueryBuilder, count *models.IntCriterionInput) criterionHandlerFunc { +func performerSceneCountCriterionHandler(qb *PerformerStore, count *models.IntCriterionInput) criterionHandlerFunc { h := countCriterionHandlerBuilder{ primaryTable: performerTable, joinTable: performersScenesTable, @@ -437,7 +630,7 @@ func performerSceneCountCriterionHandler(qb *performerQueryBuilder, count *model return h.handler(count) } -func performerImageCountCriterionHandler(qb *performerQueryBuilder, count *models.IntCriterionInput) criterionHandlerFunc { +func performerImageCountCriterionHandler(qb *PerformerStore, count *models.IntCriterionInput) criterionHandlerFunc { h := countCriterionHandlerBuilder{ primaryTable: performerTable, joinTable: performersImagesTable, @@ -447,7 +640,7 @@ func performerImageCountCriterionHandler(qb *performerQueryBuilder, count *model return h.handler(count) } -func performerGalleryCountCriterionHandler(qb *performerQueryBuilder, count *models.IntCriterionInput) criterionHandlerFunc { +func performerGalleryCountCriterionHandler(qb *PerformerStore, count *models.IntCriterionInput) criterionHandlerFunc { h := countCriterionHandlerBuilder{ primaryTable: performerTable, joinTable: performersGalleriesTable, @@ -457,7 +650,7 @@ func performerGalleryCountCriterionHandler(qb *performerQueryBuilder, count *mod return h.handler(count) } -func performerStudiosCriterionHandler(qb *performerQueryBuilder, studios *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { +func performerStudiosCriterionHandler(qb *PerformerStore, studios *models.HierarchicalMultiCriterionInput) criterionHandlerFunc { return func(ctx context.Context, f *filterBuilder) { if studios != nil { formatMaps := []utils.StrFormatMap{ @@ -534,7 +727,7 @@ func performerStudiosCriterionHandler(qb *performerQueryBuilder, studios *models } } -func (qb *performerQueryBuilder) getPerformerSort(findFilter *models.FindFilterType) string { +func (qb *PerformerStore) getPerformerSort(findFilter *models.FindFilterType) string { var sort string var direction string if findFilter == nil { @@ -561,16 +754,7 @@ func (qb *performerQueryBuilder) getPerformerSort(findFilter *models.FindFilterT return getSort(sort, direction, "performers") } -func (qb *performerQueryBuilder) queryPerformers(ctx context.Context, query string, args []interface{}) ([]*models.Performer, error) { - var ret models.Performers - if err := qb.query(ctx, query, args, &ret); err != nil { - return nil, err - } - - return []*models.Performer(ret), nil -} - -func (qb *performerQueryBuilder) tagsRepository() *joinRepository { +func (qb *PerformerStore) tagsRepository() *joinRepository { return &joinRepository{ repository: repository{ tx: qb.tx, @@ -581,16 +765,16 @@ func (qb *performerQueryBuilder) tagsRepository() *joinRepository { } } -func (qb *performerQueryBuilder) GetTagIDs(ctx context.Context, id int) ([]int, error) { +func (qb *PerformerStore) GetTagIDs(ctx context.Context, id int) ([]int, error) { return qb.tagsRepository().getIDs(ctx, id) } -func (qb *performerQueryBuilder) UpdateTags(ctx context.Context, id int, tagIDs []int) error { +func (qb *PerformerStore) UpdateTags(ctx context.Context, id int, tagIDs []int) error { // Delete the existing joins and then create new ones return qb.tagsRepository().replace(ctx, id, tagIDs) } -func (qb *performerQueryBuilder) imageRepository() *imageRepository { +func (qb *PerformerStore) imageRepository() *imageRepository { return &imageRepository{ repository: repository{ tx: qb.tx, @@ -601,19 +785,19 @@ func (qb *performerQueryBuilder) imageRepository() *imageRepository { } } -func (qb *performerQueryBuilder) GetImage(ctx context.Context, performerID int) ([]byte, error) { +func (qb *PerformerStore) GetImage(ctx context.Context, performerID int) ([]byte, error) { return qb.imageRepository().get(ctx, performerID) } -func (qb *performerQueryBuilder) UpdateImage(ctx context.Context, performerID int, image []byte) error { +func (qb *PerformerStore) UpdateImage(ctx context.Context, performerID int, image []byte) error { return qb.imageRepository().replace(ctx, performerID, image) } -func (qb *performerQueryBuilder) DestroyImage(ctx context.Context, performerID int) error { +func (qb *PerformerStore) DestroyImage(ctx context.Context, performerID int) error { return qb.imageRepository().destroy(ctx, []int{performerID}) } -func (qb *performerQueryBuilder) stashIDRepository() *stashIDRepository { +func (qb *PerformerStore) stashIDRepository() *stashIDRepository { return &stashIDRepository{ repository{ tx: qb.tx, @@ -623,40 +807,51 @@ func (qb *performerQueryBuilder) stashIDRepository() *stashIDRepository { } } -func (qb *performerQueryBuilder) GetStashIDs(ctx context.Context, performerID int) ([]models.StashID, error) { +func (qb *PerformerStore) GetStashIDs(ctx context.Context, performerID int) ([]models.StashID, error) { return qb.stashIDRepository().get(ctx, performerID) } -func (qb *performerQueryBuilder) UpdateStashIDs(ctx context.Context, performerID int, stashIDs []models.StashID) error { +func (qb *PerformerStore) UpdateStashIDs(ctx context.Context, performerID int, stashIDs []models.StashID) error { return qb.stashIDRepository().replace(ctx, performerID, stashIDs) } -func (qb *performerQueryBuilder) FindByStashID(ctx context.Context, stashID models.StashID) ([]*models.Performer, error) { - query := selectAll("performers") + ` - LEFT JOIN performer_stash_ids on performer_stash_ids.performer_id = performers.id - WHERE performer_stash_ids.stash_id = ? - AND performer_stash_ids.endpoint = ? - ` - args := []interface{}{stashID.StashID, stashID.Endpoint} - return qb.queryPerformers(ctx, query, args) -} +func (qb *PerformerStore) FindByStashID(ctx context.Context, stashID models.StashID) ([]*models.Performer, error) { + sq := dialect.From(performersStashIDsJoinTable).Select(performersStashIDsJoinTable.Col(performerIDColumn)).Where( + performersStashIDsJoinTable.Col("stash_id").Eq(stashID.StashID), + performersStashIDsJoinTable.Col("endpoint").Eq(stashID.Endpoint), + ) + ret, err := qb.findBySubquery(ctx, sq) -func (qb *performerQueryBuilder) FindByStashIDStatus(ctx context.Context, hasStashID bool, stashboxEndpoint string) ([]*models.Performer, error) { - query := selectAll("performers") + ` - LEFT JOIN performer_stash_ids on performer_stash_ids.performer_id = performers.id - ` - - if hasStashID { - query += ` - WHERE performer_stash_ids.stash_id IS NOT NULL - AND performer_stash_ids.endpoint = ? - ` - } else { - query += ` - WHERE performer_stash_ids.stash_id IS NULL - ` + if err != nil { + return nil, fmt.Errorf("getting performers for stash ID %s: %w", stashID.StashID, err) } - args := []interface{}{stashboxEndpoint} - return qb.queryPerformers(ctx, query, args) + return ret, nil +} + +func (qb *PerformerStore) FindByStashIDStatus(ctx context.Context, hasStashID bool, stashboxEndpoint string) ([]*models.Performer, error) { + table := qb.table() + sq := dialect.From(table).LeftJoin( + performersStashIDsJoinTable, + goqu.On(table.Col(idColumn).Eq(performersStashIDsJoinTable.Col(performerIDColumn))), + ).Select(table.Col(idColumn)) + + if hasStashID { + sq = sq.Where( + performersStashIDsJoinTable.Col("stash_id").IsNotNull(), + performersStashIDsJoinTable.Col("endpoint").Eq(stashboxEndpoint), + ) + } else { + sq = sq.Where( + performersStashIDsJoinTable.Col("stash_id").IsNull(), + ) + } + + ret, err := qb.findBySubquery(ctx, sq) + + if err != nil { + return nil, fmt.Errorf("getting performers for stash-box endpoint %s: %w", stashboxEndpoint, err) + } + + return ret, nil } diff --git a/pkg/sqlite/performer_test.go b/pkg/sqlite/performer_test.go index 2075407a5..40736f617 100644 --- a/pkg/sqlite/performer_test.go +++ b/pkg/sqlite/performer_test.go @@ -5,7 +5,6 @@ package sqlite_test import ( "context" - "database/sql" "fmt" "math" "strconv" @@ -13,16 +12,246 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" - "github.com/stashapp/stash/pkg/hash/md5" "github.com/stashapp/stash/pkg/models" - "github.com/stashapp/stash/pkg/sqlite" + "github.com/stretchr/testify/assert" ) +func Test_PerformerStore_Update(t *testing.T) { + var ( + name = "name" + gender = models.GenderEnumFemale + checksum = "checksum" + details = "details" + url = "url" + twitter = "twitter" + instagram = "instagram" + rating = 3 + ethnicity = "ethnicity" + country = "country" + eyeColor = "eyeColor" + height = "height" + measurements = "measurements" + fakeTits = "fakeTits" + careerLength = "careerLength" + tattoos = "tattoos" + piercings = "piercings" + aliases = "aliases" + hairColor = "hairColor" + weight = 123 + ignoreAutoTag = true + favorite = true + createdAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC) + updatedAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC) + + birthdate = models.NewDate("2003-02-01") + deathdate = models.NewDate("2023-02-01") + ) + + tests := []struct { + name string + updatedObject *models.Performer + wantErr bool + }{ + { + "full", + &models.Performer{ + ID: performerIDs[performerIdxWithGallery], + Name: name, + Checksum: checksum, + Gender: gender, + URL: url, + Twitter: twitter, + Instagram: instagram, + Birthdate: &birthdate, + Ethnicity: ethnicity, + Country: country, + EyeColor: eyeColor, + Height: height, + Measurements: measurements, + FakeTits: fakeTits, + CareerLength: careerLength, + Tattoos: tattoos, + Piercings: piercings, + Aliases: aliases, + Favorite: favorite, + Rating: &rating, + Details: details, + DeathDate: &deathdate, + HairColor: hairColor, + Weight: &weight, + IgnoreAutoTag: ignoreAutoTag, + CreatedAt: createdAt, + UpdatedAt: updatedAt, + }, + false, + }, + { + "clear all", + &models.Performer{ + ID: performerIDs[performerIdxWithGallery], + }, + false, + }, + } + + qb := db.Performer + for _, tt := range tests { + runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { + assert := assert.New(t) + + copy := *tt.updatedObject + + if err := qb.Update(ctx, tt.updatedObject); (err != nil) != tt.wantErr { + t.Errorf("PerformerStore.Update() error = %v, wantErr %v", err, tt.wantErr) + } + + if tt.wantErr { + return + } + + s, err := qb.Find(ctx, tt.updatedObject.ID) + if err != nil { + t.Errorf("PerformerStore.Find() error = %v", err) + } + + assert.Equal(copy, *s) + }) + } +} + +func Test_PerformerStore_UpdatePartial(t *testing.T) { + var ( + name = "name" + gender = models.GenderEnumFemale + checksum = "checksum" + details = "details" + url = "url" + twitter = "twitter" + instagram = "instagram" + rating = 3 + ethnicity = "ethnicity" + country = "country" + eyeColor = "eyeColor" + height = "height" + measurements = "measurements" + fakeTits = "fakeTits" + careerLength = "careerLength" + tattoos = "tattoos" + piercings = "piercings" + aliases = "aliases" + hairColor = "hairColor" + weight = 123 + ignoreAutoTag = true + favorite = true + createdAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC) + updatedAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC) + + birthdate = models.NewDate("2003-02-01") + deathdate = models.NewDate("2023-02-01") + ) + + tests := []struct { + name string + id int + partial models.PerformerPartial + want models.Performer + wantErr bool + }{ + { + "full", + performerIDs[performerIdxWithDupName], + models.PerformerPartial{ + Name: models.NewOptionalString(name), + Checksum: models.NewOptionalString(checksum), + Gender: models.NewOptionalString(gender.String()), + URL: models.NewOptionalString(url), + Twitter: models.NewOptionalString(twitter), + Instagram: models.NewOptionalString(instagram), + Birthdate: models.NewOptionalDate(birthdate), + Ethnicity: models.NewOptionalString(ethnicity), + Country: models.NewOptionalString(country), + EyeColor: models.NewOptionalString(eyeColor), + Height: models.NewOptionalString(height), + Measurements: models.NewOptionalString(measurements), + FakeTits: models.NewOptionalString(fakeTits), + CareerLength: models.NewOptionalString(careerLength), + Tattoos: models.NewOptionalString(tattoos), + Piercings: models.NewOptionalString(piercings), + Aliases: models.NewOptionalString(aliases), + Favorite: models.NewOptionalBool(favorite), + Rating: models.NewOptionalInt(rating), + Details: models.NewOptionalString(details), + DeathDate: models.NewOptionalDate(deathdate), + HairColor: models.NewOptionalString(hairColor), + Weight: models.NewOptionalInt(weight), + IgnoreAutoTag: models.NewOptionalBool(ignoreAutoTag), + CreatedAt: models.NewOptionalTime(createdAt), + UpdatedAt: models.NewOptionalTime(updatedAt), + }, + models.Performer{ + ID: performerIDs[performerIdxWithDupName], + Name: name, + Checksum: checksum, + Gender: gender, + URL: url, + Twitter: twitter, + Instagram: instagram, + Birthdate: &birthdate, + Ethnicity: ethnicity, + Country: country, + EyeColor: eyeColor, + Height: height, + Measurements: measurements, + FakeTits: fakeTits, + CareerLength: careerLength, + Tattoos: tattoos, + Piercings: piercings, + Aliases: aliases, + Favorite: favorite, + Rating: &rating, + Details: details, + DeathDate: &deathdate, + HairColor: hairColor, + Weight: &weight, + IgnoreAutoTag: ignoreAutoTag, + CreatedAt: createdAt, + UpdatedAt: updatedAt, + }, + false, + }, + } + for _, tt := range tests { + qb := db.Performer + + runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { + assert := assert.New(t) + + got, err := qb.UpdatePartial(ctx, tt.id, tt.partial) + if (err != nil) != tt.wantErr { + t.Errorf("PerformerStore.UpdatePartial() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.wantErr { + return + } + + assert.Equal(tt.want, *got) + + s, err := qb.Find(ctx, tt.id) + if err != nil { + t.Errorf("PerformerStore.Find() error = %v", err) + } + + assert.Equal(tt.want, *s) + }) + } +} + func TestPerformerFindBySceneID(t *testing.T) { withTxn(func(ctx context.Context) error { - pqb := sqlite.PerformerReaderWriter + pqb := db.Performer sceneID := sceneIDs[sceneIdxWithPerformer] performers, err := pqb.FindBySceneID(ctx, sceneID) @@ -31,10 +260,13 @@ func TestPerformerFindBySceneID(t *testing.T) { t.Errorf("Error finding performer: %s", err.Error()) } - assert.Equal(t, 1, len(performers)) + if !assert.Equal(t, 1, len(performers)) { + return nil + } + performer := performers[0] - assert.Equal(t, getPerformerStringValue(performerIdxWithScene, "Name"), performer.Name.String) + assert.Equal(t, getPerformerStringValue(performerIdxWithScene, "Name"), performer.Name) performers, err = pqb.FindBySceneID(ctx, 0) @@ -48,11 +280,73 @@ func TestPerformerFindBySceneID(t *testing.T) { }) } +func TestPerformerFindByImageID(t *testing.T) { + withTxn(func(ctx context.Context) error { + pqb := db.Performer + imageID := imageIDs[imageIdxWithPerformer] + + performers, err := pqb.FindByImageID(ctx, imageID) + + if err != nil { + t.Errorf("Error finding performer: %s", err.Error()) + } + + if !assert.Equal(t, 1, len(performers)) { + return nil + } + + performer := performers[0] + + assert.Equal(t, getPerformerStringValue(performerIdxWithImage, "Name"), performer.Name) + + performers, err = pqb.FindByImageID(ctx, 0) + + if err != nil { + t.Errorf("Error finding performer: %s", err.Error()) + } + + assert.Equal(t, 0, len(performers)) + + return nil + }) +} + +func TestPerformerFindByGalleryID(t *testing.T) { + withTxn(func(ctx context.Context) error { + pqb := db.Performer + galleryID := galleryIDs[galleryIdxWithPerformer] + + performers, err := pqb.FindByGalleryID(ctx, galleryID) + + if err != nil { + t.Errorf("Error finding performer: %s", err.Error()) + } + + if !assert.Equal(t, 1, len(performers)) { + return nil + } + + performer := performers[0] + + assert.Equal(t, getPerformerStringValue(performerIdxWithGallery, "Name"), performer.Name) + + performers, err = pqb.FindByGalleryID(ctx, 0) + + if err != nil { + t.Errorf("Error finding performer: %s", err.Error()) + } + + assert.Equal(t, 0, len(performers)) + + return nil + }) +} + func TestPerformerFindByNames(t *testing.T) { getNames := func(p []*models.Performer) []string { var ret []string for _, pp := range p { - ret = append(ret, pp.Name.String) + ret = append(ret, pp.Name) } return ret } @@ -60,7 +354,7 @@ func TestPerformerFindByNames(t *testing.T) { withTxn(func(ctx context.Context) error { var names []string - pqb := sqlite.PerformerReaderWriter + pqb := db.Performer names = append(names, performerNames[performerIdxWithScene]) // find performers by names @@ -69,15 +363,15 @@ func TestPerformerFindByNames(t *testing.T) { t.Errorf("Error finding performers: %s", err.Error()) } assert.Len(t, performers, 1) - assert.Equal(t, performerNames[performerIdxWithScene], performers[0].Name.String) + assert.Equal(t, performerNames[performerIdxWithScene], performers[0].Name) performers, err = pqb.FindByNames(ctx, names, true) // find performers by names nocase if err != nil { t.Errorf("Error finding performers: %s", err.Error()) } assert.Len(t, performers, 2) // performerIdxWithScene and performerIdxWithDupName - assert.Equal(t, strings.ToLower(performerNames[performerIdxWithScene]), strings.ToLower(performers[0].Name.String)) - assert.Equal(t, strings.ToLower(performerNames[performerIdxWithScene]), strings.ToLower(performers[1].Name.String)) + assert.Equal(t, strings.ToLower(performerNames[performerIdxWithScene]), strings.ToLower(performers[0].Name)) + assert.Equal(t, strings.ToLower(performerNames[performerIdxWithScene]), strings.ToLower(performers[1].Name)) names = append(names, performerNames[performerIdx1WithScene]) // find performers by names ( 2 names ) @@ -125,13 +419,11 @@ func TestPerformerQueryEthnicityOr(t *testing.T) { } withTxn(func(ctx context.Context) error { - sqb := sqlite.PerformerReaderWriter - - performers := queryPerformers(ctx, t, sqb, &performerFilter, nil) + performers := queryPerformers(ctx, t, &performerFilter, nil) assert.Len(t, performers, 2) - assert.Equal(t, performer1Eth, performers[0].Ethnicity.String) - assert.Equal(t, performer2Eth, performers[1].Ethnicity.String) + assert.Equal(t, performer1Eth, performers[0].Ethnicity) + assert.Equal(t, performer2Eth, performers[1].Ethnicity) return nil }) @@ -140,7 +432,7 @@ func TestPerformerQueryEthnicityOr(t *testing.T) { func TestPerformerQueryEthnicityAndRating(t *testing.T) { const performerIdx = 1 performerEth := getPerformerStringValue(performerIdx, "Ethnicity") - performerRating := getRating(performerIdx) + performerRating := int(getRating(performerIdx).Int64) performerFilter := models.PerformerFilterType{ Ethnicity: &models.StringCriterionInput{ @@ -149,20 +441,20 @@ func TestPerformerQueryEthnicityAndRating(t *testing.T) { }, And: &models.PerformerFilterType{ Rating: &models.IntCriterionInput{ - Value: int(performerRating.Int64), + Value: performerRating, Modifier: models.CriterionModifierEquals, }, }, } withTxn(func(ctx context.Context) error { - sqb := sqlite.PerformerReaderWriter - - performers := queryPerformers(ctx, t, sqb, &performerFilter, nil) + performers := queryPerformers(ctx, t, &performerFilter, nil) assert.Len(t, performers, 1) - assert.Equal(t, performerEth, performers[0].Ethnicity.String) - assert.Equal(t, performerRating.Int64, performers[0].Rating.Int64) + assert.Equal(t, performerEth, performers[0].Ethnicity) + if assert.NotNil(t, performers[0].Rating) { + assert.Equal(t, performerRating, *performers[0].Rating) + } return nil }) @@ -191,14 +483,12 @@ func TestPerformerQueryEthnicityNotRating(t *testing.T) { } withTxn(func(ctx context.Context) error { - sqb := sqlite.PerformerReaderWriter - - performers := queryPerformers(ctx, t, sqb, &performerFilter, nil) + performers := queryPerformers(ctx, t, &performerFilter, nil) for _, performer := range performers { - verifyString(t, performer.Ethnicity.String, ethCriterion) + verifyString(t, performer.Ethnicity, ethCriterion) ratingCriterion.Modifier = models.CriterionModifierNotEquals - verifyInt64(t, performer.Rating, ratingCriterion) + verifyIntPtr(t, performer.Rating, ratingCriterion) } return nil @@ -222,7 +512,7 @@ func TestPerformerIllegalQuery(t *testing.T) { } withTxn(func(ctx context.Context) error { - sqb := sqlite.PerformerReaderWriter + sqb := db.Performer _, _, err := sqb.Query(ctx, performerFilter, nil) assert.NotNil(err) @@ -248,9 +538,7 @@ func TestPerformerQueryIgnoreAutoTag(t *testing.T) { IgnoreAutoTag: &ignoreAutoTag, } - sqb := sqlite.PerformerReaderWriter - - performers := queryPerformers(ctx, t, sqb, &performerFilter, nil) + performers := queryPerformers(ctx, t, &performerFilter, nil) assert.Len(t, performers, int(math.Ceil(float64(totalPerformers)/5))) for _, p := range performers { @@ -263,7 +551,7 @@ func TestPerformerQueryIgnoreAutoTag(t *testing.T) { func TestPerformerQueryForAutoTag(t *testing.T) { withTxn(func(ctx context.Context) error { - tqb := sqlite.PerformerReaderWriter + tqb := db.Performer name := performerNames[performerIdx1WithScene] // find a performer by name @@ -274,44 +562,43 @@ func TestPerformerQueryForAutoTag(t *testing.T) { } assert.Len(t, performers, 2) - assert.Equal(t, strings.ToLower(performerNames[performerIdx1WithScene]), strings.ToLower(performers[0].Name.String)) - assert.Equal(t, strings.ToLower(performerNames[performerIdx1WithScene]), strings.ToLower(performers[1].Name.String)) + assert.Equal(t, strings.ToLower(performerNames[performerIdx1WithScene]), strings.ToLower(performers[0].Name)) + assert.Equal(t, strings.ToLower(performerNames[performerIdx1WithScene]), strings.ToLower(performers[1].Name)) return nil }) } func TestPerformerUpdatePerformerImage(t *testing.T) { - if err := withTxn(func(ctx context.Context) error { - qb := sqlite.PerformerReaderWriter + if err := withRollbackTxn(func(ctx context.Context) error { + qb := db.Performer // create performer to test against const name = "TestPerformerUpdatePerformerImage" performer := models.Performer{ - Name: sql.NullString{String: name, Valid: true}, + Name: name, Checksum: md5.FromString(name), - Favorite: sql.NullBool{Bool: false, Valid: true}, } - created, err := qb.Create(ctx, performer) + err := qb.Create(ctx, &performer) if err != nil { return fmt.Errorf("Error creating performer: %s", err.Error()) } image := []byte("image") - err = qb.UpdateImage(ctx, created.ID, image) + err = qb.UpdateImage(ctx, performer.ID, image) if err != nil { return fmt.Errorf("Error updating performer image: %s", err.Error()) } // ensure image set - storedImage, err := qb.GetImage(ctx, created.ID) + storedImage, err := qb.GetImage(ctx, performer.ID) if err != nil { return fmt.Errorf("Error getting image: %s", err.Error()) } assert.Equal(t, storedImage, image) // set nil image - err = qb.UpdateImage(ctx, created.ID, nil) + err = qb.UpdateImage(ctx, performer.ID, nil) if err == nil { return fmt.Errorf("Expected error setting nil image") } @@ -323,34 +610,33 @@ func TestPerformerUpdatePerformerImage(t *testing.T) { } func TestPerformerDestroyPerformerImage(t *testing.T) { - if err := withTxn(func(ctx context.Context) error { - qb := sqlite.PerformerReaderWriter + if err := withRollbackTxn(func(ctx context.Context) error { + qb := db.Performer // create performer to test against const name = "TestPerformerDestroyPerformerImage" performer := models.Performer{ - Name: sql.NullString{String: name, Valid: true}, + Name: name, Checksum: md5.FromString(name), - Favorite: sql.NullBool{Bool: false, Valid: true}, } - created, err := qb.Create(ctx, performer) + err := qb.Create(ctx, &performer) if err != nil { return fmt.Errorf("Error creating performer: %s", err.Error()) } image := []byte("image") - err = qb.UpdateImage(ctx, created.ID, image) + err = qb.UpdateImage(ctx, performer.ID, image) if err != nil { return fmt.Errorf("Error updating performer image: %s", err.Error()) } - err = qb.DestroyImage(ctx, created.ID) + err = qb.DestroyImage(ctx, performer.ID) if err != nil { return fmt.Errorf("Error destroying performer image: %s", err.Error()) } // image should be nil - storedImage, err := qb.GetImage(ctx, created.ID) + storedImage, err := qb.GetImage(ctx, performer.ID) if err != nil { return fmt.Errorf("Error getting image: %s", err.Error()) } @@ -383,7 +669,7 @@ func TestPerformerQueryAge(t *testing.T) { func verifyPerformerAge(t *testing.T, ageCriterion models.IntCriterionInput) { withTxn(func(ctx context.Context) error { - qb := sqlite.PerformerReaderWriter + qb := db.Performer performerFilter := models.PerformerFilterType{ Age: &ageCriterion, } @@ -397,12 +683,11 @@ func verifyPerformerAge(t *testing.T, ageCriterion models.IntCriterionInput) { for _, performer := range performers { cd := now - if performer.DeathDate.Valid { - cd, _ = time.Parse("2006-01-02", performer.DeathDate.String) + if performer.DeathDate != nil { + cd = performer.DeathDate.Time } - bd := performer.Birthdate.String - d, _ := time.Parse("2006-01-02", bd) + d := performer.Birthdate.Time age := cd.Year() - d.Year() if cd.YearDay() < d.YearDay() { age = age - 1 @@ -436,7 +721,7 @@ func TestPerformerQueryCareerLength(t *testing.T) { func verifyPerformerCareerLength(t *testing.T, criterion models.StringCriterionInput) { withTxn(func(ctx context.Context) error { - qb := sqlite.PerformerReaderWriter + qb := db.Performer performerFilter := models.PerformerFilterType{ CareerLength: &criterion, } @@ -448,7 +733,7 @@ func verifyPerformerCareerLength(t *testing.T, criterion models.StringCriterionI for _, performer := range performers { cl := performer.CareerLength - verifyNullString(t, cl, criterion) + verifyString(t, cl, criterion) } return nil @@ -470,7 +755,7 @@ func TestPerformerQueryURL(t *testing.T) { verifyFn := func(g *models.Performer) { t.Helper() - verifyNullString(t, g.URL, urlCriterion) + verifyString(t, g.URL, urlCriterion) } verifyPerformerQuery(t, filter, verifyFn) @@ -496,9 +781,7 @@ func TestPerformerQueryURL(t *testing.T) { func verifyPerformerQuery(t *testing.T, filter models.PerformerFilterType, verifyFn func(s *models.Performer)) { withTxn(func(ctx context.Context) error { t.Helper() - sqb := sqlite.PerformerReaderWriter - - performers := queryPerformers(ctx, t, sqb, &filter, nil) + performers := queryPerformers(ctx, t, &filter, nil) // assume it should find at least one assert.Greater(t, len(performers), 0) @@ -511,8 +794,8 @@ func verifyPerformerQuery(t *testing.T, filter models.PerformerFilterType, verif }) } -func queryPerformers(ctx context.Context, t *testing.T, qb models.PerformerReader, performerFilter *models.PerformerFilterType, findFilter *models.FindFilterType) []*models.Performer { - performers, _, err := qb.Query(ctx, performerFilter, findFilter) +func queryPerformers(ctx context.Context, t *testing.T, performerFilter *models.PerformerFilterType, findFilter *models.FindFilterType) []*models.Performer { + performers, _, err := db.Performer.Query(ctx, performerFilter, findFilter) if err != nil { t.Errorf("Error querying performers: %s", err.Error()) } @@ -522,7 +805,6 @@ func queryPerformers(ctx context.Context, t *testing.T, qb models.PerformerReade func TestPerformerQueryTags(t *testing.T) { withTxn(func(ctx context.Context) error { - sqb := sqlite.PerformerReaderWriter tagCriterion := models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdxWithPerformer]), @@ -536,7 +818,7 @@ func TestPerformerQueryTags(t *testing.T) { } // ensure ids are correct - performers := queryPerformers(ctx, t, sqb, &performerFilter, nil) + performers := queryPerformers(ctx, t, &performerFilter, nil) assert.Len(t, performers, 2) for _, performer := range performers { assert.True(t, performer.ID == performerIDs[performerIdxWithTag] || performer.ID == performerIDs[performerIdxWithTwoTags]) @@ -550,7 +832,7 @@ func TestPerformerQueryTags(t *testing.T) { Modifier: models.CriterionModifierIncludesAll, } - performers = queryPerformers(ctx, t, sqb, &performerFilter, nil) + performers = queryPerformers(ctx, t, &performerFilter, nil) assert.Len(t, performers, 1) assert.Equal(t, sceneIDs[performerIdxWithTwoTags], performers[0].ID) @@ -567,7 +849,7 @@ func TestPerformerQueryTags(t *testing.T) { Q: &q, } - performers = queryPerformers(ctx, t, sqb, &performerFilter, &findFilter) + performers = queryPerformers(ctx, t, &performerFilter, &findFilter) assert.Len(t, performers, 0) return nil @@ -595,12 +877,12 @@ func TestPerformerQueryTagCount(t *testing.T) { func verifyPerformersTagCount(t *testing.T, tagCountCriterion models.IntCriterionInput) { withTxn(func(ctx context.Context) error { - sqb := sqlite.PerformerReaderWriter + sqb := db.Performer performerFilter := models.PerformerFilterType{ TagCount: &tagCountCriterion, } - performers := queryPerformers(ctx, t, sqb, &performerFilter, nil) + performers := queryPerformers(ctx, t, &performerFilter, nil) assert.Greater(t, len(performers), 0) for _, performer := range performers { @@ -636,12 +918,11 @@ func TestPerformerQuerySceneCount(t *testing.T) { func verifyPerformersSceneCount(t *testing.T, sceneCountCriterion models.IntCriterionInput) { withTxn(func(ctx context.Context) error { - sqb := sqlite.PerformerReaderWriter performerFilter := models.PerformerFilterType{ SceneCount: &sceneCountCriterion, } - performers := queryPerformers(ctx, t, sqb, &performerFilter, nil) + performers := queryPerformers(ctx, t, &performerFilter, nil) assert.Greater(t, len(performers), 0) for _, performer := range performers { @@ -677,12 +958,11 @@ func TestPerformerQueryImageCount(t *testing.T) { func verifyPerformersImageCount(t *testing.T, imageCountCriterion models.IntCriterionInput) { withTxn(func(ctx context.Context) error { - sqb := sqlite.PerformerReaderWriter performerFilter := models.PerformerFilterType{ ImageCount: &imageCountCriterion, } - performers := queryPerformers(ctx, t, sqb, &performerFilter, nil) + performers := queryPerformers(ctx, t, &performerFilter, nil) assert.Greater(t, len(performers), 0) for _, performer := range performers { @@ -733,12 +1013,11 @@ func TestPerformerQueryGalleryCount(t *testing.T) { func verifyPerformersGalleryCount(t *testing.T, galleryCountCriterion models.IntCriterionInput) { withTxn(func(ctx context.Context) error { - sqb := sqlite.PerformerReaderWriter performerFilter := models.PerformerFilterType{ GalleryCount: &galleryCountCriterion, } - performers := queryPerformers(ctx, t, sqb, &performerFilter, nil) + performers := queryPerformers(ctx, t, &performerFilter, nil) assert.Greater(t, len(performers), 0) for _, performer := range performers { @@ -773,8 +1052,6 @@ func TestPerformerQueryStudio(t *testing.T) { {studioIndex: studioIdxWithGalleryPerformer, performerIndex: performerIdxWithGalleryStudio}, } - sqb := sqlite.PerformerReaderWriter - for _, tc := range testCases { studioCriterion := models.HierarchicalMultiCriterionInput{ Value: []string{ @@ -787,7 +1064,7 @@ func TestPerformerQueryStudio(t *testing.T) { Studios: &studioCriterion, } - performers := queryPerformers(ctx, t, sqb, &performerFilter, nil) + performers := queryPerformers(ctx, t, &performerFilter, nil) assert.Len(t, performers, 1) @@ -806,7 +1083,7 @@ func TestPerformerQueryStudio(t *testing.T) { Q: &q, } - performers = queryPerformers(ctx, t, sqb, &performerFilter, &findFilter) + performers = queryPerformers(ctx, t, &performerFilter, &findFilter) assert.Len(t, performers, 0) } @@ -821,21 +1098,21 @@ func TestPerformerQueryStudio(t *testing.T) { Q: &q, } - performers := queryPerformers(ctx, t, sqb, performerFilter, findFilter) + performers := queryPerformers(ctx, t, performerFilter, findFilter) assert.Len(t, performers, 1) assert.Equal(t, imageIDs[performerIdx1WithImage], performers[0].ID) q = getPerformerStringValue(performerIdxWithSceneStudio, "Name") - performers = queryPerformers(ctx, t, sqb, performerFilter, findFilter) + performers = queryPerformers(ctx, t, performerFilter, findFilter) assert.Len(t, performers, 0) performerFilter.Studios.Modifier = models.CriterionModifierNotNull - performers = queryPerformers(ctx, t, sqb, performerFilter, findFilter) + performers = queryPerformers(ctx, t, performerFilter, findFilter) assert.Len(t, performers, 1) assert.Equal(t, imageIDs[performerIdxWithSceneStudio], performers[0].ID) q = getPerformerStringValue(performerIdx1WithImage, "Name") - performers = queryPerformers(ctx, t, sqb, performerFilter, findFilter) + performers = queryPerformers(ctx, t, performerFilter, findFilter) assert.Len(t, performers, 0) return nil @@ -843,22 +1120,21 @@ func TestPerformerQueryStudio(t *testing.T) { } func TestPerformerStashIDs(t *testing.T) { - if err := withTxn(func(ctx context.Context) error { - qb := sqlite.PerformerReaderWriter + if err := withRollbackTxn(func(ctx context.Context) error { + qb := db.Performer // create performer to test against const name = "TestStashIDs" performer := models.Performer{ - Name: sql.NullString{String: name, Valid: true}, + Name: name, Checksum: md5.FromString(name), - Favorite: sql.NullBool{Bool: false, Valid: true}, } - created, err := qb.Create(ctx, performer) + err := qb.Create(ctx, &performer) if err != nil { return fmt.Errorf("Error creating performer: %s", err.Error()) } - testStashIDReaderWriter(ctx, t, qb, created.ID) + testStashIDReaderWriter(ctx, t, qb, performer.ID) return nil }); err != nil { t.Error(err.Error()) @@ -891,15 +1167,14 @@ func TestPerformerQueryRating(t *testing.T) { func verifyPerformersRating(t *testing.T, ratingCriterion models.IntCriterionInput) { withTxn(func(ctx context.Context) error { - sqb := sqlite.PerformerReaderWriter performerFilter := models.PerformerFilterType{ Rating: &ratingCriterion, } - performers := queryPerformers(ctx, t, sqb, &performerFilter, nil) + performers := queryPerformers(ctx, t, &performerFilter, nil) for _, performer := range performers { - verifyInt64(t, performer.Rating, ratingCriterion) + verifyIntPtr(t, performer.Rating, ratingCriterion) } return nil @@ -908,18 +1183,17 @@ func verifyPerformersRating(t *testing.T, ratingCriterion models.IntCriterionInp func TestPerformerQueryIsMissingRating(t *testing.T) { withTxn(func(ctx context.Context) error { - sqb := sqlite.PerformerReaderWriter isMissing := "rating" performerFilter := models.PerformerFilterType{ IsMissing: &isMissing, } - performers := queryPerformers(ctx, t, sqb, &performerFilter, nil) + performers := queryPerformers(ctx, t, &performerFilter, nil) assert.True(t, len(performers) > 0) for _, performer := range performers { - assert.True(t, !performer.Rating.Valid) + assert.Nil(t, performer.Rating) } return nil @@ -934,7 +1208,7 @@ func TestPerformerQueryIsMissingImage(t *testing.T) { } // ensure query does not error - performers, _, err := sqlite.PerformerReaderWriter.Query(ctx, performerFilter, nil) + performers, _, err := db.Performer.Query(ctx, performerFilter, nil) if err != nil { t.Errorf("Error querying performers: %s", err.Error()) } @@ -942,7 +1216,7 @@ func TestPerformerQueryIsMissingImage(t *testing.T) { assert.True(t, len(performers) > 0) for _, performer := range performers { - img, err := sqlite.PerformerReaderWriter.GetImage(ctx, performer.ID) + img, err := db.Performer.GetImage(ctx, performer.ID) if err != nil { t.Errorf("error getting performer image: %s", err.Error()) } @@ -963,7 +1237,7 @@ func TestPerformerQuerySortScenesCount(t *testing.T) { withTxn(func(ctx context.Context) error { // just ensure it queries without error - performers, _, err := sqlite.PerformerReaderWriter.Query(ctx, nil, findFilter) + performers, _, err := db.Performer.Query(ctx, nil, findFilter) if err != nil { t.Errorf("Error querying performers: %s", err.Error()) } @@ -978,7 +1252,7 @@ func TestPerformerQuerySortScenesCount(t *testing.T) { // sort in ascending order direction = models.SortDirectionEnumAsc - performers, _, err = sqlite.PerformerReaderWriter.Query(ctx, nil, findFilter) + performers, _, err = db.Performer.Query(ctx, nil, findFilter) if err != nil { t.Errorf("Error querying performers: %s", err.Error()) } @@ -992,10 +1266,173 @@ func TestPerformerQuerySortScenesCount(t *testing.T) { }) } +func TestPerformerCountByTagID(t *testing.T) { + withTxn(func(ctx context.Context) error { + sqb := db.Performer + count, err := sqb.CountByTagID(ctx, tagIDs[tagIdxWithPerformer]) + + if err != nil { + t.Errorf("Error counting performers: %s", err.Error()) + } + + assert.Equal(t, 1, count) + + count, err = sqb.CountByTagID(ctx, 0) + + if err != nil { + t.Errorf("Error counting performers: %s", err.Error()) + } + + assert.Equal(t, 0, count) + + return nil + }) +} + +func TestPerformerCount(t *testing.T) { + withTxn(func(ctx context.Context) error { + sqb := db.Performer + count, err := sqb.Count(ctx) + + if err != nil { + t.Errorf("Error counting performers: %s", err.Error()) + } + + assert.Equal(t, totalPerformers, count) + + return nil + }) +} + +func TestPerformerAll(t *testing.T) { + withTxn(func(ctx context.Context) error { + sqb := db.Performer + all, err := sqb.All(ctx) + + if err != nil { + t.Errorf("Error counting performers: %s", err.Error()) + } + + assert.Len(t, all, totalPerformers) + + return nil + }) +} + +func performersToIDs(i []*models.Performer) []int { + ret := make([]int, len(i)) + for i, v := range i { + ret[i] = v.ID + } + + return ret +} + +func TestPerformerStore_FindByStashID(t *testing.T) { + type args struct { + stashID models.StashID + } + tests := []struct { + name string + stashID models.StashID + expectedIDs []int + wantErr bool + }{ + { + name: "existing", + stashID: performerStashID(performerIdxWithScene), + expectedIDs: []int{performerIDs[performerIdxWithScene]}, + wantErr: false, + }, + { + name: "non-existing", + stashID: models.StashID{ + StashID: getPerformerStringValue(performerIdxWithScene, "stashid"), + Endpoint: "non-existing", + }, + expectedIDs: []int{}, + wantErr: false, + }, + } + + qb := db.Performer + + for _, tt := range tests { + runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { + got, err := qb.FindByStashID(ctx, tt.stashID) + if (err != nil) != tt.wantErr { + t.Errorf("PerformerStore.FindByStashID() error = %v, wantErr %v", err, tt.wantErr) + return + } + + assert.ElementsMatch(t, performersToIDs(got), tt.expectedIDs) + }) + } +} + +func TestPerformerStore_FindByStashIDStatus(t *testing.T) { + type args struct { + stashID models.StashID + } + tests := []struct { + name string + hasStashID bool + stashboxEndpoint string + include []int + exclude []int + wantErr bool + }{ + { + name: "existing", + hasStashID: true, + stashboxEndpoint: getPerformerStringValue(performerIdxWithScene, "endpoint"), + include: []int{performerIdxWithScene}, + wantErr: false, + }, + { + name: "non-existing", + hasStashID: true, + stashboxEndpoint: getPerformerStringValue(performerIdxWithScene, "non-existing"), + exclude: []int{performerIdxWithScene}, + wantErr: false, + }, + { + name: "!hasStashID", + hasStashID: false, + stashboxEndpoint: getPerformerStringValue(performerIdxWithScene, "endpoint"), + include: []int{performerIdxWithImage}, + exclude: []int{performerIdx2WithScene}, + wantErr: false, + }, + } + + qb := db.Performer + + for _, tt := range tests { + runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { + got, err := qb.FindByStashIDStatus(ctx, tt.hasStashID, tt.stashboxEndpoint) + if (err != nil) != tt.wantErr { + t.Errorf("PerformerStore.FindByStashIDStatus() error = %v, wantErr %v", err, tt.wantErr) + return + } + + include := indexesToIDs(performerIDs, tt.include) + exclude := indexesToIDs(performerIDs, tt.exclude) + + ids := performersToIDs(got) + + assert := assert.New(t) + for _, i := range include { + assert.Contains(ids, i) + } + for _, e := range exclude { + assert.NotContains(ids, e) + } + }) + } +} + // TODO Update // TODO Destroy // TODO Find -// TODO Count -// TODO All -// TODO AllSlim // TODO Query diff --git a/pkg/sqlite/scene.go b/pkg/sqlite/scene.go index 8a1807c12..dfc9412fd 100644 --- a/pkg/sqlite/scene.go +++ b/pkg/sqlite/scene.go @@ -317,12 +317,6 @@ func (qb *SceneStore) Update(ctx context.Context, updatedObject *models.Scene) e } func (qb *SceneStore) Destroy(ctx context.Context, id int) error { - // delete all related table rows - // TODO - this should be handled by a delete cascade - if err := qb.performersRepository().destroy(ctx, []int{id}); err != nil { - return err - } - // scene markers should be handled prior to calling destroy // galleries should be handled prior to calling destroy diff --git a/pkg/sqlite/setup_test.go b/pkg/sqlite/setup_test.go index d93ac82f8..6ac2d6114 100644 --- a/pkg/sqlite/setup_test.go +++ b/pkg/sqlite/setup_test.go @@ -552,7 +552,7 @@ func populateDB() error { return fmt.Errorf("error creating movies: %s", err.Error()) } - if err := createPerformers(ctx, sqlite.PerformerReaderWriter, performersNameCase, performersNameNoCase); err != nil { + if err := createPerformers(ctx, performersNameCase, performersNameNoCase); err != nil { return fmt.Errorf("error creating performers: %s", err.Error()) } @@ -584,7 +584,7 @@ func populateDB() error { return fmt.Errorf("error creating saved filters: %s", err.Error()) } - if err := linkPerformerTags(ctx, sqlite.PerformerReaderWriter); err != nil { + if err := linkPerformerTags(ctx); err != nil { return fmt.Errorf("error linking performer tags: %s", err.Error()) } @@ -1226,8 +1226,10 @@ func getPerformerStringValue(index int, field string) string { return getPrefixedStringValue("performer", index, field) } -func getPerformerNullStringValue(index int, field string) sql.NullString { - return getPrefixedNullStringValue("performer", index, field) +func getPerformerNullStringValue(index int, field string) string { + ret := getPrefixedNullStringValue("performer", index, field) + + return ret.String } func getPerformerBoolValue(index int) bool { @@ -1235,24 +1237,29 @@ func getPerformerBoolValue(index int) bool { return index == 1 } -func getPerformerBirthdate(index int) string { +func getPerformerBirthdate(index int) *models.Date { const minAge = 18 birthdate := time.Now() birthdate = birthdate.AddDate(-minAge-index, -1, -1) - return birthdate.Format("2006-01-02") + + ret := models.Date{ + Time: birthdate, + } + return &ret } -func getPerformerDeathDate(index int) models.SQLiteDate { +func getPerformerDeathDate(index int) *models.Date { if index != 5 { - return models.SQLiteDate{} + return nil } deathDate := time.Now() deathDate = deathDate.AddDate(-index+1, -1, -1) - return models.SQLiteDate{ - String: deathDate.Format("2006-01-02"), - Valid: true, + + ret := models.Date{ + Time: deathDate, } + return &ret } func getPerformerCareerLength(index int) *string { @@ -1268,8 +1275,17 @@ func getIgnoreAutoTag(index int) bool { return index%5 == 0 } +func performerStashID(i int) models.StashID { + return models.StashID{ + StashID: getPerformerStringValue(i, "stashid"), + Endpoint: getPerformerStringValue(i, "endpoint"), + } +} + // createPerformers creates n performers with plain Name and o performers with camel cased NaMe included -func createPerformers(ctx context.Context, pqb models.PerformerReaderWriter, n int, o int) error { +func createPerformers(ctx context.Context, n int, o int) error { + pqb := db.Performer + const namePlain = "Name" const nameNoCase = "NaMe" @@ -1285,34 +1301,39 @@ func createPerformers(ctx context.Context, pqb models.PerformerReaderWriter, n i // performers [ i ] and [ n + o - i - 1 ] should have similar names with only the Name!=NaMe part different performer := models.Performer{ - Name: sql.NullString{String: getPerformerStringValue(index, name), Valid: true}, - Checksum: getPerformerStringValue(i, checksumField), - URL: getPerformerNullStringValue(i, urlField), - Favorite: sql.NullBool{Bool: getPerformerBoolValue(i), Valid: true}, - Birthdate: models.SQLiteDate{ - String: getPerformerBirthdate(i), - Valid: true, - }, + Name: getPerformerStringValue(index, name), + Checksum: getPerformerStringValue(i, checksumField), + URL: getPerformerNullStringValue(i, urlField), + Favorite: getPerformerBoolValue(i), + Birthdate: getPerformerBirthdate(i), DeathDate: getPerformerDeathDate(i), - Details: sql.NullString{String: getPerformerStringValue(i, "Details"), Valid: true}, - Ethnicity: sql.NullString{String: getPerformerStringValue(i, "Ethnicity"), Valid: true}, - Rating: getRating(i), + Details: getPerformerStringValue(i, "Details"), + Ethnicity: getPerformerStringValue(i, "Ethnicity"), + Rating: getIntPtr(getRating(i)), IgnoreAutoTag: getIgnoreAutoTag(i), } careerLength := getPerformerCareerLength(i) if careerLength != nil { - performer.CareerLength = models.NullString(*careerLength) + performer.CareerLength = *careerLength } - created, err := pqb.Create(ctx, performer) + err := pqb.Create(ctx, &performer) if err != nil { return fmt.Errorf("Error creating performer %v+: %s", performer, err.Error()) } - performerIDs = append(performerIDs, created.ID) - performerNames = append(performerNames, created.Name.String) + if (index+1)%5 != 0 { + if err := pqb.UpdateStashIDs(ctx, performer.ID, []models.StashID{ + performerStashID(i), + }); err != nil { + return fmt.Errorf("setting performer stash ids: %w", err) + } + } + + performerIDs = append(performerIDs, performer.ID) + performerNames = append(performerNames, performer.Name) } return nil @@ -1581,7 +1602,8 @@ func doLinks(links [][2]int, fn func(idx1, idx2 int) error) error { return nil } -func linkPerformerTags(ctx context.Context, qb models.PerformerReaderWriter) error { +func linkPerformerTags(ctx context.Context) error { + qb := db.Performer return doLinks(performerTagLinks, func(performerIndex, tagIndex int) error { performerID := performerIDs[performerIndex] tagID := tagIDs[tagIndex] diff --git a/pkg/sqlite/tables.go b/pkg/sqlite/tables.go index 6acc985e7..99301d046 100644 --- a/pkg/sqlite/tables.go +++ b/pkg/sqlite/tables.go @@ -24,6 +24,9 @@ var ( scenesPerformersJoinTable = goqu.T(performersScenesTable) scenesStashIDsJoinTable = goqu.T("scene_stash_ids") scenesMoviesJoinTable = goqu.T(moviesScenesTable) + + performersTagsJoinTable = goqu.T(performersTagsTable) + performersStashIDsJoinTable = goqu.T("performer_stash_ids") ) var ( diff --git a/pkg/sqlite/tag_test.go b/pkg/sqlite/tag_test.go index 9309a826c..a351f28f4 100644 --- a/pkg/sqlite/tag_test.go +++ b/pkg/sqlite/tag_test.go @@ -1011,7 +1011,7 @@ func TestTagMerge(t *testing.T) { assert.Contains(g.TagIDs.List(), destID) // ensure performer points to new tag - performerTagIDs, err := sqlite.PerformerReaderWriter.GetTagIDs(ctx, performerIDs[performerIdxWithTwoTags]) + performerTagIDs, err := db.Performer.GetTagIDs(ctx, performerIDs[performerIdxWithTwoTags]) if err != nil { return err } diff --git a/pkg/sqlite/transaction.go b/pkg/sqlite/transaction.go index 42b65ad7b..c1f81d569 100644 --- a/pkg/sqlite/transaction.go +++ b/pkg/sqlite/transaction.go @@ -108,7 +108,7 @@ func (db *Database) TxnRepository() models.Repository { Gallery: db.Gallery, Image: db.Image, Movie: MovieReaderWriter, - Performer: PerformerReaderWriter, + Performer: db.Performer, Scene: db.Scene, SceneMarker: SceneMarkerReaderWriter, ScrapedItem: ScrapedItemReaderWriter,