mirror of
https://github.com/stashapp/stash.git
synced 2025-12-16 20:07:05 +03:00
Performer custom fields (#5487)
* Backend changes * Show custom field values * Add custom fields table input * Add custom field filtering * Add unit tests * Include custom fields in import/export * Anonymise performer custom fields * Move json.Number handler functions to api * Handle json.Number conversion in api
This commit is contained in:
@@ -188,7 +188,9 @@ func (i *Importer) createPerformers(ctx context.Context, names []string) ([]*mod
|
||||
newPerformer := models.NewPerformer()
|
||||
newPerformer.Name = name
|
||||
|
||||
err := i.PerformerWriter.Create(ctx, &newPerformer)
|
||||
err := i.PerformerWriter.Create(ctx, &models.CreatePerformerInput{
|
||||
Performer: &newPerformer,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -201,8 +201,8 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) {
|
||||
}
|
||||
|
||||
db.Performer.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3)
|
||||
db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Run(func(args mock.Arguments) {
|
||||
performer := args.Get(1).(*models.Performer)
|
||||
db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.CreatePerformerInput")).Run(func(args mock.Arguments) {
|
||||
performer := args.Get(1).(*models.CreatePerformerInput)
|
||||
performer.ID = existingPerformerID
|
||||
}).Return(nil)
|
||||
|
||||
@@ -235,7 +235,7 @@ func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) {
|
||||
}
|
||||
|
||||
db.Performer.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once()
|
||||
db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Return(errors.New("Create error"))
|
||||
db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.CreatePerformerInput")).Return(errors.New("Create error"))
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
@@ -274,7 +274,9 @@ func (i *Importer) createPerformers(ctx context.Context, names []string) ([]*mod
|
||||
newPerformer := models.NewPerformer()
|
||||
newPerformer.Name = name
|
||||
|
||||
err := i.PerformerWriter.Create(ctx, &newPerformer)
|
||||
err := i.PerformerWriter.Create(ctx, &models.CreatePerformerInput{
|
||||
Performer: &newPerformer,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -163,8 +163,8 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) {
|
||||
}
|
||||
|
||||
db.Performer.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3)
|
||||
db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Run(func(args mock.Arguments) {
|
||||
performer := args.Get(1).(*models.Performer)
|
||||
db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.CreatePerformerInput")).Run(func(args mock.Arguments) {
|
||||
performer := args.Get(1).(*models.CreatePerformerInput)
|
||||
performer.ID = existingPerformerID
|
||||
}).Return(nil)
|
||||
|
||||
@@ -197,7 +197,7 @@ func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) {
|
||||
}
|
||||
|
||||
db.Performer.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once()
|
||||
db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Return(errors.New("Create error"))
|
||||
db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.CreatePerformerInput")).Return(errors.New("Create error"))
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
17
pkg/models/custom_fields.go
Normal file
17
pkg/models/custom_fields.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package models
|
||||
|
||||
import "context"
|
||||
|
||||
type CustomFieldMap map[string]interface{}
|
||||
|
||||
type CustomFieldsInput struct {
|
||||
// If populated, the entire custom fields map will be replaced with this value
|
||||
Full map[string]interface{} `json:"full"`
|
||||
// If populated, only the keys in this map will be updated
|
||||
Partial map[string]interface{} `json:"partial"`
|
||||
}
|
||||
|
||||
type CustomFieldsReader interface {
|
||||
GetCustomFields(ctx context.Context, id int) (map[string]interface{}, error)
|
||||
GetCustomFieldsBulk(ctx context.Context, ids []int) ([]CustomFieldMap, error)
|
||||
}
|
||||
@@ -194,3 +194,9 @@ type PhashDistanceCriterionInput struct {
|
||||
type OrientationCriterionInput struct {
|
||||
Value []OrientationEnum `json:"value"`
|
||||
}
|
||||
|
||||
type CustomFieldCriterionInput struct {
|
||||
Field string `json:"field"`
|
||||
Value []any `json:"value"`
|
||||
Modifier CriterionModifier `json:"modifier"`
|
||||
}
|
||||
|
||||
@@ -65,6 +65,8 @@ type Performer struct {
|
||||
StashIDs []models.StashID `json:"stash_ids,omitempty"`
|
||||
IgnoreAutoTag bool `json:"ignore_auto_tag,omitempty"`
|
||||
|
||||
CustomFields map[string]interface{} `json:"custom_fields,omitempty"`
|
||||
|
||||
// deprecated - for import only
|
||||
URL string `json:"url,omitempty"`
|
||||
Twitter string `json:"twitter,omitempty"`
|
||||
|
||||
@@ -80,11 +80,11 @@ func (_m *PerformerReaderWriter) CountByTagID(ctx context.Context, tagID int) (i
|
||||
}
|
||||
|
||||
// Create provides a mock function with given fields: ctx, newPerformer
|
||||
func (_m *PerformerReaderWriter) Create(ctx context.Context, newPerformer *models.Performer) error {
|
||||
func (_m *PerformerReaderWriter) Create(ctx context.Context, newPerformer *models.CreatePerformerInput) error {
|
||||
ret := _m.Called(ctx, newPerformer)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *models.Performer) error); ok {
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *models.CreatePerformerInput) error); ok {
|
||||
r0 = rf(ctx, newPerformer)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
@@ -314,6 +314,52 @@ func (_m *PerformerReaderWriter) GetAliases(ctx context.Context, relatedID int)
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// GetCustomFields provides a mock function with given fields: ctx, id
|
||||
func (_m *PerformerReaderWriter) GetCustomFields(ctx context.Context, id int) (map[string]interface{}, error) {
|
||||
ret := _m.Called(ctx, id)
|
||||
|
||||
var r0 map[string]interface{}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, int) map[string]interface{}); ok {
|
||||
r0 = rf(ctx, id)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(map[string]interface{})
|
||||
}
|
||||
}
|
||||
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
|
||||
r1 = rf(ctx, id)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// GetCustomFieldsBulk provides a mock function with given fields: ctx, ids
|
||||
func (_m *PerformerReaderWriter) GetCustomFieldsBulk(ctx context.Context, ids []int) ([]models.CustomFieldMap, error) {
|
||||
ret := _m.Called(ctx, ids)
|
||||
|
||||
var r0 []models.CustomFieldMap
|
||||
if rf, ok := ret.Get(0).(func(context.Context, []int) []models.CustomFieldMap); ok {
|
||||
r0 = rf(ctx, ids)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]models.CustomFieldMap)
|
||||
}
|
||||
}
|
||||
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(1).(func(context.Context, []int) error); ok {
|
||||
r1 = rf(ctx, ids)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// GetImage provides a mock function with given fields: ctx, performerID
|
||||
func (_m *PerformerReaderWriter) GetImage(ctx context.Context, performerID int) ([]byte, error) {
|
||||
ret := _m.Called(ctx, performerID)
|
||||
@@ -502,11 +548,11 @@ func (_m *PerformerReaderWriter) QueryForAutoTag(ctx context.Context, words []st
|
||||
}
|
||||
|
||||
// Update provides a mock function with given fields: ctx, updatedPerformer
|
||||
func (_m *PerformerReaderWriter) Update(ctx context.Context, updatedPerformer *models.Performer) error {
|
||||
func (_m *PerformerReaderWriter) Update(ctx context.Context, updatedPerformer *models.UpdatePerformerInput) error {
|
||||
ret := _m.Called(ctx, updatedPerformer)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *models.Performer) error); ok {
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *models.UpdatePerformerInput) error); ok {
|
||||
r0 = rf(ctx, updatedPerformer)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
|
||||
@@ -39,6 +39,18 @@ type Performer struct {
|
||||
StashIDs RelatedStashIDs `json:"stash_ids"`
|
||||
}
|
||||
|
||||
type CreatePerformerInput struct {
|
||||
*Performer
|
||||
|
||||
CustomFields map[string]interface{} `json:"custom_fields"`
|
||||
}
|
||||
|
||||
type UpdatePerformerInput struct {
|
||||
*Performer
|
||||
|
||||
CustomFields CustomFieldsInput `json:"custom_fields"`
|
||||
}
|
||||
|
||||
func NewPerformer() Performer {
|
||||
currentTime := time.Now()
|
||||
return Performer{
|
||||
@@ -80,6 +92,8 @@ type PerformerPartial struct {
|
||||
Aliases *UpdateStrings
|
||||
TagIDs *UpdateIDs
|
||||
StashIDs *UpdateStashIDs
|
||||
|
||||
CustomFields CustomFieldsInput
|
||||
}
|
||||
|
||||
func NewPerformerPartial() PerformerPartial {
|
||||
|
||||
@@ -198,6 +198,9 @@ type PerformerFilterType struct {
|
||||
CreatedAt *TimestampCriterionInput `json:"created_at"`
|
||||
// Filter by updated at
|
||||
UpdatedAt *TimestampCriterionInput `json:"updated_at"`
|
||||
|
||||
// Filter by custom fields
|
||||
CustomFields []CustomFieldCriterionInput `json:"custom_fields"`
|
||||
}
|
||||
|
||||
type PerformerCreateInput struct {
|
||||
@@ -234,6 +237,8 @@ type PerformerCreateInput struct {
|
||||
HairColor *string `json:"hair_color"`
|
||||
Weight *int `json:"weight"`
|
||||
IgnoreAutoTag *bool `json:"ignore_auto_tag"`
|
||||
|
||||
CustomFields map[string]interface{} `json:"custom_fields"`
|
||||
}
|
||||
|
||||
type PerformerUpdateInput struct {
|
||||
@@ -271,4 +276,6 @@ type PerformerUpdateInput struct {
|
||||
HairColor *string `json:"hair_color"`
|
||||
Weight *int `json:"weight"`
|
||||
IgnoreAutoTag *bool `json:"ignore_auto_tag"`
|
||||
|
||||
CustomFields CustomFieldsInput `json:"custom_fields"`
|
||||
}
|
||||
|
||||
@@ -43,12 +43,12 @@ type PerformerCounter interface {
|
||||
|
||||
// PerformerCreator provides methods to create performers.
|
||||
type PerformerCreator interface {
|
||||
Create(ctx context.Context, newPerformer *Performer) error
|
||||
Create(ctx context.Context, newPerformer *CreatePerformerInput) error
|
||||
}
|
||||
|
||||
// PerformerUpdater provides methods to update performers.
|
||||
type PerformerUpdater interface {
|
||||
Update(ctx context.Context, updatedPerformer *Performer) error
|
||||
Update(ctx context.Context, updatedPerformer *UpdatePerformerInput) error
|
||||
UpdatePartial(ctx context.Context, id int, updatedPerformer PerformerPartial) (*Performer, error)
|
||||
UpdateImage(ctx context.Context, performerID int, image []byte) error
|
||||
}
|
||||
@@ -80,6 +80,8 @@ type PerformerReader interface {
|
||||
TagIDLoader
|
||||
URLLoader
|
||||
|
||||
CustomFieldsReader
|
||||
|
||||
All(ctx context.Context) ([]*Performer, error)
|
||||
GetImage(ctx context.Context, performerID int) ([]byte, error)
|
||||
HasImage(ctx context.Context, performerID int) (bool, error)
|
||||
|
||||
@@ -17,6 +17,7 @@ type ImageAliasStashIDGetter interface {
|
||||
models.AliasLoader
|
||||
models.StashIDLoader
|
||||
models.URLLoader
|
||||
models.CustomFieldsReader
|
||||
}
|
||||
|
||||
// ToJSON converts a Performer object into its JSON equivalent.
|
||||
@@ -87,6 +88,12 @@ func ToJSON(ctx context.Context, reader ImageAliasStashIDGetter, performer *mode
|
||||
|
||||
newPerformerJSON.StashIDs = performer.StashIDs.List()
|
||||
|
||||
var err error
|
||||
newPerformerJSON.CustomFields, err = reader.GetCustomFields(ctx, performer.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting performer custom fields: %v", err)
|
||||
}
|
||||
|
||||
image, err := reader.GetImage(ctx, performer.ID)
|
||||
if err != nil {
|
||||
logger.Errorf("Error getting performer image: %v", err)
|
||||
|
||||
@@ -15,9 +15,11 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
performerID = 1
|
||||
noImageID = 2
|
||||
errImageID = 3
|
||||
performerID = 1
|
||||
noImageID = 2
|
||||
errImageID = 3
|
||||
customFieldsID = 4
|
||||
errCustomFieldsID = 5
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -50,6 +52,11 @@ var (
|
||||
penisLength = 1.23
|
||||
circumcisedEnum = models.CircumisedEnumCut
|
||||
circumcised = circumcisedEnum.String()
|
||||
|
||||
emptyCustomFields = make(map[string]interface{})
|
||||
customFields = map[string]interface{}{
|
||||
"customField1": "customValue1",
|
||||
}
|
||||
)
|
||||
|
||||
var imageBytes = []byte("imageBytes")
|
||||
@@ -118,8 +125,8 @@ func createEmptyPerformer(id int) models.Performer {
|
||||
}
|
||||
}
|
||||
|
||||
func createFullJSONPerformer(name string, image string) *jsonschema.Performer {
|
||||
return &jsonschema.Performer{
|
||||
func createFullJSONPerformer(name string, image string, withCustomFields bool) *jsonschema.Performer {
|
||||
ret := &jsonschema.Performer{
|
||||
Name: name,
|
||||
Disambiguation: disambiguation,
|
||||
URLs: []string{url, twitter, instagram},
|
||||
@@ -152,7 +159,13 @@ func createFullJSONPerformer(name string, image string) *jsonschema.Performer {
|
||||
Weight: weight,
|
||||
StashIDs: stashIDs,
|
||||
IgnoreAutoTag: autoTagIgnored,
|
||||
CustomFields: emptyCustomFields,
|
||||
}
|
||||
|
||||
if withCustomFields {
|
||||
ret.CustomFields = customFields
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func createEmptyJSONPerformer() *jsonschema.Performer {
|
||||
@@ -166,13 +179,15 @@ func createEmptyJSONPerformer() *jsonschema.Performer {
|
||||
UpdatedAt: json.JSONTime{
|
||||
Time: updateTime,
|
||||
},
|
||||
CustomFields: emptyCustomFields,
|
||||
}
|
||||
}
|
||||
|
||||
type testScenario struct {
|
||||
input models.Performer
|
||||
expected *jsonschema.Performer
|
||||
err bool
|
||||
input models.Performer
|
||||
customFields map[string]interface{}
|
||||
expected *jsonschema.Performer
|
||||
err bool
|
||||
}
|
||||
|
||||
var scenarios []testScenario
|
||||
@@ -181,20 +196,36 @@ func initTestTable() {
|
||||
scenarios = []testScenario{
|
||||
{
|
||||
*createFullPerformer(performerID, performerName),
|
||||
createFullJSONPerformer(performerName, image),
|
||||
emptyCustomFields,
|
||||
createFullJSONPerformer(performerName, image, false),
|
||||
false,
|
||||
},
|
||||
{
|
||||
*createFullPerformer(customFieldsID, performerName),
|
||||
customFields,
|
||||
createFullJSONPerformer(performerName, image, true),
|
||||
false,
|
||||
},
|
||||
{
|
||||
createEmptyPerformer(noImageID),
|
||||
emptyCustomFields,
|
||||
createEmptyJSONPerformer(),
|
||||
false,
|
||||
},
|
||||
{
|
||||
*createFullPerformer(errImageID, performerName),
|
||||
createFullJSONPerformer(performerName, ""),
|
||||
emptyCustomFields,
|
||||
createFullJSONPerformer(performerName, "", false),
|
||||
// failure to get image should not cause an error
|
||||
false,
|
||||
},
|
||||
{
|
||||
*createFullPerformer(errCustomFieldsID, performerName),
|
||||
customFields,
|
||||
nil,
|
||||
// failure to get custom fields should cause an error
|
||||
true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -204,11 +235,19 @@ func TestToJSON(t *testing.T) {
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
imageErr := errors.New("error getting image")
|
||||
customFieldsErr := errors.New("error getting custom fields")
|
||||
|
||||
db.Performer.On("GetImage", testCtx, performerID).Return(imageBytes, nil).Once()
|
||||
db.Performer.On("GetImage", testCtx, customFieldsID).Return(imageBytes, nil).Once()
|
||||
db.Performer.On("GetImage", testCtx, noImageID).Return(nil, nil).Once()
|
||||
db.Performer.On("GetImage", testCtx, errImageID).Return(nil, imageErr).Once()
|
||||
|
||||
db.Performer.On("GetCustomFields", testCtx, performerID).Return(emptyCustomFields, nil).Once()
|
||||
db.Performer.On("GetCustomFields", testCtx, customFieldsID).Return(customFields, nil).Once()
|
||||
db.Performer.On("GetCustomFields", testCtx, noImageID).Return(emptyCustomFields, nil).Once()
|
||||
db.Performer.On("GetCustomFields", testCtx, errImageID).Return(emptyCustomFields, nil).Once()
|
||||
db.Performer.On("GetCustomFields", testCtx, errCustomFieldsID).Return(nil, customFieldsErr).Once()
|
||||
|
||||
for i, s := range scenarios {
|
||||
tag := s.input
|
||||
json, err := ToJSON(testCtx, db.Performer, &tag)
|
||||
|
||||
@@ -25,13 +25,15 @@ type Importer struct {
|
||||
Input jsonschema.Performer
|
||||
MissingRefBehaviour models.ImportMissingRefEnum
|
||||
|
||||
ID int
|
||||
performer models.Performer
|
||||
imageData []byte
|
||||
ID int
|
||||
performer models.Performer
|
||||
customFields models.CustomFieldMap
|
||||
imageData []byte
|
||||
}
|
||||
|
||||
func (i *Importer) PreImport(ctx context.Context) error {
|
||||
i.performer = performerJSONToPerformer(i.Input)
|
||||
i.customFields = i.Input.CustomFields
|
||||
|
||||
if err := i.populateTags(ctx); err != nil {
|
||||
return err
|
||||
@@ -165,7 +167,10 @@ func (i *Importer) FindExistingID(ctx context.Context) (*int, error) {
|
||||
}
|
||||
|
||||
func (i *Importer) Create(ctx context.Context) (*int, error) {
|
||||
err := i.ReaderWriter.Create(ctx, &i.performer)
|
||||
err := i.ReaderWriter.Create(ctx, &models.CreatePerformerInput{
|
||||
Performer: &i.performer,
|
||||
CustomFields: i.customFields,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating performer: %v", err)
|
||||
}
|
||||
@@ -175,9 +180,13 @@ func (i *Importer) Create(ctx context.Context) (*int, error) {
|
||||
}
|
||||
|
||||
func (i *Importer) Update(ctx context.Context, id int) error {
|
||||
performer := i.performer
|
||||
performer.ID = id
|
||||
err := i.ReaderWriter.Update(ctx, &performer)
|
||||
i.performer.ID = id
|
||||
err := i.ReaderWriter.Update(ctx, &models.UpdatePerformerInput{
|
||||
Performer: &i.performer,
|
||||
CustomFields: models.CustomFieldsInput{
|
||||
Full: i.customFields,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("error updating existing performer: %v", err)
|
||||
}
|
||||
|
||||
@@ -53,13 +53,14 @@ func TestImporterPreImport(t *testing.T) {
|
||||
|
||||
assert.NotNil(t, err)
|
||||
|
||||
i.Input = *createFullJSONPerformer(performerName, image)
|
||||
i.Input = *createFullJSONPerformer(performerName, image, true)
|
||||
|
||||
err = i.PreImport(testCtx)
|
||||
|
||||
assert.Nil(t, err)
|
||||
expectedPerformer := *createFullPerformer(0, performerName)
|
||||
assert.Equal(t, expectedPerformer, i.performer)
|
||||
assert.Equal(t, models.CustomFieldMap(customFields), i.customFields)
|
||||
}
|
||||
|
||||
func TestImporterPreImportWithTag(t *testing.T) {
|
||||
@@ -234,10 +235,18 @@ func TestCreate(t *testing.T) {
|
||||
Name: performerName,
|
||||
}
|
||||
|
||||
performerInput := models.CreatePerformerInput{
|
||||
Performer: &performer,
|
||||
}
|
||||
|
||||
performerErr := models.Performer{
|
||||
Name: performerNameErr,
|
||||
}
|
||||
|
||||
performerErrInput := models.CreatePerformerInput{
|
||||
Performer: &performerErr,
|
||||
}
|
||||
|
||||
i := Importer{
|
||||
ReaderWriter: db.Performer,
|
||||
TagWriter: db.Tag,
|
||||
@@ -245,11 +254,11 @@ func TestCreate(t *testing.T) {
|
||||
}
|
||||
|
||||
errCreate := errors.New("Create error")
|
||||
db.Performer.On("Create", testCtx, &performer).Run(func(args mock.Arguments) {
|
||||
arg := args.Get(1).(*models.Performer)
|
||||
db.Performer.On("Create", testCtx, &performerInput).Run(func(args mock.Arguments) {
|
||||
arg := args.Get(1).(*models.CreatePerformerInput)
|
||||
arg.ID = performerID
|
||||
}).Return(nil).Once()
|
||||
db.Performer.On("Create", testCtx, &performerErr).Return(errCreate).Once()
|
||||
db.Performer.On("Create", testCtx, &performerErrInput).Return(errCreate).Once()
|
||||
|
||||
id, err := i.Create(testCtx)
|
||||
assert.Equal(t, performerID, *id)
|
||||
@@ -284,7 +293,10 @@ func TestUpdate(t *testing.T) {
|
||||
|
||||
// id needs to be set for the mock input
|
||||
performer.ID = performerID
|
||||
db.Performer.On("Update", testCtx, &performer).Return(nil).Once()
|
||||
performerInput := models.UpdatePerformerInput{
|
||||
Performer: &performer,
|
||||
}
|
||||
db.Performer.On("Update", testCtx, &performerInput).Return(nil).Once()
|
||||
|
||||
err := i.Update(testCtx, performerID)
|
||||
assert.Nil(t, err)
|
||||
@@ -293,7 +305,10 @@ func TestUpdate(t *testing.T) {
|
||||
|
||||
// need to set id separately
|
||||
performerErr.ID = errImageID
|
||||
db.Performer.On("Update", testCtx, &performerErr).Return(errUpdate).Once()
|
||||
performerErrInput := models.UpdatePerformerInput{
|
||||
Performer: &performerErr,
|
||||
}
|
||||
db.Performer.On("Update", testCtx, &performerErrInput).Return(errUpdate).Once()
|
||||
|
||||
err = i.Update(testCtx, errImageID)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
@@ -325,7 +325,9 @@ func (i *Importer) createPerformers(ctx context.Context, names []string) ([]*mod
|
||||
newPerformer := models.NewPerformer()
|
||||
newPerformer.Name = name
|
||||
|
||||
err := i.PerformerWriter.Create(ctx, &newPerformer)
|
||||
err := i.PerformerWriter.Create(ctx, &models.CreatePerformerInput{
|
||||
Performer: &newPerformer,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -327,8 +327,8 @@ func TestImporterPreImportWithMissingPerformer(t *testing.T) {
|
||||
}
|
||||
|
||||
db.Performer.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Times(3)
|
||||
db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Run(func(args mock.Arguments) {
|
||||
p := args.Get(1).(*models.Performer)
|
||||
db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.CreatePerformerInput")).Run(func(args mock.Arguments) {
|
||||
p := args.Get(1).(*models.CreatePerformerInput)
|
||||
p.ID = existingPerformerID
|
||||
}).Return(nil)
|
||||
|
||||
@@ -361,7 +361,7 @@ func TestImporterPreImportWithMissingPerformerCreateErr(t *testing.T) {
|
||||
}
|
||||
|
||||
db.Performer.On("FindByNames", testCtx, []string{missingPerformerName}, false).Return(nil, nil).Once()
|
||||
db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.Performer")).Return(errors.New("Create error"))
|
||||
db.Performer.On("Create", testCtx, mock.AnythingOfType("*models.CreatePerformerInput")).Return(errors.New("Create error"))
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
@@ -600,6 +600,10 @@ func (db *Anonymiser) anonymisePerformers(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := db.anonymiseCustomFields(ctx, goqu.T(performersCustomFieldsTable.GetTable()), "performer_id"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1050,3 +1054,73 @@ func (db *Anonymiser) obfuscateString(in string, dict string) string {
|
||||
|
||||
return out.String()
|
||||
}
|
||||
|
||||
func (db *Anonymiser) anonymiseCustomFields(ctx context.Context, table exp.IdentifierExpression, idColumn string) error {
|
||||
lastID := 0
|
||||
lastField := ""
|
||||
total := 0
|
||||
const logEvery = 10000
|
||||
|
||||
for gotSome := true; gotSome; {
|
||||
if err := txn.WithTxn(ctx, db, func(ctx context.Context) error {
|
||||
query := dialect.From(table).Select(
|
||||
table.Col(idColumn),
|
||||
table.Col("field"),
|
||||
table.Col("value"),
|
||||
).Where(
|
||||
goqu.L("("+idColumn+", field)").Gt(goqu.L("(?, ?)", lastID, lastField)),
|
||||
).Order(
|
||||
table.Col(idColumn).Asc(), table.Col("field").Asc(),
|
||||
).Limit(1000)
|
||||
|
||||
gotSome = false
|
||||
|
||||
const single = false
|
||||
return queryFunc(ctx, query, single, func(rows *sqlx.Rows) error {
|
||||
var (
|
||||
id int
|
||||
field string
|
||||
value string
|
||||
)
|
||||
|
||||
if err := rows.Scan(
|
||||
&id,
|
||||
&field,
|
||||
&value,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
set := goqu.Record{}
|
||||
set["field"] = db.obfuscateString(field, letters)
|
||||
set["value"] = db.obfuscateString(value, letters)
|
||||
|
||||
if len(set) > 0 {
|
||||
stmt := dialect.Update(table).Set(set).Where(
|
||||
table.Col(idColumn).Eq(id),
|
||||
table.Col("field").Eq(field),
|
||||
)
|
||||
|
||||
if _, err := exec(ctx, stmt); err != nil {
|
||||
return fmt.Errorf("anonymising %s: %w", table.GetTable(), err)
|
||||
}
|
||||
}
|
||||
|
||||
lastID = id
|
||||
lastField = field
|
||||
gotSome = true
|
||||
total++
|
||||
|
||||
if total%logEvery == 0 {
|
||||
logger.Infof("Anonymised %d %s custom fields", total, table.GetTable())
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
308
pkg/sqlite/custom_fields.go
Normal file
308
pkg/sqlite/custom_fields.go
Normal file
@@ -0,0 +1,308 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/doug-martin/goqu/v9"
|
||||
"github.com/doug-martin/goqu/v9/exp"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
)
|
||||
|
||||
const maxCustomFieldNameLength = 64
|
||||
|
||||
type customFieldsStore struct {
|
||||
table exp.IdentifierExpression
|
||||
fk exp.IdentifierExpression
|
||||
}
|
||||
|
||||
func (s *customFieldsStore) deleteForID(ctx context.Context, id int) error {
|
||||
table := s.table
|
||||
q := dialect.Delete(table).Where(s.fk.Eq(id))
|
||||
_, err := exec(ctx, q)
|
||||
if err != nil {
|
||||
return fmt.Errorf("deleting from %s: %w", s.table.GetTable(), err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *customFieldsStore) SetCustomFields(ctx context.Context, id int, values models.CustomFieldsInput) error {
|
||||
var partial bool
|
||||
var valMap map[string]interface{}
|
||||
|
||||
switch {
|
||||
case values.Full != nil:
|
||||
partial = false
|
||||
valMap = values.Full
|
||||
case values.Partial != nil:
|
||||
partial = true
|
||||
valMap = values.Partial
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.validateCustomFields(valMap); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.setCustomFields(ctx, id, valMap, partial)
|
||||
}
|
||||
|
||||
func (s *customFieldsStore) validateCustomFields(values map[string]interface{}) error {
|
||||
// ensure that custom field names are valid
|
||||
// no leading or trailing whitespace, no empty strings
|
||||
for k := range values {
|
||||
if err := s.validateCustomFieldName(k); err != nil {
|
||||
return fmt.Errorf("custom field name %q: %w", k, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *customFieldsStore) validateCustomFieldName(fieldName string) error {
|
||||
// ensure that custom field names are valid
|
||||
// no leading or trailing whitespace, no empty strings
|
||||
if strings.TrimSpace(fieldName) == "" {
|
||||
return fmt.Errorf("custom field name cannot be empty")
|
||||
}
|
||||
if fieldName != strings.TrimSpace(fieldName) {
|
||||
return fmt.Errorf("custom field name cannot have leading or trailing whitespace")
|
||||
}
|
||||
if len(fieldName) > maxCustomFieldNameLength {
|
||||
return fmt.Errorf("custom field name must be less than %d characters", maxCustomFieldNameLength+1)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getSQLValueFromCustomFieldInput(input interface{}) (interface{}, error) {
|
||||
switch v := input.(type) {
|
||||
case []interface{}, map[string]interface{}:
|
||||
// TODO - in future it would be nice to convert to a JSON string
|
||||
// however, we would need some way to differentiate between a JSON string and a regular string
|
||||
// for now, we will not support objects and arrays
|
||||
return nil, fmt.Errorf("unsupported custom field value type: %T", input)
|
||||
default:
|
||||
return v, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *customFieldsStore) sqlValueToValue(value interface{}) interface{} {
|
||||
// TODO - if we ever support objects and arrays we will need to add support here
|
||||
return value
|
||||
}
|
||||
|
||||
func (s *customFieldsStore) setCustomFields(ctx context.Context, id int, values map[string]interface{}, partial bool) error {
|
||||
if !partial {
|
||||
// delete existing custom fields
|
||||
if err := s.deleteForID(ctx, id); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
conflictKey := s.fk.GetCol().(string) + ", field"
|
||||
// upsert new custom fields
|
||||
q := dialect.Insert(s.table).Prepared(true).Cols(s.fk, "field", "value").
|
||||
OnConflict(goqu.DoUpdate(conflictKey, goqu.Record{"value": goqu.I("excluded.value")}))
|
||||
r := make([]interface{}, len(values))
|
||||
var i int
|
||||
for key, value := range values {
|
||||
v, err := getSQLValueFromCustomFieldInput(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting SQL value for field %q: %w", key, err)
|
||||
}
|
||||
r[i] = goqu.Record{"field": key, "value": v, s.fk.GetCol().(string): id}
|
||||
i++
|
||||
}
|
||||
|
||||
if _, err := exec(ctx, q.Rows(r...)); err != nil {
|
||||
return fmt.Errorf("inserting custom fields: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *customFieldsStore) GetCustomFields(ctx context.Context, id int) (map[string]interface{}, error) {
|
||||
q := dialect.Select("field", "value").From(s.table).Where(s.fk.Eq(id))
|
||||
|
||||
const single = false
|
||||
ret := make(map[string]interface{})
|
||||
err := queryFunc(ctx, q, single, func(rows *sqlx.Rows) error {
|
||||
var field string
|
||||
var value interface{}
|
||||
if err := rows.Scan(&field, &value); err != nil {
|
||||
return fmt.Errorf("scanning custom fields: %w", err)
|
||||
}
|
||||
ret[field] = s.sqlValueToValue(value)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting custom fields: %w", err)
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (s *customFieldsStore) GetCustomFieldsBulk(ctx context.Context, ids []int) ([]models.CustomFieldMap, error) {
|
||||
q := dialect.Select(s.fk.As("id"), "field", "value").From(s.table).Where(s.fk.In(ids))
|
||||
|
||||
const single = false
|
||||
ret := make([]models.CustomFieldMap, len(ids))
|
||||
|
||||
idi := make(map[int]int, len(ids))
|
||||
for i, id := range ids {
|
||||
idi[id] = i
|
||||
}
|
||||
|
||||
err := queryFunc(ctx, q, single, func(rows *sqlx.Rows) error {
|
||||
var id int
|
||||
var field string
|
||||
var value interface{}
|
||||
if err := rows.Scan(&id, &field, &value); err != nil {
|
||||
return fmt.Errorf("scanning custom fields: %w", err)
|
||||
}
|
||||
|
||||
i := idi[id]
|
||||
m := ret[i]
|
||||
if m == nil {
|
||||
m = make(map[string]interface{})
|
||||
ret[i] = m
|
||||
}
|
||||
|
||||
m[field] = s.sqlValueToValue(value)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting custom fields: %w", err)
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
type customFieldsFilterHandler struct {
|
||||
table string
|
||||
fkCol string
|
||||
c []models.CustomFieldCriterionInput
|
||||
idCol string
|
||||
}
|
||||
|
||||
func (h *customFieldsFilterHandler) innerJoin(f *filterBuilder, as string, field string) {
|
||||
joinOn := fmt.Sprintf("%s = %s.%s AND %s.field = ?", h.idCol, as, h.fkCol, as)
|
||||
f.addInnerJoin(h.table, as, joinOn, field)
|
||||
}
|
||||
|
||||
func (h *customFieldsFilterHandler) leftJoin(f *filterBuilder, as string, field string) {
|
||||
joinOn := fmt.Sprintf("%s = %s.%s AND %s.field = ?", h.idCol, as, h.fkCol, as)
|
||||
f.addLeftJoin(h.table, as, joinOn, field)
|
||||
}
|
||||
|
||||
func (h *customFieldsFilterHandler) handleCriterion(f *filterBuilder, joinAs string, cc models.CustomFieldCriterionInput) {
|
||||
// convert values
|
||||
cv := make([]interface{}, len(cc.Value))
|
||||
for i, v := range cc.Value {
|
||||
var err error
|
||||
cv[i], err = getSQLValueFromCustomFieldInput(v)
|
||||
if err != nil {
|
||||
f.setError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
switch cc.Modifier {
|
||||
case models.CriterionModifierEquals:
|
||||
h.innerJoin(f, joinAs, cc.Field)
|
||||
f.addWhere(fmt.Sprintf("%[1]s.value IN %s", joinAs, getInBinding(len(cv))), cv...)
|
||||
case models.CriterionModifierNotEquals:
|
||||
h.innerJoin(f, joinAs, cc.Field)
|
||||
f.addWhere(fmt.Sprintf("%[1]s.value NOT IN %s", joinAs, getInBinding(len(cv))), cv...)
|
||||
case models.CriterionModifierIncludes:
|
||||
clauses := make([]sqlClause, len(cv))
|
||||
for i, v := range cv {
|
||||
clauses[i] = makeClause(fmt.Sprintf("%s.value LIKE ?", joinAs), fmt.Sprintf("%%%v%%", v))
|
||||
}
|
||||
h.innerJoin(f, joinAs, cc.Field)
|
||||
f.whereClauses = append(f.whereClauses, clauses...)
|
||||
case models.CriterionModifierExcludes:
|
||||
for _, v := range cv {
|
||||
f.addWhere(fmt.Sprintf("%[1]s.value NOT LIKE ?", joinAs), fmt.Sprintf("%%%v%%", v))
|
||||
}
|
||||
h.leftJoin(f, joinAs, cc.Field)
|
||||
case models.CriterionModifierMatchesRegex:
|
||||
for _, v := range cv {
|
||||
vs, ok := v.(string)
|
||||
if !ok {
|
||||
f.setError(fmt.Errorf("unsupported custom field criterion value type: %T", v))
|
||||
}
|
||||
if _, err := regexp.Compile(vs); err != nil {
|
||||
f.setError(err)
|
||||
return
|
||||
}
|
||||
f.addWhere(fmt.Sprintf("(%s.value regexp ?)", joinAs), v)
|
||||
}
|
||||
h.innerJoin(f, joinAs, cc.Field)
|
||||
case models.CriterionModifierNotMatchesRegex:
|
||||
for _, v := range cv {
|
||||
vs, ok := v.(string)
|
||||
if !ok {
|
||||
f.setError(fmt.Errorf("unsupported custom field criterion value type: %T", v))
|
||||
}
|
||||
if _, err := regexp.Compile(vs); err != nil {
|
||||
f.setError(err)
|
||||
return
|
||||
}
|
||||
f.addWhere(fmt.Sprintf("(%s.value IS NULL OR %[1]s.value NOT regexp ?)", joinAs), v)
|
||||
}
|
||||
h.leftJoin(f, joinAs, cc.Field)
|
||||
case models.CriterionModifierIsNull:
|
||||
h.leftJoin(f, joinAs, cc.Field)
|
||||
f.addWhere(fmt.Sprintf("%s.value IS NULL OR TRIM(%[1]s.value) = ''", joinAs))
|
||||
case models.CriterionModifierNotNull:
|
||||
h.innerJoin(f, joinAs, cc.Field)
|
||||
f.addWhere(fmt.Sprintf("TRIM(%[1]s.value) != ''", joinAs))
|
||||
case models.CriterionModifierBetween:
|
||||
if len(cv) != 2 {
|
||||
f.setError(fmt.Errorf("expected 2 values for custom field criterion modifier BETWEEN, got %d", len(cv)))
|
||||
return
|
||||
}
|
||||
h.innerJoin(f, joinAs, cc.Field)
|
||||
f.addWhere(fmt.Sprintf("%s.value BETWEEN ? AND ?", joinAs), cv[0], cv[1])
|
||||
case models.CriterionModifierNotBetween:
|
||||
h.innerJoin(f, joinAs, cc.Field)
|
||||
f.addWhere(fmt.Sprintf("%s.value NOT BETWEEN ? AND ?", joinAs), cv[0], cv[1])
|
||||
case models.CriterionModifierLessThan:
|
||||
if len(cv) != 1 {
|
||||
f.setError(fmt.Errorf("expected 1 value for custom field criterion modifier LESS_THAN, got %d", len(cv)))
|
||||
return
|
||||
}
|
||||
h.innerJoin(f, joinAs, cc.Field)
|
||||
f.addWhere(fmt.Sprintf("%s.value < ?", joinAs), cv[0])
|
||||
case models.CriterionModifierGreaterThan:
|
||||
if len(cv) != 1 {
|
||||
f.setError(fmt.Errorf("expected 1 value for custom field criterion modifier LESS_THAN, got %d", len(cv)))
|
||||
return
|
||||
}
|
||||
h.innerJoin(f, joinAs, cc.Field)
|
||||
f.addWhere(fmt.Sprintf("%s.value > ?", joinAs), cv[0])
|
||||
default:
|
||||
f.setError(fmt.Errorf("unsupported custom field criterion modifier: %s", cc.Modifier))
|
||||
}
|
||||
}
|
||||
|
||||
func (h *customFieldsFilterHandler) handle(ctx context.Context, f *filterBuilder) {
|
||||
if len(h.c) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for i, cc := range h.c {
|
||||
join := fmt.Sprintf("custom_fields_%d", i)
|
||||
h.handleCriterion(f, join, cc)
|
||||
}
|
||||
}
|
||||
176
pkg/sqlite/custom_fields_test.go
Normal file
176
pkg/sqlite/custom_fields_test.go
Normal file
@@ -0,0 +1,176 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package sqlite_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSetCustomFields(t *testing.T) {
|
||||
performerIdx := performerIdx1WithScene
|
||||
|
||||
mergeCustomFields := func(i map[string]interface{}) map[string]interface{} {
|
||||
m := getPerformerCustomFields(performerIdx)
|
||||
for k, v := range i {
|
||||
m[k] = v
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input models.CustomFieldsInput
|
||||
expected map[string]interface{}
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"valid full",
|
||||
models.CustomFieldsInput{
|
||||
Full: map[string]interface{}{
|
||||
"key": "value",
|
||||
},
|
||||
},
|
||||
map[string]interface{}{
|
||||
"key": "value",
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"valid partial",
|
||||
models.CustomFieldsInput{
|
||||
Partial: map[string]interface{}{
|
||||
"key": "value",
|
||||
},
|
||||
},
|
||||
mergeCustomFields(map[string]interface{}{
|
||||
"key": "value",
|
||||
}),
|
||||
false,
|
||||
},
|
||||
{
|
||||
"valid partial overwrite",
|
||||
models.CustomFieldsInput{
|
||||
Partial: map[string]interface{}{
|
||||
"real": float64(4.56),
|
||||
},
|
||||
},
|
||||
mergeCustomFields(map[string]interface{}{
|
||||
"real": float64(4.56),
|
||||
}),
|
||||
false,
|
||||
},
|
||||
{
|
||||
"leading space full",
|
||||
models.CustomFieldsInput{
|
||||
Full: map[string]interface{}{
|
||||
" key": "value",
|
||||
},
|
||||
},
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"trailing space full",
|
||||
models.CustomFieldsInput{
|
||||
Full: map[string]interface{}{
|
||||
"key ": "value",
|
||||
},
|
||||
},
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"leading space partial",
|
||||
models.CustomFieldsInput{
|
||||
Partial: map[string]interface{}{
|
||||
" key": "value",
|
||||
},
|
||||
},
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"trailing space partial",
|
||||
models.CustomFieldsInput{
|
||||
Partial: map[string]interface{}{
|
||||
"key ": "value",
|
||||
},
|
||||
},
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"big key full",
|
||||
models.CustomFieldsInput{
|
||||
Full: map[string]interface{}{
|
||||
"12345678901234567890123456789012345678901234567890123456789012345": "value",
|
||||
},
|
||||
},
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"big key partial",
|
||||
models.CustomFieldsInput{
|
||||
Partial: map[string]interface{}{
|
||||
"12345678901234567890123456789012345678901234567890123456789012345": "value",
|
||||
},
|
||||
},
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"empty key full",
|
||||
models.CustomFieldsInput{
|
||||
Full: map[string]interface{}{
|
||||
"": "value",
|
||||
},
|
||||
},
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"empty key partial",
|
||||
models.CustomFieldsInput{
|
||||
Partial: map[string]interface{}{
|
||||
"": "value",
|
||||
},
|
||||
},
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
}
|
||||
|
||||
// use performer custom fields store
|
||||
store := db.Performer
|
||||
id := performerIDs[performerIdx]
|
||||
|
||||
assert := assert.New(t)
|
||||
|
||||
for _, tt := range tests {
|
||||
runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) {
|
||||
err := store.SetCustomFields(ctx, id, tt.input)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("SetCustomFields() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if tt.wantErr {
|
||||
return
|
||||
}
|
||||
|
||||
actual, err := store.GetCustomFields(ctx, id)
|
||||
if err != nil {
|
||||
t.Errorf("GetCustomFields() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(tt.expected, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -34,7 +34,7 @@ const (
|
||||
cacheSizeEnv = "STASH_SQLITE_CACHE_SIZE"
|
||||
)
|
||||
|
||||
var appSchemaVersion uint = 70
|
||||
var appSchemaVersion uint = 71
|
||||
|
||||
//go:embed migrations/*.sql
|
||||
var migrationsBox embed.FS
|
||||
|
||||
@@ -95,6 +95,7 @@ type join struct {
|
||||
as string
|
||||
onClause string
|
||||
joinType string
|
||||
args []interface{}
|
||||
}
|
||||
|
||||
// equals returns true if the other join alias/table is equal to this one
|
||||
@@ -229,12 +230,13 @@ func (f *filterBuilder) not(n *filterBuilder) {
|
||||
// The AS is omitted if as is empty.
|
||||
// This method does not add a join if it its alias/table name is already
|
||||
// present in another existing join.
|
||||
func (f *filterBuilder) addLeftJoin(table, as, onClause string) {
|
||||
func (f *filterBuilder) addLeftJoin(table, as, onClause string, args ...interface{}) {
|
||||
newJoin := join{
|
||||
table: table,
|
||||
as: as,
|
||||
onClause: onClause,
|
||||
joinType: "LEFT",
|
||||
args: args,
|
||||
}
|
||||
|
||||
f.joins.add(newJoin)
|
||||
@@ -245,12 +247,13 @@ func (f *filterBuilder) addLeftJoin(table, as, onClause string) {
|
||||
// The AS is omitted if as is empty.
|
||||
// This method does not add a join if it its alias/table name is already
|
||||
// present in another existing join.
|
||||
func (f *filterBuilder) addInnerJoin(table, as, onClause string) {
|
||||
func (f *filterBuilder) addInnerJoin(table, as, onClause string, args ...interface{}) {
|
||||
newJoin := join{
|
||||
table: table,
|
||||
as: as,
|
||||
onClause: onClause,
|
||||
joinType: "INNER",
|
||||
args: args,
|
||||
}
|
||||
|
||||
f.joins.add(newJoin)
|
||||
|
||||
9
pkg/sqlite/migrations/71_custom_fields.up.sql
Normal file
9
pkg/sqlite/migrations/71_custom_fields.up.sql
Normal file
@@ -0,0 +1,9 @@
|
||||
CREATE TABLE `performer_custom_fields` (
|
||||
`performer_id` integer NOT NULL,
|
||||
`field` varchar(64) NOT NULL,
|
||||
`value` BLOB NOT NULL,
|
||||
PRIMARY KEY (`performer_id`, `field`),
|
||||
foreign key(`performer_id`) references `performers`(`id`) on delete CASCADE
|
||||
);
|
||||
|
||||
CREATE INDEX `index_performer_custom_fields_field_value` ON `performer_custom_fields` (`field`, `value`);
|
||||
@@ -226,6 +226,7 @@ var (
|
||||
|
||||
type PerformerStore struct {
|
||||
blobJoinQueryBuilder
|
||||
customFieldsStore
|
||||
|
||||
tableMgr *table
|
||||
}
|
||||
@@ -236,6 +237,10 @@ func NewPerformerStore(blobStore *BlobStore) *PerformerStore {
|
||||
blobStore: blobStore,
|
||||
joinTable: performerTable,
|
||||
},
|
||||
customFieldsStore: customFieldsStore{
|
||||
table: performersCustomFieldsTable,
|
||||
fk: performersCustomFieldsTable.Col(performerIDColumn),
|
||||
},
|
||||
tableMgr: performerTableMgr,
|
||||
}
|
||||
}
|
||||
@@ -248,9 +253,9 @@ func (qb *PerformerStore) selectDataset() *goqu.SelectDataset {
|
||||
return dialect.From(qb.table()).Select(qb.table().All())
|
||||
}
|
||||
|
||||
func (qb *PerformerStore) Create(ctx context.Context, newObject *models.Performer) error {
|
||||
func (qb *PerformerStore) Create(ctx context.Context, newObject *models.CreatePerformerInput) error {
|
||||
var r performerRow
|
||||
r.fromPerformer(*newObject)
|
||||
r.fromPerformer(*newObject.Performer)
|
||||
|
||||
id, err := qb.tableMgr.insertID(ctx, r)
|
||||
if err != nil {
|
||||
@@ -282,12 +287,17 @@ func (qb *PerformerStore) Create(ctx context.Context, newObject *models.Performe
|
||||
}
|
||||
}
|
||||
|
||||
const partial = false
|
||||
if err := qb.setCustomFields(ctx, id, newObject.CustomFields, partial); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updated, err := qb.find(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("finding after create: %w", err)
|
||||
}
|
||||
|
||||
*newObject = *updated
|
||||
*newObject.Performer = *updated
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -330,12 +340,16 @@ func (qb *PerformerStore) UpdatePartial(ctx context.Context, id int, partial mod
|
||||
}
|
||||
}
|
||||
|
||||
if err := qb.SetCustomFields(ctx, id, partial.CustomFields); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return qb.find(ctx, id)
|
||||
}
|
||||
|
||||
func (qb *PerformerStore) Update(ctx context.Context, updatedObject *models.Performer) error {
|
||||
func (qb *PerformerStore) Update(ctx context.Context, updatedObject *models.UpdatePerformerInput) error {
|
||||
var r performerRow
|
||||
r.fromPerformer(*updatedObject)
|
||||
r.fromPerformer(*updatedObject.Performer)
|
||||
|
||||
if err := qb.tableMgr.updateByID(ctx, updatedObject.ID, r); err != nil {
|
||||
return err
|
||||
@@ -365,6 +379,10 @@ func (qb *PerformerStore) Update(ctx context.Context, updatedObject *models.Perf
|
||||
}
|
||||
}
|
||||
|
||||
if err := qb.SetCustomFields(ctx, updatedObject.ID, updatedObject.CustomFields); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -203,6 +203,13 @@ func (qb *performerFilterHandler) criterionHandler() criterionHandler {
|
||||
performerRepository.tags.innerJoin(f, "performer_tag", "performers.id")
|
||||
},
|
||||
},
|
||||
|
||||
&customFieldsFilterHandler{
|
||||
table: performersCustomFieldsTable.GetTable(),
|
||||
fkCol: performerIDColumn,
|
||||
c: filter.CustomFields,
|
||||
idCol: "performers.id",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -16,6 +16,12 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var testCustomFields = map[string]interface{}{
|
||||
"string": "aaa",
|
||||
"int": int64(123), // int64 to match the type of the field in the database
|
||||
"real": 1.23,
|
||||
}
|
||||
|
||||
func loadPerformerRelationships(ctx context.Context, expected models.Performer, actual *models.Performer) error {
|
||||
if expected.Aliases.Loaded() {
|
||||
if err := actual.LoadAliases(ctx, db.Performer); err != nil {
|
||||
@@ -81,57 +87,62 @@ func Test_PerformerStore_Create(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
newObject models.Performer
|
||||
newObject models.CreatePerformerInput
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"full",
|
||||
models.Performer{
|
||||
Name: name,
|
||||
Disambiguation: disambiguation,
|
||||
Gender: &gender,
|
||||
URLs: models.NewRelatedStrings(urls),
|
||||
Birthdate: &birthdate,
|
||||
Ethnicity: ethnicity,
|
||||
Country: country,
|
||||
EyeColor: eyeColor,
|
||||
Height: &height,
|
||||
Measurements: measurements,
|
||||
FakeTits: fakeTits,
|
||||
PenisLength: &penisLength,
|
||||
Circumcised: &circumcised,
|
||||
CareerLength: careerLength,
|
||||
Tattoos: tattoos,
|
||||
Piercings: piercings,
|
||||
Favorite: favorite,
|
||||
Rating: &rating,
|
||||
Details: details,
|
||||
DeathDate: &deathdate,
|
||||
HairColor: hairColor,
|
||||
Weight: &weight,
|
||||
IgnoreAutoTag: ignoreAutoTag,
|
||||
TagIDs: models.NewRelatedIDs([]int{tagIDs[tagIdx1WithPerformer], tagIDs[tagIdx1WithDupName]}),
|
||||
Aliases: models.NewRelatedStrings(aliases),
|
||||
StashIDs: models.NewRelatedStashIDs([]models.StashID{
|
||||
{
|
||||
StashID: stashID1,
|
||||
Endpoint: endpoint1,
|
||||
},
|
||||
{
|
||||
StashID: stashID2,
|
||||
Endpoint: endpoint2,
|
||||
},
|
||||
}),
|
||||
CreatedAt: createdAt,
|
||||
UpdatedAt: updatedAt,
|
||||
models.CreatePerformerInput{
|
||||
Performer: &models.Performer{
|
||||
Name: name,
|
||||
Disambiguation: disambiguation,
|
||||
Gender: &gender,
|
||||
URLs: models.NewRelatedStrings(urls),
|
||||
Birthdate: &birthdate,
|
||||
Ethnicity: ethnicity,
|
||||
Country: country,
|
||||
EyeColor: eyeColor,
|
||||
Height: &height,
|
||||
Measurements: measurements,
|
||||
FakeTits: fakeTits,
|
||||
PenisLength: &penisLength,
|
||||
Circumcised: &circumcised,
|
||||
CareerLength: careerLength,
|
||||
Tattoos: tattoos,
|
||||
Piercings: piercings,
|
||||
Favorite: favorite,
|
||||
Rating: &rating,
|
||||
Details: details,
|
||||
DeathDate: &deathdate,
|
||||
HairColor: hairColor,
|
||||
Weight: &weight,
|
||||
IgnoreAutoTag: ignoreAutoTag,
|
||||
TagIDs: models.NewRelatedIDs([]int{tagIDs[tagIdx1WithPerformer], tagIDs[tagIdx1WithDupName]}),
|
||||
Aliases: models.NewRelatedStrings(aliases),
|
||||
StashIDs: models.NewRelatedStashIDs([]models.StashID{
|
||||
{
|
||||
StashID: stashID1,
|
||||
Endpoint: endpoint1,
|
||||
},
|
||||
{
|
||||
StashID: stashID2,
|
||||
Endpoint: endpoint2,
|
||||
},
|
||||
}),
|
||||
CreatedAt: createdAt,
|
||||
UpdatedAt: updatedAt,
|
||||
},
|
||||
CustomFields: testCustomFields,
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"invalid tag id",
|
||||
models.Performer{
|
||||
Name: name,
|
||||
TagIDs: models.NewRelatedIDs([]int{invalidID}),
|
||||
models.CreatePerformerInput{
|
||||
Performer: &models.Performer{
|
||||
Name: name,
|
||||
TagIDs: models.NewRelatedIDs([]int{invalidID}),
|
||||
},
|
||||
},
|
||||
true,
|
||||
},
|
||||
@@ -155,16 +166,16 @@ func Test_PerformerStore_Create(t *testing.T) {
|
||||
|
||||
assert.NotZero(p.ID)
|
||||
|
||||
copy := tt.newObject
|
||||
copy := *tt.newObject.Performer
|
||||
copy.ID = p.ID
|
||||
|
||||
// load relationships
|
||||
if err := loadPerformerRelationships(ctx, copy, &p); err != nil {
|
||||
if err := loadPerformerRelationships(ctx, copy, p.Performer); err != nil {
|
||||
t.Errorf("loadPerformerRelationships() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(copy, p)
|
||||
assert.Equal(copy, *p.Performer)
|
||||
|
||||
// ensure can find the performer
|
||||
found, err := qb.Find(ctx, p.ID)
|
||||
@@ -183,6 +194,15 @@ func Test_PerformerStore_Create(t *testing.T) {
|
||||
}
|
||||
assert.Equal(copy, *found)
|
||||
|
||||
// ensure custom fields are set
|
||||
cf, err := qb.GetCustomFields(ctx, p.ID)
|
||||
if err != nil {
|
||||
t.Errorf("PerformerStore.GetCustomFields() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(tt.newObject.CustomFields, cf)
|
||||
|
||||
return
|
||||
})
|
||||
}
|
||||
@@ -228,77 +248,109 @@ func Test_PerformerStore_Update(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
updatedObject *models.Performer
|
||||
updatedObject models.UpdatePerformerInput
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"full",
|
||||
&models.Performer{
|
||||
ID: performerIDs[performerIdxWithGallery],
|
||||
Name: name,
|
||||
Disambiguation: disambiguation,
|
||||
Gender: &gender,
|
||||
URLs: models.NewRelatedStrings(urls),
|
||||
Birthdate: &birthdate,
|
||||
Ethnicity: ethnicity,
|
||||
Country: country,
|
||||
EyeColor: eyeColor,
|
||||
Height: &height,
|
||||
Measurements: measurements,
|
||||
FakeTits: fakeTits,
|
||||
PenisLength: &penisLength,
|
||||
Circumcised: &circumcised,
|
||||
CareerLength: careerLength,
|
||||
Tattoos: tattoos,
|
||||
Piercings: piercings,
|
||||
Favorite: favorite,
|
||||
Rating: &rating,
|
||||
Details: details,
|
||||
DeathDate: &deathdate,
|
||||
HairColor: hairColor,
|
||||
Weight: &weight,
|
||||
IgnoreAutoTag: ignoreAutoTag,
|
||||
Aliases: models.NewRelatedStrings(aliases),
|
||||
TagIDs: models.NewRelatedIDs([]int{tagIDs[tagIdx1WithPerformer], tagIDs[tagIdx1WithDupName]}),
|
||||
StashIDs: models.NewRelatedStashIDs([]models.StashID{
|
||||
{
|
||||
StashID: stashID1,
|
||||
Endpoint: endpoint1,
|
||||
},
|
||||
{
|
||||
StashID: stashID2,
|
||||
Endpoint: endpoint2,
|
||||
},
|
||||
}),
|
||||
CreatedAt: createdAt,
|
||||
UpdatedAt: updatedAt,
|
||||
models.UpdatePerformerInput{
|
||||
Performer: &models.Performer{
|
||||
ID: performerIDs[performerIdxWithGallery],
|
||||
Name: name,
|
||||
Disambiguation: disambiguation,
|
||||
Gender: &gender,
|
||||
URLs: models.NewRelatedStrings(urls),
|
||||
Birthdate: &birthdate,
|
||||
Ethnicity: ethnicity,
|
||||
Country: country,
|
||||
EyeColor: eyeColor,
|
||||
Height: &height,
|
||||
Measurements: measurements,
|
||||
FakeTits: fakeTits,
|
||||
PenisLength: &penisLength,
|
||||
Circumcised: &circumcised,
|
||||
CareerLength: careerLength,
|
||||
Tattoos: tattoos,
|
||||
Piercings: piercings,
|
||||
Favorite: favorite,
|
||||
Rating: &rating,
|
||||
Details: details,
|
||||
DeathDate: &deathdate,
|
||||
HairColor: hairColor,
|
||||
Weight: &weight,
|
||||
IgnoreAutoTag: ignoreAutoTag,
|
||||
Aliases: models.NewRelatedStrings(aliases),
|
||||
TagIDs: models.NewRelatedIDs([]int{tagIDs[tagIdx1WithPerformer], tagIDs[tagIdx1WithDupName]}),
|
||||
StashIDs: models.NewRelatedStashIDs([]models.StashID{
|
||||
{
|
||||
StashID: stashID1,
|
||||
Endpoint: endpoint1,
|
||||
},
|
||||
{
|
||||
StashID: stashID2,
|
||||
Endpoint: endpoint2,
|
||||
},
|
||||
}),
|
||||
CreatedAt: createdAt,
|
||||
UpdatedAt: updatedAt,
|
||||
},
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"clear nullables",
|
||||
&models.Performer{
|
||||
ID: performerIDs[performerIdxWithGallery],
|
||||
Aliases: models.NewRelatedStrings([]string{}),
|
||||
URLs: models.NewRelatedStrings([]string{}),
|
||||
TagIDs: models.NewRelatedIDs([]int{}),
|
||||
StashIDs: models.NewRelatedStashIDs([]models.StashID{}),
|
||||
models.UpdatePerformerInput{
|
||||
Performer: &models.Performer{
|
||||
ID: performerIDs[performerIdxWithGallery],
|
||||
Aliases: models.NewRelatedStrings([]string{}),
|
||||
URLs: models.NewRelatedStrings([]string{}),
|
||||
TagIDs: models.NewRelatedIDs([]int{}),
|
||||
StashIDs: models.NewRelatedStashIDs([]models.StashID{}),
|
||||
},
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"clear tag ids",
|
||||
&models.Performer{
|
||||
ID: performerIDs[sceneIdxWithTag],
|
||||
TagIDs: models.NewRelatedIDs([]int{}),
|
||||
models.UpdatePerformerInput{
|
||||
Performer: &models.Performer{
|
||||
ID: performerIDs[sceneIdxWithTag],
|
||||
TagIDs: models.NewRelatedIDs([]int{}),
|
||||
},
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"set custom fields",
|
||||
models.UpdatePerformerInput{
|
||||
Performer: &models.Performer{
|
||||
ID: performerIDs[performerIdxWithGallery],
|
||||
},
|
||||
CustomFields: models.CustomFieldsInput{
|
||||
Full: testCustomFields,
|
||||
},
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"clear custom fields",
|
||||
models.UpdatePerformerInput{
|
||||
Performer: &models.Performer{
|
||||
ID: performerIDs[performerIdxWithGallery],
|
||||
},
|
||||
CustomFields: models.CustomFieldsInput{
|
||||
Full: map[string]interface{}{},
|
||||
},
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"invalid tag id",
|
||||
&models.Performer{
|
||||
ID: performerIDs[sceneIdxWithGallery],
|
||||
TagIDs: models.NewRelatedIDs([]int{invalidID}),
|
||||
models.UpdatePerformerInput{
|
||||
Performer: &models.Performer{
|
||||
ID: performerIDs[sceneIdxWithGallery],
|
||||
TagIDs: models.NewRelatedIDs([]int{invalidID}),
|
||||
},
|
||||
},
|
||||
true,
|
||||
},
|
||||
@@ -309,9 +361,9 @@ func Test_PerformerStore_Update(t *testing.T) {
|
||||
runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) {
|
||||
assert := assert.New(t)
|
||||
|
||||
copy := *tt.updatedObject
|
||||
copy := *tt.updatedObject.Performer
|
||||
|
||||
if err := qb.Update(ctx, tt.updatedObject); (err != nil) != tt.wantErr {
|
||||
if err := qb.Update(ctx, &tt.updatedObject); (err != nil) != tt.wantErr {
|
||||
t.Errorf("PerformerStore.Update() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
|
||||
@@ -331,6 +383,17 @@ func Test_PerformerStore_Update(t *testing.T) {
|
||||
}
|
||||
|
||||
assert.Equal(copy, *s)
|
||||
|
||||
// ensure custom fields are correct
|
||||
if tt.updatedObject.CustomFields.Full != nil {
|
||||
cf, err := qb.GetCustomFields(ctx, tt.updatedObject.ID)
|
||||
if err != nil {
|
||||
t.Errorf("PerformerStore.GetCustomFields() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(tt.updatedObject.CustomFields.Full, cf)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -573,6 +636,79 @@ func Test_PerformerStore_UpdatePartial(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_PerformerStore_UpdatePartialCustomFields(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
id int
|
||||
partial models.PerformerPartial
|
||||
expected map[string]interface{} // nil to use the partial
|
||||
}{
|
||||
{
|
||||
"set custom fields",
|
||||
performerIDs[performerIdxWithGallery],
|
||||
models.PerformerPartial{
|
||||
CustomFields: models.CustomFieldsInput{
|
||||
Full: testCustomFields,
|
||||
},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"clear custom fields",
|
||||
performerIDs[performerIdxWithGallery],
|
||||
models.PerformerPartial{
|
||||
CustomFields: models.CustomFieldsInput{
|
||||
Full: map[string]interface{}{},
|
||||
},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"partial custom fields",
|
||||
performerIDs[performerIdxWithGallery],
|
||||
models.PerformerPartial{
|
||||
CustomFields: models.CustomFieldsInput{
|
||||
Partial: map[string]interface{}{
|
||||
"string": "bbb",
|
||||
"new_field": "new",
|
||||
},
|
||||
},
|
||||
},
|
||||
map[string]interface{}{
|
||||
"int": int64(3),
|
||||
"real": 1.3,
|
||||
"string": "bbb",
|
||||
"new_field": "new",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
qb := db.Performer
|
||||
|
||||
runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) {
|
||||
assert := assert.New(t)
|
||||
|
||||
_, err := qb.UpdatePartial(ctx, tt.id, tt.partial)
|
||||
if err != nil {
|
||||
t.Errorf("PerformerStore.UpdatePartial() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// ensure custom fields are correct
|
||||
cf, err := qb.GetCustomFields(ctx, tt.id)
|
||||
if err != nil {
|
||||
t.Errorf("PerformerStore.GetCustomFields() error = %v", err)
|
||||
return
|
||||
}
|
||||
if tt.expected == nil {
|
||||
assert.Equal(tt.partial.CustomFields.Full, cf)
|
||||
} else {
|
||||
assert.Equal(tt.expected, cf)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPerformerFindBySceneID(t *testing.T) {
|
||||
withTxn(func(ctx context.Context) error {
|
||||
pqb := db.Performer
|
||||
@@ -1042,6 +1178,242 @@ func TestPerformerQuery(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPerformerQueryCustomFields(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
filter *models.PerformerFilterType
|
||||
includeIdxs []int
|
||||
excludeIdxs []int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"equals",
|
||||
&models.PerformerFilterType{
|
||||
CustomFields: []models.CustomFieldCriterionInput{
|
||||
{
|
||||
Field: "string",
|
||||
Modifier: models.CriterionModifierEquals,
|
||||
Value: []any{getPerformerStringValue(performerIdxWithGallery, "custom")},
|
||||
},
|
||||
},
|
||||
},
|
||||
[]int{performerIdxWithGallery},
|
||||
nil,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"not equals",
|
||||
&models.PerformerFilterType{
|
||||
Name: &models.StringCriterionInput{
|
||||
Value: getPerformerStringValue(performerIdxWithGallery, "Name"),
|
||||
Modifier: models.CriterionModifierEquals,
|
||||
},
|
||||
CustomFields: []models.CustomFieldCriterionInput{
|
||||
{
|
||||
Field: "string",
|
||||
Modifier: models.CriterionModifierNotEquals,
|
||||
Value: []any{getPerformerStringValue(performerIdxWithGallery, "custom")},
|
||||
},
|
||||
},
|
||||
},
|
||||
nil,
|
||||
[]int{performerIdxWithGallery},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"includes",
|
||||
&models.PerformerFilterType{
|
||||
CustomFields: []models.CustomFieldCriterionInput{
|
||||
{
|
||||
Field: "string",
|
||||
Modifier: models.CriterionModifierIncludes,
|
||||
Value: []any{getPerformerStringValue(performerIdxWithGallery, "custom")[9:]},
|
||||
},
|
||||
},
|
||||
},
|
||||
[]int{performerIdxWithGallery},
|
||||
nil,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"excludes",
|
||||
&models.PerformerFilterType{
|
||||
Name: &models.StringCriterionInput{
|
||||
Value: getPerformerStringValue(performerIdxWithGallery, "Name"),
|
||||
Modifier: models.CriterionModifierEquals,
|
||||
},
|
||||
CustomFields: []models.CustomFieldCriterionInput{
|
||||
{
|
||||
Field: "string",
|
||||
Modifier: models.CriterionModifierExcludes,
|
||||
Value: []any{getPerformerStringValue(performerIdxWithGallery, "custom")[9:]},
|
||||
},
|
||||
},
|
||||
},
|
||||
nil,
|
||||
[]int{performerIdxWithGallery},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"regex",
|
||||
&models.PerformerFilterType{
|
||||
CustomFields: []models.CustomFieldCriterionInput{
|
||||
{
|
||||
Field: "string",
|
||||
Modifier: models.CriterionModifierMatchesRegex,
|
||||
Value: []any{".*13_custom"},
|
||||
},
|
||||
},
|
||||
},
|
||||
[]int{performerIdxWithGallery},
|
||||
nil,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"invalid regex",
|
||||
&models.PerformerFilterType{
|
||||
CustomFields: []models.CustomFieldCriterionInput{
|
||||
{
|
||||
Field: "string",
|
||||
Modifier: models.CriterionModifierMatchesRegex,
|
||||
Value: []any{"["},
|
||||
},
|
||||
},
|
||||
},
|
||||
nil,
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"not matches regex",
|
||||
&models.PerformerFilterType{
|
||||
Name: &models.StringCriterionInput{
|
||||
Value: getPerformerStringValue(performerIdxWithGallery, "Name"),
|
||||
Modifier: models.CriterionModifierEquals,
|
||||
},
|
||||
CustomFields: []models.CustomFieldCriterionInput{
|
||||
{
|
||||
Field: "string",
|
||||
Modifier: models.CriterionModifierNotMatchesRegex,
|
||||
Value: []any{".*13_custom"},
|
||||
},
|
||||
},
|
||||
},
|
||||
nil,
|
||||
[]int{performerIdxWithGallery},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"invalid not matches regex",
|
||||
&models.PerformerFilterType{
|
||||
CustomFields: []models.CustomFieldCriterionInput{
|
||||
{
|
||||
Field: "string",
|
||||
Modifier: models.CriterionModifierNotMatchesRegex,
|
||||
Value: []any{"["},
|
||||
},
|
||||
},
|
||||
},
|
||||
nil,
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"null",
|
||||
&models.PerformerFilterType{
|
||||
Name: &models.StringCriterionInput{
|
||||
Value: getPerformerStringValue(performerIdxWithGallery, "Name"),
|
||||
Modifier: models.CriterionModifierEquals,
|
||||
},
|
||||
CustomFields: []models.CustomFieldCriterionInput{
|
||||
{
|
||||
Field: "not existing",
|
||||
Modifier: models.CriterionModifierIsNull,
|
||||
},
|
||||
},
|
||||
},
|
||||
[]int{performerIdxWithGallery},
|
||||
nil,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"null",
|
||||
&models.PerformerFilterType{
|
||||
Name: &models.StringCriterionInput{
|
||||
Value: getPerformerStringValue(performerIdxWithGallery, "Name"),
|
||||
Modifier: models.CriterionModifierEquals,
|
||||
},
|
||||
CustomFields: []models.CustomFieldCriterionInput{
|
||||
{
|
||||
Field: "string",
|
||||
Modifier: models.CriterionModifierNotNull,
|
||||
},
|
||||
},
|
||||
},
|
||||
[]int{performerIdxWithGallery},
|
||||
nil,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"between",
|
||||
&models.PerformerFilterType{
|
||||
CustomFields: []models.CustomFieldCriterionInput{
|
||||
{
|
||||
Field: "real",
|
||||
Modifier: models.CriterionModifierBetween,
|
||||
Value: []any{0.05, 0.15},
|
||||
},
|
||||
},
|
||||
},
|
||||
[]int{performerIdx1WithScene},
|
||||
nil,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"not between",
|
||||
&models.PerformerFilterType{
|
||||
Name: &models.StringCriterionInput{
|
||||
Value: getPerformerStringValue(performerIdx1WithScene, "Name"),
|
||||
Modifier: models.CriterionModifierEquals,
|
||||
},
|
||||
CustomFields: []models.CustomFieldCriterionInput{
|
||||
{
|
||||
Field: "real",
|
||||
Modifier: models.CriterionModifierNotBetween,
|
||||
Value: []any{0.05, 0.15},
|
||||
},
|
||||
},
|
||||
},
|
||||
nil,
|
||||
[]int{performerIdx1WithScene},
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) {
|
||||
assert := assert.New(t)
|
||||
|
||||
performers, _, err := db.Performer.Query(ctx, tt.filter, nil)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("PerformerStore.Query() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
ids := performersToIDs(performers)
|
||||
include := indexesToIDs(performerIDs, tt.includeIdxs)
|
||||
exclude := indexesToIDs(performerIDs, tt.excludeIdxs)
|
||||
|
||||
for _, i := range include {
|
||||
assert.Contains(ids, i)
|
||||
}
|
||||
for _, e := range exclude {
|
||||
assert.NotContains(ids, e)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPerformerQueryPenisLength(t *testing.T) {
|
||||
var upper = 4.0
|
||||
|
||||
@@ -1172,7 +1544,7 @@ func TestPerformerUpdatePerformerImage(t *testing.T) {
|
||||
performer := models.Performer{
|
||||
Name: name,
|
||||
}
|
||||
err := qb.Create(ctx, &performer)
|
||||
err := qb.Create(ctx, &models.CreatePerformerInput{Performer: &performer})
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error creating performer: %s", err.Error())
|
||||
}
|
||||
@@ -1680,7 +2052,7 @@ func TestPerformerStashIDs(t *testing.T) {
|
||||
performer := &models.Performer{
|
||||
Name: name,
|
||||
}
|
||||
if err := qb.Create(ctx, performer); err != nil {
|
||||
if err := qb.Create(ctx, &models.CreatePerformerInput{Performer: performer}); err != nil {
|
||||
return fmt.Errorf("Error creating performer: %s", err.Error())
|
||||
}
|
||||
|
||||
|
||||
@@ -133,6 +133,9 @@ func (qb *queryBuilder) join(table, as, onClause string) {
|
||||
|
||||
func (qb *queryBuilder) addJoins(joins ...join) {
|
||||
qb.joins.add(joins...)
|
||||
for _, j := range joins {
|
||||
qb.args = append(qb.args, j.args...)
|
||||
}
|
||||
}
|
||||
|
||||
func (qb *queryBuilder) addFilter(f *filterBuilder) error {
|
||||
@@ -151,6 +154,9 @@ func (qb *queryBuilder) addFilter(f *filterBuilder) error {
|
||||
qb.args = append(args, qb.args...)
|
||||
}
|
||||
|
||||
// add joins here to insert args
|
||||
qb.addJoins(f.getAllJoins()...)
|
||||
|
||||
clause, args = f.generateWhereClauses()
|
||||
if len(clause) > 0 {
|
||||
qb.addWhere(clause)
|
||||
@@ -169,8 +175,6 @@ func (qb *queryBuilder) addFilter(f *filterBuilder) error {
|
||||
qb.addArg(args...)
|
||||
}
|
||||
|
||||
qb.addJoins(f.getAllJoins()...)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -222,8 +222,8 @@ func (r *repository) innerJoin(j joiner, as string, parentIDCol string) {
|
||||
}
|
||||
|
||||
type joiner interface {
|
||||
addLeftJoin(table, as, onClause string)
|
||||
addInnerJoin(table, as, onClause string)
|
||||
addLeftJoin(table, as, onClause string, args ...interface{})
|
||||
addInnerJoin(table, as, onClause string, args ...interface{})
|
||||
}
|
||||
|
||||
type joinRepository struct {
|
||||
|
||||
@@ -1508,6 +1508,18 @@ func performerAliases(i int) []string {
|
||||
return []string{getPerformerStringValue(i, "alias")}
|
||||
}
|
||||
|
||||
func getPerformerCustomFields(index int) map[string]interface{} {
|
||||
if index%5 == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"string": getPerformerStringValue(index, "custom"),
|
||||
"int": int64(index % 5),
|
||||
"real": float64(index) / 10,
|
||||
}
|
||||
}
|
||||
|
||||
// createPerformers creates n performers with plain Name and o performers with camel cased NaMe included
|
||||
func createPerformers(ctx context.Context, n int, o int) error {
|
||||
pqb := db.Performer
|
||||
@@ -1558,7 +1570,10 @@ func createPerformers(ctx context.Context, n int, o int) error {
|
||||
})
|
||||
}
|
||||
|
||||
err := pqb.Create(ctx, &performer)
|
||||
err := pqb.Create(ctx, &models.CreatePerformerInput{
|
||||
Performer: &performer,
|
||||
CustomFields: getPerformerCustomFields(i),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error creating performer %v+: %s", performer, err.Error())
|
||||
|
||||
@@ -32,6 +32,7 @@ var (
|
||||
performersURLsJoinTable = goqu.T(performerURLsTable)
|
||||
performersTagsJoinTable = goqu.T(performersTagsTable)
|
||||
performersStashIDsJoinTable = goqu.T("performer_stash_ids")
|
||||
performersCustomFieldsTable = goqu.T("performer_custom_fields")
|
||||
|
||||
studiosAliasesJoinTable = goqu.T(studioAliasesTable)
|
||||
studiosTagsJoinTable = goqu.T(studiosTagsTable)
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// JSONNumberToNumber converts a JSON number to either a float64 or int64.
|
||||
func JSONNumberToNumber(n json.Number) interface{} {
|
||||
if strings.Contains(string(n), ".") {
|
||||
f, _ := n.Float64()
|
||||
return f
|
||||
}
|
||||
ret, _ := n.Int64()
|
||||
return ret
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -80,19 +79,3 @@ func MergeMaps(dest map[string]interface{}, src map[string]interface{}) {
|
||||
dest[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertMapJSONNumbers converts all JSON numbers in a map to either float64 or int64.
|
||||
func ConvertMapJSONNumbers(m map[string]interface{}) (ret map[string]interface{}) {
|
||||
ret = make(map[string]interface{})
|
||||
for k, v := range m {
|
||||
if n, ok := v.(json.Number); ok {
|
||||
ret[k] = JSONNumberToNumber(n)
|
||||
} else if mm, ok := v.(map[string]interface{}); ok {
|
||||
ret[k] = ConvertMapJSONNumbers(mm)
|
||||
} else {
|
||||
ret[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNestedMapGet(t *testing.T) {
|
||||
@@ -282,55 +279,3 @@ func TestMergeMaps(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertMapJSONNumbers(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input map[string]interface{}
|
||||
expected map[string]interface{}
|
||||
}{
|
||||
{
|
||||
name: "Convert JSON numbers to numbers",
|
||||
input: map[string]interface{}{
|
||||
"int": json.Number("12"),
|
||||
"float": json.Number("12.34"),
|
||||
"string": "foo",
|
||||
},
|
||||
expected: map[string]interface{}{
|
||||
"int": int64(12),
|
||||
"float": 12.34,
|
||||
"string": "foo",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Convert JSON numbers to numbers in nested maps",
|
||||
input: map[string]interface{}{
|
||||
"foo": map[string]interface{}{
|
||||
"int": json.Number("56"),
|
||||
"float": json.Number("56.78"),
|
||||
"nested-string": "bar",
|
||||
},
|
||||
"int": json.Number("12"),
|
||||
"float": json.Number("12.34"),
|
||||
"string": "foo",
|
||||
},
|
||||
expected: map[string]interface{}{
|
||||
"foo": map[string]interface{}{
|
||||
"int": int64(56),
|
||||
"float": 56.78,
|
||||
"nested-string": "bar",
|
||||
},
|
||||
"int": int64(12),
|
||||
"float": 12.34,
|
||||
"string": "foo",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ConvertMapJSONNumbers(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user