Performer custom fields (#5487)

* Backend changes
* Show custom field values
* Add custom fields table input
* Add custom field filtering
* Add unit tests
* Include custom fields in import/export
* Anonymise performer custom fields
* Move json.Number handler functions to api
* Handle json.Number conversion in api
This commit is contained in:
WithoutPants
2024-12-03 13:49:55 +11:00
committed by GitHub
parent a0e09bbe5c
commit 8c8be22fe4
56 changed files with 2158 additions and 277 deletions

View File

@@ -188,7 +188,9 @@ func (i *Importer) createPerformers(ctx context.Context, names []string) ([]*mod
newPerformer := models.NewPerformer()
newPerformer.Name = name
err := i.PerformerWriter.Create(ctx, &newPerformer)
err := i.PerformerWriter.Create(ctx, &models.CreatePerformerInput{
Performer: &newPerformer,
})
if err != nil {
return nil, err
}

View File

@@ -201,8 +201,8 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) {
}
db.Performer.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3)
db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Run(func(args mock.Arguments) {
performer := args.Get(1).(*models.Performer)
db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.CreatePerformerInput")).Run(func(args mock.Arguments) {
performer := args.Get(1).(*models.CreatePerformerInput)
performer.ID = existingPerformerID
}).Return(nil)
@@ -235,7 +235,7 @@ func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) {
}
db.Performer.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once()
db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Return(errors.New("Create error"))
db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.CreatePerformerInput")).Return(errors.New("Create error"))
err := i.PreImport(testCtx)
assert.NotNil(t, err)

View File

@@ -274,7 +274,9 @@ func (i *Importer) createPerformers(ctx context.Context, names []string) ([]*mod
newPerformer := models.NewPerformer()
newPerformer.Name = name
err := i.PerformerWriter.Create(ctx, &newPerformer)
err := i.PerformerWriter.Create(ctx, &models.CreatePerformerInput{
Performer: &newPerformer,
})
if err != nil {
return nil, err
}

View File

@@ -163,8 +163,8 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) {
}
db.Performer.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3)
db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Run(func(args mock.Arguments) {
performer := args.Get(1).(*models.Performer)
db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.CreatePerformerInput")).Run(func(args mock.Arguments) {
performer := args.Get(1).(*models.CreatePerformerInput)
performer.ID = existingPerformerID
}).Return(nil)
@@ -197,7 +197,7 @@ func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) {
}
db.Performer.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once()
db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Return(errors.New("Create error"))
db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.CreatePerformerInput")).Return(errors.New("Create error"))
err := i.PreImport(testCtx)
assert.NotNil(t, err)

View File

@@ -0,0 +1,17 @@
package models
import "context"
type CustomFieldMap map[string]interface{}
type CustomFieldsInput struct {
// If populated, the entire custom fields map will be replaced with this value
Full map[string]interface{} `json:"full"`
// If populated, only the keys in this map will be updated
Partial map[string]interface{} `json:"partial"`
}
type CustomFieldsReader interface {
GetCustomFields(ctx context.Context, id int) (map[string]interface{}, error)
GetCustomFieldsBulk(ctx context.Context, ids []int) ([]CustomFieldMap, error)
}

View File

@@ -194,3 +194,9 @@ type PhashDistanceCriterionInput struct {
type OrientationCriterionInput struct {
Value []OrientationEnum `json:"value"`
}
type CustomFieldCriterionInput struct {
Field string `json:"field"`
Value []any `json:"value"`
Modifier CriterionModifier `json:"modifier"`
}

View File

@@ -65,6 +65,8 @@ type Performer struct {
StashIDs []models.StashID `json:"stash_ids,omitempty"`
IgnoreAutoTag bool `json:"ignore_auto_tag,omitempty"`
CustomFields map[string]interface{} `json:"custom_fields,omitempty"`
// deprecated - for import only
URL string `json:"url,omitempty"`
Twitter string `json:"twitter,omitempty"`

View File

@@ -80,11 +80,11 @@ func (_m *PerformerReaderWriter) CountByTagID(ctx context.Context, tagID int) (i
}
// Create provides a mock function with given fields: ctx, newPerformer
func (_m *PerformerReaderWriter) Create(ctx context.Context, newPerformer *models.Performer) error {
func (_m *PerformerReaderWriter) Create(ctx context.Context, newPerformer *models.CreatePerformerInput) error {
ret := _m.Called(ctx, newPerformer)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, *models.Performer) error); ok {
if rf, ok := ret.Get(0).(func(context.Context, *models.CreatePerformerInput) error); ok {
r0 = rf(ctx, newPerformer)
} else {
r0 = ret.Error(0)
@@ -314,6 +314,52 @@ func (_m *PerformerReaderWriter) GetAliases(ctx context.Context, relatedID int)
return r0, r1
}
// GetCustomFields provides a mock function with given fields: ctx, id
func (_m *PerformerReaderWriter) GetCustomFields(ctx context.Context, id int) (map[string]interface{}, error) {
ret := _m.Called(ctx, id)
var r0 map[string]interface{}
if rf, ok := ret.Get(0).(func(context.Context, int) map[string]interface{}); ok {
r0 = rf(ctx, id)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(map[string]interface{})
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
r1 = rf(ctx, id)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// GetCustomFieldsBulk provides a mock function with given fields: ctx, ids
func (_m *PerformerReaderWriter) GetCustomFieldsBulk(ctx context.Context, ids []int) ([]models.CustomFieldMap, error) {
ret := _m.Called(ctx, ids)
var r0 []models.CustomFieldMap
if rf, ok := ret.Get(0).(func(context.Context, []int) []models.CustomFieldMap); ok {
r0 = rf(ctx, ids)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]models.CustomFieldMap)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, []int) error); ok {
r1 = rf(ctx, ids)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// GetImage provides a mock function with given fields: ctx, performerID
func (_m *PerformerReaderWriter) GetImage(ctx context.Context, performerID int) ([]byte, error) {
ret := _m.Called(ctx, performerID)
@@ -502,11 +548,11 @@ func (_m *PerformerReaderWriter) QueryForAutoTag(ctx context.Context, words []st
}
// Update provides a mock function with given fields: ctx, updatedPerformer
func (_m *PerformerReaderWriter) Update(ctx context.Context, updatedPerformer *models.Performer) error {
func (_m *PerformerReaderWriter) Update(ctx context.Context, updatedPerformer *models.UpdatePerformerInput) error {
ret := _m.Called(ctx, updatedPerformer)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, *models.Performer) error); ok {
if rf, ok := ret.Get(0).(func(context.Context, *models.UpdatePerformerInput) error); ok {
r0 = rf(ctx, updatedPerformer)
} else {
r0 = ret.Error(0)

View File

@@ -39,6 +39,18 @@ type Performer struct {
StashIDs RelatedStashIDs `json:"stash_ids"`
}
type CreatePerformerInput struct {
*Performer
CustomFields map[string]interface{} `json:"custom_fields"`
}
type UpdatePerformerInput struct {
*Performer
CustomFields CustomFieldsInput `json:"custom_fields"`
}
func NewPerformer() Performer {
currentTime := time.Now()
return Performer{
@@ -80,6 +92,8 @@ type PerformerPartial struct {
Aliases *UpdateStrings
TagIDs *UpdateIDs
StashIDs *UpdateStashIDs
CustomFields CustomFieldsInput
}
func NewPerformerPartial() PerformerPartial {

View File

@@ -198,6 +198,9 @@ type PerformerFilterType struct {
CreatedAt *TimestampCriterionInput `json:"created_at"`
// Filter by updated at
UpdatedAt *TimestampCriterionInput `json:"updated_at"`
// Filter by custom fields
CustomFields []CustomFieldCriterionInput `json:"custom_fields"`
}
type PerformerCreateInput struct {
@@ -234,6 +237,8 @@ type PerformerCreateInput struct {
HairColor *string `json:"hair_color"`
Weight *int `json:"weight"`
IgnoreAutoTag *bool `json:"ignore_auto_tag"`
CustomFields map[string]interface{} `json:"custom_fields"`
}
type PerformerUpdateInput struct {
@@ -271,4 +276,6 @@ type PerformerUpdateInput struct {
HairColor *string `json:"hair_color"`
Weight *int `json:"weight"`
IgnoreAutoTag *bool `json:"ignore_auto_tag"`
CustomFields CustomFieldsInput `json:"custom_fields"`
}

View File

@@ -43,12 +43,12 @@ type PerformerCounter interface {
// PerformerCreator provides methods to create performers.
type PerformerCreator interface {
Create(ctx context.Context, newPerformer *Performer) error
Create(ctx context.Context, newPerformer *CreatePerformerInput) error
}
// PerformerUpdater provides methods to update performers.
type PerformerUpdater interface {
Update(ctx context.Context, updatedPerformer *Performer) error
Update(ctx context.Context, updatedPerformer *UpdatePerformerInput) error
UpdatePartial(ctx context.Context, id int, updatedPerformer PerformerPartial) (*Performer, error)
UpdateImage(ctx context.Context, performerID int, image []byte) error
}
@@ -80,6 +80,8 @@ type PerformerReader interface {
TagIDLoader
URLLoader
CustomFieldsReader
All(ctx context.Context) ([]*Performer, error)
GetImage(ctx context.Context, performerID int) ([]byte, error)
HasImage(ctx context.Context, performerID int) (bool, error)

View File

@@ -17,6 +17,7 @@ type ImageAliasStashIDGetter interface {
models.AliasLoader
models.StashIDLoader
models.URLLoader
models.CustomFieldsReader
}
// ToJSON converts a Performer object into its JSON equivalent.
@@ -87,6 +88,12 @@ func ToJSON(ctx context.Context, reader ImageAliasStashIDGetter, performer *mode
newPerformerJSON.StashIDs = performer.StashIDs.List()
var err error
newPerformerJSON.CustomFields, err = reader.GetCustomFields(ctx, performer.ID)
if err != nil {
return nil, fmt.Errorf("getting performer custom fields: %v", err)
}
image, err := reader.GetImage(ctx, performer.ID)
if err != nil {
logger.Errorf("Error getting performer image: %v", err)

View File

@@ -15,9 +15,11 @@ import (
)
const (
performerID = 1
noImageID = 2
errImageID = 3
performerID = 1
noImageID = 2
errImageID = 3
customFieldsID = 4
errCustomFieldsID = 5
)
const (
@@ -50,6 +52,11 @@ var (
penisLength = 1.23
circumcisedEnum = models.CircumisedEnumCut
circumcised = circumcisedEnum.String()
emptyCustomFields = make(map[string]interface{})
customFields = map[string]interface{}{
"customField1": "customValue1",
}
)
var imageBytes = []byte("imageBytes")
@@ -118,8 +125,8 @@ func createEmptyPerformer(id int) models.Performer {
}
}
func createFullJSONPerformer(name string, image string) *jsonschema.Performer {
return &jsonschema.Performer{
func createFullJSONPerformer(name string, image string, withCustomFields bool) *jsonschema.Performer {
ret := &jsonschema.Performer{
Name: name,
Disambiguation: disambiguation,
URLs: []string{url, twitter, instagram},
@@ -152,7 +159,13 @@ func createFullJSONPerformer(name string, image string) *jsonschema.Performer {
Weight: weight,
StashIDs: stashIDs,
IgnoreAutoTag: autoTagIgnored,
CustomFields: emptyCustomFields,
}
if withCustomFields {
ret.CustomFields = customFields
}
return ret
}
func createEmptyJSONPerformer() *jsonschema.Performer {
@@ -166,13 +179,15 @@ func createEmptyJSONPerformer() *jsonschema.Performer {
UpdatedAt: json.JSONTime{
Time: updateTime,
},
CustomFields: emptyCustomFields,
}
}
type testScenario struct {
input models.Performer
expected *jsonschema.Performer
err bool
input models.Performer
customFields map[string]interface{}
expected *jsonschema.Performer
err bool
}
var scenarios []testScenario
@@ -181,20 +196,36 @@ func initTestTable() {
scenarios = []testScenario{
{
*createFullPerformer(performerID, performerName),
createFullJSONPerformer(performerName, image),
emptyCustomFields,
createFullJSONPerformer(performerName, image, false),
false,
},
{
*createFullPerformer(customFieldsID, performerName),
customFields,
createFullJSONPerformer(performerName, image, true),
false,
},
{
createEmptyPerformer(noImageID),
emptyCustomFields,
createEmptyJSONPerformer(),
false,
},
{
*createFullPerformer(errImageID, performerName),
createFullJSONPerformer(performerName, ""),
emptyCustomFields,
createFullJSONPerformer(performerName, "", false),
// failure to get image should not cause an error
false,
},
{
*createFullPerformer(errCustomFieldsID, performerName),
customFields,
nil,
// failure to get custom fields should cause an error
true,
},
}
}
@@ -204,11 +235,19 @@ func TestToJSON(t *testing.T) {
db := mocks.NewDatabase()
imageErr := errors.New("error getting image")
customFieldsErr := errors.New("error getting custom fields")
db.Performer.On("GetImage", testCtx, performerID).Return(imageBytes, nil).Once()
db.Performer.On("GetImage", testCtx, customFieldsID).Return(imageBytes, nil).Once()
db.Performer.On("GetImage", testCtx, noImageID).Return(nil, nil).Once()
db.Performer.On("GetImage", testCtx, errImageID).Return(nil, imageErr).Once()
db.Performer.On("GetCustomFields", testCtx, performerID).Return(emptyCustomFields, nil).Once()
db.Performer.On("GetCustomFields", testCtx, customFieldsID).Return(customFields, nil).Once()
db.Performer.On("GetCustomFields", testCtx, noImageID).Return(emptyCustomFields, nil).Once()
db.Performer.On("GetCustomFields", testCtx, errImageID).Return(emptyCustomFields, nil).Once()
db.Performer.On("GetCustomFields", testCtx, errCustomFieldsID).Return(nil, customFieldsErr).Once()
for i, s := range scenarios {
tag := s.input
json, err := ToJSON(testCtx, db.Performer, &tag)

View File

@@ -25,13 +25,15 @@ type Importer struct {
Input jsonschema.Performer
MissingRefBehaviour models.ImportMissingRefEnum
ID int
performer models.Performer
imageData []byte
ID int
performer models.Performer
customFields models.CustomFieldMap
imageData []byte
}
func (i *Importer) PreImport(ctx context.Context) error {
i.performer = performerJSONToPerformer(i.Input)
i.customFields = i.Input.CustomFields
if err := i.populateTags(ctx); err != nil {
return err
@@ -165,7 +167,10 @@ func (i *Importer) FindExistingID(ctx context.Context) (*int, error) {
}
func (i *Importer) Create(ctx context.Context) (*int, error) {
err := i.ReaderWriter.Create(ctx, &i.performer)
err := i.ReaderWriter.Create(ctx, &models.CreatePerformerInput{
Performer: &i.performer,
CustomFields: i.customFields,
})
if err != nil {
return nil, fmt.Errorf("error creating performer: %v", err)
}
@@ -175,9 +180,13 @@ func (i *Importer) Create(ctx context.Context) (*int, error) {
}
func (i *Importer) Update(ctx context.Context, id int) error {
performer := i.performer
performer.ID = id
err := i.ReaderWriter.Update(ctx, &performer)
i.performer.ID = id
err := i.ReaderWriter.Update(ctx, &models.UpdatePerformerInput{
Performer: &i.performer,
CustomFields: models.CustomFieldsInput{
Full: i.customFields,
},
})
if err != nil {
return fmt.Errorf("error updating existing performer: %v", err)
}

View File

@@ -53,13 +53,14 @@ func TestImporterPreImport(t *testing.T) {
assert.NotNil(t, err)
i.Input = *createFullJSONPerformer(performerName, image)
i.Input = *createFullJSONPerformer(performerName, image, true)
err = i.PreImport(testCtx)
assert.Nil(t, err)
expectedPerformer := *createFullPerformer(0, performerName)
assert.Equal(t, expectedPerformer, i.performer)
assert.Equal(t, models.CustomFieldMap(customFields), i.customFields)
}
func TestImporterPreImportWithTag(t *testing.T) {
@@ -234,10 +235,18 @@ func TestCreate(t *testing.T) {
Name: performerName,
}
performerInput := models.CreatePerformerInput{
Performer: &performer,
}
performerErr := models.Performer{
Name: performerNameErr,
}
performerErrInput := models.CreatePerformerInput{
Performer: &performerErr,
}
i := Importer{
ReaderWriter: db.Performer,
TagWriter: db.Tag,
@@ -245,11 +254,11 @@ func TestCreate(t *testing.T) {
}
errCreate := errors.New("Create error")
db.Performer.On("Create", testCtx, &performer).Run(func(args mock.Arguments) {
arg := args.Get(1).(*models.Performer)
db.Performer.On("Create", testCtx, &performerInput).Run(func(args mock.Arguments) {
arg := args.Get(1).(*models.CreatePerformerInput)
arg.ID = performerID
}).Return(nil).Once()
db.Performer.On("Create", testCtx, &performerErr).Return(errCreate).Once()
db.Performer.On("Create", testCtx, &performerErrInput).Return(errCreate).Once()
id, err := i.Create(testCtx)
assert.Equal(t, performerID, *id)
@@ -284,7 +293,10 @@ func TestUpdate(t *testing.T) {
// id needs to be set for the mock input
performer.ID = performerID
db.Performer.On("Update", testCtx, &performer).Return(nil).Once()
performerInput := models.UpdatePerformerInput{
Performer: &performer,
}
db.Performer.On("Update", testCtx, &performerInput).Return(nil).Once()
err := i.Update(testCtx, performerID)
assert.Nil(t, err)
@@ -293,7 +305,10 @@ func TestUpdate(t *testing.T) {
// need to set id separately
performerErr.ID = errImageID
db.Performer.On("Update", testCtx, &performerErr).Return(errUpdate).Once()
performerErrInput := models.UpdatePerformerInput{
Performer: &performerErr,
}
db.Performer.On("Update", testCtx, &performerErrInput).Return(errUpdate).Once()
err = i.Update(testCtx, errImageID)
assert.NotNil(t, err)

View File

@@ -325,7 +325,9 @@ func (i *Importer) createPerformers(ctx context.Context, names []string) ([]*mod
newPerformer := models.NewPerformer()
newPerformer.Name = name
err := i.PerformerWriter.Create(ctx, &newPerformer)
err := i.PerformerWriter.Create(ctx, &models.CreatePerformerInput{
Performer: &newPerformer,
})
if err != nil {
return nil, err
}

View File

@@ -327,8 +327,8 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) {
}
db.Performer.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3)
db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Run(func(args mock.Arguments) {
p := args.Get(1).(*models.Performer)
db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.CreatePerformerInput")).Run(func(args mock.Arguments) {
p := args.Get(1).(*models.CreatePerformerInput)
p.ID = existingPerformerID
}).Return(nil)
@@ -361,7 +361,7 @@ func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) {
}
db.Performer.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once()
db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Return(errors.New("Create error"))
db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.CreatePerformerInput")).Return(errors.New("Create error"))
err := i.PreImport(testCtx)
assert.NotNil(t, err)

View File

@@ -600,6 +600,10 @@ func (db *Anonymiser) anonymisePerformers(ctx context.Context) error {
return err
}
if err := db.anonymiseCustomFields(ctx, goqu.T(performersCustomFieldsTable.GetTable()), "performer_id"); err != nil {
return err
}
return nil
}
@@ -1050,3 +1054,73 @@ func (db *Anonymiser) obfuscateString(in string, dict string) string {
return out.String()
}
func (db *Anonymiser) anonymiseCustomFields(ctx context.Context, table exp.IdentifierExpression, idColumn string) error {
lastID := 0
lastField := ""
total := 0
const logEvery = 10000
for gotSome := true; gotSome; {
if err := txn.WithTxn(ctx, db, func(ctx context.Context) error {
query := dialect.From(table).Select(
table.Col(idColumn),
table.Col("field"),
table.Col("value"),
).Where(
goqu.L("("+idColumn+", field)").Gt(goqu.L("(?, ?)", lastID, lastField)),
).Order(
table.Col(idColumn).Asc(), table.Col("field").Asc(),
).Limit(1000)
gotSome = false
const single = false
return queryFunc(ctx, query, single, func(rows *sqlx.Rows) error {
var (
id int
field string
value string
)
if err := rows.Scan(
&id,
&field,
&value,
); err != nil {
return err
}
set := goqu.Record{}
set["field"] = db.obfuscateString(field, letters)
set["value"] = db.obfuscateString(value, letters)
if len(set) > 0 {
stmt := dialect.Update(table).Set(set).Where(
table.Col(idColumn).Eq(id),
table.Col("field").Eq(field),
)
if _, err := exec(ctx, stmt); err != nil {
return fmt.Errorf("anonymising %s: %w", table.GetTable(), err)
}
}
lastID = id
lastField = field
gotSome = true
total++
if total%logEvery == 0 {
logger.Infof("Anonymised %d %s custom fields", total, table.GetTable())
}
return nil
})
}); err != nil {
return err
}
}
return nil
}

308
pkg/sqlite/custom_fields.go Normal file
View File

@@ -0,0 +1,308 @@
package sqlite
import (
"context"
"fmt"
"regexp"
"strings"
"github.com/doug-martin/goqu/v9"
"github.com/doug-martin/goqu/v9/exp"
"github.com/jmoiron/sqlx"
"github.com/stashapp/stash/pkg/models"
)
const maxCustomFieldNameLength = 64
type customFieldsStore struct {
table exp.IdentifierExpression
fk exp.IdentifierExpression
}
func (s *customFieldsStore) deleteForID(ctx context.Context, id int) error {
table := s.table
q := dialect.Delete(table).Where(s.fk.Eq(id))
_, err := exec(ctx, q)
if err != nil {
return fmt.Errorf("deleting from %s: %w", s.table.GetTable(), err)
}
return nil
}
func (s *customFieldsStore) SetCustomFields(ctx context.Context, id int, values models.CustomFieldsInput) error {
var partial bool
var valMap map[string]interface{}
switch {
case values.Full != nil:
partial = false
valMap = values.Full
case values.Partial != nil:
partial = true
valMap = values.Partial
default:
return nil
}
if err := s.validateCustomFields(valMap); err != nil {
return err
}
return s.setCustomFields(ctx, id, valMap, partial)
}
func (s *customFieldsStore) validateCustomFields(values map[string]interface{}) error {
// ensure that custom field names are valid
// no leading or trailing whitespace, no empty strings
for k := range values {
if err := s.validateCustomFieldName(k); err != nil {
return fmt.Errorf("custom field name %q: %w", k, err)
}
}
return nil
}
func (s *customFieldsStore) validateCustomFieldName(fieldName string) error {
// ensure that custom field names are valid
// no leading or trailing whitespace, no empty strings
if strings.TrimSpace(fieldName) == "" {
return fmt.Errorf("custom field name cannot be empty")
}
if fieldName != strings.TrimSpace(fieldName) {
return fmt.Errorf("custom field name cannot have leading or trailing whitespace")
}
if len(fieldName) > maxCustomFieldNameLength {
return fmt.Errorf("custom field name must be less than %d characters", maxCustomFieldNameLength+1)
}
return nil
}
func getSQLValueFromCustomFieldInput(input interface{}) (interface{}, error) {
switch v := input.(type) {
case []interface{}, map[string]interface{}:
// TODO - in future it would be nice to convert to a JSON string
// however, we would need some way to differentiate between a JSON string and a regular string
// for now, we will not support objects and arrays
return nil, fmt.Errorf("unsupported custom field value type: %T", input)
default:
return v, nil
}
}
func (s *customFieldsStore) sqlValueToValue(value interface{}) interface{} {
// TODO - if we ever support objects and arrays we will need to add support here
return value
}
func (s *customFieldsStore) setCustomFields(ctx context.Context, id int, values map[string]interface{}, partial bool) error {
if !partial {
// delete existing custom fields
if err := s.deleteForID(ctx, id); err != nil {
return err
}
}
if len(values) == 0 {
return nil
}
conflictKey := s.fk.GetCol().(string) + ", field"
// upsert new custom fields
q := dialect.Insert(s.table).Prepared(true).Cols(s.fk, "field", "value").
OnConflict(goqu.DoUpdate(conflictKey, goqu.Record{"value": goqu.I("excluded.value")}))
r := make([]interface{}, len(values))
var i int
for key, value := range values {
v, err := getSQLValueFromCustomFieldInput(value)
if err != nil {
return fmt.Errorf("getting SQL value for field %q: %w", key, err)
}
r[i] = goqu.Record{"field": key, "value": v, s.fk.GetCol().(string): id}
i++
}
if _, err := exec(ctx, q.Rows(r...)); err != nil {
return fmt.Errorf("inserting custom fields: %w", err)
}
return nil
}
func (s *customFieldsStore) GetCustomFields(ctx context.Context, id int) (map[string]interface{}, error) {
q := dialect.Select("field", "value").From(s.table).Where(s.fk.Eq(id))
const single = false
ret := make(map[string]interface{})
err := queryFunc(ctx, q, single, func(rows *sqlx.Rows) error {
var field string
var value interface{}
if err := rows.Scan(&field, &value); err != nil {
return fmt.Errorf("scanning custom fields: %w", err)
}
ret[field] = s.sqlValueToValue(value)
return nil
})
if err != nil {
return nil, fmt.Errorf("getting custom fields: %w", err)
}
return ret, nil
}
func (s *customFieldsStore) GetCustomFieldsBulk(ctx context.Context, ids []int) ([]models.CustomFieldMap, error) {
q := dialect.Select(s.fk.As("id"), "field", "value").From(s.table).Where(s.fk.In(ids))
const single = false
ret := make([]models.CustomFieldMap, len(ids))
idi := make(map[int]int, len(ids))
for i, id := range ids {
idi[id] = i
}
err := queryFunc(ctx, q, single, func(rows *sqlx.Rows) error {
var id int
var field string
var value interface{}
if err := rows.Scan(&id, &field, &value); err != nil {
return fmt.Errorf("scanning custom fields: %w", err)
}
i := idi[id]
m := ret[i]
if m == nil {
m = make(map[string]interface{})
ret[i] = m
}
m[field] = s.sqlValueToValue(value)
return nil
})
if err != nil {
return nil, fmt.Errorf("getting custom fields: %w", err)
}
return ret, nil
}
type customFieldsFilterHandler struct {
table string
fkCol string
c []models.CustomFieldCriterionInput
idCol string
}
func (h *customFieldsFilterHandler) innerJoin(f *filterBuilder, as string, field string) {
joinOn := fmt.Sprintf("%s = %s.%s AND %s.field = ?", h.idCol, as, h.fkCol, as)
f.addInnerJoin(h.table, as, joinOn, field)
}
func (h *customFieldsFilterHandler) leftJoin(f *filterBuilder, as string, field string) {
joinOn := fmt.Sprintf("%s = %s.%s AND %s.field = ?", h.idCol, as, h.fkCol, as)
f.addLeftJoin(h.table, as, joinOn, field)
}
func (h *customFieldsFilterHandler) handleCriterion(f *filterBuilder, joinAs string, cc models.CustomFieldCriterionInput) {
// convert values
cv := make([]interface{}, len(cc.Value))
for i, v := range cc.Value {
var err error
cv[i], err = getSQLValueFromCustomFieldInput(v)
if err != nil {
f.setError(err)
return
}
}
switch cc.Modifier {
case models.CriterionModifierEquals:
h.innerJoin(f, joinAs, cc.Field)
f.addWhere(fmt.Sprintf("%[1]s.value IN %s", joinAs, getInBinding(len(cv))), cv...)
case models.CriterionModifierNotEquals:
h.innerJoin(f, joinAs, cc.Field)
f.addWhere(fmt.Sprintf("%[1]s.value NOT IN %s", joinAs, getInBinding(len(cv))), cv...)
case models.CriterionModifierIncludes:
clauses := make([]sqlClause, len(cv))
for i, v := range cv {
clauses[i] = makeClause(fmt.Sprintf("%s.value LIKE ?", joinAs), fmt.Sprintf("%%%v%%", v))
}
h.innerJoin(f, joinAs, cc.Field)
f.whereClauses = append(f.whereClauses, clauses...)
case models.CriterionModifierExcludes:
for _, v := range cv {
f.addWhere(fmt.Sprintf("%[1]s.value NOT LIKE ?", joinAs), fmt.Sprintf("%%%v%%", v))
}
h.leftJoin(f, joinAs, cc.Field)
case models.CriterionModifierMatchesRegex:
for _, v := range cv {
vs, ok := v.(string)
if !ok {
f.setError(fmt.Errorf("unsupported custom field criterion value type: %T", v))
}
if _, err := regexp.Compile(vs); err != nil {
f.setError(err)
return
}
f.addWhere(fmt.Sprintf("(%s.value regexp ?)", joinAs), v)
}
h.innerJoin(f, joinAs, cc.Field)
case models.CriterionModifierNotMatchesRegex:
for _, v := range cv {
vs, ok := v.(string)
if !ok {
f.setError(fmt.Errorf("unsupported custom field criterion value type: %T", v))
}
if _, err := regexp.Compile(vs); err != nil {
f.setError(err)
return
}
f.addWhere(fmt.Sprintf("(%s.value IS NULL OR %[1]s.value NOT regexp ?)", joinAs), v)
}
h.leftJoin(f, joinAs, cc.Field)
case models.CriterionModifierIsNull:
h.leftJoin(f, joinAs, cc.Field)
f.addWhere(fmt.Sprintf("%s.value IS NULL OR TRIM(%[1]s.value) = ''", joinAs))
case models.CriterionModifierNotNull:
h.innerJoin(f, joinAs, cc.Field)
f.addWhere(fmt.Sprintf("TRIM(%[1]s.value) != ''", joinAs))
case models.CriterionModifierBetween:
if len(cv) != 2 {
f.setError(fmt.Errorf("expected 2 values for custom field criterion modifier BETWEEN, got %d", len(cv)))
return
}
h.innerJoin(f, joinAs, cc.Field)
f.addWhere(fmt.Sprintf("%s.value BETWEEN ? AND ?", joinAs), cv[0], cv[1])
case models.CriterionModifierNotBetween:
h.innerJoin(f, joinAs, cc.Field)
f.addWhere(fmt.Sprintf("%s.value NOT BETWEEN ? AND ?", joinAs), cv[0], cv[1])
case models.CriterionModifierLessThan:
if len(cv) != 1 {
f.setError(fmt.Errorf("expected 1 value for custom field criterion modifier LESS_THAN, got %d", len(cv)))
return
}
h.innerJoin(f, joinAs, cc.Field)
f.addWhere(fmt.Sprintf("%s.value < ?", joinAs), cv[0])
case models.CriterionModifierGreaterThan:
if len(cv) != 1 {
f.setError(fmt.Errorf("expected 1 value for custom field criterion modifier LESS_THAN, got %d", len(cv)))
return
}
h.innerJoin(f, joinAs, cc.Field)
f.addWhere(fmt.Sprintf("%s.value > ?", joinAs), cv[0])
default:
f.setError(fmt.Errorf("unsupported custom field criterion modifier: %s", cc.Modifier))
}
}
func (h *customFieldsFilterHandler) handle(ctx context.Context, f *filterBuilder) {
if len(h.c) == 0 {
return
}
for i, cc := range h.c {
join := fmt.Sprintf("custom_fields_%d", i)
h.handleCriterion(f, join, cc)
}
}

View File

@@ -0,0 +1,176 @@
//go:build integration
// +build integration
package sqlite_test
import (
"context"
"testing"
"github.com/stashapp/stash/pkg/models"
"github.com/stretchr/testify/assert"
)
func TestSetCustomFields(t *testing.T) {
performerIdx := performerIdx1WithScene
mergeCustomFields := func(i map[string]interface{}) map[string]interface{} {
m := getPerformerCustomFields(performerIdx)
for k, v := range i {
m[k] = v
}
return m
}
tests := []struct {
name string
input models.CustomFieldsInput
expected map[string]interface{}
wantErr bool
}{
{
"valid full",
models.CustomFieldsInput{
Full: map[string]interface{}{
"key": "value",
},
},
map[string]interface{}{
"key": "value",
},
false,
},
{
"valid partial",
models.CustomFieldsInput{
Partial: map[string]interface{}{
"key": "value",
},
},
mergeCustomFields(map[string]interface{}{
"key": "value",
}),
false,
},
{
"valid partial overwrite",
models.CustomFieldsInput{
Partial: map[string]interface{}{
"real": float64(4.56),
},
},
mergeCustomFields(map[string]interface{}{
"real": float64(4.56),
}),
false,
},
{
"leading space full",
models.CustomFieldsInput{
Full: map[string]interface{}{
" key": "value",
},
},
nil,
true,
},
{
"trailing space full",
models.CustomFieldsInput{
Full: map[string]interface{}{
"key ": "value",
},
},
nil,
true,
},
{
"leading space partial",
models.CustomFieldsInput{
Partial: map[string]interface{}{
" key": "value",
},
},
nil,
true,
},
{
"trailing space partial",
models.CustomFieldsInput{
Partial: map[string]interface{}{
"key ": "value",
},
},
nil,
true,
},
{
"big key full",
models.CustomFieldsInput{
Full: map[string]interface{}{
"12345678901234567890123456789012345678901234567890123456789012345": "value",
},
},
nil,
true,
},
{
"big key partial",
models.CustomFieldsInput{
Partial: map[string]interface{}{
"12345678901234567890123456789012345678901234567890123456789012345": "value",
},
},
nil,
true,
},
{
"empty key full",
models.CustomFieldsInput{
Full: map[string]interface{}{
"": "value",
},
},
nil,
true,
},
{
"empty key partial",
models.CustomFieldsInput{
Partial: map[string]interface{}{
"": "value",
},
},
nil,
true,
},
}
// use performer custom fields store
store := db.Performer
id := performerIDs[performerIdx]
assert := assert.New(t)
for _, tt := range tests {
runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) {
err := store.SetCustomFields(ctx, id, tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("SetCustomFields() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr {
return
}
actual, err := store.GetCustomFields(ctx, id)
if err != nil {
t.Errorf("GetCustomFields() error = %v", err)
return
}
assert.Equal(tt.expected, actual)
})
}
}

View File

@@ -34,7 +34,7 @@ const (
cacheSizeEnv = "STASH_SQLITE_CACHE_SIZE"
)
var appSchemaVersion uint = 70
var appSchemaVersion uint = 71
//go:embed migrations/*.sql
var migrationsBox embed.FS

View File

@@ -95,6 +95,7 @@ type join struct {
as string
onClause string
joinType string
args []interface{}
}
// equals returns true if the other join alias/table is equal to this one
@@ -229,12 +230,13 @@ func (f *filterBuilder) not(n *filterBuilder) {
// The AS is omitted if as is empty.
// This method does not add a join if it its alias/table name is already
// present in another existing join.
func (f *filterBuilder) addLeftJoin(table, as, onClause string) {
func (f *filterBuilder) addLeftJoin(table, as, onClause string, args ...interface{}) {
newJoin := join{
table: table,
as: as,
onClause: onClause,
joinType: "LEFT",
args: args,
}
f.joins.add(newJoin)
@@ -245,12 +247,13 @@ func (f *filterBuilder) addLeftJoin(table, as, onClause string) {
// The AS is omitted if as is empty.
// This method does not add a join if it its alias/table name is already
// present in another existing join.
func (f *filterBuilder) addInnerJoin(table, as, onClause string) {
func (f *filterBuilder) addInnerJoin(table, as, onClause string, args ...interface{}) {
newJoin := join{
table: table,
as: as,
onClause: onClause,
joinType: "INNER",
args: args,
}
f.joins.add(newJoin)

View File

@@ -0,0 +1,9 @@
CREATE TABLE `performer_custom_fields` (
`performer_id` integer NOT NULL,
`field` varchar(64) NOT NULL,
`value` BLOB NOT NULL,
PRIMARY KEY (`performer_id`, `field`),
foreign key(`performer_id`) references `performers`(`id`) on delete CASCADE
);
CREATE INDEX `index_performer_custom_fields_field_value` ON `performer_custom_fields` (`field`, `value`);

View File

@@ -226,6 +226,7 @@ var (
type PerformerStore struct {
blobJoinQueryBuilder
customFieldsStore
tableMgr *table
}
@@ -236,6 +237,10 @@ func NewPerformerStore(blobStore *BlobStore) *PerformerStore {
blobStore: blobStore,
joinTable: performerTable,
},
customFieldsStore: customFieldsStore{
table: performersCustomFieldsTable,
fk: performersCustomFieldsTable.Col(performerIDColumn),
},
tableMgr: performerTableMgr,
}
}
@@ -248,9 +253,9 @@ func (qb *PerformerStore) selectDataset() *goqu.SelectDataset {
return dialect.From(qb.table()).Select(qb.table().All())
}
func (qb *PerformerStore) Create(ctx context.Context, newObject *models.Performer) error {
func (qb *PerformerStore) Create(ctx context.Context, newObject *models.CreatePerformerInput) error {
var r performerRow
r.fromPerformer(*newObject)
r.fromPerformer(*newObject.Performer)
id, err := qb.tableMgr.insertID(ctx, r)
if err != nil {
@@ -282,12 +287,17 @@ func (qb *PerformerStore) Create(ctx context.Context, newObject *models.Performe
}
}
const partial = false
if err := qb.setCustomFields(ctx, id, newObject.CustomFields, partial); err != nil {
return err
}
updated, err := qb.find(ctx, id)
if err != nil {
return fmt.Errorf("finding after create: %w", err)
}
*newObject = *updated
*newObject.Performer = *updated
return nil
}
@@ -330,12 +340,16 @@ func (qb *PerformerStore) UpdatePartial(ctx context.Context, id int, partial mod
}
}
if err := qb.SetCustomFields(ctx, id, partial.CustomFields); err != nil {
return nil, err
}
return qb.find(ctx, id)
}
func (qb *PerformerStore) Update(ctx context.Context, updatedObject *models.Performer) error {
func (qb *PerformerStore) Update(ctx context.Context, updatedObject *models.UpdatePerformerInput) error {
var r performerRow
r.fromPerformer(*updatedObject)
r.fromPerformer(*updatedObject.Performer)
if err := qb.tableMgr.updateByID(ctx, updatedObject.ID, r); err != nil {
return err
@@ -365,6 +379,10 @@ func (qb *PerformerStore) Update(ctx context.Context, updatedObject *models.Perf
}
}
if err := qb.SetCustomFields(ctx, updatedObject.ID, updatedObject.CustomFields); err != nil {
return err
}
return nil
}

View File

@@ -203,6 +203,13 @@ func (qb *performerFilterHandler) criterionHandler() criterionHandler {
performerRepository.tags.innerJoin(f, "performer_tag", "performers.id")
},
},
&customFieldsFilterHandler{
table: performersCustomFieldsTable.GetTable(),
fkCol: performerIDColumn,
c: filter.CustomFields,
idCol: "performers.id",
},
}
}

View File

@@ -16,6 +16,12 @@ import (
"github.com/stretchr/testify/assert"
)
var testCustomFields = map[string]interface{}{
"string": "aaa",
"int": int64(123), // int64 to match the type of the field in the database
"real": 1.23,
}
func loadPerformerRelationships(ctx context.Context, expected models.Performer, actual *models.Performer) error {
if expected.Aliases.Loaded() {
if err := actual.LoadAliases(ctx, db.Performer); err != nil {
@@ -81,57 +87,62 @@ func Test_PerformerStore_Create(t *testing.T) {
tests := []struct {
name string
newObject models.Performer
newObject models.CreatePerformerInput
wantErr bool
}{
{
"full",
models.Performer{
Name: name,
Disambiguation: disambiguation,
Gender: &gender,
URLs: models.NewRelatedStrings(urls),
Birthdate: &birthdate,
Ethnicity: ethnicity,
Country: country,
EyeColor: eyeColor,
Height: &height,
Measurements: measurements,
FakeTits: fakeTits,
PenisLength: &penisLength,
Circumcised: &circumcised,
CareerLength: careerLength,
Tattoos: tattoos,
Piercings: piercings,
Favorite: favorite,
Rating: &rating,
Details: details,
DeathDate: &deathdate,
HairColor: hairColor,
Weight: &weight,
IgnoreAutoTag: ignoreAutoTag,
TagIDs: models.NewRelatedIDs([]int{tagIDs[tagIdx1WithPerformer], tagIDs[tagIdx1WithDupName]}),
Aliases: models.NewRelatedStrings(aliases),
StashIDs: models.NewRelatedStashIDs([]models.StashID{
{
StashID: stashID1,
Endpoint: endpoint1,
},
{
StashID: stashID2,
Endpoint: endpoint2,
},
}),
CreatedAt: createdAt,
UpdatedAt: updatedAt,
models.CreatePerformerInput{
Performer: &models.Performer{
Name: name,
Disambiguation: disambiguation,
Gender: &gender,
URLs: models.NewRelatedStrings(urls),
Birthdate: &birthdate,
Ethnicity: ethnicity,
Country: country,
EyeColor: eyeColor,
Height: &height,
Measurements: measurements,
FakeTits: fakeTits,
PenisLength: &penisLength,
Circumcised: &circumcised,
CareerLength: careerLength,
Tattoos: tattoos,
Piercings: piercings,
Favorite: favorite,
Rating: &rating,
Details: details,
DeathDate: &deathdate,
HairColor: hairColor,
Weight: &weight,
IgnoreAutoTag: ignoreAutoTag,
TagIDs: models.NewRelatedIDs([]int{tagIDs[tagIdx1WithPerformer], tagIDs[tagIdx1WithDupName]}),
Aliases: models.NewRelatedStrings(aliases),
StashIDs: models.NewRelatedStashIDs([]models.StashID{
{
StashID: stashID1,
Endpoint: endpoint1,
},
{
StashID: stashID2,
Endpoint: endpoint2,
},
}),
CreatedAt: createdAt,
UpdatedAt: updatedAt,
},
CustomFields: testCustomFields,
},
false,
},
{
"invalid tag id",
models.Performer{
Name: name,
TagIDs: models.NewRelatedIDs([]int{invalidID}),
models.CreatePerformerInput{
Performer: &models.Performer{
Name: name,
TagIDs: models.NewRelatedIDs([]int{invalidID}),
},
},
true,
},
@@ -155,16 +166,16 @@ func Test_PerformerStore_Create(t *testing.T) {
assert.NotZero(p.ID)
copy := tt.newObject
copy := *tt.newObject.Performer
copy.ID = p.ID
// load relationships
if err := loadPerformerRelationships(ctx, copy, &p); err != nil {
if err := loadPerformerRelationships(ctx, copy, p.Performer); err != nil {
t.Errorf("loadPerformerRelationships() error = %v", err)
return
}
assert.Equal(copy, p)
assert.Equal(copy, *p.Performer)
// ensure can find the performer
found, err := qb.Find(ctx, p.ID)
@@ -183,6 +194,15 @@ func Test_PerformerStore_Create(t *testing.T) {
}
assert.Equal(copy, *found)
// ensure custom fields are set
cf, err := qb.GetCustomFields(ctx, p.ID)
if err != nil {
t.Errorf("PerformerStore.GetCustomFields() error = %v", err)
return
}
assert.Equal(tt.newObject.CustomFields, cf)
return
})
}
@@ -228,77 +248,109 @@ func Test_PerformerStore_Update(t *testing.T) {
tests := []struct {
name string
updatedObject *models.Performer
updatedObject models.UpdatePerformerInput
wantErr bool
}{
{
"full",
&models.Performer{
ID: performerIDs[performerIdxWithGallery],
Name: name,
Disambiguation: disambiguation,
Gender: &gender,
URLs: models.NewRelatedStrings(urls),
Birthdate: &birthdate,
Ethnicity: ethnicity,
Country: country,
EyeColor: eyeColor,
Height: &height,
Measurements: measurements,
FakeTits: fakeTits,
PenisLength: &penisLength,
Circumcised: &circumcised,
CareerLength: careerLength,
Tattoos: tattoos,
Piercings: piercings,
Favorite: favorite,
Rating: &rating,
Details: details,
DeathDate: &deathdate,
HairColor: hairColor,
Weight: &weight,
IgnoreAutoTag: ignoreAutoTag,
Aliases: models.NewRelatedStrings(aliases),
TagIDs: models.NewRelatedIDs([]int{tagIDs[tagIdx1WithPerformer], tagIDs[tagIdx1WithDupName]}),
StashIDs: models.NewRelatedStashIDs([]models.StashID{
{
StashID: stashID1,
Endpoint: endpoint1,
},
{
StashID: stashID2,
Endpoint: endpoint2,
},
}),
CreatedAt: createdAt,
UpdatedAt: updatedAt,
models.UpdatePerformerInput{
Performer: &models.Performer{
ID: performerIDs[performerIdxWithGallery],
Name: name,
Disambiguation: disambiguation,
Gender: &gender,
URLs: models.NewRelatedStrings(urls),
Birthdate: &birthdate,
Ethnicity: ethnicity,
Country: country,
EyeColor: eyeColor,
Height: &height,
Measurements: measurements,
FakeTits: fakeTits,
PenisLength: &penisLength,
Circumcised: &circumcised,
CareerLength: careerLength,
Tattoos: tattoos,
Piercings: piercings,
Favorite: favorite,
Rating: &rating,
Details: details,
DeathDate: &deathdate,
HairColor: hairColor,
Weight: &weight,
IgnoreAutoTag: ignoreAutoTag,
Aliases: models.NewRelatedStrings(aliases),
TagIDs: models.NewRelatedIDs([]int{tagIDs[tagIdx1WithPerformer], tagIDs[tagIdx1WithDupName]}),
StashIDs: models.NewRelatedStashIDs([]models.StashID{
{
StashID: stashID1,
Endpoint: endpoint1,
},
{
StashID: stashID2,
Endpoint: endpoint2,
},
}),
CreatedAt: createdAt,
UpdatedAt: updatedAt,
},
},
false,
},
{
"clear nullables",
&models.Performer{
ID: performerIDs[performerIdxWithGallery],
Aliases: models.NewRelatedStrings([]string{}),
URLs: models.NewRelatedStrings([]string{}),
TagIDs: models.NewRelatedIDs([]int{}),
StashIDs: models.NewRelatedStashIDs([]models.StashID{}),
models.UpdatePerformerInput{
Performer: &models.Performer{
ID: performerIDs[performerIdxWithGallery],
Aliases: models.NewRelatedStrings([]string{}),
URLs: models.NewRelatedStrings([]string{}),
TagIDs: models.NewRelatedIDs([]int{}),
StashIDs: models.NewRelatedStashIDs([]models.StashID{}),
},
},
false,
},
{
"clear tag ids",
&models.Performer{
ID: performerIDs[sceneIdxWithTag],
TagIDs: models.NewRelatedIDs([]int{}),
models.UpdatePerformerInput{
Performer: &models.Performer{
ID: performerIDs[sceneIdxWithTag],
TagIDs: models.NewRelatedIDs([]int{}),
},
},
false,
},
{
"set custom fields",
models.UpdatePerformerInput{
Performer: &models.Performer{
ID: performerIDs[performerIdxWithGallery],
},
CustomFields: models.CustomFieldsInput{
Full: testCustomFields,
},
},
false,
},
{
"clear custom fields",
models.UpdatePerformerInput{
Performer: &models.Performer{
ID: performerIDs[performerIdxWithGallery],
},
CustomFields: models.CustomFieldsInput{
Full: map[string]interface{}{},
},
},
false,
},
{
"invalid tag id",
&models.Performer{
ID: performerIDs[sceneIdxWithGallery],
TagIDs: models.NewRelatedIDs([]int{invalidID}),
models.UpdatePerformerInput{
Performer: &models.Performer{
ID: performerIDs[sceneIdxWithGallery],
TagIDs: models.NewRelatedIDs([]int{invalidID}),
},
},
true,
},
@@ -309,9 +361,9 @@ func Test_PerformerStore_Update(t *testing.T) {
runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) {
assert := assert.New(t)
copy := *tt.updatedObject
copy := *tt.updatedObject.Performer
if err := qb.Update(ctx, tt.updatedObject); (err != nil) != tt.wantErr {
if err := qb.Update(ctx, &tt.updatedObject); (err != nil) != tt.wantErr {
t.Errorf("PerformerStore.Update() error = %v, wantErr %v", err, tt.wantErr)
}
@@ -331,6 +383,17 @@ func Test_PerformerStore_Update(t *testing.T) {
}
assert.Equal(copy, *s)
// ensure custom fields are correct
if tt.updatedObject.CustomFields.Full != nil {
cf, err := qb.GetCustomFields(ctx, tt.updatedObject.ID)
if err != nil {
t.Errorf("PerformerStore.GetCustomFields() error = %v", err)
return
}
assert.Equal(tt.updatedObject.CustomFields.Full, cf)
}
})
}
}
@@ -573,6 +636,79 @@ func Test_PerformerStore_UpdatePartial(t *testing.T) {
}
}
func Test_PerformerStore_UpdatePartialCustomFields(t *testing.T) {
tests := []struct {
name string
id int
partial models.PerformerPartial
expected map[string]interface{} // nil to use the partial
}{
{
"set custom fields",
performerIDs[performerIdxWithGallery],
models.PerformerPartial{
CustomFields: models.CustomFieldsInput{
Full: testCustomFields,
},
},
nil,
},
{
"clear custom fields",
performerIDs[performerIdxWithGallery],
models.PerformerPartial{
CustomFields: models.CustomFieldsInput{
Full: map[string]interface{}{},
},
},
nil,
},
{
"partial custom fields",
performerIDs[performerIdxWithGallery],
models.PerformerPartial{
CustomFields: models.CustomFieldsInput{
Partial: map[string]interface{}{
"string": "bbb",
"new_field": "new",
},
},
},
map[string]interface{}{
"int": int64(3),
"real": 1.3,
"string": "bbb",
"new_field": "new",
},
},
}
for _, tt := range tests {
qb := db.Performer
runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) {
assert := assert.New(t)
_, err := qb.UpdatePartial(ctx, tt.id, tt.partial)
if err != nil {
t.Errorf("PerformerStore.UpdatePartial() error = %v", err)
return
}
// ensure custom fields are correct
cf, err := qb.GetCustomFields(ctx, tt.id)
if err != nil {
t.Errorf("PerformerStore.GetCustomFields() error = %v", err)
return
}
if tt.expected == nil {
assert.Equal(tt.partial.CustomFields.Full, cf)
} else {
assert.Equal(tt.expected, cf)
}
})
}
}
func TestPerformerFindBySceneID(t *testing.T) {
withTxn(func(ctx context.Context) error {
pqb := db.Performer
@@ -1042,6 +1178,242 @@ func TestPerformerQuery(t *testing.T) {
}
}
func TestPerformerQueryCustomFields(t *testing.T) {
tests := []struct {
name string
filter *models.PerformerFilterType
includeIdxs []int
excludeIdxs []int
wantErr bool
}{
{
"equals",
&models.PerformerFilterType{
CustomFields: []models.CustomFieldCriterionInput{
{
Field: "string",
Modifier: models.CriterionModifierEquals,
Value: []any{getPerformerStringValue(performerIdxWithGallery, "custom")},
},
},
},
[]int{performerIdxWithGallery},
nil,
false,
},
{
"not equals",
&models.PerformerFilterType{
Name: &models.StringCriterionInput{
Value: getPerformerStringValue(performerIdxWithGallery, "Name"),
Modifier: models.CriterionModifierEquals,
},
CustomFields: []models.CustomFieldCriterionInput{
{
Field: "string",
Modifier: models.CriterionModifierNotEquals,
Value: []any{getPerformerStringValue(performerIdxWithGallery, "custom")},
},
},
},
nil,
[]int{performerIdxWithGallery},
false,
},
{
"includes",
&models.PerformerFilterType{
CustomFields: []models.CustomFieldCriterionInput{
{
Field: "string",
Modifier: models.CriterionModifierIncludes,
Value: []any{getPerformerStringValue(performerIdxWithGallery, "custom")[9:]},
},
},
},
[]int{performerIdxWithGallery},
nil,
false,
},
{
"excludes",
&models.PerformerFilterType{
Name: &models.StringCriterionInput{
Value: getPerformerStringValue(performerIdxWithGallery, "Name"),
Modifier: models.CriterionModifierEquals,
},
CustomFields: []models.CustomFieldCriterionInput{
{
Field: "string",
Modifier: models.CriterionModifierExcludes,
Value: []any{getPerformerStringValue(performerIdxWithGallery, "custom")[9:]},
},
},
},
nil,
[]int{performerIdxWithGallery},
false,
},
{
"regex",
&models.PerformerFilterType{
CustomFields: []models.CustomFieldCriterionInput{
{
Field: "string",
Modifier: models.CriterionModifierMatchesRegex,
Value: []any{".*13_custom"},
},
},
},
[]int{performerIdxWithGallery},
nil,
false,
},
{
"invalid regex",
&models.PerformerFilterType{
CustomFields: []models.CustomFieldCriterionInput{
{
Field: "string",
Modifier: models.CriterionModifierMatchesRegex,
Value: []any{"["},
},
},
},
nil,
nil,
true,
},
{
"not matches regex",
&models.PerformerFilterType{
Name: &models.StringCriterionInput{
Value: getPerformerStringValue(performerIdxWithGallery, "Name"),
Modifier: models.CriterionModifierEquals,
},
CustomFields: []models.CustomFieldCriterionInput{
{
Field: "string",
Modifier: models.CriterionModifierNotMatchesRegex,
Value: []any{".*13_custom"},
},
},
},
nil,
[]int{performerIdxWithGallery},
false,
},
{
"invalid not matches regex",
&models.PerformerFilterType{
CustomFields: []models.CustomFieldCriterionInput{
{
Field: "string",
Modifier: models.CriterionModifierNotMatchesRegex,
Value: []any{"["},
},
},
},
nil,
nil,
true,
},
{
"null",
&models.PerformerFilterType{
Name: &models.StringCriterionInput{
Value: getPerformerStringValue(performerIdxWithGallery, "Name"),
Modifier: models.CriterionModifierEquals,
},
CustomFields: []models.CustomFieldCriterionInput{
{
Field: "not existing",
Modifier: models.CriterionModifierIsNull,
},
},
},
[]int{performerIdxWithGallery},
nil,
false,
},
{
"null",
&models.PerformerFilterType{
Name: &models.StringCriterionInput{
Value: getPerformerStringValue(performerIdxWithGallery, "Name"),
Modifier: models.CriterionModifierEquals,
},
CustomFields: []models.CustomFieldCriterionInput{
{
Field: "string",
Modifier: models.CriterionModifierNotNull,
},
},
},
[]int{performerIdxWithGallery},
nil,
false,
},
{
"between",
&models.PerformerFilterType{
CustomFields: []models.CustomFieldCriterionInput{
{
Field: "real",
Modifier: models.CriterionModifierBetween,
Value: []any{0.05, 0.15},
},
},
},
[]int{performerIdx1WithScene},
nil,
false,
},
{
"not between",
&models.PerformerFilterType{
Name: &models.StringCriterionInput{
Value: getPerformerStringValue(performerIdx1WithScene, "Name"),
Modifier: models.CriterionModifierEquals,
},
CustomFields: []models.CustomFieldCriterionInput{
{
Field: "real",
Modifier: models.CriterionModifierNotBetween,
Value: []any{0.05, 0.15},
},
},
},
nil,
[]int{performerIdx1WithScene},
false,
},
}
for _, tt := range tests {
runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) {
assert := assert.New(t)
performers, _, err := db.Performer.Query(ctx, tt.filter, nil)
if (err != nil) != tt.wantErr {
t.Errorf("PerformerStore.Query() error = %v, wantErr %v", err, tt.wantErr)
return
}
ids := performersToIDs(performers)
include := indexesToIDs(performerIDs, tt.includeIdxs)
exclude := indexesToIDs(performerIDs, tt.excludeIdxs)
for _, i := range include {
assert.Contains(ids, i)
}
for _, e := range exclude {
assert.NotContains(ids, e)
}
})
}
}
func TestPerformerQueryPenisLength(t *testing.T) {
var upper = 4.0
@@ -1172,7 +1544,7 @@ func TestPerformerUpdatePerformerImage(t *testing.T) {
performer := models.Performer{
Name: name,
}
err := qb.Create(ctx, &performer)
err := qb.Create(ctx, &models.CreatePerformerInput{Performer: &performer})
if err != nil {
return fmt.Errorf("Error creating performer: %s", err.Error())
}
@@ -1680,7 +2052,7 @@ func TestPerformerStashIDs(t *testing.T) {
performer := &models.Performer{
Name: name,
}
if err := qb.Create(ctx, performer); err != nil {
if err := qb.Create(ctx, &models.CreatePerformerInput{Performer: performer}); err != nil {
return fmt.Errorf("Error creating performer: %s", err.Error())
}

View File

@@ -133,6 +133,9 @@ func (qb *queryBuilder) join(table, as, onClause string) {
func (qb *queryBuilder) addJoins(joins ...join) {
qb.joins.add(joins...)
for _, j := range joins {
qb.args = append(qb.args, j.args...)
}
}
func (qb *queryBuilder) addFilter(f *filterBuilder) error {
@@ -151,6 +154,9 @@ func (qb *queryBuilder) addFilter(f *filterBuilder) error {
qb.args = append(args, qb.args...)
}
// add joins here to insert args
qb.addJoins(f.getAllJoins()...)
clause, args = f.generateWhereClauses()
if len(clause) > 0 {
qb.addWhere(clause)
@@ -169,8 +175,6 @@ func (qb *queryBuilder) addFilter(f *filterBuilder) error {
qb.addArg(args...)
}
qb.addJoins(f.getAllJoins()...)
return nil
}

View File

@@ -222,8 +222,8 @@ func (r *repository) innerJoin(j joiner, as string, parentIDCol string) {
}
type joiner interface {
addLeftJoin(table, as, onClause string)
addInnerJoin(table, as, onClause string)
addLeftJoin(table, as, onClause string, args ...interface{})
addInnerJoin(table, as, onClause string, args ...interface{})
}
type joinRepository struct {

View File

@@ -1508,6 +1508,18 @@ func performerAliases(i int) []string {
return []string{getPerformerStringValue(i, "alias")}
}
func getPerformerCustomFields(index int) map[string]interface{} {
if index%5 == 0 {
return nil
}
return map[string]interface{}{
"string": getPerformerStringValue(index, "custom"),
"int": int64(index % 5),
"real": float64(index) / 10,
}
}
// createPerformers creates n performers with plain Name and o performers with camel cased NaMe included
func createPerformers(ctx context.Context, n int, o int) error {
pqb := db.Performer
@@ -1558,7 +1570,10 @@ func createPerformers(ctx context.Context, n int, o int) error {
})
}
err := pqb.Create(ctx, &performer)
err := pqb.Create(ctx, &models.CreatePerformerInput{
Performer: &performer,
CustomFields: getPerformerCustomFields(i),
})
if err != nil {
return fmt.Errorf("Error creating performer %v+: %s", performer, err.Error())

View File

@@ -32,6 +32,7 @@ var (
performersURLsJoinTable = goqu.T(performerURLsTable)
performersTagsJoinTable = goqu.T(performersTagsTable)
performersStashIDsJoinTable = goqu.T("performer_stash_ids")
performersCustomFieldsTable = goqu.T("performer_custom_fields")
studiosAliasesJoinTable = goqu.T(studioAliasesTable)
studiosTagsJoinTable = goqu.T(studiosTagsTable)

View File

@@ -1,16 +0,0 @@
package utils
import (
"encoding/json"
"strings"
)
// JSONNumberToNumber converts a JSON number to either a float64 or int64.
func JSONNumberToNumber(n json.Number) interface{} {
if strings.Contains(string(n), ".") {
f, _ := n.Float64()
return f
}
ret, _ := n.Int64()
return ret
}

View File

@@ -1,7 +1,6 @@
package utils
import (
"encoding/json"
"strings"
)
@@ -80,19 +79,3 @@ func MergeMaps(dest map[string]interface{}, src map[string]interface{}) {
dest[k] = v
}
}
// ConvertMapJSONNumbers converts all JSON numbers in a map to either float64 or int64.
func ConvertMapJSONNumbers(m map[string]interface{}) (ret map[string]interface{}) {
ret = make(map[string]interface{})
for k, v := range m {
if n, ok := v.(json.Number); ok {
ret[k] = JSONNumberToNumber(n)
} else if mm, ok := v.(map[string]interface{}); ok {
ret[k] = ConvertMapJSONNumbers(mm)
} else {
ret[k] = v
}
}
return ret
}

View File

@@ -1,11 +1,8 @@
package utils
import (
"encoding/json"
"reflect"
"testing"
"github.com/stretchr/testify/assert"
)
func TestNestedMapGet(t *testing.T) {
@@ -282,55 +279,3 @@ func TestMergeMaps(t *testing.T) {
})
}
}
func TestConvertMapJSONNumbers(t *testing.T) {
tests := []struct {
name string
input map[string]interface{}
expected map[string]interface{}
}{
{
name: "Convert JSON numbers to numbers",
input: map[string]interface{}{
"int": json.Number("12"),
"float": json.Number("12.34"),
"string": "foo",
},
expected: map[string]interface{}{
"int": int64(12),
"float": 12.34,
"string": "foo",
},
},
{
name: "Convert JSON numbers to numbers in nested maps",
input: map[string]interface{}{
"foo": map[string]interface{}{
"int": json.Number("56"),
"float": json.Number("56.78"),
"nested-string": "bar",
},
"int": json.Number("12"),
"float": json.Number("12.34"),
"string": "foo",
},
expected: map[string]interface{}{
"foo": map[string]interface{}{
"int": int64(56),
"float": 56.78,
"nested-string": "bar",
},
"int": int64(12),
"float": 12.34,
"string": "foo",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ConvertMapJSONNumbers(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}