diff --git a/pkg/api/resolver_mutation_configure.go b/pkg/api/resolver_mutation_configure.go index ee038ae77..8273cc7e9 100644 --- a/pkg/api/resolver_mutation_configure.go +++ b/pkg/api/resolver_mutation_configure.go @@ -13,6 +13,8 @@ import ( "github.com/stashapp/stash/pkg/utils" ) +var ErrOverriddenConfig = errors.New("cannot set overridden value") + func (r *mutationResolver) Setup(ctx context.Context, input models.SetupInput) (bool, error) { err := manager.GetInstance().Setup(ctx, input) return err == nil, err @@ -25,6 +27,7 @@ func (r *mutationResolver) Migrate(ctx context.Context, input models.MigrateInpu func (r *mutationResolver) ConfigureGeneral(ctx context.Context, input models.ConfigGeneralInput) (*models.ConfigGeneralResult, error) { c := config.GetInstance() + existingPaths := c.GetStashPaths() if len(input.Stashes) > 0 { for _, s := range input.Stashes { @@ -46,7 +49,20 @@ func (r *mutationResolver) ConfigureGeneral(ctx context.Context, input models.Co c.Set(config.Stash, input.Stashes) } - if input.DatabasePath != nil { + checkConfigOverride := func(key string) error { + if c.HasOverride(key) { + return fmt.Errorf("%w: %s", ErrOverriddenConfig, key) + } + + return nil + } + + existingDBPath := c.GetDatabasePath() + if input.DatabasePath != nil && existingDBPath != *input.DatabasePath { + if err := checkConfigOverride(config.Database); err != nil { + return makeConfigGeneralResult(), err + } + ext := filepath.Ext(*input.DatabasePath) if ext != ".db" && ext != ".sqlite" && ext != ".sqlite3" { return makeConfigGeneralResult(), fmt.Errorf("invalid database path, use extension db, sqlite, or sqlite3") @@ -54,14 +70,24 @@ func (r *mutationResolver) ConfigureGeneral(ctx context.Context, input models.Co c.Set(config.Database, input.DatabasePath) } - if input.GeneratedPath != nil { + existingGeneratedPath := c.GetGeneratedPath() + if input.GeneratedPath != nil && existingGeneratedPath != *input.GeneratedPath { + if err := checkConfigOverride(config.Generated); err != nil { + return makeConfigGeneralResult(), err + } + if err := utils.EnsureDir(*input.GeneratedPath); err != nil { return makeConfigGeneralResult(), err } c.Set(config.Generated, input.GeneratedPath) } - if input.MetadataPath != nil { + existingMetadataPath := c.GetMetadataPath() + if input.MetadataPath != nil && existingMetadataPath != *input.MetadataPath { + if err := checkConfigOverride(config.Metadata); err != nil { + return makeConfigGeneralResult(), err + } + if *input.MetadataPath != "" { if err := utils.EnsureDir(*input.MetadataPath); err != nil { return makeConfigGeneralResult(), err @@ -70,7 +96,12 @@ func (r *mutationResolver) ConfigureGeneral(ctx context.Context, input models.Co c.Set(config.Metadata, input.MetadataPath) } - if input.CachePath != nil { + existingCachePath := c.GetCachePath() + if input.CachePath != nil && existingCachePath != *input.CachePath { + if err := checkConfigOverride(config.Metadata); err != nil { + return makeConfigGeneralResult(), err + } + if *input.CachePath != "" { if err := utils.EnsureDir(*input.CachePath); err != nil { return makeConfigGeneralResult(), err diff --git a/pkg/api/resolver_query_configuration.go b/pkg/api/resolver_query_configuration.go index 413aba88e..25a43275b 100644 --- a/pkg/api/resolver_query_configuration.go +++ b/pkg/api/resolver_query_configuration.go @@ -64,7 +64,7 @@ func makeConfigGeneralResult() *models.ConfigGeneralResult { DatabasePath: config.GetDatabasePath(), GeneratedPath: config.GetGeneratedPath(), MetadataPath: config.GetMetadataPath(), - ConfigFilePath: config.GetConfigFilePath(), + ConfigFilePath: config.GetConfigFile(), ScrapersPath: config.GetScrapersPath(), CachePath: config.GetCachePath(), CalculateMd5: config.IsCalculateMD5(), @@ -108,7 +108,7 @@ func makeConfigInterfaceResult() *models.ConfigInterfaceResult { soundOnPreview := config.GetSoundOnPreview() wallShowTitle := config.GetWallShowTitle() wallPlayback := config.GetWallPlayback() - noBrowser := config.GetNoBrowserFlag() + noBrowser := config.GetNoBrowser() maximumLoopDuration := config.GetMaximumLoopDuration() autostartVideo := config.GetAutostartVideo() autostartVideoOnPlaySelected := config.GetAutostartVideoOnPlaySelected() diff --git a/pkg/api/server.go b/pkg/api/server.go index e9fd23efe..b9a48e60c 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -252,7 +252,7 @@ func Start(uiBox embed.FS, loginUIBox embed.FS) { // This can be done before actually starting the server, as modern browsers will // automatically reload the page if a local port is closed at page load and then opened. - if !c.GetNoBrowserFlag() && manager.GetInstance().IsDesktop() { + if !c.GetNoBrowser() && manager.GetInstance().IsDesktop() { err = browser.OpenURL(displayAddress) if err != nil { logger.Error("Could not open browser: " + err.Error()) diff --git a/pkg/manager/config/config.go b/pkg/manager/config/config.go index 27dd82ad7..5b378623b 100644 --- a/pkg/manager/config/config.go +++ b/pkg/manager/config/config.go @@ -21,164 +21,173 @@ import ( "github.com/stashapp/stash/pkg/utils" ) -const Stash = "stash" -const Cache = "cache" -const Generated = "generated" -const Metadata = "metadata" -const Downloads = "downloads" -const ApiKey = "api_key" -const Username = "username" -const Password = "password" -const MaxSessionAge = "max_session_age" - -const DefaultMaxSessionAge = 60 * 60 * 1 // 1 hours - -const Database = "database" - -const Exclude = "exclude" -const ImageExclude = "image_exclude" - -const VideoExtensions = "video_extensions" - -var defaultVideoExtensions = []string{"m4v", "mp4", "mov", "wmv", "avi", "mpg", "mpeg", "rmvb", "rm", "flv", "asf", "mkv", "webm"} - -const ImageExtensions = "image_extensions" - -var defaultImageExtensions = []string{"png", "jpg", "jpeg", "gif", "webp"} - -const GalleryExtensions = "gallery_extensions" - -var defaultGalleryExtensions = []string{"zip", "cbz"} - -const CreateGalleriesFromFolders = "create_galleries_from_folders" - -// CalculateMD5 is the config key used to determine if MD5 should be calculated -// for video files. -const CalculateMD5 = "calculate_md5" - -// VideoFileNamingAlgorithm is the config key used to determine what hash -// should be used when generating and using generated files for scenes. -const VideoFileNamingAlgorithm = "video_file_naming_algorithm" - -const MaxTranscodeSize = "max_transcode_size" -const MaxStreamingTranscodeSize = "max_streaming_transcode_size" - -const ParallelTasks = "parallel_tasks" -const parallelTasksDefault = 1 - -const PreviewPreset = "preview_preset" - -const PreviewAudio = "preview_audio" -const previewAudioDefault = true - -const PreviewSegmentDuration = "preview_segment_duration" -const previewSegmentDurationDefault = 0.75 - -const PreviewSegments = "preview_segments" -const previewSegmentsDefault = 12 - -const PreviewExcludeStart = "preview_exclude_start" -const previewExcludeStartDefault = "0" - -const PreviewExcludeEnd = "preview_exclude_end" -const previewExcludeEndDefault = "0" - -const WriteImageThumbnails = "write_image_thumbnails" -const writeImageThumbnailsDefault = true - -const Host = "host" -const Port = "port" -const ExternalHost = "external_host" - -// key used to sign JWT tokens -const JWTSignKey = "jwt_secret_key" - -// key used for session store -const SessionStoreKey = "session_store_key" - -// scraping options -const ScrapersPath = "scrapers_path" -const ScraperUserAgent = "scraper_user_agent" -const ScraperCertCheck = "scraper_cert_check" -const ScraperCDPPath = "scraper_cdp_path" -const ScraperExcludeTagPatterns = "scraper_exclude_tag_patterns" - -// stash-box options -const StashBoxes = "stash_boxes" - -// plugin options -const PluginsPath = "plugins_path" - -// i18n -const Language = "language" - -// served directories -// this should be manually configured only -const CustomServedFolders = "custom_served_folders" - -// UI directory. Overrides to serve the UI from a specific location -// rather than use the embedded UI. -const CustomUILocation = "custom_ui_location" - -// Interface options -const MenuItems = "menu_items" - -var defaultMenuItems = []string{"scenes", "images", "movies", "markers", "galleries", "performers", "studios", "tags"} - -const SoundOnPreview = "sound_on_preview" -const WallShowTitle = "wall_show_title" -const CustomPerformerImageLocation = "custom_performer_image_location" -const MaximumLoopDuration = "maximum_loop_duration" -const AutostartVideo = "autostart_video" -const AutostartVideoOnPlaySelected = "autostart_video_on_play_selected" -const ContinuePlaylistDefault = "continue_playlist_default" -const ShowStudioAsText = "show_studio_as_text" -const CSSEnabled = "cssEnabled" -const WallPlayback = "wall_playback" -const SlideshowDelay = "slideshow_delay" - const ( + Stash = "stash" + Cache = "cache" + Generated = "generated" + Metadata = "metadata" + Downloads = "downloads" + ApiKey = "api_key" + Username = "username" + Password = "password" + MaxSessionAge = "max_session_age" + + DefaultMaxSessionAge = 60 * 60 * 1 // 1 hours + + Database = "database" + + Exclude = "exclude" + ImageExclude = "image_exclude" + + VideoExtensions = "video_extensions" + ImageExtensions = "image_extensions" + GalleryExtensions = "gallery_extensions" + CreateGalleriesFromFolders = "create_galleries_from_folders" + + // CalculateMD5 is the config key used to determine if MD5 should be calculated + // for video files. + CalculateMD5 = "calculate_md5" + + // VideoFileNamingAlgorithm is the config key used to determine what hash + // should be used when generating and using generated files for scenes. + VideoFileNamingAlgorithm = "video_file_naming_algorithm" + + MaxTranscodeSize = "max_transcode_size" + MaxStreamingTranscodeSize = "max_streaming_transcode_size" + + ParallelTasks = "parallel_tasks" + parallelTasksDefault = 1 + + PreviewPreset = "preview_preset" + + PreviewAudio = "preview_audio" + previewAudioDefault = true + + PreviewSegmentDuration = "preview_segment_duration" + previewSegmentDurationDefault = 0.75 + + PreviewSegments = "preview_segments" + previewSegmentsDefault = 12 + + PreviewExcludeStart = "preview_exclude_start" + previewExcludeStartDefault = "0" + + PreviewExcludeEnd = "preview_exclude_end" + previewExcludeEndDefault = "0" + + WriteImageThumbnails = "write_image_thumbnails" + writeImageThumbnailsDefault = true + + Host = "host" + hostDefault = "0.0.0.0" + + Port = "port" + portDefault = 9999 + + ExternalHost = "external_host" + + // key used to sign JWT tokens + JWTSignKey = "jwt_secret_key" + + // key used for session store + SessionStoreKey = "session_store_key" + + // scraping options + ScrapersPath = "scrapers_path" + ScraperUserAgent = "scraper_user_agent" + ScraperCertCheck = "scraper_cert_check" + ScraperCDPPath = "scraper_cdp_path" + ScraperExcludeTagPatterns = "scraper_exclude_tag_patterns" + + // stash-box options + StashBoxes = "stash_boxes" + + // plugin options + PluginsPath = "plugins_path" + + // i18n + Language = "language" + + // served directories + // this should be manually configured only + CustomServedFolders = "custom_served_folders" + + // UI directory. Overrides to serve the UI from a specific location + // rather than use the embedded UI. + CustomUILocation = "custom_ui_location" + + // Interface options + MenuItems = "menu_items" + + SoundOnPreview = "sound_on_preview" + + WallShowTitle = "wall_show_title" + defaultWallShowTitle = true + + CustomPerformerImageLocation = "custom_performer_image_location" + MaximumLoopDuration = "maximum_loop_duration" + AutostartVideo = "autostart_video" + AutostartVideoOnPlaySelected = "autostart_video_on_play_selected" + ContinuePlaylistDefault = "continue_playlist_default" + ShowStudioAsText = "show_studio_as_text" + CSSEnabled = "cssEnabled" + + WallPlayback = "wall_playback" + defaultWallPlayback = "video" + + SlideshowDelay = "slideshow_delay" + defaultSlideshowDelay = 5000 + DisableDropdownCreatePerformer = "disable_dropdown_create.performer" DisableDropdownCreateStudio = "disable_dropdown_create.studio" DisableDropdownCreateTag = "disable_dropdown_create.tag" -) -const HandyKey = "handy_key" -const FunscriptOffset = "funscript_offset" + HandyKey = "handy_key" + FunscriptOffset = "funscript_offset" + + // Security + TrustedProxies = "trusted_proxies" + dangerousAllowPublicWithoutAuth = "dangerous_allow_public_without_auth" + dangerousAllowPublicWithoutAuthDefault = "false" + SecurityTripwireAccessedFromPublicInternet = "security_tripwire_accessed_from_public_internet" + securityTripwireAccessedFromPublicInternetDefault = "" + + // DLNA options + DLNAServerName = "dlna.server_name" + DLNADefaultEnabled = "dlna.default_enabled" + DLNADefaultIPWhitelist = "dlna.default_whitelist" + DLNAInterfaces = "dlna.interfaces" + + // Logging options + LogFile = "logFile" + LogOut = "logOut" + defaultLogOut = true + LogLevel = "logLevel" + defaultLogLevel = "Info" + LogAccess = "logAccess" + defaultLogAccess = true -// Default settings -const ( DefaultIdentifySettings = "defaults.identify_task" - DeleteFileDefault = "defaults.delete_file" - DeleteGeneratedDefault = "defaults.delete_generated" + DeleteFileDefault = "defaults.delete_file" + DeleteGeneratedDefault = "defaults.delete_generated" + deleteGeneratedDefaultDefault = true + + // Desktop Integration Options + NoBrowser = "noBrowser" + NoBrowserDefault = false + + // File upload options + MaxUploadSize = "max_upload_size" ) -// Security -const TrustedProxies = "trusted_proxies" -const dangerousAllowPublicWithoutAuth = "dangerous_allow_public_without_auth" -const dangerousAllowPublicWithoutAuthDefault = "false" -const SecurityTripwireAccessedFromPublicInternet = "security_tripwire_accessed_from_public_internet" -const securityTripwireAccessedFromPublicInternetDefault = "" - -// DLNA options -const DLNAServerName = "dlna.server_name" -const DLNADefaultEnabled = "dlna.default_enabled" -const DLNADefaultIPWhitelist = "dlna.default_whitelist" -const DLNAInterfaces = "dlna.interfaces" - -// Desktop Integration Options -const NoBrowser = "noBrowser" -const NoBrowserDefault = false - -// Logging options -const LogFile = "logFile" -const LogOut = "logOut" -const LogLevel = "logLevel" -const LogAccess = "logAccess" - -// File upload options -const MaxUploadSize = "max_upload_size" +// slice default values +var ( + defaultVideoExtensions = []string{"m4v", "mp4", "mov", "wmv", "avi", "mpg", "mpeg", "rmvb", "rm", "flv", "asf", "mkv", "webm"} + defaultImageExtensions = []string{"png", "jpg", "jpeg", "gif", "webp"} + defaultGalleryExtensions = []string{"zip", "cbz"} + defaultMenuItems = []string{"scenes", "images", "movies", "markers", "galleries", "performers", "studios", "tags"} +) type MissingConfigError struct { missingFields []string @@ -199,6 +208,13 @@ func (s *StashBoxError) Error() string { } type Instance struct { + // main instance - backed by config file + main *viper.Viper + + // override instance - populated from flags/environment + // not written to config file + overrides *viper.Viper + cpuProfilePath string isNewSystem bool certFile string @@ -209,13 +225,6 @@ type Instance struct { var instance *Instance -func GetInstance() *Instance { - if instance == nil { - instance = &Instance{} - } - return instance -} - func (i *Instance) IsNewSystem() bool { return i.isNewSystem } @@ -223,7 +232,7 @@ func (i *Instance) IsNewSystem() bool { func (i *Instance) SetConfigFile(fn string) { i.Lock() defer i.Unlock() - viper.SetConfigFile(fn) + i.main.SetConfigFile(fn) } func (i *Instance) InitTLS() { @@ -253,16 +262,14 @@ func (i *Instance) GetCPUProfilePath() string { return i.cpuProfilePath } -func (i *Instance) GetNoBrowserFlag() bool { - i.Lock() - defer i.Unlock() - return viper.GetBool(NoBrowser) +func (i *Instance) GetNoBrowser() bool { + return i.getBool(NoBrowser) } func (i *Instance) Set(key string, value interface{}) { i.Lock() defer i.Unlock() - viper.Set(key, value) + i.main.Set(key, value) } func (i *Instance) SetPassword(value string) { @@ -277,14 +284,20 @@ func (i *Instance) SetPassword(value string) { func (i *Instance) Write() error { i.Lock() defer i.Unlock() - return viper.WriteConfig() + return i.main.WriteConfig() +} + +// FileEnvSet returns true if the configuration file environment parameter +// is set. +func FileEnvSet() bool { + return os.Getenv("STASH_CONFIG_FILE") != "" } // GetConfigFile returns the full path to the used configuration file. func (i *Instance) GetConfigFile() string { i.RLock() defer i.RUnlock() - return viper.ConfigFileUsed() + return i.main.ConfigFileUsed() } // GetConfigPath returns the path of the directory containing the used @@ -299,13 +312,94 @@ func (i *Instance) GetDefaultDatabaseFilePath() string { return filepath.Join(i.GetConfigPath(), "stash-go.sqlite") } +// viper returns the viper instance that should be used to get the provided +// key. Returns the overrides instance if the key exists there, otherwise it +// returns the main instance. Assumes read lock held. +func (i *Instance) viper(key string) *viper.Viper { + v := i.main + if i.overrides.IsSet(key) { + v = i.overrides + } + + return v +} + +func (i *Instance) HasOverride(key string) bool { + i.RLock() + defer i.RUnlock() + + return i.overrides.IsSet(key) +} + +// These functions wrap the equivalent viper functions, checking the override +// instance first, then the main instance. + +func (i *Instance) unmarshalKey(key string, rawVal interface{}) error { + i.RLock() + defer i.RUnlock() + + return i.viper(key).UnmarshalKey(key, rawVal) +} + +func (i *Instance) getStringSlice(key string) []string { + i.RLock() + defer i.RUnlock() + + return i.viper(key).GetStringSlice(key) +} + +func (i *Instance) getString(key string) string { + i.RLock() + defer i.RUnlock() + + return i.viper(key).GetString(key) +} + +func (i *Instance) getBool(key string) bool { + i.RLock() + defer i.RUnlock() + + return i.viper(key).GetBool(key) +} + +func (i *Instance) getInt(key string) int { + i.RLock() + defer i.RUnlock() + + return i.viper(key).GetInt(key) +} + +func (i *Instance) getFloat64(key string) float64 { + i.RLock() + defer i.RUnlock() + + return i.viper(key).GetFloat64(key) +} + +func (i *Instance) getStringMapString(key string) map[string]string { + i.RLock() + defer i.RUnlock() + + return i.viper(key).GetStringMapString(key) +} + +// GetStathPaths returns the configured stash library paths. +// Works opposite to the usual case - it will return the override +// value only if the main value is not set. func (i *Instance) GetStashPaths() []*models.StashConfig { i.RLock() defer i.RUnlock() + var ret []*models.StashConfig - if err := viper.UnmarshalKey(Stash, &ret); err != nil || len(ret) == 0 { + + v := i.main + if !v.IsSet(Stash) { + v = i.overrides + } + + if err := v.UnmarshalKey(Stash, &ret); err != nil || len(ret) == 0 { // fallback to legacy format - ss := viper.GetStringSlice(Stash) + ss := v.GetStringSlice(Stash) ret = nil for _, path := range ss { toAdd := &models.StashConfig{ @@ -318,72 +412,47 @@ func (i *Instance) GetStashPaths() []*models.StashConfig { return ret } -func (i *Instance) GetConfigFilePath() string { - i.RLock() - defer i.RUnlock() - return viper.ConfigFileUsed() -} - func (i *Instance) GetCachePath() string { - i.RLock() - defer i.RUnlock() - return viper.GetString(Cache) + return i.getString(Cache) } func (i *Instance) GetGeneratedPath() string { - i.RLock() - defer i.RUnlock() - return viper.GetString(Generated) + return i.getString(Generated) } func (i *Instance) GetMetadataPath() string { - i.RLock() - defer i.RUnlock() - return viper.GetString(Metadata) + return i.getString(Metadata) } func (i *Instance) GetDatabasePath() string { - i.RLock() - defer i.RUnlock() - return viper.GetString(Database) + return i.getString(Database) } func (i *Instance) GetJWTSignKey() []byte { - i.RLock() - defer i.RUnlock() - return []byte(viper.GetString(JWTSignKey)) + return []byte(i.getString(JWTSignKey)) } func (i *Instance) GetSessionStoreKey() []byte { - i.RLock() - defer i.RUnlock() - return []byte(viper.GetString(SessionStoreKey)) + return []byte(i.getString(SessionStoreKey)) } func (i *Instance) GetDefaultScrapersPath() string { // default to the same directory as the config file - fn := filepath.Join(i.GetConfigPath(), "scrapers") return fn } func (i *Instance) GetExcludes() []string { - i.RLock() - defer i.RUnlock() - return viper.GetStringSlice(Exclude) + return i.getStringSlice(Exclude) } func (i *Instance) GetImageExcludes() []string { - i.RLock() - defer i.RUnlock() - return viper.GetStringSlice(ImageExclude) + return i.getStringSlice(ImageExclude) } func (i *Instance) GetVideoExtensions() []string { - i.RLock() - defer i.RUnlock() - ret := viper.GetStringSlice(VideoExtensions) + ret := i.getStringSlice(VideoExtensions) if ret == nil { ret = defaultVideoExtensions } @@ -391,9 +460,7 @@ func (i *Instance) GetVideoExtensions() []string { } func (i *Instance) GetImageExtensions() []string { - i.RLock() - defer i.RUnlock() - ret := viper.GetStringSlice(ImageExtensions) + ret := i.getStringSlice(ImageExtensions) if ret == nil { ret = defaultImageExtensions } @@ -401,9 +468,7 @@ func (i *Instance) GetImageExtensions() []string { } func (i *Instance) GetGalleryExtensions() []string { - i.RLock() - defer i.RUnlock() - ret := viper.GetStringSlice(GalleryExtensions) + ret := i.getStringSlice(GalleryExtensions) if ret == nil { ret = defaultGalleryExtensions } @@ -411,15 +476,11 @@ func (i *Instance) GetGalleryExtensions() []string { } func (i *Instance) GetCreateGalleriesFromFolders() bool { - i.RLock() - defer i.RUnlock() - return viper.GetBool(CreateGalleriesFromFolders) + return i.getBool(CreateGalleriesFromFolders) } func (i *Instance) GetLanguage() string { - i.RLock() - defer i.RUnlock() - ret := viper.GetString(Language) + ret := i.getString(Language) // default to English if ret == "" { @@ -432,17 +493,13 @@ func (i *Instance) GetLanguage() string { // IsCalculateMD5 returns true if MD5 checksums should be generated for // scene video files. func (i *Instance) IsCalculateMD5() bool { - i.RLock() - defer i.RUnlock() - return viper.GetBool(CalculateMD5) + return i.getBool(CalculateMD5) } // GetVideoFileNamingAlgorithm returns what hash algorithm should be used for // naming generated scene video files. func (i *Instance) GetVideoFileNamingAlgorithm() models.HashAlgorithm { - i.RLock() - defer i.RUnlock() - ret := viper.GetString(VideoFileNamingAlgorithm) + ret := i.getString(VideoFileNamingAlgorithm) // default to oshash if ret == "" { @@ -453,54 +510,41 @@ func (i *Instance) GetVideoFileNamingAlgorithm() models.HashAlgorithm { } func (i *Instance) GetScrapersPath() string { - i.RLock() - defer i.RUnlock() - return viper.GetString(ScrapersPath) + return i.getString(ScrapersPath) } func (i *Instance) GetScraperUserAgent() string { - i.RLock() - defer i.RUnlock() - return viper.GetString(ScraperUserAgent) + return i.getString(ScraperUserAgent) } // GetScraperCDPPath gets the path to the Chrome executable or remote address // to an instance of Chrome. func (i *Instance) GetScraperCDPPath() string { - i.RLock() - defer i.RUnlock() - return viper.GetString(ScraperCDPPath) + return i.getString(ScraperCDPPath) } // GetScraperCertCheck returns true if the scraper should check for insecure // certificates when fetching an image or a page. func (i *Instance) GetScraperCertCheck() bool { + ret := true i.RLock() defer i.RUnlock() - ret := true - if viper.IsSet(ScraperCertCheck) { - ret = viper.GetBool(ScraperCertCheck) + + v := i.viper(ScraperCertCheck) + if v.IsSet(ScraperCertCheck) { + ret = v.GetBool(ScraperCertCheck) } return ret } func (i *Instance) GetScraperExcludeTagPatterns() []string { - i.RLock() - defer i.RUnlock() - var ret []string - if viper.IsSet(ScraperExcludeTagPatterns) { - ret = viper.GetStringSlice(ScraperExcludeTagPatterns) - } - - return ret + return i.getStringSlice(ScraperExcludeTagPatterns) } func (i *Instance) GetStashBoxes() models.StashBoxes { - i.RLock() - defer i.RUnlock() var boxes models.StashBoxes - if err := viper.UnmarshalKey(StashBoxes, &boxes); err != nil { + if err := i.unmarshalKey(StashBoxes, &boxes); err != nil { logger.Warnf("error in unmarshalkey: %v", err) } @@ -515,49 +559,45 @@ func (i *Instance) GetDefaultPluginsPath() string { } func (i *Instance) GetPluginsPath() string { - i.RLock() - defer i.RUnlock() - return viper.GetString(PluginsPath) + return i.getString(PluginsPath) } func (i *Instance) GetHost() string { - i.RLock() - defer i.RUnlock() - return viper.GetString(Host) + ret := i.getString(Host) + if ret == "" { + ret = hostDefault + } + + return ret } func (i *Instance) GetPort() int { - i.RLock() - defer i.RUnlock() - return viper.GetInt(Port) + ret := i.getInt(Port) + if ret == 0 { + ret = portDefault + } + + return ret } func (i *Instance) GetExternalHost() string { - i.RLock() - defer i.RUnlock() - return viper.GetString(ExternalHost) + return i.getString(ExternalHost) } // GetPreviewSegmentDuration returns the duration of a single segment in a // scene preview file, in seconds. func (i *Instance) GetPreviewSegmentDuration() float64 { - i.RLock() - defer i.RUnlock() - return viper.GetFloat64(PreviewSegmentDuration) + return i.getFloat64(PreviewSegmentDuration) } // GetParallelTasks returns the number of parallel tasks that should be started // by scan or generate task. func (i *Instance) GetParallelTasks() int { - i.RLock() - defer i.RUnlock() - return viper.GetInt(ParallelTasks) + return i.getInt(ParallelTasks) } func (i *Instance) GetParallelTasksWithAutoDetection() int { - i.RLock() - defer i.RUnlock() - parallelTasks := viper.GetInt(ParallelTasks) + parallelTasks := i.getInt(ParallelTasks) if parallelTasks <= 0 { parallelTasks = (runtime.NumCPU() / 4) + 1 } @@ -565,16 +605,12 @@ func (i *Instance) GetParallelTasksWithAutoDetection() int { } func (i *Instance) GetPreviewAudio() bool { - i.RLock() - defer i.RUnlock() - return viper.GetBool(PreviewAudio) + return i.getBool(PreviewAudio) } // GetPreviewSegments returns the amount of segments in a scene preview file. func (i *Instance) GetPreviewSegments() int { - i.RLock() - defer i.RUnlock() - return viper.GetInt(PreviewSegments) + return i.getInt(PreviewSegments) } // GetPreviewExcludeStart returns the configuration setting string for @@ -584,9 +620,7 @@ func (i *Instance) GetPreviewSegments() int { // in the preview. If the value is suffixed with a '%' character (for example // '2%'), then it is interpreted as a proportion of the total video duration. func (i *Instance) GetPreviewExcludeStart() string { - i.RLock() - defer i.RUnlock() - return viper.GetString(PreviewExcludeStart) + return i.getString(PreviewExcludeStart) } // GetPreviewExcludeEnd returns the configuration setting string for @@ -595,17 +629,13 @@ func (i *Instance) GetPreviewExcludeStart() string { // when generating previews. If the value is suffixed with a '%' character, // then it is interpreted as a proportion of the total video duration. func (i *Instance) GetPreviewExcludeEnd() string { - i.RLock() - defer i.RUnlock() - return viper.GetString(PreviewExcludeEnd) + return i.getString(PreviewExcludeEnd) } // GetPreviewPreset returns the preset when generating previews. Defaults to // Slow. func (i *Instance) GetPreviewPreset() models.PreviewPreset { - i.RLock() - defer i.RUnlock() - ret := viper.GetString(PreviewPreset) + ret := i.getString(PreviewPreset) // default to slow if ret == "" { @@ -616,9 +646,7 @@ func (i *Instance) GetPreviewPreset() models.PreviewPreset { } func (i *Instance) GetMaxTranscodeSize() models.StreamingResolutionEnum { - i.RLock() - defer i.RUnlock() - ret := viper.GetString(MaxTranscodeSize) + ret := i.getString(MaxTranscodeSize) // default to original if ret == "" { @@ -629,9 +657,7 @@ func (i *Instance) GetMaxTranscodeSize() models.StreamingResolutionEnum { } func (i *Instance) GetMaxStreamingTranscodeSize() models.StreamingResolutionEnum { - i.RLock() - defer i.RUnlock() - ret := viper.GetString(MaxStreamingTranscodeSize) + ret := i.getString(MaxStreamingTranscodeSize) // default to original if ret == "" { @@ -644,48 +670,32 @@ func (i *Instance) GetMaxStreamingTranscodeSize() models.StreamingResolutionEnum // IsWriteImageThumbnails returns true if image thumbnails should be written // to disk after generating on the fly. func (i *Instance) IsWriteImageThumbnails() bool { - i.RLock() - defer i.RUnlock() - return viper.GetBool(WriteImageThumbnails) + return i.getBool(WriteImageThumbnails) } func (i *Instance) GetAPIKey() string { - i.RLock() - defer i.RUnlock() - return viper.GetString(ApiKey) + return i.getString(ApiKey) } func (i *Instance) GetUsername() string { - i.RLock() - defer i.RUnlock() - return viper.GetString(Username) + return i.getString(Username) } func (i *Instance) GetPasswordHash() string { - i.RLock() - defer i.RUnlock() - return viper.GetString(Password) + return i.getString(Password) } func (i *Instance) GetCredentials() (string, string) { if i.HasCredentials() { - i.RLock() - defer i.RUnlock() - return viper.GetString(Username), viper.GetString(Password) + return i.getString(Username), i.getString(Password) } return "", "" } func (i *Instance) HasCredentials() bool { - i.RLock() - defer i.RUnlock() - if !viper.IsSet(Username) || !viper.IsSet(Password) { - return false - } - - username := viper.GetString(Username) - pwHash := viper.GetString(Password) + username := i.getString(Username) + pwHash := i.getString(Password) return username != "" && pwHash != "" } @@ -739,75 +749,78 @@ func (i *Instance) ValidateStashBoxes(boxes []*models.StashBoxInput) error { // GetMaxSessionAge gets the maximum age for session cookies, in seconds. // Session cookie expiry times are refreshed every request. func (i *Instance) GetMaxSessionAge() int { - i.Lock() - defer i.Unlock() - viper.SetDefault(MaxSessionAge, DefaultMaxSessionAge) - return viper.GetInt(MaxSessionAge) + i.RLock() + defer i.RUnlock() + + ret := DefaultMaxSessionAge + v := i.viper(MaxSessionAge) + if v.IsSet(MaxSessionAge) { + ret = v.GetInt(MaxSessionAge) + } + + return ret } // GetCustomServedFolders gets the map of custom paths to their applicable // filesystem locations func (i *Instance) GetCustomServedFolders() URLMap { - i.RLock() - defer i.RUnlock() - return viper.GetStringMapString(CustomServedFolders) + return i.getStringMapString(CustomServedFolders) } func (i *Instance) GetCustomUILocation() string { - i.RLock() - defer i.RUnlock() - return viper.GetString(CustomUILocation) + return i.getString(CustomUILocation) } // Interface options func (i *Instance) GetMenuItems() []string { i.RLock() defer i.RUnlock() - if viper.IsSet(MenuItems) { - return viper.GetStringSlice(MenuItems) + v := i.viper(MenuItems) + if v.IsSet(MenuItems) { + return v.GetStringSlice(MenuItems) } return defaultMenuItems } func (i *Instance) GetSoundOnPreview() bool { - i.RLock() - defer i.RUnlock() - return viper.GetBool(SoundOnPreview) + return i.getBool(SoundOnPreview) } func (i *Instance) GetWallShowTitle() bool { - i.Lock() - defer i.Unlock() - viper.SetDefault(WallShowTitle, true) - return viper.GetBool(WallShowTitle) + i.RLock() + defer i.RUnlock() + + ret := defaultWallShowTitle + v := i.viper(WallShowTitle) + if v.IsSet(WallShowTitle) { + ret = v.GetBool(WallShowTitle) + } + return ret } func (i *Instance) GetCustomPerformerImageLocation() string { - i.Lock() - defer i.Unlock() - viper.SetDefault(CustomPerformerImageLocation, "") - return viper.GetString(CustomPerformerImageLocation) + return i.getString(CustomPerformerImageLocation) } func (i *Instance) GetWallPlayback() string { - i.Lock() - defer i.Unlock() - viper.SetDefault(WallPlayback, "video") - return viper.GetString(WallPlayback) + i.RLock() + defer i.RUnlock() + + ret := defaultWallPlayback + v := i.viper(WallPlayback) + if v.IsSet(WallPlayback) { + ret = v.GetString(WallPlayback) + } + + return ret } func (i *Instance) GetMaximumLoopDuration() int { - i.Lock() - defer i.Unlock() - viper.SetDefault(MaximumLoopDuration, 0) - return viper.GetInt(MaximumLoopDuration) + return i.getInt(MaximumLoopDuration) } func (i *Instance) GetAutostartVideo() bool { - i.Lock() - defer i.Unlock() - viper.SetDefault(AutostartVideo, false) - return viper.GetBool(AutostartVideo) + return i.getBool(AutostartVideo) } func (i *Instance) GetAutostartVideoOnPlaySelected() bool { @@ -825,35 +838,33 @@ func (i *Instance) GetContinuePlaylistDefault() bool { } func (i *Instance) GetShowStudioAsText() bool { - i.Lock() - defer i.Unlock() - viper.SetDefault(ShowStudioAsText, false) - return viper.GetBool(ShowStudioAsText) + return i.getBool(ShowStudioAsText) } func (i *Instance) GetSlideshowDelay() int { - i.Lock() - defer i.Unlock() - viper.SetDefault(SlideshowDelay, 5000) - return viper.GetInt(SlideshowDelay) + i.RLock() + defer i.RUnlock() + + ret := defaultSlideshowDelay + v := i.viper(SlideshowDelay) + if v.IsSet(SlideshowDelay) { + ret = v.GetInt(SlideshowDelay) + } + + return ret } func (i *Instance) GetDisableDropdownCreate() *models.ConfigDisableDropdownCreate { - i.Lock() - defer i.Unlock() - return &models.ConfigDisableDropdownCreate{ - Performer: viper.GetBool(DisableDropdownCreatePerformer), - Studio: viper.GetBool(DisableDropdownCreateStudio), - Tag: viper.GetBool(DisableDropdownCreateTag), + Performer: i.getBool(DisableDropdownCreatePerformer), + Studio: i.getBool(DisableDropdownCreateStudio), + Tag: i.getBool(DisableDropdownCreateTag), } } func (i *Instance) GetCSSPath() string { - i.RLock() - defer i.RUnlock() // use custom.css in the same directory as the config file - configFileUsed := viper.ConfigFileUsed() + configFileUsed := i.GetConfigFile() configDir := filepath.Dir(configFileUsed) fn := filepath.Join(configDir, "custom.css") @@ -879,9 +890,9 @@ func (i *Instance) GetCSS() string { } func (i *Instance) SetCSS(css string) { - i.RLock() - defer i.RUnlock() fn := i.GetCSSPath() + i.Lock() + defer i.Unlock() buf := []byte(css) @@ -891,36 +902,32 @@ func (i *Instance) SetCSS(css string) { } func (i *Instance) GetCSSEnabled() bool { - i.RLock() - defer i.RUnlock() - return viper.GetBool(CSSEnabled) + return i.getBool(CSSEnabled) } func (i *Instance) GetHandyKey() string { - i.RLock() - defer i.RUnlock() - return viper.GetString(HandyKey) + return i.getString(HandyKey) } func (i *Instance) GetFunscriptOffset() int { - i.Lock() - defer i.Unlock() - viper.SetDefault(FunscriptOffset, 0) - return viper.GetInt(FunscriptOffset) + return i.getInt(FunscriptOffset) } func (i *Instance) GetDeleteFileDefault() bool { - i.Lock() - defer i.Unlock() - viper.SetDefault(DeleteFileDefault, false) - return viper.GetBool(DeleteFileDefault) + return i.getBool(DeleteFileDefault) } func (i *Instance) GetDeleteGeneratedDefault() bool { - i.Lock() - defer i.Unlock() - viper.SetDefault(DeleteGeneratedDefault, true) - return viper.GetBool(DeleteGeneratedDefault) + i.RLock() + defer i.RUnlock() + ret := deleteGeneratedDefaultDefault + + v := i.viper(DeleteGeneratedDefault) + if v.IsSet(DeleteGeneratedDefault) { + ret = v.GetBool(DeleteGeneratedDefault) + } + + return ret } // GetDefaultIdentifySettings returns the default Identify task settings. @@ -944,65 +951,49 @@ func (i *Instance) GetDefaultIdentifySettings() *models.IdentifyMetadataTaskOpti // GetTrustedProxies returns a comma separated list of ip addresses that should allow proxying. // When empty, allow from any private network func (i *Instance) GetTrustedProxies() []string { - i.RLock() - defer i.RUnlock() - return viper.GetStringSlice(TrustedProxies) + return i.getStringSlice(TrustedProxies) } // GetDangerousAllowPublicWithoutAuth determines if the security feature is enabled. // See https://github.com/stashapp/stash/wiki/Authentication-Required-When-Accessing-Stash-From-the-Internet func (i *Instance) GetDangerousAllowPublicWithoutAuth() bool { - i.RLock() - defer i.RUnlock() - return viper.GetBool(dangerousAllowPublicWithoutAuth) + return i.getBool(dangerousAllowPublicWithoutAuth) } // GetSecurityTripwireAccessedFromPublicInternet returns a public IP address if stash // has been accessed from the public internet, with no auth enabled, and // DangerousAllowPublicWithoutAuth disabled. Returns an empty string otherwise. func (i *Instance) GetSecurityTripwireAccessedFromPublicInternet() string { - i.RLock() - defer i.RUnlock() - return viper.GetString(SecurityTripwireAccessedFromPublicInternet) + return i.getString(SecurityTripwireAccessedFromPublicInternet) } // GetDLNAServerName returns the visible name of the DLNA server. If empty, // "stash" will be used. func (i *Instance) GetDLNAServerName() string { - i.RLock() - defer i.RUnlock() - return viper.GetString(DLNAServerName) + return i.getString(DLNAServerName) } // GetDLNADefaultEnabled returns true if the DLNA is enabled by default. func (i *Instance) GetDLNADefaultEnabled() bool { - i.RLock() - defer i.RUnlock() - return viper.GetBool(DLNADefaultEnabled) + return i.getBool(DLNADefaultEnabled) } // GetDLNADefaultIPWhitelist returns a list of IP addresses/wildcards that // are allowed to use the DLNA service. func (i *Instance) GetDLNADefaultIPWhitelist() []string { - i.RLock() - defer i.RUnlock() - return viper.GetStringSlice(DLNADefaultIPWhitelist) + return i.getStringSlice(DLNADefaultIPWhitelist) } // GetDLNAInterfaces returns a list of interface names to expose DLNA on. If // empty, runs on all interfaces. func (i *Instance) GetDLNAInterfaces() []string { - i.RLock() - defer i.RUnlock() - return viper.GetStringSlice(DLNAInterfaces) + return i.getStringSlice(DLNAInterfaces) } // GetLogFile returns the filename of the file to output logs to. // An empty string means that file logging will be disabled. func (i *Instance) GetLogFile() string { - i.RLock() - defer i.RUnlock() - return viper.GetString(LogFile) + return i.getString(LogFile) } // GetLogOut returns true if logging should be output to the terminal @@ -1011,9 +1002,12 @@ func (i *Instance) GetLogFile() string { func (i *Instance) GetLogOut() bool { i.RLock() defer i.RUnlock() - ret := true - if viper.IsSet(LogOut) { - ret = viper.GetBool(LogOut) + + ret := defaultLogOut + v := i.viper(LogOut) + + if v.IsSet(LogOut) { + ret = v.GetBool(LogOut) } return ret @@ -1022,13 +1016,9 @@ func (i *Instance) GetLogOut() bool { // GetLogLevel returns the lowest log level to write to the log. // Should be one of "Debug", "Info", "Warning", "Error" func (i *Instance) GetLogLevel() string { - i.RLock() - defer i.RUnlock() - const defaultValue = "Info" - - value := viper.GetString(LogLevel) + value := i.getString(LogLevel) if value != "Debug" && value != "Info" && value != "Warning" && value != "Error" && value != "Trace" { - value = defaultValue + value = defaultLogLevel } return value @@ -1039,9 +1029,11 @@ func (i *Instance) GetLogLevel() string { func (i *Instance) GetLogAccess() bool { i.RLock() defer i.RUnlock() - ret := true - if viper.IsSet(LogAccess) { - ret = viper.GetBool(LogAccess) + ret := defaultLogAccess + + v := i.viper(LogAccess) + if v.IsSet(LogAccess) { + ret = v.GetBool(LogAccess) } return ret @@ -1052,8 +1044,10 @@ func (i *Instance) GetMaxUploadSize() int64 { i.RLock() defer i.RUnlock() ret := int64(1024) - if viper.IsSet(MaxUploadSize) { - ret = viper.GetInt64(MaxUploadSize) + + v := i.viper(MaxUploadSize) + if v.IsSet(MaxUploadSize) { + ret = v.GetInt64(MaxUploadSize) } return ret << 20 } @@ -1077,7 +1071,7 @@ func (i *Instance) Validate() error { var missingFields []string for _, p := range mandatoryPaths { - if !viper.IsSet(p) || viper.GetString(p) == "" { + if !i.viper(p).IsSet(p) || i.viper(p).GetString(p) == "" { missingFields = append(missingFields, p) } } @@ -1094,8 +1088,8 @@ func (i *Instance) Validate() error { func (i *Instance) SetChecksumDefaultValues(defaultAlgorithm models.HashAlgorithm, usingMD5 bool) { i.Lock() defer i.Unlock() - viper.SetDefault(VideoFileNamingAlgorithm, defaultAlgorithm) - viper.SetDefault(CalculateMD5, usingMD5) + i.main.SetDefault(VideoFileNamingAlgorithm, defaultAlgorithm) + i.main.SetDefault(CalculateMD5, usingMD5) } func (i *Instance) setDefaultValues(write bool) error { @@ -1106,31 +1100,60 @@ func (i *Instance) setDefaultValues(write bool) error { i.Lock() defer i.Unlock() - viper.SetDefault(ParallelTasks, parallelTasksDefault) - viper.SetDefault(PreviewSegmentDuration, previewSegmentDurationDefault) - viper.SetDefault(PreviewSegments, previewSegmentsDefault) - viper.SetDefault(PreviewExcludeStart, previewExcludeStartDefault) - viper.SetDefault(PreviewExcludeEnd, previewExcludeEndDefault) - viper.SetDefault(PreviewAudio, previewAudioDefault) - viper.SetDefault(SoundOnPreview, false) - viper.SetDefault(WriteImageThumbnails, writeImageThumbnailsDefault) + // set the default host and port so that these are written to the config + // file + i.main.SetDefault(Host, hostDefault) + i.main.SetDefault(Port, portDefault) - viper.SetDefault(Database, defaultDatabaseFilePath) + i.main.SetDefault(ParallelTasks, parallelTasksDefault) + i.main.SetDefault(PreviewSegmentDuration, previewSegmentDurationDefault) + i.main.SetDefault(PreviewSegments, previewSegmentsDefault) + i.main.SetDefault(PreviewExcludeStart, previewExcludeStartDefault) + i.main.SetDefault(PreviewExcludeEnd, previewExcludeEndDefault) + i.main.SetDefault(PreviewAudio, previewAudioDefault) + i.main.SetDefault(SoundOnPreview, false) - viper.SetDefault(dangerousAllowPublicWithoutAuth, dangerousAllowPublicWithoutAuthDefault) - viper.SetDefault(SecurityTripwireAccessedFromPublicInternet, securityTripwireAccessedFromPublicInternetDefault) + i.main.SetDefault(WriteImageThumbnails, writeImageThumbnailsDefault) + + i.main.SetDefault(Database, defaultDatabaseFilePath) + + i.main.SetDefault(dangerousAllowPublicWithoutAuth, dangerousAllowPublicWithoutAuthDefault) + i.main.SetDefault(SecurityTripwireAccessedFromPublicInternet, securityTripwireAccessedFromPublicInternetDefault) // Set generated to the metadata path for backwards compat - viper.SetDefault(Generated, viper.GetString(Metadata)) + i.main.SetDefault(Generated, i.main.GetString(Metadata)) viper.SetDefault(NoBrowser, NoBrowserDefault) // Set default scrapers and plugins paths - viper.SetDefault(ScrapersPath, defaultScrapersPath) - viper.SetDefault(PluginsPath, defaultPluginsPath) + i.main.SetDefault(ScrapersPath, defaultScrapersPath) + i.main.SetDefault(PluginsPath, defaultPluginsPath) if write { - return viper.WriteConfig() + return i.main.WriteConfig() + } + + return nil +} + +// setExistingSystemDefaults sets config options that are new and unset in an existing install, +// but should have a separate default than for brand-new systems, to maintain behavior. +func (i *Instance) setExistingSystemDefaults() error { + i.Lock() + defer i.Unlock() + if !i.isNewSystem { + configDirtied := false + + // Existing systems as of the introduction of auto-browser open should retain existing + // behavior and not start the browser automatically. + if !i.main.InConfig(NoBrowser) { + configDirtied = true + i.main.Set(NoBrowser, true) + } + + if configDirtied { + return i.main.WriteConfig() + } } return nil diff --git a/pkg/manager/config/config_concurrency_test.go b/pkg/manager/config/config_concurrency_test.go index 8558f0214..4b940cb2e 100644 --- a/pkg/manager/config/config_concurrency_test.go +++ b/pkg/manager/config/config_concurrency_test.go @@ -21,12 +21,15 @@ func TestConcurrentConfigAccess(t *testing.T) { } i.HasCredentials() + i.ValidateCredentials("", "") i.GetCPUProfilePath() i.GetConfigFile() i.GetConfigPath() i.GetDefaultDatabaseFilePath() i.GetStashPaths() - i.GetConfigFilePath() + _ = i.ValidateStashBoxes(nil) + _ = i.Validate() + _ = i.ActivatePublicAccessTripwire("") i.Set(Cache, i.GetCachePath()) i.Set(Generated, i.GetGeneratedPath()) i.Set(Metadata, i.GetMetadataPath()) @@ -94,6 +97,15 @@ func TestConcurrentConfigAccess(t *testing.T) { i.Set(MaxUploadSize, i.GetMaxUploadSize()) i.Set(FunscriptOffset, i.GetFunscriptOffset()) i.Set(DefaultIdentifySettings, i.GetDefaultIdentifySettings()) + i.Set(DeleteGeneratedDefault, i.GetDeleteGeneratedDefault()) + i.Set(DeleteFileDefault, i.GetDeleteFileDefault()) + i.Set(TrustedProxies, i.GetTrustedProxies()) + i.Set(dangerousAllowPublicWithoutAuth, i.GetDangerousAllowPublicWithoutAuth()) + i.Set(SecurityTripwireAccessedFromPublicInternet, i.GetSecurityTripwireAccessedFromPublicInternet()) + i.Set(DisableDropdownCreatePerformer, i.GetDisableDropdownCreate().Performer) + i.Set(DisableDropdownCreateStudio, i.GetDisableDropdownCreate().Studio) + i.Set(DisableDropdownCreateTag, i.GetDisableDropdownCreate().Tag) + i.SetChecksumDefaultValues(i.GetVideoFileNamingAlgorithm(), i.IsCalculateMD5()) i.Set(AutostartVideoOnPlaySelected, i.GetAutostartVideoOnPlaySelected()) i.Set(ContinuePlaylistDefault, i.GetContinuePlaylistDefault()) } diff --git a/pkg/manager/config/init.go b/pkg/manager/config/init.go index 76e588ef4..3b0d00a6a 100644 --- a/pkg/manager/config/init.go +++ b/pkg/manager/config/init.go @@ -14,7 +14,10 @@ import ( "github.com/stashapp/stash/pkg/utils" ) -var once sync.Once +var ( + initOnce sync.Once + instanceOnce sync.Once +) type flagStruct struct { configFilePath string @@ -22,18 +25,29 @@ type flagStruct struct { nobrowser bool } +func GetInstance() *Instance { + instanceOnce.Do(func() { + instance = &Instance{ + main: viper.New(), + overrides: viper.New(), + } + }) + return instance +} + func Initialize() (*Instance, error) { var err error - once.Do(func() { + initOnce.Do(func() { flags := initFlags() - instance = &Instance{ - cpuProfilePath: flags.cpuProfilePath, - } + overrides := makeOverrideConfig() - if err = initConfig(flags); err != nil { + _ = GetInstance() + instance.overrides = overrides + instance.cpuProfilePath = flags.cpuProfilePath + + if err = initConfig(instance, flags); err != nil { return } - initEnvs() if instance.isNewSystem { if instance.Validate() == nil { @@ -43,20 +57,23 @@ func Initialize() (*Instance, error) { } if !instance.isNewSystem { - setExistingSystemDefaults(instance) - err = instance.SetInitialConfig() + err = instance.setExistingSystemDefaults() + if err == nil { + err = instance.SetInitialConfig() + } } }) return instance, err } -func initConfig(flags flagStruct) error { +func initConfig(instance *Instance, flags flagStruct) error { + v := instance.main // The config file is called config. Leave off the file extension. - viper.SetConfigName("config") + v.SetConfigName("config") - viper.AddConfigPath(".") // Look for config in the working directory - viper.AddConfigPath("$HOME/.stash") // Look for the config in the home directory + v.AddConfigPath(".") // Look for config in the working directory + v.AddConfigPath("$HOME/.stash") // Look for the config in the home directory configFile := "" envConfigFile := os.Getenv("STASH_CONFIG_FILE") @@ -68,7 +85,7 @@ func initConfig(flags flagStruct) error { } if configFile != "" { - viper.SetConfigFile(configFile) + v.SetConfigFile(configFile) // if file does not exist, assume it is a new system if exists, _ := utils.FileExists(configFile); !exists { @@ -86,7 +103,7 @@ func initConfig(flags flagStruct) error { } } - err := viper.ReadInConfig() // Find and read the config file + err := v.ReadInConfig() // Find and read the config file // if not found, assume its a new system var notFoundErr viper.ConfigFileNotFoundError if errors.As(err, ¬FoundErr) { @@ -99,28 +116,6 @@ func initConfig(flags flagStruct) error { return nil } -// setExistingSystemDefaults sets config options that are new and unset in an existing install, -// but should have a separate default than for brand-new systems, to maintain behavior. -func setExistingSystemDefaults(instance *Instance) { - if !instance.isNewSystem { - configDirtied := false - - // Existing systems as of the introduction of auto-browser open should retain existing - // behavior and not start the browser automatically. - if !viper.InConfig("nobrowser") { - configDirtied = true - viper.Set("nobrowser", "true") - } - - if configDirtied { - err := viper.WriteConfig() - if err != nil { - logger.Errorf("Could not save existing system defaults: %s", err.Error()) - } - } - } -} - func initFlags() flagStruct { flags := flagStruct{} @@ -131,30 +126,35 @@ func initFlags() flagStruct { pflag.BoolVar(&flags.nobrowser, "nobrowser", false, "Don't open a browser window after launch") pflag.Parse() - if err := viper.BindPFlags(pflag.CommandLine); err != nil { - logger.Infof("failed to bind flags: %s", err.Error()) - } return flags } -func initEnvs() { - viper.SetEnvPrefix("stash") // will be uppercased automatically - bindEnv("host") // STASH_HOST - bindEnv("port") // STASH_PORT - bindEnv("external_host") // STASH_EXTERNAL_HOST - bindEnv("generated") // STASH_GENERATED - bindEnv("metadata") // STASH_METADATA - bindEnv("cache") // STASH_CACHE - - // only set stash config flag if not already set - if instance.GetStashPaths() == nil { - bindEnv("stash") // STASH_STASH - } +func initEnvs(viper *viper.Viper) { + viper.SetEnvPrefix("stash") // will be uppercased automatically + bindEnv(viper, "host") // STASH_HOST + bindEnv(viper, "port") // STASH_PORT + bindEnv(viper, "external_host") // STASH_EXTERNAL_HOST + bindEnv(viper, "generated") // STASH_GENERATED + bindEnv(viper, "metadata") // STASH_METADATA + bindEnv(viper, "cache") // STASH_CACHE + bindEnv(viper, "stash") // STASH_STASH } -func bindEnv(key string) { +func bindEnv(viper *viper.Viper, key string) { if err := viper.BindEnv(key); err != nil { panic(fmt.Sprintf("unable to set environment key (%v): %v", key, err)) } } + +func makeOverrideConfig() *viper.Viper { + viper := viper.New() + + if err := viper.BindPFlags(pflag.CommandLine); err != nil { + logger.Infof("failed to bind flags: %s", err.Error()) + } + + initEnvs(viper) + + return viper +} diff --git a/pkg/manager/manager.go b/pkg/manager/manager.go index f4acc5505..4d7a44efe 100644 --- a/pkg/manager/manager.go +++ b/pkg/manager/manager.go @@ -302,31 +302,41 @@ func setSetupDefaults(input *models.SetupInput) { func (s *singleton) Setup(ctx context.Context, input models.SetupInput) error { setSetupDefaults(&input) + c := s.Config // create the config directory if it does not exist - configDir := filepath.Dir(input.ConfigLocation) - if exists, _ := utils.DirExists(configDir); !exists { - if err := os.Mkdir(configDir, 0755); err != nil { - return fmt.Errorf("abc: %v", err) + // don't do anything if config is already set in the environment + if !config.FileEnvSet() { + configDir := filepath.Dir(input.ConfigLocation) + if exists, _ := utils.DirExists(configDir); !exists { + if err := os.Mkdir(configDir, 0755); err != nil { + return fmt.Errorf("error creating config directory: %v", err) + } } + + if err := utils.Touch(input.ConfigLocation); err != nil { + return fmt.Errorf("error creating config file: %v", err) + } + + s.Config.SetConfigFile(input.ConfigLocation) } // create the generated directory if it does not exist - if exists, _ := utils.DirExists(input.GeneratedLocation); !exists { - if err := os.Mkdir(input.GeneratedLocation, 0755); err != nil { - return fmt.Errorf("error creating generated directory: %v", err) + if !c.HasOverride(config.Generated) { + if exists, _ := utils.DirExists(input.GeneratedLocation); !exists { + if err := os.Mkdir(input.GeneratedLocation, 0755); err != nil { + return fmt.Errorf("error creating generated directory: %v", err) + } } - } - if err := utils.Touch(input.ConfigLocation); err != nil { - return fmt.Errorf("error creating config file: %v", err) + s.Config.Set(config.Generated, input.GeneratedLocation) } - s.Config.SetConfigFile(input.ConfigLocation) - // set the configuration - s.Config.Set(config.Generated, input.GeneratedLocation) - s.Config.Set(config.Database, input.DatabaseFile) + if !c.HasOverride(config.Database) { + s.Config.Set(config.Database, input.DatabaseFile) + } + s.Config.Set(config.Stash, input.Stashes) if err := s.Config.Write(); err != nil { return fmt.Errorf("error writing configuration file: %v", err) diff --git a/ui/v2.5/src/components/Setup/Setup.tsx b/ui/v2.5/src/components/Setup/Setup.tsx index 4cc677d48..cc0932828 100644 --- a/ui/v2.5/src/components/Setup/Setup.tsx +++ b/ui/v2.5/src/components/Setup/Setup.tsx @@ -1,4 +1,4 @@ -import React, { useEffect, useState } from "react"; +import React, { useEffect, useState, useContext } from "react"; import { FormattedMessage, useIntl } from "react-intl"; import { Alert, @@ -11,11 +11,16 @@ import { import * as GQL from "src/core/generated-graphql"; import { mutateSetup, useSystemStatus } from "src/core/StashService"; import { Link } from "react-router-dom"; +import { ConfigurationContext } from "src/hooks/Config"; import StashConfiguration from "../Settings/StashConfiguration"; import { Icon, LoadingIndicator } from "../Shared"; import { FolderSelectDialog } from "../Shared/FolderSelect/FolderSelectDialog"; export const Setup: React.FC = () => { + const { configuration, loading: configLoading } = useContext( + ConfigurationContext + ); + const [step, setStep] = useState(0); const [configLocation, setConfigLocation] = useState(""); const [stashes, setStashes] = useState([]); @@ -36,6 +41,23 @@ export const Setup: React.FC = () => { } }, [systemStatus]); + useEffect(() => { + if (configuration) { + const { stashes: configStashes, generatedPath } = configuration.general; + if (configStashes.length > 0) { + setStashes( + configStashes.map((s) => { + const { __typename, ...withoutTypename } = s; + return withoutTypename; + }) + ); + } + if (generatedPath) { + setGeneratedLocation(generatedPath); + } + } + }, [configuration]); + const discordLink = ( Discord @@ -179,6 +201,47 @@ export const Setup: React.FC = () => { return ; } + function maybeRenderGenerated() { + if (!configuration?.general.generatedPath) { + return ( + +

+ +

+

+ {chunks}, + }} + /> +

+ + ) => + setGeneratedLocation(e.currentTarget.value) + } + /> + + + + +
+ ); + } + } + function renderSetPaths() { return ( <> @@ -228,41 +291,7 @@ export const Setup: React.FC = () => { } /> - -

- -

-

- {chunks}, - }} - /> -

- - ) => - setGeneratedLocation(e.currentTarget.value) - } - /> - - - - -
+ {maybeRenderGenerated()}
@@ -524,7 +553,7 @@ export const Setup: React.FC = () => { } // only display setup wizard if system is not setup - if (statusLoading) { + if (statusLoading || configLoading) { return ; }