Auto tag rewrite (#1324)

This commit is contained in:
WithoutPants
2021-04-26 12:51:31 +10:00
committed by GitHub
parent f66010a367
commit 2eb2d865dc
26 changed files with 1469 additions and 370 deletions

View File

@@ -5,12 +5,15 @@ import (
"errors"
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
"github.com/remeh/sizedwaitgroup"
"github.com/stashapp/stash/pkg/autotag"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/manager/config"
"github.com/stashapp/stash/pkg/models"
@@ -626,6 +629,15 @@ func (s *singleton) generateScreenshot(sceneId string, at *float64) {
}()
}
func (s *singleton) isFileBasedAutoTag(input models.AutoTagMetadataInput) bool {
const wildcard = "*"
performerIds := input.Performers
studioIds := input.Studios
tagIds := input.Tags
return (len(performerIds) == 0 || performerIds[0] == wildcard) && (len(studioIds) == 0 || studioIds[0] == wildcard) && (len(tagIds) == 0 || tagIds[0] == wildcard)
}
func (s *singleton) AutoTag(input models.AutoTagMetadataInput) {
if s.Status.Status != Idle {
return
@@ -636,58 +648,160 @@ func (s *singleton) AutoTag(input models.AutoTagMetadataInput) {
go func() {
defer s.returnToIdleState()
performerIds := input.Performers
studioIds := input.Studios
tagIds := input.Tags
// calculate work load
performerCount := len(performerIds)
studioCount := len(studioIds)
tagCount := len(tagIds)
if err := s.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error {
performerQuery := r.Performer()
studioQuery := r.Studio()
tagQuery := r.Tag()
const wildcard = "*"
var err error
if performerCount == 1 && performerIds[0] == wildcard {
performerCount, err = performerQuery.Count()
if err != nil {
return fmt.Errorf("Error getting performer count: %s", err.Error())
}
}
if studioCount == 1 && studioIds[0] == wildcard {
studioCount, err = studioQuery.Count()
if err != nil {
return fmt.Errorf("Error getting studio count: %s", err.Error())
}
}
if tagCount == 1 && tagIds[0] == wildcard {
tagCount, err = tagQuery.Count()
if err != nil {
return fmt.Errorf("Error getting tag count: %s", err.Error())
}
}
return nil
}); err != nil {
logger.Error(err.Error())
return
if s.isFileBasedAutoTag(input) {
// doing file-based auto-tag
s.autoTagScenes(input.Paths, len(input.Performers) > 0, len(input.Studios) > 0, len(input.Tags) > 0)
} else {
// doing specific performer/studio/tag auto-tag
s.autoTagSpecific(input)
}
total := performerCount + studioCount + tagCount
s.Status.setProgress(0, total)
s.autoTagPerformers(input.Paths, performerIds)
s.autoTagStudios(input.Paths, studioIds)
s.autoTagTags(input.Paths, tagIds)
}()
}
func (s *singleton) autoTagScenes(paths []string, performers, studios, tags bool) {
if err := s.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error {
ret := &models.SceneFilterType{}
or := ret
sep := string(filepath.Separator)
for _, p := range paths {
if !strings.HasSuffix(p, sep) {
p = p + sep
}
if ret.Path == nil {
or = ret
} else {
newOr := &models.SceneFilterType{}
or.Or = newOr
or = newOr
}
or.Path = &models.StringCriterionInput{
Modifier: models.CriterionModifierEquals,
Value: p + "%",
}
}
organized := false
ret.Organized = &organized
// batch process scenes
batchSize := 1000
page := 1
findFilter := &models.FindFilterType{
PerPage: &batchSize,
Page: &page,
}
more := true
processed := 0
for more {
scenes, total, err := r.Scene().Query(ret, findFilter)
if err != nil {
return err
}
if processed == 0 {
logger.Infof("Starting autotag of %d scenes", total)
}
for _, ss := range scenes {
if s.Status.stopping {
logger.Info("Stopping due to user request")
return nil
}
t := autoTagSceneTask{
txnManager: s.TxnManager,
scene: ss,
performers: performers,
studios: studios,
tags: tags,
}
var wg sync.WaitGroup
wg.Add(1)
go t.Start(&wg)
wg.Wait()
processed++
s.Status.setProgress(processed, total)
}
if len(scenes) != batchSize {
more = false
} else {
page++
}
}
return nil
}); err != nil {
logger.Error(err.Error())
}
logger.Info("Finished autotag")
}
func (s *singleton) autoTagSpecific(input models.AutoTagMetadataInput) {
performerIds := input.Performers
studioIds := input.Studios
tagIds := input.Tags
performerCount := len(performerIds)
studioCount := len(studioIds)
tagCount := len(tagIds)
if err := s.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error {
performerQuery := r.Performer()
studioQuery := r.Studio()
tagQuery := r.Tag()
const wildcard = "*"
var err error
if performerCount == 1 && performerIds[0] == wildcard {
performerCount, err = performerQuery.Count()
if err != nil {
return fmt.Errorf("error getting performer count: %s", err.Error())
}
}
if studioCount == 1 && studioIds[0] == wildcard {
studioCount, err = studioQuery.Count()
if err != nil {
return fmt.Errorf("error getting studio count: %s", err.Error())
}
}
if tagCount == 1 && tagIds[0] == wildcard {
tagCount, err = tagQuery.Count()
if err != nil {
return fmt.Errorf("error getting tag count: %s", err.Error())
}
}
return nil
}); err != nil {
logger.Error(err.Error())
return
}
total := performerCount + studioCount + tagCount
s.Status.setProgress(0, total)
logger.Infof("Starting autotag of %d performers, %d studios, %d tags", performerCount, studioCount, tagCount)
s.autoTagPerformers(input.Paths, performerIds)
s.autoTagStudios(input.Paths, studioIds)
s.autoTagTags(input.Paths, tagIds)
logger.Info("Finished autotag")
}
func (s *singleton) autoTagPerformers(paths []string, performerIds []string) {
var wg sync.WaitGroup
if s.Status.stopping {
return
}
for _, performerId := range performerIds {
var performers []*models.Performer
@@ -698,46 +812,53 @@ func (s *singleton) autoTagPerformers(paths []string, performerIds []string) {
var err error
performers, err = performerQuery.All()
if err != nil {
return fmt.Errorf("Error querying performers: %s", err.Error())
return fmt.Errorf("error querying performers: %s", err.Error())
}
} else {
performerIdInt, err := strconv.Atoi(performerId)
if err != nil {
return fmt.Errorf("Error parsing performer id %s: %s", performerId, err.Error())
return fmt.Errorf("error parsing performer id %s: %s", performerId, err.Error())
}
performer, err := performerQuery.Find(performerIdInt)
if err != nil {
return fmt.Errorf("Error finding performer id %s: %s", performerId, err.Error())
return fmt.Errorf("error finding performer id %s: %s", performerId, err.Error())
}
if performer == nil {
return fmt.Errorf("performer with id %s not found", performerId)
}
performers = append(performers, performer)
}
for _, performer := range performers {
if s.Status.stopping {
logger.Info("Stopping due to user request")
return nil
}
if err := s.TxnManager.WithTxn(context.TODO(), func(r models.Repository) error {
return autotag.PerformerScenes(performer, paths, r.Scene())
}); err != nil {
return fmt.Errorf("error auto-tagging performer '%s': %s", performer.Name.String, err.Error())
}
s.Status.incrementProgress()
}
return nil
}); err != nil {
logger.Error(err.Error())
continue
}
for _, performer := range performers {
wg.Add(1)
task := AutoTagPerformerTask{
AutoTagTask: AutoTagTask{
txnManager: s.TxnManager,
paths: paths,
},
performer: performer,
}
go task.Start(&wg)
wg.Wait()
s.Status.incrementProgress()
}
}
}
func (s *singleton) autoTagStudios(paths []string, studioIds []string) {
var wg sync.WaitGroup
if s.Status.stopping {
return
}
for _, studioId := range studioIds {
var studios []*models.Studio
@@ -747,46 +868,54 @@ func (s *singleton) autoTagStudios(paths []string, studioIds []string) {
var err error
studios, err = studioQuery.All()
if err != nil {
return fmt.Errorf("Error querying studios: %s", err.Error())
return fmt.Errorf("error querying studios: %s", err.Error())
}
} else {
studioIdInt, err := strconv.Atoi(studioId)
if err != nil {
return fmt.Errorf("Error parsing studio id %s: %s", studioId, err.Error())
return fmt.Errorf("error parsing studio id %s: %s", studioId, err.Error())
}
studio, err := studioQuery.Find(studioIdInt)
if err != nil {
return fmt.Errorf("Error finding studio id %s: %s", studioId, err.Error())
return fmt.Errorf("error finding studio id %s: %s", studioId, err.Error())
}
if studio == nil {
return fmt.Errorf("studio with id %s not found", studioId)
}
studios = append(studios, studio)
}
for _, studio := range studios {
if s.Status.stopping {
logger.Info("Stopping due to user request")
return nil
}
if err := s.TxnManager.WithTxn(context.TODO(), func(r models.Repository) error {
return autotag.StudioScenes(studio, paths, r.Scene())
}); err != nil {
return fmt.Errorf("error auto-tagging studio '%s': %s", studio.Name.String, err.Error())
}
s.Status.incrementProgress()
}
return nil
}); err != nil {
logger.Error(err.Error())
continue
}
for _, studio := range studios {
wg.Add(1)
task := AutoTagStudioTask{
AutoTagTask: AutoTagTask{
txnManager: s.TxnManager,
paths: paths,
},
studio: studio,
}
go task.Start(&wg)
wg.Wait()
s.Status.incrementProgress()
}
}
}
func (s *singleton) autoTagTags(paths []string, tagIds []string) {
var wg sync.WaitGroup
if s.Status.stopping {
return
}
for _, tagId := range tagIds {
var tags []*models.Tag
if err := s.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error {
@@ -795,41 +924,41 @@ func (s *singleton) autoTagTags(paths []string, tagIds []string) {
var err error
tags, err = tagQuery.All()
if err != nil {
return fmt.Errorf("Error querying tags: %s", err.Error())
return fmt.Errorf("error querying tags: %s", err.Error())
}
} else {
tagIdInt, err := strconv.Atoi(tagId)
if err != nil {
return fmt.Errorf("Error parsing tag id %s: %s", tagId, err.Error())
return fmt.Errorf("error parsing tag id %s: %s", tagId, err.Error())
}
tag, err := tagQuery.Find(tagIdInt)
if err != nil {
return fmt.Errorf("Error finding tag id %s: %s", tagId, err.Error())
return fmt.Errorf("error finding tag id %s: %s", tagId, err.Error())
}
tags = append(tags, tag)
}
for _, tag := range tags {
if s.Status.stopping {
logger.Info("Stopping due to user request")
return nil
}
if err := s.TxnManager.WithTxn(context.TODO(), func(r models.Repository) error {
return autotag.TagScenes(tag, paths, r.Scene())
}); err != nil {
return fmt.Errorf("error auto-tagging tag '%s': %s", tag.Name, err.Error())
}
s.Status.incrementProgress()
}
return nil
}); err != nil {
logger.Error(err.Error())
continue
}
for _, tag := range tags {
wg.Add(1)
task := AutoTagTagTask{
AutoTagTask: AutoTagTask{
txnManager: s.TxnManager,
paths: paths,
},
tag: tag,
}
go task.Start(&wg)
wg.Wait()
s.Status.incrementProgress()
}
}
}