mirror of
https://github.com/stashapp/stash.git
synced 2025-12-17 20:34:37 +03:00
Plugin hooks (#1452)
* Refactor session and plugin code * Add context to job tasks * Show hooks in plugins page * Refactor session management
This commit is contained in:
@@ -65,6 +65,15 @@ func (t changesetTranslator) hasField(field string) bool {
|
||||
return found
|
||||
}
|
||||
|
||||
func (t changesetTranslator) getFields() []string {
|
||||
var ret []string
|
||||
for k := range t.inputMap {
|
||||
ret = append(ret, k)
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func (t changesetTranslator) nullString(value *string, field string) *sql.NullString {
|
||||
if !t.hasField(field) {
|
||||
return nil
|
||||
|
||||
@@ -5,13 +5,12 @@ package api
|
||||
type key int
|
||||
|
||||
const (
|
||||
galleryKey key = 0
|
||||
performerKey key = 1
|
||||
sceneKey key = 2
|
||||
studioKey key = 3
|
||||
movieKey key = 4
|
||||
ContextUser key = 5
|
||||
tagKey key = 6
|
||||
downloadKey key = 7
|
||||
imageKey key = 8
|
||||
galleryKey key = iota
|
||||
performerKey
|
||||
sceneKey
|
||||
studioKey
|
||||
movieKey
|
||||
tagKey
|
||||
downloadKey
|
||||
imageKey
|
||||
)
|
||||
|
||||
@@ -7,10 +7,16 @@ import (
|
||||
|
||||
"github.com/stashapp/stash/pkg/logger"
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/plugin"
|
||||
)
|
||||
|
||||
type hookExecutor interface {
|
||||
ExecutePostHooks(ctx context.Context, id int, hookType plugin.HookTriggerEnum, input interface{}, inputFields []string)
|
||||
}
|
||||
|
||||
type Resolver struct {
|
||||
txnManager models.TransactionManager
|
||||
txnManager models.TransactionManager
|
||||
hookExecutor hookExecutor
|
||||
}
|
||||
|
||||
func (r *Resolver) Gallery() models.GalleryResolver {
|
||||
|
||||
@@ -10,9 +10,21 @@ import (
|
||||
|
||||
"github.com/stashapp/stash/pkg/manager"
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/plugin"
|
||||
"github.com/stashapp/stash/pkg/utils"
|
||||
)
|
||||
|
||||
func (r *mutationResolver) getGallery(ctx context.Context, id int) (ret *models.Gallery, err error) {
|
||||
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
|
||||
ret, err = repo.Gallery().Find(id)
|
||||
return err
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (r *mutationResolver) GalleryCreate(ctx context.Context, input models.GalleryCreateInput) (*models.Gallery, error) {
|
||||
// name must be provided
|
||||
if input.Title == "" {
|
||||
@@ -90,7 +102,8 @@ func (r *mutationResolver) GalleryCreate(ctx context.Context, input models.Galle
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return gallery, nil
|
||||
r.hookExecutor.ExecutePostHooks(ctx, gallery.ID, plugin.GalleryCreatePost, input, nil)
|
||||
return r.getGallery(ctx, gallery.ID)
|
||||
}
|
||||
|
||||
func (r *mutationResolver) updateGalleryPerformers(qb models.GalleryReaderWriter, galleryID int, performerIDs []string) error {
|
||||
@@ -130,7 +143,9 @@ func (r *mutationResolver) GalleryUpdate(ctx context.Context, input models.Galle
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
// execute post hooks outside txn
|
||||
r.hookExecutor.ExecutePostHooks(ctx, ret.ID, plugin.GalleryUpdatePost, input, translator.getFields())
|
||||
return r.getGallery(ctx, ret.ID)
|
||||
}
|
||||
|
||||
func (r *mutationResolver) GalleriesUpdate(ctx context.Context, input []*models.GalleryUpdateInput) (ret []*models.Gallery, err error) {
|
||||
@@ -156,7 +171,23 @@ func (r *mutationResolver) GalleriesUpdate(ctx context.Context, input []*models.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
// execute post hooks outside txn
|
||||
var newRet []*models.Gallery
|
||||
for i, gallery := range ret {
|
||||
translator := changesetTranslator{
|
||||
inputMap: inputMaps[i],
|
||||
}
|
||||
|
||||
r.hookExecutor.ExecutePostHooks(ctx, gallery.ID, plugin.GalleryUpdatePost, input, translator.getFields())
|
||||
gallery, err = r.getGallery(ctx, gallery.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newRet = append(newRet, gallery)
|
||||
}
|
||||
|
||||
return newRet, nil
|
||||
}
|
||||
|
||||
func (r *mutationResolver) galleryUpdate(input models.GalleryUpdateInput, translator changesetTranslator, repo models.Repository) (*models.Gallery, error) {
|
||||
@@ -314,7 +345,20 @@ func (r *mutationResolver) BulkGalleryUpdate(ctx context.Context, input models.B
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
// execute post hooks outside of txn
|
||||
var newRet []*models.Gallery
|
||||
for _, gallery := range ret {
|
||||
r.hookExecutor.ExecutePostHooks(ctx, gallery.ID, plugin.GalleryUpdatePost, input, translator.getFields())
|
||||
|
||||
gallery, err := r.getGallery(ctx, gallery.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newRet = append(newRet, gallery)
|
||||
}
|
||||
|
||||
return newRet, nil
|
||||
}
|
||||
|
||||
func adjustGalleryPerformerIDs(qb models.GalleryReader, galleryID int, ids models.BulkUpdateIds) (ret []int, err error) {
|
||||
@@ -438,6 +482,16 @@ func (r *mutationResolver) GalleryDestroy(ctx context.Context, input models.Gall
|
||||
}
|
||||
}
|
||||
|
||||
// call post hook after performing the other actions
|
||||
for _, gallery := range galleries {
|
||||
r.hookExecutor.ExecutePostHooks(ctx, gallery.ID, plugin.GalleryDestroyPost, input, nil)
|
||||
}
|
||||
|
||||
// call image destroy post hook as well
|
||||
for _, img := range imgsToDelete {
|
||||
r.hookExecutor.ExecutePostHooks(ctx, img.ID, plugin.ImageDestroyPost, nil, nil)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -8,9 +8,21 @@ import (
|
||||
|
||||
"github.com/stashapp/stash/pkg/manager"
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/plugin"
|
||||
"github.com/stashapp/stash/pkg/utils"
|
||||
)
|
||||
|
||||
func (r *mutationResolver) getImage(ctx context.Context, id int) (ret *models.Image, err error) {
|
||||
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
|
||||
ret, err = repo.Image().Find(id)
|
||||
return err
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (r *mutationResolver) ImageUpdate(ctx context.Context, input models.ImageUpdateInput) (ret *models.Image, err error) {
|
||||
translator := changesetTranslator{
|
||||
inputMap: getUpdateInputMap(ctx),
|
||||
@@ -24,7 +36,9 @@ func (r *mutationResolver) ImageUpdate(ctx context.Context, input models.ImageUp
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
// execute post hooks outside txn
|
||||
r.hookExecutor.ExecutePostHooks(ctx, ret.ID, plugin.ImageUpdatePost, input, translator.getFields())
|
||||
return r.getImage(ctx, ret.ID)
|
||||
}
|
||||
|
||||
func (r *mutationResolver) ImagesUpdate(ctx context.Context, input []*models.ImageUpdateInput) (ret []*models.Image, err error) {
|
||||
@@ -50,7 +64,23 @@ func (r *mutationResolver) ImagesUpdate(ctx context.Context, input []*models.Ima
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
// execute post hooks outside txn
|
||||
var newRet []*models.Image
|
||||
for i, image := range ret {
|
||||
translator := changesetTranslator{
|
||||
inputMap: inputMaps[i],
|
||||
}
|
||||
|
||||
r.hookExecutor.ExecutePostHooks(ctx, image.ID, plugin.ImageUpdatePost, input, translator.getFields())
|
||||
image, err = r.getImage(ctx, image.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newRet = append(newRet, image)
|
||||
}
|
||||
|
||||
return newRet, nil
|
||||
}
|
||||
|
||||
func (r *mutationResolver) imageUpdate(input models.ImageUpdateInput, translator changesetTranslator, repo models.Repository) (*models.Image, error) {
|
||||
@@ -202,7 +232,20 @@ func (r *mutationResolver) BulkImageUpdate(ctx context.Context, input models.Bul
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
// execute post hooks outside of txn
|
||||
var newRet []*models.Image
|
||||
for _, image := range ret {
|
||||
r.hookExecutor.ExecutePostHooks(ctx, image.ID, plugin.ImageUpdatePost, input, translator.getFields())
|
||||
|
||||
image, err = r.getImage(ctx, image.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newRet = append(newRet, image)
|
||||
}
|
||||
|
||||
return newRet, nil
|
||||
}
|
||||
|
||||
func adjustImageGalleryIDs(qb models.ImageReader, imageID int, ids models.BulkUpdateIds) (ret []int, err error) {
|
||||
@@ -268,6 +311,9 @@ func (r *mutationResolver) ImageDestroy(ctx context.Context, input models.ImageD
|
||||
manager.DeleteImageFile(image)
|
||||
}
|
||||
|
||||
// call post hook after performing the other actions
|
||||
r.hookExecutor.ExecutePostHooks(ctx, image.ID, plugin.ImageDestroyPost, input, nil)
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
@@ -315,6 +361,9 @@ func (r *mutationResolver) ImagesDestroy(ctx context.Context, input models.Image
|
||||
if input.DeleteFile != nil && *input.DeleteFile {
|
||||
manager.DeleteImageFile(image)
|
||||
}
|
||||
|
||||
// call post hook after performing the other actions
|
||||
r.hookExecutor.ExecutePostHooks(ctx, image.ID, plugin.ImageDestroyPost, input, nil)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
)
|
||||
|
||||
func (r *mutationResolver) MetadataScan(ctx context.Context, input models.ScanMetadataInput) (string, error) {
|
||||
jobID, err := manager.GetInstance().Scan(input)
|
||||
jobID, err := manager.GetInstance().Scan(ctx, input)
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -27,7 +27,7 @@ func (r *mutationResolver) MetadataScan(ctx context.Context, input models.ScanMe
|
||||
}
|
||||
|
||||
func (r *mutationResolver) MetadataImport(ctx context.Context) (string, error) {
|
||||
jobID, err := manager.GetInstance().Import()
|
||||
jobID, err := manager.GetInstance().Import(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -41,13 +41,13 @@ func (r *mutationResolver) ImportObjects(ctx context.Context, input models.Impor
|
||||
return "", err
|
||||
}
|
||||
|
||||
jobID := manager.GetInstance().RunSingleTask(t)
|
||||
jobID := manager.GetInstance().RunSingleTask(ctx, t)
|
||||
|
||||
return strconv.Itoa(jobID), nil
|
||||
}
|
||||
|
||||
func (r *mutationResolver) MetadataExport(ctx context.Context) (string, error) {
|
||||
jobID, err := manager.GetInstance().Export()
|
||||
jobID, err := manager.GetInstance().Export(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -75,7 +75,7 @@ func (r *mutationResolver) ExportObjects(ctx context.Context, input models.Expor
|
||||
}
|
||||
|
||||
func (r *mutationResolver) MetadataGenerate(ctx context.Context, input models.GenerateMetadataInput) (string, error) {
|
||||
jobID, err := manager.GetInstance().Generate(input)
|
||||
jobID, err := manager.GetInstance().Generate(ctx, input)
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -85,17 +85,17 @@ func (r *mutationResolver) MetadataGenerate(ctx context.Context, input models.Ge
|
||||
}
|
||||
|
||||
func (r *mutationResolver) MetadataAutoTag(ctx context.Context, input models.AutoTagMetadataInput) (string, error) {
|
||||
jobID := manager.GetInstance().AutoTag(input)
|
||||
jobID := manager.GetInstance().AutoTag(ctx, input)
|
||||
return strconv.Itoa(jobID), nil
|
||||
}
|
||||
|
||||
func (r *mutationResolver) MetadataClean(ctx context.Context, input models.CleanMetadataInput) (string, error) {
|
||||
jobID := manager.GetInstance().Clean(input)
|
||||
jobID := manager.GetInstance().Clean(ctx, input)
|
||||
return strconv.Itoa(jobID), nil
|
||||
}
|
||||
|
||||
func (r *mutationResolver) MigrateHashNaming(ctx context.Context) (string, error) {
|
||||
jobID := manager.GetInstance().MigrateHash()
|
||||
jobID := manager.GetInstance().MigrateHash(ctx)
|
||||
return strconv.Itoa(jobID), nil
|
||||
}
|
||||
|
||||
|
||||
@@ -7,9 +7,21 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/plugin"
|
||||
"github.com/stashapp/stash/pkg/utils"
|
||||
)
|
||||
|
||||
func (r *mutationResolver) getMovie(ctx context.Context, id int) (ret *models.Movie, err error) {
|
||||
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
|
||||
ret, err = repo.Movie().Find(id)
|
||||
return err
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (r *mutationResolver) MovieCreate(ctx context.Context, input models.MovieCreateInput) (*models.Movie, error) {
|
||||
// generate checksum from movie name rather than image
|
||||
checksum := utils.MD5FromString(input.Name)
|
||||
@@ -104,7 +116,8 @@ func (r *mutationResolver) MovieCreate(ctx context.Context, input models.MovieCr
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return movie, nil
|
||||
r.hookExecutor.ExecutePostHooks(ctx, movie.ID, plugin.MovieCreatePost, input, nil)
|
||||
return r.getMovie(ctx, movie.ID)
|
||||
}
|
||||
|
||||
func (r *mutationResolver) MovieUpdate(ctx context.Context, input models.MovieUpdateInput) (*models.Movie, error) {
|
||||
@@ -203,7 +216,8 @@ func (r *mutationResolver) MovieUpdate(ctx context.Context, input models.MovieUp
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return movie, nil
|
||||
r.hookExecutor.ExecutePostHooks(ctx, movie.ID, plugin.MovieUpdatePost, input, translator.getFields())
|
||||
return r.getMovie(ctx, movie.ID)
|
||||
}
|
||||
|
||||
func (r *mutationResolver) MovieDestroy(ctx context.Context, input models.MovieDestroyInput) (bool, error) {
|
||||
@@ -217,6 +231,9 @@ func (r *mutationResolver) MovieDestroy(ctx context.Context, input models.MovieD
|
||||
}); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
r.hookExecutor.ExecutePostHooks(ctx, id, plugin.MovieDestroyPost, input, nil)
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
@@ -238,5 +255,10 @@ func (r *mutationResolver) MoviesDestroy(ctx context.Context, movieIDs []string)
|
||||
}); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
for _, id := range ids {
|
||||
r.hookExecutor.ExecutePostHooks(ctx, id, plugin.MovieDestroyPost, movieIDs, nil)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
@@ -9,9 +9,21 @@ import (
|
||||
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/performer"
|
||||
"github.com/stashapp/stash/pkg/plugin"
|
||||
"github.com/stashapp/stash/pkg/utils"
|
||||
)
|
||||
|
||||
func (r *mutationResolver) getPerformer(ctx context.Context, id int) (ret *models.Performer, err error) {
|
||||
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
|
||||
ret, err = repo.Performer().Find(id)
|
||||
return err
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (r *mutationResolver) PerformerCreate(ctx context.Context, input models.PerformerCreateInput) (*models.Performer, error) {
|
||||
// generate checksum from performer name rather than image
|
||||
checksum := utils.MD5FromString(input.Name)
|
||||
@@ -146,7 +158,8 @@ func (r *mutationResolver) PerformerCreate(ctx context.Context, input models.Per
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return performer, nil
|
||||
r.hookExecutor.ExecutePostHooks(ctx, performer.ID, plugin.PerformerCreatePost, input, nil)
|
||||
return r.getPerformer(ctx, performer.ID)
|
||||
}
|
||||
|
||||
func (r *mutationResolver) PerformerUpdate(ctx context.Context, input models.PerformerUpdateInput) (*models.Performer, error) {
|
||||
@@ -267,7 +280,8 @@ func (r *mutationResolver) PerformerUpdate(ctx context.Context, input models.Per
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return p, nil
|
||||
r.hookExecutor.ExecutePostHooks(ctx, p.ID, plugin.PerformerUpdatePost, input, translator.getFields())
|
||||
return r.getPerformer(ctx, p.ID)
|
||||
}
|
||||
|
||||
func (r *mutationResolver) updatePerformerTags(qb models.PerformerReaderWriter, performerID int, tagsIDs []string) error {
|
||||
@@ -372,7 +386,20 @@ func (r *mutationResolver) BulkPerformerUpdate(ctx context.Context, input models
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
// execute post hooks outside of txn
|
||||
var newRet []*models.Performer
|
||||
for _, performer := range ret {
|
||||
r.hookExecutor.ExecutePostHooks(ctx, performer.ID, plugin.ImageUpdatePost, input, translator.getFields())
|
||||
|
||||
performer, err = r.getPerformer(ctx, performer.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newRet = append(newRet, performer)
|
||||
}
|
||||
|
||||
return newRet, nil
|
||||
}
|
||||
|
||||
func (r *mutationResolver) PerformerDestroy(ctx context.Context, input models.PerformerDestroyInput) (bool, error) {
|
||||
@@ -386,6 +413,9 @@ func (r *mutationResolver) PerformerDestroy(ctx context.Context, input models.Pe
|
||||
}); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
r.hookExecutor.ExecutePostHooks(ctx, id, plugin.PerformerDestroyPost, input, nil)
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
@@ -407,5 +437,10 @@ func (r *mutationResolver) PerformersDestroy(ctx context.Context, performerIDs [
|
||||
}); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
for _, id := range ids {
|
||||
r.hookExecutor.ExecutePostHooks(ctx, id, plugin.PerformerDestroyPost, performerIDs, nil)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
@@ -2,40 +2,15 @@ package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/stashapp/stash/pkg/logger"
|
||||
"github.com/stashapp/stash/pkg/manager"
|
||||
"github.com/stashapp/stash/pkg/manager/config"
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/plugin/common"
|
||||
)
|
||||
|
||||
func (r *mutationResolver) RunPluginTask(ctx context.Context, pluginID string, taskName string, args []*models.PluginArgInput) (string, error) {
|
||||
currentUser := getCurrentUserID(ctx)
|
||||
|
||||
var cookie *http.Cookie
|
||||
var err error
|
||||
if currentUser != nil {
|
||||
cookie, err = createSessionCookie(*currentUser)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
config := config.GetInstance()
|
||||
serverConnection := common.StashServerConnection{
|
||||
Scheme: "http",
|
||||
Port: config.GetPort(),
|
||||
SessionCookie: cookie,
|
||||
Dir: config.GetConfigPath(),
|
||||
}
|
||||
|
||||
if HasTLSConfig() {
|
||||
serverConnection.Scheme = "https"
|
||||
}
|
||||
|
||||
manager.GetInstance().RunPluginTask(pluginID, taskName, args, serverConnection)
|
||||
m := manager.GetInstance()
|
||||
m.RunPluginTask(ctx, pluginID, taskName, args)
|
||||
return "todo", nil
|
||||
}
|
||||
|
||||
|
||||
@@ -10,9 +10,21 @@ import (
|
||||
"github.com/stashapp/stash/pkg/manager"
|
||||
"github.com/stashapp/stash/pkg/manager/config"
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/plugin"
|
||||
"github.com/stashapp/stash/pkg/utils"
|
||||
)
|
||||
|
||||
func (r *mutationResolver) getScene(ctx context.Context, id int) (ret *models.Scene, err error) {
|
||||
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
|
||||
ret, err = repo.Scene().Find(id)
|
||||
return err
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (r *mutationResolver) SceneUpdate(ctx context.Context, input models.SceneUpdateInput) (ret *models.Scene, err error) {
|
||||
translator := changesetTranslator{
|
||||
inputMap: getUpdateInputMap(ctx),
|
||||
@@ -26,7 +38,8 @@ func (r *mutationResolver) SceneUpdate(ctx context.Context, input models.SceneUp
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
r.hookExecutor.ExecutePostHooks(ctx, ret.ID, plugin.SceneUpdatePost, input, translator.getFields())
|
||||
return r.getScene(ctx, ret.ID)
|
||||
}
|
||||
|
||||
func (r *mutationResolver) ScenesUpdate(ctx context.Context, input []*models.SceneUpdateInput) (ret []*models.Scene, err error) {
|
||||
@@ -52,7 +65,24 @@ func (r *mutationResolver) ScenesUpdate(ctx context.Context, input []*models.Sce
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
// execute post hooks outside of txn
|
||||
var newRet []*models.Scene
|
||||
for i, scene := range ret {
|
||||
translator := changesetTranslator{
|
||||
inputMap: inputMaps[i],
|
||||
}
|
||||
|
||||
r.hookExecutor.ExecutePostHooks(ctx, scene.ID, plugin.SceneUpdatePost, input, translator.getFields())
|
||||
|
||||
scene, err = r.getScene(ctx, scene.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newRet = append(newRet, scene)
|
||||
}
|
||||
|
||||
return newRet, nil
|
||||
}
|
||||
|
||||
func (r *mutationResolver) sceneUpdate(input models.SceneUpdateInput, translator changesetTranslator, repo models.Repository) (*models.Scene, error) {
|
||||
@@ -281,7 +311,20 @@ func (r *mutationResolver) BulkSceneUpdate(ctx context.Context, input models.Bul
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
// execute post hooks outside of txn
|
||||
var newRet []*models.Scene
|
||||
for _, scene := range ret {
|
||||
r.hookExecutor.ExecutePostHooks(ctx, scene.ID, plugin.SceneUpdatePost, input, translator.getFields())
|
||||
|
||||
scene, err = r.getScene(ctx, scene.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newRet = append(newRet, scene)
|
||||
}
|
||||
|
||||
return newRet, nil
|
||||
}
|
||||
|
||||
func adjustIDs(existingIDs []int, updateIDs models.BulkUpdateIds) []int {
|
||||
@@ -393,6 +436,9 @@ func (r *mutationResolver) SceneDestroy(ctx context.Context, input models.SceneD
|
||||
manager.DeleteSceneFile(scene)
|
||||
}
|
||||
|
||||
// call post hook after performing the other actions
|
||||
r.hookExecutor.ExecutePostHooks(ctx, scene.ID, plugin.SceneDestroyPost, input, nil)
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
@@ -442,11 +488,25 @@ func (r *mutationResolver) ScenesDestroy(ctx context.Context, input models.Scene
|
||||
if input.DeleteFile != nil && *input.DeleteFile {
|
||||
manager.DeleteSceneFile(scene)
|
||||
}
|
||||
|
||||
// call post hook after performing the other actions
|
||||
r.hookExecutor.ExecutePostHooks(ctx, scene.ID, plugin.SceneDestroyPost, input, nil)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (r *mutationResolver) getSceneMarker(ctx context.Context, id int) (ret *models.SceneMarker, err error) {
|
||||
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
|
||||
ret, err = repo.SceneMarker().Find(id)
|
||||
return err
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (r *mutationResolver) SceneMarkerCreate(ctx context.Context, input models.SceneMarkerCreateInput) (*models.SceneMarker, error) {
|
||||
primaryTagID, err := strconv.Atoi(input.PrimaryTagID)
|
||||
if err != nil {
|
||||
@@ -473,7 +533,13 @@ func (r *mutationResolver) SceneMarkerCreate(ctx context.Context, input models.S
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return r.changeMarker(ctx, create, newSceneMarker, tagIDs)
|
||||
ret, err := r.changeMarker(ctx, create, newSceneMarker, tagIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r.hookExecutor.ExecutePostHooks(ctx, ret.ID, plugin.SceneMarkerCreatePost, input, nil)
|
||||
return r.getSceneMarker(ctx, ret.ID)
|
||||
}
|
||||
|
||||
func (r *mutationResolver) SceneMarkerUpdate(ctx context.Context, input models.SceneMarkerUpdateInput) (*models.SceneMarker, error) {
|
||||
@@ -507,7 +573,16 @@ func (r *mutationResolver) SceneMarkerUpdate(ctx context.Context, input models.S
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return r.changeMarker(ctx, update, updatedSceneMarker, tagIDs)
|
||||
ret, err := r.changeMarker(ctx, update, updatedSceneMarker, tagIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
translator := changesetTranslator{
|
||||
inputMap: getUpdateInputMap(ctx),
|
||||
}
|
||||
r.hookExecutor.ExecutePostHooks(ctx, ret.ID, plugin.SceneMarkerUpdatePost, input, translator.getFields())
|
||||
return r.getSceneMarker(ctx, ret.ID)
|
||||
}
|
||||
|
||||
func (r *mutationResolver) SceneMarkerDestroy(ctx context.Context, id string) (bool, error) {
|
||||
@@ -544,6 +619,8 @@ func (r *mutationResolver) SceneMarkerDestroy(ctx context.Context, id string) (b
|
||||
|
||||
postCommitFunc()
|
||||
|
||||
r.hookExecutor.ExecutePostHooks(ctx, markerID, plugin.SceneMarkerDestroyPost, id, nil)
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
@@ -651,9 +728,9 @@ func (r *mutationResolver) SceneResetO(ctx context.Context, id string) (ret int,
|
||||
|
||||
func (r *mutationResolver) SceneGenerateScreenshot(ctx context.Context, id string, at *float64) (string, error) {
|
||||
if at != nil {
|
||||
manager.GetInstance().GenerateScreenshot(id, *at)
|
||||
manager.GetInstance().GenerateScreenshot(ctx, id, *at)
|
||||
} else {
|
||||
manager.GetInstance().GenerateDefaultScreenshot(id)
|
||||
manager.GetInstance().GenerateDefaultScreenshot(ctx, id)
|
||||
}
|
||||
|
||||
return "todo", nil
|
||||
|
||||
@@ -24,6 +24,6 @@ func (r *mutationResolver) SubmitStashBoxFingerprints(ctx context.Context, input
|
||||
}
|
||||
|
||||
func (r *mutationResolver) StashBoxBatchPerformerTag(ctx context.Context, input models.StashBoxBatchPerformerTagInput) (string, error) {
|
||||
jobID := manager.GetInstance().StashBoxBatchPerformerTag(input)
|
||||
jobID := manager.GetInstance().StashBoxBatchPerformerTag(ctx, input)
|
||||
return strconv.Itoa(jobID), nil
|
||||
}
|
||||
|
||||
@@ -8,9 +8,21 @@ import (
|
||||
|
||||
"github.com/stashapp/stash/pkg/manager"
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/plugin"
|
||||
"github.com/stashapp/stash/pkg/utils"
|
||||
)
|
||||
|
||||
func (r *mutationResolver) getStudio(ctx context.Context, id int) (ret *models.Studio, err error) {
|
||||
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
|
||||
ret, err = repo.Studio().Find(id)
|
||||
return err
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (r *mutationResolver) StudioCreate(ctx context.Context, input models.StudioCreateInput) (*models.Studio, error) {
|
||||
// generate checksum from studio name rather than image
|
||||
checksum := utils.MD5FromString(input.Name)
|
||||
@@ -82,7 +94,8 @@ func (r *mutationResolver) StudioCreate(ctx context.Context, input models.Studio
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return studio, nil
|
||||
r.hookExecutor.ExecutePostHooks(ctx, studio.ID, plugin.StudioCreatePost, input, nil)
|
||||
return r.getStudio(ctx, studio.ID)
|
||||
}
|
||||
|
||||
func (r *mutationResolver) StudioUpdate(ctx context.Context, input models.StudioUpdateInput) (*models.Studio, error) {
|
||||
@@ -162,7 +175,8 @@ func (r *mutationResolver) StudioUpdate(ctx context.Context, input models.Studio
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return studio, nil
|
||||
r.hookExecutor.ExecutePostHooks(ctx, studio.ID, plugin.StudioUpdatePost, input, translator.getFields())
|
||||
return r.getStudio(ctx, studio.ID)
|
||||
}
|
||||
|
||||
func (r *mutationResolver) StudioDestroy(ctx context.Context, input models.StudioDestroyInput) (bool, error) {
|
||||
@@ -176,6 +190,9 @@ func (r *mutationResolver) StudioDestroy(ctx context.Context, input models.Studi
|
||||
}); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
r.hookExecutor.ExecutePostHooks(ctx, id, plugin.StudioDestroyPost, input, nil)
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
@@ -197,5 +214,10 @@ func (r *mutationResolver) StudiosDestroy(ctx context.Context, studioIDs []strin
|
||||
}); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
for _, id := range ids {
|
||||
r.hookExecutor.ExecutePostHooks(ctx, id, plugin.StudioDestroyPost, studioIDs, nil)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
@@ -7,10 +7,22 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/plugin"
|
||||
"github.com/stashapp/stash/pkg/tag"
|
||||
"github.com/stashapp/stash/pkg/utils"
|
||||
)
|
||||
|
||||
func (r *mutationResolver) getTag(ctx context.Context, id int) (ret *models.Tag, err error) {
|
||||
if err := r.withReadTxn(ctx, func(repo models.ReaderRepository) error {
|
||||
ret, err = repo.Tag().Find(id)
|
||||
return err
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (r *mutationResolver) TagCreate(ctx context.Context, input models.TagCreateInput) (*models.Tag, error) {
|
||||
// Populate a new tag from the input
|
||||
currentTime := time.Now()
|
||||
@@ -68,7 +80,8 @@ func (r *mutationResolver) TagCreate(ctx context.Context, input models.TagCreate
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return t, nil
|
||||
r.hookExecutor.ExecutePostHooks(ctx, t.ID, plugin.TagCreatePost, input, nil)
|
||||
return r.getTag(ctx, t.ID)
|
||||
}
|
||||
|
||||
func (r *mutationResolver) TagUpdate(ctx context.Context, input models.TagUpdateInput) (*models.Tag, error) {
|
||||
@@ -153,7 +166,8 @@ func (r *mutationResolver) TagUpdate(ctx context.Context, input models.TagUpdate
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return t, nil
|
||||
r.hookExecutor.ExecutePostHooks(ctx, t.ID, plugin.TagUpdatePost, input, translator.getFields())
|
||||
return r.getTag(ctx, t.ID)
|
||||
}
|
||||
|
||||
func (r *mutationResolver) TagDestroy(ctx context.Context, input models.TagDestroyInput) (bool, error) {
|
||||
@@ -167,6 +181,9 @@ func (r *mutationResolver) TagDestroy(ctx context.Context, input models.TagDestr
|
||||
}); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
r.hookExecutor.ExecutePostHooks(ctx, tagID, plugin.TagDestroyPost, input, nil)
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
@@ -188,5 +205,10 @@ func (r *mutationResolver) TagsDestroy(ctx context.Context, tagIDs []string) (bo
|
||||
}); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
for _, id := range ids {
|
||||
r.hookExecutor.ExecutePostHooks(ctx, id, plugin.TagDestroyPost, tagIDs, nil)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/models/mocks"
|
||||
"github.com/stashapp/stash/pkg/plugin"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
@@ -15,7 +16,8 @@ import (
|
||||
// TODO - move this into a common area
|
||||
func newResolver() *Resolver {
|
||||
return &Resolver{
|
||||
txnManager: mocks.NewTransactionManager(),
|
||||
txnManager: mocks.NewTransactionManager(),
|
||||
hookExecutor: &mockHookExecutor{},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,6 +28,11 @@ const existingTagID = 1
|
||||
const existingTagName = "existingTagName"
|
||||
const newTagID = 2
|
||||
|
||||
type mockHookExecutor struct{}
|
||||
|
||||
func (*mockHookExecutor) ExecutePostHooks(ctx context.Context, id int, hookType plugin.HookTriggerEnum, input interface{}, inputFields []string) {
|
||||
}
|
||||
|
||||
func TestTagCreate(t *testing.T) {
|
||||
r := newResolver()
|
||||
|
||||
@@ -84,10 +91,12 @@ func TestTagCreate(t *testing.T) {
|
||||
|
||||
tagRW.On("Query", tagFilterForName(tagName), findFilter).Return(nil, 0, nil).Once()
|
||||
tagRW.On("Query", tagFilterForAlias(tagName), findFilter).Return(nil, 0, nil).Once()
|
||||
tagRW.On("Create", mock.AnythingOfType("models.Tag")).Return(&models.Tag{
|
||||
newTag := &models.Tag{
|
||||
ID: newTagID,
|
||||
Name: tagName,
|
||||
}, nil)
|
||||
}
|
||||
tagRW.On("Create", mock.AnythingOfType("models.Tag")).Return(newTag, nil)
|
||||
tagRW.On("Find", newTagID).Return(newTag, nil)
|
||||
|
||||
tag, err := r.Mutation().TagCreate(context.TODO(), models.TagCreateInput{
|
||||
Name: tagName,
|
||||
|
||||
@@ -29,6 +29,7 @@ import (
|
||||
"github.com/stashapp/stash/pkg/manager/config"
|
||||
"github.com/stashapp/stash/pkg/manager/paths"
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/session"
|
||||
"github.com/stashapp/stash/pkg/utils"
|
||||
)
|
||||
|
||||
@@ -41,11 +42,6 @@ var uiBox *packr.Box
|
||||
//var legacyUiBox *packr.Box
|
||||
var loginUIBox *packr.Box
|
||||
|
||||
const (
|
||||
ApiKeyHeader = "ApiKey"
|
||||
ApiKeyParameter = "apikey"
|
||||
)
|
||||
|
||||
func allowUnauthenticated(r *http.Request) bool {
|
||||
return strings.HasPrefix(r.URL.Path, "/login") || r.URL.Path == "/css"
|
||||
}
|
||||
@@ -53,44 +49,26 @@ func allowUnauthenticated(r *http.Request) bool {
|
||||
func authenticateHandler() func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
c := config.GetInstance()
|
||||
ctx := r.Context()
|
||||
|
||||
// translate api key into current user, if present
|
||||
userID := ""
|
||||
apiKey := r.Header.Get(ApiKeyHeader)
|
||||
var err error
|
||||
|
||||
// try getting the api key as a query parameter
|
||||
if apiKey == "" {
|
||||
apiKey = r.URL.Query().Get(ApiKeyParameter)
|
||||
}
|
||||
|
||||
if apiKey != "" {
|
||||
// match against configured API and set userID to the
|
||||
// configured username. In future, we'll want to
|
||||
// get the username from the key.
|
||||
if c.GetAPIKey() != apiKey {
|
||||
w.Header().Add("WWW-Authenticate", `FormBased`)
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
userID, err := manager.GetInstance().SessionStore.Authenticate(w, r)
|
||||
if err != nil {
|
||||
if err != session.ErrUnauthorized {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
_, err = w.Write([]byte(err.Error()))
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
userID = c.GetUsername()
|
||||
} else {
|
||||
// handle session
|
||||
userID, err = getSessionUserID(w, r)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
_, err = w.Write([]byte(err.Error()))
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
// unauthorized error
|
||||
w.Header().Add("WWW-Authenticate", `FormBased`)
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
c := config.GetInstance()
|
||||
ctx := r.Context()
|
||||
|
||||
// handle redirect if no user and user is required
|
||||
if userID == "" && c.HasCredentials() && !allowUnauthenticated(r) {
|
||||
// if we don't have a userID, then redirect
|
||||
@@ -112,7 +90,7 @@ func authenticateHandler() func(http.Handler) http.Handler {
|
||||
return
|
||||
}
|
||||
|
||||
ctx = context.WithValue(ctx, ContextUser, userID)
|
||||
ctx = session.SetCurrentUserID(ctx, userID)
|
||||
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
@@ -121,6 +99,16 @@ func authenticateHandler() func(http.Handler) http.Handler {
|
||||
}
|
||||
}
|
||||
|
||||
func visitedPluginHandler() func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// get the visited plugins and set them in the context
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const loginEndPoint = "/login"
|
||||
|
||||
func Start() {
|
||||
@@ -128,13 +116,15 @@ func Start() {
|
||||
//legacyUiBox = packr.New("UI Box", "../../ui/v1/dist/stash-frontend")
|
||||
loginUIBox = packr.New("Login UI Box", "../../ui/login")
|
||||
|
||||
initSessionStore()
|
||||
initialiseImages()
|
||||
|
||||
r := chi.NewRouter()
|
||||
|
||||
r.Use(middleware.Heartbeat("/healthz"))
|
||||
r.Use(authenticateHandler())
|
||||
visitedPluginHandler := manager.GetInstance().SessionStore.VisitedPluginHandler()
|
||||
r.Use(visitedPluginHandler)
|
||||
|
||||
r.Use(middleware.Recoverer)
|
||||
|
||||
c := config.GetInstance()
|
||||
@@ -155,8 +145,10 @@ func Start() {
|
||||
}
|
||||
|
||||
txnManager := manager.GetInstance().TxnManager
|
||||
pluginCache := manager.GetInstance().PluginCache
|
||||
resolver := &Resolver{
|
||||
txnManager: txnManager,
|
||||
txnManager: txnManager,
|
||||
hookExecutor: pluginCache,
|
||||
}
|
||||
|
||||
gqlSrv := gqlHandler.New(models.NewExecutableSchema(models.Config{Resolvers: resolver}))
|
||||
@@ -184,7 +176,8 @@ func Start() {
|
||||
}
|
||||
|
||||
// register GQL handler with plugin cache
|
||||
manager.GetInstance().PluginCache.RegisterGQLHandler(gqlHandlerFunc)
|
||||
// chain the visited plugin handler
|
||||
manager.GetInstance().PluginCache.RegisterGQLHandler(visitedPluginHandler(http.HandlerFunc(gqlHandlerFunc)))
|
||||
|
||||
r.HandleFunc("/graphql", gqlHandlerFunc)
|
||||
r.HandleFunc("/playground", gqlPlayground.Handler("GraphQL playground", "/graphql"))
|
||||
@@ -358,15 +351,6 @@ func makeTLSConfig() *tls.Config {
|
||||
return tlsConfig
|
||||
}
|
||||
|
||||
func HasTLSConfig() bool {
|
||||
ret, _ := utils.FileExists(paths.GetSSLCert())
|
||||
if ret {
|
||||
ret, _ = utils.FileExists(paths.GetSSLKey())
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
type contextKey struct {
|
||||
name string
|
||||
}
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"net/http"
|
||||
|
||||
"github.com/stashapp/stash/pkg/manager"
|
||||
"github.com/stashapp/stash/pkg/manager/config"
|
||||
|
||||
"github.com/gorilla/securecookie"
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/stashapp/stash/pkg/session"
|
||||
)
|
||||
|
||||
const cookieName = "session"
|
||||
@@ -19,17 +17,11 @@ const userIDKey = "userID"
|
||||
|
||||
const returnURLParam = "returnURL"
|
||||
|
||||
var sessionStore = sessions.NewCookieStore(config.GetInstance().GetSessionStoreKey())
|
||||
|
||||
type loginTemplateData struct {
|
||||
URL string
|
||||
Error string
|
||||
}
|
||||
|
||||
func initSessionStore() {
|
||||
sessionStore.MaxAge(config.GetInstance().GetMaxSessionAge())
|
||||
}
|
||||
|
||||
func redirectToLogin(w http.ResponseWriter, returnURL string, loginError string) {
|
||||
data, _ := loginUIBox.Find("login.html")
|
||||
templ, err := template.New("Login").Parse(string(data))
|
||||
@@ -59,22 +51,13 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
url = "/"
|
||||
}
|
||||
|
||||
// ignore error - we want a new session regardless
|
||||
newSession, _ := sessionStore.Get(r, cookieName)
|
||||
|
||||
username := r.FormValue("username")
|
||||
password := r.FormValue("password")
|
||||
|
||||
// authenticate the user
|
||||
if !config.GetInstance().ValidateCredentials(username, password) {
|
||||
err := manager.GetInstance().SessionStore.Login(w, r)
|
||||
if err == session.ErrInvalidCredentials {
|
||||
// redirect back to the login page with an error
|
||||
redirectToLogin(w, url, "Username or password is invalid")
|
||||
return
|
||||
}
|
||||
|
||||
newSession.Values[userIDKey] = username
|
||||
|
||||
err := newSession.Save(r, w)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
@@ -84,17 +67,7 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func handleLogout(w http.ResponseWriter, r *http.Request) {
|
||||
session, err := sessionStore.Get(r, cookieName)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
delete(session.Values, userIDKey)
|
||||
session.Options.MaxAge = -1
|
||||
|
||||
err = session.Save(r, w)
|
||||
if err != nil {
|
||||
if err := manager.GetInstance().SessionStore.Logout(w, r); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -102,51 +75,3 @@ func handleLogout(w http.ResponseWriter, r *http.Request) {
|
||||
// redirect to the login page if credentials are required
|
||||
getLoginHandler(w, r)
|
||||
}
|
||||
|
||||
func getSessionUserID(w http.ResponseWriter, r *http.Request) (string, error) {
|
||||
session, err := sessionStore.Get(r, cookieName)
|
||||
// ignore errors and treat as an empty user id, so that we handle expired
|
||||
// cookie
|
||||
if err != nil {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
if !session.IsNew {
|
||||
val := session.Values[userIDKey]
|
||||
|
||||
// refresh the cookie
|
||||
err = session.Save(r, w)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
ret, _ := val.(string)
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func getCurrentUserID(ctx context.Context) *string {
|
||||
userCtxVal := ctx.Value(ContextUser)
|
||||
if userCtxVal != nil {
|
||||
currentUser := userCtxVal.(string)
|
||||
return ¤tUser
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func createSessionCookie(username string) (*http.Cookie, error) {
|
||||
session := sessions.NewSession(sessionStore, cookieName)
|
||||
session.Values[userIDKey] = username
|
||||
|
||||
encoded, err := securecookie.EncodeMulti(session.Name(), session.Values,
|
||||
sessionStore.Codecs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return sessions.NewCookie(session.Name(), encoded, session.Options), nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user