Parent studios (#595)

* Refactor getMultiCriterionClause
Co-authored-by: Anon247 <61889302+Anon247@users.noreply.github.com>
This commit is contained in:
WithoutPants
2020-06-15 21:34:39 +10:00
committed by GitHub
parent a77fea5724
commit 96e6e16507
37 changed files with 818 additions and 146 deletions

View File

@@ -3,6 +3,22 @@ fragment StudioData on Studio {
checksum checksum
name name
url url
parent_studio {
id
checksum
name
url
image_path
scene_count
}
child_studios {
id
checksum
name
url
image_path
scene_count
}
image_path image_path
scene_count scene_count
} }

View File

@@ -1,9 +1,10 @@
mutation StudioCreate( mutation StudioCreate(
$name: String!, $name: String!,
$url: String, $url: String,
$image: String) { $image: String
$parent_id: ID) {
studioCreate(input: { name: $name, url: $url, image: $image }) { studioCreate(input: { name: $name, url: $url, image: $image, parent_id: $parent_id }) {
...StudioData ...StudioData
} }
} }
@@ -12,9 +13,10 @@ mutation StudioUpdate(
$id: ID! $id: ID!
$name: String, $name: String,
$url: String, $url: String,
$image: String) { $image: String
$parent_id: ID) {
studioUpdate(input: { id: $id, name: $name, url: $url, image: $image }) { studioUpdate(input: { id: $id, name: $name, url: $url, image: $image, parent_id: $parent_id }) {
...StudioData ...StudioData
} }
} }

View File

@@ -1,5 +1,5 @@
query FindStudios($filter: FindFilterType) { query FindStudios($filter: FindFilterType, $studio_filter: StudioFilterType ) {
findStudios(filter: $filter) { findStudios(filter: $filter, studio_filter: $studio_filter) {
count count
studios { studios {
...StudioData ...StudioData

View File

@@ -20,7 +20,7 @@ type Query {
"""Find a studio by ID""" """Find a studio by ID"""
findStudio(id: ID!): Studio findStudio(id: ID!): Studio
"""A function which queries Studio objects""" """A function which queries Studio objects"""
findStudios(filter: FindFilterType): FindStudiosResultType! findStudios(studio_filter: StudioFilterType, filter: FindFilterType): FindStudiosResultType!
"""Find a movie by ID""" """Find a movie by ID"""
findMovie(id: ID!): Movie findMovie(id: ID!): Movie

View File

@@ -91,6 +91,11 @@ input MovieFilterType {
studios: MultiCriterionInput studios: MultiCriterionInput
} }
input StudioFilterType {
"""Filter to only include studios with this parent studio"""
parents: MultiCriterionInput
}
enum CriterionModifier { enum CriterionModifier {
"""=""" """="""
EQUALS, EQUALS,

View File

@@ -3,7 +3,8 @@ type Studio {
checksum: String! checksum: String!
name: String! name: String!
url: String url: String
parent_studio: Studio
child_studios: [Studio!]!
image_path: String # Resolver image_path: String # Resolver
scene_count: Int # Resolver scene_count: Int # Resolver
} }
@@ -11,6 +12,7 @@ type Studio {
input StudioCreateInput { input StudioCreateInput {
name: String! name: String!
url: String url: String
parent_id: ID
"""This should be base64 encoded""" """This should be base64 encoded"""
image: String image: String
} }
@@ -19,6 +21,7 @@ input StudioUpdateInput {
id: ID! id: ID!
name: String name: String
url: String url: String
parent_id: ID,
"""This should be base64 encoded""" """This should be base64 encoded"""
image: String image: String
} }

View File

@@ -2,6 +2,7 @@ package api
import ( import (
"context" "context"
"github.com/stashapp/stash/pkg/api/urlbuilders" "github.com/stashapp/stash/pkg/api/urlbuilders"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
) )
@@ -31,3 +32,17 @@ func (r *studioResolver) SceneCount(ctx context.Context, obj *models.Studio) (*i
res, err := qb.CountByStudioID(obj.ID) res, err := qb.CountByStudioID(obj.ID)
return &res, err return &res, err
} }
func (r *studioResolver) ParentStudio(ctx context.Context, obj *models.Studio) (*models.Studio, error) {
if !obj.ParentID.Valid {
return nil, nil
}
qb := models.NewStudioQueryBuilder()
return qb.Find(int(obj.ParentID.Int64), nil)
}
func (r *studioResolver) ChildStudios(ctx context.Context, obj *models.Studio) ([]*models.Studio, error) {
qb := models.NewStudioQueryBuilder()
return qb.FindChildren(obj.ID, nil)
}

View File

@@ -7,6 +7,7 @@ import (
"time" "time"
"github.com/stashapp/stash/pkg/database" "github.com/stashapp/stash/pkg/database"
"github.com/stashapp/stash/pkg/manager"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
) )
@@ -40,6 +41,10 @@ func (r *mutationResolver) StudioCreate(ctx context.Context, input models.Studio
if input.URL != nil { if input.URL != nil {
newStudio.URL = sql.NullString{String: *input.URL, Valid: true} newStudio.URL = sql.NullString{String: *input.URL, Valid: true}
} }
if input.ParentID != nil {
parentID, _ := strconv.ParseInt(*input.ParentID, 10, 64)
newStudio.ParentID = sql.NullInt64{Int64: parentID, Valid: true}
}
// Start the transaction and save the studio // Start the transaction and save the studio
tx := database.DB.MustBeginTx(ctx, nil) tx := database.DB.MustBeginTx(ctx, nil)
@@ -61,33 +66,48 @@ func (r *mutationResolver) StudioCreate(ctx context.Context, input models.Studio
func (r *mutationResolver) StudioUpdate(ctx context.Context, input models.StudioUpdateInput) (*models.Studio, error) { func (r *mutationResolver) StudioUpdate(ctx context.Context, input models.StudioUpdateInput) (*models.Studio, error) {
// Populate studio from the input // Populate studio from the input
studioID, _ := strconv.Atoi(input.ID) studioID, _ := strconv.Atoi(input.ID)
updatedStudio := models.Studio{
updatedStudio := models.StudioPartial{
ID: studioID, ID: studioID,
UpdatedAt: models.SQLiteTimestamp{Timestamp: time.Now()}, UpdatedAt: &models.SQLiteTimestamp{Timestamp: time.Now()},
} }
if input.Image != nil { if input.Image != nil {
_, imageData, err := utils.ProcessBase64Image(*input.Image) _, imageData, err := utils.ProcessBase64Image(*input.Image)
if err != nil { if err != nil {
return nil, err return nil, err
} }
updatedStudio.Image = imageData updatedStudio.Image = &imageData
} }
if input.Name != nil { if input.Name != nil {
// generate checksum from studio name rather than image // generate checksum from studio name rather than image
checksum := utils.MD5FromString(*input.Name) checksum := utils.MD5FromString(*input.Name)
updatedStudio.Name = sql.NullString{String: *input.Name, Valid: true} updatedStudio.Name = &sql.NullString{String: *input.Name, Valid: true}
updatedStudio.Checksum = checksum updatedStudio.Checksum = &checksum
} }
if input.URL != nil { if input.URL != nil {
updatedStudio.URL = sql.NullString{String: *input.URL, Valid: true} updatedStudio.URL = &sql.NullString{String: *input.URL, Valid: true}
}
if input.ParentID != nil {
parentID, _ := strconv.ParseInt(*input.ParentID, 10, 64)
updatedStudio.ParentID = &sql.NullInt64{Int64: parentID, Valid: true}
} else {
// parent studio must be nullable
updatedStudio.ParentID = &sql.NullInt64{Valid: false}
} }
// Start the transaction and save the studio // Start the transaction and save the studio
tx := database.DB.MustBeginTx(ctx, nil) tx := database.DB.MustBeginTx(ctx, nil)
qb := models.NewStudioQueryBuilder() qb := models.NewStudioQueryBuilder()
if err := manager.ValidateModifyStudio(updatedStudio, tx); err != nil {
tx.Rollback()
return nil, err
}
studio, err := qb.Update(updatedStudio, tx) studio, err := qb.Update(updatedStudio, tx)
if err != nil { if err != nil {
_ = tx.Rollback() tx.Rollback()
return nil, err return nil, err
} }

View File

@@ -12,9 +12,9 @@ func (r *queryResolver) FindStudio(ctx context.Context, id string) (*models.Stud
return qb.Find(idInt, nil) return qb.Find(idInt, nil)
} }
func (r *queryResolver) FindStudios(ctx context.Context, filter *models.FindFilterType) (*models.FindStudiosResultType, error) { func (r *queryResolver) FindStudios(ctx context.Context, studioFilter *models.StudioFilterType, filter *models.FindFilterType) (*models.FindStudiosResultType, error) {
qb := models.NewStudioQueryBuilder() qb := models.NewStudioQueryBuilder()
studios, total := qb.Query(filter) studios, total := qb.Query(studioFilter, filter)
return &models.FindStudiosResultType{ return &models.FindStudiosResultType{
Count: total, Count: total,
Studios: studios, Studios: studios,

View File

@@ -19,7 +19,7 @@ import (
var DB *sqlx.DB var DB *sqlx.DB
var dbPath string var dbPath string
var appSchemaVersion uint = 8 var appSchemaVersion uint = 9
var databaseSchemaVersion uint var databaseSchemaVersion uint
const sqlite3Driver = "sqlite3ex" const sqlite3Driver = "sqlite3ex"

View File

@@ -0,0 +1,3 @@
ALTER TABLE studios
ADD COLUMN parent_id INTEGER DEFAULT NULL CHECK ( id IS NOT parent_id ) REFERENCES studios(id) on delete set null;
CREATE INDEX index_studios_on_parent_id on studios (parent_id);

View File

@@ -2,17 +2,19 @@ package jsonschema
import ( import (
"fmt" "fmt"
"github.com/json-iterator/go"
"github.com/stashapp/stash/pkg/models"
"os" "os"
jsoniter "github.com/json-iterator/go"
"github.com/stashapp/stash/pkg/models"
) )
type Studio struct { type Studio struct {
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
URL string `json:"url,omitempty"` URL string `json:"url,omitempty"`
Image string `json:"image,omitempty"` ParentStudio string `json:"parent_studio,omitempty"`
CreatedAt models.JSONTime `json:"created_at,omitempty"` Image string `json:"image,omitempty"`
UpdatedAt models.JSONTime `json:"updated_at,omitempty"` CreatedAt models.JSONTime `json:"created_at,omitempty"`
UpdatedAt models.JSONTime `json:"updated_at,omitempty"`
} }
func LoadStudioFile(filePath string) (*Studio, error) { func LoadStudioFile(filePath string) (*Studio, error) {

36
pkg/manager/studio.go Normal file
View File

@@ -0,0 +1,36 @@
package manager
import (
"errors"
"fmt"
"github.com/jmoiron/sqlx"
"github.com/stashapp/stash/pkg/models"
)
func ValidateModifyStudio(studio models.StudioPartial, tx *sqlx.Tx) error {
if studio.ParentID == nil || !studio.ParentID.Valid {
return nil
}
// ensure there is no cyclic dependency
thisID := studio.ID
qb := models.NewStudioQueryBuilder()
currentParentID := *studio.ParentID
for currentParentID.Valid {
if currentParentID.Int64 == int64(thisID) {
return errors.New("studio cannot be an ancestor of itself")
}
currentStudio, err := qb.Find(int(currentParentID.Int64), tx)
if err != nil {
return fmt.Errorf("error finding parent studio: %s", err.Error())
}
currentParentID = currentStudio.ParentID
}
return nil
}

View File

@@ -3,6 +3,12 @@ package manager
import ( import (
"context" "context"
"fmt" "fmt"
"math"
"runtime"
"strconv"
"sync"
"time"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
"github.com/stashapp/stash/pkg/database" "github.com/stashapp/stash/pkg/database"
"github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/logger"
@@ -10,11 +16,6 @@ import (
"github.com/stashapp/stash/pkg/manager/paths" "github.com/stashapp/stash/pkg/manager/paths"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
"math"
"runtime"
"strconv"
"sync"
"time"
) )
type ExportTask struct { type ExportTask struct {
@@ -395,6 +396,8 @@ func (t *ExportTask) ExportStudios(ctx context.Context, workers int) {
func exportStudio(wg *sync.WaitGroup, jobChan <-chan *models.Studio) { func exportStudio(wg *sync.WaitGroup, jobChan <-chan *models.Studio) {
defer wg.Done() defer wg.Done()
studioQB := models.NewStudioQueryBuilder()
for studio := range jobChan { for studio := range jobChan {
newStudioJSON := jsonschema.Studio{ newStudioJSON := jsonschema.Studio{
@@ -408,6 +411,12 @@ func exportStudio(wg *sync.WaitGroup, jobChan <-chan *models.Studio) {
if studio.URL.Valid { if studio.URL.Valid {
newStudioJSON.URL = studio.URL.String newStudioJSON.URL = studio.URL.String
} }
if studio.ParentID.Valid {
parent, _ := studioQB.Find(int(studio.ParentID.Int64), nil)
if parent != nil {
newStudioJSON.ParentStudio = parent.Name.String
}
}
newStudioJSON.Image = utils.GetBase64StringFromData(studio.Image) newStudioJSON.Image = utils.GetBase64StringFromData(studio.Image)

View File

@@ -3,6 +3,7 @@ package manager
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"strconv" "strconv"
"sync" "sync"
"time" "time"
@@ -157,7 +158,8 @@ func (t *ImportTask) ImportPerformers(ctx context.Context) {
func (t *ImportTask) ImportStudios(ctx context.Context) { func (t *ImportTask) ImportStudios(ctx context.Context) {
tx := database.DB.MustBeginTx(ctx, nil) tx := database.DB.MustBeginTx(ctx, nil)
qb := models.NewStudioQueryBuilder()
pendingParent := make(map[string][]*jsonschema.Studio)
for i, mappingJSON := range t.Mappings.Studios { for i, mappingJSON := range t.Mappings.Studios {
index := i + 1 index := i + 1
@@ -172,35 +174,28 @@ func (t *ImportTask) ImportStudios(ctx context.Context) {
logger.Progressf("[studios] %d of %d", index, len(t.Mappings.Studios)) logger.Progressf("[studios] %d of %d", index, len(t.Mappings.Studios))
// generate checksum from studio name rather than image if err := t.ImportStudio(studioJSON, pendingParent, tx); err != nil {
checksum := utils.MD5FromString(studioJSON.Name) tx.Rollback()
// Process the base 64 encoded image string
_, imageData, err := utils.ProcessBase64Image(studioJSON.Image)
if err != nil {
_ = tx.Rollback()
logger.Errorf("[studios] <%s> invalid image: %s", mappingJSON.Checksum, err.Error())
return
}
// Populate a new studio from the input
newStudio := models.Studio{
Image: imageData,
Checksum: checksum,
Name: sql.NullString{String: studioJSON.Name, Valid: true},
URL: sql.NullString{String: studioJSON.URL, Valid: true},
CreatedAt: models.SQLiteTimestamp{Timestamp: t.getTimeFromJSONTime(studioJSON.CreatedAt)},
UpdatedAt: models.SQLiteTimestamp{Timestamp: t.getTimeFromJSONTime(studioJSON.UpdatedAt)},
}
_, err = qb.Create(newStudio, tx)
if err != nil {
_ = tx.Rollback()
logger.Errorf("[studios] <%s> failed to create: %s", mappingJSON.Checksum, err.Error()) logger.Errorf("[studios] <%s> failed to create: %s", mappingJSON.Checksum, err.Error())
return return
} }
} }
// create the leftover studios, warning for missing parents
if len(pendingParent) > 0 {
logger.Warnf("[studios] importing studios with missing parents")
for _, s := range pendingParent {
for _, orphanStudioJSON := range s {
if err := t.ImportStudio(orphanStudioJSON, nil, tx); err != nil {
tx.Rollback()
logger.Errorf("[studios] <%s> failed to create: %s", orphanStudioJSON.Name, err.Error())
return
}
}
}
}
logger.Info("[studios] importing") logger.Info("[studios] importing")
if err := tx.Commit(); err != nil { if err := tx.Commit(); err != nil {
logger.Errorf("[studios] import failed to commit: %s", err.Error()) logger.Errorf("[studios] import failed to commit: %s", err.Error())
@@ -208,6 +203,74 @@ func (t *ImportTask) ImportStudios(ctx context.Context) {
logger.Info("[studios] import complete") logger.Info("[studios] import complete")
} }
func (t *ImportTask) ImportStudio(studioJSON *jsonschema.Studio, pendingParent map[string][]*jsonschema.Studio, tx *sqlx.Tx) error {
qb := models.NewStudioQueryBuilder()
// generate checksum from studio name rather than image
checksum := utils.MD5FromString(studioJSON.Name)
// Process the base 64 encoded image string
_, imageData, err := utils.ProcessBase64Image(studioJSON.Image)
if err != nil {
return fmt.Errorf("invalid image: %s", err.Error())
}
// Populate a new studio from the input
newStudio := models.Studio{
Image: imageData,
Checksum: checksum,
Name: sql.NullString{String: studioJSON.Name, Valid: true},
URL: sql.NullString{String: studioJSON.URL, Valid: true},
CreatedAt: models.SQLiteTimestamp{Timestamp: t.getTimeFromJSONTime(studioJSON.CreatedAt)},
UpdatedAt: models.SQLiteTimestamp{Timestamp: t.getTimeFromJSONTime(studioJSON.UpdatedAt)},
}
// Populate the parent ID
if studioJSON.ParentStudio != "" {
studio, err := qb.FindByName(studioJSON.ParentStudio, tx, false)
if err != nil {
return fmt.Errorf("error finding studio by name <%s>: %s", studioJSON.ParentStudio, err.Error())
}
if studio == nil {
// its possible that the parent hasn't been created yet
// do it after it is created
if pendingParent == nil {
logger.Warnf("[studios] studio <%s> does not exist", studioJSON.ParentStudio)
} else {
// add to the pending parent list so that it is created after the parent
s := pendingParent[studioJSON.ParentStudio]
s = append(s, studioJSON)
pendingParent[studioJSON.ParentStudio] = s
// skip
return nil
}
} else {
newStudio.ParentID = sql.NullInt64{Int64: int64(studio.ID), Valid: true}
}
}
_, err = qb.Create(newStudio, tx)
if err != nil {
return err
}
// now create the studios pending this studios creation
s := pendingParent[studioJSON.Name]
for _, childStudioJSON := range s {
// map is nil since we're not checking parent studios at this point
if err := t.ImportStudio(childStudioJSON, nil, tx); err != nil {
return fmt.Errorf("failed to create child studio <%s>: %s", childStudioJSON.Name, err.Error())
}
}
// delete the entry from the map so that we know its not left over
delete(pendingParent, studioJSON.Name)
return err
}
func (t *ImportTask) ImportMovies(ctx context.Context) { func (t *ImportTask) ImportMovies(ctx context.Context) {
tx := database.DB.MustBeginTx(ctx, nil) tx := database.DB.MustBeginTx(ctx, nil)
qb := models.NewMovieQueryBuilder() qb := models.NewMovieQueryBuilder()

View File

@@ -10,8 +10,20 @@ type Studio struct {
Checksum string `db:"checksum" json:"checksum"` Checksum string `db:"checksum" json:"checksum"`
Name sql.NullString `db:"name" json:"name"` Name sql.NullString `db:"name" json:"name"`
URL sql.NullString `db:"url" json:"url"` URL sql.NullString `db:"url" json:"url"`
ParentID sql.NullInt64 `db:"parent_id,omitempty" json:"parent_id"`
CreatedAt SQLiteTimestamp `db:"created_at" json:"created_at"` CreatedAt SQLiteTimestamp `db:"created_at" json:"created_at"`
UpdatedAt SQLiteTimestamp `db:"updated_at" json:"updated_at"` UpdatedAt SQLiteTimestamp `db:"updated_at" json:"updated_at"`
} }
type StudioPartial struct {
ID int `db:"id" json:"id"`
Image *[]byte `db:"image" json:"image"`
Checksum *string `db:"checksum" json:"checksum"`
Name *sql.NullString `db:"name" json:"name"`
URL *sql.NullString `db:"url" json:"url"`
ParentID *sql.NullInt64 `db:"parent_id,omitempty" json:"parent_id"`
CreatedAt *SQLiteTimestamp `db:"created_at" json:"created_at"`
UpdatedAt *SQLiteTimestamp `db:"updated_at" json:"updated_at"`
}
var DefaultStudioImage = "" var DefaultStudioImage = ""

View File

@@ -2,7 +2,6 @@ package models
import ( import (
"database/sql" "database/sql"
"strconv"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
"github.com/stashapp/stash/pkg/database" "github.com/stashapp/stash/pkg/database"
@@ -148,7 +147,7 @@ func (qb *MovieQueryBuilder) Query(movieFilter *MovieFilterType, findFilter *Fin
args = append(args, studioID) args = append(args, studioID)
} }
whereClause, havingClause := qb.getMultiCriterionClause("studio", "", "studio_id", studiosFilter) whereClause, havingClause := getMultiCriterionClause("movies", "studio", "", "", "studio_id", studiosFilter)
whereClauses = appendClause(whereClauses, whereClause) whereClauses = appendClause(whereClauses, whereClause)
havingClauses = appendClause(havingClauses, havingClause) havingClauses = appendClause(havingClauses, havingClause)
} }
@@ -165,29 +164,6 @@ func (qb *MovieQueryBuilder) Query(movieFilter *MovieFilterType, findFilter *Fin
return movies, countResult return movies, countResult
} }
// returns where clause and having clause
func (qb *MovieQueryBuilder) getMultiCriterionClause(table string, joinTable string, joinTableField string, criterion *MultiCriterionInput) (string, string) {
whereClause := ""
havingClause := ""
if criterion.Modifier == CriterionModifierIncludes {
// includes any of the provided ids
whereClause = table + ".id IN " + getInBinding(len(criterion.Value))
} else if criterion.Modifier == CriterionModifierIncludesAll {
// includes all of the provided ids
whereClause = table + ".id IN " + getInBinding(len(criterion.Value))
havingClause = "count(distinct " + table + ".id) IS " + strconv.Itoa(len(criterion.Value))
} else if criterion.Modifier == CriterionModifierExcludes {
// excludes all of the provided ids
if joinTable != "" {
whereClause = "not exists (select " + joinTable + ".movie_id from " + joinTable + " where " + joinTable + ".movie_id = movies.id and " + joinTable + "." + joinTableField + " in " + getInBinding(len(criterion.Value)) + ")"
} else {
whereClause = "not exists (select m.id from movies as m where m.id = movies.id and m." + joinTableField + " in " + getInBinding(len(criterion.Value)) + ")"
}
}
return whereClause, havingClause
}
func (qb *MovieQueryBuilder) getMovieSort(findFilter *FindFilterType) string { func (qb *MovieQueryBuilder) getMovieSort(findFilter *FindFilterType) string {
var sort string var sort string
var direction string var direction string

View File

@@ -3,6 +3,7 @@
package models_test package models_test
import ( import (
"strconv"
"strings" "strings"
"testing" "testing"
@@ -86,6 +87,42 @@ func TestMovieFindByNames(t *testing.T) {
assert.Equal(t, strings.ToLower(movieNames[movieIdxWithScene]), strings.ToLower(movies[1].Name.String)) assert.Equal(t, strings.ToLower(movieNames[movieIdxWithScene]), strings.ToLower(movies[1].Name.String))
} }
func TestMovieQueryStudio(t *testing.T) {
mqb := models.NewMovieQueryBuilder()
studioCriterion := models.MultiCriterionInput{
Value: []string{
strconv.Itoa(studioIDs[studioIdxWithMovie]),
},
Modifier: models.CriterionModifierIncludes,
}
movieFilter := models.MovieFilterType{
Studios: &studioCriterion,
}
movies, _ := mqb.Query(&movieFilter, nil)
assert.Len(t, movies, 1)
// ensure id is correct
assert.Equal(t, movieIDs[movieIdxWithStudio], movies[0].ID)
studioCriterion = models.MultiCriterionInput{
Value: []string{
strconv.Itoa(studioIDs[studioIdxWithMovie]),
},
Modifier: models.CriterionModifierExcludes,
}
q := getMovieStringValue(movieIdxWithStudio, titleField)
findFilter := models.FindFilterType{
Q: &q,
}
movies, _ = mqb.Query(&movieFilter, &findFilter)
assert.Len(t, movies, 0)
}
// TODO Update // TODO Update
// TODO Destroy // TODO Destroy
// TODO Find // TODO Find

View File

@@ -2,7 +2,6 @@ package models
import ( import (
"database/sql" "database/sql"
"strconv"
"strings" "strings"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
@@ -329,7 +328,7 @@ func (qb *SceneQueryBuilder) Query(sceneFilter *SceneFilterType, findFilter *Fin
} }
query.body += " LEFT JOIN tags on tags_join.tag_id = tags.id" query.body += " LEFT JOIN tags on tags_join.tag_id = tags.id"
whereClause, havingClause := getMultiCriterionClause("tags", "scenes_tags", "tag_id", tagsFilter) whereClause, havingClause := getMultiCriterionClause("scenes", "tags", "scenes_tags", "scene_id", "tag_id", tagsFilter)
query.addWhere(whereClause) query.addWhere(whereClause)
query.addHaving(havingClause) query.addHaving(havingClause)
} }
@@ -340,7 +339,7 @@ func (qb *SceneQueryBuilder) Query(sceneFilter *SceneFilterType, findFilter *Fin
} }
query.body += " LEFT JOIN performers ON performers_join.performer_id = performers.id" query.body += " LEFT JOIN performers ON performers_join.performer_id = performers.id"
whereClause, havingClause := getMultiCriterionClause("performers", "performers_scenes", "performer_id", performersFilter) whereClause, havingClause := getMultiCriterionClause("scenes", "performers", "performers_scenes", "scene_id", "performer_id", performersFilter)
query.addWhere(whereClause) query.addWhere(whereClause)
query.addHaving(havingClause) query.addHaving(havingClause)
} }
@@ -350,7 +349,7 @@ func (qb *SceneQueryBuilder) Query(sceneFilter *SceneFilterType, findFilter *Fin
query.addArg(studioID) query.addArg(studioID)
} }
whereClause, havingClause := getMultiCriterionClause("studio", "", "studio_id", studiosFilter) whereClause, havingClause := getMultiCriterionClause("scenes", "studio", "", "", "studio_id", studiosFilter)
query.addWhere(whereClause) query.addWhere(whereClause)
query.addHaving(havingClause) query.addHaving(havingClause)
} }
@@ -361,7 +360,7 @@ func (qb *SceneQueryBuilder) Query(sceneFilter *SceneFilterType, findFilter *Fin
} }
query.body += " LEFT JOIN movies ON movies_join.movie_id = movies.id" query.body += " LEFT JOIN movies ON movies_join.movie_id = movies.id"
whereClause, havingClause := getMultiCriterionClause("movies", "movies_scenes", "movie_id", moviesFilter) whereClause, havingClause := getMultiCriterionClause("scenes", "movies", "movies_scenes", "scene_id", "movie_id", moviesFilter)
query.addWhere(whereClause) query.addWhere(whereClause)
query.addHaving(havingClause) query.addHaving(havingClause)
} }
@@ -414,29 +413,6 @@ func getDurationWhereClause(durationFilter IntCriterionInput) (string, []interfa
return clause, args return clause, args
} }
// returns where clause and having clause
func getMultiCriterionClause(table string, joinTable string, joinTableField string, criterion *MultiCriterionInput) (string, string) {
whereClause := ""
havingClause := ""
if criterion.Modifier == CriterionModifierIncludes {
// includes any of the provided ids
whereClause = table + ".id IN " + getInBinding(len(criterion.Value))
} else if criterion.Modifier == CriterionModifierIncludesAll {
// includes all of the provided ids
whereClause = table + ".id IN " + getInBinding(len(criterion.Value))
havingClause = "count(distinct " + table + ".id) IS " + strconv.Itoa(len(criterion.Value))
} else if criterion.Modifier == CriterionModifierExcludes {
// excludes all of the provided ids
if joinTable != "" {
whereClause = "not exists (select " + joinTable + ".scene_id from " + joinTable + " where " + joinTable + ".scene_id = scenes.id and " + joinTable + "." + joinTableField + " in " + getInBinding(len(criterion.Value)) + ")"
} else {
whereClause = "not exists (select s.id from scenes as s where s.id = scenes.id and s." + joinTableField + " in " + getInBinding(len(criterion.Value)) + ")"
}
}
return whereClause, havingClause
}
func (qb *SceneQueryBuilder) QueryAllByPathRegex(regex string) ([]*Scene, error) { func (qb *SceneQueryBuilder) QueryAllByPathRegex(regex string) ([]*Scene, error) {
var args []interface{} var args []interface{}
body := selectDistinctIDs("scenes") + " WHERE scenes.path regexp ?" body := selectDistinctIDs("scenes") + " WHERE scenes.path regexp ?"

View File

@@ -676,6 +676,42 @@ func TestSceneQueryStudio(t *testing.T) {
assert.Len(t, scenes, 0) assert.Len(t, scenes, 0)
} }
func TestSceneQueryMovies(t *testing.T) {
sqb := models.NewSceneQueryBuilder()
movieCriterion := models.MultiCriterionInput{
Value: []string{
strconv.Itoa(movieIDs[movieIdxWithScene]),
},
Modifier: models.CriterionModifierIncludes,
}
sceneFilter := models.SceneFilterType{
Movies: &movieCriterion,
}
scenes, _ := sqb.Query(&sceneFilter, nil)
assert.Len(t, scenes, 1)
// ensure id is correct
assert.Equal(t, sceneIDs[sceneIdxWithMovie], scenes[0].ID)
movieCriterion = models.MultiCriterionInput{
Value: []string{
strconv.Itoa(movieIDs[movieIdxWithScene]),
},
Modifier: models.CriterionModifierExcludes,
}
q := getSceneStringValue(sceneIdxWithMovie, titleField)
findFilter := models.FindFilterType{
Q: &q,
}
scenes, _ = sqb.Query(&sceneFilter, &findFilter)
assert.Len(t, scenes, 0)
}
func TestSceneQuerySorting(t *testing.T) { func TestSceneQuerySorting(t *testing.T) {
sort := titleField sort := titleField
direction := models.SortDirectionEnumAsc direction := models.SortDirectionEnumAsc

View File

@@ -239,6 +239,29 @@ func getIntCriterionWhereClause(column string, input IntCriterionInput) (string,
return column + " " + binding, count return column + " " + binding, count
} }
// returns where clause and having clause
func getMultiCriterionClause(primaryTable, foreignTable, joinTable, primaryFK, foreignFK string, criterion *MultiCriterionInput) (string, string) {
whereClause := ""
havingClause := ""
if criterion.Modifier == CriterionModifierIncludes {
// includes any of the provided ids
whereClause = foreignTable + ".id IN " + getInBinding(len(criterion.Value))
} else if criterion.Modifier == CriterionModifierIncludesAll {
// includes all of the provided ids
whereClause = foreignTable + ".id IN " + getInBinding(len(criterion.Value))
havingClause = "count(distinct " + foreignTable + ".id) IS " + strconv.Itoa(len(criterion.Value))
} else if criterion.Modifier == CriterionModifierExcludes {
// excludes all of the provided ids
if joinTable != "" {
whereClause = "not exists (select " + joinTable + "." + primaryFK + " from " + joinTable + " where " + joinTable + "." + primaryFK + " = " + primaryTable + ".id and " + joinTable + "." + foreignFK + " in " + getInBinding(len(criterion.Value)) + ")"
} else {
whereClause = "not exists (select s.id from " + primaryTable + " as s where s.id = " + primaryTable + ".id and s." + foreignFK + " in " + getInBinding(len(criterion.Value)) + ")"
}
}
return whereClause, havingClause
}
func runIdsQuery(query string, args []interface{}) ([]int, error) { func runIdsQuery(query string, args []interface{}) ([]int, error) {
var result []struct { var result []struct {
Int int `db:"id"` Int int `db:"id"`

View File

@@ -16,8 +16,8 @@ func NewStudioQueryBuilder() StudioQueryBuilder {
func (qb *StudioQueryBuilder) Create(newStudio Studio, tx *sqlx.Tx) (*Studio, error) { func (qb *StudioQueryBuilder) Create(newStudio Studio, tx *sqlx.Tx) (*Studio, error) {
ensureTx(tx) ensureTx(tx)
result, err := tx.NamedExec( result, err := tx.NamedExec(
`INSERT INTO studios (image, checksum, name, url, created_at, updated_at) `INSERT INTO studios (image, checksum, name, url, parent_id, created_at, updated_at)
VALUES (:image, :checksum, :name, :url, :created_at, :updated_at) VALUES (:image, :checksum, :name, :url, :parent_id, :created_at, :updated_at)
`, `,
newStudio, newStudio,
) )
@@ -35,20 +35,21 @@ func (qb *StudioQueryBuilder) Create(newStudio Studio, tx *sqlx.Tx) (*Studio, er
return &newStudio, nil return &newStudio, nil
} }
func (qb *StudioQueryBuilder) Update(updatedStudio Studio, tx *sqlx.Tx) (*Studio, error) { func (qb *StudioQueryBuilder) Update(updatedStudio StudioPartial, tx *sqlx.Tx) (*Studio, error) {
ensureTx(tx) ensureTx(tx)
_, err := tx.NamedExec( _, err := tx.NamedExec(
`UPDATE studios SET `+SQLGenKeys(updatedStudio)+` WHERE studios.id = :id`, `UPDATE studios SET `+SQLGenKeysPartial(updatedStudio)+` WHERE studios.id = :id`,
updatedStudio, updatedStudio,
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := tx.Get(&updatedStudio, `SELECT * FROM studios WHERE id = ? LIMIT 1`, updatedStudio.ID); err != nil { var ret Studio
if err := tx.Get(&ret, `SELECT * FROM studios WHERE id = ? LIMIT 1`, updatedStudio.ID); err != nil {
return nil, err return nil, err
} }
return &updatedStudio, nil return &ret, nil
} }
func (qb *StudioQueryBuilder) Destroy(id string, tx *sqlx.Tx) error { func (qb *StudioQueryBuilder) Destroy(id string, tx *sqlx.Tx) error {
@@ -73,6 +74,12 @@ func (qb *StudioQueryBuilder) Find(id int, tx *sqlx.Tx) (*Studio, error) {
return qb.queryStudio(query, args, tx) return qb.queryStudio(query, args, tx)
} }
func (qb *StudioQueryBuilder) FindChildren(id int, tx *sqlx.Tx) ([]*Studio, error) {
query := "SELECT studios.* FROM studios WHERE studios.parent_id = ?"
args := []interface{}{id}
return qb.queryStudios(query, args, tx)
}
func (qb *StudioQueryBuilder) FindBySceneID(sceneID int) (*Studio, error) { func (qb *StudioQueryBuilder) FindBySceneID(sceneID int) (*Studio, error) {
query := "SELECT studios.* FROM studios JOIN scenes ON studios.id = scenes.studio_id WHERE scenes.id = ? LIMIT 1" query := "SELECT studios.* FROM studios JOIN scenes ON studios.id = scenes.studio_id WHERE scenes.id = ? LIMIT 1"
args := []interface{}{sceneID} args := []interface{}{sceneID}
@@ -101,7 +108,10 @@ func (qb *StudioQueryBuilder) AllSlim() ([]*Studio, error) {
return qb.queryStudios("SELECT studios.id, studios.name FROM studios "+qb.getStudioSort(nil), nil, nil) return qb.queryStudios("SELECT studios.id, studios.name FROM studios "+qb.getStudioSort(nil), nil, nil)
} }
func (qb *StudioQueryBuilder) Query(findFilter *FindFilterType) ([]*Studio, int) { func (qb *StudioQueryBuilder) Query(studioFilter *StudioFilterType, findFilter *FindFilterType) ([]*Studio, int) {
if studioFilter == nil {
studioFilter = &StudioFilterType{}
}
if findFilter == nil { if findFilter == nil {
findFilter = &FindFilterType{} findFilter = &FindFilterType{}
} }
@@ -116,11 +126,26 @@ func (qb *StudioQueryBuilder) Query(findFilter *FindFilterType) ([]*Studio, int)
if q := findFilter.Q; q != nil && *q != "" { if q := findFilter.Q; q != nil && *q != "" {
searchColumns := []string{"studios.name"} searchColumns := []string{"studios.name"}
clause, thisArgs := getSearchBinding(searchColumns, *q, false) clause, thisArgs := getSearchBinding(searchColumns, *q, false)
whereClauses = append(whereClauses, clause) whereClauses = append(whereClauses, clause)
args = append(args, thisArgs...) args = append(args, thisArgs...)
} }
if parentsFilter := studioFilter.Parents; parentsFilter != nil && len(parentsFilter.Value) > 0 {
body += `
left join studios as parent_studio on parent_studio.id = studios.parent_id
`
for _, studioID := range parentsFilter.Value {
args = append(args, studioID)
}
whereClause, havingClause := getMultiCriterionClause("studios", "parent_studio", "", "", "parent_id", parentsFilter)
whereClauses = appendClause(whereClauses, whereClause)
havingClauses = appendClause(havingClauses, havingClause)
}
sortAndPagination := qb.getStudioSort(findFilter) + getPagination(findFilter) sortAndPagination := qb.getStudioSort(findFilter) + getPagination(findFilter)
idsResult, countResult := executeFindQuery("studios", body, args, sortAndPagination, whereClauses, havingClauses) idsResult, countResult := executeFindQuery("studios", body, args, sortAndPagination, whereClauses, havingClauses)

View File

@@ -3,9 +3,13 @@
package models_test package models_test
import ( import (
"context"
"database/sql"
"strconv"
"strings" "strings"
"testing" "testing"
"github.com/stashapp/stash/pkg/database"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@@ -39,6 +43,173 @@ func TestStudioFindByName(t *testing.T) {
} }
func TestStudioQueryParent(t *testing.T) {
sqb := models.NewStudioQueryBuilder()
studioCriterion := models.MultiCriterionInput{
Value: []string{
strconv.Itoa(studioIDs[studioIdxWithChildStudio]),
},
Modifier: models.CriterionModifierIncludes,
}
studioFilter := models.StudioFilterType{
Parents: &studioCriterion,
}
studios, _ := sqb.Query(&studioFilter, nil)
assert.Len(t, studios, 1)
// ensure id is correct
assert.Equal(t, sceneIDs[studioIdxWithParentStudio], studios[0].ID)
studioCriterion = models.MultiCriterionInput{
Value: []string{
strconv.Itoa(studioIDs[studioIdxWithChildStudio]),
},
Modifier: models.CriterionModifierExcludes,
}
q := getStudioStringValue(studioIdxWithParentStudio, titleField)
findFilter := models.FindFilterType{
Q: &q,
}
studios, _ = sqb.Query(&studioFilter, &findFilter)
assert.Len(t, studios, 0)
}
func TestStudioDestroyParent(t *testing.T) {
const parentName = "parent"
const childName = "child"
// create parent and child studios
ctx := context.TODO()
tx := database.DB.MustBeginTx(ctx, nil)
createdParent, err := createStudio(tx, parentName, nil)
if err != nil {
tx.Rollback()
t.Fatalf("Error creating parent studio: %s", err.Error())
}
parentID := int64(createdParent.ID)
createdChild, err := createStudio(tx, childName, &parentID)
if err != nil {
tx.Rollback()
t.Fatalf("Error creating child studio: %s", err.Error())
}
if err := tx.Commit(); err != nil {
tx.Rollback()
t.Fatalf("Error committing: %s", err.Error())
}
sqb := models.NewStudioQueryBuilder()
// destroy the parent
tx = database.DB.MustBeginTx(ctx, nil)
err = sqb.Destroy(strconv.Itoa(createdParent.ID), tx)
if err != nil {
tx.Rollback()
t.Fatalf("Error destroying parent studio: %s", err.Error())
}
if err := tx.Commit(); err != nil {
tx.Rollback()
t.Fatalf("Error committing: %s", err.Error())
}
// destroy the child
tx = database.DB.MustBeginTx(ctx, nil)
err = sqb.Destroy(strconv.Itoa(createdChild.ID), tx)
if err != nil {
tx.Rollback()
t.Fatalf("Error destroying child studio: %s", err.Error())
}
if err := tx.Commit(); err != nil {
tx.Rollback()
t.Fatalf("Error committing: %s", err.Error())
}
}
func TestStudioFindChildren(t *testing.T) {
sqb := models.NewStudioQueryBuilder()
studios, err := sqb.FindChildren(studioIDs[studioIdxWithChildStudio], nil)
if err != nil {
t.Fatalf("error calling FindChildren: %s", err.Error())
}
assert.Len(t, studios, 1)
assert.Equal(t, studioIDs[studioIdxWithParentStudio], studios[0].ID)
studios, err = sqb.FindChildren(0, nil)
if err != nil {
t.Fatalf("error calling FindChildren: %s", err.Error())
}
assert.Len(t, studios, 0)
}
func TestStudioUpdateClearParent(t *testing.T) {
const parentName = "clearParent_parent"
const childName = "clearParent_child"
// create parent and child studios
ctx := context.TODO()
tx := database.DB.MustBeginTx(ctx, nil)
createdParent, err := createStudio(tx, parentName, nil)
if err != nil {
tx.Rollback()
t.Fatalf("Error creating parent studio: %s", err.Error())
}
parentID := int64(createdParent.ID)
createdChild, err := createStudio(tx, childName, &parentID)
if err != nil {
tx.Rollback()
t.Fatalf("Error creating child studio: %s", err.Error())
}
if err := tx.Commit(); err != nil {
tx.Rollback()
t.Fatalf("Error committing: %s", err.Error())
}
sqb := models.NewStudioQueryBuilder()
// clear the parent id from the child
tx = database.DB.MustBeginTx(ctx, nil)
updatePartial := models.StudioPartial{
ID: createdChild.ID,
ParentID: &sql.NullInt64{Valid: false},
}
updatedStudio, err := sqb.Update(updatePartial, tx)
if err != nil {
tx.Rollback()
t.Fatalf("Error updated studio: %s", err.Error())
}
if updatedStudio.ParentID.Valid {
t.Error("updated studio has parent ID set")
}
if err := tx.Commit(); err != nil {
tx.Rollback()
t.Fatalf("Error committing: %s", err.Error())
}
}
// TODO Create // TODO Create
// TODO Update // TODO Update
// TODO Destroy // TODO Destroy

View File

@@ -22,12 +22,12 @@ import (
const totalScenes = 12 const totalScenes = 12
const performersNameCase = 3 const performersNameCase = 3
const performersNameNoCase = 2 const performersNameNoCase = 2
const moviesNameCase = 1 const moviesNameCase = 2
const moviesNameNoCase = 1 const moviesNameNoCase = 1
const totalGalleries = 1 const totalGalleries = 1
const tagsNameNoCase = 2 const tagsNameNoCase = 2
const tagsNameCase = 5 const tagsNameCase = 5
const studiosNameCase = 1 const studiosNameCase = 4
const studiosNameNoCase = 1 const studiosNameNoCase = 1
var sceneIDs []int var sceneIDs []int
@@ -61,9 +61,10 @@ const performerIdx1WithDupName = 3
const performerIdxWithDupName = 4 const performerIdxWithDupName = 4
const movieIdxWithScene = 0 const movieIdxWithScene = 0
const movieIdxWithStudio = 1
// movies with dup names start from the end // movies with dup names start from the end
const movieIdxWithDupName = 1 const movieIdxWithDupName = 2
const galleryIdxWithScene = 0 const galleryIdxWithScene = 0
@@ -78,9 +79,12 @@ const tagIdx1WithDupName = 5
const tagIdxWithDupName = 6 const tagIdxWithDupName = 6
const studioIdxWithScene = 0 const studioIdxWithScene = 0
const studioIdxWithMovie = 1
const studioIdxWithChildStudio = 2
const studioIdxWithParentStudio = 3
// studios with dup names start from the end // studios with dup names start from the end
const studioIdxWithDupName = 1 const studioIdxWithDupName = 4
const markerIdxWithScene = 0 const markerIdxWithScene = 0
@@ -196,6 +200,16 @@ func populateDB() error {
return err return err
} }
if err := linkMovieStudio(tx, movieIdxWithStudio, studioIdxWithMovie); err != nil {
tx.Rollback()
return err
}
if err := linkStudioParent(tx, studioIdxWithChildStudio, studioIdxWithParentStudio); err != nil {
tx.Rollback()
return err
}
if err := createMarker(tx, sceneIdxWithMarker, tagIdxWithPrimaryMarker, []int{tagIdxWithMarker}); err != nil { if err := createMarker(tx, sceneIdxWithMarker, tagIdxWithPrimaryMarker, []int{tagIdxWithMarker}); err != nil {
tx.Rollback() tx.Rollback()
return err return err
@@ -312,10 +326,9 @@ func createMovies(tx *sqlx.Tx, n int, o int) error {
const namePlain = "Name" const namePlain = "Name"
const nameNoCase = "NaMe" const nameNoCase = "NaMe"
name := namePlain
for i := 0; i < n+o; i++ { for i := 0; i < n+o; i++ {
index := i index := i
name := namePlain
if i >= n { // i<n tags get normal names if i >= n { // i<n tags get normal names
name = nameNoCase // i>=n movies get dup names if case is not checked name = nameNoCase // i>=n movies get dup names if case is not checked
@@ -323,8 +336,9 @@ func createMovies(tx *sqlx.Tx, n int, o int) error {
} // so count backwards to 0 as needed } // so count backwards to 0 as needed
// movies [ i ] and [ n + o - i - 1 ] should have similar names with only the Name!=NaMe part different // movies [ i ] and [ n + o - i - 1 ] should have similar names with only the Name!=NaMe part different
name = getMovieStringValue(index, name)
movie := models.Movie{ movie := models.Movie{
Name: sql.NullString{String: getMovieStringValue(index, name), Valid: true}, Name: sql.NullString{String: name, Valid: true},
FrontImage: []byte(models.DefaultMovieImage), FrontImage: []byte(models.DefaultMovieImage),
Checksum: utils.MD5FromString(name), Checksum: utils.MD5FromString(name),
} }
@@ -332,7 +346,7 @@ func createMovies(tx *sqlx.Tx, n int, o int) error {
created, err := mqb.Create(movie, tx) created, err := mqb.Create(movie, tx)
if err != nil { if err != nil {
return fmt.Errorf("Error creating movie %v+: %s", movie, err.Error()) return fmt.Errorf("Error creating movie [%d] %v+: %s", i, movie, err.Error())
} }
movieIDs = append(movieIDs, created.ID) movieIDs = append(movieIDs, created.ID)
@@ -432,16 +446,35 @@ func getStudioStringValue(index int, field string) string {
return "studio_" + strconv.FormatInt(int64(index), 10) + "_" + field return "studio_" + strconv.FormatInt(int64(index), 10) + "_" + field
} }
func createStudio(tx *sqlx.Tx, name string, parentID *int64) (*models.Studio, error) {
sqb := models.NewStudioQueryBuilder()
studio := models.Studio{
Name: sql.NullString{String: name, Valid: true},
Image: []byte(models.DefaultStudioImage),
Checksum: utils.MD5FromString(name),
}
if parentID != nil {
studio.ParentID = sql.NullInt64{Int64: *parentID, Valid: true}
}
created, err := sqb.Create(studio, tx)
if err != nil {
return nil, fmt.Errorf("Error creating studio %v+: %s", studio, err.Error())
}
return created, nil
}
//createStudios creates n studios with plain Name and o studios with camel cased NaMe included //createStudios creates n studios with plain Name and o studios with camel cased NaMe included
func createStudios(tx *sqlx.Tx, n int, o int) error { func createStudios(tx *sqlx.Tx, n int, o int) error {
sqb := models.NewStudioQueryBuilder()
const namePlain = "Name" const namePlain = "Name"
const nameNoCase = "NaMe" const nameNoCase = "NaMe"
name := namePlain
for i := 0; i < n+o; i++ { for i := 0; i < n+o; i++ {
index := i index := i
name := namePlain
if i >= n { // i<n studios get normal names if i >= n { // i<n studios get normal names
name = nameNoCase // i>=n studios get dup names if case is not checked name = nameNoCase // i>=n studios get dup names if case is not checked
@@ -449,16 +482,11 @@ func createStudios(tx *sqlx.Tx, n int, o int) error {
} // so count backwards to 0 as needed } // so count backwards to 0 as needed
// studios [ i ] and [ n + o - i - 1 ] should have similar names with only the Name!=NaMe part different // studios [ i ] and [ n + o - i - 1 ] should have similar names with only the Name!=NaMe part different
tag := models.Studio{ name = getStudioStringValue(index, name)
Name: sql.NullString{String: getStudioStringValue(index, name), Valid: true}, created, err := createStudio(tx, name, nil)
Image: []byte(models.DefaultStudioImage),
Checksum: utils.MD5FromString(name),
}
created, err := sqb.Create(tag, tx)
if err != nil { if err != nil {
return fmt.Errorf("Error creating studio %v+: %s", tag, err.Error()) return err
} }
studioIDs = append(studioIDs, created.ID) studioIDs = append(studioIDs, created.ID)
@@ -582,3 +610,27 @@ func linkSceneStudio(tx *sqlx.Tx, sceneIndex, studioIndex int) error {
return err return err
} }
func linkMovieStudio(tx *sqlx.Tx, movieIndex, studioIndex int) error {
mqb := models.NewMovieQueryBuilder()
movie := models.MoviePartial{
ID: movieIDs[movieIndex],
StudioID: &sql.NullInt64{Int64: int64(studioIDs[studioIndex]), Valid: true},
}
_, err := mqb.Update(movie, tx)
return err
}
func linkStudioParent(tx *sqlx.Tx, parentIndex, childIndex int) error {
sqb := models.NewStudioQueryBuilder()
studio := models.StudioPartial{
ID: studioIDs[childIndex],
ParentID: &sql.NullInt64{Int64: int64(studioIDs[parentIndex]), Valid: true},
}
_, err := sqb.Update(studio, tx)
return err
}

View File

@@ -2,6 +2,9 @@ import React from "react";
import ReactMarkdown from "react-markdown"; import ReactMarkdown from "react-markdown";
const markup = ` const markup = `
### ✨ New Features
* Add support for parent/child studios.
### 🎨 Improvements ### 🎨 Improvements
* Show rating as stars in scene page. * Show rating as stars in scene page.
* Add reload scrapers button. * Add reload scrapers button.

View File

@@ -140,6 +140,7 @@ export const AddFilter: React.FC<IAddFilterProps> = (
if ( if (
criterion.type !== "performers" && criterion.type !== "performers" &&
criterion.type !== "studios" && criterion.type !== "studios" &&
criterion.type !== "parent_studios" &&
criterion.type !== "tags" && criterion.type !== "tags" &&
criterion.type !== "sceneTags" && criterion.type !== "sceneTags" &&
criterion.type !== "movies" criterion.type !== "movies"

View File

@@ -23,7 +23,13 @@ type ValidTypes =
type Option = { value: string; label: string }; type Option = { value: string; label: string };
interface ITypeProps { interface ITypeProps {
type?: "performers" | "studios" | "tags" | "sceneTags" | "movies"; type?:
| "performers"
| "studios"
| "parent_studios"
| "tags"
| "sceneTags"
| "movies";
} }
interface IFilterProps { interface IFilterProps {
ids?: string[]; ids?: string[];
@@ -175,7 +181,7 @@ export const MarkerTitleSuggest: React.FC<IMarkerSuggestProps> = (props) => {
export const FilterSelect: React.FC<IFilterProps & ITypeProps> = (props) => export const FilterSelect: React.FC<IFilterProps & ITypeProps> = (props) =>
props.type === "performers" ? ( props.type === "performers" ? (
<PerformerSelect {...(props as IFilterProps)} /> <PerformerSelect {...(props as IFilterProps)} />
) : props.type === "studios" ? ( ) : props.type === "studios" || props.type === "parent_studios" ? (
<StudioSelect {...(props as IFilterProps)} /> <StudioSelect {...(props as IFilterProps)} />
) : props.type === "movies" ? ( ) : props.type === "movies" ? (
<MovieSelect {...(props as IFilterProps)} /> <MovieSelect {...(props as IFilterProps)} />

View File

@@ -3,12 +3,45 @@ import React from "react";
import { Link } from "react-router-dom"; import { Link } from "react-router-dom";
import * as GQL from "src/core/generated-graphql"; import * as GQL from "src/core/generated-graphql";
import { FormattedPlural } from "react-intl"; import { FormattedPlural } from "react-intl";
import { NavUtils } from "src/utils";
interface IProps { interface IProps {
studio: GQL.StudioDataFragment; studio: GQL.StudioDataFragment;
hideParent?: boolean;
} }
export const StudioCard: React.FC<IProps> = ({ studio }) => { function maybeRenderParent(
studio: GQL.StudioDataFragment,
hideParent?: boolean
) {
if (!hideParent && studio.parent_studio) {
return (
<div>
Part of&nbsp;
<Link to={`/studios/${studio.parent_studio.id}`}>
{studio.parent_studio.name}
</Link>
.
</div>
);
}
}
function maybeRenderChildren(studio: GQL.StudioDataFragment) {
if (studio.child_studios.length > 0) {
return (
<div>
Parent of&nbsp;
<Link to={NavUtils.makeChildStudiosUrl(studio)}>
{studio.child_studios.length} studios
</Link>
.
</div>
);
}
}
export const StudioCard: React.FC<IProps> = ({ studio, hideParent }) => {
return ( return (
<Card className="studio-card"> <Card className="studio-card">
<Link to={`/studios/${studio.id}`} className="studio-card-header"> <Link to={`/studios/${studio.id}`} className="studio-card-header">
@@ -29,6 +62,8 @@ export const StudioCard: React.FC<IProps> = ({ studio }) => {
/> />
. .
</span> </span>
{maybeRenderParent(studio, hideParent)}
{maybeRenderChildren(studio)}
</div> </div>
</Card> </Card>
); );

View File

@@ -1,6 +1,6 @@
/* eslint-disable react/no-this-in-sfc */ /* eslint-disable react/no-this-in-sfc */
import { Table } from "react-bootstrap"; import { Table, Tabs, Tab } from "react-bootstrap";
import React, { useEffect, useState } from "react"; import React, { useEffect, useState } from "react";
import { useParams, useHistory } from "react-router-dom"; import { useParams, useHistory } from "react-router-dom";
import cx from "classnames"; import cx from "classnames";
@@ -18,9 +18,11 @@ import {
DetailsEditNavbar, DetailsEditNavbar,
Modal, Modal,
LoadingIndicator, LoadingIndicator,
StudioSelect,
} from "src/components/Shared"; } from "src/components/Shared";
import { useToast } from "src/hooks"; import { useToast } from "src/hooks";
import { StudioScenesPanel } from "./StudioScenesPanel"; import { StudioScenesPanel } from "./StudioScenesPanel";
import { StudioChildrenPanel } from "./StudioChildrenPanel";
export const Studio: React.FC = () => { export const Studio: React.FC = () => {
const history = useHistory(); const history = useHistory();
@@ -36,6 +38,7 @@ export const Studio: React.FC = () => {
const [image, setImage] = useState<string>(); const [image, setImage] = useState<string>();
const [name, setName] = useState<string>(); const [name, setName] = useState<string>();
const [url, setUrl] = useState<string>(); const [url, setUrl] = useState<string>();
const [parentStudioId, setParentStudioId] = useState<string>();
// Studio state // Studio state
const [studio, setStudio] = useState<Partial<GQL.StudioDataFragment>>({}); const [studio, setStudio] = useState<Partial<GQL.StudioDataFragment>>({});
@@ -55,6 +58,7 @@ export const Studio: React.FC = () => {
function updateStudioEditState(state: Partial<GQL.StudioDataFragment>) { function updateStudioEditState(state: Partial<GQL.StudioDataFragment>) {
setName(state.name); setName(state.name);
setUrl(state.url ?? undefined); setUrl(state.url ?? undefined);
setParentStudioId(state?.parent_studio?.id ?? undefined);
} }
function updateStudioData(studioData: Partial<GQL.StudioDataFragment>) { function updateStudioData(studioData: Partial<GQL.StudioDataFragment>) {
@@ -89,6 +93,7 @@ export const Studio: React.FC = () => {
const input: Partial<GQL.StudioCreateInput | GQL.StudioUpdateInput> = { const input: Partial<GQL.StudioCreateInput | GQL.StudioUpdateInput> = {
name, name,
url, url,
parent_id: parentStudioId,
image, image,
}; };
@@ -189,6 +194,20 @@ export const Studio: React.FC = () => {
isEditing: !!isEditing, isEditing: !!isEditing,
onChange: setUrl, onChange: setUrl,
})} })}
<tr>
<td>Parent Studio</td>
<td>
<StudioSelect
onSelect={(items) =>
setParentStudioId(
items.length > 0 ? items[0]?.id : undefined
)
}
ids={parentStudioId ? [parentStudioId] : []}
isDisabled={!isEditing}
/>
</td>
</tr>
</tbody> </tbody>
</Table> </Table>
<DetailsEditNavbar <DetailsEditNavbar
@@ -205,7 +224,14 @@ export const Studio: React.FC = () => {
</div> </div>
{!isNew && ( {!isNew && (
<div className="col-12 col-sm-8"> <div className="col-12 col-sm-8">
<StudioScenesPanel studio={studio} /> <Tabs id="studio-tabs" mountOnEnter>
<Tab eventKey="studio-scenes-panel" title="Scenes">
<StudioScenesPanel studio={studio} />
</Tab>
<Tab eventKey="studio-children-panel" title="Child Studios">
<StudioChildrenPanel studio={studio} />
</Tab>
</Tabs>
</div> </div>
)} )}
{renderDeleteAlert()} {renderDeleteAlert()}

View File

@@ -0,0 +1,47 @@
import React from "react";
import * as GQL from "src/core/generated-graphql";
import { ParentStudiosCriterion } from "src/models/list-filter/criteria/studios";
import { ListFilterModel } from "src/models/list-filter/filter";
import { StudioList } from "../StudioList";
interface IStudioChildrenPanel {
studio: Partial<GQL.StudioDataFragment>;
}
export const StudioChildrenPanel: React.FC<IStudioChildrenPanel> = ({
studio,
}) => {
function filterHook(filter: ListFilterModel) {
const studioValue = { id: studio.id!, label: studio.name! };
// if studio is already present, then we modify it, otherwise add
let parentStudioCriterion = filter.criteria.find((c) => {
return c.type === "parent_studios";
}) as ParentStudiosCriterion;
if (
parentStudioCriterion &&
(parentStudioCriterion.modifier === GQL.CriterionModifier.IncludesAll ||
parentStudioCriterion.modifier === GQL.CriterionModifier.Includes)
) {
// add the studio if not present
if (
!parentStudioCriterion.value.find((p) => {
return p.id === studio.id;
})
) {
parentStudioCriterion.value.push(studioValue);
}
parentStudioCriterion.modifier = GQL.CriterionModifier.IncludesAll;
} else {
// overwrite
parentStudioCriterion = new ParentStudiosCriterion();
parentStudioCriterion.value = [studioValue];
filter.criteria.push(parentStudioCriterion);
}
return filter;
}
return <StudioList fromParent filterHook={filterHook} />;
};

View File

@@ -5,9 +5,19 @@ import { ListFilterModel } from "src/models/list-filter/filter";
import { DisplayMode } from "src/models/list-filter/types"; import { DisplayMode } from "src/models/list-filter/types";
import { StudioCard } from "./StudioCard"; import { StudioCard } from "./StudioCard";
export const StudioList: React.FC = () => { interface IStudioList {
fromParent?: boolean;
filterHook?: (filter: ListFilterModel) => ListFilterModel;
}
export const StudioList: React.FC<IStudioList> = ({
fromParent,
filterHook,
}) => {
const listData = useStudiosList({ const listData = useStudiosList({
renderContent, renderContent,
subComponent: fromParent,
filterHook,
}); });
function renderContent( function renderContent(
@@ -20,7 +30,11 @@ export const StudioList: React.FC = () => {
return ( return (
<div className="row px-xl-5 justify-content-center"> <div className="row px-xl-5 justify-content-center">
{result.data.findStudios.studios.map((studio) => ( {result.data.findStudios.studios.map((studio) => (
<StudioCard key={studio.id} studio={studio} /> <StudioCard
key={studio.id}
studio={studio}
hideParent={fromParent}
/>
))} ))}
</div> </div>
); );

View File

@@ -74,6 +74,7 @@ export const useFindStudios = (filter: ListFilterModel) =>
GQL.useFindStudiosQuery({ GQL.useFindStudiosQuery({
variables: { variables: {
filter: filter.makeFindFilter(), filter: filter.makeFindFilter(),
studio_filter: filter.makeStudioFilter(),
}, },
}); });

View File

@@ -31,7 +31,8 @@ export type CriterionType =
| "tattoos" | "tattoos"
| "piercings" | "piercings"
| "aliases" | "aliases"
| "gender"; | "gender"
| "parent_studios";
type Option = string | number | IOptionType; type Option = string | number | IOptionType;
export type CriterionValue = string | number | ILabeledId[]; export type CriterionValue = string | number | ILabeledId[];
@@ -93,6 +94,8 @@ export abstract class Criterion {
return "Aliases"; return "Aliases";
case "gender": case "gender":
return "Gender"; return "Gender";
case "parent_studios":
return "Parent Studios";
} }
} }

View File

@@ -24,3 +24,13 @@ export class StudiosCriterionOption implements ICriterionOption {
public label: string = Criterion.getLabel("studios"); public label: string = Criterion.getLabel("studios");
public value: CriterionType = "studios"; public value: CriterionType = "studios";
} }
export class ParentStudiosCriterion extends StudiosCriterion {
public type: CriterionType = "parent_studios";
public parameterName: string = "parents";
}
export class ParentStudiosCriterionOption implements ICriterionOption {
public label: string = Criterion.getLabel("parent_studios");
public value: CriterionType = "parent_studios";
}

View File

@@ -17,7 +17,7 @@ import { NoneCriterion } from "./none";
import { PerformersCriterion } from "./performers"; import { PerformersCriterion } from "./performers";
import { RatingCriterion } from "./rating"; import { RatingCriterion } from "./rating";
import { ResolutionCriterion } from "./resolution"; import { ResolutionCriterion } from "./resolution";
import { StudiosCriterion } from "./studios"; import { StudiosCriterion, ParentStudiosCriterion } from "./studios";
import { TagsCriterion } from "./tags"; import { TagsCriterion } from "./tags";
import { GenderCriterion } from "./gender"; import { GenderCriterion } from "./gender";
import { MoviesCriterion } from "./movies"; import { MoviesCriterion } from "./movies";
@@ -50,6 +50,8 @@ export function makeCriteria(type: CriterionType = "none") {
return new PerformersCriterion(); return new PerformersCriterion();
case "studios": case "studios":
return new StudiosCriterion(); return new StudiosCriterion();
case "parent_studios":
return new ParentStudiosCriterion();
case "movies": case "movies":
return new MoviesCriterion(); return new MoviesCriterion();
case "birth_year": case "birth_year":

View File

@@ -7,6 +7,7 @@ import {
SceneMarkerFilterType, SceneMarkerFilterType,
SortDirectionEnum, SortDirectionEnum,
MovieFilterType, MovieFilterType,
StudioFilterType,
} from "src/core/generated-graphql"; } from "src/core/generated-graphql";
import { stringToGender } from "src/core/StashService"; import { stringToGender } from "src/core/StashService";
import { import {
@@ -41,7 +42,12 @@ import {
ResolutionCriterion, ResolutionCriterion,
ResolutionCriterionOption, ResolutionCriterionOption,
} from "./criteria/resolution"; } from "./criteria/resolution";
import { StudiosCriterion, StudiosCriterionOption } from "./criteria/studios"; import {
StudiosCriterion,
StudiosCriterionOption,
ParentStudiosCriterion,
ParentStudiosCriterionOption,
} from "./criteria/studios";
import { import {
SceneTagsCriterionOption, SceneTagsCriterionOption,
TagsCriterion, TagsCriterion,
@@ -159,7 +165,10 @@ export class ListFilterModel {
this.sortBy = "name"; this.sortBy = "name";
this.sortByOptions = ["name", "scenes_count"]; this.sortByOptions = ["name", "scenes_count"];
this.displayModeOptions = [DisplayMode.Grid]; this.displayModeOptions = [DisplayMode.Grid];
this.criterionOptions = [new NoneCriterionOption()]; this.criterionOptions = [
new NoneCriterionOption(),
new ParentStudiosCriterionOption(),
];
break; break;
case FilterMode.Movies: case FilterMode.Movies:
this.sortBy = "name"; this.sortBy = "name";
@@ -576,4 +585,22 @@ export class ListFilterModel {
}); });
return result; return result;
} }
public makeStudioFilter(): StudioFilterType {
const result: StudioFilterType = {};
this.criteria.forEach((criterion) => {
switch (criterion.type) {
case "parent_studios": {
const studCrit = criterion as ParentStudiosCriterion;
result.parents = {
value: studCrit.value.map((studio) => studio.id),
modifier: studCrit.modifier,
};
break;
}
// no default
}
});
return result;
}
} }

View File

@@ -1,6 +1,9 @@
import * as GQL from "src/core/generated-graphql"; import * as GQL from "src/core/generated-graphql";
import { PerformersCriterion } from "src/models/list-filter/criteria/performers"; import { PerformersCriterion } from "src/models/list-filter/criteria/performers";
import { StudiosCriterion } from "src/models/list-filter/criteria/studios"; import {
StudiosCriterion,
ParentStudiosCriterion,
} from "src/models/list-filter/criteria/studios";
import { TagsCriterion } from "src/models/list-filter/criteria/tags"; import { TagsCriterion } from "src/models/list-filter/criteria/tags";
import { ListFilterModel } from "src/models/list-filter/filter"; import { ListFilterModel } from "src/models/list-filter/filter";
import { FilterMode } from "src/models/list-filter/types"; import { FilterMode } from "src/models/list-filter/types";
@@ -30,6 +33,17 @@ const makeStudioScenesUrl = (studio: Partial<GQL.StudioDataFragment>) => {
return `/scenes?${filter.makeQueryParameters()}`; return `/scenes?${filter.makeQueryParameters()}`;
}; };
const makeChildStudiosUrl = (studio: Partial<GQL.StudioDataFragment>) => {
if (!studio.id) return "#";
const filter = new ListFilterModel(FilterMode.Studios);
const criterion = new ParentStudiosCriterion();
criterion.value = [
{ id: studio.id, label: studio.name || `Studio ${studio.id}` },
];
filter.criteria.push(criterion);
return `/studios?${filter.makeQueryParameters()}`;
};
const makeMovieScenesUrl = (movie: Partial<GQL.MovieDataFragment>) => { const makeMovieScenesUrl = (movie: Partial<GQL.MovieDataFragment>) => {
if (!movie.id) return "#"; if (!movie.id) return "#";
const filter = new ListFilterModel(FilterMode.Scenes); const filter = new ListFilterModel(FilterMode.Scenes);
@@ -73,4 +87,5 @@ export default {
makeTagScenesUrl, makeTagScenesUrl,
makeSceneMarkerUrl, makeSceneMarkerUrl,
makeMovieScenesUrl, makeMovieScenesUrl,
makeChildStudiosUrl,
}; };