diff --git a/graphql/schema/schema.graphql b/graphql/schema/schema.graphql index 3a4f6e738..2491061d4 100644 --- a/graphql/schema/schema.graphql +++ b/graphql/schema/schema.graphql @@ -335,6 +335,12 @@ type Mutation { """Backup the database. Optionally returns a link to download the database file""" backupDatabase(input: BackupDatabaseInput!): String + """DANGEROUS: Execute an arbitrary SQL statement that returns rows.""" + querySQL(sql: String!, args: [Any]): SQLQueryResult! + + """DANGEROUS: Execute an arbitrary SQL statement without returning any rows.""" + execSQL(sql: String!, args: [Any]): SQLExecResult! + """Run batch performer tag task. Returns the job ID.""" stashBoxBatchPerformerTag(input: StashBoxBatchPerformerTagInput!): String! diff --git a/graphql/schema/types/sql.graphql b/graphql/schema/types/sql.graphql new file mode 100644 index 000000000..055c3a21c --- /dev/null +++ b/graphql/schema/types/sql.graphql @@ -0,0 +1,19 @@ +type SQLQueryResult { + """The column names, in the order they appear in the result set.""" + columns: [String!]! + """The returned rows.""" + rows: [[Any]!]! +} + +type SQLExecResult { + """ + The number of rows affected by the query, usually an UPDATE, INSERT, or DELETE. + Not all queries or databases support this feature. + """ + rows_affected: Int64 + """ + The integer generated by the database in response to a command. Typically this will be from an "auto increment" column when inserting a new row. + Not all databases support this feature, and the syntax of such statements varies. + """ + last_insert_id: Int64 +} diff --git a/internal/api/resolver.go b/internal/api/resolver.go index 3932270b7..ff74a4456 100644 --- a/internal/api/resolver.go +++ b/internal/api/resolver.go @@ -216,6 +216,44 @@ func (r *queryResolver) Latestversion(ctx context.Context) (*LatestVersion, erro }, nil } +func (r *mutationResolver) ExecSQL(ctx context.Context, sql string, args []interface{}) (*SQLExecResult, error) { + var rowsAffected *int64 + var lastInsertID *int64 + + db := manager.GetInstance().Database + if err := r.withTxn(ctx, func(ctx context.Context) error { + var err error + rowsAffected, lastInsertID, err = db.ExecSQL(ctx, sql, args) + return err + }); err != nil { + return nil, err + } + + return &SQLExecResult{ + RowsAffected: rowsAffected, + LastInsertID: lastInsertID, + }, nil +} + +func (r *mutationResolver) QuerySQL(ctx context.Context, sql string, args []interface{}) (*SQLQueryResult, error) { + var cols []string + var rows [][]interface{} + + db := manager.GetInstance().Database + if err := r.withTxn(ctx, func(ctx context.Context) error { + var err error + cols, rows, err = db.QuerySQL(ctx, sql, args) + return err + }); err != nil { + return nil, err + } + + return &SQLQueryResult{ + Columns: cols, + Rows: rows, + }, nil +} + // Get scene marker tags which show up under the video. func (r *queryResolver) SceneMarkerTags(ctx context.Context, scene_id string) ([]*SceneMarkerTag, error) { sceneID, err := strconv.Atoi(scene_id) diff --git a/pkg/sqlite/database.go b/pkg/sqlite/database.go index 12562b5c6..e6f6adbeb 100644 --- a/pkg/sqlite/database.go +++ b/pkg/sqlite/database.go @@ -2,6 +2,7 @@ package sqlite import ( "context" + "database/sql" "embed" "errors" "fmt" @@ -458,6 +459,60 @@ func (db *Database) Vacuum(ctx context.Context) error { return err } +func (db *Database) ExecSQL(ctx context.Context, query string, args []interface{}) (*int64, *int64, error) { + wrapper := dbWrapper{} + + result, err := wrapper.Exec(ctx, query, args...) + if err != nil { + return nil, nil, err + } + + var rowsAffected *int64 + ra, err := result.RowsAffected() + if err == nil { + rowsAffected = &ra + } + + var lastInsertId *int64 + li, err := result.LastInsertId() + if err == nil { + lastInsertId = &li + } + + return rowsAffected, lastInsertId, nil +} + +func (db *Database) QuerySQL(ctx context.Context, query string, args []interface{}) ([]string, [][]interface{}, error) { + wrapper := dbWrapper{} + + rows, err := wrapper.QueryxContext(ctx, query, args...) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return nil, nil, err + } + defer rows.Close() + + cols, err := rows.Columns() + if err != nil { + return nil, nil, err + } + + var ret [][]interface{} + + for rows.Next() { + row, err := rows.SliceScan() + if err != nil { + return nil, nil, err + } + ret = append(ret, row) + } + + if err := rows.Err(); err != nil { + return nil, nil, err + } + + return cols, ret, nil +} + func (db *Database) runCustomMigrations(ctx context.Context, fns []customMigrationFunc) error { for _, fn := range fns { if err := db.runCustomMigration(ctx, fn); err != nil {