From 2a9ed0abca8ee2cda4d1f5291c853f3a7091611c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A1rk=20S=C3=A1gi-Kaz=C3=A1r?= Date: Thu, 3 Dec 2020 16:23:33 +0100 Subject: [PATCH] Accept a config file path instead of a config name Config name is a Viper concept used for searching a specific file in various paths with various extensions. Making it configurable is usually not a useful feature as users mostly want to define a full or relative path to a config file. This change replaces config name with config file. --- cmd/install_windows.go | 2 +- cmd/portable.go | 1 - cmd/root.go | 27 +++++---------- cmd/startsubsys.go | 27 +-------------- config/config.go | 23 ++++++++----- config/config_test.go | 75 +++++++++++++++++++++++++++++------------- 6 files changed, 77 insertions(+), 78 deletions(-) diff --git a/cmd/install_windows.go b/cmd/install_windows.go index c4353f50..4a1a4606 100644 --- a/cmd/install_windows.go +++ b/cmd/install_windows.go @@ -64,7 +64,7 @@ func getCustomServeFlags() []string { result = append(result, "--"+configDirFlag) result = append(result, configDir) } - if configFile != defaultConfigName { + if configFile != "" { result = append(result, "--"+configFileFlag) result = append(result, configFile) } diff --git a/cmd/portable.go b/cmd/portable.go index bc74c297..c90b44d9 100644 --- a/cmd/portable.go +++ b/cmd/portable.go @@ -124,7 +124,6 @@ Please take a look at the usage below to customize the serving parameters`, } service := service.Service{ ConfigDir: filepath.Clean(defaultConfigDir), - ConfigFile: defaultConfigName, LogFilePath: portableLogFile, LogMaxSize: defaultLogMaxSize, LogMaxBackups: defaultLogMaxBackup, diff --git a/cmd/root.go b/cmd/root.go index f902b07d..81f26be4 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -8,7 +8,6 @@ import ( "github.com/spf13/cobra" "github.com/spf13/viper" - "github.com/drakkan/sftpgo/config" "github.com/drakkan/sftpgo/version" ) @@ -16,7 +15,6 @@ const ( configDirFlag = "config-dir" configDirKey = "config_dir" configFileFlag = "config-file" - configFileKey = "config_file" logFilePathFlag = "log-file-path" logFilePathKey = "log_file_path" logMaxSizeFlag = "log-max-size" @@ -40,7 +38,6 @@ const ( loadDataCleanFlag = "loaddata-clean" loadDataCleanKey = "loaddata_clean" defaultConfigDir = "." - defaultConfigName = config.DefaultConfigName defaultLogFile = "sftpgo.log" defaultLogMaxSize = 10 defaultLogMaxBackup = 5 @@ -96,29 +93,21 @@ func addConfigFlags(cmd *cobra.Command) { viper.BindEnv(configDirKey, "SFTPGO_CONFIG_DIR") //nolint:errcheck // err is not nil only if the key to bind is missing cmd.Flags().StringVarP(&configDir, configDirFlag, "c", viper.GetString(configDirKey), `Location for SFTPGo config dir. This directory -should contain the "sftpgo" configuration file -or the configured config-file and it is used as -the base for files with a relative path (eg. the -private keys for the SFTP server, the SQLite +should contain the "sftpgo" configuration file. +It is used as the base for files with a relative path +(eg. the private keys for the SFTP server, the SQLite database if you use SQLite as data provider). This flag can be set using SFTPGO_CONFIG_DIR env var too.`) viper.BindPFlag(configDirKey, cmd.Flags().Lookup(configDirFlag)) //nolint:errcheck - viper.SetDefault(configFileKey, defaultConfigName) - viper.BindEnv(configFileKey, "SFTPGO_CONFIG_FILE") //nolint:errcheck - cmd.Flags().StringVarP(&configFile, configFileFlag, "f", viper.GetString(configFileKey), - `Name for SFTPGo configuration file. It must be -the name of a file stored in config-dir not the -absolute path to the configuration file. The -specified file name must have no extension we -automatically load JSON, YAML, TOML, HCL and -Java properties. Therefore if you set "sftpgo" -then "sftpgo.json", "sftpgo.yaml" and so on -are searched. + cmd.Flags().StringVar(&configFile, configFileFlag, os.Getenv("SFTPGO_CONFIG_FILE"), + `Path to SFTPGo configuration file. It must be +an absolute path to a file or a path relative to the working directory. +The specified file name must have a supported extension +(JSON, YAML, TOML or Java properties). This flag can be set using SFTPGO_CONFIG_FILE env var too.`) - viper.BindPFlag(configFileKey, cmd.Flags().Lookup(configFileFlag)) //nolint:errcheck } func addServeFlags(cmd *cobra.Command) { diff --git a/cmd/startsubsys.go b/cmd/startsubsys.go index 5f9c82bc..cb446cc4 100644 --- a/cmd/startsubsys.go +++ b/cmd/startsubsys.go @@ -145,33 +145,8 @@ $ journalctl -o verbose -f To see full logs. If not set, the logs will be sent to the standard error`) - viper.SetDefault(configDirKey, defaultConfigDir) - viper.BindEnv(configDirKey, "SFTPGO_CONFIG_DIR") //nolint:errcheck // err is not nil only if the key to bind is missing - subsystemCmd.Flags().StringVarP(&configDir, configDirFlag, "c", viper.GetString(configDirKey), - `Location for SFTPGo config dir. This directory -should contain the "sftpgo" configuration file -or the configured config-file and it is used as -the base for files with a relative path (eg. the -private keys for the SFTP server, the SQLite -database if you use SQLite as data provider). -This flag can be set using SFTPGO_CONFIG_DIR -env var too.`) - viper.BindPFlag(configDirKey, subsystemCmd.Flags().Lookup(configDirFlag)) //nolint:errcheck - viper.SetDefault(configFileKey, defaultConfigName) - viper.BindEnv(configFileKey, "SFTPGO_CONFIG_FILE") //nolint:errcheck - subsystemCmd.Flags().StringVarP(&configFile, configFileFlag, "f", viper.GetString(configFileKey), - `Name for SFTPGo configuration file. It must be -the name of a file stored in config-dir not the -absolute path to the configuration file. The -specified file name must have no extension we -automatically load JSON, YAML, TOML, HCL and -Java properties. Therefore if you set "sftpgo" -then "sftpgo.json", "sftpgo.yaml" and so on -are searched. -This flag can be set using SFTPGO_CONFIG_FILE -env var too.`) - viper.BindPFlag(configFileKey, subsystemCmd.Flags().Lookup(configFileFlag)) //nolint:errcheck + addConfigFlags(subsystemCmd) viper.SetDefault(logVerboseKey, defaultLogVerbose) viper.BindEnv(logVerboseKey, "SFTPGO_LOG_VERBOSE") //nolint:errcheck diff --git a/config/config.go b/config/config.go index 7f2eab05..fceddf40 100644 --- a/config/config.go +++ b/config/config.go @@ -22,10 +22,10 @@ import ( const ( logSender = "config" - // DefaultConfigName defines the name for the default config file. + // configName defines the name for the default config file. // This is the file name without extension, we use viper and so we // support all the config files format supported by viper - DefaultConfigName = "sftpgo" + configName = "sftpgo" // ConfigEnvPrefix defines a prefix that ENVIRONMENT variables will use configEnvPrefix = "sftpgo" ) @@ -48,6 +48,13 @@ type globalConfig struct { } func init() { + Init() +} + +// Init initializes the global configuration. +// It is not supposed to be called outside of this package. +// It is exported to minimize refactoring efforts. Will eventually disappear. +func Init() { // create a default configuration to use if no config file is provided globalConf = globalConfig{ Common: common.Configuration{ @@ -177,7 +184,7 @@ func init() { viper.SetEnvPrefix(configEnvPrefix) replacer := strings.NewReplacer(".", "__") viper.SetEnvKeyReplacer(replacer) - viper.SetConfigName(DefaultConfigName) + viper.SetConfigName(configName) setViperDefaults() viper.AutomaticEnv() viper.AllowEmptyEnv(true) @@ -233,12 +240,12 @@ func SetHTTPDConfig(config httpd.Conf) { globalConf.HTTPDConfig = config } -//GetProviderConf returns the configuration for the data provider +// GetProviderConf returns the configuration for the data provider func GetProviderConf() dataprovider.Config { return globalConf.ProviderConf } -//SetProviderConf sets the configuration for the data provider +// SetProviderConf sets the configuration for the data provider func SetProviderConf(config dataprovider.Config) { globalConf.ProviderConf = config } @@ -283,13 +290,13 @@ func getRedactedGlobalConf() globalConfig { // configDir will be added to the configuration search paths. // The search path contains by default the current directory and on linux it contains // $HOME/.config/sftpgo and /etc/sftpgo too. -// configName is the name of the configuration to search without extension -func LoadConfig(configDir, configName string) error { +// configFile is an absolute or relative path (to the working directory) to the configuration file. +func LoadConfig(configDir, configFile string) error { var err error viper.AddConfigPath(configDir) setViperAdditionalConfigPaths() viper.AddConfigPath(".") - viper.SetConfigName(configName) + viper.SetConfigFile(configFile) if err = viper.ReadInConfig(); err != nil { logger.Warn(logSender, "", "error loading configuration file: %v", err) logger.WarnToConsole("error loading configuration file: %v", err) diff --git a/config/config_test.go b/config/config_test.go index cefd23df..03a203b8 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -8,6 +8,7 @@ import ( "strings" "testing" + "github.com/spf13/viper" "github.com/stretchr/testify/assert" "github.com/drakkan/sftpgo/common" @@ -22,12 +23,18 @@ import ( const ( tempConfigName = "temp" - configName = "sftpgo" ) +func reset() { + viper.Reset() + config.Init() +} + func TestLoadConfigTest(t *testing.T) { + reset() + configDir := ".." - err := config.LoadConfig(configDir, configName) + err := config.LoadConfig(configDir, "") assert.NoError(t, err) assert.NotEqual(t, httpd.Conf{}, config.GetHTTPConfig()) assert.NotEqual(t, dataprovider.Config{}, config.GetProviderConf()) @@ -35,25 +42,27 @@ func TestLoadConfigTest(t *testing.T) { assert.NotEqual(t, httpclient.Config{}, config.GetHTTPConfig()) confName := tempConfigName + ".json" configFilePath := filepath.Join(configDir, confName) - err = config.LoadConfig(configDir, tempConfigName) + err = config.LoadConfig(configDir, configFilePath) assert.NoError(t, err) err = ioutil.WriteFile(configFilePath, []byte("{invalid json}"), os.ModePerm) assert.NoError(t, err) - err = config.LoadConfig(configDir, tempConfigName) + err = config.LoadConfig(configDir, configFilePath) assert.NoError(t, err) err = ioutil.WriteFile(configFilePath, []byte("{\"sftpd\": {\"bind_port\": \"a\"}}"), os.ModePerm) assert.NoError(t, err) - err = config.LoadConfig(configDir, tempConfigName) + err = config.LoadConfig(configDir, configFilePath) assert.NotNil(t, err) err = os.Remove(configFilePath) assert.NoError(t, err) } func TestEmptyBanner(t *testing.T) { + reset() + configDir := ".." confName := tempConfigName + ".json" configFilePath := filepath.Join(configDir, confName) - err := config.LoadConfig(configDir, configName) + err := config.LoadConfig(configDir, "") assert.NoError(t, err) sftpdConf := config.GetSFTPDConfig() sftpdConf.Banner = " " @@ -62,7 +71,7 @@ func TestEmptyBanner(t *testing.T) { jsonConf, _ := json.Marshal(c) err = ioutil.WriteFile(configFilePath, jsonConf, os.ModePerm) assert.NoError(t, err) - err = config.LoadConfig(configDir, tempConfigName) + err = config.LoadConfig(configDir, configFilePath) assert.NoError(t, err) sftpdConf = config.GetSFTPDConfig() assert.NotEmpty(t, strings.TrimSpace(sftpdConf.Banner)) @@ -76,7 +85,7 @@ func TestEmptyBanner(t *testing.T) { jsonConf, _ = json.Marshal(c1) err = ioutil.WriteFile(configFilePath, jsonConf, os.ModePerm) assert.NoError(t, err) - err = config.LoadConfig(configDir, tempConfigName) + err = config.LoadConfig(configDir, configFilePath) assert.NoError(t, err) ftpdConf = config.GetFTPDConfig() assert.NotEmpty(t, strings.TrimSpace(ftpdConf.Banner)) @@ -85,10 +94,12 @@ func TestEmptyBanner(t *testing.T) { } func TestInvalidUploadMode(t *testing.T) { + reset() + configDir := ".." confName := tempConfigName + ".json" configFilePath := filepath.Join(configDir, confName) - err := config.LoadConfig(configDir, configName) + err := config.LoadConfig(configDir, "") assert.NoError(t, err) commonConf := config.GetCommonConfig() commonConf.UploadMode = 10 @@ -98,17 +109,19 @@ func TestInvalidUploadMode(t *testing.T) { assert.NoError(t, err) err = ioutil.WriteFile(configFilePath, jsonConf, os.ModePerm) assert.NoError(t, err) - err = config.LoadConfig(configDir, tempConfigName) + err = config.LoadConfig(configDir, configFilePath) assert.NotNil(t, err) err = os.Remove(configFilePath) assert.NoError(t, err) } func TestInvalidExternalAuthScope(t *testing.T) { + reset() + configDir := ".." confName := tempConfigName + ".json" configFilePath := filepath.Join(configDir, confName) - err := config.LoadConfig(configDir, configName) + err := config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() providerConf.ExternalAuthScope = 10 @@ -118,17 +131,19 @@ func TestInvalidExternalAuthScope(t *testing.T) { assert.NoError(t, err) err = ioutil.WriteFile(configFilePath, jsonConf, os.ModePerm) assert.NoError(t, err) - err = config.LoadConfig(configDir, tempConfigName) + err = config.LoadConfig(configDir, configFilePath) assert.NotNil(t, err) err = os.Remove(configFilePath) assert.NoError(t, err) } func TestInvalidCredentialsPath(t *testing.T) { + reset() + configDir := ".." confName := tempConfigName + ".json" configFilePath := filepath.Join(configDir, confName) - err := config.LoadConfig(configDir, configName) + err := config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() providerConf.CredentialsPath = "" @@ -138,17 +153,19 @@ func TestInvalidCredentialsPath(t *testing.T) { assert.NoError(t, err) err = ioutil.WriteFile(configFilePath, jsonConf, os.ModePerm) assert.NoError(t, err) - err = config.LoadConfig(configDir, tempConfigName) + err = config.LoadConfig(configDir, configFilePath) assert.NotNil(t, err) err = os.Remove(configFilePath) assert.NoError(t, err) } func TestInvalidProxyProtocol(t *testing.T) { + reset() + configDir := ".." confName := tempConfigName + ".json" configFilePath := filepath.Join(configDir, confName) - err := config.LoadConfig(configDir, configName) + err := config.LoadConfig(configDir, "") assert.NoError(t, err) commonConf := config.GetCommonConfig() commonConf.ProxyProtocol = 10 @@ -158,17 +175,19 @@ func TestInvalidProxyProtocol(t *testing.T) { assert.NoError(t, err) err = ioutil.WriteFile(configFilePath, jsonConf, os.ModePerm) assert.NoError(t, err) - err = config.LoadConfig(configDir, tempConfigName) + err = config.LoadConfig(configDir, configFilePath) assert.NotNil(t, err) err = os.Remove(configFilePath) assert.NoError(t, err) } func TestInvalidUsersBaseDir(t *testing.T) { + reset() + configDir := ".." confName := tempConfigName + ".json" configFilePath := filepath.Join(configDir, confName) - err := config.LoadConfig(configDir, configName) + err := config.LoadConfig(configDir, "") assert.NoError(t, err) providerConf := config.GetProviderConf() providerConf.UsersBaseDir = "." @@ -178,17 +197,19 @@ func TestInvalidUsersBaseDir(t *testing.T) { assert.NoError(t, err) err = ioutil.WriteFile(configFilePath, jsonConf, os.ModePerm) assert.NoError(t, err) - err = config.LoadConfig(configDir, tempConfigName) + err = config.LoadConfig(configDir, configFilePath) assert.NotNil(t, err) err = os.Remove(configFilePath) assert.NoError(t, err) } func TestCommonParamsCompatibility(t *testing.T) { + reset() + configDir := ".." confName := tempConfigName + ".json" configFilePath := filepath.Join(configDir, confName) - err := config.LoadConfig(configDir, configName) + err := config.LoadConfig(configDir, "") assert.NoError(t, err) sftpdConf := config.GetSFTPDConfig() sftpdConf.IdleTimeout = 21 //nolint:staticcheck @@ -204,7 +225,7 @@ func TestCommonParamsCompatibility(t *testing.T) { assert.NoError(t, err) err = ioutil.WriteFile(configFilePath, jsonConf, os.ModePerm) assert.NoError(t, err) - err = config.LoadConfig(configDir, tempConfigName) + err = config.LoadConfig(configDir, configFilePath) assert.NoError(t, err) commonConf := config.GetCommonConfig() assert.Equal(t, 21, commonConf.IdleTimeout) @@ -220,10 +241,12 @@ func TestCommonParamsCompatibility(t *testing.T) { } func TestHostKeyCompatibility(t *testing.T) { + reset() + configDir := ".." confName := tempConfigName + ".json" configFilePath := filepath.Join(configDir, confName) - err := config.LoadConfig(configDir, configName) + err := config.LoadConfig(configDir, "") assert.NoError(t, err) sftpdConf := config.GetSFTPDConfig() sftpdConf.Keys = []sftpd.Key{ //nolint:staticcheck @@ -240,7 +263,7 @@ func TestHostKeyCompatibility(t *testing.T) { assert.NoError(t, err) err = ioutil.WriteFile(configFilePath, jsonConf, os.ModePerm) assert.NoError(t, err) - err = config.LoadConfig(configDir, tempConfigName) + err = config.LoadConfig(configDir, configFilePath) assert.NoError(t, err) sftpdConf = config.GetSFTPDConfig() assert.Equal(t, 2, len(sftpdConf.HostKeys)) @@ -251,6 +274,8 @@ func TestHostKeyCompatibility(t *testing.T) { } func TestSetGetConfig(t *testing.T) { + reset() + sftpdConf := config.GetSFTPDConfig() sftpdConf.MaxAuthTries = 10 config.SetSFTPDConfig(sftpdConf) @@ -288,8 +313,10 @@ func TestSetGetConfig(t *testing.T) { } func TestServiceToStart(t *testing.T) { + reset() + configDir := ".." - err := config.LoadConfig(configDir, configName) + err := config.LoadConfig(configDir, "") assert.NoError(t, err) assert.True(t, config.HasServicesToStart()) sftpdConf := config.GetSFTPDConfig() @@ -315,6 +342,8 @@ func TestServiceToStart(t *testing.T) { } func TestConfigFromEnv(t *testing.T) { + reset() + os.Setenv("SFTPGO_SFTPD__BIND_ADDRESS", "127.0.0.1") os.Setenv("SFTPGO_DATA_PROVIDER__PASSWORD_HASHING__ARGON2_OPTIONS__ITERATIONS", "41") os.Setenv("SFTPGO_DATA_PROVIDER__POOL_SIZE", "10")