Plugin hooks (#1452)

* Refactor session and plugin code
* Add context to job tasks
* Show hooks in plugins page
* Refactor session management
This commit is contained in:
WithoutPants
2021-06-11 17:24:58 +10:00
committed by GitHub
parent dde361f9f3
commit 46bbede9a0
48 changed files with 1289 additions and 338 deletions

View File

@@ -10,6 +10,12 @@ query Plugins {
name name
description description
} }
hooks {
name
description
hooks
}
} }
} }

View File

@@ -7,6 +7,7 @@ type Plugin {
version: String version: String
tasks: [PluginTask!] tasks: [PluginTask!]
hooks: [PluginHook!]
} }
type PluginTask { type PluginTask {
@@ -15,6 +16,13 @@ type PluginTask {
plugin: Plugin! plugin: Plugin!
} }
type PluginHook {
name: String!
description: String
hooks: [String!]
plugin: Plugin!
}
type PluginResult { type PluginResult {
error: String error: String
result: String result: String

View File

@@ -65,6 +65,15 @@ func (t changesetTranslator) hasField(field string) bool {
return found return found
} }
func (t changesetTranslator) getFields() []string {
var ret []string
for k := range t.inputMap {
ret = append(ret, k)
}
return ret
}
func (t changesetTranslator) nullString(value *string, field string) *sql.NullString { func (t changesetTranslator) nullString(value *string, field string) *sql.NullString {
if !t.hasField(field) { if !t.hasField(field) {
return nil return nil

View File

@@ -5,13 +5,12 @@ package api
type key int type key int
const ( const (
galleryKey key = 0 galleryKey key = iota
performerKey key = 1 performerKey
sceneKey key = 2 sceneKey
studioKey key = 3 studioKey
movieKey key = 4 movieKey
ContextUser key = 5 tagKey
tagKey key = 6 downloadKey
downloadKey key = 7 imageKey
imageKey key = 8
) )

View File

@@ -7,10 +7,16 @@ import (
"github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/plugin"
) )
type hookExecutor interface {
ExecutePostHooks(ctx context.Context, id int, hookType plugin.HookTriggerEnum, input interface{}, inputFields []string)
}
type Resolver struct { type Resolver struct {
txnManager models.TransactionManager txnManager models.TransactionManager
hookExecutor hookExecutor
} }
func (r *Resolver) Gallery() models.GalleryResolver { func (r *Resolver) Gallery() models.GalleryResolver {

View File

@@ -10,9 +10,21 @@ import (
"github.com/stashapp/stash/pkg/manager" "github.com/stashapp/stash/pkg/manager"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/plugin"
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
) )
func (r *mutationResolver) getGallery(ctx context.Context, id int) (ret *models.Gallery, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Gallery().Find(id)
return err
}); err != nil {
return nil, err
}
return ret, nil
}
func (r *mutationResolver) GalleryCreate(ctx context.Context, input models.GalleryCreateInput) (*models.Gallery, error) { func (r *mutationResolver) GalleryCreate(ctx context.Context, input models.GalleryCreateInput) (*models.Gallery, error) {
// name must be provided // name must be provided
if input.Title == "" { if input.Title == "" {
@@ -90,7 +102,8 @@ func (r *mutationResolver) GalleryCreate(ctx context.Context, input models.Galle
return nil, err return nil, err
} }
return gallery, nil r.hookExecutor.ExecutePostHooks(ctx, gallery.ID, plugin.GalleryCreatePost, input, nil)
return r.getGallery(ctx, gallery.ID)
} }
func (r *mutationResolver) updateGalleryPerformers(qb models.GalleryReaderWriter, galleryID int, performerIDs []string) error { func (r *mutationResolver) updateGalleryPerformers(qb models.GalleryReaderWriter, galleryID int, performerIDs []string) error {
@@ -130,7 +143,9 @@ func (r *mutationResolver) GalleryUpdate(ctx context.Context, input models.Galle
return nil, err return nil, err
} }
return ret, nil // execute post hooks outside txn
r.hookExecutor.ExecutePostHooks(ctx, ret.ID, plugin.GalleryUpdatePost, input, translator.getFields())
return r.getGallery(ctx, ret.ID)
} }
func (r *mutationResolver) GalleriesUpdate(ctx context.Context, input []*models.GalleryUpdateInput) (ret []*models.Gallery, err error) { func (r *mutationResolver) GalleriesUpdate(ctx context.Context, input []*models.GalleryUpdateInput) (ret []*models.Gallery, err error) {
@@ -156,7 +171,23 @@ func (r *mutationResolver) GalleriesUpdate(ctx context.Context, input []*models.
return nil, err return nil, err
} }
return ret, nil // execute post hooks outside txn
var newRet []*models.Gallery
for i, gallery := range ret {
translator := changesetTranslator{
inputMap: inputMaps[i],
}
r.hookExecutor.ExecutePostHooks(ctx, gallery.ID, plugin.GalleryUpdatePost, input, translator.getFields())
gallery, err = r.getGallery(ctx, gallery.ID)
if err != nil {
return nil, err
}
newRet = append(newRet, gallery)
}
return newRet, nil
} }
func (r *mutationResolver) galleryUpdate(input models.GalleryUpdateInput, translator changesetTranslator, repo models.Repository) (*models.Gallery, error) { func (r *mutationResolver) galleryUpdate(input models.GalleryUpdateInput, translator changesetTranslator, repo models.Repository) (*models.Gallery, error) {
@@ -314,7 +345,20 @@ func (r *mutationResolver) BulkGalleryUpdate(ctx context.Context, input models.B
return nil, err return nil, err
} }
return ret, nil // execute post hooks outside of txn
var newRet []*models.Gallery
for _, gallery := range ret {
r.hookExecutor.ExecutePostHooks(ctx, gallery.ID, plugin.GalleryUpdatePost, input, translator.getFields())
gallery, err := r.getGallery(ctx, gallery.ID)
if err != nil {
return nil, err
}
newRet = append(newRet, gallery)
}
return newRet, nil
} }
func adjustGalleryPerformerIDs(qb models.GalleryReader, galleryID int, ids models.BulkUpdateIds) (ret []int, err error) { func adjustGalleryPerformerIDs(qb models.GalleryReader, galleryID int, ids models.BulkUpdateIds) (ret []int, err error) {
@@ -438,6 +482,16 @@ func (r *mutationResolver) GalleryDestroy(ctx context.Context, input models.Gall
} }
} }
// call post hook after performing the other actions
for _, gallery := range galleries {
r.hookExecutor.ExecutePostHooks(ctx, gallery.ID, plugin.GalleryDestroyPost, input, nil)
}
// call image destroy post hook as well
for _, img := range imgsToDelete {
r.hookExecutor.ExecutePostHooks(ctx, img.ID, plugin.ImageDestroyPost, nil, nil)
}
return true, nil return true, nil
} }

View File

@@ -8,9 +8,21 @@ import (
"github.com/stashapp/stash/pkg/manager" "github.com/stashapp/stash/pkg/manager"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/plugin"
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
) )
func (r *mutationResolver) getImage(ctx context.Context, id int) (ret *models.Image, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Image().Find(id)
return err
}); err != nil {
return nil, err
}
return ret, nil
}
func (r *mutationResolver) ImageUpdate(ctx context.Context, input models.ImageUpdateInput) (ret *models.Image, err error) { func (r *mutationResolver) ImageUpdate(ctx context.Context, input models.ImageUpdateInput) (ret *models.Image, err error) {
translator := changesetTranslator{ translator := changesetTranslator{
inputMap: getUpdateInputMap(ctx), inputMap: getUpdateInputMap(ctx),
@@ -24,7 +36,9 @@ func (r *mutationResolver) ImageUpdate(ctx context.Context, input models.ImageUp
return nil, err return nil, err
} }
return ret, nil // execute post hooks outside txn
r.hookExecutor.ExecutePostHooks(ctx, ret.ID, plugin.ImageUpdatePost, input, translator.getFields())
return r.getImage(ctx, ret.ID)
} }
func (r *mutationResolver) ImagesUpdate(ctx context.Context, input []*models.ImageUpdateInput) (ret []*models.Image, err error) { func (r *mutationResolver) ImagesUpdate(ctx context.Context, input []*models.ImageUpdateInput) (ret []*models.Image, err error) {
@@ -50,7 +64,23 @@ func (r *mutationResolver) ImagesUpdate(ctx context.Context, input []*models.Ima
return nil, err return nil, err
} }
return ret, nil // execute post hooks outside txn
var newRet []*models.Image
for i, image := range ret {
translator := changesetTranslator{
inputMap: inputMaps[i],
}
r.hookExecutor.ExecutePostHooks(ctx, image.ID, plugin.ImageUpdatePost, input, translator.getFields())
image, err = r.getImage(ctx, image.ID)
if err != nil {
return nil, err
}
newRet = append(newRet, image)
}
return newRet, nil
} }
func (r *mutationResolver) imageUpdate(input models.ImageUpdateInput, translator changesetTranslator, repo models.Repository) (*models.Image, error) { func (r *mutationResolver) imageUpdate(input models.ImageUpdateInput, translator changesetTranslator, repo models.Repository) (*models.Image, error) {
@@ -202,7 +232,20 @@ func (r *mutationResolver) BulkImageUpdate(ctx context.Context, input models.Bul
return nil, err return nil, err
} }
return ret, nil // execute post hooks outside of txn
var newRet []*models.Image
for _, image := range ret {
r.hookExecutor.ExecutePostHooks(ctx, image.ID, plugin.ImageUpdatePost, input, translator.getFields())
image, err = r.getImage(ctx, image.ID)
if err != nil {
return nil, err
}
newRet = append(newRet, image)
}
return newRet, nil
} }
func adjustImageGalleryIDs(qb models.ImageReader, imageID int, ids models.BulkUpdateIds) (ret []int, err error) { func adjustImageGalleryIDs(qb models.ImageReader, imageID int, ids models.BulkUpdateIds) (ret []int, err error) {
@@ -268,6 +311,9 @@ func (r *mutationResolver) ImageDestroy(ctx context.Context, input models.ImageD
manager.DeleteImageFile(image) manager.DeleteImageFile(image)
} }
// call post hook after performing the other actions
r.hookExecutor.ExecutePostHooks(ctx, image.ID, plugin.ImageDestroyPost, input, nil)
return true, nil return true, nil
} }
@@ -315,6 +361,9 @@ func (r *mutationResolver) ImagesDestroy(ctx context.Context, input models.Image
if input.DeleteFile != nil && *input.DeleteFile { if input.DeleteFile != nil && *input.DeleteFile {
manager.DeleteImageFile(image) manager.DeleteImageFile(image)
} }
// call post hook after performing the other actions
r.hookExecutor.ExecutePostHooks(ctx, image.ID, plugin.ImageDestroyPost, input, nil)
} }
return true, nil return true, nil

View File

@@ -17,7 +17,7 @@ import (
) )
func (r *mutationResolver) MetadataScan(ctx context.Context, input models.ScanMetadataInput) (string, error) { func (r *mutationResolver) MetadataScan(ctx context.Context, input models.ScanMetadataInput) (string, error) {
jobID, err := manager.GetInstance().Scan(input) jobID, err := manager.GetInstance().Scan(ctx, input)
if err != nil { if err != nil {
return "", err return "", err
@@ -27,7 +27,7 @@ func (r *mutationResolver) MetadataScan(ctx context.Context, input models.ScanMe
} }
func (r *mutationResolver) MetadataImport(ctx context.Context) (string, error) { func (r *mutationResolver) MetadataImport(ctx context.Context) (string, error) {
jobID, err := manager.GetInstance().Import() jobID, err := manager.GetInstance().Import(ctx)
if err != nil { if err != nil {
return "", err return "", err
} }
@@ -41,13 +41,13 @@ func (r *mutationResolver) ImportObjects(ctx context.Context, input models.Impor
return "", err return "", err
} }
jobID := manager.GetInstance().RunSingleTask(t) jobID := manager.GetInstance().RunSingleTask(ctx, t)
return strconv.Itoa(jobID), nil return strconv.Itoa(jobID), nil
} }
func (r *mutationResolver) MetadataExport(ctx context.Context) (string, error) { func (r *mutationResolver) MetadataExport(ctx context.Context) (string, error) {
jobID, err := manager.GetInstance().Export() jobID, err := manager.GetInstance().Export(ctx)
if err != nil { if err != nil {
return "", err return "", err
} }
@@ -75,7 +75,7 @@ func (r *mutationResolver) ExportObjects(ctx context.Context, input models.Expor
} }
func (r *mutationResolver) MetadataGenerate(ctx context.Context, input models.GenerateMetadataInput) (string, error) { func (r *mutationResolver) MetadataGenerate(ctx context.Context, input models.GenerateMetadataInput) (string, error) {
jobID, err := manager.GetInstance().Generate(input) jobID, err := manager.GetInstance().Generate(ctx, input)
if err != nil { if err != nil {
return "", err return "", err
@@ -85,17 +85,17 @@ func (r *mutationResolver) MetadataGenerate(ctx context.Context, input models.Ge
} }
func (r *mutationResolver) MetadataAutoTag(ctx context.Context, input models.AutoTagMetadataInput) (string, error) { func (r *mutationResolver) MetadataAutoTag(ctx context.Context, input models.AutoTagMetadataInput) (string, error) {
jobID := manager.GetInstance().AutoTag(input) jobID := manager.GetInstance().AutoTag(ctx, input)
return strconv.Itoa(jobID), nil return strconv.Itoa(jobID), nil
} }
func (r *mutationResolver) MetadataClean(ctx context.Context, input models.CleanMetadataInput) (string, error) { func (r *mutationResolver) MetadataClean(ctx context.Context, input models.CleanMetadataInput) (string, error) {
jobID := manager.GetInstance().Clean(input) jobID := manager.GetInstance().Clean(ctx, input)
return strconv.Itoa(jobID), nil return strconv.Itoa(jobID), nil
} }
func (r *mutationResolver) MigrateHashNaming(ctx context.Context) (string, error) { func (r *mutationResolver) MigrateHashNaming(ctx context.Context) (string, error) {
jobID := manager.GetInstance().MigrateHash() jobID := manager.GetInstance().MigrateHash(ctx)
return strconv.Itoa(jobID), nil return strconv.Itoa(jobID), nil
} }

View File

@@ -7,9 +7,21 @@ import (
"time" "time"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/plugin"
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
) )
func (r *mutationResolver) getMovie(ctx context.Context, id int) (ret *models.Movie, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Movie().Find(id)
return err
}); err != nil {
return nil, err
}
return ret, nil
}
func (r *mutationResolver) MovieCreate(ctx context.Context, input models.MovieCreateInput) (*models.Movie, error) { func (r *mutationResolver) MovieCreate(ctx context.Context, input models.MovieCreateInput) (*models.Movie, error) {
// generate checksum from movie name rather than image // generate checksum from movie name rather than image
checksum := utils.MD5FromString(input.Name) checksum := utils.MD5FromString(input.Name)
@@ -104,7 +116,8 @@ func (r *mutationResolver) MovieCreate(ctx context.Context, input models.MovieCr
return nil, err return nil, err
} }
return movie, nil r.hookExecutor.ExecutePostHooks(ctx, movie.ID, plugin.MovieCreatePost, input, nil)
return r.getMovie(ctx, movie.ID)
} }
func (r *mutationResolver) MovieUpdate(ctx context.Context, input models.MovieUpdateInput) (*models.Movie, error) { func (r *mutationResolver) MovieUpdate(ctx context.Context, input models.MovieUpdateInput) (*models.Movie, error) {
@@ -203,7 +216,8 @@ func (r *mutationResolver) MovieUpdate(ctx context.Context, input models.MovieUp
return nil, err return nil, err
} }
return movie, nil r.hookExecutor.ExecutePostHooks(ctx, movie.ID, plugin.MovieUpdatePost, input, translator.getFields())
return r.getMovie(ctx, movie.ID)
} }
func (r *mutationResolver) MovieDestroy(ctx context.Context, input models.MovieDestroyInput) (bool, error) { func (r *mutationResolver) MovieDestroy(ctx context.Context, input models.MovieDestroyInput) (bool, error) {
@@ -217,6 +231,9 @@ func (r *mutationResolver) MovieDestroy(ctx context.Context, input models.MovieD
}); err != nil { }); err != nil {
return false, err return false, err
} }
r.hookExecutor.ExecutePostHooks(ctx, id, plugin.MovieDestroyPost, input, nil)
return true, nil return true, nil
} }
@@ -238,5 +255,10 @@ func (r *mutationResolver) MoviesDestroy(ctx context.Context, movieIDs []string)
}); err != nil { }); err != nil {
return false, err return false, err
} }
for _, id := range ids {
r.hookExecutor.ExecutePostHooks(ctx, id, plugin.MovieDestroyPost, movieIDs, nil)
}
return true, nil return true, nil
} }

View File

@@ -9,9 +9,21 @@ import (
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/performer" "github.com/stashapp/stash/pkg/performer"
"github.com/stashapp/stash/pkg/plugin"
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
) )
func (r *mutationResolver) getPerformer(ctx context.Context, id int) (ret *models.Performer, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Performer().Find(id)
return err
}); err != nil {
return nil, err
}
return ret, nil
}
func (r *mutationResolver) PerformerCreate(ctx context.Context, input models.PerformerCreateInput) (*models.Performer, error) { func (r *mutationResolver) PerformerCreate(ctx context.Context, input models.PerformerCreateInput) (*models.Performer, error) {
// generate checksum from performer name rather than image // generate checksum from performer name rather than image
checksum := utils.MD5FromString(input.Name) checksum := utils.MD5FromString(input.Name)
@@ -146,7 +158,8 @@ func (r *mutationResolver) PerformerCreate(ctx context.Context, input models.Per
return nil, err return nil, err
} }
return performer, nil r.hookExecutor.ExecutePostHooks(ctx, performer.ID, plugin.PerformerCreatePost, input, nil)
return r.getPerformer(ctx, performer.ID)
} }
func (r *mutationResolver) PerformerUpdate(ctx context.Context, input models.PerformerUpdateInput) (*models.Performer, error) { func (r *mutationResolver) PerformerUpdate(ctx context.Context, input models.PerformerUpdateInput) (*models.Performer, error) {
@@ -267,7 +280,8 @@ func (r *mutationResolver) PerformerUpdate(ctx context.Context, input models.Per
return nil, err return nil, err
} }
return p, nil r.hookExecutor.ExecutePostHooks(ctx, p.ID, plugin.PerformerUpdatePost, input, translator.getFields())
return r.getPerformer(ctx, p.ID)
} }
func (r *mutationResolver) updatePerformerTags(qb models.PerformerReaderWriter, performerID int, tagsIDs []string) error { func (r *mutationResolver) updatePerformerTags(qb models.PerformerReaderWriter, performerID int, tagsIDs []string) error {
@@ -372,7 +386,20 @@ func (r *mutationResolver) BulkPerformerUpdate(ctx context.Context, input models
return nil, err return nil, err
} }
return ret, nil // execute post hooks outside of txn
var newRet []*models.Performer
for _, performer := range ret {
r.hookExecutor.ExecutePostHooks(ctx, performer.ID, plugin.ImageUpdatePost, input, translator.getFields())
performer, err = r.getPerformer(ctx, performer.ID)
if err != nil {
return nil, err
}
newRet = append(newRet, performer)
}
return newRet, nil
} }
func (r *mutationResolver) PerformerDestroy(ctx context.Context, input models.PerformerDestroyInput) (bool, error) { func (r *mutationResolver) PerformerDestroy(ctx context.Context, input models.PerformerDestroyInput) (bool, error) {
@@ -386,6 +413,9 @@ func (r *mutationResolver) PerformerDestroy(ctx context.Context, input models.Pe
}); err != nil { }); err != nil {
return false, err return false, err
} }
r.hookExecutor.ExecutePostHooks(ctx, id, plugin.PerformerDestroyPost, input, nil)
return true, nil return true, nil
} }
@@ -407,5 +437,10 @@ func (r *mutationResolver) PerformersDestroy(ctx context.Context, performerIDs [
}); err != nil { }); err != nil {
return false, err return false, err
} }
for _, id := range ids {
r.hookExecutor.ExecutePostHooks(ctx, id, plugin.PerformerDestroyPost, performerIDs, nil)
}
return true, nil return true, nil
} }

View File

@@ -2,40 +2,15 @@ package api
import ( import (
"context" "context"
"net/http"
"github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/manager" "github.com/stashapp/stash/pkg/manager"
"github.com/stashapp/stash/pkg/manager/config"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/plugin/common"
) )
func (r *mutationResolver) RunPluginTask(ctx context.Context, pluginID string, taskName string, args []*models.PluginArgInput) (string, error) { func (r *mutationResolver) RunPluginTask(ctx context.Context, pluginID string, taskName string, args []*models.PluginArgInput) (string, error) {
currentUser := getCurrentUserID(ctx) m := manager.GetInstance()
m.RunPluginTask(ctx, pluginID, taskName, args)
var cookie *http.Cookie
var err error
if currentUser != nil {
cookie, err = createSessionCookie(*currentUser)
if err != nil {
return "", err
}
}
config := config.GetInstance()
serverConnection := common.StashServerConnection{
Scheme: "http",
Port: config.GetPort(),
SessionCookie: cookie,
Dir: config.GetConfigPath(),
}
if HasTLSConfig() {
serverConnection.Scheme = "https"
}
manager.GetInstance().RunPluginTask(pluginID, taskName, args, serverConnection)
return "todo", nil return "todo", nil
} }

View File

@@ -10,9 +10,21 @@ import (
"github.com/stashapp/stash/pkg/manager" "github.com/stashapp/stash/pkg/manager"
"github.com/stashapp/stash/pkg/manager/config" "github.com/stashapp/stash/pkg/manager/config"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/plugin"
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
) )
func (r *mutationResolver) getScene(ctx context.Context, id int) (ret *models.Scene, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Scene().Find(id)
return err
}); err != nil {
return nil, err
}
return ret, nil
}
func (r *mutationResolver) SceneUpdate(ctx context.Context, input models.SceneUpdateInput) (ret *models.Scene, err error) { func (r *mutationResolver) SceneUpdate(ctx context.Context, input models.SceneUpdateInput) (ret *models.Scene, err error) {
translator := changesetTranslator{ translator := changesetTranslator{
inputMap: getUpdateInputMap(ctx), inputMap: getUpdateInputMap(ctx),
@@ -26,7 +38,8 @@ func (r *mutationResolver) SceneUpdate(ctx context.Context, input models.SceneUp
return nil, err return nil, err
} }
return ret, nil r.hookExecutor.ExecutePostHooks(ctx, ret.ID, plugin.SceneUpdatePost, input, translator.getFields())
return r.getScene(ctx, ret.ID)
} }
func (r *mutationResolver) ScenesUpdate(ctx context.Context, input []*models.SceneUpdateInput) (ret []*models.Scene, err error) { func (r *mutationResolver) ScenesUpdate(ctx context.Context, input []*models.SceneUpdateInput) (ret []*models.Scene, err error) {
@@ -52,7 +65,24 @@ func (r *mutationResolver) ScenesUpdate(ctx context.Context, input []*models.Sce
return nil, err return nil, err
} }
return ret, nil // execute post hooks outside of txn
var newRet []*models.Scene
for i, scene := range ret {
translator := changesetTranslator{
inputMap: inputMaps[i],
}
r.hookExecutor.ExecutePostHooks(ctx, scene.ID, plugin.SceneUpdatePost, input, translator.getFields())
scene, err = r.getScene(ctx, scene.ID)
if err != nil {
return nil, err
}
newRet = append(newRet, scene)
}
return newRet, nil
} }
func (r *mutationResolver) sceneUpdate(input models.SceneUpdateInput, translator changesetTranslator, repo models.Repository) (*models.Scene, error) { func (r *mutationResolver) sceneUpdate(input models.SceneUpdateInput, translator changesetTranslator, repo models.Repository) (*models.Scene, error) {
@@ -281,7 +311,20 @@ func (r *mutationResolver) BulkSceneUpdate(ctx context.Context, input models.Bul
return nil, err return nil, err
} }
return ret, nil // execute post hooks outside of txn
var newRet []*models.Scene
for _, scene := range ret {
r.hookExecutor.ExecutePostHooks(ctx, scene.ID, plugin.SceneUpdatePost, input, translator.getFields())
scene, err = r.getScene(ctx, scene.ID)
if err != nil {
return nil, err
}
newRet = append(newRet, scene)
}
return newRet, nil
} }
func adjustIDs(existingIDs []int, updateIDs models.BulkUpdateIds) []int { func adjustIDs(existingIDs []int, updateIDs models.BulkUpdateIds) []int {
@@ -393,6 +436,9 @@ func (r *mutationResolver) SceneDestroy(ctx context.Context, input models.SceneD
manager.DeleteSceneFile(scene) manager.DeleteSceneFile(scene)
} }
// call post hook after performing the other actions
r.hookExecutor.ExecutePostHooks(ctx, scene.ID, plugin.SceneDestroyPost, input, nil)
return true, nil return true, nil
} }
@@ -442,11 +488,25 @@ func (r *mutationResolver) ScenesDestroy(ctx context.Context, input models.Scene
if input.DeleteFile != nil && *input.DeleteFile { if input.DeleteFile != nil && *input.DeleteFile {
manager.DeleteSceneFile(scene) manager.DeleteSceneFile(scene)
} }
// call post hook after performing the other actions
r.hookExecutor.ExecutePostHooks(ctx, scene.ID, plugin.SceneDestroyPost, input, nil)
} }
return true, nil return true, nil
} }
func (r *mutationResolver) getSceneMarker(ctx context.Context, id int) (ret *models.SceneMarker, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.SceneMarker().Find(id)
return err
}); err != nil {
return nil, err
}
return ret, nil
}
func (r *mutationResolver) SceneMarkerCreate(ctx context.Context, input models.SceneMarkerCreateInput) (*models.SceneMarker, error) { func (r *mutationResolver) SceneMarkerCreate(ctx context.Context, input models.SceneMarkerCreateInput) (*models.SceneMarker, error) {
primaryTagID, err := strconv.Atoi(input.PrimaryTagID) primaryTagID, err := strconv.Atoi(input.PrimaryTagID)
if err != nil { if err != nil {
@@ -473,7 +533,13 @@ func (r *mutationResolver) SceneMarkerCreate(ctx context.Context, input models.S
return nil, err return nil, err
} }
return r.changeMarker(ctx, create, newSceneMarker, tagIDs) ret, err := r.changeMarker(ctx, create, newSceneMarker, tagIDs)
if err != nil {
return nil, err
}
r.hookExecutor.ExecutePostHooks(ctx, ret.ID, plugin.SceneMarkerCreatePost, input, nil)
return r.getSceneMarker(ctx, ret.ID)
} }
func (r *mutationResolver) SceneMarkerUpdate(ctx context.Context, input models.SceneMarkerUpdateInput) (*models.SceneMarker, error) { func (r *mutationResolver) SceneMarkerUpdate(ctx context.Context, input models.SceneMarkerUpdateInput) (*models.SceneMarker, error) {
@@ -507,7 +573,16 @@ func (r *mutationResolver) SceneMarkerUpdate(ctx context.Context, input models.S
return nil, err return nil, err
} }
return r.changeMarker(ctx, update, updatedSceneMarker, tagIDs) ret, err := r.changeMarker(ctx, update, updatedSceneMarker, tagIDs)
if err != nil {
return nil, err
}
translator := changesetTranslator{
inputMap: getUpdateInputMap(ctx),
}
r.hookExecutor.ExecutePostHooks(ctx, ret.ID, plugin.SceneMarkerUpdatePost, input, translator.getFields())
return r.getSceneMarker(ctx, ret.ID)
} }
func (r *mutationResolver) SceneMarkerDestroy(ctx context.Context, id string) (bool, error) { func (r *mutationResolver) SceneMarkerDestroy(ctx context.Context, id string) (bool, error) {
@@ -544,6 +619,8 @@ func (r *mutationResolver) SceneMarkerDestroy(ctx context.Context, id string) (b
postCommitFunc() postCommitFunc()
r.hookExecutor.ExecutePostHooks(ctx, markerID, plugin.SceneMarkerDestroyPost, id, nil)
return true, nil return true, nil
} }
@@ -651,9 +728,9 @@ func (r *mutationResolver) SceneResetO(ctx context.Context, id string) (ret int,
func (r *mutationResolver) SceneGenerateScreenshot(ctx context.Context, id string, at *float64) (string, error) { func (r *mutationResolver) SceneGenerateScreenshot(ctx context.Context, id string, at *float64) (string, error) {
if at != nil { if at != nil {
manager.GetInstance().GenerateScreenshot(id, *at) manager.GetInstance().GenerateScreenshot(ctx, id, *at)
} else { } else {
manager.GetInstance().GenerateDefaultScreenshot(id) manager.GetInstance().GenerateDefaultScreenshot(ctx, id)
} }
return "todo", nil return "todo", nil

View File

@@ -24,6 +24,6 @@ func (r *mutationResolver) SubmitStashBoxFingerprints(ctx context.Context, input
} }
func (r *mutationResolver) StashBoxBatchPerformerTag(ctx context.Context, input models.StashBoxBatchPerformerTagInput) (string, error) { func (r *mutationResolver) StashBoxBatchPerformerTag(ctx context.Context, input models.StashBoxBatchPerformerTagInput) (string, error) {
jobID := manager.GetInstance().StashBoxBatchPerformerTag(input) jobID := manager.GetInstance().StashBoxBatchPerformerTag(ctx, input)
return strconv.Itoa(jobID), nil return strconv.Itoa(jobID), nil
} }

View File

@@ -8,9 +8,21 @@ import (
"github.com/stashapp/stash/pkg/manager" "github.com/stashapp/stash/pkg/manager"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/plugin"
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
) )
func (r *mutationResolver) getStudio(ctx context.Context, id int) (ret *models.Studio, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Studio().Find(id)
return err
}); err != nil {
return nil, err
}
return ret, nil
}
func (r *mutationResolver) StudioCreate(ctx context.Context, input models.StudioCreateInput) (*models.Studio, error) { func (r *mutationResolver) StudioCreate(ctx context.Context, input models.StudioCreateInput) (*models.Studio, error) {
// generate checksum from studio name rather than image // generate checksum from studio name rather than image
checksum := utils.MD5FromString(input.Name) checksum := utils.MD5FromString(input.Name)
@@ -82,7 +94,8 @@ func (r *mutationResolver) StudioCreate(ctx context.Context, input models.Studio
return nil, err return nil, err
} }
return studio, nil r.hookExecutor.ExecutePostHooks(ctx, studio.ID, plugin.StudioCreatePost, input, nil)
return r.getStudio(ctx, studio.ID)
} }
func (r *mutationResolver) StudioUpdate(ctx context.Context, input models.StudioUpdateInput) (*models.Studio, error) { func (r *mutationResolver) StudioUpdate(ctx context.Context, input models.StudioUpdateInput) (*models.Studio, error) {
@@ -162,7 +175,8 @@ func (r *mutationResolver) StudioUpdate(ctx context.Context, input models.Studio
return nil, err return nil, err
} }
return studio, nil r.hookExecutor.ExecutePostHooks(ctx, studio.ID, plugin.StudioUpdatePost, input, translator.getFields())
return r.getStudio(ctx, studio.ID)
} }
func (r *mutationResolver) StudioDestroy(ctx context.Context, input models.StudioDestroyInput) (bool, error) { func (r *mutationResolver) StudioDestroy(ctx context.Context, input models.StudioDestroyInput) (bool, error) {
@@ -176,6 +190,9 @@ func (r *mutationResolver) StudioDestroy(ctx context.Context, input models.Studi
}); err != nil { }); err != nil {
return false, err return false, err
} }
r.hookExecutor.ExecutePostHooks(ctx, id, plugin.StudioDestroyPost, input, nil)
return true, nil return true, nil
} }
@@ -197,5 +214,10 @@ func (r *mutationResolver) StudiosDestroy(ctx context.Context, studioIDs []strin
}); err != nil { }); err != nil {
return false, err return false, err
} }
for _, id := range ids {
r.hookExecutor.ExecutePostHooks(ctx, id, plugin.StudioDestroyPost, studioIDs, nil)
}
return true, nil return true, nil
} }

View File

@@ -7,10 +7,22 @@ import (
"time" "time"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/plugin"
"github.com/stashapp/stash/pkg/tag" "github.com/stashapp/stash/pkg/tag"
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
) )
func (r *mutationResolver) getTag(ctx context.Context, id int) (ret *models.Tag, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Tag().Find(id)
return err
}); err != nil {
return nil, err
}
return ret, nil
}
func (r *mutationResolver) TagCreate(ctx context.Context, input models.TagCreateInput) (*models.Tag, error) { func (r *mutationResolver) TagCreate(ctx context.Context, input models.TagCreateInput) (*models.Tag, error) {
// Populate a new tag from the input // Populate a new tag from the input
currentTime := time.Now() currentTime := time.Now()
@@ -68,7 +80,8 @@ func (r *mutationResolver) TagCreate(ctx context.Context, input models.TagCreate
return nil, err return nil, err
} }
return t, nil r.hookExecutor.ExecutePostHooks(ctx, t.ID, plugin.TagCreatePost, input, nil)
return r.getTag(ctx, t.ID)
} }
func (r *mutationResolver) TagUpdate(ctx context.Context, input models.TagUpdateInput) (*models.Tag, error) { func (r *mutationResolver) TagUpdate(ctx context.Context, input models.TagUpdateInput) (*models.Tag, error) {
@@ -153,7 +166,8 @@ func (r *mutationResolver) TagUpdate(ctx context.Context, input models.TagUpdate
return nil, err return nil, err
} }
return t, nil r.hookExecutor.ExecutePostHooks(ctx, t.ID, plugin.TagUpdatePost, input, translator.getFields())
return r.getTag(ctx, t.ID)
} }
func (r *mutationResolver) TagDestroy(ctx context.Context, input models.TagDestroyInput) (bool, error) { func (r *mutationResolver) TagDestroy(ctx context.Context, input models.TagDestroyInput) (bool, error) {
@@ -167,6 +181,9 @@ func (r *mutationResolver) TagDestroy(ctx context.Context, input models.TagDestr
}); err != nil { }); err != nil {
return false, err return false, err
} }
r.hookExecutor.ExecutePostHooks(ctx, tagID, plugin.TagDestroyPost, input, nil)
return true, nil return true, nil
} }
@@ -188,5 +205,10 @@ func (r *mutationResolver) TagsDestroy(ctx context.Context, tagIDs []string) (bo
}); err != nil { }); err != nil {
return false, err return false, err
} }
for _, id := range ids {
r.hookExecutor.ExecutePostHooks(ctx, id, plugin.TagDestroyPost, tagIDs, nil)
}
return true, nil return true, nil
} }

View File

@@ -7,6 +7,7 @@ import (
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/models/mocks" "github.com/stashapp/stash/pkg/models/mocks"
"github.com/stashapp/stash/pkg/plugin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
@@ -15,7 +16,8 @@ import (
// TODO - move this into a common area // TODO - move this into a common area
func newResolver() *Resolver { func newResolver() *Resolver {
return &Resolver{ return &Resolver{
txnManager: mocks.NewTransactionManager(), txnManager: mocks.NewTransactionManager(),
hookExecutor: &mockHookExecutor{},
} }
} }
@@ -26,6 +28,11 @@ const existingTagID = 1
const existingTagName = "existingTagName" const existingTagName = "existingTagName"
const newTagID = 2 const newTagID = 2
type mockHookExecutor struct{}
func (*mockHookExecutor) ExecutePostHooks(ctx context.Context, id int, hookType plugin.HookTriggerEnum, input interface{}, inputFields []string) {
}
func TestTagCreate(t *testing.T) { func TestTagCreate(t *testing.T) {
r := newResolver() r := newResolver()
@@ -84,10 +91,12 @@ func TestTagCreate(t *testing.T) {
tagRW.On("Query", tagFilterForName(tagName), findFilter).Return(nil, 0, nil).Once() tagRW.On("Query", tagFilterForName(tagName), findFilter).Return(nil, 0, nil).Once()
tagRW.On("Query", tagFilterForAlias(tagName), findFilter).Return(nil, 0, nil).Once() tagRW.On("Query", tagFilterForAlias(tagName), findFilter).Return(nil, 0, nil).Once()
tagRW.On("Create", mock.AnythingOfType("models.Tag")).Return(&models.Tag{ newTag := &models.Tag{
ID: newTagID, ID: newTagID,
Name: tagName, Name: tagName,
}, nil) }
tagRW.On("Create", mock.AnythingOfType("models.Tag")).Return(newTag, nil)
tagRW.On("Find", newTagID).Return(newTag, nil)
tag, err := r.Mutation().TagCreate(context.TODO(), models.TagCreateInput{ tag, err := r.Mutation().TagCreate(context.TODO(), models.TagCreateInput{
Name: tagName, Name: tagName,

View File

@@ -29,6 +29,7 @@ import (
"github.com/stashapp/stash/pkg/manager/config" "github.com/stashapp/stash/pkg/manager/config"
"github.com/stashapp/stash/pkg/manager/paths" "github.com/stashapp/stash/pkg/manager/paths"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/session"
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
) )
@@ -41,11 +42,6 @@ var uiBox *packr.Box
//var legacyUiBox *packr.Box //var legacyUiBox *packr.Box
var loginUIBox *packr.Box var loginUIBox *packr.Box
const (
ApiKeyHeader = "ApiKey"
ApiKeyParameter = "apikey"
)
func allowUnauthenticated(r *http.Request) bool { func allowUnauthenticated(r *http.Request) bool {
return strings.HasPrefix(r.URL.Path, "/login") || r.URL.Path == "/css" return strings.HasPrefix(r.URL.Path, "/login") || r.URL.Path == "/css"
} }
@@ -53,44 +49,26 @@ func allowUnauthenticated(r *http.Request) bool {
func authenticateHandler() func(http.Handler) http.Handler { func authenticateHandler() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c := config.GetInstance() userID, err := manager.GetInstance().SessionStore.Authenticate(w, r)
ctx := r.Context() if err != nil {
if err != session.ErrUnauthorized {
// translate api key into current user, if present w.WriteHeader(http.StatusInternalServerError)
userID := "" _, err = w.Write([]byte(err.Error()))
apiKey := r.Header.Get(ApiKeyHeader) if err != nil {
var err error logger.Error(err)
}
// try getting the api key as a query parameter
if apiKey == "" {
apiKey = r.URL.Query().Get(ApiKeyParameter)
}
if apiKey != "" {
// match against configured API and set userID to the
// configured username. In future, we'll want to
// get the username from the key.
if c.GetAPIKey() != apiKey {
w.Header().Add("WWW-Authenticate", `FormBased`)
w.WriteHeader(http.StatusUnauthorized)
return return
} }
userID = c.GetUsername() // unauthorized error
} else { w.Header().Add("WWW-Authenticate", `FormBased`)
// handle session w.WriteHeader(http.StatusUnauthorized)
userID, err = getSessionUserID(w, r)
}
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
_, err = w.Write([]byte(err.Error()))
if err != nil {
logger.Error(err)
}
return return
} }
c := config.GetInstance()
ctx := r.Context()
// handle redirect if no user and user is required // handle redirect if no user and user is required
if userID == "" && c.HasCredentials() && !allowUnauthenticated(r) { if userID == "" && c.HasCredentials() && !allowUnauthenticated(r) {
// if we don't have a userID, then redirect // if we don't have a userID, then redirect
@@ -112,7 +90,7 @@ func authenticateHandler() func(http.Handler) http.Handler {
return return
} }
ctx = context.WithValue(ctx, ContextUser, userID) ctx = session.SetCurrentUserID(ctx, userID)
r = r.WithContext(ctx) r = r.WithContext(ctx)
@@ -121,6 +99,16 @@ func authenticateHandler() func(http.Handler) http.Handler {
} }
} }
func visitedPluginHandler() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// get the visited plugins and set them in the context
next.ServeHTTP(w, r)
})
}
}
const loginEndPoint = "/login" const loginEndPoint = "/login"
func Start() { func Start() {
@@ -128,13 +116,15 @@ func Start() {
//legacyUiBox = packr.New("UI Box", "../../ui/v1/dist/stash-frontend") //legacyUiBox = packr.New("UI Box", "../../ui/v1/dist/stash-frontend")
loginUIBox = packr.New("Login UI Box", "../../ui/login") loginUIBox = packr.New("Login UI Box", "../../ui/login")
initSessionStore()
initialiseImages() initialiseImages()
r := chi.NewRouter() r := chi.NewRouter()
r.Use(middleware.Heartbeat("/healthz")) r.Use(middleware.Heartbeat("/healthz"))
r.Use(authenticateHandler()) r.Use(authenticateHandler())
visitedPluginHandler := manager.GetInstance().SessionStore.VisitedPluginHandler()
r.Use(visitedPluginHandler)
r.Use(middleware.Recoverer) r.Use(middleware.Recoverer)
c := config.GetInstance() c := config.GetInstance()
@@ -155,8 +145,10 @@ func Start() {
} }
txnManager := manager.GetInstance().TxnManager txnManager := manager.GetInstance().TxnManager
pluginCache := manager.GetInstance().PluginCache
resolver := &Resolver{ resolver := &Resolver{
txnManager: txnManager, txnManager: txnManager,
hookExecutor: pluginCache,
} }
gqlSrv := gqlHandler.New(models.NewExecutableSchema(models.Config{Resolvers: resolver})) gqlSrv := gqlHandler.New(models.NewExecutableSchema(models.Config{Resolvers: resolver}))
@@ -184,7 +176,8 @@ func Start() {
} }
// register GQL handler with plugin cache // register GQL handler with plugin cache
manager.GetInstance().PluginCache.RegisterGQLHandler(gqlHandlerFunc) // chain the visited plugin handler
manager.GetInstance().PluginCache.RegisterGQLHandler(visitedPluginHandler(http.HandlerFunc(gqlHandlerFunc)))
r.HandleFunc("/graphql", gqlHandlerFunc) r.HandleFunc("/graphql", gqlHandlerFunc)
r.HandleFunc("/playground", gqlPlayground.Handler("GraphQL playground", "/graphql")) r.HandleFunc("/playground", gqlPlayground.Handler("GraphQL playground", "/graphql"))
@@ -358,15 +351,6 @@ func makeTLSConfig() *tls.Config {
return tlsConfig return tlsConfig
} }
func HasTLSConfig() bool {
ret, _ := utils.FileExists(paths.GetSSLCert())
if ret {
ret, _ = utils.FileExists(paths.GetSSLKey())
}
return ret
}
type contextKey struct { type contextKey struct {
name string name string
} }

View File

@@ -1,15 +1,13 @@
package api package api
import ( import (
"context"
"fmt" "fmt"
"html/template" "html/template"
"net/http" "net/http"
"github.com/stashapp/stash/pkg/manager"
"github.com/stashapp/stash/pkg/manager/config" "github.com/stashapp/stash/pkg/manager/config"
"github.com/stashapp/stash/pkg/session"
"github.com/gorilla/securecookie"
"github.com/gorilla/sessions"
) )
const cookieName = "session" const cookieName = "session"
@@ -19,17 +17,11 @@ const userIDKey = "userID"
const returnURLParam = "returnURL" const returnURLParam = "returnURL"
var sessionStore = sessions.NewCookieStore(config.GetInstance().GetSessionStoreKey())
type loginTemplateData struct { type loginTemplateData struct {
URL string URL string
Error string Error string
} }
func initSessionStore() {
sessionStore.MaxAge(config.GetInstance().GetMaxSessionAge())
}
func redirectToLogin(w http.ResponseWriter, returnURL string, loginError string) { func redirectToLogin(w http.ResponseWriter, returnURL string, loginError string) {
data, _ := loginUIBox.Find("login.html") data, _ := loginUIBox.Find("login.html")
templ, err := template.New("Login").Parse(string(data)) templ, err := template.New("Login").Parse(string(data))
@@ -59,22 +51,13 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
url = "/" url = "/"
} }
// ignore error - we want a new session regardless err := manager.GetInstance().SessionStore.Login(w, r)
newSession, _ := sessionStore.Get(r, cookieName) if err == session.ErrInvalidCredentials {
username := r.FormValue("username")
password := r.FormValue("password")
// authenticate the user
if !config.GetInstance().ValidateCredentials(username, password) {
// redirect back to the login page with an error // redirect back to the login page with an error
redirectToLogin(w, url, "Username or password is invalid") redirectToLogin(w, url, "Username or password is invalid")
return return
} }
newSession.Values[userIDKey] = username
err := newSession.Save(r, w)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
@@ -84,17 +67,7 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
} }
func handleLogout(w http.ResponseWriter, r *http.Request) { func handleLogout(w http.ResponseWriter, r *http.Request) {
session, err := sessionStore.Get(r, cookieName) if err := manager.GetInstance().SessionStore.Logout(w, r); err != nil {
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
delete(session.Values, userIDKey)
session.Options.MaxAge = -1
err = session.Save(r, w)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
} }
@@ -102,51 +75,3 @@ func handleLogout(w http.ResponseWriter, r *http.Request) {
// redirect to the login page if credentials are required // redirect to the login page if credentials are required
getLoginHandler(w, r) getLoginHandler(w, r)
} }
func getSessionUserID(w http.ResponseWriter, r *http.Request) (string, error) {
session, err := sessionStore.Get(r, cookieName)
// ignore errors and treat as an empty user id, so that we handle expired
// cookie
if err != nil {
return "", nil
}
if !session.IsNew {
val := session.Values[userIDKey]
// refresh the cookie
err = session.Save(r, w)
if err != nil {
return "", err
}
ret, _ := val.(string)
return ret, nil
}
return "", nil
}
func getCurrentUserID(ctx context.Context) *string {
userCtxVal := ctx.Value(ContextUser)
if userCtxVal != nil {
currentUser := userCtxVal.(string)
return &currentUser
}
return nil
}
func createSessionCookie(username string) (*http.Cookie, error) {
session := sessions.NewSession(sessionStore, cookieName)
session.Values[userIDKey] = username
encoded, err := securecookie.EncodeMulti(session.Name(), session.Values,
sessionStore.Codecs...)
if err != nil {
return nil, err
}
return sessions.NewCookie(session.Name(), encoded, session.Options), nil
}

View File

@@ -55,6 +55,7 @@ type Job struct {
EndTime *time.Time EndTime *time.Time
AddTime time.Time AddTime time.Time
outerCtx context.Context
exec JobExec exec JobExec
cancelFunc context.CancelFunc cancelFunc context.CancelFunc
} }

View File

@@ -4,6 +4,8 @@ import (
"context" "context"
"sync" "sync"
"time" "time"
"github.com/stashapp/stash/pkg/utils"
) )
const maxGraveyardSize = 10 const maxGraveyardSize = 10
@@ -46,7 +48,7 @@ func (m *Manager) Stop() {
} }
// Add queues a job. // Add queues a job.
func (m *Manager) Add(description string, e JobExec) int { func (m *Manager) Add(ctx context.Context, description string, e JobExec) int {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@@ -58,6 +60,7 @@ func (m *Manager) Add(description string, e JobExec) int {
Description: description, Description: description,
AddTime: t, AddTime: t,
exec: e, exec: e,
outerCtx: ctx,
} }
m.queue = append(m.queue, &j) m.queue = append(m.queue, &j)
@@ -74,7 +77,7 @@ func (m *Manager) Add(description string, e JobExec) int {
// Start adds a job and starts it immediately, concurrently with any other // Start adds a job and starts it immediately, concurrently with any other
// jobs. // jobs.
func (m *Manager) Start(description string, e JobExec) int { func (m *Manager) Start(ctx context.Context, description string, e JobExec) int {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@@ -86,6 +89,7 @@ func (m *Manager) Start(description string, e JobExec) int {
Description: description, Description: description,
AddTime: t, AddTime: t,
exec: e, exec: e,
outerCtx: ctx,
} }
m.queue = append(m.queue, &j) m.queue = append(m.queue, &j)
@@ -173,7 +177,7 @@ func (m *Manager) dispatch(j *Job) (done chan struct{}) {
j.StartTime = &t j.StartTime = &t
j.Status = StatusRunning j.Status = StatusRunning
ctx, cancelFunc := context.WithCancel(context.Background()) ctx, cancelFunc := context.WithCancel(utils.ValueOnlyContext(j.outerCtx))
j.cancelFunc = cancelFunc j.cancelFunc = cancelFunc
done = make(chan struct{}) done = make(chan struct{})

View File

@@ -45,7 +45,7 @@ func TestAdd(t *testing.T) {
const jobName = "test job" const jobName = "test job"
exec1 := newTestExec(make(chan struct{})) exec1 := newTestExec(make(chan struct{}))
jobID := m.Add(jobName, exec1) jobID := m.Add(context.Background(), jobName, exec1)
// expect jobID to be the first ID // expect jobID to be the first ID
assert := assert.New(t) assert := assert.New(t)
@@ -80,7 +80,7 @@ func TestAdd(t *testing.T) {
// add another job to the queue // add another job to the queue
const otherJobName = "other job name" const otherJobName = "other job name"
exec2 := newTestExec(make(chan struct{})) exec2 := newTestExec(make(chan struct{}))
job2ID := m.Add(otherJobName, exec2) job2ID := m.Add(context.Background(), otherJobName, exec2)
// expect status to be ready // expect status to be ready
j2 := m.GetJob(job2ID) j2 := m.GetJob(job2ID)
@@ -130,11 +130,11 @@ func TestCancel(t *testing.T) {
// add two jobs // add two jobs
const jobName = "test job" const jobName = "test job"
exec1 := newTestExec(make(chan struct{})) exec1 := newTestExec(make(chan struct{}))
jobID := m.Add(jobName, exec1) jobID := m.Add(context.Background(), jobName, exec1)
const otherJobName = "other job" const otherJobName = "other job"
exec2 := newTestExec(make(chan struct{})) exec2 := newTestExec(make(chan struct{}))
job2ID := m.Add(otherJobName, exec2) job2ID := m.Add(context.Background(), otherJobName, exec2)
// wait a tiny bit // wait a tiny bit
time.Sleep(sleepTime) time.Sleep(sleepTime)
@@ -198,11 +198,11 @@ func TestCancelAll(t *testing.T) {
// add two jobs // add two jobs
const jobName = "test job" const jobName = "test job"
exec1 := newTestExec(make(chan struct{})) exec1 := newTestExec(make(chan struct{}))
jobID := m.Add(jobName, exec1) jobID := m.Add(context.Background(), jobName, exec1)
const otherJobName = "other job" const otherJobName = "other job"
exec2 := newTestExec(make(chan struct{})) exec2 := newTestExec(make(chan struct{}))
job2ID := m.Add(otherJobName, exec2) job2ID := m.Add(context.Background(), otherJobName, exec2)
// wait a tiny bit // wait a tiny bit
time.Sleep(sleepTime) time.Sleep(sleepTime)
@@ -246,7 +246,7 @@ func TestSubscribe(t *testing.T) {
// add a job // add a job
const jobName = "test job" const jobName = "test job"
exec1 := newTestExec(make(chan struct{})) exec1 := newTestExec(make(chan struct{}))
jobID := m.Add(jobName, exec1) jobID := m.Add(context.Background(), jobName, exec1)
assert := assert.New(t) assert := assert.New(t)
@@ -326,7 +326,7 @@ func TestSubscribe(t *testing.T) {
// add another job and cancel it // add another job and cancel it
exec2 := newTestExec(make(chan struct{})) exec2 := newTestExec(make(chan struct{}))
jobID = m.Add(jobName, exec2) jobID = m.Add(context.Background(), jobName, exec2)
m.CancelJob(jobID) m.CancelJob(jobID)

View File

@@ -13,6 +13,7 @@ import (
"github.com/spf13/viper" "github.com/spf13/viper"
"github.com/stashapp/stash/pkg/manager/paths"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
) )
@@ -150,6 +151,15 @@ func (e MissingConfigError) Error() string {
return fmt.Sprintf("missing the following mandatory settings: %s", strings.Join(e.missingFields, ", ")) return fmt.Sprintf("missing the following mandatory settings: %s", strings.Join(e.missingFields, ", "))
} }
func HasTLSConfig() bool {
ret, _ := utils.FileExists(paths.GetSSLCert())
if ret {
ret, _ = utils.FileExists(paths.GetSSLKey())
}
return ret
}
type Instance struct { type Instance struct {
cpuProfilePath string cpuProfilePath string
isNewSystem bool isNewSystem bool

View File

@@ -19,6 +19,7 @@ import (
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/plugin" "github.com/stashapp/stash/pkg/plugin"
"github.com/stashapp/stash/pkg/scraper" "github.com/stashapp/stash/pkg/scraper"
"github.com/stashapp/stash/pkg/session"
"github.com/stashapp/stash/pkg/sqlite" "github.com/stashapp/stash/pkg/sqlite"
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
) )
@@ -31,6 +32,8 @@ type singleton struct {
FFMPEGPath string FFMPEGPath string
FFProbePath string FFProbePath string
SessionStore *session.Store
JobManager *job.Manager JobManager *job.Manager
PluginCache *plugin.Cache PluginCache *plugin.Cache
@@ -100,6 +103,11 @@ func Initialize() *singleton {
if cfgFile != "" { if cfgFile != "" {
cfgFile = cfgFile + " " cfgFile = cfgFile + " "
} }
// create temporary session store - this will be re-initialised
// after config is complete
instance.SessionStore = session.NewStore(cfg)
logger.Warnf("config file %snot found. Assuming new system...", cfgFile) logger.Warnf("config file %snot found. Assuming new system...", cfgFile)
} }
@@ -179,6 +187,8 @@ func (s *singleton) PostInit() error {
s.Paths = paths.NewPaths(s.Config.GetGeneratedPath()) s.Paths = paths.NewPaths(s.Config.GetGeneratedPath())
s.RefreshConfig() s.RefreshConfig()
s.SessionStore = session.NewStore(s.Config)
s.PluginCache.RegisterSessionStore(s.SessionStore)
if err := s.PluginCache.LoadPlugins(); err != nil { if err := s.PluginCache.LoadPlugins(); err != nil {
logger.Errorf("Error reading plugin configs: %s", err.Error()) logger.Errorf("Error reading plugin configs: %s", err.Error())

View File

@@ -60,7 +60,7 @@ func (s *singleton) ScanSubscribe(ctx context.Context) <-chan bool {
return s.scanSubs.subscribe(ctx) return s.scanSubs.subscribe(ctx)
} }
func (s *singleton) Scan(input models.ScanMetadataInput) (int, error) { func (s *singleton) Scan(ctx context.Context, input models.ScanMetadataInput) (int, error) {
if err := s.validateFFMPEG(); err != nil { if err := s.validateFFMPEG(); err != nil {
return 0, err return 0, err
} }
@@ -71,10 +71,10 @@ func (s *singleton) Scan(input models.ScanMetadataInput) (int, error) {
subscriptions: s.scanSubs, subscriptions: s.scanSubs,
} }
return s.JobManager.Add("Scanning...", &scanJob), nil return s.JobManager.Add(ctx, "Scanning...", &scanJob), nil
} }
func (s *singleton) Import() (int, error) { func (s *singleton) Import(ctx context.Context) (int, error) {
config := config.GetInstance() config := config.GetInstance()
metadataPath := config.GetMetadataPath() metadataPath := config.GetMetadataPath()
if metadataPath == "" { if metadataPath == "" {
@@ -96,10 +96,10 @@ func (s *singleton) Import() (int, error) {
task.Start(&wg) task.Start(&wg)
}) })
return s.JobManager.Add("Importing...", j), nil return s.JobManager.Add(ctx, "Importing...", j), nil
} }
func (s *singleton) Export() (int, error) { func (s *singleton) Export(ctx context.Context) (int, error) {
config := config.GetInstance() config := config.GetInstance()
metadataPath := config.GetMetadataPath() metadataPath := config.GetMetadataPath()
if metadataPath == "" { if metadataPath == "" {
@@ -117,10 +117,10 @@ func (s *singleton) Export() (int, error) {
task.Start(&wg) task.Start(&wg)
}) })
return s.JobManager.Add("Exporting...", j), nil return s.JobManager.Add(ctx, "Exporting...", j), nil
} }
func (s *singleton) RunSingleTask(t Task) int { func (s *singleton) RunSingleTask(ctx context.Context, t Task) int {
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
@@ -128,7 +128,7 @@ func (s *singleton) RunSingleTask(t Task) int {
t.Start(&wg) t.Start(&wg)
}) })
return s.JobManager.Add(t.GetDescription(), j) return s.JobManager.Add(ctx, t.GetDescription(), j)
} }
func setGeneratePreviewOptionsInput(optionsInput *models.GeneratePreviewOptionsInput) { func setGeneratePreviewOptionsInput(optionsInput *models.GeneratePreviewOptionsInput) {
@@ -159,7 +159,7 @@ func setGeneratePreviewOptionsInput(optionsInput *models.GeneratePreviewOptionsI
} }
} }
func (s *singleton) Generate(input models.GenerateMetadataInput) (int, error) { func (s *singleton) Generate(ctx context.Context, input models.GenerateMetadataInput) (int, error) {
if err := s.validateFFMPEG(); err != nil { if err := s.validateFFMPEG(); err != nil {
return 0, err return 0, err
} }
@@ -367,19 +367,19 @@ func (s *singleton) Generate(input models.GenerateMetadataInput) (int, error) {
logger.Info(fmt.Sprintf("Generate finished (%s)", elapsed)) logger.Info(fmt.Sprintf("Generate finished (%s)", elapsed))
}) })
return s.JobManager.Add("Generating...", j), nil return s.JobManager.Add(ctx, "Generating...", j), nil
} }
func (s *singleton) GenerateDefaultScreenshot(sceneId string) int { func (s *singleton) GenerateDefaultScreenshot(ctx context.Context, sceneId string) int {
return s.generateScreenshot(sceneId, nil) return s.generateScreenshot(ctx, sceneId, nil)
} }
func (s *singleton) GenerateScreenshot(sceneId string, at float64) int { func (s *singleton) GenerateScreenshot(ctx context.Context, sceneId string, at float64) int {
return s.generateScreenshot(sceneId, &at) return s.generateScreenshot(ctx, sceneId, &at)
} }
// generate default screenshot if at is nil // generate default screenshot if at is nil
func (s *singleton) generateScreenshot(sceneId string, at *float64) int { func (s *singleton) generateScreenshot(ctx context.Context, sceneId string, at *float64) int {
instance.Paths.Generated.EnsureTmpDir() instance.Paths.Generated.EnsureTmpDir()
j := job.MakeJobExec(func(ctx context.Context, progress *job.Progress) { j := job.MakeJobExec(func(ctx context.Context, progress *job.Progress) {
@@ -413,19 +413,19 @@ func (s *singleton) generateScreenshot(sceneId string, at *float64) int {
logger.Infof("Generate screenshot finished") logger.Infof("Generate screenshot finished")
}) })
return s.JobManager.Add(fmt.Sprintf("Generating screenshot for scene id %s", sceneId), j) return s.JobManager.Add(ctx, fmt.Sprintf("Generating screenshot for scene id %s", sceneId), j)
} }
func (s *singleton) AutoTag(input models.AutoTagMetadataInput) int { func (s *singleton) AutoTag(ctx context.Context, input models.AutoTagMetadataInput) int {
j := autoTagJob{ j := autoTagJob{
txnManager: s.TxnManager, txnManager: s.TxnManager,
input: input, input: input,
} }
return s.JobManager.Add("Auto-tagging...", &j) return s.JobManager.Add(ctx, "Auto-tagging...", &j)
} }
func (s *singleton) Clean(input models.CleanMetadataInput) int { func (s *singleton) Clean(ctx context.Context, input models.CleanMetadataInput) int {
j := job.MakeJobExec(func(ctx context.Context, progress *job.Progress) { j := job.MakeJobExec(func(ctx context.Context, progress *job.Progress) {
var scenes []*models.Scene var scenes []*models.Scene
var images []*models.Image var images []*models.Image
@@ -488,6 +488,7 @@ func (s *singleton) Clean(input models.CleanMetadataInput) int {
wg.Add(1) wg.Add(1)
task := CleanTask{ task := CleanTask{
ctx: ctx,
TxnManager: s.TxnManager, TxnManager: s.TxnManager,
Scene: scene, Scene: scene,
fileNamingAlgorithm: fileNamingAlgo, fileNamingAlgorithm: fileNamingAlgo,
@@ -514,6 +515,7 @@ func (s *singleton) Clean(input models.CleanMetadataInput) int {
wg.Add(1) wg.Add(1)
task := CleanTask{ task := CleanTask{
ctx: ctx,
TxnManager: s.TxnManager, TxnManager: s.TxnManager,
Image: img, Image: img,
} }
@@ -538,6 +540,7 @@ func (s *singleton) Clean(input models.CleanMetadataInput) int {
wg.Add(1) wg.Add(1)
task := CleanTask{ task := CleanTask{
ctx: ctx,
TxnManager: s.TxnManager, TxnManager: s.TxnManager,
Gallery: gallery, Gallery: gallery,
} }
@@ -552,10 +555,10 @@ func (s *singleton) Clean(input models.CleanMetadataInput) int {
s.scanSubs.notify() s.scanSubs.notify()
}) })
return s.JobManager.Add("Cleaning...", j) return s.JobManager.Add(ctx, "Cleaning...", j)
} }
func (s *singleton) MigrateHash() int { func (s *singleton) MigrateHash(ctx context.Context) int {
j := job.MakeJobExec(func(ctx context.Context, progress *job.Progress) { j := job.MakeJobExec(func(ctx context.Context, progress *job.Progress) {
fileNamingAlgo := config.GetInstance().GetVideoFileNamingAlgorithm() fileNamingAlgo := config.GetInstance().GetVideoFileNamingAlgorithm()
logger.Infof("Migrating generated files for %s naming hash", fileNamingAlgo.String()) logger.Infof("Migrating generated files for %s naming hash", fileNamingAlgo.String())
@@ -596,7 +599,7 @@ func (s *singleton) MigrateHash() int {
logger.Info("Finished migrating") logger.Info("Finished migrating")
}) })
return s.JobManager.Add("Migrating scene hashes...", j) return s.JobManager.Add(ctx, "Migrating scene hashes...", j)
} }
type totalsGenerate struct { type totalsGenerate struct {
@@ -702,7 +705,7 @@ func (s *singleton) neededGenerate(scenes []*models.Scene, input models.Generate
return &totals return &totals
} }
func (s *singleton) StashBoxBatchPerformerTag(input models.StashBoxBatchPerformerTagInput) int { func (s *singleton) StashBoxBatchPerformerTag(ctx context.Context, input models.StashBoxBatchPerformerTagInput) int {
j := job.MakeJobExec(func(ctx context.Context, progress *job.Progress) { j := job.MakeJobExec(func(ctx context.Context, progress *job.Progress) {
logger.Infof("Initiating stash-box batch performer tag") logger.Infof("Initiating stash-box batch performer tag")
@@ -800,5 +803,5 @@ func (s *singleton) StashBoxBatchPerformerTag(input models.StashBoxBatchPerforme
} }
}) })
return s.JobManager.Add("Batch stash-box performer tag...", j) return s.JobManager.Add(ctx, "Batch stash-box performer tag...", j)
} }

View File

@@ -10,10 +10,12 @@ import (
"github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/manager/config" "github.com/stashapp/stash/pkg/manager/config"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/plugin"
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
) )
type CleanTask struct { type CleanTask struct {
ctx context.Context
TxnManager models.TransactionManager TxnManager models.TransactionManager
Scene *models.Scene Scene *models.Scene
Gallery *models.Gallery Gallery *models.Gallery
@@ -158,6 +160,8 @@ func (t *CleanTask) deleteScene(sceneID int) {
postCommitFunc() postCommitFunc()
DeleteGeneratedSceneFiles(scene, t.fileNamingAlgorithm) DeleteGeneratedSceneFiles(scene, t.fileNamingAlgorithm)
GetInstance().PluginCache.ExecutePostHooks(t.ctx, sceneID, plugin.SceneDestroyPost, nil, nil)
} }
func (t *CleanTask) deleteGallery(galleryID int) { func (t *CleanTask) deleteGallery(galleryID int) {
@@ -168,6 +172,8 @@ func (t *CleanTask) deleteGallery(galleryID int) {
logger.Errorf("Error deleting gallery from database: %s", err.Error()) logger.Errorf("Error deleting gallery from database: %s", err.Error())
return return
} }
GetInstance().PluginCache.ExecutePostHooks(t.ctx, galleryID, plugin.GalleryDestroyPost, nil, nil)
} }
func (t *CleanTask) deleteImage(imageID int) { func (t *CleanTask) deleteImage(imageID int) {
@@ -185,20 +191,8 @@ func (t *CleanTask) deleteImage(imageID int) {
if pathErr != nil { if pathErr != nil {
logger.Errorf("Error deleting thumbnail image from cache: %s", pathErr) logger.Errorf("Error deleting thumbnail image from cache: %s", pathErr)
} }
}
func (t *CleanTask) fileExists(filename string) (bool, error) { GetInstance().PluginCache.ExecutePostHooks(t.ctx, imageID, plugin.ImageDestroyPost, nil, nil)
info, err := os.Stat(filename)
if os.IsNotExist(err) {
return false, nil
}
// handle if error is something else
if err != nil {
return false, err
}
return !info.IsDir(), nil
} }
func getStashFromPath(pathToCheck string) *models.StashConfig { func getStashFromPath(pathToCheck string) *models.StashConfig {

View File

@@ -7,13 +7,12 @@ import (
"github.com/stashapp/stash/pkg/job" "github.com/stashapp/stash/pkg/job"
"github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/plugin/common"
) )
func (s *singleton) RunPluginTask(pluginID string, taskName string, args []*models.PluginArgInput, serverConnection common.StashServerConnection) int { func (s *singleton) RunPluginTask(ctx context.Context, pluginID string, taskName string, args []*models.PluginArgInput) int {
j := job.MakeJobExec(func(ctx context.Context, progress *job.Progress) { j := job.MakeJobExec(func(jobCtx context.Context, progress *job.Progress) {
pluginProgress := make(chan float64) pluginProgress := make(chan float64)
task, err := s.PluginCache.CreateTask(pluginID, taskName, serverConnection, args, pluginProgress) task, err := s.PluginCache.CreateTask(ctx, pluginID, taskName, args, pluginProgress)
if err != nil { if err != nil {
logger.Errorf("Error creating plugin task: %s", err.Error()) logger.Errorf("Error creating plugin task: %s", err.Error())
return return
@@ -48,7 +47,7 @@ func (s *singleton) RunPluginTask(pluginID string, taskName string, args []*mode
return return
case p := <-pluginProgress: case p := <-pluginProgress:
progress.SetPercent(p) progress.SetPercent(p)
case <-ctx.Done(): case <-jobCtx.Done():
if err := task.Stop(); err != nil { if err := task.Stop(); err != nil {
logger.Errorf("Error stopping plugin operation: %s", err.Error()) logger.Errorf("Error stopping plugin operation: %s", err.Error())
} }
@@ -57,5 +56,5 @@ func (s *singleton) RunPluginTask(pluginID string, taskName string, args []*mode
} }
}) })
return s.JobManager.Add(fmt.Sprintf("Running plugin task: %s", taskName), j) return s.JobManager.Add(ctx, fmt.Sprintf("Running plugin task: %s", taskName), j)
} }

View File

@@ -21,6 +21,7 @@ import (
"github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/manager/config" "github.com/stashapp/stash/pkg/manager/config"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/plugin"
"github.com/stashapp/stash/pkg/scene" "github.com/stashapp/stash/pkg/scene"
"github.com/stashapp/stash/pkg/utils" "github.com/stashapp/stash/pkg/utils"
) )
@@ -102,6 +103,7 @@ func (j *ScanJob) Execute(ctx context.Context, progress *job.Progress) {
GeneratePhash: utils.IsTrue(input.ScanGeneratePhashes), GeneratePhash: utils.IsTrue(input.ScanGeneratePhashes),
progress: progress, progress: progress,
CaseSensitiveFs: csFs, CaseSensitiveFs: csFs,
ctx: ctx,
} }
go func() { go func() {
@@ -201,6 +203,7 @@ func (j *ScanJob) neededScan(ctx context.Context, paths []*models.StashConfig) (
} }
type ScanTask struct { type ScanTask struct {
ctx context.Context
TxnManager models.TransactionManager TxnManager models.TransactionManager
FilePath string FilePath string
UseFileMetadata bool UseFileMetadata bool
@@ -424,6 +427,8 @@ func (t *ScanTask) scanGallery() {
if err != nil { if err != nil {
return err return err
} }
GetInstance().PluginCache.ExecutePostHooks(t.ctx, g.ID, plugin.GalleryUpdatePost, nil, nil)
} }
} else { } else {
currentTime := time.Now() currentTime := time.Now()
@@ -461,6 +466,8 @@ func (t *ScanTask) scanGallery() {
return err return err
} }
scanImages = true scanImages = true
GetInstance().PluginCache.ExecutePostHooks(t.ctx, g.ID, plugin.GalleryCreatePost, nil, nil)
} }
} }
@@ -787,6 +794,8 @@ func (t *ScanTask) scanScene() *models.Scene {
}); err != nil { }); err != nil {
return logError(err) return logError(err)
} }
GetInstance().PluginCache.ExecutePostHooks(t.ctx, s.ID, plugin.SceneUpdatePost, nil, nil)
} }
} else { } else {
logger.Infof("%s doesn't exist. Creating new item...", t.FilePath) logger.Infof("%s doesn't exist. Creating new item...", t.FilePath)
@@ -826,6 +835,8 @@ func (t *ScanTask) scanScene() *models.Scene {
}); err != nil { }); err != nil {
return logError(err) return logError(err)
} }
GetInstance().PluginCache.ExecutePostHooks(t.ctx, retScene.ID, plugin.SceneCreatePost, nil, nil)
} }
return retScene return retScene
@@ -895,6 +906,8 @@ func (t *ScanTask) rescanScene(s *models.Scene, fileModTime time.Time) (*models.
return nil, err return nil, err
} }
GetInstance().PluginCache.ExecutePostHooks(t.ctx, ret.ID, plugin.SceneUpdatePost, nil, nil)
// leave the generated files as is - the scene file may have been moved // leave the generated files as is - the scene file may have been moved
// elsewhere // elsewhere
@@ -1081,6 +1094,8 @@ func (t *ScanTask) scanImage() {
logger.Error(err.Error()) logger.Error(err.Error())
return return
} }
GetInstance().PluginCache.ExecutePostHooks(t.ctx, i.ID, plugin.ImageUpdatePost, nil, nil)
} }
} else { } else {
logger.Infof("%s doesn't exist. Creating new item...", image.PathDisplayName(t.FilePath)) logger.Infof("%s doesn't exist. Creating new item...", image.PathDisplayName(t.FilePath))
@@ -1111,6 +1126,8 @@ func (t *ScanTask) scanImage() {
logger.Error(err.Error()) logger.Error(err.Error())
return return
} }
GetInstance().PluginCache.ExecutePostHooks(t.ctx, i.ID, plugin.ImageCreatePost, nil, nil)
} }
if t.zipGallery != nil { if t.zipGallery != nil {
@@ -1186,6 +1203,8 @@ func (t *ScanTask) rescanImage(i *models.Image, fileModTime time.Time) (*models.
} }
} }
GetInstance().PluginCache.ExecutePostHooks(t.ctx, ret.ID, plugin.ImageUpdatePost, nil, nil)
return ret, nil return ret, nil
} }

View File

@@ -2,6 +2,10 @@ package common
import "net/http" import "net/http"
const (
HookContextKey = "hookContext"
)
// StashServerConnection represents the connection details needed for a // StashServerConnection represents the connection details needed for a
// plugin instance to connect to its parent stash server. // plugin instance to connect to its parent stash server.
type StashServerConnection struct { type StashServerConnection struct {
@@ -97,3 +101,12 @@ func (o *PluginOutput) SetError(err error) {
errStr := err.Error() errStr := err.Error()
o.Error = &errStr o.Error = &errStr
} }
// HookContext is passed as a PluginArgValue and indicates what hook triggered
// this plugin task.
type HookContext struct {
ID int `json:"id,omitempty"`
Type string `json:"type"`
Input interface{} `json:"input"`
InputFields []string `json:"inputFields,omitempty"`
}

View File

@@ -54,6 +54,9 @@ type Config struct {
// The task configurations for tasks provided by this plugin. // The task configurations for tasks provided by this plugin.
Tasks []*OperationConfig `yaml:"tasks"` Tasks []*OperationConfig `yaml:"tasks"`
// The hooks configurations for hooks registered by this plugin.
Hooks []*HookConfig `yaml:"hooks"`
} }
func (c Config) getPluginTasks(includePlugin bool) []*models.PluginTask { func (c Config) getPluginTasks(includePlugin bool) []*models.PluginTask {
@@ -74,6 +77,34 @@ func (c Config) getPluginTasks(includePlugin bool) []*models.PluginTask {
return ret return ret
} }
func (c Config) getPluginHooks(includePlugin bool) []*models.PluginHook {
var ret []*models.PluginHook
for _, o := range c.Hooks {
hook := &models.PluginHook{
Name: o.Name,
Description: &o.Description,
Hooks: convertHooks(o.TriggeredBy),
}
if includePlugin {
hook.Plugin = c.toPlugin()
}
ret = append(ret, hook)
}
return ret
}
func convertHooks(hooks []HookTriggerEnum) []string {
var ret []string
for _, h := range hooks {
ret = append(ret, h.String())
}
return ret
}
func (c Config) getName() string { func (c Config) getName() string {
if c.Name != "" { if c.Name != "" {
return c.Name return c.Name
@@ -90,6 +121,7 @@ func (c Config) toPlugin() *models.Plugin {
URL: c.URL, URL: c.URL,
Version: c.Version, Version: c.Version,
Tasks: c.getPluginTasks(false), Tasks: c.getPluginTasks(false),
Hooks: c.getPluginHooks(false),
} }
} }
@@ -103,6 +135,19 @@ func (c Config) getTask(name string) *OperationConfig {
return nil return nil
} }
func (c Config) getHooks(hookType HookTriggerEnum) []*HookConfig {
var ret []*HookConfig
for _, h := range c.Hooks {
for _, t := range h.TriggeredBy {
if hookType == t {
ret = append(ret, h)
}
}
}
return ret
}
func (c Config) getConfigPath() string { func (c Config) getConfigPath() string {
return filepath.Dir(c.path) return filepath.Dir(c.path)
} }
@@ -194,6 +239,13 @@ type OperationConfig struct {
DefaultArgs map[string]string `yaml:"defaultArgs"` DefaultArgs map[string]string `yaml:"defaultArgs"`
} }
type HookConfig struct {
OperationConfig `yaml:",inline"`
// A list of stash operations that will be used to trigger this hook operation.
TriggeredBy []HookTriggerEnum `yaml:"triggeredBy"`
}
func loadPluginFromYAML(reader io.Reader) (*Config, error) { func loadPluginFromYAML(reader io.Reader) (*Config, error) {
ret := &Config{} ret := &Config{}

View File

@@ -11,6 +11,8 @@ function main() {
doLongTask(); doLongTask();
} else if (modeArg == "indef") { } else if (modeArg == "indef") {
doIndefiniteTask(); doIndefiniteTask();
} else if (modeArg == "hook") {
doHookTask();
} }
} catch (err) { } catch (err) {
return { return {
@@ -207,4 +209,9 @@ function doIndefiniteTask() {
} }
} }
function doHookTask() {
log.Info("JS Hook called!");
log.Info(input.Args);
}
main(); main();

View File

@@ -24,4 +24,35 @@ tasks:
description: Sleeps for 100 seconds - interruptable description: Sleeps for 100 seconds - interruptable
defaultArgs: defaultArgs:
mode: long mode: long
hooks:
- name: Log scene marker create/update
description: Logs some stuff when creating/updating scene marker.
triggeredBy:
- SceneMarker.Create.Post
- SceneMarker.Update.Post
- SceneMarker.Delete.Post
- Scene.Create.Post
- Scene.Update.Post
- Scene.Destroy.Post
- Image.Create.Post
- Image.Update.Post
- Image.Destroy.Post
- Gallery.Create.Post
- Gallery.Update.Post
- Gallery.Destroy.Post
- Movie.Create.Post
- Movie.Update.Post
- Movie.Destroy.Post
- Performer.Create.Post
- Performer.Update.Post
- Performer.Destroy.Post
- Studio.Create.Post
- Studio.Update.Post
- Studio.Destroy.Post
- Tag.Create.Post
- Tag.Update.Post
- Tag.Destroy.Post
defaultArgs:
mode: hook

View File

@@ -17,7 +17,10 @@ class StashInterface:
self.url = scheme + "://localhost:" + str(self.port) + "/graphql" self.url = scheme + "://localhost:" + str(self.port) + "/graphql"
# TODO - cookies # Session cookie for authentication
self.cookies = {
'session': conn.get('SessionCookie').get('Value')
}
def __callGraphQL(self, query, variables = None): def __callGraphQL(self, query, variables = None):
json = {} json = {}
@@ -26,7 +29,7 @@ class StashInterface:
json['variables'] = variables json['variables'] = variables
# handle cookies # handle cookies
response = requests.post(self.url, json=json, headers=self.headers) response = requests.post(self.url, json=json, headers=self.headers, cookies=self.cookies)
if response.status_code == 200: if response.status_code == 200:
result = response.json() result = response.json()

125
pkg/plugin/hooks.go Normal file
View File

@@ -0,0 +1,125 @@
package plugin
import (
"github.com/stashapp/stash/pkg/plugin/common"
)
type HookTriggerEnum string
// Scan-related hooks are current disabled until post-hook execution is
// integrated.
const (
SceneMarkerCreatePost HookTriggerEnum = "SceneMarker.Create.Post"
SceneMarkerUpdatePost HookTriggerEnum = "SceneMarker.Update.Post"
SceneMarkerDestroyPost HookTriggerEnum = "SceneMarker.Destroy.Post"
SceneCreatePost HookTriggerEnum = "Scene.Create.Post"
SceneUpdatePost HookTriggerEnum = "Scene.Update.Post"
SceneDestroyPost HookTriggerEnum = "Scene.Destroy.Post"
ImageCreatePost HookTriggerEnum = "Image.Create.Post"
ImageUpdatePost HookTriggerEnum = "Image.Update.Post"
ImageDestroyPost HookTriggerEnum = "Image.Destroy.Post"
GalleryCreatePost HookTriggerEnum = "Gallery.Create.Post"
GalleryUpdatePost HookTriggerEnum = "Gallery.Update.Post"
GalleryDestroyPost HookTriggerEnum = "Gallery.Destroy.Post"
MovieCreatePost HookTriggerEnum = "Movie.Create.Post"
MovieUpdatePost HookTriggerEnum = "Movie.Update.Post"
MovieDestroyPost HookTriggerEnum = "Movie.Destroy.Post"
PerformerCreatePost HookTriggerEnum = "Performer.Create.Post"
PerformerUpdatePost HookTriggerEnum = "Performer.Update.Post"
PerformerDestroyPost HookTriggerEnum = "Performer.Destroy.Post"
StudioCreatePost HookTriggerEnum = "Studio.Create.Post"
StudioUpdatePost HookTriggerEnum = "Studio.Update.Post"
StudioDestroyPost HookTriggerEnum = "Studio.Destroy.Post"
TagCreatePost HookTriggerEnum = "Tag.Create.Post"
TagUpdatePost HookTriggerEnum = "Tag.Update.Post"
TagDestroyPost HookTriggerEnum = "Tag.Destroy.Post"
)
var AllHookTriggerEnum = []HookTriggerEnum{
SceneMarkerCreatePost,
SceneMarkerUpdatePost,
SceneMarkerDestroyPost,
SceneCreatePost,
SceneUpdatePost,
SceneDestroyPost,
ImageCreatePost,
ImageUpdatePost,
ImageDestroyPost,
GalleryCreatePost,
GalleryUpdatePost,
GalleryDestroyPost,
MovieCreatePost,
MovieUpdatePost,
MovieDestroyPost,
PerformerCreatePost,
PerformerUpdatePost,
PerformerDestroyPost,
StudioCreatePost,
StudioUpdatePost,
StudioDestroyPost,
TagCreatePost,
TagUpdatePost,
TagDestroyPost,
}
func (e HookTriggerEnum) IsValid() bool {
switch e {
case SceneMarkerCreatePost,
SceneMarkerUpdatePost,
SceneMarkerDestroyPost,
SceneCreatePost,
SceneUpdatePost,
SceneDestroyPost,
ImageCreatePost,
ImageUpdatePost,
ImageDestroyPost,
GalleryCreatePost,
GalleryUpdatePost,
GalleryDestroyPost,
MovieCreatePost,
MovieUpdatePost,
MovieDestroyPost,
PerformerCreatePost,
PerformerUpdatePost,
PerformerDestroyPost,
StudioCreatePost,
StudioUpdatePost,
StudioDestroyPost,
TagCreatePost,
TagUpdatePost,
TagDestroyPost:
return true
}
return false
}
func (e HookTriggerEnum) String() string {
return string(e)
}
func addHookContext(argsMap common.ArgsMap, hookContext common.HookContext) {
argsMap[common.HookContextKey] = hookContext
}

View File

@@ -28,11 +28,6 @@ type jsPluginTask struct {
vm *otto.Otto vm *otto.Otto
} }
func throw(vm *otto.Otto, str string) {
value, _ := vm.Call("new Error", nil, str)
panic(value)
}
func (t *jsPluginTask) onError(err error) { func (t *jsPluginTask) onError(err error) {
errString := err.Error() errString := err.Error()
t.result = &common.PluginOutput{ t.result = &common.PluginOutput{
@@ -76,12 +71,10 @@ func (t *jsPluginTask) Start() error {
return err return err
} }
input := t.buildPluginInput() t.vm.Set("input", t.input)
t.vm.Set("input", input)
js.AddLogAPI(t.vm, t.progress) js.AddLogAPI(t.vm, t.progress)
js.AddUtilAPI(t.vm) js.AddUtilAPI(t.vm)
js.AddGQLAPI(t.vm, t.gqlHandler) js.AddGQLAPI(t.vm, t.input.ServerConnection.SessionCookie, t.gqlHandler)
t.vm.Interrupt = make(chan func(), 1) t.vm.Interrupt = make(chan func(), 1)

View File

@@ -33,7 +33,7 @@ func throw(vm *otto.Otto, str string) {
panic(value) panic(value)
} }
func gqlRequestFunc(vm *otto.Otto, gqlHandler http.HandlerFunc) func(call otto.FunctionCall) otto.Value { func gqlRequestFunc(vm *otto.Otto, cookie *http.Cookie, gqlHandler http.Handler) func(call otto.FunctionCall) otto.Value {
return func(call otto.FunctionCall) otto.Value { return func(call otto.FunctionCall) otto.Value {
if len(call.ArgumentList) == 0 { if len(call.ArgumentList) == 0 {
throw(vm, "missing argument") throw(vm, "missing argument")
@@ -67,11 +67,15 @@ func gqlRequestFunc(vm *otto.Otto, gqlHandler http.HandlerFunc) func(call otto.F
} }
r.Header.Set("Content-Type", "application/json") r.Header.Set("Content-Type", "application/json")
if cookie != nil {
r.AddCookie(cookie)
}
w := &responseWriter{ w := &responseWriter{
header: make(http.Header), header: make(http.Header),
} }
gqlHandler(w, r) gqlHandler.ServeHTTP(w, r)
if w.statusCode != http.StatusOK && w.statusCode != 0 { if w.statusCode != http.StatusOK && w.statusCode != 0 {
throw(vm, fmt.Sprintf("graphQL query failed: %d - %s. Query: %s. Variables: %v", w.statusCode, w.r.String(), in.Query, in.Variables)) throw(vm, fmt.Sprintf("graphQL query failed: %d - %s. Query: %s. Variables: %v", w.statusCode, w.r.String(), in.Query, in.Variables))
@@ -99,9 +103,9 @@ func gqlRequestFunc(vm *otto.Otto, gqlHandler http.HandlerFunc) func(call otto.F
} }
} }
func AddGQLAPI(vm *otto.Otto, gqlHandler http.HandlerFunc) { func AddGQLAPI(vm *otto.Otto, cookie *http.Cookie, gqlHandler http.Handler) {
gql, _ := vm.Object("({})") gql, _ := vm.Object("({})")
gql.Set("Do", gqlRequestFunc(vm, gqlHandler)) gql.Set("Do", gqlRequestFunc(vm, cookie, gqlHandler))
vm.Set("gql", gql) vm.Set("gql", gql)
} }

View File

@@ -8,6 +8,8 @@ import (
"github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/logger"
) )
const pluginPrefix = "[Plugin] "
func argToString(call otto.FunctionCall) string { func argToString(call otto.FunctionCall) string {
arg := call.Argument(0) arg := call.Argument(0)
if arg.IsObject() { if arg.IsObject() {
@@ -20,27 +22,27 @@ func argToString(call otto.FunctionCall) string {
} }
func logTrace(call otto.FunctionCall) otto.Value { func logTrace(call otto.FunctionCall) otto.Value {
logger.Trace(argToString(call)) logger.Trace(pluginPrefix + argToString(call))
return otto.UndefinedValue() return otto.UndefinedValue()
} }
func logDebug(call otto.FunctionCall) otto.Value { func logDebug(call otto.FunctionCall) otto.Value {
logger.Debug(argToString(call)) logger.Debug(pluginPrefix + argToString(call))
return otto.UndefinedValue() return otto.UndefinedValue()
} }
func logInfo(call otto.FunctionCall) otto.Value { func logInfo(call otto.FunctionCall) otto.Value {
logger.Info(argToString(call)) logger.Info(pluginPrefix + argToString(call))
return otto.UndefinedValue() return otto.UndefinedValue()
} }
func logWarn(call otto.FunctionCall) otto.Value { func logWarn(call otto.FunctionCall) otto.Value {
logger.Warn(argToString(call)) logger.Warn(pluginPrefix + argToString(call))
return otto.UndefinedValue() return otto.UndefinedValue()
} }
func logError(call otto.FunctionCall) otto.Value { func logError(call otto.FunctionCall) otto.Value {
logger.Error(argToString(call)) logger.Error(pluginPrefix + argToString(call))
return otto.UndefinedValue() return otto.UndefinedValue()
} }

View File

@@ -8,6 +8,7 @@
package plugin package plugin
import ( import (
"context"
"fmt" "fmt"
"net/http" "net/http"
"os" "os"
@@ -17,13 +18,16 @@ import (
"github.com/stashapp/stash/pkg/manager/config" "github.com/stashapp/stash/pkg/manager/config"
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/plugin/common" "github.com/stashapp/stash/pkg/plugin/common"
"github.com/stashapp/stash/pkg/session"
"github.com/stashapp/stash/pkg/utils"
) )
// Cache stores plugin details. // Cache stores plugin details.
type Cache struct { type Cache struct {
config *config.Instance config *config.Instance
plugins []Config plugins []Config
gqlHandler http.HandlerFunc sessionStore *session.Store
gqlHandler http.Handler
} }
// NewCache returns a new Cache. // NewCache returns a new Cache.
@@ -39,10 +43,14 @@ func NewCache(config *config.Instance) *Cache {
} }
} }
func (c *Cache) RegisterGQLHandler(handler http.HandlerFunc) { func (c *Cache) RegisterGQLHandler(handler http.Handler) {
c.gqlHandler = handler c.gqlHandler = handler
} }
func (c *Cache) RegisterSessionStore(sessionStore *session.Store) {
c.sessionStore = sessionStore
}
// LoadPlugins clears the plugin cache and loads from the plugin path. // LoadPlugins clears the plugin cache and loads from the plugin path.
// In the event of an error during loading, the cache will be left empty. // In the event of an error during loading, the cache will be left empty.
func (c *Cache) LoadPlugins() error { func (c *Cache) LoadPlugins() error {
@@ -105,10 +113,38 @@ func (c Cache) ListPluginTasks() []*models.PluginTask {
return ret return ret
} }
func buildPluginInput(plugin *Config, operation *OperationConfig, serverConnection common.StashServerConnection, args []*models.PluginArgInput) common.PluginInput {
args = applyDefaultArgs(args, operation.DefaultArgs)
serverConnection.PluginDir = plugin.getConfigPath()
return common.PluginInput{
ServerConnection: serverConnection,
Args: toPluginArgs(args),
}
}
func (c Cache) makeServerConnection(ctx context.Context) common.StashServerConnection {
cookie := c.sessionStore.MakePluginCookie(ctx)
serverConnection := common.StashServerConnection{
Scheme: "http",
Port: c.config.GetPort(),
SessionCookie: cookie,
Dir: c.config.GetConfigPath(),
}
if config.HasTLSConfig() {
serverConnection.Scheme = "https"
}
return serverConnection
}
// CreateTask runs the plugin operation for the pluginID and operation // CreateTask runs the plugin operation for the pluginID and operation
// name provided. Returns an error if the plugin or the operation could not be // name provided. Returns an error if the plugin or the operation could not be
// resolved. // resolved.
func (c Cache) CreateTask(pluginID string, operationName string, serverConnection common.StashServerConnection, args []*models.PluginArgInput, progress chan float64) (Task, error) { func (c Cache) CreateTask(ctx context.Context, pluginID string, operationName string, args []*models.PluginArgInput, progress chan float64) (Task, error) {
serverConnection := c.makeServerConnection(ctx)
// find the plugin and operation // find the plugin and operation
plugin := c.getPlugin(pluginID) plugin := c.getPlugin(pluginID)
@@ -122,16 +158,88 @@ func (c Cache) CreateTask(pluginID string, operationName string, serverConnectio
} }
task := pluginTask{ task := pluginTask{
plugin: plugin, plugin: plugin,
operation: operation, operation: operation,
serverConnection: serverConnection, input: buildPluginInput(plugin, operation, serverConnection, args),
args: args, progress: progress,
progress: progress, gqlHandler: c.gqlHandler,
gqlHandler: c.gqlHandler,
} }
return task.createTask(), nil return task.createTask(), nil
} }
func (c Cache) ExecutePostHooks(ctx context.Context, id int, hookType HookTriggerEnum, input interface{}, inputFields []string) {
if err := c.executePostHooks(ctx, hookType, common.HookContext{
ID: id,
Type: hookType.String(),
Input: input,
InputFields: inputFields,
}); err != nil {
logger.Errorf("error executing post hooks: %s", err.Error())
}
}
func (c Cache) executePostHooks(ctx context.Context, hookType HookTriggerEnum, hookContext common.HookContext) error {
visitedPlugins := session.GetVisitedPlugins(ctx)
for _, p := range c.plugins {
hooks := p.getHooks(hookType)
// don't revisit a plugin we've already visited
// only log if there's hooks that we're skipping
if len(hooks) > 0 && utils.StrInclude(visitedPlugins, p.id) {
logger.Debugf("plugin ID '%s' already triggered, not re-triggering", p.id)
continue
}
for _, h := range hooks {
newCtx := session.AddVisitedPlugin(ctx, p.id)
serverConnection := c.makeServerConnection(newCtx)
pluginInput := buildPluginInput(&p, &h.OperationConfig, serverConnection, nil)
addHookContext(pluginInput.Args, hookContext)
pt := pluginTask{
plugin: &p,
operation: &h.OperationConfig,
input: pluginInput,
gqlHandler: c.gqlHandler,
}
task := pt.createTask()
if err := task.Start(); err != nil {
return err
}
// handle cancel from context
c := make(chan struct{})
go func() {
task.Wait()
close(c)
}()
select {
case <-ctx.Done():
task.Stop()
return fmt.Errorf("operation cancelled")
case <-c:
// task finished normally
}
output := task.GetResult()
if output == nil {
logger.Debugf("%s [%s]: returned no result", hookType.String(), p.Name)
} else {
if output.Error != nil {
logger.Errorf("%s [%s]: returned error: %s", hookType.String(), p.Name, *output.Error)
} else if output.Output != nil {
logger.Debugf("%s [%s]: returned: %v", hookType.String(), p.Name, output.Output)
}
}
}
}
return nil
}
func (c Cache) getPlugin(pluginID string) *Config { func (c Cache) getPlugin(pluginID string) *Config {
for _, s := range c.plugins { for _, s := range c.plugins {
if s.id == pluginID { if s.id == pluginID {

View File

@@ -50,8 +50,7 @@ func (t *rawPluginTask) Start() error {
go func() { go func() {
defer stdin.Close() defer stdin.Close()
input := t.buildPluginInput() inBytes, _ := json.Marshal(t.input)
inBytes, _ := json.Marshal(input)
io.WriteString(stdin, string(inBytes)) io.WriteString(stdin, string(inBytes))
}() }()

View File

@@ -70,12 +70,10 @@ func (t *rpcPluginTask) Start() error {
Client: t.client, Client: t.client,
} }
input := t.buildPluginInput()
t.done = make(chan *rpc.Call, 1) t.done = make(chan *rpc.Call, 1)
result := common.PluginOutput{} result := common.PluginOutput{}
t.waitGroup.Add(1) t.waitGroup.Add(1)
iface.RunAsync(input, &result, t.done) iface.RunAsync(t.input, &result, t.done)
go t.waitToFinish(&result) go t.waitToFinish(&result)
t.started = true t.started = true

View File

@@ -3,7 +3,6 @@ package plugin
import ( import (
"net/http" "net/http"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/plugin/common" "github.com/stashapp/stash/pkg/plugin/common"
) )
@@ -31,11 +30,10 @@ type taskBuilder interface {
} }
type pluginTask struct { type pluginTask struct {
plugin *Config plugin *Config
operation *OperationConfig operation *OperationConfig
serverConnection common.StashServerConnection input common.PluginInput
args []*models.PluginArgInput gqlHandler http.Handler
gqlHandler http.HandlerFunc
progress chan float64 progress chan float64
result *common.PluginOutput result *common.PluginOutput
@@ -48,12 +46,3 @@ func (t *pluginTask) GetResult() *common.PluginOutput {
func (t *pluginTask) createTask() Task { func (t *pluginTask) createTask() Task {
return t.plugin.Interface.getTaskBuilder().build(*t) return t.plugin.Interface.getTaskBuilder().build(*t)
} }
func (t *pluginTask) buildPluginInput() common.PluginInput {
args := applyDefaultArgs(t.args, t.operation.DefaultArgs)
t.serverConnection.PluginDir = t.plugin.getConfigPath()
return common.PluginInput{
ServerConnection: t.serverConnection,
Args: toPluginArgs(args),
}
}

240
pkg/session/session.go Normal file
View File

@@ -0,0 +1,240 @@
package session
import (
"context"
"errors"
"net/http"
"github.com/gorilla/securecookie"
"github.com/gorilla/sessions"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/manager/config"
"github.com/stashapp/stash/pkg/utils"
)
type key int
const (
contextUser key = iota
contextVisitedPlugins
)
const (
userIDKey = "userID"
visitedPluginsKey = "visitedPlugins"
)
const (
ApiKeyHeader = "ApiKey"
ApiKeyParameter = "apikey"
)
const (
cookieName = "session"
usernameFormKey = "username"
passwordFormKey = "password"
)
var ErrInvalidCredentials = errors.New("invalid username or password")
var ErrUnauthorized = errors.New("unauthorized")
type Store struct {
sessionStore *sessions.CookieStore
config *config.Instance
}
func NewStore(c *config.Instance) *Store {
ret := &Store{
sessionStore: sessions.NewCookieStore(config.GetInstance().GetSessionStoreKey()),
config: c,
}
ret.sessionStore.MaxAge(config.GetInstance().GetMaxSessionAge())
return ret
}
func (s *Store) Login(w http.ResponseWriter, r *http.Request) error {
// ignore error - we want a new session regardless
newSession, _ := s.sessionStore.Get(r, cookieName)
username := r.FormValue(usernameFormKey)
password := r.FormValue(passwordFormKey)
// authenticate the user
if !config.GetInstance().ValidateCredentials(username, password) {
return ErrInvalidCredentials
}
newSession.Values[userIDKey] = username
err := newSession.Save(r, w)
if err != nil {
return err
}
return nil
}
func (s *Store) Logout(w http.ResponseWriter, r *http.Request) error {
session, err := s.sessionStore.Get(r, cookieName)
if err != nil {
return err
}
delete(session.Values, userIDKey)
session.Options.MaxAge = -1
err = session.Save(r, w)
if err != nil {
return err
}
return nil
}
func (s *Store) GetSessionUserID(w http.ResponseWriter, r *http.Request) (string, error) {
session, err := s.sessionStore.Get(r, cookieName)
// ignore errors and treat as an empty user id, so that we handle expired
// cookie
if err != nil {
return "", nil
}
if !session.IsNew {
val := session.Values[userIDKey]
// refresh the cookie
err = session.Save(r, w)
if err != nil {
return "", err
}
ret, _ := val.(string)
return ret, nil
}
return "", nil
}
func SetCurrentUserID(ctx context.Context, userID string) context.Context {
return context.WithValue(ctx, contextUser, userID)
}
// GetCurrentUserID gets the current user id from the provided context
func GetCurrentUserID(ctx context.Context) *string {
userCtxVal := ctx.Value(contextUser)
if userCtxVal != nil {
currentUser := userCtxVal.(string)
return &currentUser
}
return nil
}
func (s *Store) VisitedPluginHandler() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// get the visited plugins from the cookie and set in the context
session, err := s.sessionStore.Get(r, cookieName)
// ignore errors
if err == nil {
val := session.Values[visitedPluginsKey]
visitedPlugins, _ := val.([]string)
ctx := setVisitedPlugins(r.Context(), visitedPlugins)
r = r.WithContext(ctx)
}
next.ServeHTTP(w, r)
})
}
}
func GetVisitedPlugins(ctx context.Context) []string {
ctxVal := ctx.Value(contextVisitedPlugins)
if ctxVal != nil {
return ctxVal.([]string)
}
return nil
}
func AddVisitedPlugin(ctx context.Context, pluginID string) context.Context {
curVal := GetVisitedPlugins(ctx)
curVal = utils.StrAppendUnique(curVal, pluginID)
return setVisitedPlugins(ctx, curVal)
}
func setVisitedPlugins(ctx context.Context, visitedPlugins []string) context.Context {
return context.WithValue(ctx, contextVisitedPlugins, visitedPlugins)
}
func (s *Store) createSessionCookie(username string) (*http.Cookie, error) {
session := sessions.NewSession(s.sessionStore, cookieName)
session.Values[userIDKey] = username
encoded, err := securecookie.EncodeMulti(session.Name(), session.Values,
s.sessionStore.Codecs...)
if err != nil {
return nil, err
}
return sessions.NewCookie(session.Name(), encoded, session.Options), nil
}
func (s *Store) MakePluginCookie(ctx context.Context) *http.Cookie {
currentUser := GetCurrentUserID(ctx)
visitedPlugins := GetVisitedPlugins(ctx)
session := sessions.NewSession(s.sessionStore, cookieName)
if currentUser != nil {
session.Values[userIDKey] = *currentUser
}
session.Values[visitedPluginsKey] = visitedPlugins
encoded, err := securecookie.EncodeMulti(session.Name(), session.Values,
s.sessionStore.Codecs...)
if err != nil {
logger.Errorf("error creating session cookie: %s", err.Error())
return nil
}
return sessions.NewCookie(session.Name(), encoded, session.Options)
}
func (s *Store) Authenticate(w http.ResponseWriter, r *http.Request) (userID string, err error) {
c := s.config
// translate api key into current user, if present
apiKey := r.Header.Get(ApiKeyHeader)
// try getting the api key as a query parameter
if apiKey == "" {
apiKey = r.URL.Query().Get(ApiKeyParameter)
}
if apiKey != "" {
// match against configured API and set userID to the
// configured username. In future, we'll want to
// get the username from the key.
if c.GetAPIKey() != apiKey {
return "", ErrUnauthorized
}
userID = c.GetUsername()
} else {
// handle session
userID, err = s.GetSessionUserID(w, r)
}
if err != nil {
return "", err
}
return
}

28
pkg/utils/context.go Normal file
View File

@@ -0,0 +1,28 @@
package utils
import (
"context"
"time"
)
type valueOnlyContext struct {
context.Context
}
func (valueOnlyContext) Deadline() (deadline time.Time, ok bool) {
return
}
func (valueOnlyContext) Done() <-chan struct{} {
return nil
}
func (valueOnlyContext) Err() error {
return nil
}
func ValueOnlyContext(ctx context.Context) context.Context {
return valueOnlyContext{
ctx,
}
}

View File

@@ -23,7 +23,7 @@ func TestOshashEmpty(t *testing.T) {
func TestOshashCollisions(t *testing.T) { func TestOshashCollisions(t *testing.T) {
buf1 := []byte("this is dumb") buf1 := []byte("this is dumb")
buf2 := []byte("dumb is this") buf2 := []byte("dumb is this")
var size = int64(len(buf1)) size := int64(len(buf1))
head := make([]byte, chunkSize) head := make([]byte, chunkSize)
tail1 := make([]byte, chunkSize) tail1 := make([]byte, chunkSize)

View File

@@ -1,4 +1,5 @@
### ✨ New Features ### ✨ New Features
* Added support for triggering plugin tasks during operations. ([#1452](https://github.com/stashapp/stash/pull/1452))
* Support Studio filter including child studios. ([#1397](https://github.com/stashapp/stash/pull/1397)) * Support Studio filter including child studios. ([#1397](https://github.com/stashapp/stash/pull/1397))
* Added support for tag aliases. ([#1412](https://github.com/stashapp/stash/pull/1412)) * Added support for tag aliases. ([#1412](https://github.com/stashapp/stash/pull/1412))
* Support embedded Javascript plugins. ([#1393](https://github.com/stashapp/stash/pull/1393)) * Support embedded Javascript plugins. ([#1393](https://github.com/stashapp/stash/pull/1393))

View File

@@ -1,9 +1,10 @@
import React from "react"; import React from "react";
import { Button } from "react-bootstrap"; import { Button } from "react-bootstrap";
import * as GQL from "src/core/generated-graphql";
import { mutateReloadPlugins, usePlugins } from "src/core/StashService"; import { mutateReloadPlugins, usePlugins } from "src/core/StashService";
import { useToast } from "src/hooks"; import { useToast } from "src/hooks";
import { TextUtils } from "src/utils"; import { TextUtils } from "src/utils";
import { Icon, LoadingIndicator } from "src/components/Shared"; import { CollapseButton, Icon, LoadingIndicator } from "src/components/Shared";
export const SettingsPluginsPanel: React.FC = () => { export const SettingsPluginsPanel: React.FC = () => {
const Toast = useToast(); const Toast = useToast();
@@ -33,13 +34,14 @@ export const SettingsPluginsPanel: React.FC = () => {
function renderPlugins() { function renderPlugins() {
const elements = (data?.plugins ?? []).map((plugin) => ( const elements = (data?.plugins ?? []).map((plugin) => (
<div key={plugin.id}> <div key={plugin.id}>
<h5> <h4>
{plugin.name} {plugin.version ? `(${plugin.version})` : undefined}{" "} {plugin.name} {plugin.version ? `(${plugin.version})` : undefined}{" "}
{renderLink(plugin.url ?? undefined)} {renderLink(plugin.url ?? undefined)}
</h5> </h4>
{plugin.description ? ( {plugin.description ? (
<small className="text-muted">{plugin.description}</small> <small className="text-muted">{plugin.description}</small>
) : undefined} ) : undefined}
{renderPluginHooks(plugin.hooks ?? undefined)}
<hr /> <hr />
</div> </div>
)); ));
@@ -47,11 +49,40 @@ export const SettingsPluginsPanel: React.FC = () => {
return <div>{elements}</div>; return <div>{elements}</div>;
} }
function renderPluginHooks(
hooks?: Pick<GQL.PluginHook, "name" | "description" | "hooks">[]
) {
if (!hooks || hooks.length === 0) {
return;
}
return (
<div className="mt-2">
<h5>Hooks</h5>
{hooks.map((h) => (
<div key={`${h.name}`} className="mb-3">
<h6>{h.name}</h6>
<CollapseButton text="Triggers on">
<ul>
{h.hooks?.map((hh) => (
<li>
<code>{hh}</code>
</li>
))}
</ul>
</CollapseButton>
<small className="text-muted">{h.description}</small>
</div>
))}
</div>
);
}
if (loading) return <LoadingIndicator />; if (loading) return <LoadingIndicator />;
return ( return (
<> <>
<h4>Plugins</h4> <h3>Plugins</h3>
<hr /> <hr />
{renderPlugins()} {renderPlugins()}
<Button onClick={() => onReloadPlugins()}> <Button onClick={() => onReloadPlugins()}>

View File

@@ -78,22 +78,6 @@ For embedded plugins, the `exec` field is a list with the first element being th
For embedded plugins, the `interface` field must be set to one of the following values: For embedded plugins, the `interface` field must be set to one of the following values:
* `js` * `js`
# Task configuration
Tasks are configured using the following structure:
```
tasks:
- name: <operation name>
description: <optional description>
defaultArgs:
argKey: argValue
```
A plugin configuration may contain multiple tasks.
The `defaultArgs` field is used to add inputs to the plugin input sent to the plugin.
# Javascript API # Javascript API
## Logging ## Logging

View File

@@ -89,20 +89,14 @@ The `errLog` field tells stash what the default log level should be when the plu
# Task configuration # Task configuration
Tasks are configured using the following structure: In addition to the standard task configuration, external tags may be configured with an optional `execArgs` field to add extra parameters to the execution arguments for the task.
For example:
``` ```
tasks: tasks:
- name: <operation name> - name: <operation name>
description: <optional description> description: <optional description>
defaultArgs:
argKey: argValue
execArgs: execArgs:
- <arg to add to the exec line> - <arg to add to the exec line>
``` ```
A plugin configuration may contain multiple tasks.
The `defaultArgs` field is used to add inputs to the plugin input sent to the plugin.
The `execArgs` field allows adding extra parameters to the execution arguments for this task.

View File

@@ -2,6 +2,8 @@
Stash supports the running tasks via plugins. Plugins can be implemented using embedded Javascript, or by calling an external binary. Stash supports the running tasks via plugins. Plugins can be implemented using embedded Javascript, or by calling an external binary.
Stash also supports triggering of plugin hooks from specific stash operations.
> **⚠️ Note:** Plugin support is still experimental and is likely to change. > **⚠️ Note:** Plugin support is still experimental and is likely to change.
# Adding plugins # Adding plugins
@@ -67,3 +69,108 @@ Plugin output is expected in the following structure (presented here as JSON for
``` ```
The `error` field is logged in stash at the `error` log level if present. The `output` is written at the `debug` log level. The `error` field is logged in stash at the `error` log level if present. The `output` is written at the `debug` log level.
## Task configuration
Tasks are configured using the following structure:
```
tasks:
- name: <operation name>
description: <optional description>
defaultArgs:
argKey: argValue
```
A plugin configuration may contain multiple tasks.
The `defaultArgs` field is used to add inputs to the plugin input sent to the plugin.
## Hook configuration
Stash supports executing plugin operations via triggering of a hook during a stash operation.
Hooks are configured using a similar structure to tasks:
```
hooks:
- name: <operation name>
description: <optional description>
triggeredBy:
- <trigger types>...
defaultArgs:
argKey: argValue
```
**Note:** it is possible for hooks to trigger eachother or themselves if they perform mutations. For safety, hooks will not be triggered if they have already been triggered in the context of the operation. Stash uses cookies to track this context, so it's important for plugins to send cookies when performing operations.
### Trigger types
Trigger types use the following format:
`<object type>.<operation>.<hook type>`
For example, a post-hook on a scene create operation will be `Scene.Create.Post`.
The following object types are supported:
* `Scene`
* `SceneMarker`
* `Image`
* `Gallery`
* `Movie`
* `Performer`
* `Studio`
* `Tag`
The following operations are supported:
* `Create`
* `Update`
* `Destroy`
Currently, only `Post` hook types are supported. These are executed after the operation has completed and the transaction is committed.
### Hook input
Plugin tasks triggered by a hook include an argument named `hookContext` in the `args` object structure. The `hookContext` is structured as follows:
```
{
"id": <object id>,
"type": <trigger type>,
"input": <operation input>,
"inputFields": <fields included in input>
}
```
The `input` field contains the JSON graphql input passed to the original operation. This will differ between operations. For hooks triggered by operations in a scan or clean, the input will be nil. `inputFields` is populated in update operations to indicate which fields were passed to the operation, to differentiate between missing and empty fields.
For example, here is the `args` values for a Scene update operation:
```
{
"hookContext": {
"type":"Scene.Update.Post",
"id":45,
"input":{
"clientMutationId":null,
"id":"45",
"title":null,
"details":null,
"url":null,
"date":null,
"rating":null,
"organized":null,
"studio_id":null,
"gallery_ids":null,
"performer_ids":null,
"movies":null,
"tag_ids":["21"],
"cover_image":null,
"stash_ids":null
},
"inputFields":[
"tag_ids",
"id"
]
}
}
```