Performer custom fields (#5487)

* Backend changes
* Show custom field values
* Add custom fields table input
* Add custom field filtering
* Add unit tests
* Include custom fields in import/export
* Anonymise performer custom fields
* Move json.Number handler functions to api
* Handle json.Number conversion in api
This commit is contained in:
WithoutPants
2024-12-03 13:49:55 +11:00
committed by GitHub
parent a0e09bbe5c
commit 8c8be22fe4
56 changed files with 2158 additions and 277 deletions

View File

@@ -600,6 +600,10 @@ func (db *Anonymiser) anonymisePerformers(ctx context.Context) error {
return err
}
if err := db.anonymiseCustomFields(ctx, goqu.T(performersCustomFieldsTable.GetTable()), "performer_id"); err != nil {
return err
}
return nil
}
@@ -1050,3 +1054,73 @@ func (db *Anonymiser) obfuscateString(in string, dict string) string {
return out.String()
}
func (db *Anonymiser) anonymiseCustomFields(ctx context.Context, table exp.IdentifierExpression, idColumn string) error {
lastID := 0
lastField := ""
total := 0
const logEvery = 10000
for gotSome := true; gotSome; {
if err := txn.WithTxn(ctx, db, func(ctx context.Context) error {
query := dialect.From(table).Select(
table.Col(idColumn),
table.Col("field"),
table.Col("value"),
).Where(
goqu.L("("+idColumn+", field)").Gt(goqu.L("(?, ?)", lastID, lastField)),
).Order(
table.Col(idColumn).Asc(), table.Col("field").Asc(),
).Limit(1000)
gotSome = false
const single = false
return queryFunc(ctx, query, single, func(rows *sqlx.Rows) error {
var (
id int
field string
value string
)
if err := rows.Scan(
&id,
&field,
&value,
); err != nil {
return err
}
set := goqu.Record{}
set["field"] = db.obfuscateString(field, letters)
set["value"] = db.obfuscateString(value, letters)
if len(set) > 0 {
stmt := dialect.Update(table).Set(set).Where(
table.Col(idColumn).Eq(id),
table.Col("field").Eq(field),
)
if _, err := exec(ctx, stmt); err != nil {
return fmt.Errorf("anonymising %s: %w", table.GetTable(), err)
}
}
lastID = id
lastField = field
gotSome = true
total++
if total%logEvery == 0 {
logger.Infof("Anonymised %d %s custom fields", total, table.GetTable())
}
return nil
})
}); err != nil {
return err
}
}
return nil
}

308
pkg/sqlite/custom_fields.go Normal file
View File

@@ -0,0 +1,308 @@
package sqlite
import (
"context"
"fmt"
"regexp"
"strings"
"github.com/doug-martin/goqu/v9"
"github.com/doug-martin/goqu/v9/exp"
"github.com/jmoiron/sqlx"
"github.com/stashapp/stash/pkg/models"
)
const maxCustomFieldNameLength = 64
type customFieldsStore struct {
table exp.IdentifierExpression
fk exp.IdentifierExpression
}
func (s *customFieldsStore) deleteForID(ctx context.Context, id int) error {
table := s.table
q := dialect.Delete(table).Where(s.fk.Eq(id))
_, err := exec(ctx, q)
if err != nil {
return fmt.Errorf("deleting from %s: %w", s.table.GetTable(), err)
}
return nil
}
func (s *customFieldsStore) SetCustomFields(ctx context.Context, id int, values models.CustomFieldsInput) error {
var partial bool
var valMap map[string]interface{}
switch {
case values.Full != nil:
partial = false
valMap = values.Full
case values.Partial != nil:
partial = true
valMap = values.Partial
default:
return nil
}
if err := s.validateCustomFields(valMap); err != nil {
return err
}
return s.setCustomFields(ctx, id, valMap, partial)
}
func (s *customFieldsStore) validateCustomFields(values map[string]interface{}) error {
// ensure that custom field names are valid
// no leading or trailing whitespace, no empty strings
for k := range values {
if err := s.validateCustomFieldName(k); err != nil {
return fmt.Errorf("custom field name %q: %w", k, err)
}
}
return nil
}
func (s *customFieldsStore) validateCustomFieldName(fieldName string) error {
// ensure that custom field names are valid
// no leading or trailing whitespace, no empty strings
if strings.TrimSpace(fieldName) == "" {
return fmt.Errorf("custom field name cannot be empty")
}
if fieldName != strings.TrimSpace(fieldName) {
return fmt.Errorf("custom field name cannot have leading or trailing whitespace")
}
if len(fieldName) > maxCustomFieldNameLength {
return fmt.Errorf("custom field name must be less than %d characters", maxCustomFieldNameLength+1)
}
return nil
}
func getSQLValueFromCustomFieldInput(input interface{}) (interface{}, error) {
switch v := input.(type) {
case []interface{}, map[string]interface{}:
// TODO - in future it would be nice to convert to a JSON string
// however, we would need some way to differentiate between a JSON string and a regular string
// for now, we will not support objects and arrays
return nil, fmt.Errorf("unsupported custom field value type: %T", input)
default:
return v, nil
}
}
func (s *customFieldsStore) sqlValueToValue(value interface{}) interface{} {
// TODO - if we ever support objects and arrays we will need to add support here
return value
}
func (s *customFieldsStore) setCustomFields(ctx context.Context, id int, values map[string]interface{}, partial bool) error {
if !partial {
// delete existing custom fields
if err := s.deleteForID(ctx, id); err != nil {
return err
}
}
if len(values) == 0 {
return nil
}
conflictKey := s.fk.GetCol().(string) + ", field"
// upsert new custom fields
q := dialect.Insert(s.table).Prepared(true).Cols(s.fk, "field", "value").
OnConflict(goqu.DoUpdate(conflictKey, goqu.Record{"value": goqu.I("excluded.value")}))
r := make([]interface{}, len(values))
var i int
for key, value := range values {
v, err := getSQLValueFromCustomFieldInput(value)
if err != nil {
return fmt.Errorf("getting SQL value for field %q: %w", key, err)
}
r[i] = goqu.Record{"field": key, "value": v, s.fk.GetCol().(string): id}
i++
}
if _, err := exec(ctx, q.Rows(r...)); err != nil {
return fmt.Errorf("inserting custom fields: %w", err)
}
return nil
}
func (s *customFieldsStore) GetCustomFields(ctx context.Context, id int) (map[string]interface{}, error) {
q := dialect.Select("field", "value").From(s.table).Where(s.fk.Eq(id))
const single = false
ret := make(map[string]interface{})
err := queryFunc(ctx, q, single, func(rows *sqlx.Rows) error {
var field string
var value interface{}
if err := rows.Scan(&field, &value); err != nil {
return fmt.Errorf("scanning custom fields: %w", err)
}
ret[field] = s.sqlValueToValue(value)
return nil
})
if err != nil {
return nil, fmt.Errorf("getting custom fields: %w", err)
}
return ret, nil
}
func (s *customFieldsStore) GetCustomFieldsBulk(ctx context.Context, ids []int) ([]models.CustomFieldMap, error) {
q := dialect.Select(s.fk.As("id"), "field", "value").From(s.table).Where(s.fk.In(ids))
const single = false
ret := make([]models.CustomFieldMap, len(ids))
idi := make(map[int]int, len(ids))
for i, id := range ids {
idi[id] = i
}
err := queryFunc(ctx, q, single, func(rows *sqlx.Rows) error {
var id int
var field string
var value interface{}
if err := rows.Scan(&id, &field, &value); err != nil {
return fmt.Errorf("scanning custom fields: %w", err)
}
i := idi[id]
m := ret[i]
if m == nil {
m = make(map[string]interface{})
ret[i] = m
}
m[field] = s.sqlValueToValue(value)
return nil
})
if err != nil {
return nil, fmt.Errorf("getting custom fields: %w", err)
}
return ret, nil
}
type customFieldsFilterHandler struct {
table string
fkCol string
c []models.CustomFieldCriterionInput
idCol string
}
func (h *customFieldsFilterHandler) innerJoin(f *filterBuilder, as string, field string) {
joinOn := fmt.Sprintf("%s = %s.%s AND %s.field = ?", h.idCol, as, h.fkCol, as)
f.addInnerJoin(h.table, as, joinOn, field)
}
func (h *customFieldsFilterHandler) leftJoin(f *filterBuilder, as string, field string) {
joinOn := fmt.Sprintf("%s = %s.%s AND %s.field = ?", h.idCol, as, h.fkCol, as)
f.addLeftJoin(h.table, as, joinOn, field)
}
func (h *customFieldsFilterHandler) handleCriterion(f *filterBuilder, joinAs string, cc models.CustomFieldCriterionInput) {
// convert values
cv := make([]interface{}, len(cc.Value))
for i, v := range cc.Value {
var err error
cv[i], err = getSQLValueFromCustomFieldInput(v)
if err != nil {
f.setError(err)
return
}
}
switch cc.Modifier {
case models.CriterionModifierEquals:
h.innerJoin(f, joinAs, cc.Field)
f.addWhere(fmt.Sprintf("%[1]s.value IN %s", joinAs, getInBinding(len(cv))), cv...)
case models.CriterionModifierNotEquals:
h.innerJoin(f, joinAs, cc.Field)
f.addWhere(fmt.Sprintf("%[1]s.value NOT IN %s", joinAs, getInBinding(len(cv))), cv...)
case models.CriterionModifierIncludes:
clauses := make([]sqlClause, len(cv))
for i, v := range cv {
clauses[i] = makeClause(fmt.Sprintf("%s.value LIKE ?", joinAs), fmt.Sprintf("%%%v%%", v))
}
h.innerJoin(f, joinAs, cc.Field)
f.whereClauses = append(f.whereClauses, clauses...)
case models.CriterionModifierExcludes:
for _, v := range cv {
f.addWhere(fmt.Sprintf("%[1]s.value NOT LIKE ?", joinAs), fmt.Sprintf("%%%v%%", v))
}
h.leftJoin(f, joinAs, cc.Field)
case models.CriterionModifierMatchesRegex:
for _, v := range cv {
vs, ok := v.(string)
if !ok {
f.setError(fmt.Errorf("unsupported custom field criterion value type: %T", v))
}
if _, err := regexp.Compile(vs); err != nil {
f.setError(err)
return
}
f.addWhere(fmt.Sprintf("(%s.value regexp ?)", joinAs), v)
}
h.innerJoin(f, joinAs, cc.Field)
case models.CriterionModifierNotMatchesRegex:
for _, v := range cv {
vs, ok := v.(string)
if !ok {
f.setError(fmt.Errorf("unsupported custom field criterion value type: %T", v))
}
if _, err := regexp.Compile(vs); err != nil {
f.setError(err)
return
}
f.addWhere(fmt.Sprintf("(%s.value IS NULL OR %[1]s.value NOT regexp ?)", joinAs), v)
}
h.leftJoin(f, joinAs, cc.Field)
case models.CriterionModifierIsNull:
h.leftJoin(f, joinAs, cc.Field)
f.addWhere(fmt.Sprintf("%s.value IS NULL OR TRIM(%[1]s.value) = ''", joinAs))
case models.CriterionModifierNotNull:
h.innerJoin(f, joinAs, cc.Field)
f.addWhere(fmt.Sprintf("TRIM(%[1]s.value) != ''", joinAs))
case models.CriterionModifierBetween:
if len(cv) != 2 {
f.setError(fmt.Errorf("expected 2 values for custom field criterion modifier BETWEEN, got %d", len(cv)))
return
}
h.innerJoin(f, joinAs, cc.Field)
f.addWhere(fmt.Sprintf("%s.value BETWEEN ? AND ?", joinAs), cv[0], cv[1])
case models.CriterionModifierNotBetween:
h.innerJoin(f, joinAs, cc.Field)
f.addWhere(fmt.Sprintf("%s.value NOT BETWEEN ? AND ?", joinAs), cv[0], cv[1])
case models.CriterionModifierLessThan:
if len(cv) != 1 {
f.setError(fmt.Errorf("expected 1 value for custom field criterion modifier LESS_THAN, got %d", len(cv)))
return
}
h.innerJoin(f, joinAs, cc.Field)
f.addWhere(fmt.Sprintf("%s.value < ?", joinAs), cv[0])
case models.CriterionModifierGreaterThan:
if len(cv) != 1 {
f.setError(fmt.Errorf("expected 1 value for custom field criterion modifier LESS_THAN, got %d", len(cv)))
return
}
h.innerJoin(f, joinAs, cc.Field)
f.addWhere(fmt.Sprintf("%s.value > ?", joinAs), cv[0])
default:
f.setError(fmt.Errorf("unsupported custom field criterion modifier: %s", cc.Modifier))
}
}
func (h *customFieldsFilterHandler) handle(ctx context.Context, f *filterBuilder) {
if len(h.c) == 0 {
return
}
for i, cc := range h.c {
join := fmt.Sprintf("custom_fields_%d", i)
h.handleCriterion(f, join, cc)
}
}

View File

@@ -0,0 +1,176 @@
//go:build integration
// +build integration
package sqlite_test
import (
"context"
"testing"
"github.com/stashapp/stash/pkg/models"
"github.com/stretchr/testify/assert"
)
func TestSetCustomFields(t *testing.T) {
performerIdx := performerIdx1WithScene
mergeCustomFields := func(i map[string]interface{}) map[string]interface{} {
m := getPerformerCustomFields(performerIdx)
for k, v := range i {
m[k] = v
}
return m
}
tests := []struct {
name string
input models.CustomFieldsInput
expected map[string]interface{}
wantErr bool
}{
{
"valid full",
models.CustomFieldsInput{
Full: map[string]interface{}{
"key": "value",
},
},
map[string]interface{}{
"key": "value",
},
false,
},
{
"valid partial",
models.CustomFieldsInput{
Partial: map[string]interface{}{
"key": "value",
},
},
mergeCustomFields(map[string]interface{}{
"key": "value",
}),
false,
},
{
"valid partial overwrite",
models.CustomFieldsInput{
Partial: map[string]interface{}{
"real": float64(4.56),
},
},
mergeCustomFields(map[string]interface{}{
"real": float64(4.56),
}),
false,
},
{
"leading space full",
models.CustomFieldsInput{
Full: map[string]interface{}{
" key": "value",
},
},
nil,
true,
},
{
"trailing space full",
models.CustomFieldsInput{
Full: map[string]interface{}{
"key ": "value",
},
},
nil,
true,
},
{
"leading space partial",
models.CustomFieldsInput{
Partial: map[string]interface{}{
" key": "value",
},
},
nil,
true,
},
{
"trailing space partial",
models.CustomFieldsInput{
Partial: map[string]interface{}{
"key ": "value",
},
},
nil,
true,
},
{
"big key full",
models.CustomFieldsInput{
Full: map[string]interface{}{
"12345678901234567890123456789012345678901234567890123456789012345": "value",
},
},
nil,
true,
},
{
"big key partial",
models.CustomFieldsInput{
Partial: map[string]interface{}{
"12345678901234567890123456789012345678901234567890123456789012345": "value",
},
},
nil,
true,
},
{
"empty key full",
models.CustomFieldsInput{
Full: map[string]interface{}{
"": "value",
},
},
nil,
true,
},
{
"empty key partial",
models.CustomFieldsInput{
Partial: map[string]interface{}{
"": "value",
},
},
nil,
true,
},
}
// use performer custom fields store
store := db.Performer
id := performerIDs[performerIdx]
assert := assert.New(t)
for _, tt := range tests {
runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) {
err := store.SetCustomFields(ctx, id, tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("SetCustomFields() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr {
return
}
actual, err := store.GetCustomFields(ctx, id)
if err != nil {
t.Errorf("GetCustomFields() error = %v", err)
return
}
assert.Equal(tt.expected, actual)
})
}
}

View File

@@ -34,7 +34,7 @@ const (
cacheSizeEnv = "STASH_SQLITE_CACHE_SIZE"
)
var appSchemaVersion uint = 70
var appSchemaVersion uint = 71
//go:embed migrations/*.sql
var migrationsBox embed.FS

View File

@@ -95,6 +95,7 @@ type join struct {
as string
onClause string
joinType string
args []interface{}
}
// equals returns true if the other join alias/table is equal to this one
@@ -229,12 +230,13 @@ func (f *filterBuilder) not(n *filterBuilder) {
// The AS is omitted if as is empty.
// This method does not add a join if it its alias/table name is already
// present in another existing join.
func (f *filterBuilder) addLeftJoin(table, as, onClause string) {
func (f *filterBuilder) addLeftJoin(table, as, onClause string, args ...interface{}) {
newJoin := join{
table: table,
as: as,
onClause: onClause,
joinType: "LEFT",
args: args,
}
f.joins.add(newJoin)
@@ -245,12 +247,13 @@ func (f *filterBuilder) addLeftJoin(table, as, onClause string) {
// The AS is omitted if as is empty.
// This method does not add a join if it its alias/table name is already
// present in another existing join.
func (f *filterBuilder) addInnerJoin(table, as, onClause string) {
func (f *filterBuilder) addInnerJoin(table, as, onClause string, args ...interface{}) {
newJoin := join{
table: table,
as: as,
onClause: onClause,
joinType: "INNER",
args: args,
}
f.joins.add(newJoin)

View File

@@ -0,0 +1,9 @@
CREATE TABLE `performer_custom_fields` (
`performer_id` integer NOT NULL,
`field` varchar(64) NOT NULL,
`value` BLOB NOT NULL,
PRIMARY KEY (`performer_id`, `field`),
foreign key(`performer_id`) references `performers`(`id`) on delete CASCADE
);
CREATE INDEX `index_performer_custom_fields_field_value` ON `performer_custom_fields` (`field`, `value`);

View File

@@ -226,6 +226,7 @@ var (
type PerformerStore struct {
blobJoinQueryBuilder
customFieldsStore
tableMgr *table
}
@@ -236,6 +237,10 @@ func NewPerformerStore(blobStore *BlobStore) *PerformerStore {
blobStore: blobStore,
joinTable: performerTable,
},
customFieldsStore: customFieldsStore{
table: performersCustomFieldsTable,
fk: performersCustomFieldsTable.Col(performerIDColumn),
},
tableMgr: performerTableMgr,
}
}
@@ -248,9 +253,9 @@ func (qb *PerformerStore) selectDataset() *goqu.SelectDataset {
return dialect.From(qb.table()).Select(qb.table().All())
}
func (qb *PerformerStore) Create(ctx context.Context, newObject *models.Performer) error {
func (qb *PerformerStore) Create(ctx context.Context, newObject *models.CreatePerformerInput) error {
var r performerRow
r.fromPerformer(*newObject)
r.fromPerformer(*newObject.Performer)
id, err := qb.tableMgr.insertID(ctx, r)
if err != nil {
@@ -282,12 +287,17 @@ func (qb *PerformerStore) Create(ctx context.Context, newObject *models.Performe
}
}
const partial = false
if err := qb.setCustomFields(ctx, id, newObject.CustomFields, partial); err != nil {
return err
}
updated, err := qb.find(ctx, id)
if err != nil {
return fmt.Errorf("finding after create: %w", err)
}
*newObject = *updated
*newObject.Performer = *updated
return nil
}
@@ -330,12 +340,16 @@ func (qb *PerformerStore) UpdatePartial(ctx context.Context, id int, partial mod
}
}
if err := qb.SetCustomFields(ctx, id, partial.CustomFields); err != nil {
return nil, err
}
return qb.find(ctx, id)
}
func (qb *PerformerStore) Update(ctx context.Context, updatedObject *models.Performer) error {
func (qb *PerformerStore) Update(ctx context.Context, updatedObject *models.UpdatePerformerInput) error {
var r performerRow
r.fromPerformer(*updatedObject)
r.fromPerformer(*updatedObject.Performer)
if err := qb.tableMgr.updateByID(ctx, updatedObject.ID, r); err != nil {
return err
@@ -365,6 +379,10 @@ func (qb *PerformerStore) Update(ctx context.Context, updatedObject *models.Perf
}
}
if err := qb.SetCustomFields(ctx, updatedObject.ID, updatedObject.CustomFields); err != nil {
return err
}
return nil
}

View File

@@ -203,6 +203,13 @@ func (qb *performerFilterHandler) criterionHandler() criterionHandler {
performerRepository.tags.innerJoin(f, "performer_tag", "performers.id")
},
},
&customFieldsFilterHandler{
table: performersCustomFieldsTable.GetTable(),
fkCol: performerIDColumn,
c: filter.CustomFields,
idCol: "performers.id",
},
}
}

View File

@@ -16,6 +16,12 @@ import (
"github.com/stretchr/testify/assert"
)
var testCustomFields = map[string]interface{}{
"string": "aaa",
"int": int64(123), // int64 to match the type of the field in the database
"real": 1.23,
}
func loadPerformerRelationships(ctx context.Context, expected models.Performer, actual *models.Performer) error {
if expected.Aliases.Loaded() {
if err := actual.LoadAliases(ctx, db.Performer); err != nil {
@@ -81,57 +87,62 @@ func Test_PerformerStore_Create(t *testing.T) {
tests := []struct {
name string
newObject models.Performer
newObject models.CreatePerformerInput
wantErr bool
}{
{
"full",
models.Performer{
Name: name,
Disambiguation: disambiguation,
Gender: &gender,
URLs: models.NewRelatedStrings(urls),
Birthdate: &birthdate,
Ethnicity: ethnicity,
Country: country,
EyeColor: eyeColor,
Height: &height,
Measurements: measurements,
FakeTits: fakeTits,
PenisLength: &penisLength,
Circumcised: &circumcised,
CareerLength: careerLength,
Tattoos: tattoos,
Piercings: piercings,
Favorite: favorite,
Rating: &rating,
Details: details,
DeathDate: &deathdate,
HairColor: hairColor,
Weight: &weight,
IgnoreAutoTag: ignoreAutoTag,
TagIDs: models.NewRelatedIDs([]int{tagIDs[tagIdx1WithPerformer], tagIDs[tagIdx1WithDupName]}),
Aliases: models.NewRelatedStrings(aliases),
StashIDs: models.NewRelatedStashIDs([]models.StashID{
{
StashID: stashID1,
Endpoint: endpoint1,
},
{
StashID: stashID2,
Endpoint: endpoint2,
},
}),
CreatedAt: createdAt,
UpdatedAt: updatedAt,
models.CreatePerformerInput{
Performer: &models.Performer{
Name: name,
Disambiguation: disambiguation,
Gender: &gender,
URLs: models.NewRelatedStrings(urls),
Birthdate: &birthdate,
Ethnicity: ethnicity,
Country: country,
EyeColor: eyeColor,
Height: &height,
Measurements: measurements,
FakeTits: fakeTits,
PenisLength: &penisLength,
Circumcised: &circumcised,
CareerLength: careerLength,
Tattoos: tattoos,
Piercings: piercings,
Favorite: favorite,
Rating: &rating,
Details: details,
DeathDate: &deathdate,
HairColor: hairColor,
Weight: &weight,
IgnoreAutoTag: ignoreAutoTag,
TagIDs: models.NewRelatedIDs([]int{tagIDs[tagIdx1WithPerformer], tagIDs[tagIdx1WithDupName]}),
Aliases: models.NewRelatedStrings(aliases),
StashIDs: models.NewRelatedStashIDs([]models.StashID{
{
StashID: stashID1,
Endpoint: endpoint1,
},
{
StashID: stashID2,
Endpoint: endpoint2,
},
}),
CreatedAt: createdAt,
UpdatedAt: updatedAt,
},
CustomFields: testCustomFields,
},
false,
},
{
"invalid tag id",
models.Performer{
Name: name,
TagIDs: models.NewRelatedIDs([]int{invalidID}),
models.CreatePerformerInput{
Performer: &models.Performer{
Name: name,
TagIDs: models.NewRelatedIDs([]int{invalidID}),
},
},
true,
},
@@ -155,16 +166,16 @@ func Test_PerformerStore_Create(t *testing.T) {
assert.NotZero(p.ID)
copy := tt.newObject
copy := *tt.newObject.Performer
copy.ID = p.ID
// load relationships
if err := loadPerformerRelationships(ctx, copy, &p); err != nil {
if err := loadPerformerRelationships(ctx, copy, p.Performer); err != nil {
t.Errorf("loadPerformerRelationships() error = %v", err)
return
}
assert.Equal(copy, p)
assert.Equal(copy, *p.Performer)
// ensure can find the performer
found, err := qb.Find(ctx, p.ID)
@@ -183,6 +194,15 @@ func Test_PerformerStore_Create(t *testing.T) {
}
assert.Equal(copy, *found)
// ensure custom fields are set
cf, err := qb.GetCustomFields(ctx, p.ID)
if err != nil {
t.Errorf("PerformerStore.GetCustomFields() error = %v", err)
return
}
assert.Equal(tt.newObject.CustomFields, cf)
return
})
}
@@ -228,77 +248,109 @@ func Test_PerformerStore_Update(t *testing.T) {
tests := []struct {
name string
updatedObject *models.Performer
updatedObject models.UpdatePerformerInput
wantErr bool
}{
{
"full",
&models.Performer{
ID: performerIDs[performerIdxWithGallery],
Name: name,
Disambiguation: disambiguation,
Gender: &gender,
URLs: models.NewRelatedStrings(urls),
Birthdate: &birthdate,
Ethnicity: ethnicity,
Country: country,
EyeColor: eyeColor,
Height: &height,
Measurements: measurements,
FakeTits: fakeTits,
PenisLength: &penisLength,
Circumcised: &circumcised,
CareerLength: careerLength,
Tattoos: tattoos,
Piercings: piercings,
Favorite: favorite,
Rating: &rating,
Details: details,
DeathDate: &deathdate,
HairColor: hairColor,
Weight: &weight,
IgnoreAutoTag: ignoreAutoTag,
Aliases: models.NewRelatedStrings(aliases),
TagIDs: models.NewRelatedIDs([]int{tagIDs[tagIdx1WithPerformer], tagIDs[tagIdx1WithDupName]}),
StashIDs: models.NewRelatedStashIDs([]models.StashID{
{
StashID: stashID1,
Endpoint: endpoint1,
},
{
StashID: stashID2,
Endpoint: endpoint2,
},
}),
CreatedAt: createdAt,
UpdatedAt: updatedAt,
models.UpdatePerformerInput{
Performer: &models.Performer{
ID: performerIDs[performerIdxWithGallery],
Name: name,
Disambiguation: disambiguation,
Gender: &gender,
URLs: models.NewRelatedStrings(urls),
Birthdate: &birthdate,
Ethnicity: ethnicity,
Country: country,
EyeColor: eyeColor,
Height: &height,
Measurements: measurements,
FakeTits: fakeTits,
PenisLength: &penisLength,
Circumcised: &circumcised,
CareerLength: careerLength,
Tattoos: tattoos,
Piercings: piercings,
Favorite: favorite,
Rating: &rating,
Details: details,
DeathDate: &deathdate,
HairColor: hairColor,
Weight: &weight,
IgnoreAutoTag: ignoreAutoTag,
Aliases: models.NewRelatedStrings(aliases),
TagIDs: models.NewRelatedIDs([]int{tagIDs[tagIdx1WithPerformer], tagIDs[tagIdx1WithDupName]}),
StashIDs: models.NewRelatedStashIDs([]models.StashID{
{
StashID: stashID1,
Endpoint: endpoint1,
},
{
StashID: stashID2,
Endpoint: endpoint2,
},
}),
CreatedAt: createdAt,
UpdatedAt: updatedAt,
},
},
false,
},
{
"clear nullables",
&models.Performer{
ID: performerIDs[performerIdxWithGallery],
Aliases: models.NewRelatedStrings([]string{}),
URLs: models.NewRelatedStrings([]string{}),
TagIDs: models.NewRelatedIDs([]int{}),
StashIDs: models.NewRelatedStashIDs([]models.StashID{}),
models.UpdatePerformerInput{
Performer: &models.Performer{
ID: performerIDs[performerIdxWithGallery],
Aliases: models.NewRelatedStrings([]string{}),
URLs: models.NewRelatedStrings([]string{}),
TagIDs: models.NewRelatedIDs([]int{}),
StashIDs: models.NewRelatedStashIDs([]models.StashID{}),
},
},
false,
},
{
"clear tag ids",
&models.Performer{
ID: performerIDs[sceneIdxWithTag],
TagIDs: models.NewRelatedIDs([]int{}),
models.UpdatePerformerInput{
Performer: &models.Performer{
ID: performerIDs[sceneIdxWithTag],
TagIDs: models.NewRelatedIDs([]int{}),
},
},
false,
},
{
"set custom fields",
models.UpdatePerformerInput{
Performer: &models.Performer{
ID: performerIDs[performerIdxWithGallery],
},
CustomFields: models.CustomFieldsInput{
Full: testCustomFields,
},
},
false,
},
{
"clear custom fields",
models.UpdatePerformerInput{
Performer: &models.Performer{
ID: performerIDs[performerIdxWithGallery],
},
CustomFields: models.CustomFieldsInput{
Full: map[string]interface{}{},
},
},
false,
},
{
"invalid tag id",
&models.Performer{
ID: performerIDs[sceneIdxWithGallery],
TagIDs: models.NewRelatedIDs([]int{invalidID}),
models.UpdatePerformerInput{
Performer: &models.Performer{
ID: performerIDs[sceneIdxWithGallery],
TagIDs: models.NewRelatedIDs([]int{invalidID}),
},
},
true,
},
@@ -309,9 +361,9 @@ func Test_PerformerStore_Update(t *testing.T) {
runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) {
assert := assert.New(t)
copy := *tt.updatedObject
copy := *tt.updatedObject.Performer
if err := qb.Update(ctx, tt.updatedObject); (err != nil) != tt.wantErr {
if err := qb.Update(ctx, &tt.updatedObject); (err != nil) != tt.wantErr {
t.Errorf("PerformerStore.Update() error = %v, wantErr %v", err, tt.wantErr)
}
@@ -331,6 +383,17 @@ func Test_PerformerStore_Update(t *testing.T) {
}
assert.Equal(copy, *s)
// ensure custom fields are correct
if tt.updatedObject.CustomFields.Full != nil {
cf, err := qb.GetCustomFields(ctx, tt.updatedObject.ID)
if err != nil {
t.Errorf("PerformerStore.GetCustomFields() error = %v", err)
return
}
assert.Equal(tt.updatedObject.CustomFields.Full, cf)
}
})
}
}
@@ -573,6 +636,79 @@ func Test_PerformerStore_UpdatePartial(t *testing.T) {
}
}
func Test_PerformerStore_UpdatePartialCustomFields(t *testing.T) {
tests := []struct {
name string
id int
partial models.PerformerPartial
expected map[string]interface{} // nil to use the partial
}{
{
"set custom fields",
performerIDs[performerIdxWithGallery],
models.PerformerPartial{
CustomFields: models.CustomFieldsInput{
Full: testCustomFields,
},
},
nil,
},
{
"clear custom fields",
performerIDs[performerIdxWithGallery],
models.PerformerPartial{
CustomFields: models.CustomFieldsInput{
Full: map[string]interface{}{},
},
},
nil,
},
{
"partial custom fields",
performerIDs[performerIdxWithGallery],
models.PerformerPartial{
CustomFields: models.CustomFieldsInput{
Partial: map[string]interface{}{
"string": "bbb",
"new_field": "new",
},
},
},
map[string]interface{}{
"int": int64(3),
"real": 1.3,
"string": "bbb",
"new_field": "new",
},
},
}
for _, tt := range tests {
qb := db.Performer
runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) {
assert := assert.New(t)
_, err := qb.UpdatePartial(ctx, tt.id, tt.partial)
if err != nil {
t.Errorf("PerformerStore.UpdatePartial() error = %v", err)
return
}
// ensure custom fields are correct
cf, err := qb.GetCustomFields(ctx, tt.id)
if err != nil {
t.Errorf("PerformerStore.GetCustomFields() error = %v", err)
return
}
if tt.expected == nil {
assert.Equal(tt.partial.CustomFields.Full, cf)
} else {
assert.Equal(tt.expected, cf)
}
})
}
}
func TestPerformerFindBySceneID(t *testing.T) {
withTxn(func(ctx context.Context) error {
pqb := db.Performer
@@ -1042,6 +1178,242 @@ func TestPerformerQuery(t *testing.T) {
}
}
func TestPerformerQueryCustomFields(t *testing.T) {
tests := []struct {
name string
filter *models.PerformerFilterType
includeIdxs []int
excludeIdxs []int
wantErr bool
}{
{
"equals",
&models.PerformerFilterType{
CustomFields: []models.CustomFieldCriterionInput{
{
Field: "string",
Modifier: models.CriterionModifierEquals,
Value: []any{getPerformerStringValue(performerIdxWithGallery, "custom")},
},
},
},
[]int{performerIdxWithGallery},
nil,
false,
},
{
"not equals",
&models.PerformerFilterType{
Name: &models.StringCriterionInput{
Value: getPerformerStringValue(performerIdxWithGallery, "Name"),
Modifier: models.CriterionModifierEquals,
},
CustomFields: []models.CustomFieldCriterionInput{
{
Field: "string",
Modifier: models.CriterionModifierNotEquals,
Value: []any{getPerformerStringValue(performerIdxWithGallery, "custom")},
},
},
},
nil,
[]int{performerIdxWithGallery},
false,
},
{
"includes",
&models.PerformerFilterType{
CustomFields: []models.CustomFieldCriterionInput{
{
Field: "string",
Modifier: models.CriterionModifierIncludes,
Value: []any{getPerformerStringValue(performerIdxWithGallery, "custom")[9:]},
},
},
},
[]int{performerIdxWithGallery},
nil,
false,
},
{
"excludes",
&models.PerformerFilterType{
Name: &models.StringCriterionInput{
Value: getPerformerStringValue(performerIdxWithGallery, "Name"),
Modifier: models.CriterionModifierEquals,
},
CustomFields: []models.CustomFieldCriterionInput{
{
Field: "string",
Modifier: models.CriterionModifierExcludes,
Value: []any{getPerformerStringValue(performerIdxWithGallery, "custom")[9:]},
},
},
},
nil,
[]int{performerIdxWithGallery},
false,
},
{
"regex",
&models.PerformerFilterType{
CustomFields: []models.CustomFieldCriterionInput{
{
Field: "string",
Modifier: models.CriterionModifierMatchesRegex,
Value: []any{".*13_custom"},
},
},
},
[]int{performerIdxWithGallery},
nil,
false,
},
{
"invalid regex",
&models.PerformerFilterType{
CustomFields: []models.CustomFieldCriterionInput{
{
Field: "string",
Modifier: models.CriterionModifierMatchesRegex,
Value: []any{"["},
},
},
},
nil,
nil,
true,
},
{
"not matches regex",
&models.PerformerFilterType{
Name: &models.StringCriterionInput{
Value: getPerformerStringValue(performerIdxWithGallery, "Name"),
Modifier: models.CriterionModifierEquals,
},
CustomFields: []models.CustomFieldCriterionInput{
{
Field: "string",
Modifier: models.CriterionModifierNotMatchesRegex,
Value: []any{".*13_custom"},
},
},
},
nil,
[]int{performerIdxWithGallery},
false,
},
{
"invalid not matches regex",
&models.PerformerFilterType{
CustomFields: []models.CustomFieldCriterionInput{
{
Field: "string",
Modifier: models.CriterionModifierNotMatchesRegex,
Value: []any{"["},
},
},
},
nil,
nil,
true,
},
{
"null",
&models.PerformerFilterType{
Name: &models.StringCriterionInput{
Value: getPerformerStringValue(performerIdxWithGallery, "Name"),
Modifier: models.CriterionModifierEquals,
},
CustomFields: []models.CustomFieldCriterionInput{
{
Field: "not existing",
Modifier: models.CriterionModifierIsNull,
},
},
},
[]int{performerIdxWithGallery},
nil,
false,
},
{
"null",
&models.PerformerFilterType{
Name: &models.StringCriterionInput{
Value: getPerformerStringValue(performerIdxWithGallery, "Name"),
Modifier: models.CriterionModifierEquals,
},
CustomFields: []models.CustomFieldCriterionInput{
{
Field: "string",
Modifier: models.CriterionModifierNotNull,
},
},
},
[]int{performerIdxWithGallery},
nil,
false,
},
{
"between",
&models.PerformerFilterType{
CustomFields: []models.CustomFieldCriterionInput{
{
Field: "real",
Modifier: models.CriterionModifierBetween,
Value: []any{0.05, 0.15},
},
},
},
[]int{performerIdx1WithScene},
nil,
false,
},
{
"not between",
&models.PerformerFilterType{
Name: &models.StringCriterionInput{
Value: getPerformerStringValue(performerIdx1WithScene, "Name"),
Modifier: models.CriterionModifierEquals,
},
CustomFields: []models.CustomFieldCriterionInput{
{
Field: "real",
Modifier: models.CriterionModifierNotBetween,
Value: []any{0.05, 0.15},
},
},
},
nil,
[]int{performerIdx1WithScene},
false,
},
}
for _, tt := range tests {
runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) {
assert := assert.New(t)
performers, _, err := db.Performer.Query(ctx, tt.filter, nil)
if (err != nil) != tt.wantErr {
t.Errorf("PerformerStore.Query() error = %v, wantErr %v", err, tt.wantErr)
return
}
ids := performersToIDs(performers)
include := indexesToIDs(performerIDs, tt.includeIdxs)
exclude := indexesToIDs(performerIDs, tt.excludeIdxs)
for _, i := range include {
assert.Contains(ids, i)
}
for _, e := range exclude {
assert.NotContains(ids, e)
}
})
}
}
func TestPerformerQueryPenisLength(t *testing.T) {
var upper = 4.0
@@ -1172,7 +1544,7 @@ func TestPerformerUpdatePerformerImage(t *testing.T) {
performer := models.Performer{
Name: name,
}
err := qb.Create(ctx, &performer)
err := qb.Create(ctx, &models.CreatePerformerInput{Performer: &performer})
if err != nil {
return fmt.Errorf("Error creating performer: %s", err.Error())
}
@@ -1680,7 +2052,7 @@ func TestPerformerStashIDs(t *testing.T) {
performer := &models.Performer{
Name: name,
}
if err := qb.Create(ctx, performer); err != nil {
if err := qb.Create(ctx, &models.CreatePerformerInput{Performer: performer}); err != nil {
return fmt.Errorf("Error creating performer: %s", err.Error())
}

View File

@@ -133,6 +133,9 @@ func (qb *queryBuilder) join(table, as, onClause string) {
func (qb *queryBuilder) addJoins(joins ...join) {
qb.joins.add(joins...)
for _, j := range joins {
qb.args = append(qb.args, j.args...)
}
}
func (qb *queryBuilder) addFilter(f *filterBuilder) error {
@@ -151,6 +154,9 @@ func (qb *queryBuilder) addFilter(f *filterBuilder) error {
qb.args = append(args, qb.args...)
}
// add joins here to insert args
qb.addJoins(f.getAllJoins()...)
clause, args = f.generateWhereClauses()
if len(clause) > 0 {
qb.addWhere(clause)
@@ -169,8 +175,6 @@ func (qb *queryBuilder) addFilter(f *filterBuilder) error {
qb.addArg(args...)
}
qb.addJoins(f.getAllJoins()...)
return nil
}

View File

@@ -222,8 +222,8 @@ func (r *repository) innerJoin(j joiner, as string, parentIDCol string) {
}
type joiner interface {
addLeftJoin(table, as, onClause string)
addInnerJoin(table, as, onClause string)
addLeftJoin(table, as, onClause string, args ...interface{})
addInnerJoin(table, as, onClause string, args ...interface{})
}
type joinRepository struct {

View File

@@ -1508,6 +1508,18 @@ func performerAliases(i int) []string {
return []string{getPerformerStringValue(i, "alias")}
}
func getPerformerCustomFields(index int) map[string]interface{} {
if index%5 == 0 {
return nil
}
return map[string]interface{}{
"string": getPerformerStringValue(index, "custom"),
"int": int64(index % 5),
"real": float64(index) / 10,
}
}
// createPerformers creates n performers with plain Name and o performers with camel cased NaMe included
func createPerformers(ctx context.Context, n int, o int) error {
pqb := db.Performer
@@ -1558,7 +1570,10 @@ func createPerformers(ctx context.Context, n int, o int) error {
})
}
err := pqb.Create(ctx, &performer)
err := pqb.Create(ctx, &models.CreatePerformerInput{
Performer: &performer,
CustomFields: getPerformerCustomFields(i),
})
if err != nil {
return fmt.Errorf("Error creating performer %v+: %s", performer, err.Error())

View File

@@ -32,6 +32,7 @@ var (
performersURLsJoinTable = goqu.T(performerURLsTable)
performersTagsJoinTable = goqu.T(performersTagsTable)
performersStashIDsJoinTable = goqu.T("performer_stash_ids")
performersCustomFieldsTable = goqu.T("performer_custom_fields")
studiosAliasesJoinTable = goqu.T(studioAliasesTable)
studiosTagsJoinTable = goqu.T(studiosTagsTable)