diff --git a/pkg/api/resolver_mutation_scene.go b/pkg/api/resolver_mutation_scene.go index 8cbfa1e86..2cdf55635 100644 --- a/pkg/api/resolver_mutation_scene.go +++ b/pkg/api/resolver_mutation_scene.go @@ -14,36 +14,36 @@ func (r *mutationResolver) SceneUpdate(ctx context.Context, input models.SceneUp // Populate scene from the input sceneID, _ := strconv.Atoi(input.ID) updatedTime := time.Now() - updatedScene := models.Scene{ + updatedScene := models.ScenePartial{ ID: sceneID, - UpdatedAt: models.SQLiteTimestamp{Timestamp: updatedTime}, + UpdatedAt: &models.SQLiteTimestamp{Timestamp: updatedTime}, } if input.Title != nil { - updatedScene.Title = sql.NullString{String: *input.Title, Valid: true} + updatedScene.Title = &sql.NullString{String: *input.Title, Valid: true} } if input.Details != nil { - updatedScene.Details = sql.NullString{String: *input.Details, Valid: true} + updatedScene.Details = &sql.NullString{String: *input.Details, Valid: true} } if input.URL != nil { - updatedScene.URL = sql.NullString{String: *input.URL, Valid: true} + updatedScene.URL = &sql.NullString{String: *input.URL, Valid: true} } if input.Date != nil { - updatedScene.Date = models.SQLiteDate{String: *input.Date, Valid: true} + updatedScene.Date = &models.SQLiteDate{String: *input.Date, Valid: true} } if input.Rating != nil { - updatedScene.Rating = sql.NullInt64{Int64: int64(*input.Rating), Valid: true} + updatedScene.Rating = &sql.NullInt64{Int64: int64(*input.Rating), Valid: true} } else { // rating must be nullable - updatedScene.Rating = sql.NullInt64{Valid: false} + updatedScene.Rating = &sql.NullInt64{Valid: false} } if input.StudioID != nil { studioID, _ := strconv.ParseInt(*input.StudioID, 10, 64) - updatedScene.StudioID = sql.NullInt64{Int64: studioID, Valid: true} + updatedScene.StudioID = &sql.NullInt64{Int64: studioID, Valid: true} } else { // studio must be nullable - updatedScene.StudioID = sql.NullInt64{Valid: false} + updatedScene.StudioID = &sql.NullInt64{Valid: false} } // Start the transaction and save the scene marker diff --git a/pkg/manager/task_scan.go b/pkg/manager/task_scan.go index b2e11d7cc..26e1fda32 100644 --- a/pkg/manager/task_scan.go +++ b/pkg/manager/task_scan.go @@ -107,8 +107,11 @@ func (t *ScanTask) scanScene() { logger.Infof("%s already exists. Duplicate of %s ", t.FilePath, scene.Path) } else { logger.Infof("%s already exists. Updating path...", t.FilePath) - scene.Path = t.FilePath - _, err = qb.Update(*scene, tx) + scenePartial := models.ScenePartial{ + ID: scene.ID, + Path: &t.FilePath, + } + _, err = qb.Update(scenePartial, tx) } } else { logger.Infof("%s doesn't exist. Creating new item...", t.FilePath) diff --git a/pkg/models/model_scene.go b/pkg/models/model_scene.go index 0b480c30e..2f487488f 100644 --- a/pkg/models/model_scene.go +++ b/pkg/models/model_scene.go @@ -25,3 +25,25 @@ type Scene struct { CreatedAt SQLiteTimestamp `db:"created_at" json:"created_at"` UpdatedAt SQLiteTimestamp `db:"updated_at" json:"updated_at"` } + +type ScenePartial struct { + ID int `db:"id" json:"id"` + Checksum *string `db:"checksum" json:"checksum"` + Path *string `db:"path" json:"path"` + Title *sql.NullString `db:"title" json:"title"` + Details *sql.NullString `db:"details" json:"details"` + URL *sql.NullString `db:"url" json:"url"` + Date *SQLiteDate `db:"date" json:"date"` + Rating *sql.NullInt64 `db:"rating" json:"rating"` + Size *sql.NullString `db:"size" json:"size"` + Duration *sql.NullFloat64 `db:"duration" json:"duration"` + VideoCodec *sql.NullString `db:"video_codec" json:"video_codec"` + AudioCodec *sql.NullString `db:"audio_codec" json:"audio_codec"` + Width *sql.NullInt64 `db:"width" json:"width"` + Height *sql.NullInt64 `db:"height" json:"height"` + Framerate *sql.NullFloat64 `db:"framerate" json:"framerate"` + Bitrate *sql.NullInt64 `db:"bitrate" json:"bitrate"` + StudioID *sql.NullInt64 `db:"studio_id,omitempty" json:"studio_id"` + CreatedAt *SQLiteTimestamp `db:"created_at" json:"created_at"` + UpdatedAt *SQLiteTimestamp `db:"updated_at" json:"updated_at"` +} diff --git a/pkg/models/querybuilder_scene.go b/pkg/models/querybuilder_scene.go index a6f32fc2b..bd4f00e93 100644 --- a/pkg/models/querybuilder_scene.go +++ b/pkg/models/querybuilder_scene.go @@ -2,10 +2,11 @@ package models import ( "database/sql" - "github.com/jmoiron/sqlx" - "github.com/stashapp/stash/pkg/database" "strconv" "strings" + + "github.com/jmoiron/sqlx" + "github.com/stashapp/stash/pkg/database" ) const scenesForPerformerQuery = ` @@ -60,20 +61,17 @@ func (qb *SceneQueryBuilder) Create(newScene Scene, tx *sqlx.Tx) (*Scene, error) return &newScene, nil } -func (qb *SceneQueryBuilder) Update(updatedScene Scene, tx *sqlx.Tx) (*Scene, error) { +func (qb *SceneQueryBuilder) Update(updatedScene ScenePartial, tx *sqlx.Tx) (*Scene, error) { ensureTx(tx) _, err := tx.NamedExec( - `UPDATE scenes SET `+SQLGenKeys(updatedScene)+` WHERE scenes.id = :id`, + `UPDATE scenes SET `+SQLGenKeysPartial(updatedScene)+` WHERE scenes.id = :id`, updatedScene, ) if err != nil { return nil, err } - if err := tx.Get(&updatedScene, `SELECT * FROM scenes WHERE id = ? LIMIT 1`, updatedScene.ID); err != nil { - return nil, err - } - return &updatedScene, nil + return qb.Find(updatedScene.ID) } func (qb *SceneQueryBuilder) Find(id int) (*Scene, error) { diff --git a/pkg/models/querybuilder_sql.go b/pkg/models/querybuilder_sql.go index 3f5021cb3..263e558de 100644 --- a/pkg/models/querybuilder_sql.go +++ b/pkg/models/querybuilder_sql.go @@ -236,6 +236,17 @@ func ensureTx(tx *sqlx.Tx) { // of keys for non empty key:values. These keys are formated // keyname=:keyname with a comma seperating them func SQLGenKeys(i interface{}) string { + return sqlGenKeys(i, false) +} + +// support a partial interface. When a partial interface is provided, +// keys will always be included if the value is not null. The partial +// interface must therefore consist of pointers +func SQLGenKeysPartial(i interface{}) string { + return sqlGenKeys(i, true) +} + +func sqlGenKeys(i interface{}, partial bool) string { var query []string v := reflect.ValueOf(i) for i := 0; i < v.NumField(); i++ { @@ -247,42 +258,46 @@ func SQLGenKeys(i interface{}) string { } switch t := v.Field(i).Interface().(type) { case string: - if t != "" { + if partial || t != "" { query = append(query, fmt.Sprintf("%s=:%s", key, key)) } case int: - if t != 0 { + if partial || t != 0 { query = append(query, fmt.Sprintf("%s=:%s", key, key)) } case float64: - if t != 0 { + if partial || t != 0 { query = append(query, fmt.Sprintf("%s=:%s", key, key)) } case SQLiteTimestamp: - if !t.Timestamp.IsZero() { + if partial || !t.Timestamp.IsZero() { query = append(query, fmt.Sprintf("%s=:%s", key, key)) } case SQLiteDate: - if t.Valid { + if partial || t.Valid { query = append(query, fmt.Sprintf("%s=:%s", key, key)) } case sql.NullString: - // should be nullable - query = append(query, fmt.Sprintf("%s=:%s", key, key)) + if partial || t.Valid { + query = append(query, fmt.Sprintf("%s=:%s", key, key)) + } case sql.NullBool: - // should be nullable - query = append(query, fmt.Sprintf("%s=:%s", key, key)) + if partial || t.Valid { + query = append(query, fmt.Sprintf("%s=:%s", key, key)) + } case sql.NullInt64: - // should be nullable - query = append(query, fmt.Sprintf("%s=:%s", key, key)) + if partial || t.Valid { + query = append(query, fmt.Sprintf("%s=:%s", key, key)) + } case sql.NullFloat64: - // should be nullable - query = append(query, fmt.Sprintf("%s=:%s", key, key)) + if partial || t.Valid { + query = append(query, fmt.Sprintf("%s=:%s", key, key)) + } default: reflectValue := reflect.ValueOf(t) kind := reflectValue.Kind() isNil := reflectValue.IsNil() - if kind != reflect.Ptr && !isNil { + if kind != reflect.Ptr || !isNil { query = append(query, fmt.Sprintf("%s=:%s", key, key)) } }