mirror of
https://github.com/stashapp/stash.git
synced 2025-12-17 12:24:38 +03:00
Data layer restructuring (#997)
* Move query builders to sqlite package * Add transaction system * Wrap model resolvers in transaction * Add error return value for StringSliceToIntSlice * Update/refactor mutation resolvers * Convert query builders * Remove unused join types * Add stash id unit tests * Use WAL journal mode
This commit is contained in:
254
pkg/sqlite/sql.go
Normal file
254
pkg/sqlite/sql.go
Normal file
@@ -0,0 +1,254 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/stashapp/stash/pkg/logger"
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
)
|
||||
|
||||
var randomSortFloat = rand.Float64()
|
||||
|
||||
func selectAll(tableName string) string {
|
||||
idColumn := getColumn(tableName, "*")
|
||||
return "SELECT " + idColumn + " FROM " + tableName + " "
|
||||
}
|
||||
|
||||
func selectDistinctIDs(tableName string) string {
|
||||
idColumn := getColumn(tableName, "id")
|
||||
return "SELECT DISTINCT " + idColumn + " FROM " + tableName + " "
|
||||
}
|
||||
|
||||
func getColumn(tableName string, columnName string) string {
|
||||
return tableName + "." + columnName
|
||||
}
|
||||
|
||||
func getPagination(findFilter *models.FindFilterType) string {
|
||||
if findFilter == nil {
|
||||
panic("nil find filter for pagination")
|
||||
}
|
||||
|
||||
var page int
|
||||
if findFilter.Page == nil || *findFilter.Page < 1 {
|
||||
page = 1
|
||||
} else {
|
||||
page = *findFilter.Page
|
||||
}
|
||||
|
||||
var perPage int
|
||||
if findFilter.PerPage == nil {
|
||||
perPage = 25
|
||||
} else {
|
||||
perPage = *findFilter.PerPage
|
||||
}
|
||||
|
||||
if perPage > 1000 {
|
||||
perPage = 1000
|
||||
} else if perPage < 1 {
|
||||
perPage = 1
|
||||
}
|
||||
|
||||
page = (page - 1) * perPage
|
||||
return " LIMIT " + strconv.Itoa(perPage) + " OFFSET " + strconv.Itoa(page) + " "
|
||||
}
|
||||
|
||||
func getSort(sort string, direction string, tableName string) string {
|
||||
if direction != "ASC" && direction != "DESC" {
|
||||
direction = "ASC"
|
||||
}
|
||||
|
||||
const randomSeedPrefix = "random_"
|
||||
|
||||
if strings.HasSuffix(sort, "_count") {
|
||||
var relationTableName = strings.TrimSuffix(sort, "_count") // TODO: pluralize?
|
||||
colName := getColumn(relationTableName, "id")
|
||||
return " ORDER BY COUNT(distinct " + colName + ") " + direction
|
||||
} else if strings.Compare(sort, "filesize") == 0 {
|
||||
colName := getColumn(tableName, "size")
|
||||
return " ORDER BY cast(" + colName + " as integer) " + direction
|
||||
} else if strings.HasPrefix(sort, randomSeedPrefix) {
|
||||
// seed as a parameter from the UI
|
||||
// turn the provided seed into a float
|
||||
seedStr := "0." + sort[len(randomSeedPrefix):]
|
||||
seed, err := strconv.ParseFloat(seedStr, 32)
|
||||
if err != nil {
|
||||
// fallback to default seed
|
||||
seed = randomSortFloat
|
||||
}
|
||||
return getRandomSort(tableName, direction, seed)
|
||||
} else if strings.Compare(sort, "random") == 0 {
|
||||
return getRandomSort(tableName, direction, randomSortFloat)
|
||||
} else {
|
||||
colName := getColumn(tableName, sort)
|
||||
var additional string
|
||||
if tableName == "scenes" {
|
||||
additional = ", bitrate DESC, framerate DESC, scenes.rating DESC, scenes.duration DESC"
|
||||
} else if tableName == "scene_markers" {
|
||||
additional = ", scene_markers.scene_id ASC, scene_markers.seconds ASC"
|
||||
}
|
||||
if strings.Compare(sort, "name") == 0 {
|
||||
return " ORDER BY " + colName + " COLLATE NOCASE " + direction + additional
|
||||
}
|
||||
if strings.Compare(sort, "title") == 0 {
|
||||
return " ORDER BY " + colName + " COLLATE NATURAL_CS " + direction + additional
|
||||
}
|
||||
|
||||
return " ORDER BY " + colName + " " + direction + additional
|
||||
}
|
||||
}
|
||||
|
||||
func getRandomSort(tableName string, direction string, seed float64) string {
|
||||
// https://stackoverflow.com/a/24511461
|
||||
colName := getColumn(tableName, "id")
|
||||
randomSortString := strconv.FormatFloat(seed, 'f', 16, 32)
|
||||
return " ORDER BY " + "(substr(" + colName + " * " + randomSortString + ", length(" + colName + ") + 2))" + " " + direction
|
||||
}
|
||||
|
||||
func getSearchBinding(columns []string, q string, not bool) (string, []interface{}) {
|
||||
var likeClauses []string
|
||||
var args []interface{}
|
||||
|
||||
notStr := ""
|
||||
binaryType := " OR "
|
||||
if not {
|
||||
notStr = " NOT "
|
||||
binaryType = " AND "
|
||||
}
|
||||
|
||||
queryWords := strings.Split(q, " ")
|
||||
trimmedQuery := strings.Trim(q, "\"")
|
||||
if trimmedQuery == q {
|
||||
// Search for any word
|
||||
for _, word := range queryWords {
|
||||
for _, column := range columns {
|
||||
likeClauses = append(likeClauses, column+notStr+" LIKE ?")
|
||||
args = append(args, "%"+word+"%")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Search the exact query
|
||||
for _, column := range columns {
|
||||
likeClauses = append(likeClauses, column+notStr+" LIKE ?")
|
||||
args = append(args, "%"+trimmedQuery+"%")
|
||||
}
|
||||
}
|
||||
likes := strings.Join(likeClauses, binaryType)
|
||||
|
||||
return "(" + likes + ")", args
|
||||
}
|
||||
|
||||
func getInBinding(length int) string {
|
||||
bindings := strings.Repeat("?, ", length)
|
||||
bindings = strings.TrimRight(bindings, ", ")
|
||||
return "(" + bindings + ")"
|
||||
}
|
||||
|
||||
func getCriterionModifierBinding(criterionModifier models.CriterionModifier, value interface{}) (string, int) {
|
||||
var length int
|
||||
switch x := value.(type) {
|
||||
case []string:
|
||||
length = len(x)
|
||||
case []int:
|
||||
length = len(x)
|
||||
default:
|
||||
length = 1
|
||||
}
|
||||
if modifier := criterionModifier.String(); criterionModifier.IsValid() {
|
||||
switch modifier {
|
||||
case "EQUALS", "NOT_EQUALS", "GREATER_THAN", "LESS_THAN", "IS_NULL", "NOT_NULL":
|
||||
return getSimpleCriterionClause(criterionModifier, "?")
|
||||
case "INCLUDES":
|
||||
return "IN " + getInBinding(length), length // TODO?
|
||||
case "EXCLUDES":
|
||||
return "NOT IN " + getInBinding(length), length // TODO?
|
||||
default:
|
||||
logger.Errorf("todo")
|
||||
return "= ?", 1 // TODO
|
||||
}
|
||||
}
|
||||
return "= ?", 1 // TODO
|
||||
}
|
||||
|
||||
func getSimpleCriterionClause(criterionModifier models.CriterionModifier, rhs string) (string, int) {
|
||||
if modifier := criterionModifier.String(); criterionModifier.IsValid() {
|
||||
switch modifier {
|
||||
case "EQUALS":
|
||||
return "= " + rhs, 1
|
||||
case "NOT_EQUALS":
|
||||
return "!= " + rhs, 1
|
||||
case "GREATER_THAN":
|
||||
return "> " + rhs, 1
|
||||
case "LESS_THAN":
|
||||
return "< " + rhs, 1
|
||||
case "IS_NULL":
|
||||
return "IS NULL", 0
|
||||
case "NOT_NULL":
|
||||
return "IS NOT NULL", 0
|
||||
default:
|
||||
logger.Errorf("todo")
|
||||
return "= ?", 1 // TODO
|
||||
}
|
||||
}
|
||||
|
||||
return "= ?", 1 // TODO
|
||||
}
|
||||
|
||||
func getIntCriterionWhereClause(column string, input models.IntCriterionInput) (string, int) {
|
||||
binding, count := getCriterionModifierBinding(input.Modifier, input.Value)
|
||||
return column + " " + binding, count
|
||||
}
|
||||
|
||||
// returns where clause and having clause
|
||||
func getMultiCriterionClause(primaryTable, foreignTable, joinTable, primaryFK, foreignFK string, criterion *models.MultiCriterionInput) (string, string) {
|
||||
whereClause := ""
|
||||
havingClause := ""
|
||||
if criterion.Modifier == models.CriterionModifierIncludes {
|
||||
// includes any of the provided ids
|
||||
whereClause = foreignTable + ".id IN " + getInBinding(len(criterion.Value))
|
||||
} else if criterion.Modifier == models.CriterionModifierIncludesAll {
|
||||
// includes all of the provided ids
|
||||
whereClause = foreignTable + ".id IN " + getInBinding(len(criterion.Value))
|
||||
havingClause = "count(distinct " + foreignTable + ".id) IS " + strconv.Itoa(len(criterion.Value))
|
||||
} else if criterion.Modifier == models.CriterionModifierExcludes {
|
||||
// excludes all of the provided ids
|
||||
if joinTable != "" {
|
||||
whereClause = "not exists (select " + joinTable + "." + primaryFK + " from " + joinTable + " where " + joinTable + "." + primaryFK + " = " + primaryTable + ".id and " + joinTable + "." + foreignFK + " in " + getInBinding(len(criterion.Value)) + ")"
|
||||
} else {
|
||||
whereClause = "not exists (select s.id from " + primaryTable + " as s where s.id = " + primaryTable + ".id and s." + foreignFK + " in " + getInBinding(len(criterion.Value)) + ")"
|
||||
}
|
||||
}
|
||||
|
||||
return whereClause, havingClause
|
||||
}
|
||||
|
||||
func ensureTx(tx *sqlx.Tx) {
|
||||
if tx == nil {
|
||||
panic("must use a transaction")
|
||||
}
|
||||
}
|
||||
|
||||
func getImage(tx dbi, query string, args ...interface{}) ([]byte, error) {
|
||||
rows, err := tx.Queryx(query, args...)
|
||||
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var ret []byte
|
||||
if rows.Next() {
|
||||
if err := rows.Scan(&ret); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
Reference in New Issue
Block a user