Only update non-nil database fields

This commit is contained in:
WithoutPants
2019-10-15 08:57:53 +11:00
parent 470c64b840
commit 0852199e27
5 changed files with 72 additions and 34 deletions

View File

@@ -14,36 +14,36 @@ func (r *mutationResolver) SceneUpdate(ctx context.Context, input models.SceneUp
// Populate scene from the input // Populate scene from the input
sceneID, _ := strconv.Atoi(input.ID) sceneID, _ := strconv.Atoi(input.ID)
updatedTime := time.Now() updatedTime := time.Now()
updatedScene := models.Scene{ updatedScene := models.ScenePartial{
ID: sceneID, ID: sceneID,
UpdatedAt: models.SQLiteTimestamp{Timestamp: updatedTime}, UpdatedAt: &models.SQLiteTimestamp{Timestamp: updatedTime},
} }
if input.Title != nil { if input.Title != nil {
updatedScene.Title = sql.NullString{String: *input.Title, Valid: true} updatedScene.Title = &sql.NullString{String: *input.Title, Valid: true}
} }
if input.Details != nil { if input.Details != nil {
updatedScene.Details = sql.NullString{String: *input.Details, Valid: true} updatedScene.Details = &sql.NullString{String: *input.Details, Valid: true}
} }
if input.URL != nil { if input.URL != nil {
updatedScene.URL = sql.NullString{String: *input.URL, Valid: true} updatedScene.URL = &sql.NullString{String: *input.URL, Valid: true}
} }
if input.Date != nil { if input.Date != nil {
updatedScene.Date = models.SQLiteDate{String: *input.Date, Valid: true} updatedScene.Date = &models.SQLiteDate{String: *input.Date, Valid: true}
} }
if input.Rating != nil { if input.Rating != nil {
updatedScene.Rating = sql.NullInt64{Int64: int64(*input.Rating), Valid: true} updatedScene.Rating = &sql.NullInt64{Int64: int64(*input.Rating), Valid: true}
} else { } else {
// rating must be nullable // rating must be nullable
updatedScene.Rating = sql.NullInt64{Valid: false} updatedScene.Rating = &sql.NullInt64{Valid: false}
} }
if input.StudioID != nil { if input.StudioID != nil {
studioID, _ := strconv.ParseInt(*input.StudioID, 10, 64) studioID, _ := strconv.ParseInt(*input.StudioID, 10, 64)
updatedScene.StudioID = sql.NullInt64{Int64: studioID, Valid: true} updatedScene.StudioID = &sql.NullInt64{Int64: studioID, Valid: true}
} else { } else {
// studio must be nullable // studio must be nullable
updatedScene.StudioID = sql.NullInt64{Valid: false} updatedScene.StudioID = &sql.NullInt64{Valid: false}
} }
// Start the transaction and save the scene marker // Start the transaction and save the scene marker

View File

@@ -107,8 +107,11 @@ func (t *ScanTask) scanScene() {
logger.Infof("%s already exists. Duplicate of %s ", t.FilePath, scene.Path) logger.Infof("%s already exists. Duplicate of %s ", t.FilePath, scene.Path)
} else { } else {
logger.Infof("%s already exists. Updating path...", t.FilePath) logger.Infof("%s already exists. Updating path...", t.FilePath)
scene.Path = t.FilePath scenePartial := models.ScenePartial{
_, err = qb.Update(*scene, tx) ID: scene.ID,
Path: &t.FilePath,
}
_, err = qb.Update(scenePartial, tx)
} }
} else { } else {
logger.Infof("%s doesn't exist. Creating new item...", t.FilePath) logger.Infof("%s doesn't exist. Creating new item...", t.FilePath)

View File

@@ -25,3 +25,25 @@ type Scene struct {
CreatedAt SQLiteTimestamp `db:"created_at" json:"created_at"` CreatedAt SQLiteTimestamp `db:"created_at" json:"created_at"`
UpdatedAt SQLiteTimestamp `db:"updated_at" json:"updated_at"` UpdatedAt SQLiteTimestamp `db:"updated_at" json:"updated_at"`
} }
type ScenePartial struct {
ID int `db:"id" json:"id"`
Checksum *string `db:"checksum" json:"checksum"`
Path *string `db:"path" json:"path"`
Title *sql.NullString `db:"title" json:"title"`
Details *sql.NullString `db:"details" json:"details"`
URL *sql.NullString `db:"url" json:"url"`
Date *SQLiteDate `db:"date" json:"date"`
Rating *sql.NullInt64 `db:"rating" json:"rating"`
Size *sql.NullString `db:"size" json:"size"`
Duration *sql.NullFloat64 `db:"duration" json:"duration"`
VideoCodec *sql.NullString `db:"video_codec" json:"video_codec"`
AudioCodec *sql.NullString `db:"audio_codec" json:"audio_codec"`
Width *sql.NullInt64 `db:"width" json:"width"`
Height *sql.NullInt64 `db:"height" json:"height"`
Framerate *sql.NullFloat64 `db:"framerate" json:"framerate"`
Bitrate *sql.NullInt64 `db:"bitrate" json:"bitrate"`
StudioID *sql.NullInt64 `db:"studio_id,omitempty" json:"studio_id"`
CreatedAt *SQLiteTimestamp `db:"created_at" json:"created_at"`
UpdatedAt *SQLiteTimestamp `db:"updated_at" json:"updated_at"`
}

View File

@@ -2,10 +2,11 @@ package models
import ( import (
"database/sql" "database/sql"
"github.com/jmoiron/sqlx"
"github.com/stashapp/stash/pkg/database"
"strconv" "strconv"
"strings" "strings"
"github.com/jmoiron/sqlx"
"github.com/stashapp/stash/pkg/database"
) )
const scenesForPerformerQuery = ` const scenesForPerformerQuery = `
@@ -60,20 +61,17 @@ func (qb *SceneQueryBuilder) Create(newScene Scene, tx *sqlx.Tx) (*Scene, error)
return &newScene, nil return &newScene, nil
} }
func (qb *SceneQueryBuilder) Update(updatedScene Scene, tx *sqlx.Tx) (*Scene, error) { func (qb *SceneQueryBuilder) Update(updatedScene ScenePartial, tx *sqlx.Tx) (*Scene, error) {
ensureTx(tx) ensureTx(tx)
_, err := tx.NamedExec( _, err := tx.NamedExec(
`UPDATE scenes SET `+SQLGenKeys(updatedScene)+` WHERE scenes.id = :id`, `UPDATE scenes SET `+SQLGenKeysPartial(updatedScene)+` WHERE scenes.id = :id`,
updatedScene, updatedScene,
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := tx.Get(&updatedScene, `SELECT * FROM scenes WHERE id = ? LIMIT 1`, updatedScene.ID); err != nil { return qb.Find(updatedScene.ID)
return nil, err
}
return &updatedScene, nil
} }
func (qb *SceneQueryBuilder) Find(id int) (*Scene, error) { func (qb *SceneQueryBuilder) Find(id int) (*Scene, error) {

View File

@@ -236,6 +236,17 @@ func ensureTx(tx *sqlx.Tx) {
// of keys for non empty key:values. These keys are formated // of keys for non empty key:values. These keys are formated
// keyname=:keyname with a comma seperating them // keyname=:keyname with a comma seperating them
func SQLGenKeys(i interface{}) string { func SQLGenKeys(i interface{}) string {
return sqlGenKeys(i, false)
}
// support a partial interface. When a partial interface is provided,
// keys will always be included if the value is not null. The partial
// interface must therefore consist of pointers
func SQLGenKeysPartial(i interface{}) string {
return sqlGenKeys(i, true)
}
func sqlGenKeys(i interface{}, partial bool) string {
var query []string var query []string
v := reflect.ValueOf(i) v := reflect.ValueOf(i)
for i := 0; i < v.NumField(); i++ { for i := 0; i < v.NumField(); i++ {
@@ -247,42 +258,46 @@ func SQLGenKeys(i interface{}) string {
} }
switch t := v.Field(i).Interface().(type) { switch t := v.Field(i).Interface().(type) {
case string: case string:
if t != "" { if partial || t != "" {
query = append(query, fmt.Sprintf("%s=:%s", key, key)) query = append(query, fmt.Sprintf("%s=:%s", key, key))
} }
case int: case int:
if t != 0 { if partial || t != 0 {
query = append(query, fmt.Sprintf("%s=:%s", key, key)) query = append(query, fmt.Sprintf("%s=:%s", key, key))
} }
case float64: case float64:
if t != 0 { if partial || t != 0 {
query = append(query, fmt.Sprintf("%s=:%s", key, key)) query = append(query, fmt.Sprintf("%s=:%s", key, key))
} }
case SQLiteTimestamp: case SQLiteTimestamp:
if !t.Timestamp.IsZero() { if partial || !t.Timestamp.IsZero() {
query = append(query, fmt.Sprintf("%s=:%s", key, key)) query = append(query, fmt.Sprintf("%s=:%s", key, key))
} }
case SQLiteDate: case SQLiteDate:
if t.Valid { if partial || t.Valid {
query = append(query, fmt.Sprintf("%s=:%s", key, key)) query = append(query, fmt.Sprintf("%s=:%s", key, key))
} }
case sql.NullString: case sql.NullString:
// should be nullable if partial || t.Valid {
query = append(query, fmt.Sprintf("%s=:%s", key, key)) query = append(query, fmt.Sprintf("%s=:%s", key, key))
}
case sql.NullBool: case sql.NullBool:
// should be nullable if partial || t.Valid {
query = append(query, fmt.Sprintf("%s=:%s", key, key)) query = append(query, fmt.Sprintf("%s=:%s", key, key))
}
case sql.NullInt64: case sql.NullInt64:
// should be nullable if partial || t.Valid {
query = append(query, fmt.Sprintf("%s=:%s", key, key)) query = append(query, fmt.Sprintf("%s=:%s", key, key))
}
case sql.NullFloat64: case sql.NullFloat64:
// should be nullable if partial || t.Valid {
query = append(query, fmt.Sprintf("%s=:%s", key, key)) query = append(query, fmt.Sprintf("%s=:%s", key, key))
}
default: default:
reflectValue := reflect.ValueOf(t) reflectValue := reflect.ValueOf(t)
kind := reflectValue.Kind() kind := reflectValue.Kind()
isNil := reflectValue.IsNil() isNil := reflectValue.IsNil()
if kind != reflect.Ptr && !isNil { if kind != reflect.Ptr || !isNil {
query = append(query, fmt.Sprintf("%s=:%s", key, key)) query = append(query, fmt.Sprintf("%s=:%s", key, key))
} }
} }