mirror of
https://github.com/stashapp/stash.git
synced 2025-12-17 04:14:39 +03:00
Improve plugin hook cyclic detection (#4625)
* Move and rename HookTriggerEnum into separate package * Move visited plugin hook handler code * Allow up to ten plugin hook loops
This commit is contained in:
@@ -21,6 +21,7 @@ import (
|
||||
"github.com/stashapp/stash/pkg/logger"
|
||||
"github.com/stashapp/stash/pkg/models"
|
||||
"github.com/stashapp/stash/pkg/plugin/common"
|
||||
"github.com/stashapp/stash/pkg/plugin/hook"
|
||||
"github.com/stashapp/stash/pkg/session"
|
||||
"github.com/stashapp/stash/pkg/sliceutil"
|
||||
"github.com/stashapp/stash/pkg/txn"
|
||||
@@ -356,7 +357,7 @@ func waitForTask(ctx context.Context, task Task) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c Cache) ExecutePostHooks(ctx context.Context, id int, hookType HookTriggerEnum, input interface{}, inputFields []string) {
|
||||
func (c Cache) ExecutePostHooks(ctx context.Context, id int, hookType hook.TriggerEnum, input interface{}, inputFields []string) {
|
||||
if err := c.executePostHooks(ctx, hookType, common.HookContext{
|
||||
ID: id,
|
||||
Type: hookType.String(),
|
||||
@@ -367,7 +368,7 @@ func (c Cache) ExecutePostHooks(ctx context.Context, id int, hookType HookTrigge
|
||||
}
|
||||
}
|
||||
|
||||
func (c Cache) RegisterPostHooks(ctx context.Context, id int, hookType HookTriggerEnum, input interface{}, inputFields []string) {
|
||||
func (c Cache) RegisterPostHooks(ctx context.Context, id int, hookType hook.TriggerEnum, input interface{}, inputFields []string) {
|
||||
txn.AddPostCommitHook(ctx, func(ctx context.Context) {
|
||||
c.ExecutePostHooks(ctx, id, hookType, input, inputFields)
|
||||
})
|
||||
@@ -379,23 +380,28 @@ func (c Cache) ExecuteSceneUpdatePostHooks(ctx context.Context, input models.Sce
|
||||
logger.Errorf("error converting id in SceneUpdatePostHooks: %v", err)
|
||||
return
|
||||
}
|
||||
c.ExecutePostHooks(ctx, id, SceneUpdatePost, input, inputFields)
|
||||
c.ExecutePostHooks(ctx, id, hook.SceneUpdatePost, input, inputFields)
|
||||
}
|
||||
|
||||
func (c Cache) executePostHooks(ctx context.Context, hookType HookTriggerEnum, hookContext common.HookContext) error {
|
||||
visitedPlugins := session.GetVisitedPlugins(ctx)
|
||||
// maxCyclicLoopDepth is the maximum number of identical plugin hook calls that
|
||||
// can be made before a cyclic loop is detected. It is set to an arbitrary value
|
||||
// that should not be hit under normal circumstances.
|
||||
const maxCyclicLoopDepth = 10
|
||||
|
||||
func (c Cache) executePostHooks(ctx context.Context, hookType hook.TriggerEnum, hookContext common.HookContext) error {
|
||||
visitedPluginHookCounts := getVisitedPluginHookCounts(ctx)
|
||||
|
||||
for _, p := range c.enabledPlugins() {
|
||||
hooks := p.getHooks(hookType)
|
||||
// don't revisit a plugin we've already visited
|
||||
// only log if there's hooks that we're skipping
|
||||
if len(hooks) > 0 && sliceutil.Contains(visitedPlugins, p.id) {
|
||||
logger.Debugf("plugin ID '%s' already triggered, not re-triggering", p.id)
|
||||
if len(hooks) > 0 && visitedPluginHookCounts.For(p.id, hookType) >= maxCyclicLoopDepth {
|
||||
logger.Debugf("cyclic loop detected: plugin ID '%s' hook %s, not re-triggering", p.id, hookType)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, h := range hooks {
|
||||
newCtx := session.AddVisitedPlugin(ctx, p.id)
|
||||
newCtx := session.AddVisitedPluginHook(ctx, p.id, hookType)
|
||||
serverConnection := c.makeServerConnection(newCtx)
|
||||
|
||||
pluginInput := buildPluginInput(&p, &h.OperationConfig, serverConnection, nil)
|
||||
@@ -434,6 +440,46 @@ func (c Cache) executePostHooks(ctx context.Context, hookType HookTriggerEnum, h
|
||||
return nil
|
||||
}
|
||||
|
||||
type visitedPluginHookCount struct {
|
||||
session.VisitedPluginHook
|
||||
Count int
|
||||
}
|
||||
|
||||
type visitedPluginHookCounts []visitedPluginHookCount
|
||||
|
||||
func (v visitedPluginHookCounts) For(pluginID string, hookType hook.TriggerEnum) int {
|
||||
for _, c := range v {
|
||||
if c.VisitedPluginHook.PluginID == pluginID && c.VisitedPluginHook.HookType == hookType {
|
||||
return c.Count
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func getVisitedPluginHookCounts(ctx context.Context) visitedPluginHookCounts {
|
||||
visitedPluginHooks := session.GetVisitedPluginHooks(ctx)
|
||||
|
||||
visitedPluginHookCounts := make([]visitedPluginHookCount, 0)
|
||||
for _, p := range visitedPluginHooks {
|
||||
found := false
|
||||
for i, v := range visitedPluginHookCounts {
|
||||
if v.VisitedPluginHook == p {
|
||||
visitedPluginHookCounts[i].Count++
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
visitedPluginHookCounts = append(visitedPluginHookCounts, visitedPluginHookCount{
|
||||
VisitedPluginHook: p,
|
||||
Count: 1,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return visitedPluginHookCounts
|
||||
}
|
||||
|
||||
func (c Cache) getPlugin(pluginID string) *Config {
|
||||
for _, s := range c.plugins {
|
||||
if s.id == pluginID {
|
||||
|
||||
Reference in New Issue
Block a user