Cache package list in memory instead of filesystem (#4309)

This commit is contained in:
WithoutPants
2023-11-23 14:16:13 +11:00
committed by GitHub
parent 0dcd58763f
commit a8140c11ec
4 changed files with 84 additions and 85 deletions

View File

@@ -288,7 +288,7 @@ func initialize() error {
return nil return nil
} }
func initialisePackageManager(localPath string, srcPathGetter pkg.SourcePathGetter, cachePathGetter pkg.CachePathGetter) *pkg.Manager { func initialisePackageManager(localPath string, srcPathGetter pkg.SourcePathGetter) *pkg.Manager {
const timeout = 10 * time.Second const timeout = 10 * time.Second
httpClient := &http.Client{ httpClient := &http.Client{
Transport: &http.Transport{ Transport: &http.Transport{
@@ -304,7 +304,6 @@ func initialisePackageManager(localPath string, srcPathGetter pkg.SourcePathGett
}, },
PackagePathGetter: srcPathGetter, PackagePathGetter: srcPathGetter,
Client: httpClient, Client: httpClient,
CachePathGetter: cachePathGetter,
} }
} }
@@ -595,11 +594,11 @@ func (s *Manager) RefreshStreamManager() {
} }
func (s *Manager) RefreshScraperSourceManager() { func (s *Manager) RefreshScraperSourceManager() {
s.ScraperPackageManager = initialisePackageManager(s.Config.GetScrapersPath(), s.Config.GetScraperPackagePathGetter(), s.Config) s.ScraperPackageManager = initialisePackageManager(s.Config.GetScrapersPath(), s.Config.GetScraperPackagePathGetter())
} }
func (s *Manager) RefreshPluginSourceManager() { func (s *Manager) RefreshPluginSourceManager() {
s.PluginPackageManager = initialisePackageManager(s.Config.GetPluginsPath(), s.Config.GetPluginPackagePathGetter(), s.Config) s.PluginPackageManager = initialisePackageManager(s.Config.GetPluginsPath(), s.Config.GetPluginPackagePathGetter())
} }
func setSetupDefaults(input *SetupInput) { func setSetupDefaults(input *SetupInput) {

View File

@@ -1,27 +1,23 @@
package pkg package pkg
import ( import (
"fmt"
"io"
"os"
"path/filepath"
"time" "time"
"github.com/stashapp/stash/pkg/hash/md5"
"github.com/stashapp/stash/pkg/logger"
) )
const cacheSubDir = "package_lists" type cacheEntry struct {
lastModified time.Time
type repositoryCache struct { data []RemotePackage
cachePath string
} }
func (c *repositoryCache) path(url string) string { type repositoryCache struct {
// convert the url to md5 // cache maps the URL to the last modified time and the data
hash := md5.FromString(url) cache map[string]cacheEntry
}
return filepath.Join(c.cachePath, cacheSubDir, hash) func (c *repositoryCache) ensureCache() {
if c.cache == nil {
c.cache = make(map[string]cacheEntry)
}
} }
func (c *repositoryCache) lastModified(url string) *time.Time { func (c *repositoryCache) lastModified(url string) *time.Time {
@@ -29,54 +25,35 @@ func (c *repositoryCache) lastModified(url string) *time.Time {
return nil return nil
} }
path := c.path(url) c.ensureCache()
s, err := os.Stat(path) e, found := c.cache[url]
if err != nil {
// ignore if !found {
logger.Debugf("error getting cached file %s: %v", path, err)
return nil return nil
} }
ret := s.ModTime() return &e.lastModified
return &ret
} }
func (c *repositoryCache) getPackageList(url string) (io.ReadCloser, error) { func (c *repositoryCache) getPackageList(url string) []RemotePackage {
path := c.path(url) c.ensureCache()
ret, err := os.Open(path) e, found := c.cache[url]
if err != nil {
return nil, fmt.Errorf("failed to get file %q: %w", path, err) if !found {
return nil
} }
return ret, nil return e.data
} }
func (c *repositoryCache) cacheFile(url string, data io.ReadCloser) (io.ReadCloser, error) { func (c *repositoryCache) cacheList(url string, lastModified time.Time, data []RemotePackage) {
if c == nil { if c == nil {
return data, nil return
} }
path := c.path(url) c.ensureCache()
c.cache[url] = cacheEntry{
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { lastModified: lastModified,
// ignore, just return the original file data: data,
logger.Debugf("error creating cache path %s: %v", filepath.Dir(path), err)
return data, nil
} }
f, err := os.Create(path)
if err != nil {
// ignore, just return the original file
logger.Debugf("error creating cached file %s: %v", path, err)
return data, nil
}
defer data.Close()
if _, err := io.Copy(f, data); err != nil {
_ = f.Close()
return nil, fmt.Errorf("writing to cache file %s - %w", path, err)
}
_ = f.Close()
return c.getPackageList(url)
} }

View File

@@ -14,10 +14,6 @@ import (
"github.com/stashapp/stash/pkg/models" "github.com/stashapp/stash/pkg/models"
) )
type CachePathGetter interface {
GetCachePath() string
}
// SourcePathGetter gets the source path for a given package URL. // SourcePathGetter gets the source path for a given package URL.
type SourcePathGetter interface { type SourcePathGetter interface {
// GetAllSourcePaths gets all source paths. // GetAllSourcePaths gets all source paths.
@@ -31,9 +27,18 @@ type SourcePathGetter interface {
type Manager struct { type Manager struct {
Local *Store Local *Store
PackagePathGetter SourcePathGetter PackagePathGetter SourcePathGetter
CachePathGetter CachePathGetter
Client *http.Client Client *http.Client
cache *repositoryCache
}
func (m *Manager) getCache() *repositoryCache {
if m.cache == nil {
m.cache = &repositoryCache{}
}
return m.cache
} }
func (m *Manager) remoteFromURL(path string) (*httpRepository, error) { func (m *Manager) remoteFromURL(path string) (*httpRepository, error) {
@@ -42,13 +47,7 @@ func (m *Manager) remoteFromURL(path string) (*httpRepository, error) {
return nil, fmt.Errorf("parsing path: %w", err) return nil, fmt.Errorf("parsing path: %w", err)
} }
cachePath := m.CachePathGetter.GetCachePath() return newHttpRepository(*u, m.Client, m.getCache()), nil
var cache *repositoryCache
if cachePath != "" {
cache = &repositoryCache{cachePath: cachePath}
}
return newHttpRepository(*u, m.Client, cache), nil
} }
func (m *Manager) ListInstalled(ctx context.Context) (LocalPackageIndex, error) { func (m *Manager) ListInstalled(ctx context.Context) (LocalPackageIndex, error) {

View File

@@ -16,9 +16,6 @@ import (
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
) )
// DefaultCacheTTL is the default time to live for the index cache.
const DefaultCacheTTL = 5 * time.Minute
// httpRepository is a HTTP based repository. // httpRepository is a HTTP based repository.
// It is configured with a package list URL. Packages are located from the Path field of the package. // It is configured with a package list URL. Packages are located from the Path field of the package.
// //
@@ -51,13 +48,28 @@ func (r *httpRepository) List(ctx context.Context) ([]RemotePackage, error) {
// the package list URL may be file://, in which case we need to use the local file system // the package list URL may be file://, in which case we need to use the local file system
var ( var (
f io.ReadCloser f io.ReadCloser
err error modTime *time.Time
err error
) )
if u.Scheme == "file" {
isLocal := u.Scheme == "file"
if isLocal {
f, err = r.getLocalFile(ctx, u.Path) f, err = r.getLocalFile(ctx, u.Path)
} else { } else {
f, err = r.getFileCached(ctx, u) // try to get the cached list first
var cachedList []RemotePackage
cachedList, err = r.getCachedList(ctx, u)
if err != nil {
return nil, fmt.Errorf("failed to get cached package list: %w", err)
}
if cachedList != nil {
return cachedList, nil
}
f, modTime, err = r.getFile(ctx, u)
} }
if err != nil { if err != nil {
@@ -76,6 +88,11 @@ func (r *httpRepository) List(ctx context.Context) ([]RemotePackage, error) {
return nil, fmt.Errorf("reading package list: %w", err) return nil, fmt.Errorf("reading package list: %w", err)
} }
// cache if not local file
if !isLocal {
r.cache.cacheList(u.String(), *modTime, index)
}
return index, nil return index, nil
} }
@@ -124,7 +141,7 @@ func (r *httpRepository) GetPackageZip(ctx context.Context, pkg RemotePackage) (
f, err = r.getLocalFile(ctx, u.Path) f, err = r.getLocalFile(ctx, u.Path)
} else { } else {
f, err = r.getFile(ctx, u) f, _, err = r.getFile(ctx, u)
} }
if err != nil { if err != nil {
@@ -135,8 +152,8 @@ func (r *httpRepository) GetPackageZip(ctx context.Context, pkg RemotePackage) (
} }
// getFileCached tries to get the list from the local cache. // getFileCached tries to get the list from the local cache.
// If it is not found or is stale, then it gets it normally. // If it is not found or is stale, then nil is returned.
func (r *httpRepository) getFileCached(ctx context.Context, u url.URL) (io.ReadCloser, error) { func (r *httpRepository) getCachedList(ctx context.Context, u url.URL) ([]RemotePackage, error) {
// check if the file is in the cache first // check if the file is in the cache first
localModTime := r.cache.lastModified(u.String()) localModTime := r.cache.lastModified(u.String())
@@ -163,34 +180,41 @@ func (r *httpRepository) getFileCached(ctx context.Context, u url.URL) (io.ReadC
if !remoteModTime.After(*localModTime) { if !remoteModTime.After(*localModTime) {
logger.Debugf("cached version of %s is equal or newer than remote", u.String()) logger.Debugf("cached version of %s is equal or newer than remote", u.String())
return r.cache.getPackageList(u.String()) return r.cache.getPackageList(u.String()), nil
} }
} }
logger.Debugf("cached version of %s is older than remote", u.String()) logger.Debugf("cached version of %s is older than remote", u.String())
} }
return r.getFile(ctx, u) return nil, nil
} }
func (r *httpRepository) getFile(ctx context.Context, u url.URL) (io.ReadCloser, error) { // getFile gets the file from the remote server. Returns the file and the last modified time.
func (r *httpRepository) getFile(ctx context.Context, u url.URL) (io.ReadCloser, *time.Time, error) {
logger.Debugf("fetching %s", u.String()) logger.Debugf("fetching %s", u.String())
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
if err != nil { if err != nil {
// shouldn't happen // shouldn't happen
return nil, err return nil, nil, err
} }
resp, err := r.client.Do(req) resp, err := r.client.Do(req)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get remote file: %w", err) return nil, nil, fmt.Errorf("failed to get remote file: %w", err)
} }
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
return nil, fmt.Errorf("failed to get remote file: %s", resp.Status) return nil, nil, fmt.Errorf("failed to get remote file: %s", resp.Status)
} }
return r.cache.cacheFile(u.String(), resp.Body) lastModified := resp.Header.Get("Last-Modified")
var remoteModTime time.Time
if lastModified != "" {
remoteModTime, _ = time.Parse(http.TimeFormat, lastModified)
}
return resp.Body, &remoteModTime, nil
} }
func (r *httpRepository) getLocalFile(ctx context.Context, path string) (fs.File, error) { func (r *httpRepository) getLocalFile(ctx context.Context, path string) (fs.File, error) {