mirror of
https://github.com/stashapp/stash.git
synced 2025-12-18 04:44:37 +03:00
Auto tag rewrite (#1324)
This commit is contained in:
@@ -5,12 +5,15 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/remeh/sizedwaitgroup"
|
||||
|
||||
"github.com/stashapp/stash/pkg/autotag"
|
||||
"github.com/stashapp/stash/pkg/logger"
|
||||
"github.com/stashapp/stash/pkg/manager/config"
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
@@ -626,6 +629,15 @@ func (s *singleton) generateScreenshot(sceneId string, at *float64) {
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *singleton) isFileBasedAutoTag(input models.AutoTagMetadataInput) bool {
|
||||
const wildcard = "*"
|
||||
performerIds := input.Performers
|
||||
studioIds := input.Studios
|
||||
tagIds := input.Tags
|
||||
|
||||
return (len(performerIds) == 0 || performerIds[0] == wildcard) && (len(studioIds) == 0 || studioIds[0] == wildcard) && (len(tagIds) == 0 || tagIds[0] == wildcard)
|
||||
}
|
||||
|
||||
func (s *singleton) AutoTag(input models.AutoTagMetadataInput) {
|
||||
if s.Status.Status != Idle {
|
||||
return
|
||||
@@ -636,58 +648,160 @@ func (s *singleton) AutoTag(input models.AutoTagMetadataInput) {
|
||||
go func() {
|
||||
defer s.returnToIdleState()
|
||||
|
||||
performerIds := input.Performers
|
||||
studioIds := input.Studios
|
||||
tagIds := input.Tags
|
||||
|
||||
// calculate work load
|
||||
performerCount := len(performerIds)
|
||||
studioCount := len(studioIds)
|
||||
tagCount := len(tagIds)
|
||||
|
||||
if err := s.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error {
|
||||
performerQuery := r.Performer()
|
||||
studioQuery := r.Studio()
|
||||
tagQuery := r.Tag()
|
||||
|
||||
const wildcard = "*"
|
||||
var err error
|
||||
if performerCount == 1 && performerIds[0] == wildcard {
|
||||
performerCount, err = performerQuery.Count()
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error getting performer count: %s", err.Error())
|
||||
}
|
||||
}
|
||||
if studioCount == 1 && studioIds[0] == wildcard {
|
||||
studioCount, err = studioQuery.Count()
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error getting studio count: %s", err.Error())
|
||||
}
|
||||
}
|
||||
if tagCount == 1 && tagIds[0] == wildcard {
|
||||
tagCount, err = tagQuery.Count()
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error getting tag count: %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
logger.Error(err.Error())
|
||||
return
|
||||
if s.isFileBasedAutoTag(input) {
|
||||
// doing file-based auto-tag
|
||||
s.autoTagScenes(input.Paths, len(input.Performers) > 0, len(input.Studios) > 0, len(input.Tags) > 0)
|
||||
} else {
|
||||
// doing specific performer/studio/tag auto-tag
|
||||
s.autoTagSpecific(input)
|
||||
}
|
||||
|
||||
total := performerCount + studioCount + tagCount
|
||||
s.Status.setProgress(0, total)
|
||||
|
||||
s.autoTagPerformers(input.Paths, performerIds)
|
||||
s.autoTagStudios(input.Paths, studioIds)
|
||||
s.autoTagTags(input.Paths, tagIds)
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *singleton) autoTagScenes(paths []string, performers, studios, tags bool) {
|
||||
if err := s.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error {
|
||||
ret := &models.SceneFilterType{}
|
||||
or := ret
|
||||
sep := string(filepath.Separator)
|
||||
|
||||
for _, p := range paths {
|
||||
if !strings.HasSuffix(p, sep) {
|
||||
p = p + sep
|
||||
}
|
||||
|
||||
if ret.Path == nil {
|
||||
or = ret
|
||||
} else {
|
||||
newOr := &models.SceneFilterType{}
|
||||
or.Or = newOr
|
||||
or = newOr
|
||||
}
|
||||
|
||||
or.Path = &models.StringCriterionInput{
|
||||
Modifier: models.CriterionModifierEquals,
|
||||
Value: p + "%",
|
||||
}
|
||||
}
|
||||
|
||||
organized := false
|
||||
ret.Organized = &organized
|
||||
|
||||
// batch process scenes
|
||||
batchSize := 1000
|
||||
page := 1
|
||||
findFilter := &models.FindFilterType{
|
||||
PerPage: &batchSize,
|
||||
Page: &page,
|
||||
}
|
||||
|
||||
more := true
|
||||
processed := 0
|
||||
for more {
|
||||
scenes, total, err := r.Scene().Query(ret, findFilter)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if processed == 0 {
|
||||
logger.Infof("Starting autotag of %d scenes", total)
|
||||
}
|
||||
|
||||
for _, ss := range scenes {
|
||||
if s.Status.stopping {
|
||||
logger.Info("Stopping due to user request")
|
||||
return nil
|
||||
}
|
||||
|
||||
t := autoTagSceneTask{
|
||||
txnManager: s.TxnManager,
|
||||
scene: ss,
|
||||
performers: performers,
|
||||
studios: studios,
|
||||
tags: tags,
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go t.Start(&wg)
|
||||
wg.Wait()
|
||||
|
||||
processed++
|
||||
s.Status.setProgress(processed, total)
|
||||
}
|
||||
|
||||
if len(scenes) != batchSize {
|
||||
more = false
|
||||
} else {
|
||||
page++
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
logger.Error(err.Error())
|
||||
}
|
||||
|
||||
logger.Info("Finished autotag")
|
||||
}
|
||||
|
||||
func (s *singleton) autoTagSpecific(input models.AutoTagMetadataInput) {
|
||||
performerIds := input.Performers
|
||||
studioIds := input.Studios
|
||||
tagIds := input.Tags
|
||||
|
||||
performerCount := len(performerIds)
|
||||
studioCount := len(studioIds)
|
||||
tagCount := len(tagIds)
|
||||
|
||||
if err := s.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error {
|
||||
performerQuery := r.Performer()
|
||||
studioQuery := r.Studio()
|
||||
tagQuery := r.Tag()
|
||||
|
||||
const wildcard = "*"
|
||||
var err error
|
||||
if performerCount == 1 && performerIds[0] == wildcard {
|
||||
performerCount, err = performerQuery.Count()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting performer count: %s", err.Error())
|
||||
}
|
||||
}
|
||||
if studioCount == 1 && studioIds[0] == wildcard {
|
||||
studioCount, err = studioQuery.Count()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting studio count: %s", err.Error())
|
||||
}
|
||||
}
|
||||
if tagCount == 1 && tagIds[0] == wildcard {
|
||||
tagCount, err = tagQuery.Count()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting tag count: %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
logger.Error(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
total := performerCount + studioCount + tagCount
|
||||
s.Status.setProgress(0, total)
|
||||
|
||||
logger.Infof("Starting autotag of %d performers, %d studios, %d tags", performerCount, studioCount, tagCount)
|
||||
|
||||
s.autoTagPerformers(input.Paths, performerIds)
|
||||
s.autoTagStudios(input.Paths, studioIds)
|
||||
s.autoTagTags(input.Paths, tagIds)
|
||||
|
||||
logger.Info("Finished autotag")
|
||||
}
|
||||
|
||||
func (s *singleton) autoTagPerformers(paths []string, performerIds []string) {
|
||||
var wg sync.WaitGroup
|
||||
if s.Status.stopping {
|
||||
return
|
||||
}
|
||||
|
||||
for _, performerId := range performerIds {
|
||||
var performers []*models.Performer
|
||||
|
||||
@@ -698,46 +812,53 @@ func (s *singleton) autoTagPerformers(paths []string, performerIds []string) {
|
||||
var err error
|
||||
performers, err = performerQuery.All()
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error querying performers: %s", err.Error())
|
||||
return fmt.Errorf("error querying performers: %s", err.Error())
|
||||
}
|
||||
} else {
|
||||
performerIdInt, err := strconv.Atoi(performerId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error parsing performer id %s: %s", performerId, err.Error())
|
||||
return fmt.Errorf("error parsing performer id %s: %s", performerId, err.Error())
|
||||
}
|
||||
|
||||
performer, err := performerQuery.Find(performerIdInt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error finding performer id %s: %s", performerId, err.Error())
|
||||
return fmt.Errorf("error finding performer id %s: %s", performerId, err.Error())
|
||||
}
|
||||
|
||||
if performer == nil {
|
||||
return fmt.Errorf("performer with id %s not found", performerId)
|
||||
}
|
||||
performers = append(performers, performer)
|
||||
}
|
||||
|
||||
for _, performer := range performers {
|
||||
if s.Status.stopping {
|
||||
logger.Info("Stopping due to user request")
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.TxnManager.WithTxn(context.TODO(), func(r models.Repository) error {
|
||||
return autotag.PerformerScenes(performer, paths, r.Scene())
|
||||
}); err != nil {
|
||||
return fmt.Errorf("error auto-tagging performer '%s': %s", performer.Name.String, err.Error())
|
||||
}
|
||||
|
||||
s.Status.incrementProgress()
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
logger.Error(err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
for _, performer := range performers {
|
||||
wg.Add(1)
|
||||
task := AutoTagPerformerTask{
|
||||
AutoTagTask: AutoTagTask{
|
||||
txnManager: s.TxnManager,
|
||||
paths: paths,
|
||||
},
|
||||
performer: performer,
|
||||
}
|
||||
go task.Start(&wg)
|
||||
wg.Wait()
|
||||
|
||||
s.Status.incrementProgress()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *singleton) autoTagStudios(paths []string, studioIds []string) {
|
||||
var wg sync.WaitGroup
|
||||
if s.Status.stopping {
|
||||
return
|
||||
}
|
||||
|
||||
for _, studioId := range studioIds {
|
||||
var studios []*models.Studio
|
||||
|
||||
@@ -747,46 +868,54 @@ func (s *singleton) autoTagStudios(paths []string, studioIds []string) {
|
||||
var err error
|
||||
studios, err = studioQuery.All()
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error querying studios: %s", err.Error())
|
||||
return fmt.Errorf("error querying studios: %s", err.Error())
|
||||
}
|
||||
} else {
|
||||
studioIdInt, err := strconv.Atoi(studioId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error parsing studio id %s: %s", studioId, err.Error())
|
||||
return fmt.Errorf("error parsing studio id %s: %s", studioId, err.Error())
|
||||
}
|
||||
|
||||
studio, err := studioQuery.Find(studioIdInt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error finding studio id %s: %s", studioId, err.Error())
|
||||
return fmt.Errorf("error finding studio id %s: %s", studioId, err.Error())
|
||||
}
|
||||
|
||||
if studio == nil {
|
||||
return fmt.Errorf("studio with id %s not found", studioId)
|
||||
}
|
||||
|
||||
studios = append(studios, studio)
|
||||
}
|
||||
|
||||
for _, studio := range studios {
|
||||
if s.Status.stopping {
|
||||
logger.Info("Stopping due to user request")
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.TxnManager.WithTxn(context.TODO(), func(r models.Repository) error {
|
||||
return autotag.StudioScenes(studio, paths, r.Scene())
|
||||
}); err != nil {
|
||||
return fmt.Errorf("error auto-tagging studio '%s': %s", studio.Name.String, err.Error())
|
||||
}
|
||||
|
||||
s.Status.incrementProgress()
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
logger.Error(err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
for _, studio := range studios {
|
||||
wg.Add(1)
|
||||
task := AutoTagStudioTask{
|
||||
AutoTagTask: AutoTagTask{
|
||||
txnManager: s.TxnManager,
|
||||
paths: paths,
|
||||
},
|
||||
studio: studio,
|
||||
}
|
||||
go task.Start(&wg)
|
||||
wg.Wait()
|
||||
|
||||
s.Status.incrementProgress()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *singleton) autoTagTags(paths []string, tagIds []string) {
|
||||
var wg sync.WaitGroup
|
||||
if s.Status.stopping {
|
||||
return
|
||||
}
|
||||
|
||||
for _, tagId := range tagIds {
|
||||
var tags []*models.Tag
|
||||
if err := s.TxnManager.WithReadTxn(context.TODO(), func(r models.ReaderRepository) error {
|
||||
@@ -795,41 +924,41 @@ func (s *singleton) autoTagTags(paths []string, tagIds []string) {
|
||||
var err error
|
||||
tags, err = tagQuery.All()
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error querying tags: %s", err.Error())
|
||||
return fmt.Errorf("error querying tags: %s", err.Error())
|
||||
}
|
||||
} else {
|
||||
tagIdInt, err := strconv.Atoi(tagId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error parsing tag id %s: %s", tagId, err.Error())
|
||||
return fmt.Errorf("error parsing tag id %s: %s", tagId, err.Error())
|
||||
}
|
||||
|
||||
tag, err := tagQuery.Find(tagIdInt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error finding tag id %s: %s", tagId, err.Error())
|
||||
return fmt.Errorf("error finding tag id %s: %s", tagId, err.Error())
|
||||
}
|
||||
tags = append(tags, tag)
|
||||
}
|
||||
|
||||
for _, tag := range tags {
|
||||
if s.Status.stopping {
|
||||
logger.Info("Stopping due to user request")
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.TxnManager.WithTxn(context.TODO(), func(r models.Repository) error {
|
||||
return autotag.TagScenes(tag, paths, r.Scene())
|
||||
}); err != nil {
|
||||
return fmt.Errorf("error auto-tagging tag '%s': %s", tag.Name, err.Error())
|
||||
}
|
||||
|
||||
s.Status.incrementProgress()
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
logger.Error(err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
for _, tag := range tags {
|
||||
wg.Add(1)
|
||||
task := AutoTagTagTask{
|
||||
AutoTagTask: AutoTagTask{
|
||||
txnManager: s.TxnManager,
|
||||
paths: paths,
|
||||
},
|
||||
tag: tag,
|
||||
}
|
||||
go task.Start(&wg)
|
||||
wg.Wait()
|
||||
|
||||
s.Status.incrementProgress()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user