diff --git a/graphql/documents/data/studio.graphql b/graphql/documents/data/studio.graphql index 9d6af7945..890f9c43f 100644 --- a/graphql/documents/data/studio.graphql +++ b/graphql/documents/data/studio.graphql @@ -3,6 +3,22 @@ fragment StudioData on Studio { checksum name url + parent_studio { + id + checksum + name + url + image_path + scene_count + } + child_studios { + id + checksum + name + url + image_path + scene_count + } image_path scene_count } diff --git a/graphql/documents/mutations/studio.graphql b/graphql/documents/mutations/studio.graphql index 19ece5bb7..539b96358 100644 --- a/graphql/documents/mutations/studio.graphql +++ b/graphql/documents/mutations/studio.graphql @@ -1,9 +1,10 @@ mutation StudioCreate( $name: String!, $url: String, - $image: String) { + $image: String + $parent_id: ID) { - studioCreate(input: { name: $name, url: $url, image: $image }) { + studioCreate(input: { name: $name, url: $url, image: $image, parent_id: $parent_id }) { ...StudioData } } @@ -12,9 +13,10 @@ mutation StudioUpdate( $id: ID! $name: String, $url: String, - $image: String) { + $image: String + $parent_id: ID) { - studioUpdate(input: { id: $id, name: $name, url: $url, image: $image }) { + studioUpdate(input: { id: $id, name: $name, url: $url, image: $image, parent_id: $parent_id }) { ...StudioData } } diff --git a/graphql/documents/queries/studio.graphql b/graphql/documents/queries/studio.graphql index a4898c210..f18f67e93 100644 --- a/graphql/documents/queries/studio.graphql +++ b/graphql/documents/queries/studio.graphql @@ -1,5 +1,5 @@ -query FindStudios($filter: FindFilterType) { - findStudios(filter: $filter) { +query FindStudios($filter: FindFilterType, $studio_filter: StudioFilterType ) { + findStudios(filter: $filter, studio_filter: $studio_filter) { count studios { ...StudioData diff --git a/graphql/schema/schema.graphql b/graphql/schema/schema.graphql index 5f9d9878a..6760dec8e 100644 --- a/graphql/schema/schema.graphql +++ b/graphql/schema/schema.graphql @@ -20,7 +20,7 @@ type Query { """Find a studio by ID""" findStudio(id: ID!): Studio """A function which queries Studio objects""" - findStudios(filter: FindFilterType): FindStudiosResultType! + findStudios(studio_filter: StudioFilterType, filter: FindFilterType): FindStudiosResultType! """Find a movie by ID""" findMovie(id: ID!): Movie diff --git a/graphql/schema/types/filters.graphql b/graphql/schema/types/filters.graphql index 92ac58e2c..80250181d 100644 --- a/graphql/schema/types/filters.graphql +++ b/graphql/schema/types/filters.graphql @@ -91,6 +91,11 @@ input MovieFilterType { studios: MultiCriterionInput } +input StudioFilterType { + """Filter to only include studios with this parent studio""" + parents: MultiCriterionInput +} + enum CriterionModifier { """=""" EQUALS, diff --git a/graphql/schema/types/studio.graphql b/graphql/schema/types/studio.graphql index fc6579794..90309b8c5 100644 --- a/graphql/schema/types/studio.graphql +++ b/graphql/schema/types/studio.graphql @@ -3,7 +3,8 @@ type Studio { checksum: String! name: String! url: String - + parent_studio: Studio + child_studios: [Studio!]! image_path: String # Resolver scene_count: Int # Resolver } @@ -11,6 +12,7 @@ type Studio { input StudioCreateInput { name: String! url: String + parent_id: ID """This should be base64 encoded""" image: String } @@ -19,6 +21,7 @@ input StudioUpdateInput { id: ID! name: String url: String + parent_id: ID, """This should be base64 encoded""" image: String } diff --git a/pkg/api/resolver_model_studio.go b/pkg/api/resolver_model_studio.go index fee539c6c..f068fa12c 100644 --- a/pkg/api/resolver_model_studio.go +++ b/pkg/api/resolver_model_studio.go @@ -2,6 +2,7 @@ package api import ( "context" + "github.com/stashapp/stash/pkg/api/urlbuilders" "github.com/stashapp/stash/pkg/models" ) @@ -31,3 +32,17 @@ func (r *studioResolver) SceneCount(ctx context.Context, obj *models.Studio) (*i res, err := qb.CountByStudioID(obj.ID) return &res, err } + +func (r *studioResolver) ParentStudio(ctx context.Context, obj *models.Studio) (*models.Studio, error) { + if !obj.ParentID.Valid { + return nil, nil + } + + qb := models.NewStudioQueryBuilder() + return qb.Find(int(obj.ParentID.Int64), nil) +} + +func (r *studioResolver) ChildStudios(ctx context.Context, obj *models.Studio) ([]*models.Studio, error) { + qb := models.NewStudioQueryBuilder() + return qb.FindChildren(obj.ID, nil) +} diff --git a/pkg/api/resolver_mutation_studio.go b/pkg/api/resolver_mutation_studio.go index 1d07a809c..1a0fcff47 100644 --- a/pkg/api/resolver_mutation_studio.go +++ b/pkg/api/resolver_mutation_studio.go @@ -7,6 +7,7 @@ import ( "time" "github.com/stashapp/stash/pkg/database" + "github.com/stashapp/stash/pkg/manager" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/utils" ) @@ -40,6 +41,10 @@ func (r *mutationResolver) StudioCreate(ctx context.Context, input models.Studio if input.URL != nil { newStudio.URL = sql.NullString{String: *input.URL, Valid: true} } + if input.ParentID != nil { + parentID, _ := strconv.ParseInt(*input.ParentID, 10, 64) + newStudio.ParentID = sql.NullInt64{Int64: parentID, Valid: true} + } // Start the transaction and save the studio tx := database.DB.MustBeginTx(ctx, nil) @@ -61,33 +66,48 @@ func (r *mutationResolver) StudioCreate(ctx context.Context, input models.Studio func (r *mutationResolver) StudioUpdate(ctx context.Context, input models.StudioUpdateInput) (*models.Studio, error) { // Populate studio from the input studioID, _ := strconv.Atoi(input.ID) - updatedStudio := models.Studio{ + + updatedStudio := models.StudioPartial{ ID: studioID, - UpdatedAt: models.SQLiteTimestamp{Timestamp: time.Now()}, + UpdatedAt: &models.SQLiteTimestamp{Timestamp: time.Now()}, } if input.Image != nil { _, imageData, err := utils.ProcessBase64Image(*input.Image) if err != nil { return nil, err } - updatedStudio.Image = imageData + updatedStudio.Image = &imageData } if input.Name != nil { // generate checksum from studio name rather than image checksum := utils.MD5FromString(*input.Name) - updatedStudio.Name = sql.NullString{String: *input.Name, Valid: true} - updatedStudio.Checksum = checksum + updatedStudio.Name = &sql.NullString{String: *input.Name, Valid: true} + updatedStudio.Checksum = &checksum } if input.URL != nil { - updatedStudio.URL = sql.NullString{String: *input.URL, Valid: true} + updatedStudio.URL = &sql.NullString{String: *input.URL, Valid: true} + } + + if input.ParentID != nil { + parentID, _ := strconv.ParseInt(*input.ParentID, 10, 64) + updatedStudio.ParentID = &sql.NullInt64{Int64: parentID, Valid: true} + } else { + // parent studio must be nullable + updatedStudio.ParentID = &sql.NullInt64{Valid: false} } // Start the transaction and save the studio tx := database.DB.MustBeginTx(ctx, nil) qb := models.NewStudioQueryBuilder() + + if err := manager.ValidateModifyStudio(updatedStudio, tx); err != nil { + tx.Rollback() + return nil, err + } + studio, err := qb.Update(updatedStudio, tx) if err != nil { - _ = tx.Rollback() + tx.Rollback() return nil, err } diff --git a/pkg/api/resolver_query_find_studio.go b/pkg/api/resolver_query_find_studio.go index 4b39130f4..25528c517 100644 --- a/pkg/api/resolver_query_find_studio.go +++ b/pkg/api/resolver_query_find_studio.go @@ -12,9 +12,9 @@ func (r *queryResolver) FindStudio(ctx context.Context, id string) (*models.Stud return qb.Find(idInt, nil) } -func (r *queryResolver) FindStudios(ctx context.Context, filter *models.FindFilterType) (*models.FindStudiosResultType, error) { +func (r *queryResolver) FindStudios(ctx context.Context, studioFilter *models.StudioFilterType, filter *models.FindFilterType) (*models.FindStudiosResultType, error) { qb := models.NewStudioQueryBuilder() - studios, total := qb.Query(filter) + studios, total := qb.Query(studioFilter, filter) return &models.FindStudiosResultType{ Count: total, Studios: studios, diff --git a/pkg/database/database.go b/pkg/database/database.go index 63f6d9cb1..11cdb0085 100644 --- a/pkg/database/database.go +++ b/pkg/database/database.go @@ -19,7 +19,7 @@ import ( var DB *sqlx.DB var dbPath string -var appSchemaVersion uint = 8 +var appSchemaVersion uint = 9 var databaseSchemaVersion uint const sqlite3Driver = "sqlite3ex" diff --git a/pkg/database/migrations/9_studios_parent_studio.up.sql b/pkg/database/migrations/9_studios_parent_studio.up.sql new file mode 100644 index 000000000..38ca11350 --- /dev/null +++ b/pkg/database/migrations/9_studios_parent_studio.up.sql @@ -0,0 +1,3 @@ +ALTER TABLE studios + ADD COLUMN parent_id INTEGER DEFAULT NULL CHECK ( id IS NOT parent_id ) REFERENCES studios(id) on delete set null; + CREATE INDEX index_studios_on_parent_id on studios (parent_id); \ No newline at end of file diff --git a/pkg/manager/jsonschema/studio.go b/pkg/manager/jsonschema/studio.go index 246c36050..ed1f7dea0 100644 --- a/pkg/manager/jsonschema/studio.go +++ b/pkg/manager/jsonschema/studio.go @@ -2,17 +2,19 @@ package jsonschema import ( "fmt" - "github.com/json-iterator/go" - "github.com/stashapp/stash/pkg/models" "os" + + jsoniter "github.com/json-iterator/go" + "github.com/stashapp/stash/pkg/models" ) type Studio struct { - Name string `json:"name,omitempty"` - URL string `json:"url,omitempty"` - Image string `json:"image,omitempty"` - CreatedAt models.JSONTime `json:"created_at,omitempty"` - UpdatedAt models.JSONTime `json:"updated_at,omitempty"` + Name string `json:"name,omitempty"` + URL string `json:"url,omitempty"` + ParentStudio string `json:"parent_studio,omitempty"` + Image string `json:"image,omitempty"` + CreatedAt models.JSONTime `json:"created_at,omitempty"` + UpdatedAt models.JSONTime `json:"updated_at,omitempty"` } func LoadStudioFile(filePath string) (*Studio, error) { diff --git a/pkg/manager/studio.go b/pkg/manager/studio.go new file mode 100644 index 000000000..6de1ee572 --- /dev/null +++ b/pkg/manager/studio.go @@ -0,0 +1,36 @@ +package manager + +import ( + "errors" + "fmt" + + "github.com/jmoiron/sqlx" + "github.com/stashapp/stash/pkg/models" +) + +func ValidateModifyStudio(studio models.StudioPartial, tx *sqlx.Tx) error { + if studio.ParentID == nil || !studio.ParentID.Valid { + return nil + } + + // ensure there is no cyclic dependency + thisID := studio.ID + qb := models.NewStudioQueryBuilder() + + currentParentID := *studio.ParentID + + for currentParentID.Valid { + if currentParentID.Int64 == int64(thisID) { + return errors.New("studio cannot be an ancestor of itself") + } + + currentStudio, err := qb.Find(int(currentParentID.Int64), tx) + if err != nil { + return fmt.Errorf("error finding parent studio: %s", err.Error()) + } + + currentParentID = currentStudio.ParentID + } + + return nil +} diff --git a/pkg/manager/task_export.go b/pkg/manager/task_export.go index 27d8bb6cc..58339ade6 100644 --- a/pkg/manager/task_export.go +++ b/pkg/manager/task_export.go @@ -3,6 +3,12 @@ package manager import ( "context" "fmt" + "math" + "runtime" + "strconv" + "sync" + "time" + "github.com/jmoiron/sqlx" "github.com/stashapp/stash/pkg/database" "github.com/stashapp/stash/pkg/logger" @@ -10,11 +16,6 @@ import ( "github.com/stashapp/stash/pkg/manager/paths" "github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/utils" - "math" - "runtime" - "strconv" - "sync" - "time" ) type ExportTask struct { @@ -395,6 +396,8 @@ func (t *ExportTask) ExportStudios(ctx context.Context, workers int) { func exportStudio(wg *sync.WaitGroup, jobChan <-chan *models.Studio) { defer wg.Done() + studioQB := models.NewStudioQueryBuilder() + for studio := range jobChan { newStudioJSON := jsonschema.Studio{ @@ -408,6 +411,12 @@ func exportStudio(wg *sync.WaitGroup, jobChan <-chan *models.Studio) { if studio.URL.Valid { newStudioJSON.URL = studio.URL.String } + if studio.ParentID.Valid { + parent, _ := studioQB.Find(int(studio.ParentID.Int64), nil) + if parent != nil { + newStudioJSON.ParentStudio = parent.Name.String + } + } newStudioJSON.Image = utils.GetBase64StringFromData(studio.Image) diff --git a/pkg/manager/task_import.go b/pkg/manager/task_import.go index efc2d62c2..aa04ce0ab 100644 --- a/pkg/manager/task_import.go +++ b/pkg/manager/task_import.go @@ -3,6 +3,7 @@ package manager import ( "context" "database/sql" + "fmt" "strconv" "sync" "time" @@ -157,7 +158,8 @@ func (t *ImportTask) ImportPerformers(ctx context.Context) { func (t *ImportTask) ImportStudios(ctx context.Context) { tx := database.DB.MustBeginTx(ctx, nil) - qb := models.NewStudioQueryBuilder() + + pendingParent := make(map[string][]*jsonschema.Studio) for i, mappingJSON := range t.Mappings.Studios { index := i + 1 @@ -172,35 +174,28 @@ func (t *ImportTask) ImportStudios(ctx context.Context) { logger.Progressf("[studios] %d of %d", index, len(t.Mappings.Studios)) - // generate checksum from studio name rather than image - checksum := utils.MD5FromString(studioJSON.Name) - - // Process the base 64 encoded image string - _, imageData, err := utils.ProcessBase64Image(studioJSON.Image) - if err != nil { - _ = tx.Rollback() - logger.Errorf("[studios] <%s> invalid image: %s", mappingJSON.Checksum, err.Error()) - return - } - - // Populate a new studio from the input - newStudio := models.Studio{ - Image: imageData, - Checksum: checksum, - Name: sql.NullString{String: studioJSON.Name, Valid: true}, - URL: sql.NullString{String: studioJSON.URL, Valid: true}, - CreatedAt: models.SQLiteTimestamp{Timestamp: t.getTimeFromJSONTime(studioJSON.CreatedAt)}, - UpdatedAt: models.SQLiteTimestamp{Timestamp: t.getTimeFromJSONTime(studioJSON.UpdatedAt)}, - } - - _, err = qb.Create(newStudio, tx) - if err != nil { - _ = tx.Rollback() + if err := t.ImportStudio(studioJSON, pendingParent, tx); err != nil { + tx.Rollback() logger.Errorf("[studios] <%s> failed to create: %s", mappingJSON.Checksum, err.Error()) return } } + // create the leftover studios, warning for missing parents + if len(pendingParent) > 0 { + logger.Warnf("[studios] importing studios with missing parents") + + for _, s := range pendingParent { + for _, orphanStudioJSON := range s { + if err := t.ImportStudio(orphanStudioJSON, nil, tx); err != nil { + tx.Rollback() + logger.Errorf("[studios] <%s> failed to create: %s", orphanStudioJSON.Name, err.Error()) + return + } + } + } + } + logger.Info("[studios] importing") if err := tx.Commit(); err != nil { logger.Errorf("[studios] import failed to commit: %s", err.Error()) @@ -208,6 +203,74 @@ func (t *ImportTask) ImportStudios(ctx context.Context) { logger.Info("[studios] import complete") } +func (t *ImportTask) ImportStudio(studioJSON *jsonschema.Studio, pendingParent map[string][]*jsonschema.Studio, tx *sqlx.Tx) error { + qb := models.NewStudioQueryBuilder() + + // generate checksum from studio name rather than image + checksum := utils.MD5FromString(studioJSON.Name) + + // Process the base 64 encoded image string + _, imageData, err := utils.ProcessBase64Image(studioJSON.Image) + if err != nil { + return fmt.Errorf("invalid image: %s", err.Error()) + } + + // Populate a new studio from the input + newStudio := models.Studio{ + Image: imageData, + Checksum: checksum, + Name: sql.NullString{String: studioJSON.Name, Valid: true}, + URL: sql.NullString{String: studioJSON.URL, Valid: true}, + CreatedAt: models.SQLiteTimestamp{Timestamp: t.getTimeFromJSONTime(studioJSON.CreatedAt)}, + UpdatedAt: models.SQLiteTimestamp{Timestamp: t.getTimeFromJSONTime(studioJSON.UpdatedAt)}, + } + + // Populate the parent ID + if studioJSON.ParentStudio != "" { + studio, err := qb.FindByName(studioJSON.ParentStudio, tx, false) + if err != nil { + return fmt.Errorf("error finding studio by name <%s>: %s", studioJSON.ParentStudio, err.Error()) + } + + if studio == nil { + // its possible that the parent hasn't been created yet + // do it after it is created + if pendingParent == nil { + logger.Warnf("[studios] studio <%s> does not exist", studioJSON.ParentStudio) + } else { + // add to the pending parent list so that it is created after the parent + s := pendingParent[studioJSON.ParentStudio] + s = append(s, studioJSON) + pendingParent[studioJSON.ParentStudio] = s + + // skip + return nil + } + } else { + newStudio.ParentID = sql.NullInt64{Int64: int64(studio.ID), Valid: true} + } + } + + _, err = qb.Create(newStudio, tx) + if err != nil { + return err + } + + // now create the studios pending this studios creation + s := pendingParent[studioJSON.Name] + for _, childStudioJSON := range s { + // map is nil since we're not checking parent studios at this point + if err := t.ImportStudio(childStudioJSON, nil, tx); err != nil { + return fmt.Errorf("failed to create child studio <%s>: %s", childStudioJSON.Name, err.Error()) + } + } + + // delete the entry from the map so that we know its not left over + delete(pendingParent, studioJSON.Name) + + return err +} + func (t *ImportTask) ImportMovies(ctx context.Context) { tx := database.DB.MustBeginTx(ctx, nil) qb := models.NewMovieQueryBuilder() diff --git a/pkg/models/model_studio.go b/pkg/models/model_studio.go index e978bb439..d880ce159 100644 --- a/pkg/models/model_studio.go +++ b/pkg/models/model_studio.go @@ -10,8 +10,20 @@ type Studio struct { Checksum string `db:"checksum" json:"checksum"` Name sql.NullString `db:"name" json:"name"` URL sql.NullString `db:"url" json:"url"` + ParentID sql.NullInt64 `db:"parent_id,omitempty" json:"parent_id"` CreatedAt SQLiteTimestamp `db:"created_at" json:"created_at"` UpdatedAt SQLiteTimestamp `db:"updated_at" json:"updated_at"` } +type StudioPartial struct { + ID int `db:"id" json:"id"` + Image *[]byte `db:"image" json:"image"` + Checksum *string `db:"checksum" json:"checksum"` + Name *sql.NullString `db:"name" json:"name"` + URL *sql.NullString `db:"url" json:"url"` + ParentID *sql.NullInt64 `db:"parent_id,omitempty" json:"parent_id"` + CreatedAt *SQLiteTimestamp `db:"created_at" json:"created_at"` + UpdatedAt *SQLiteTimestamp `db:"updated_at" json:"updated_at"` +} + var DefaultStudioImage = "" diff --git a/pkg/models/querybuilder_movies.go b/pkg/models/querybuilder_movies.go index cccd4d180..3462b027d 100644 --- a/pkg/models/querybuilder_movies.go +++ b/pkg/models/querybuilder_movies.go @@ -2,7 +2,6 @@ package models import ( "database/sql" - "strconv" "github.com/jmoiron/sqlx" "github.com/stashapp/stash/pkg/database" @@ -148,7 +147,7 @@ func (qb *MovieQueryBuilder) Query(movieFilter *MovieFilterType, findFilter *Fin args = append(args, studioID) } - whereClause, havingClause := qb.getMultiCriterionClause("studio", "", "studio_id", studiosFilter) + whereClause, havingClause := getMultiCriterionClause("movies", "studio", "", "", "studio_id", studiosFilter) whereClauses = appendClause(whereClauses, whereClause) havingClauses = appendClause(havingClauses, havingClause) } @@ -165,29 +164,6 @@ func (qb *MovieQueryBuilder) Query(movieFilter *MovieFilterType, findFilter *Fin return movies, countResult } -// returns where clause and having clause -func (qb *MovieQueryBuilder) getMultiCriterionClause(table string, joinTable string, joinTableField string, criterion *MultiCriterionInput) (string, string) { - whereClause := "" - havingClause := "" - if criterion.Modifier == CriterionModifierIncludes { - // includes any of the provided ids - whereClause = table + ".id IN " + getInBinding(len(criterion.Value)) - } else if criterion.Modifier == CriterionModifierIncludesAll { - // includes all of the provided ids - whereClause = table + ".id IN " + getInBinding(len(criterion.Value)) - havingClause = "count(distinct " + table + ".id) IS " + strconv.Itoa(len(criterion.Value)) - } else if criterion.Modifier == CriterionModifierExcludes { - // excludes all of the provided ids - if joinTable != "" { - whereClause = "not exists (select " + joinTable + ".movie_id from " + joinTable + " where " + joinTable + ".movie_id = movies.id and " + joinTable + "." + joinTableField + " in " + getInBinding(len(criterion.Value)) + ")" - } else { - whereClause = "not exists (select m.id from movies as m where m.id = movies.id and m." + joinTableField + " in " + getInBinding(len(criterion.Value)) + ")" - } - } - - return whereClause, havingClause -} - func (qb *MovieQueryBuilder) getMovieSort(findFilter *FindFilterType) string { var sort string var direction string diff --git a/pkg/models/querybuilder_movies_test.go b/pkg/models/querybuilder_movies_test.go index a11f94f18..99f7ece8b 100644 --- a/pkg/models/querybuilder_movies_test.go +++ b/pkg/models/querybuilder_movies_test.go @@ -3,6 +3,7 @@ package models_test import ( + "strconv" "strings" "testing" @@ -86,6 +87,42 @@ func TestMovieFindByNames(t *testing.T) { assert.Equal(t, strings.ToLower(movieNames[movieIdxWithScene]), strings.ToLower(movies[1].Name.String)) } +func TestMovieQueryStudio(t *testing.T) { + mqb := models.NewMovieQueryBuilder() + studioCriterion := models.MultiCriterionInput{ + Value: []string{ + strconv.Itoa(studioIDs[studioIdxWithMovie]), + }, + Modifier: models.CriterionModifierIncludes, + } + + movieFilter := models.MovieFilterType{ + Studios: &studioCriterion, + } + + movies, _ := mqb.Query(&movieFilter, nil) + + assert.Len(t, movies, 1) + + // ensure id is correct + assert.Equal(t, movieIDs[movieIdxWithStudio], movies[0].ID) + + studioCriterion = models.MultiCriterionInput{ + Value: []string{ + strconv.Itoa(studioIDs[studioIdxWithMovie]), + }, + Modifier: models.CriterionModifierExcludes, + } + + q := getMovieStringValue(movieIdxWithStudio, titleField) + findFilter := models.FindFilterType{ + Q: &q, + } + + movies, _ = mqb.Query(&movieFilter, &findFilter) + assert.Len(t, movies, 0) +} + // TODO Update // TODO Destroy // TODO Find diff --git a/pkg/models/querybuilder_scene.go b/pkg/models/querybuilder_scene.go index 2a8d162c4..854b824f6 100644 --- a/pkg/models/querybuilder_scene.go +++ b/pkg/models/querybuilder_scene.go @@ -2,7 +2,6 @@ package models import ( "database/sql" - "strconv" "strings" "github.com/jmoiron/sqlx" @@ -329,7 +328,7 @@ func (qb *SceneQueryBuilder) Query(sceneFilter *SceneFilterType, findFilter *Fin } query.body += " LEFT JOIN tags on tags_join.tag_id = tags.id" - whereClause, havingClause := getMultiCriterionClause("tags", "scenes_tags", "tag_id", tagsFilter) + whereClause, havingClause := getMultiCriterionClause("scenes", "tags", "scenes_tags", "scene_id", "tag_id", tagsFilter) query.addWhere(whereClause) query.addHaving(havingClause) } @@ -340,7 +339,7 @@ func (qb *SceneQueryBuilder) Query(sceneFilter *SceneFilterType, findFilter *Fin } query.body += " LEFT JOIN performers ON performers_join.performer_id = performers.id" - whereClause, havingClause := getMultiCriterionClause("performers", "performers_scenes", "performer_id", performersFilter) + whereClause, havingClause := getMultiCriterionClause("scenes", "performers", "performers_scenes", "scene_id", "performer_id", performersFilter) query.addWhere(whereClause) query.addHaving(havingClause) } @@ -350,7 +349,7 @@ func (qb *SceneQueryBuilder) Query(sceneFilter *SceneFilterType, findFilter *Fin query.addArg(studioID) } - whereClause, havingClause := getMultiCriterionClause("studio", "", "studio_id", studiosFilter) + whereClause, havingClause := getMultiCriterionClause("scenes", "studio", "", "", "studio_id", studiosFilter) query.addWhere(whereClause) query.addHaving(havingClause) } @@ -361,7 +360,7 @@ func (qb *SceneQueryBuilder) Query(sceneFilter *SceneFilterType, findFilter *Fin } query.body += " LEFT JOIN movies ON movies_join.movie_id = movies.id" - whereClause, havingClause := getMultiCriterionClause("movies", "movies_scenes", "movie_id", moviesFilter) + whereClause, havingClause := getMultiCriterionClause("scenes", "movies", "movies_scenes", "scene_id", "movie_id", moviesFilter) query.addWhere(whereClause) query.addHaving(havingClause) } @@ -414,29 +413,6 @@ func getDurationWhereClause(durationFilter IntCriterionInput) (string, []interfa return clause, args } -// returns where clause and having clause -func getMultiCriterionClause(table string, joinTable string, joinTableField string, criterion *MultiCriterionInput) (string, string) { - whereClause := "" - havingClause := "" - if criterion.Modifier == CriterionModifierIncludes { - // includes any of the provided ids - whereClause = table + ".id IN " + getInBinding(len(criterion.Value)) - } else if criterion.Modifier == CriterionModifierIncludesAll { - // includes all of the provided ids - whereClause = table + ".id IN " + getInBinding(len(criterion.Value)) - havingClause = "count(distinct " + table + ".id) IS " + strconv.Itoa(len(criterion.Value)) - } else if criterion.Modifier == CriterionModifierExcludes { - // excludes all of the provided ids - if joinTable != "" { - whereClause = "not exists (select " + joinTable + ".scene_id from " + joinTable + " where " + joinTable + ".scene_id = scenes.id and " + joinTable + "." + joinTableField + " in " + getInBinding(len(criterion.Value)) + ")" - } else { - whereClause = "not exists (select s.id from scenes as s where s.id = scenes.id and s." + joinTableField + " in " + getInBinding(len(criterion.Value)) + ")" - } - } - - return whereClause, havingClause -} - func (qb *SceneQueryBuilder) QueryAllByPathRegex(regex string) ([]*Scene, error) { var args []interface{} body := selectDistinctIDs("scenes") + " WHERE scenes.path regexp ?" diff --git a/pkg/models/querybuilder_scene_test.go b/pkg/models/querybuilder_scene_test.go index df6b0e817..161ec49f3 100644 --- a/pkg/models/querybuilder_scene_test.go +++ b/pkg/models/querybuilder_scene_test.go @@ -676,6 +676,42 @@ func TestSceneQueryStudio(t *testing.T) { assert.Len(t, scenes, 0) } +func TestSceneQueryMovies(t *testing.T) { + sqb := models.NewSceneQueryBuilder() + movieCriterion := models.MultiCriterionInput{ + Value: []string{ + strconv.Itoa(movieIDs[movieIdxWithScene]), + }, + Modifier: models.CriterionModifierIncludes, + } + + sceneFilter := models.SceneFilterType{ + Movies: &movieCriterion, + } + + scenes, _ := sqb.Query(&sceneFilter, nil) + + assert.Len(t, scenes, 1) + + // ensure id is correct + assert.Equal(t, sceneIDs[sceneIdxWithMovie], scenes[0].ID) + + movieCriterion = models.MultiCriterionInput{ + Value: []string{ + strconv.Itoa(movieIDs[movieIdxWithScene]), + }, + Modifier: models.CriterionModifierExcludes, + } + + q := getSceneStringValue(sceneIdxWithMovie, titleField) + findFilter := models.FindFilterType{ + Q: &q, + } + + scenes, _ = sqb.Query(&sceneFilter, &findFilter) + assert.Len(t, scenes, 0) +} + func TestSceneQuerySorting(t *testing.T) { sort := titleField direction := models.SortDirectionEnumAsc diff --git a/pkg/models/querybuilder_sql.go b/pkg/models/querybuilder_sql.go index 2944462c2..063ac44c4 100644 --- a/pkg/models/querybuilder_sql.go +++ b/pkg/models/querybuilder_sql.go @@ -239,6 +239,29 @@ func getIntCriterionWhereClause(column string, input IntCriterionInput) (string, return column + " " + binding, count } +// returns where clause and having clause +func getMultiCriterionClause(primaryTable, foreignTable, joinTable, primaryFK, foreignFK string, criterion *MultiCriterionInput) (string, string) { + whereClause := "" + havingClause := "" + if criterion.Modifier == CriterionModifierIncludes { + // includes any of the provided ids + whereClause = foreignTable + ".id IN " + getInBinding(len(criterion.Value)) + } else if criterion.Modifier == CriterionModifierIncludesAll { + // includes all of the provided ids + whereClause = foreignTable + ".id IN " + getInBinding(len(criterion.Value)) + havingClause = "count(distinct " + foreignTable + ".id) IS " + strconv.Itoa(len(criterion.Value)) + } else if criterion.Modifier == CriterionModifierExcludes { + // excludes all of the provided ids + if joinTable != "" { + whereClause = "not exists (select " + joinTable + "." + primaryFK + " from " + joinTable + " where " + joinTable + "." + primaryFK + " = " + primaryTable + ".id and " + joinTable + "." + foreignFK + " in " + getInBinding(len(criterion.Value)) + ")" + } else { + whereClause = "not exists (select s.id from " + primaryTable + " as s where s.id = " + primaryTable + ".id and s." + foreignFK + " in " + getInBinding(len(criterion.Value)) + ")" + } + } + + return whereClause, havingClause +} + func runIdsQuery(query string, args []interface{}) ([]int, error) { var result []struct { Int int `db:"id"` diff --git a/pkg/models/querybuilder_studio.go b/pkg/models/querybuilder_studio.go index 2b65bba2a..90812e956 100644 --- a/pkg/models/querybuilder_studio.go +++ b/pkg/models/querybuilder_studio.go @@ -16,8 +16,8 @@ func NewStudioQueryBuilder() StudioQueryBuilder { func (qb *StudioQueryBuilder) Create(newStudio Studio, tx *sqlx.Tx) (*Studio, error) { ensureTx(tx) result, err := tx.NamedExec( - `INSERT INTO studios (image, checksum, name, url, created_at, updated_at) - VALUES (:image, :checksum, :name, :url, :created_at, :updated_at) + `INSERT INTO studios (image, checksum, name, url, parent_id, created_at, updated_at) + VALUES (:image, :checksum, :name, :url, :parent_id, :created_at, :updated_at) `, newStudio, ) @@ -35,20 +35,21 @@ func (qb *StudioQueryBuilder) Create(newStudio Studio, tx *sqlx.Tx) (*Studio, er return &newStudio, nil } -func (qb *StudioQueryBuilder) Update(updatedStudio Studio, tx *sqlx.Tx) (*Studio, error) { +func (qb *StudioQueryBuilder) Update(updatedStudio StudioPartial, tx *sqlx.Tx) (*Studio, error) { ensureTx(tx) _, err := tx.NamedExec( - `UPDATE studios SET `+SQLGenKeys(updatedStudio)+` WHERE studios.id = :id`, + `UPDATE studios SET `+SQLGenKeysPartial(updatedStudio)+` WHERE studios.id = :id`, updatedStudio, ) if err != nil { return nil, err } - if err := tx.Get(&updatedStudio, `SELECT * FROM studios WHERE id = ? LIMIT 1`, updatedStudio.ID); err != nil { + var ret Studio + if err := tx.Get(&ret, `SELECT * FROM studios WHERE id = ? LIMIT 1`, updatedStudio.ID); err != nil { return nil, err } - return &updatedStudio, nil + return &ret, nil } func (qb *StudioQueryBuilder) Destroy(id string, tx *sqlx.Tx) error { @@ -73,6 +74,12 @@ func (qb *StudioQueryBuilder) Find(id int, tx *sqlx.Tx) (*Studio, error) { return qb.queryStudio(query, args, tx) } +func (qb *StudioQueryBuilder) FindChildren(id int, tx *sqlx.Tx) ([]*Studio, error) { + query := "SELECT studios.* FROM studios WHERE studios.parent_id = ?" + args := []interface{}{id} + return qb.queryStudios(query, args, tx) +} + func (qb *StudioQueryBuilder) FindBySceneID(sceneID int) (*Studio, error) { query := "SELECT studios.* FROM studios JOIN scenes ON studios.id = scenes.studio_id WHERE scenes.id = ? LIMIT 1" args := []interface{}{sceneID} @@ -101,7 +108,10 @@ func (qb *StudioQueryBuilder) AllSlim() ([]*Studio, error) { return qb.queryStudios("SELECT studios.id, studios.name FROM studios "+qb.getStudioSort(nil), nil, nil) } -func (qb *StudioQueryBuilder) Query(findFilter *FindFilterType) ([]*Studio, int) { +func (qb *StudioQueryBuilder) Query(studioFilter *StudioFilterType, findFilter *FindFilterType) ([]*Studio, int) { + if studioFilter == nil { + studioFilter = &StudioFilterType{} + } if findFilter == nil { findFilter = &FindFilterType{} } @@ -116,11 +126,26 @@ func (qb *StudioQueryBuilder) Query(findFilter *FindFilterType) ([]*Studio, int) if q := findFilter.Q; q != nil && *q != "" { searchColumns := []string{"studios.name"} + clause, thisArgs := getSearchBinding(searchColumns, *q, false) whereClauses = append(whereClauses, clause) args = append(args, thisArgs...) } + if parentsFilter := studioFilter.Parents; parentsFilter != nil && len(parentsFilter.Value) > 0 { + body += ` + left join studios as parent_studio on parent_studio.id = studios.parent_id + ` + + for _, studioID := range parentsFilter.Value { + args = append(args, studioID) + } + + whereClause, havingClause := getMultiCriterionClause("studios", "parent_studio", "", "", "parent_id", parentsFilter) + whereClauses = appendClause(whereClauses, whereClause) + havingClauses = appendClause(havingClauses, havingClause) + } + sortAndPagination := qb.getStudioSort(findFilter) + getPagination(findFilter) idsResult, countResult := executeFindQuery("studios", body, args, sortAndPagination, whereClauses, havingClauses) diff --git a/pkg/models/querybuilder_studio_test.go b/pkg/models/querybuilder_studio_test.go index 5f7460844..f5dea9172 100644 --- a/pkg/models/querybuilder_studio_test.go +++ b/pkg/models/querybuilder_studio_test.go @@ -3,9 +3,13 @@ package models_test import ( + "context" + "database/sql" + "strconv" "strings" "testing" + "github.com/stashapp/stash/pkg/database" "github.com/stashapp/stash/pkg/models" "github.com/stretchr/testify/assert" ) @@ -39,6 +43,173 @@ func TestStudioFindByName(t *testing.T) { } +func TestStudioQueryParent(t *testing.T) { + sqb := models.NewStudioQueryBuilder() + studioCriterion := models.MultiCriterionInput{ + Value: []string{ + strconv.Itoa(studioIDs[studioIdxWithChildStudio]), + }, + Modifier: models.CriterionModifierIncludes, + } + + studioFilter := models.StudioFilterType{ + Parents: &studioCriterion, + } + + studios, _ := sqb.Query(&studioFilter, nil) + + assert.Len(t, studios, 1) + + // ensure id is correct + assert.Equal(t, sceneIDs[studioIdxWithParentStudio], studios[0].ID) + + studioCriterion = models.MultiCriterionInput{ + Value: []string{ + strconv.Itoa(studioIDs[studioIdxWithChildStudio]), + }, + Modifier: models.CriterionModifierExcludes, + } + + q := getStudioStringValue(studioIdxWithParentStudio, titleField) + findFilter := models.FindFilterType{ + Q: &q, + } + + studios, _ = sqb.Query(&studioFilter, &findFilter) + assert.Len(t, studios, 0) +} + +func TestStudioDestroyParent(t *testing.T) { + const parentName = "parent" + const childName = "child" + + // create parent and child studios + ctx := context.TODO() + tx := database.DB.MustBeginTx(ctx, nil) + + createdParent, err := createStudio(tx, parentName, nil) + if err != nil { + tx.Rollback() + t.Fatalf("Error creating parent studio: %s", err.Error()) + } + + parentID := int64(createdParent.ID) + createdChild, err := createStudio(tx, childName, &parentID) + if err != nil { + tx.Rollback() + t.Fatalf("Error creating child studio: %s", err.Error()) + } + + if err := tx.Commit(); err != nil { + tx.Rollback() + t.Fatalf("Error committing: %s", err.Error()) + } + + sqb := models.NewStudioQueryBuilder() + + // destroy the parent + tx = database.DB.MustBeginTx(ctx, nil) + + err = sqb.Destroy(strconv.Itoa(createdParent.ID), tx) + if err != nil { + tx.Rollback() + t.Fatalf("Error destroying parent studio: %s", err.Error()) + } + + if err := tx.Commit(); err != nil { + tx.Rollback() + t.Fatalf("Error committing: %s", err.Error()) + } + + // destroy the child + tx = database.DB.MustBeginTx(ctx, nil) + + err = sqb.Destroy(strconv.Itoa(createdChild.ID), tx) + if err != nil { + tx.Rollback() + t.Fatalf("Error destroying child studio: %s", err.Error()) + } + + if err := tx.Commit(); err != nil { + tx.Rollback() + t.Fatalf("Error committing: %s", err.Error()) + } +} + +func TestStudioFindChildren(t *testing.T) { + sqb := models.NewStudioQueryBuilder() + + studios, err := sqb.FindChildren(studioIDs[studioIdxWithChildStudio], nil) + + if err != nil { + t.Fatalf("error calling FindChildren: %s", err.Error()) + } + + assert.Len(t, studios, 1) + assert.Equal(t, studioIDs[studioIdxWithParentStudio], studios[0].ID) + + studios, err = sqb.FindChildren(0, nil) + + if err != nil { + t.Fatalf("error calling FindChildren: %s", err.Error()) + } + + assert.Len(t, studios, 0) +} + +func TestStudioUpdateClearParent(t *testing.T) { + const parentName = "clearParent_parent" + const childName = "clearParent_child" + + // create parent and child studios + ctx := context.TODO() + tx := database.DB.MustBeginTx(ctx, nil) + + createdParent, err := createStudio(tx, parentName, nil) + if err != nil { + tx.Rollback() + t.Fatalf("Error creating parent studio: %s", err.Error()) + } + + parentID := int64(createdParent.ID) + createdChild, err := createStudio(tx, childName, &parentID) + if err != nil { + tx.Rollback() + t.Fatalf("Error creating child studio: %s", err.Error()) + } + + if err := tx.Commit(); err != nil { + tx.Rollback() + t.Fatalf("Error committing: %s", err.Error()) + } + + sqb := models.NewStudioQueryBuilder() + + // clear the parent id from the child + tx = database.DB.MustBeginTx(ctx, nil) + + updatePartial := models.StudioPartial{ + ID: createdChild.ID, + ParentID: &sql.NullInt64{Valid: false}, + } + + updatedStudio, err := sqb.Update(updatePartial, tx) + + if err != nil { + tx.Rollback() + t.Fatalf("Error updated studio: %s", err.Error()) + } + + if updatedStudio.ParentID.Valid { + t.Error("updated studio has parent ID set") + } + + if err := tx.Commit(); err != nil { + tx.Rollback() + t.Fatalf("Error committing: %s", err.Error()) + } +} + // TODO Create // TODO Update // TODO Destroy diff --git a/pkg/models/setup_test.go b/pkg/models/setup_test.go index 0cfb25328..95e0a1230 100644 --- a/pkg/models/setup_test.go +++ b/pkg/models/setup_test.go @@ -22,12 +22,12 @@ import ( const totalScenes = 12 const performersNameCase = 3 const performersNameNoCase = 2 -const moviesNameCase = 1 +const moviesNameCase = 2 const moviesNameNoCase = 1 const totalGalleries = 1 const tagsNameNoCase = 2 const tagsNameCase = 5 -const studiosNameCase = 1 +const studiosNameCase = 4 const studiosNameNoCase = 1 var sceneIDs []int @@ -61,9 +61,10 @@ const performerIdx1WithDupName = 3 const performerIdxWithDupName = 4 const movieIdxWithScene = 0 +const movieIdxWithStudio = 1 // movies with dup names start from the end -const movieIdxWithDupName = 1 +const movieIdxWithDupName = 2 const galleryIdxWithScene = 0 @@ -78,9 +79,12 @@ const tagIdx1WithDupName = 5 const tagIdxWithDupName = 6 const studioIdxWithScene = 0 +const studioIdxWithMovie = 1 +const studioIdxWithChildStudio = 2 +const studioIdxWithParentStudio = 3 // studios with dup names start from the end -const studioIdxWithDupName = 1 +const studioIdxWithDupName = 4 const markerIdxWithScene = 0 @@ -196,6 +200,16 @@ func populateDB() error { return err } + if err := linkMovieStudio(tx, movieIdxWithStudio, studioIdxWithMovie); err != nil { + tx.Rollback() + return err + } + + if err := linkStudioParent(tx, studioIdxWithChildStudio, studioIdxWithParentStudio); err != nil { + tx.Rollback() + return err + } + if err := createMarker(tx, sceneIdxWithMarker, tagIdxWithPrimaryMarker, []int{tagIdxWithMarker}); err != nil { tx.Rollback() return err @@ -312,10 +326,9 @@ func createMovies(tx *sqlx.Tx, n int, o int) error { const namePlain = "Name" const nameNoCase = "NaMe" - name := namePlain - for i := 0; i < n+o; i++ { index := i + name := namePlain if i >= n { // i=n movies get dup names if case is not checked @@ -323,8 +336,9 @@ func createMovies(tx *sqlx.Tx, n int, o int) error { } // so count backwards to 0 as needed // movies [ i ] and [ n + o - i - 1 ] should have similar names with only the Name!=NaMe part different + name = getMovieStringValue(index, name) movie := models.Movie{ - Name: sql.NullString{String: getMovieStringValue(index, name), Valid: true}, + Name: sql.NullString{String: name, Valid: true}, FrontImage: []byte(models.DefaultMovieImage), Checksum: utils.MD5FromString(name), } @@ -332,7 +346,7 @@ func createMovies(tx *sqlx.Tx, n int, o int) error { created, err := mqb.Create(movie, tx) if err != nil { - return fmt.Errorf("Error creating movie %v+: %s", movie, err.Error()) + return fmt.Errorf("Error creating movie [%d] %v+: %s", i, movie, err.Error()) } movieIDs = append(movieIDs, created.ID) @@ -432,16 +446,35 @@ func getStudioStringValue(index int, field string) string { return "studio_" + strconv.FormatInt(int64(index), 10) + "_" + field } +func createStudio(tx *sqlx.Tx, name string, parentID *int64) (*models.Studio, error) { + sqb := models.NewStudioQueryBuilder() + studio := models.Studio{ + Name: sql.NullString{String: name, Valid: true}, + Image: []byte(models.DefaultStudioImage), + Checksum: utils.MD5FromString(name), + } + + if parentID != nil { + studio.ParentID = sql.NullInt64{Int64: *parentID, Valid: true} + } + + created, err := sqb.Create(studio, tx) + + if err != nil { + return nil, fmt.Errorf("Error creating studio %v+: %s", studio, err.Error()) + } + + return created, nil +} + //createStudios creates n studios with plain Name and o studios with camel cased NaMe included func createStudios(tx *sqlx.Tx, n int, o int) error { - sqb := models.NewStudioQueryBuilder() const namePlain = "Name" const nameNoCase = "NaMe" - name := namePlain - for i := 0; i < n+o; i++ { index := i + name := namePlain if i >= n { // i=n studios get dup names if case is not checked @@ -449,16 +482,11 @@ func createStudios(tx *sqlx.Tx, n int, o int) error { } // so count backwards to 0 as needed // studios [ i ] and [ n + o - i - 1 ] should have similar names with only the Name!=NaMe part different - tag := models.Studio{ - Name: sql.NullString{String: getStudioStringValue(index, name), Valid: true}, - Image: []byte(models.DefaultStudioImage), - Checksum: utils.MD5FromString(name), - } - - created, err := sqb.Create(tag, tx) + name = getStudioStringValue(index, name) + created, err := createStudio(tx, name, nil) if err != nil { - return fmt.Errorf("Error creating studio %v+: %s", tag, err.Error()) + return err } studioIDs = append(studioIDs, created.ID) @@ -582,3 +610,27 @@ func linkSceneStudio(tx *sqlx.Tx, sceneIndex, studioIndex int) error { return err } + +func linkMovieStudio(tx *sqlx.Tx, movieIndex, studioIndex int) error { + mqb := models.NewMovieQueryBuilder() + + movie := models.MoviePartial{ + ID: movieIDs[movieIndex], + StudioID: &sql.NullInt64{Int64: int64(studioIDs[studioIndex]), Valid: true}, + } + _, err := mqb.Update(movie, tx) + + return err +} + +func linkStudioParent(tx *sqlx.Tx, parentIndex, childIndex int) error { + sqb := models.NewStudioQueryBuilder() + + studio := models.StudioPartial{ + ID: studioIDs[childIndex], + ParentID: &sql.NullInt64{Int64: int64(studioIDs[parentIndex]), Valid: true}, + } + _, err := sqb.Update(studio, tx) + + return err +} diff --git a/ui/v2.5/src/components/Changelog/versions/v030.tsx b/ui/v2.5/src/components/Changelog/versions/v030.tsx index 780a84047..94cf6a922 100644 --- a/ui/v2.5/src/components/Changelog/versions/v030.tsx +++ b/ui/v2.5/src/components/Changelog/versions/v030.tsx @@ -2,6 +2,9 @@ import React from "react"; import ReactMarkdown from "react-markdown"; const markup = ` +### ✨ New Features +* Add support for parent/child studios. + ### 🎨 Improvements * Show rating as stars in scene page. * Add reload scrapers button. diff --git a/ui/v2.5/src/components/List/AddFilter.tsx b/ui/v2.5/src/components/List/AddFilter.tsx index 5ed6f1b37..3b7b61bad 100644 --- a/ui/v2.5/src/components/List/AddFilter.tsx +++ b/ui/v2.5/src/components/List/AddFilter.tsx @@ -140,6 +140,7 @@ export const AddFilter: React.FC = ( if ( criterion.type !== "performers" && criterion.type !== "studios" && + criterion.type !== "parent_studios" && criterion.type !== "tags" && criterion.type !== "sceneTags" && criterion.type !== "movies" diff --git a/ui/v2.5/src/components/Shared/Select.tsx b/ui/v2.5/src/components/Shared/Select.tsx index 737705dcc..9b75d2ea7 100644 --- a/ui/v2.5/src/components/Shared/Select.tsx +++ b/ui/v2.5/src/components/Shared/Select.tsx @@ -23,7 +23,13 @@ type ValidTypes = type Option = { value: string; label: string }; interface ITypeProps { - type?: "performers" | "studios" | "tags" | "sceneTags" | "movies"; + type?: + | "performers" + | "studios" + | "parent_studios" + | "tags" + | "sceneTags" + | "movies"; } interface IFilterProps { ids?: string[]; @@ -175,7 +181,7 @@ export const MarkerTitleSuggest: React.FC = (props) => { export const FilterSelect: React.FC = (props) => props.type === "performers" ? ( - ) : props.type === "studios" ? ( + ) : props.type === "studios" || props.type === "parent_studios" ? ( ) : props.type === "movies" ? ( diff --git a/ui/v2.5/src/components/Studios/StudioCard.tsx b/ui/v2.5/src/components/Studios/StudioCard.tsx index 4b1f8f86e..8a1bc4b7b 100644 --- a/ui/v2.5/src/components/Studios/StudioCard.tsx +++ b/ui/v2.5/src/components/Studios/StudioCard.tsx @@ -3,12 +3,45 @@ import React from "react"; import { Link } from "react-router-dom"; import * as GQL from "src/core/generated-graphql"; import { FormattedPlural } from "react-intl"; +import { NavUtils } from "src/utils"; interface IProps { studio: GQL.StudioDataFragment; + hideParent?: boolean; } -export const StudioCard: React.FC = ({ studio }) => { +function maybeRenderParent( + studio: GQL.StudioDataFragment, + hideParent?: boolean +) { + if (!hideParent && studio.parent_studio) { + return ( +
+ Part of  + + {studio.parent_studio.name} + + . +
+ ); + } +} + +function maybeRenderChildren(studio: GQL.StudioDataFragment) { + if (studio.child_studios.length > 0) { + return ( +
+ Parent of  + + {studio.child_studios.length} studios + + . +
+ ); + } +} + +export const StudioCard: React.FC = ({ studio, hideParent }) => { return ( @@ -29,6 +62,8 @@ export const StudioCard: React.FC = ({ studio }) => { /> . + {maybeRenderParent(studio, hideParent)} + {maybeRenderChildren(studio)} ); diff --git a/ui/v2.5/src/components/Studios/StudioDetails/Studio.tsx b/ui/v2.5/src/components/Studios/StudioDetails/Studio.tsx index 80fef1d59..84d105978 100644 --- a/ui/v2.5/src/components/Studios/StudioDetails/Studio.tsx +++ b/ui/v2.5/src/components/Studios/StudioDetails/Studio.tsx @@ -1,6 +1,6 @@ /* eslint-disable react/no-this-in-sfc */ -import { Table } from "react-bootstrap"; +import { Table, Tabs, Tab } from "react-bootstrap"; import React, { useEffect, useState } from "react"; import { useParams, useHistory } from "react-router-dom"; import cx from "classnames"; @@ -18,9 +18,11 @@ import { DetailsEditNavbar, Modal, LoadingIndicator, + StudioSelect, } from "src/components/Shared"; import { useToast } from "src/hooks"; import { StudioScenesPanel } from "./StudioScenesPanel"; +import { StudioChildrenPanel } from "./StudioChildrenPanel"; export const Studio: React.FC = () => { const history = useHistory(); @@ -36,6 +38,7 @@ export const Studio: React.FC = () => { const [image, setImage] = useState(); const [name, setName] = useState(); const [url, setUrl] = useState(); + const [parentStudioId, setParentStudioId] = useState(); // Studio state const [studio, setStudio] = useState>({}); @@ -55,6 +58,7 @@ export const Studio: React.FC = () => { function updateStudioEditState(state: Partial) { setName(state.name); setUrl(state.url ?? undefined); + setParentStudioId(state?.parent_studio?.id ?? undefined); } function updateStudioData(studioData: Partial) { @@ -89,6 +93,7 @@ export const Studio: React.FC = () => { const input: Partial = { name, url, + parent_id: parentStudioId, image, }; @@ -189,6 +194,20 @@ export const Studio: React.FC = () => { isEditing: !!isEditing, onChange: setUrl, })} + + Parent Studio + + + setParentStudioId( + items.length > 0 ? items[0]?.id : undefined + ) + } + ids={parentStudioId ? [parentStudioId] : []} + isDisabled={!isEditing} + /> + + { {!isNew && (
- + + + + + + + +
)} {renderDeleteAlert()} diff --git a/ui/v2.5/src/components/Studios/StudioDetails/StudioChildrenPanel.tsx b/ui/v2.5/src/components/Studios/StudioDetails/StudioChildrenPanel.tsx new file mode 100644 index 000000000..44c7ebf39 --- /dev/null +++ b/ui/v2.5/src/components/Studios/StudioDetails/StudioChildrenPanel.tsx @@ -0,0 +1,47 @@ +import React from "react"; +import * as GQL from "src/core/generated-graphql"; +import { ParentStudiosCriterion } from "src/models/list-filter/criteria/studios"; +import { ListFilterModel } from "src/models/list-filter/filter"; +import { StudioList } from "../StudioList"; + +interface IStudioChildrenPanel { + studio: Partial; +} + +export const StudioChildrenPanel: React.FC = ({ + studio, +}) => { + function filterHook(filter: ListFilterModel) { + const studioValue = { id: studio.id!, label: studio.name! }; + // if studio is already present, then we modify it, otherwise add + let parentStudioCriterion = filter.criteria.find((c) => { + return c.type === "parent_studios"; + }) as ParentStudiosCriterion; + + if ( + parentStudioCriterion && + (parentStudioCriterion.modifier === GQL.CriterionModifier.IncludesAll || + parentStudioCriterion.modifier === GQL.CriterionModifier.Includes) + ) { + // add the studio if not present + if ( + !parentStudioCriterion.value.find((p) => { + return p.id === studio.id; + }) + ) { + parentStudioCriterion.value.push(studioValue); + } + + parentStudioCriterion.modifier = GQL.CriterionModifier.IncludesAll; + } else { + // overwrite + parentStudioCriterion = new ParentStudiosCriterion(); + parentStudioCriterion.value = [studioValue]; + filter.criteria.push(parentStudioCriterion); + } + + return filter; + } + + return ; +}; diff --git a/ui/v2.5/src/components/Studios/StudioList.tsx b/ui/v2.5/src/components/Studios/StudioList.tsx index 960743af4..10d2dde7f 100644 --- a/ui/v2.5/src/components/Studios/StudioList.tsx +++ b/ui/v2.5/src/components/Studios/StudioList.tsx @@ -5,9 +5,19 @@ import { ListFilterModel } from "src/models/list-filter/filter"; import { DisplayMode } from "src/models/list-filter/types"; import { StudioCard } from "./StudioCard"; -export const StudioList: React.FC = () => { +interface IStudioList { + fromParent?: boolean; + filterHook?: (filter: ListFilterModel) => ListFilterModel; +} + +export const StudioList: React.FC = ({ + fromParent, + filterHook, +}) => { const listData = useStudiosList({ renderContent, + subComponent: fromParent, + filterHook, }); function renderContent( @@ -20,7 +30,11 @@ export const StudioList: React.FC = () => { return (
{result.data.findStudios.studios.map((studio) => ( - + ))}
); diff --git a/ui/v2.5/src/core/StashService.ts b/ui/v2.5/src/core/StashService.ts index ffe6e6d11..103dd5a57 100644 --- a/ui/v2.5/src/core/StashService.ts +++ b/ui/v2.5/src/core/StashService.ts @@ -74,6 +74,7 @@ export const useFindStudios = (filter: ListFilterModel) => GQL.useFindStudiosQuery({ variables: { filter: filter.makeFindFilter(), + studio_filter: filter.makeStudioFilter(), }, }); diff --git a/ui/v2.5/src/models/list-filter/criteria/criterion.ts b/ui/v2.5/src/models/list-filter/criteria/criterion.ts index 8a142eff3..d94c65ad8 100644 --- a/ui/v2.5/src/models/list-filter/criteria/criterion.ts +++ b/ui/v2.5/src/models/list-filter/criteria/criterion.ts @@ -31,7 +31,8 @@ export type CriterionType = | "tattoos" | "piercings" | "aliases" - | "gender"; + | "gender" + | "parent_studios"; type Option = string | number | IOptionType; export type CriterionValue = string | number | ILabeledId[]; @@ -93,6 +94,8 @@ export abstract class Criterion { return "Aliases"; case "gender": return "Gender"; + case "parent_studios": + return "Parent Studios"; } } diff --git a/ui/v2.5/src/models/list-filter/criteria/studios.ts b/ui/v2.5/src/models/list-filter/criteria/studios.ts index 0e734fb6b..dad7fa7a6 100644 --- a/ui/v2.5/src/models/list-filter/criteria/studios.ts +++ b/ui/v2.5/src/models/list-filter/criteria/studios.ts @@ -24,3 +24,13 @@ export class StudiosCriterionOption implements ICriterionOption { public label: string = Criterion.getLabel("studios"); public value: CriterionType = "studios"; } + +export class ParentStudiosCriterion extends StudiosCriterion { + public type: CriterionType = "parent_studios"; + public parameterName: string = "parents"; +} + +export class ParentStudiosCriterionOption implements ICriterionOption { + public label: string = Criterion.getLabel("parent_studios"); + public value: CriterionType = "parent_studios"; +} diff --git a/ui/v2.5/src/models/list-filter/criteria/utils.ts b/ui/v2.5/src/models/list-filter/criteria/utils.ts index 785d5f1dd..5907b03b7 100644 --- a/ui/v2.5/src/models/list-filter/criteria/utils.ts +++ b/ui/v2.5/src/models/list-filter/criteria/utils.ts @@ -17,7 +17,7 @@ import { NoneCriterion } from "./none"; import { PerformersCriterion } from "./performers"; import { RatingCriterion } from "./rating"; import { ResolutionCriterion } from "./resolution"; -import { StudiosCriterion } from "./studios"; +import { StudiosCriterion, ParentStudiosCriterion } from "./studios"; import { TagsCriterion } from "./tags"; import { GenderCriterion } from "./gender"; import { MoviesCriterion } from "./movies"; @@ -50,6 +50,8 @@ export function makeCriteria(type: CriterionType = "none") { return new PerformersCriterion(); case "studios": return new StudiosCriterion(); + case "parent_studios": + return new ParentStudiosCriterion(); case "movies": return new MoviesCriterion(); case "birth_year": diff --git a/ui/v2.5/src/models/list-filter/filter.ts b/ui/v2.5/src/models/list-filter/filter.ts index 045893ec6..57dd106d3 100644 --- a/ui/v2.5/src/models/list-filter/filter.ts +++ b/ui/v2.5/src/models/list-filter/filter.ts @@ -7,6 +7,7 @@ import { SceneMarkerFilterType, SortDirectionEnum, MovieFilterType, + StudioFilterType, } from "src/core/generated-graphql"; import { stringToGender } from "src/core/StashService"; import { @@ -41,7 +42,12 @@ import { ResolutionCriterion, ResolutionCriterionOption, } from "./criteria/resolution"; -import { StudiosCriterion, StudiosCriterionOption } from "./criteria/studios"; +import { + StudiosCriterion, + StudiosCriterionOption, + ParentStudiosCriterion, + ParentStudiosCriterionOption, +} from "./criteria/studios"; import { SceneTagsCriterionOption, TagsCriterion, @@ -159,7 +165,10 @@ export class ListFilterModel { this.sortBy = "name"; this.sortByOptions = ["name", "scenes_count"]; this.displayModeOptions = [DisplayMode.Grid]; - this.criterionOptions = [new NoneCriterionOption()]; + this.criterionOptions = [ + new NoneCriterionOption(), + new ParentStudiosCriterionOption(), + ]; break; case FilterMode.Movies: this.sortBy = "name"; @@ -576,4 +585,22 @@ export class ListFilterModel { }); return result; } + + public makeStudioFilter(): StudioFilterType { + const result: StudioFilterType = {}; + this.criteria.forEach((criterion) => { + switch (criterion.type) { + case "parent_studios": { + const studCrit = criterion as ParentStudiosCriterion; + result.parents = { + value: studCrit.value.map((studio) => studio.id), + modifier: studCrit.modifier, + }; + break; + } + // no default + } + }); + return result; + } } diff --git a/ui/v2.5/src/utils/navigation.ts b/ui/v2.5/src/utils/navigation.ts index da1fa9d04..a4867bf6c 100644 --- a/ui/v2.5/src/utils/navigation.ts +++ b/ui/v2.5/src/utils/navigation.ts @@ -1,6 +1,9 @@ import * as GQL from "src/core/generated-graphql"; import { PerformersCriterion } from "src/models/list-filter/criteria/performers"; -import { StudiosCriterion } from "src/models/list-filter/criteria/studios"; +import { + StudiosCriterion, + ParentStudiosCriterion, +} from "src/models/list-filter/criteria/studios"; import { TagsCriterion } from "src/models/list-filter/criteria/tags"; import { ListFilterModel } from "src/models/list-filter/filter"; import { FilterMode } from "src/models/list-filter/types"; @@ -30,6 +33,17 @@ const makeStudioScenesUrl = (studio: Partial) => { return `/scenes?${filter.makeQueryParameters()}`; }; +const makeChildStudiosUrl = (studio: Partial) => { + if (!studio.id) return "#"; + const filter = new ListFilterModel(FilterMode.Studios); + const criterion = new ParentStudiosCriterion(); + criterion.value = [ + { id: studio.id, label: studio.name || `Studio ${studio.id}` }, + ]; + filter.criteria.push(criterion); + return `/studios?${filter.makeQueryParameters()}`; +}; + const makeMovieScenesUrl = (movie: Partial) => { if (!movie.id) return "#"; const filter = new ListFilterModel(FilterMode.Scenes); @@ -73,4 +87,5 @@ export default { makeTagScenesUrl, makeSceneMarkerUrl, makeMovieScenesUrl, + makeChildStudiosUrl, };