Plugin hooks (#1452)

* Refactor session and plugin code
* Add context to job tasks
* Show hooks in plugins page
* Refactor session management
This commit is contained in:
WithoutPants
2021-06-11 17:24:58 +10:00
committed by GitHub
parent dde361f9f3
commit 46bbede9a0
48 changed files with 1289 additions and 338 deletions

View File

@@ -65,6 +65,15 @@ func (t changesetTranslator) hasField(field string) bool {
return found
}
func (t changesetTranslator) getFields() []string {
var ret []string
for k := range t.inputMap {
ret = append(ret, k)
}
return ret
}
func (t changesetTranslator) nullString(value *string, field string) *sql.NullString {
if !t.hasField(field) {
return nil

View File

@@ -5,13 +5,12 @@ package api
type key int
const (
galleryKey key = 0
performerKey key = 1
sceneKey key = 2
studioKey key = 3
movieKey key = 4
ContextUser key = 5
tagKey key = 6
downloadKey key = 7
imageKey key = 8
galleryKey key = iota
performerKey
sceneKey
studioKey
movieKey
tagKey
downloadKey
imageKey
)

View File

@@ -7,10 +7,16 @@ import (
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/plugin"
)
type hookExecutor interface {
ExecutePostHooks(ctx context.Context, id int, hookType plugin.HookTriggerEnum, input interface{}, inputFields []string)
}
type Resolver struct {
txnManager models.TransactionManager
txnManager models.TransactionManager
hookExecutor hookExecutor
}
func (r *Resolver) Gallery() models.GalleryResolver {

View File

@@ -10,9 +10,21 @@ import (
"github.com/stashapp/stash/pkg/manager"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/plugin"
"github.com/stashapp/stash/pkg/utils"
)
func (r *mutationResolver) getGallery(ctx context.Context, id int) (ret *models.Gallery, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Gallery().Find(id)
return err
}); err != nil {
return nil, err
}
return ret, nil
}
func (r *mutationResolver) GalleryCreate(ctx context.Context, input models.GalleryCreateInput) (*models.Gallery, error) {
// name must be provided
if input.Title == "" {
@@ -90,7 +102,8 @@ func (r *mutationResolver) GalleryCreate(ctx context.Context, input models.Galle
return nil, err
}
return gallery, nil
r.hookExecutor.ExecutePostHooks(ctx, gallery.ID, plugin.GalleryCreatePost, input, nil)
return r.getGallery(ctx, gallery.ID)
}
func (r *mutationResolver) updateGalleryPerformers(qb models.GalleryReaderWriter, galleryID int, performerIDs []string) error {
@@ -130,7 +143,9 @@ func (r *mutationResolver) GalleryUpdate(ctx context.Context, input models.Galle
return nil, err
}
return ret, nil
// execute post hooks outside txn
r.hookExecutor.ExecutePostHooks(ctx, ret.ID, plugin.GalleryUpdatePost, input, translator.getFields())
return r.getGallery(ctx, ret.ID)
}
func (r *mutationResolver) GalleriesUpdate(ctx context.Context, input []*models.GalleryUpdateInput) (ret []*models.Gallery, err error) {
@@ -156,7 +171,23 @@ func (r *mutationResolver) GalleriesUpdate(ctx context.Context, input []*models.
return nil, err
}
return ret, nil
// execute post hooks outside txn
var newRet []*models.Gallery
for i, gallery := range ret {
translator := changesetTranslator{
inputMap: inputMaps[i],
}
r.hookExecutor.ExecutePostHooks(ctx, gallery.ID, plugin.GalleryUpdatePost, input, translator.getFields())
gallery, err = r.getGallery(ctx, gallery.ID)
if err != nil {
return nil, err
}
newRet = append(newRet, gallery)
}
return newRet, nil
}
func (r *mutationResolver) galleryUpdate(input models.GalleryUpdateInput, translator changesetTranslator, repo models.Repository) (*models.Gallery, error) {
@@ -314,7 +345,20 @@ func (r *mutationResolver) BulkGalleryUpdate(ctx context.Context, input models.B
return nil, err
}
return ret, nil
// execute post hooks outside of txn
var newRet []*models.Gallery
for _, gallery := range ret {
r.hookExecutor.ExecutePostHooks(ctx, gallery.ID, plugin.GalleryUpdatePost, input, translator.getFields())
gallery, err := r.getGallery(ctx, gallery.ID)
if err != nil {
return nil, err
}
newRet = append(newRet, gallery)
}
return newRet, nil
}
func adjustGalleryPerformerIDs(qb models.GalleryReader, galleryID int, ids models.BulkUpdateIds) (ret []int, err error) {
@@ -438,6 +482,16 @@ func (r *mutationResolver) GalleryDestroy(ctx context.Context, input models.Gall
}
}
// call post hook after performing the other actions
for _, gallery := range galleries {
r.hookExecutor.ExecutePostHooks(ctx, gallery.ID, plugin.GalleryDestroyPost, input, nil)
}
// call image destroy post hook as well
for _, img := range imgsToDelete {
r.hookExecutor.ExecutePostHooks(ctx, img.ID, plugin.ImageDestroyPost, nil, nil)
}
return true, nil
}

View File

@@ -8,9 +8,21 @@ import (
"github.com/stashapp/stash/pkg/manager"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/plugin"
"github.com/stashapp/stash/pkg/utils"
)
func (r *mutationResolver) getImage(ctx context.Context, id int) (ret *models.Image, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Image().Find(id)
return err
}); err != nil {
return nil, err
}
return ret, nil
}
func (r *mutationResolver) ImageUpdate(ctx context.Context, input models.ImageUpdateInput) (ret *models.Image, err error) {
translator := changesetTranslator{
inputMap: getUpdateInputMap(ctx),
@@ -24,7 +36,9 @@ func (r *mutationResolver) ImageUpdate(ctx context.Context, input models.ImageUp
return nil, err
}
return ret, nil
// execute post hooks outside txn
r.hookExecutor.ExecutePostHooks(ctx, ret.ID, plugin.ImageUpdatePost, input, translator.getFields())
return r.getImage(ctx, ret.ID)
}
func (r *mutationResolver) ImagesUpdate(ctx context.Context, input []*models.ImageUpdateInput) (ret []*models.Image, err error) {
@@ -50,7 +64,23 @@ func (r *mutationResolver) ImagesUpdate(ctx context.Context, input []*models.Ima
return nil, err
}
return ret, nil
// execute post hooks outside txn
var newRet []*models.Image
for i, image := range ret {
translator := changesetTranslator{
inputMap: inputMaps[i],
}
r.hookExecutor.ExecutePostHooks(ctx, image.ID, plugin.ImageUpdatePost, input, translator.getFields())
image, err = r.getImage(ctx, image.ID)
if err != nil {
return nil, err
}
newRet = append(newRet, image)
}
return newRet, nil
}
func (r *mutationResolver) imageUpdate(input models.ImageUpdateInput, translator changesetTranslator, repo models.Repository) (*models.Image, error) {
@@ -202,7 +232,20 @@ func (r *mutationResolver) BulkImageUpdate(ctx context.Context, input models.Bul
return nil, err
}
return ret, nil
// execute post hooks outside of txn
var newRet []*models.Image
for _, image := range ret {
r.hookExecutor.ExecutePostHooks(ctx, image.ID, plugin.ImageUpdatePost, input, translator.getFields())
image, err = r.getImage(ctx, image.ID)
if err != nil {
return nil, err
}
newRet = append(newRet, image)
}
return newRet, nil
}
func adjustImageGalleryIDs(qb models.ImageReader, imageID int, ids models.BulkUpdateIds) (ret []int, err error) {
@@ -268,6 +311,9 @@ func (r *mutationResolver) ImageDestroy(ctx context.Context, input models.ImageD
manager.DeleteImageFile(image)
}
// call post hook after performing the other actions
r.hookExecutor.ExecutePostHooks(ctx, image.ID, plugin.ImageDestroyPost, input, nil)
return true, nil
}
@@ -315,6 +361,9 @@ func (r *mutationResolver) ImagesDestroy(ctx context.Context, input models.Image
if input.DeleteFile != nil && *input.DeleteFile {
manager.DeleteImageFile(image)
}
// call post hook after performing the other actions
r.hookExecutor.ExecutePostHooks(ctx, image.ID, plugin.ImageDestroyPost, input, nil)
}
return true, nil

View File

@@ -17,7 +17,7 @@ import (
)
func (r *mutationResolver) MetadataScan(ctx context.Context, input models.ScanMetadataInput) (string, error) {
jobID, err := manager.GetInstance().Scan(input)
jobID, err := manager.GetInstance().Scan(ctx, input)
if err != nil {
return "", err
@@ -27,7 +27,7 @@ func (r *mutationResolver) MetadataScan(ctx context.Context, input models.ScanMe
}
func (r *mutationResolver) MetadataImport(ctx context.Context) (string, error) {
jobID, err := manager.GetInstance().Import()
jobID, err := manager.GetInstance().Import(ctx)
if err != nil {
return "", err
}
@@ -41,13 +41,13 @@ func (r *mutationResolver) ImportObjects(ctx context.Context, input models.Impor
return "", err
}
jobID := manager.GetInstance().RunSingleTask(t)
jobID := manager.GetInstance().RunSingleTask(ctx, t)
return strconv.Itoa(jobID), nil
}
func (r *mutationResolver) MetadataExport(ctx context.Context) (string, error) {
jobID, err := manager.GetInstance().Export()
jobID, err := manager.GetInstance().Export(ctx)
if err != nil {
return "", err
}
@@ -75,7 +75,7 @@ func (r *mutationResolver) ExportObjects(ctx context.Context, input models.Expor
}
func (r *mutationResolver) MetadataGenerate(ctx context.Context, input models.GenerateMetadataInput) (string, error) {
jobID, err := manager.GetInstance().Generate(input)
jobID, err := manager.GetInstance().Generate(ctx, input)
if err != nil {
return "", err
@@ -85,17 +85,17 @@ func (r *mutationResolver) MetadataGenerate(ctx context.Context, input models.Ge
}
func (r *mutationResolver) MetadataAutoTag(ctx context.Context, input models.AutoTagMetadataInput) (string, error) {
jobID := manager.GetInstance().AutoTag(input)
jobID := manager.GetInstance().AutoTag(ctx, input)
return strconv.Itoa(jobID), nil
}
func (r *mutationResolver) MetadataClean(ctx context.Context, input models.CleanMetadataInput) (string, error) {
jobID := manager.GetInstance().Clean(input)
jobID := manager.GetInstance().Clean(ctx, input)
return strconv.Itoa(jobID), nil
}
func (r *mutationResolver) MigrateHashNaming(ctx context.Context) (string, error) {
jobID := manager.GetInstance().MigrateHash()
jobID := manager.GetInstance().MigrateHash(ctx)
return strconv.Itoa(jobID), nil
}

View File

@@ -7,9 +7,21 @@ import (
"time"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/plugin"
"github.com/stashapp/stash/pkg/utils"
)
func (r *mutationResolver) getMovie(ctx context.Context, id int) (ret *models.Movie, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Movie().Find(id)
return err
}); err != nil {
return nil, err
}
return ret, nil
}
func (r *mutationResolver) MovieCreate(ctx context.Context, input models.MovieCreateInput) (*models.Movie, error) {
// generate checksum from movie name rather than image
checksum := utils.MD5FromString(input.Name)
@@ -104,7 +116,8 @@ func (r *mutationResolver) MovieCreate(ctx context.Context, input models.MovieCr
return nil, err
}
return movie, nil
r.hookExecutor.ExecutePostHooks(ctx, movie.ID, plugin.MovieCreatePost, input, nil)
return r.getMovie(ctx, movie.ID)
}
func (r *mutationResolver) MovieUpdate(ctx context.Context, input models.MovieUpdateInput) (*models.Movie, error) {
@@ -203,7 +216,8 @@ func (r *mutationResolver) MovieUpdate(ctx context.Context, input models.MovieUp
return nil, err
}
return movie, nil
r.hookExecutor.ExecutePostHooks(ctx, movie.ID, plugin.MovieUpdatePost, input, translator.getFields())
return r.getMovie(ctx, movie.ID)
}
func (r *mutationResolver) MovieDestroy(ctx context.Context, input models.MovieDestroyInput) (bool, error) {
@@ -217,6 +231,9 @@ func (r *mutationResolver) MovieDestroy(ctx context.Context, input models.MovieD
}); err != nil {
return false, err
}
r.hookExecutor.ExecutePostHooks(ctx, id, plugin.MovieDestroyPost, input, nil)
return true, nil
}
@@ -238,5 +255,10 @@ func (r *mutationResolver) MoviesDestroy(ctx context.Context, movieIDs []string)
}); err != nil {
return false, err
}
for _, id := range ids {
r.hookExecutor.ExecutePostHooks(ctx, id, plugin.MovieDestroyPost, movieIDs, nil)
}
return true, nil
}

View File

@@ -9,9 +9,21 @@ import (
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/performer"
"github.com/stashapp/stash/pkg/plugin"
"github.com/stashapp/stash/pkg/utils"
)
func (r *mutationResolver) getPerformer(ctx context.Context, id int) (ret *models.Performer, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Performer().Find(id)
return err
}); err != nil {
return nil, err
}
return ret, nil
}
func (r *mutationResolver) PerformerCreate(ctx context.Context, input models.PerformerCreateInput) (*models.Performer, error) {
// generate checksum from performer name rather than image
checksum := utils.MD5FromString(input.Name)
@@ -146,7 +158,8 @@ func (r *mutationResolver) PerformerCreate(ctx context.Context, input models.Per
return nil, err
}
return performer, nil
r.hookExecutor.ExecutePostHooks(ctx, performer.ID, plugin.PerformerCreatePost, input, nil)
return r.getPerformer(ctx, performer.ID)
}
func (r *mutationResolver) PerformerUpdate(ctx context.Context, input models.PerformerUpdateInput) (*models.Performer, error) {
@@ -267,7 +280,8 @@ func (r *mutationResolver) PerformerUpdate(ctx context.Context, input models.Per
return nil, err
}
return p, nil
r.hookExecutor.ExecutePostHooks(ctx, p.ID, plugin.PerformerUpdatePost, input, translator.getFields())
return r.getPerformer(ctx, p.ID)
}
func (r *mutationResolver) updatePerformerTags(qb models.PerformerReaderWriter, performerID int, tagsIDs []string) error {
@@ -372,7 +386,20 @@ func (r *mutationResolver) BulkPerformerUpdate(ctx context.Context, input models
return nil, err
}
return ret, nil
// execute post hooks outside of txn
var newRet []*models.Performer
for _, performer := range ret {
r.hookExecutor.ExecutePostHooks(ctx, performer.ID, plugin.ImageUpdatePost, input, translator.getFields())
performer, err = r.getPerformer(ctx, performer.ID)
if err != nil {
return nil, err
}
newRet = append(newRet, performer)
}
return newRet, nil
}
func (r *mutationResolver) PerformerDestroy(ctx context.Context, input models.PerformerDestroyInput) (bool, error) {
@@ -386,6 +413,9 @@ func (r *mutationResolver) PerformerDestroy(ctx context.Context, input models.Pe
}); err != nil {
return false, err
}
r.hookExecutor.ExecutePostHooks(ctx, id, plugin.PerformerDestroyPost, input, nil)
return true, nil
}
@@ -407,5 +437,10 @@ func (r *mutationResolver) PerformersDestroy(ctx context.Context, performerIDs [
}); err != nil {
return false, err
}
for _, id := range ids {
r.hookExecutor.ExecutePostHooks(ctx, id, plugin.PerformerDestroyPost, performerIDs, nil)
}
return true, nil
}

View File

@@ -2,40 +2,15 @@ package api
import (
"context"
"net/http"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/manager"
"github.com/stashapp/stash/pkg/manager/config"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/plugin/common"
)
func (r *mutationResolver) RunPluginTask(ctx context.Context, pluginID string, taskName string, args []*models.PluginArgInput) (string, error) {
currentUser := getCurrentUserID(ctx)
var cookie *http.Cookie
var err error
if currentUser != nil {
cookie, err = createSessionCookie(*currentUser)
if err != nil {
return "", err
}
}
config := config.GetInstance()
serverConnection := common.StashServerConnection{
Scheme: "http",
Port: config.GetPort(),
SessionCookie: cookie,
Dir: config.GetConfigPath(),
}
if HasTLSConfig() {
serverConnection.Scheme = "https"
}
manager.GetInstance().RunPluginTask(pluginID, taskName, args, serverConnection)
m := manager.GetInstance()
m.RunPluginTask(ctx, pluginID, taskName, args)
return "todo", nil
}

View File

@@ -10,9 +10,21 @@ import (
"github.com/stashapp/stash/pkg/manager"
"github.com/stashapp/stash/pkg/manager/config"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/plugin"
"github.com/stashapp/stash/pkg/utils"
)
func (r *mutationResolver) getScene(ctx context.Context, id int) (ret *models.Scene, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Scene().Find(id)
return err
}); err != nil {
return nil, err
}
return ret, nil
}
func (r *mutationResolver) SceneUpdate(ctx context.Context, input models.SceneUpdateInput) (ret *models.Scene, err error) {
translator := changesetTranslator{
inputMap: getUpdateInputMap(ctx),
@@ -26,7 +38,8 @@ func (r *mutationResolver) SceneUpdate(ctx context.Context, input models.SceneUp
return nil, err
}
return ret, nil
r.hookExecutor.ExecutePostHooks(ctx, ret.ID, plugin.SceneUpdatePost, input, translator.getFields())
return r.getScene(ctx, ret.ID)
}
func (r *mutationResolver) ScenesUpdate(ctx context.Context, input []*models.SceneUpdateInput) (ret []*models.Scene, err error) {
@@ -52,7 +65,24 @@ func (r *mutationResolver) ScenesUpdate(ctx context.Context, input []*models.Sce
return nil, err
}
return ret, nil
// execute post hooks outside of txn
var newRet []*models.Scene
for i, scene := range ret {
translator := changesetTranslator{
inputMap: inputMaps[i],
}
r.hookExecutor.ExecutePostHooks(ctx, scene.ID, plugin.SceneUpdatePost, input, translator.getFields())
scene, err = r.getScene(ctx, scene.ID)
if err != nil {
return nil, err
}
newRet = append(newRet, scene)
}
return newRet, nil
}
func (r *mutationResolver) sceneUpdate(input models.SceneUpdateInput, translator changesetTranslator, repo models.Repository) (*models.Scene, error) {
@@ -281,7 +311,20 @@ func (r *mutationResolver) BulkSceneUpdate(ctx context.Context, input models.Bul
return nil, err
}
return ret, nil
// execute post hooks outside of txn
var newRet []*models.Scene
for _, scene := range ret {
r.hookExecutor.ExecutePostHooks(ctx, scene.ID, plugin.SceneUpdatePost, input, translator.getFields())
scene, err = r.getScene(ctx, scene.ID)
if err != nil {
return nil, err
}
newRet = append(newRet, scene)
}
return newRet, nil
}
func adjustIDs(existingIDs []int, updateIDs models.BulkUpdateIds) []int {
@@ -393,6 +436,9 @@ func (r *mutationResolver) SceneDestroy(ctx context.Context, input models.SceneD
manager.DeleteSceneFile(scene)
}
// call post hook after performing the other actions
r.hookExecutor.ExecutePostHooks(ctx, scene.ID, plugin.SceneDestroyPost, input, nil)
return true, nil
}
@@ -442,11 +488,25 @@ func (r *mutationResolver) ScenesDestroy(ctx context.Context, input models.Scene
if input.DeleteFile != nil && *input.DeleteFile {
manager.DeleteSceneFile(scene)
}
// call post hook after performing the other actions
r.hookExecutor.ExecutePostHooks(ctx, scene.ID, plugin.SceneDestroyPost, input, nil)
}
return true, nil
}
func (r *mutationResolver) getSceneMarker(ctx context.Context, id int) (ret *models.SceneMarker, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.SceneMarker().Find(id)
return err
}); err != nil {
return nil, err
}
return ret, nil
}
func (r *mutationResolver) SceneMarkerCreate(ctx context.Context, input models.SceneMarkerCreateInput) (*models.SceneMarker, error) {
primaryTagID, err := strconv.Atoi(input.PrimaryTagID)
if err != nil {
@@ -473,7 +533,13 @@ func (r *mutationResolver) SceneMarkerCreate(ctx context.Context, input models.S
return nil, err
}
return r.changeMarker(ctx, create, newSceneMarker, tagIDs)
ret, err := r.changeMarker(ctx, create, newSceneMarker, tagIDs)
if err != nil {
return nil, err
}
r.hookExecutor.ExecutePostHooks(ctx, ret.ID, plugin.SceneMarkerCreatePost, input, nil)
return r.getSceneMarker(ctx, ret.ID)
}
func (r *mutationResolver) SceneMarkerUpdate(ctx context.Context, input models.SceneMarkerUpdateInput) (*models.SceneMarker, error) {
@@ -507,7 +573,16 @@ func (r *mutationResolver) SceneMarkerUpdate(ctx context.Context, input models.S
return nil, err
}
return r.changeMarker(ctx, update, updatedSceneMarker, tagIDs)
ret, err := r.changeMarker(ctx, update, updatedSceneMarker, tagIDs)
if err != nil {
return nil, err
}
translator := changesetTranslator{
inputMap: getUpdateInputMap(ctx),
}
r.hookExecutor.ExecutePostHooks(ctx, ret.ID, plugin.SceneMarkerUpdatePost, input, translator.getFields())
return r.getSceneMarker(ctx, ret.ID)
}
func (r *mutationResolver) SceneMarkerDestroy(ctx context.Context, id string) (bool, error) {
@@ -544,6 +619,8 @@ func (r *mutationResolver) SceneMarkerDestroy(ctx context.Context, id string) (b
postCommitFunc()
r.hookExecutor.ExecutePostHooks(ctx, markerID, plugin.SceneMarkerDestroyPost, id, nil)
return true, nil
}
@@ -651,9 +728,9 @@ func (r *mutationResolver) SceneResetO(ctx context.Context, id string) (ret int,
func (r *mutationResolver) SceneGenerateScreenshot(ctx context.Context, id string, at *float64) (string, error) {
if at != nil {
manager.GetInstance().GenerateScreenshot(id, *at)
manager.GetInstance().GenerateScreenshot(ctx, id, *at)
} else {
manager.GetInstance().GenerateDefaultScreenshot(id)
manager.GetInstance().GenerateDefaultScreenshot(ctx, id)
}
return "todo", nil

View File

@@ -24,6 +24,6 @@ func (r *mutationResolver) SubmitStashBoxFingerprints(ctx context.Context, input
}
func (r *mutationResolver) StashBoxBatchPerformerTag(ctx context.Context, input models.StashBoxBatchPerformerTagInput) (string, error) {
jobID := manager.GetInstance().StashBoxBatchPerformerTag(input)
jobID := manager.GetInstance().StashBoxBatchPerformerTag(ctx, input)
return strconv.Itoa(jobID), nil
}

View File

@@ -8,9 +8,21 @@ import (
"github.com/stashapp/stash/pkg/manager"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/plugin"
"github.com/stashapp/stash/pkg/utils"
)
func (r *mutationResolver) getStudio(ctx context.Context, id int) (ret *models.Studio, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Studio().Find(id)
return err
}); err != nil {
return nil, err
}
return ret, nil
}
func (r *mutationResolver) StudioCreate(ctx context.Context, input models.StudioCreateInput) (*models.Studio, error) {
// generate checksum from studio name rather than image
checksum := utils.MD5FromString(input.Name)
@@ -82,7 +94,8 @@ func (r *mutationResolver) StudioCreate(ctx context.Context, input models.Studio
return nil, err
}
return studio, nil
r.hookExecutor.ExecutePostHooks(ctx, studio.ID, plugin.StudioCreatePost, input, nil)
return r.getStudio(ctx, studio.ID)
}
func (r *mutationResolver) StudioUpdate(ctx context.Context, input models.StudioUpdateInput) (*models.Studio, error) {
@@ -162,7 +175,8 @@ func (r *mutationResolver) StudioUpdate(ctx context.Context, input models.Studio
return nil, err
}
return studio, nil
r.hookExecutor.ExecutePostHooks(ctx, studio.ID, plugin.StudioUpdatePost, input, translator.getFields())
return r.getStudio(ctx, studio.ID)
}
func (r *mutationResolver) StudioDestroy(ctx context.Context, input models.StudioDestroyInput) (bool, error) {
@@ -176,6 +190,9 @@ func (r *mutationResolver) StudioDestroy(ctx context.Context, input models.Studi
}); err != nil {
return false, err
}
r.hookExecutor.ExecutePostHooks(ctx, id, plugin.StudioDestroyPost, input, nil)
return true, nil
}
@@ -197,5 +214,10 @@ func (r *mutationResolver) StudiosDestroy(ctx context.Context, studioIDs []strin
}); err != nil {
return false, err
}
for _, id := range ids {
r.hookExecutor.ExecutePostHooks(ctx, id, plugin.StudioDestroyPost, studioIDs, nil)
}
return true, nil
}

View File

@@ -7,10 +7,22 @@ import (
"time"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/plugin"
"github.com/stashapp/stash/pkg/tag"
"github.com/stashapp/stash/pkg/utils"
)
func (r *mutationResolver) getTag(ctx context.Context, id int) (ret *models.Tag, err error) {
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
ret, err = repo.Tag().Find(id)
return err
}); err != nil {
return nil, err
}
return ret, nil
}
func (r *mutationResolver) TagCreate(ctx context.Context, input models.TagCreateInput) (*models.Tag, error) {
// Populate a new tag from the input
currentTime := time.Now()
@@ -68,7 +80,8 @@ func (r *mutationResolver) TagCreate(ctx context.Context, input models.TagCreate
return nil, err
}
return t, nil
r.hookExecutor.ExecutePostHooks(ctx, t.ID, plugin.TagCreatePost, input, nil)
return r.getTag(ctx, t.ID)
}
func (r *mutationResolver) TagUpdate(ctx context.Context, input models.TagUpdateInput) (*models.Tag, error) {
@@ -153,7 +166,8 @@ func (r *mutationResolver) TagUpdate(ctx context.Context, input models.TagUpdate
return nil, err
}
return t, nil
r.hookExecutor.ExecutePostHooks(ctx, t.ID, plugin.TagUpdatePost, input, translator.getFields())
return r.getTag(ctx, t.ID)
}
func (r *mutationResolver) TagDestroy(ctx context.Context, input models.TagDestroyInput) (bool, error) {
@@ -167,6 +181,9 @@ func (r *mutationResolver) TagDestroy(ctx context.Context, input models.TagDestr
}); err != nil {
return false, err
}
r.hookExecutor.ExecutePostHooks(ctx, tagID, plugin.TagDestroyPost, input, nil)
return true, nil
}
@@ -188,5 +205,10 @@ func (r *mutationResolver) TagsDestroy(ctx context.Context, tagIDs []string) (bo
}); err != nil {
return false, err
}
for _, id := range ids {
r.hookExecutor.ExecutePostHooks(ctx, id, plugin.TagDestroyPost, tagIDs, nil)
}
return true, nil
}

View File

@@ -7,6 +7,7 @@ import (
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/models/mocks"
"github.com/stashapp/stash/pkg/plugin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
@@ -15,7 +16,8 @@ import (
// TODO - move this into a common area
func newResolver() *Resolver {
return &Resolver{
txnManager: mocks.NewTransactionManager(),
txnManager: mocks.NewTransactionManager(),
hookExecutor: &mockHookExecutor{},
}
}
@@ -26,6 +28,11 @@ const existingTagID = 1
const existingTagName = "existingTagName"
const newTagID = 2
type mockHookExecutor struct{}
func (*mockHookExecutor) ExecutePostHooks(ctx context.Context, id int, hookType plugin.HookTriggerEnum, input interface{}, inputFields []string) {
}
func TestTagCreate(t *testing.T) {
r := newResolver()
@@ -84,10 +91,12 @@ func TestTagCreate(t *testing.T) {
tagRW.On("Query", tagFilterForName(tagName), findFilter).Return(nil, 0, nil).Once()
tagRW.On("Query", tagFilterForAlias(tagName), findFilter).Return(nil, 0, nil).Once()
tagRW.On("Create", mock.AnythingOfType("models.Tag")).Return(&models.Tag{
newTag := &models.Tag{
ID: newTagID,
Name: tagName,
}, nil)
}
tagRW.On("Create", mock.AnythingOfType("models.Tag")).Return(newTag, nil)
tagRW.On("Find", newTagID).Return(newTag, nil)
tag, err := r.Mutation().TagCreate(context.TODO(), models.TagCreateInput{
Name: tagName,

View File

@@ -29,6 +29,7 @@ import (
"github.com/stashapp/stash/pkg/manager/config"
"github.com/stashapp/stash/pkg/manager/paths"
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/session"
"github.com/stashapp/stash/pkg/utils"
)
@@ -41,11 +42,6 @@ var uiBox *packr.Box
//var legacyUiBox *packr.Box
var loginUIBox *packr.Box
const (
ApiKeyHeader = "ApiKey"
ApiKeyParameter = "apikey"
)
func allowUnauthenticated(r *http.Request) bool {
return strings.HasPrefix(r.URL.Path, "/login") || r.URL.Path == "/css"
}
@@ -53,44 +49,26 @@ func allowUnauthenticated(r *http.Request) bool {
func authenticateHandler() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c := config.GetInstance()
ctx := r.Context()
// translate api key into current user, if present
userID := ""
apiKey := r.Header.Get(ApiKeyHeader)
var err error
// try getting the api key as a query parameter
if apiKey == "" {
apiKey = r.URL.Query().Get(ApiKeyParameter)
}
if apiKey != "" {
// match against configured API and set userID to the
// configured username. In future, we'll want to
// get the username from the key.
if c.GetAPIKey() != apiKey {
w.Header().Add("WWW-Authenticate", `FormBased`)
w.WriteHeader(http.StatusUnauthorized)
userID, err := manager.GetInstance().SessionStore.Authenticate(w, r)
if err != nil {
if err != session.ErrUnauthorized {
w.WriteHeader(http.StatusInternalServerError)
_, err = w.Write([]byte(err.Error()))
if err != nil {
logger.Error(err)
}
return
}
userID = c.GetUsername()
} else {
// handle session
userID, err = getSessionUserID(w, r)
}
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
_, err = w.Write([]byte(err.Error()))
if err != nil {
logger.Error(err)
}
// unauthorized error
w.Header().Add("WWW-Authenticate", `FormBased`)
w.WriteHeader(http.StatusUnauthorized)
return
}
c := config.GetInstance()
ctx := r.Context()
// handle redirect if no user and user is required
if userID == "" && c.HasCredentials() && !allowUnauthenticated(r) {
// if we don't have a userID, then redirect
@@ -112,7 +90,7 @@ func authenticateHandler() func(http.Handler) http.Handler {
return
}
ctx = context.WithValue(ctx, ContextUser, userID)
ctx = session.SetCurrentUserID(ctx, userID)
r = r.WithContext(ctx)
@@ -121,6 +99,16 @@ func authenticateHandler() func(http.Handler) http.Handler {
}
}
func visitedPluginHandler() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// get the visited plugins and set them in the context
next.ServeHTTP(w, r)
})
}
}
const loginEndPoint = "/login"
func Start() {
@@ -128,13 +116,15 @@ func Start() {
//legacyUiBox = packr.New("UI Box", "../../ui/v1/dist/stash-frontend")
loginUIBox = packr.New("Login UI Box", "../../ui/login")
initSessionStore()
initialiseImages()
r := chi.NewRouter()
r.Use(middleware.Heartbeat("/healthz"))
r.Use(authenticateHandler())
visitedPluginHandler := manager.GetInstance().SessionStore.VisitedPluginHandler()
r.Use(visitedPluginHandler)
r.Use(middleware.Recoverer)
c := config.GetInstance()
@@ -155,8 +145,10 @@ func Start() {
}
txnManager := manager.GetInstance().TxnManager
pluginCache := manager.GetInstance().PluginCache
resolver := &Resolver{
txnManager: txnManager,
txnManager: txnManager,
hookExecutor: pluginCache,
}
gqlSrv := gqlHandler.New(models.NewExecutableSchema(models.Config{Resolvers: resolver}))
@@ -184,7 +176,8 @@ func Start() {
}
// register GQL handler with plugin cache
manager.GetInstance().PluginCache.RegisterGQLHandler(gqlHandlerFunc)
// chain the visited plugin handler
manager.GetInstance().PluginCache.RegisterGQLHandler(visitedPluginHandler(http.HandlerFunc(gqlHandlerFunc)))
r.HandleFunc("/graphql", gqlHandlerFunc)
r.HandleFunc("/playground", gqlPlayground.Handler("GraphQL playground", "/graphql"))
@@ -358,15 +351,6 @@ func makeTLSConfig() *tls.Config {
return tlsConfig
}
func HasTLSConfig() bool {
ret, _ := utils.FileExists(paths.GetSSLCert())
if ret {
ret, _ = utils.FileExists(paths.GetSSLKey())
}
return ret
}
type contextKey struct {
name string
}

View File

@@ -1,15 +1,13 @@
package api
import (
"context"
"fmt"
"html/template"
"net/http"
"github.com/stashapp/stash/pkg/manager"
"github.com/stashapp/stash/pkg/manager/config"
"github.com/gorilla/securecookie"
"github.com/gorilla/sessions"
"github.com/stashapp/stash/pkg/session"
)
const cookieName = "session"
@@ -19,17 +17,11 @@ const userIDKey = "userID"
const returnURLParam = "returnURL"
var sessionStore = sessions.NewCookieStore(config.GetInstance().GetSessionStoreKey())
type loginTemplateData struct {
URL string
Error string
}
func initSessionStore() {
sessionStore.MaxAge(config.GetInstance().GetMaxSessionAge())
}
func redirectToLogin(w http.ResponseWriter, returnURL string, loginError string) {
data, _ := loginUIBox.Find("login.html")
templ, err := template.New("Login").Parse(string(data))
@@ -59,22 +51,13 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
url = "/"
}
// ignore error - we want a new session regardless
newSession, _ := sessionStore.Get(r, cookieName)
username := r.FormValue("username")
password := r.FormValue("password")
// authenticate the user
if !config.GetInstance().ValidateCredentials(username, password) {
err := manager.GetInstance().SessionStore.Login(w, r)
if err == session.ErrInvalidCredentials {
// redirect back to the login page with an error
redirectToLogin(w, url, "Username or password is invalid")
return
}
newSession.Values[userIDKey] = username
err := newSession.Save(r, w)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
@@ -84,17 +67,7 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
}
func handleLogout(w http.ResponseWriter, r *http.Request) {
session, err := sessionStore.Get(r, cookieName)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
delete(session.Values, userIDKey)
session.Options.MaxAge = -1
err = session.Save(r, w)
if err != nil {
if err := manager.GetInstance().SessionStore.Logout(w, r); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@@ -102,51 +75,3 @@ func handleLogout(w http.ResponseWriter, r *http.Request) {
// redirect to the login page if credentials are required
getLoginHandler(w, r)
}
func getSessionUserID(w http.ResponseWriter, r *http.Request) (string, error) {
session, err := sessionStore.Get(r, cookieName)
// ignore errors and treat as an empty user id, so that we handle expired
// cookie
if err != nil {
return "", nil
}
if !session.IsNew {
val := session.Values[userIDKey]
// refresh the cookie
err = session.Save(r, w)
if err != nil {
return "", err
}
ret, _ := val.(string)
return ret, nil
}
return "", nil
}
func getCurrentUserID(ctx context.Context) *string {
userCtxVal := ctx.Value(ContextUser)
if userCtxVal != nil {
currentUser := userCtxVal.(string)
return &currentUser
}
return nil
}
func createSessionCookie(username string) (*http.Cookie, error) {
session := sessions.NewSession(sessionStore, cookieName)
session.Values[userIDKey] = username
encoded, err := securecookie.EncodeMulti(session.Name(), session.Values,
sessionStore.Codecs...)
if err != nil {
return nil, err
}
return sessions.NewCookie(session.Name(), encoded, session.Options), nil
}