mirror of
https://github.com/stashapp/stash.git
synced 2025-12-18 04:44:37 +03:00
Restructure data layer (#2532)
* Add new txn manager interface * Add txn management to sqlite * Rename get to getByID * Add contexts to repository methods * Update query builders * Add context to reader writer interfaces * Use repository in resolver * Tighten interfaces * Tighten interfaces in dlna * Tighten interfaces in match package * Tighten interfaces in scraper package * Tighten interfaces in scan code * Tighten interfaces on autotag package * Remove ReaderWriter usage * Merge database package into sqlite
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -26,23 +27,23 @@ type repository struct {
|
||||
idColumn string
|
||||
}
|
||||
|
||||
func (r *repository) get(id int, dest interface{}) error {
|
||||
func (r *repository) getByID(ctx context.Context, id int, dest interface{}) error {
|
||||
stmt := fmt.Sprintf("SELECT * FROM %s WHERE %s = ? LIMIT 1", r.tableName, r.idColumn)
|
||||
return r.tx.Get(dest, stmt, id)
|
||||
return r.tx.Get(ctx, dest, stmt, id)
|
||||
}
|
||||
|
||||
func (r *repository) getAll(id int, f func(rows *sqlx.Rows) error) error {
|
||||
func (r *repository) getAll(ctx context.Context, id int, f func(rows *sqlx.Rows) error) error {
|
||||
stmt := fmt.Sprintf("SELECT * FROM %s WHERE %s = ?", r.tableName, r.idColumn)
|
||||
return r.queryFunc(stmt, []interface{}{id}, false, f)
|
||||
return r.queryFunc(ctx, stmt, []interface{}{id}, false, f)
|
||||
}
|
||||
|
||||
func (r *repository) insert(obj interface{}) (sql.Result, error) {
|
||||
func (r *repository) insert(ctx context.Context, obj interface{}) (sql.Result, error) {
|
||||
stmt := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", r.tableName, listKeys(obj, false), listKeys(obj, true))
|
||||
return r.tx.NamedExec(stmt, obj)
|
||||
return r.tx.NamedExec(ctx, stmt, obj)
|
||||
}
|
||||
|
||||
func (r *repository) insertObject(obj interface{}, out interface{}) error {
|
||||
result, err := r.insert(obj)
|
||||
func (r *repository) insertObject(ctx context.Context, obj interface{}, out interface{}) error {
|
||||
result, err := r.insert(ctx, obj)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -50,11 +51,11 @@ func (r *repository) insertObject(obj interface{}, out interface{}) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return r.get(int(id), out)
|
||||
return r.getByID(ctx, int(id), out)
|
||||
}
|
||||
|
||||
func (r *repository) update(id int, obj interface{}, partial bool) error {
|
||||
exists, err := r.exists(id)
|
||||
func (r *repository) update(ctx context.Context, id int, obj interface{}, partial bool) error {
|
||||
exists, err := r.exists(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -64,13 +65,13 @@ func (r *repository) update(id int, obj interface{}, partial bool) error {
|
||||
}
|
||||
|
||||
stmt := fmt.Sprintf("UPDATE %s SET %s WHERE %s.%s = :id", r.tableName, updateSet(obj, partial), r.tableName, r.idColumn)
|
||||
_, err = r.tx.NamedExec(stmt, obj)
|
||||
_, err = r.tx.NamedExec(ctx, stmt, obj)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *repository) updateMap(id int, m map[string]interface{}) error {
|
||||
exists, err := r.exists(id)
|
||||
func (r *repository) updateMap(ctx context.Context, id int, m map[string]interface{}) error {
|
||||
exists, err := r.exists(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -80,14 +81,14 @@ func (r *repository) updateMap(id int, m map[string]interface{}) error {
|
||||
}
|
||||
|
||||
stmt := fmt.Sprintf("UPDATE %s SET %s WHERE %s.%s = :id", r.tableName, updateSetMap(m), r.tableName, r.idColumn)
|
||||
_, err = r.tx.NamedExec(stmt, m)
|
||||
_, err = r.tx.NamedExec(ctx, stmt, m)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *repository) destroyExisting(ids []int) error {
|
||||
func (r *repository) destroyExisting(ctx context.Context, ids []int) error {
|
||||
for _, id := range ids {
|
||||
exists, err := r.exists(id)
|
||||
exists, err := r.exists(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -97,13 +98,13 @@ func (r *repository) destroyExisting(ids []int) error {
|
||||
}
|
||||
}
|
||||
|
||||
return r.destroy(ids)
|
||||
return r.destroy(ctx, ids)
|
||||
}
|
||||
|
||||
func (r *repository) destroy(ids []int) error {
|
||||
func (r *repository) destroy(ctx context.Context, ids []int) error {
|
||||
for _, id := range ids {
|
||||
stmt := fmt.Sprintf("DELETE FROM %s WHERE %s = ?", r.tableName, r.idColumn)
|
||||
if _, err := r.tx.Exec(stmt, id); err != nil {
|
||||
if _, err := r.tx.Exec(ctx, stmt, id); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -111,11 +112,11 @@ func (r *repository) destroy(ids []int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *repository) exists(id int) (bool, error) {
|
||||
func (r *repository) exists(ctx context.Context, id int) (bool, error) {
|
||||
stmt := fmt.Sprintf("SELECT %s FROM %s WHERE %s = ? LIMIT 1", r.idColumn, r.tableName, r.idColumn)
|
||||
stmt = r.buildCountQuery(stmt)
|
||||
|
||||
c, err := r.runCountQuery(stmt, []interface{}{id})
|
||||
c, err := r.runCountQuery(ctx, stmt, []interface{}{id})
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -127,25 +128,25 @@ func (r *repository) buildCountQuery(query string) string {
|
||||
return "SELECT COUNT(*) as count FROM (" + query + ") as temp"
|
||||
}
|
||||
|
||||
func (r *repository) runCountQuery(query string, args []interface{}) (int, error) {
|
||||
func (r *repository) runCountQuery(ctx context.Context, query string, args []interface{}) (int, error) {
|
||||
result := struct {
|
||||
Int int `db:"count"`
|
||||
}{0}
|
||||
|
||||
// Perform query and fetch result
|
||||
if err := r.tx.Get(&result, query, args...); err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
if err := r.tx.Get(ctx, &result, query, args...); err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return result.Int, nil
|
||||
}
|
||||
|
||||
func (r *repository) runIdsQuery(query string, args []interface{}) ([]int, error) {
|
||||
func (r *repository) runIdsQuery(ctx context.Context, query string, args []interface{}) ([]int, error) {
|
||||
var result []struct {
|
||||
Int int `db:"id"`
|
||||
}
|
||||
|
||||
if err := r.tx.Select(&result, query, args...); err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
if err := r.tx.Select(ctx, &result, query, args...); err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return []int{}, err
|
||||
}
|
||||
|
||||
@@ -156,24 +157,24 @@ func (r *repository) runIdsQuery(query string, args []interface{}) ([]int, error
|
||||
return vsm, nil
|
||||
}
|
||||
|
||||
func (r *repository) runSumQuery(query string, args []interface{}) (float64, error) {
|
||||
func (r *repository) runSumQuery(ctx context.Context, query string, args []interface{}) (float64, error) {
|
||||
// Perform query and fetch result
|
||||
result := struct {
|
||||
Float64 float64 `db:"sum"`
|
||||
}{0}
|
||||
|
||||
// Perform query and fetch result
|
||||
if err := r.tx.Get(&result, query, args...); err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
if err := r.tx.Get(ctx, &result, query, args...); err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return result.Float64, nil
|
||||
}
|
||||
|
||||
func (r *repository) queryFunc(query string, args []interface{}, single bool, f func(rows *sqlx.Rows) error) error {
|
||||
func (r *repository) queryFunc(ctx context.Context, query string, args []interface{}, single bool, f func(rows *sqlx.Rows) error) error {
|
||||
logger.Tracef("SQL: %s, args: %v", query, args)
|
||||
|
||||
rows, err := r.tx.Queryx(query, args...)
|
||||
rows, err := r.tx.Queryx(ctx, query, args...)
|
||||
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return err
|
||||
@@ -196,8 +197,8 @@ func (r *repository) queryFunc(query string, args []interface{}, single bool, f
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *repository) query(query string, args []interface{}, out objectList) error {
|
||||
return r.queryFunc(query, args, false, func(rows *sqlx.Rows) error {
|
||||
func (r *repository) query(ctx context.Context, query string, args []interface{}, out objectList) error {
|
||||
return r.queryFunc(ctx, query, args, false, func(rows *sqlx.Rows) error {
|
||||
object := out.New()
|
||||
if err := rows.StructScan(object); err != nil {
|
||||
return err
|
||||
@@ -207,8 +208,8 @@ func (r *repository) query(query string, args []interface{}, out objectList) err
|
||||
})
|
||||
}
|
||||
|
||||
func (r *repository) queryStruct(query string, args []interface{}, out interface{}) error {
|
||||
return r.queryFunc(query, args, true, func(rows *sqlx.Rows) error {
|
||||
func (r *repository) queryStruct(ctx context.Context, query string, args []interface{}, out interface{}) error {
|
||||
return r.queryFunc(ctx, query, args, true, func(rows *sqlx.Rows) error {
|
||||
if err := rows.StructScan(out); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -216,8 +217,8 @@ func (r *repository) queryStruct(query string, args []interface{}, out interface
|
||||
})
|
||||
}
|
||||
|
||||
func (r *repository) querySimple(query string, args []interface{}, out interface{}) error {
|
||||
rows, err := r.tx.Queryx(query, args...)
|
||||
func (r *repository) querySimple(ctx context.Context, query string, args []interface{}, out interface{}) error {
|
||||
rows, err := r.tx.Queryx(ctx, query, args...)
|
||||
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return err
|
||||
@@ -249,7 +250,7 @@ func (r *repository) buildQueryBody(body string, whereClauses []string, havingCl
|
||||
return body
|
||||
}
|
||||
|
||||
func (r *repository) executeFindQuery(body string, args []interface{}, sortAndPagination string, whereClauses []string, havingClauses []string, withClauses []string, recursiveWith bool) ([]int, int, error) {
|
||||
func (r *repository) executeFindQuery(ctx context.Context, body string, args []interface{}, sortAndPagination string, whereClauses []string, havingClauses []string, withClauses []string, recursiveWith bool) ([]int, int, error) {
|
||||
body = r.buildQueryBody(body, whereClauses, havingClauses)
|
||||
|
||||
withClause := ""
|
||||
@@ -272,8 +273,8 @@ func (r *repository) executeFindQuery(body string, args []interface{}, sortAndPa
|
||||
var idsResult []int
|
||||
var idsErr error
|
||||
|
||||
countResult, countErr = r.runCountQuery(countQuery, args)
|
||||
idsResult, idsErr = r.runIdsQuery(idsQuery, args)
|
||||
countResult, countErr = r.runCountQuery(ctx, countQuery, args)
|
||||
idsResult, idsErr = r.runIdsQuery(ctx, idsQuery, args)
|
||||
|
||||
if countErr != nil {
|
||||
return nil, 0, fmt.Errorf("error executing count query with SQL: %s, args: %v, error: %s", countQuery, args, countErr.Error())
|
||||
@@ -318,23 +319,23 @@ type joinRepository struct {
|
||||
fkColumn string
|
||||
}
|
||||
|
||||
func (r *joinRepository) getIDs(id int) ([]int, error) {
|
||||
func (r *joinRepository) getIDs(ctx context.Context, id int) ([]int, error) {
|
||||
query := fmt.Sprintf(`SELECT %s as id from %s WHERE %s = ?`, r.fkColumn, r.tableName, r.idColumn)
|
||||
return r.runIdsQuery(query, []interface{}{id})
|
||||
return r.runIdsQuery(ctx, query, []interface{}{id})
|
||||
}
|
||||
|
||||
func (r *joinRepository) insert(id, foreignID int) (sql.Result, error) {
|
||||
func (r *joinRepository) insert(ctx context.Context, id, foreignID int) (sql.Result, error) {
|
||||
stmt := fmt.Sprintf("INSERT INTO %s (%s, %s) VALUES (?, ?)", r.tableName, r.idColumn, r.fkColumn)
|
||||
return r.tx.Exec(stmt, id, foreignID)
|
||||
return r.tx.Exec(ctx, stmt, id, foreignID)
|
||||
}
|
||||
|
||||
func (r *joinRepository) replace(id int, foreignIDs []int) error {
|
||||
if err := r.destroy([]int{id}); err != nil {
|
||||
func (r *joinRepository) replace(ctx context.Context, id int, foreignIDs []int) error {
|
||||
if err := r.destroy(ctx, []int{id}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, fk := range foreignIDs {
|
||||
if _, err := r.insert(id, fk); err != nil {
|
||||
if _, err := r.insert(ctx, id, fk); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -347,20 +348,20 @@ type imageRepository struct {
|
||||
imageColumn string
|
||||
}
|
||||
|
||||
func (r *imageRepository) get(id int) ([]byte, error) {
|
||||
func (r *imageRepository) get(ctx context.Context, id int) ([]byte, error) {
|
||||
query := fmt.Sprintf("SELECT %s from %s WHERE %s = ?", r.imageColumn, r.tableName, r.idColumn)
|
||||
var ret []byte
|
||||
err := r.querySimple(query, []interface{}{id}, &ret)
|
||||
err := r.querySimple(ctx, query, []interface{}{id}, &ret)
|
||||
return ret, err
|
||||
}
|
||||
|
||||
func (r *imageRepository) replace(id int, image []byte) error {
|
||||
if err := r.destroy([]int{id}); err != nil {
|
||||
func (r *imageRepository) replace(ctx context.Context, id int, image []byte) error {
|
||||
if err := r.destroy(ctx, []int{id}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
stmt := fmt.Sprintf("INSERT INTO %s (%s, %s) VALUES (?, ?)", r.tableName, r.idColumn, r.imageColumn)
|
||||
_, err := r.tx.Exec(stmt, id, image)
|
||||
_, err := r.tx.Exec(ctx, stmt, id, image)
|
||||
|
||||
return err
|
||||
}
|
||||
@@ -369,10 +370,10 @@ type captionRepository struct {
|
||||
repository
|
||||
}
|
||||
|
||||
func (r *captionRepository) get(id int) ([]*models.SceneCaption, error) {
|
||||
func (r *captionRepository) get(ctx context.Context, id int) ([]*models.SceneCaption, error) {
|
||||
query := fmt.Sprintf("SELECT %s, %s, %s from %s WHERE %s = ?", sceneCaptionCodeColumn, sceneCaptionFilenameColumn, sceneCaptionTypeColumn, r.tableName, r.idColumn)
|
||||
var ret []*models.SceneCaption
|
||||
err := r.queryFunc(query, []interface{}{id}, false, func(rows *sqlx.Rows) error {
|
||||
err := r.queryFunc(ctx, query, []interface{}{id}, false, func(rows *sqlx.Rows) error {
|
||||
var captionCode string
|
||||
var captionFilename string
|
||||
var captionType string
|
||||
@@ -392,18 +393,18 @@ func (r *captionRepository) get(id int) ([]*models.SceneCaption, error) {
|
||||
return ret, err
|
||||
}
|
||||
|
||||
func (r *captionRepository) insert(id int, caption *models.SceneCaption) (sql.Result, error) {
|
||||
func (r *captionRepository) insert(ctx context.Context, id int, caption *models.SceneCaption) (sql.Result, error) {
|
||||
stmt := fmt.Sprintf("INSERT INTO %s (%s, %s, %s, %s) VALUES (?, ?, ?, ?)", r.tableName, r.idColumn, sceneCaptionCodeColumn, sceneCaptionFilenameColumn, sceneCaptionTypeColumn)
|
||||
return r.tx.Exec(stmt, id, caption.LanguageCode, caption.Filename, caption.CaptionType)
|
||||
return r.tx.Exec(ctx, stmt, id, caption.LanguageCode, caption.Filename, caption.CaptionType)
|
||||
}
|
||||
|
||||
func (r *captionRepository) replace(id int, captions []*models.SceneCaption) error {
|
||||
if err := r.destroy([]int{id}); err != nil {
|
||||
func (r *captionRepository) replace(ctx context.Context, id int, captions []*models.SceneCaption) error {
|
||||
if err := r.destroy(ctx, []int{id}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, caption := range captions {
|
||||
if _, err := r.insert(id, caption); err != nil {
|
||||
if _, err := r.insert(ctx, id, caption); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -416,10 +417,10 @@ type stringRepository struct {
|
||||
stringColumn string
|
||||
}
|
||||
|
||||
func (r *stringRepository) get(id int) ([]string, error) {
|
||||
func (r *stringRepository) get(ctx context.Context, id int) ([]string, error) {
|
||||
query := fmt.Sprintf("SELECT %s from %s WHERE %s = ?", r.stringColumn, r.tableName, r.idColumn)
|
||||
var ret []string
|
||||
err := r.queryFunc(query, []interface{}{id}, false, func(rows *sqlx.Rows) error {
|
||||
err := r.queryFunc(ctx, query, []interface{}{id}, false, func(rows *sqlx.Rows) error {
|
||||
var out string
|
||||
if err := rows.Scan(&out); err != nil {
|
||||
return err
|
||||
@@ -431,18 +432,18 @@ func (r *stringRepository) get(id int) ([]string, error) {
|
||||
return ret, err
|
||||
}
|
||||
|
||||
func (r *stringRepository) insert(id int, s string) (sql.Result, error) {
|
||||
func (r *stringRepository) insert(ctx context.Context, id int, s string) (sql.Result, error) {
|
||||
stmt := fmt.Sprintf("INSERT INTO %s (%s, %s) VALUES (?, ?)", r.tableName, r.idColumn, r.stringColumn)
|
||||
return r.tx.Exec(stmt, id, s)
|
||||
return r.tx.Exec(ctx, stmt, id, s)
|
||||
}
|
||||
|
||||
func (r *stringRepository) replace(id int, newStrings []string) error {
|
||||
if err := r.destroy([]int{id}); err != nil {
|
||||
func (r *stringRepository) replace(ctx context.Context, id int, newStrings []string) error {
|
||||
if err := r.destroy(ctx, []int{id}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, s := range newStrings {
|
||||
if _, err := r.insert(id, s); err != nil {
|
||||
if _, err := r.insert(ctx, id, s); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -464,21 +465,21 @@ func (s *stashIDs) New() interface{} {
|
||||
return &models.StashID{}
|
||||
}
|
||||
|
||||
func (r *stashIDRepository) get(id int) ([]*models.StashID, error) {
|
||||
func (r *stashIDRepository) get(ctx context.Context, id int) ([]*models.StashID, error) {
|
||||
query := fmt.Sprintf("SELECT stash_id, endpoint from %s WHERE %s = ?", r.tableName, r.idColumn)
|
||||
var ret stashIDs
|
||||
err := r.query(query, []interface{}{id}, &ret)
|
||||
err := r.query(ctx, query, []interface{}{id}, &ret)
|
||||
return []*models.StashID(ret), err
|
||||
}
|
||||
|
||||
func (r *stashIDRepository) replace(id int, newIDs []models.StashID) error {
|
||||
if err := r.destroy([]int{id}); err != nil {
|
||||
func (r *stashIDRepository) replace(ctx context.Context, id int, newIDs []models.StashID) error {
|
||||
if err := r.destroy(ctx, []int{id}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("INSERT INTO %s (%s, endpoint, stash_id) VALUES (?, ?, ?)", r.tableName, r.idColumn)
|
||||
for _, stashID := range newIDs {
|
||||
_, err := r.tx.Exec(query, id, stashID.Endpoint, stashID.StashID)
|
||||
_, err := r.tx.Exec(ctx, query, id, stashID.Endpoint, stashID.StashID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user