Saved filters (#1474)

* Refactor list filter
* Filter/criterion refactor
* Rename option value to type
* Remove None from options
* Add saved filter button
* Integrate default filters
This commit is contained in:
WithoutPants
2021-06-16 14:53:32 +10:00
committed by GitHub
parent 4fe4da6c01
commit dc7584d77e
74 changed files with 2583 additions and 1263 deletions

109
pkg/sqlite/saved_filter.go Normal file
View File

@@ -0,0 +1,109 @@
package sqlite
import (
"database/sql"
"fmt"
"github.com/stashapp/stash/pkg/models"
)
const savedFilterTable = "saved_filters"
const savedFilterDefaultName = ""
type savedFilterQueryBuilder struct {
repository
}
func NewSavedFilterReaderWriter(tx dbi) *savedFilterQueryBuilder {
return &savedFilterQueryBuilder{
repository{
tx: tx,
tableName: savedFilterTable,
idColumn: idColumn,
},
}
}
func (qb *savedFilterQueryBuilder) Create(newObject models.SavedFilter) (*models.SavedFilter, error) {
var ret models.SavedFilter
if err := qb.insertObject(newObject, &ret); err != nil {
return nil, err
}
return &ret, nil
}
func (qb *savedFilterQueryBuilder) Update(updatedObject models.SavedFilter) (*models.SavedFilter, error) {
const partial = false
if err := qb.update(updatedObject.ID, updatedObject, partial); err != nil {
return nil, err
}
var ret models.SavedFilter
if err := qb.get(updatedObject.ID, &ret); err != nil {
return nil, err
}
return &ret, nil
}
func (qb *savedFilterQueryBuilder) SetDefault(obj models.SavedFilter) (*models.SavedFilter, error) {
// find the existing default
existing, err := qb.FindDefault(obj.Mode)
if err != nil {
return nil, err
}
obj.Name = savedFilterDefaultName
if existing != nil {
obj.ID = existing.ID
return qb.Update(obj)
}
return qb.Create(obj)
}
func (qb *savedFilterQueryBuilder) Destroy(id int) error {
return qb.destroyExisting([]int{id})
}
func (qb *savedFilterQueryBuilder) Find(id int) (*models.SavedFilter, error) {
var ret models.SavedFilter
if err := qb.get(id, &ret); err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
return &ret, nil
}
func (qb *savedFilterQueryBuilder) FindByMode(mode models.FilterMode) ([]*models.SavedFilter, error) {
// exclude empty-named filters - these are the internal default filters
query := fmt.Sprintf(`SELECT * FROM %s WHERE mode = ? AND name != ?`, savedFilterTable)
var ret models.SavedFilters
if err := qb.query(query, []interface{}{mode, savedFilterDefaultName}, &ret); err != nil {
return nil, err
}
return []*models.SavedFilter(ret), nil
}
func (qb *savedFilterQueryBuilder) FindDefault(mode models.FilterMode) (*models.SavedFilter, error) {
query := fmt.Sprintf(`SELECT * FROM %s WHERE mode = ? AND name = ?`, savedFilterTable)
var ret models.SavedFilters
if err := qb.query(query, []interface{}{mode, savedFilterDefaultName}, &ret); err != nil {
return nil, err
}
if len(ret) > 0 {
return ret[0], nil
}
return nil, nil
}

View File

@@ -0,0 +1,123 @@
// +build integration
package sqlite_test
import (
"testing"
"github.com/stashapp/stash/pkg/models"
"github.com/stretchr/testify/assert"
)
func TestSavedFilterFind(t *testing.T) {
withTxn(func(r models.Repository) error {
savedFilter, err := r.SavedFilter().Find(savedFilterIDs[savedFilterIdxImage])
if err != nil {
t.Errorf("Error finding saved filter: %s", err.Error())
}
assert.Equal(t, savedFilterIDs[savedFilterIdxImage], savedFilter.ID)
return nil
})
}
func TestSavedFilterFindByMode(t *testing.T) {
withTxn(func(r models.Repository) error {
savedFilters, err := r.SavedFilter().FindByMode(models.FilterModeScenes)
if err != nil {
t.Errorf("Error finding saved filters: %s", err.Error())
}
assert.Len(t, savedFilters, 1)
assert.Equal(t, savedFilterIDs[savedFilterIdxScene], savedFilters[0].ID)
return nil
})
}
func TestSavedFilterDestroy(t *testing.T) {
const filterName = "filterToDestroy"
const testFilter = "{}"
var id int
// create the saved filter to destroy
withTxn(func(r models.Repository) error {
created, err := r.SavedFilter().Create(models.SavedFilter{
Name: filterName,
Mode: models.FilterModeScenes,
Filter: testFilter,
})
if err == nil {
id = created.ID
}
return err
})
withTxn(func(r models.Repository) error {
qb := r.SavedFilter()
return qb.Destroy(id)
})
// now try to find it
withTxn(func(r models.Repository) error {
found, err := r.SavedFilter().Find(id)
if err == nil {
assert.Nil(t, found)
}
return err
})
}
func TestSavedFilterFindDefault(t *testing.T) {
withTxn(func(r models.Repository) error {
def, err := r.SavedFilter().FindDefault(models.FilterModeScenes)
if err == nil {
assert.Equal(t, savedFilterIDs[savedFilterIdxDefaultScene], def.ID)
}
return err
})
}
func TestSavedFilterSetDefault(t *testing.T) {
const newFilter = "foo"
withTxn(func(r models.Repository) error {
_, err := r.SavedFilter().SetDefault(models.SavedFilter{
Mode: models.FilterModeMovies,
Filter: newFilter,
})
return err
})
var defID int
withTxn(func(r models.Repository) error {
def, err := r.SavedFilter().FindDefault(models.FilterModeMovies)
if err == nil {
defID = def.ID
assert.Equal(t, newFilter, def.Filter)
}
return err
})
// destroy it again
withTxn(func(r models.Repository) error {
return r.SavedFilter().Destroy(defID)
})
}
// TODO Update
// TODO Destroy
// TODO Find
// TODO GetMarkerStrings
// TODO Wall
// TODO Query

View File

@@ -189,22 +189,34 @@ const (
)
const (
pathField = "Path"
checksumField = "Checksum"
titleField = "Title"
urlField = "URL"
zipPath = "zipPath.zip"
savedFilterIdxDefaultScene = iota
savedFilterIdxDefaultImage
savedFilterIdxScene
savedFilterIdxImage
// new indexes above
totalSavedFilters
)
const (
pathField = "Path"
checksumField = "Checksum"
titleField = "Title"
urlField = "URL"
zipPath = "zipPath.zip"
firstSavedFilterName = "firstSavedFilterName"
)
var (
sceneIDs []int
imageIDs []int
performerIDs []int
movieIDs []int
galleryIDs []int
tagIDs []int
studioIDs []int
markerIDs []int
sceneIDs []int
imageIDs []int
performerIDs []int
movieIDs []int
galleryIDs []int
tagIDs []int
studioIDs []int
markerIDs []int
savedFilterIDs []int
tagNames []string
studioNames []string
@@ -423,6 +435,10 @@ func populateDB() error {
return fmt.Errorf("error creating studios: %s", err.Error())
}
if err := createSavedFilters(r.SavedFilter(), totalSavedFilters); err != nil {
return fmt.Errorf("error creating saved filters: %s", err.Error())
}
if err := linkPerformerTags(r.Performer()); err != nil {
return fmt.Errorf("error linking performer tags: %s", err.Error())
}
@@ -979,6 +995,51 @@ func createMarker(mqb models.SceneMarkerReaderWriter, sceneIdx, primaryTagIdx in
return nil
}
func getSavedFilterMode(index int) models.FilterMode {
switch index {
case savedFilterIdxScene, savedFilterIdxDefaultScene:
return models.FilterModeScenes
case savedFilterIdxImage, savedFilterIdxDefaultImage:
return models.FilterModeImages
default:
return models.FilterModeScenes
}
}
func getSavedFilterName(index int) string {
if index <= savedFilterIdxDefaultImage {
// empty string for default filters
return ""
}
if index <= savedFilterIdxImage {
// use the same name for the first two - should be possible
return firstSavedFilterName
}
return getPrefixedStringValue("savedFilter", index, "Name")
}
func createSavedFilters(qb models.SavedFilterReaderWriter, n int) error {
for i := 0; i < n; i++ {
savedFilter := models.SavedFilter{
Mode: getSavedFilterMode(i),
Name: getSavedFilterName(i),
Filter: getPrefixedStringValue("savedFilter", i, "Filter"),
}
created, err := qb.Create(savedFilter)
if err != nil {
return fmt.Errorf("Error creating saved filter %v+: %s", savedFilter, err.Error())
}
savedFilterIDs = append(savedFilterIDs, created.ID)
}
return nil
}
func doLinks(links [][2]int, fn func(idx1, idx2 int) error) error {
for _, l := range links {
if err := fn(l[0], l[1]); err != nil {

View File

@@ -125,6 +125,11 @@ func (t *transaction) Tag() models.TagReaderWriter {
return NewTagReaderWriter(t.tx)
}
func (t *transaction) SavedFilter() models.SavedFilterReaderWriter {
t.ensureTx()
return NewSavedFilterReaderWriter(t.tx)
}
type ReadTransaction struct{}
func (t *ReadTransaction) Begin() error {
@@ -183,6 +188,10 @@ func (t *ReadTransaction) Tag() models.TagReader {
return NewTagReaderWriter(database.DB)
}
func (t *ReadTransaction) SavedFilter() models.SavedFilterReader {
return NewSavedFilterReaderWriter(database.DB)
}
type TransactionManager struct {
}