mirror of
https://github.com/drakkan/sftpgo.git
synced 2025-12-06 14:20:55 +03:00
Accept a config file path instead of a config name
Config name is a Viper concept used for searching a specific file in various paths with various extensions. Making it configurable is usually not a useful feature as users mostly want to define a full or relative path to a config file. This change replaces config name with config file.
This commit is contained in:
@@ -22,10 +22,10 @@ import (
|
||||
|
||||
const (
|
||||
logSender = "config"
|
||||
// DefaultConfigName defines the name for the default config file.
|
||||
// configName defines the name for the default config file.
|
||||
// This is the file name without extension, we use viper and so we
|
||||
// support all the config files format supported by viper
|
||||
DefaultConfigName = "sftpgo"
|
||||
configName = "sftpgo"
|
||||
// ConfigEnvPrefix defines a prefix that ENVIRONMENT variables will use
|
||||
configEnvPrefix = "sftpgo"
|
||||
)
|
||||
@@ -48,6 +48,13 @@ type globalConfig struct {
|
||||
}
|
||||
|
||||
func init() {
|
||||
Init()
|
||||
}
|
||||
|
||||
// Init initializes the global configuration.
|
||||
// It is not supposed to be called outside of this package.
|
||||
// It is exported to minimize refactoring efforts. Will eventually disappear.
|
||||
func Init() {
|
||||
// create a default configuration to use if no config file is provided
|
||||
globalConf = globalConfig{
|
||||
Common: common.Configuration{
|
||||
@@ -177,7 +184,7 @@ func init() {
|
||||
viper.SetEnvPrefix(configEnvPrefix)
|
||||
replacer := strings.NewReplacer(".", "__")
|
||||
viper.SetEnvKeyReplacer(replacer)
|
||||
viper.SetConfigName(DefaultConfigName)
|
||||
viper.SetConfigName(configName)
|
||||
setViperDefaults()
|
||||
viper.AutomaticEnv()
|
||||
viper.AllowEmptyEnv(true)
|
||||
@@ -233,12 +240,12 @@ func SetHTTPDConfig(config httpd.Conf) {
|
||||
globalConf.HTTPDConfig = config
|
||||
}
|
||||
|
||||
//GetProviderConf returns the configuration for the data provider
|
||||
// GetProviderConf returns the configuration for the data provider
|
||||
func GetProviderConf() dataprovider.Config {
|
||||
return globalConf.ProviderConf
|
||||
}
|
||||
|
||||
//SetProviderConf sets the configuration for the data provider
|
||||
// SetProviderConf sets the configuration for the data provider
|
||||
func SetProviderConf(config dataprovider.Config) {
|
||||
globalConf.ProviderConf = config
|
||||
}
|
||||
@@ -283,13 +290,13 @@ func getRedactedGlobalConf() globalConfig {
|
||||
// configDir will be added to the configuration search paths.
|
||||
// The search path contains by default the current directory and on linux it contains
|
||||
// $HOME/.config/sftpgo and /etc/sftpgo too.
|
||||
// configName is the name of the configuration to search without extension
|
||||
func LoadConfig(configDir, configName string) error {
|
||||
// configFile is an absolute or relative path (to the working directory) to the configuration file.
|
||||
func LoadConfig(configDir, configFile string) error {
|
||||
var err error
|
||||
viper.AddConfigPath(configDir)
|
||||
setViperAdditionalConfigPaths()
|
||||
viper.AddConfigPath(".")
|
||||
viper.SetConfigName(configName)
|
||||
viper.SetConfigFile(configFile)
|
||||
if err = viper.ReadInConfig(); err != nil {
|
||||
logger.Warn(logSender, "", "error loading configuration file: %v", err)
|
||||
logger.WarnToConsole("error loading configuration file: %v", err)
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/drakkan/sftpgo/common"
|
||||
@@ -22,12 +23,18 @@ import (
|
||||
|
||||
const (
|
||||
tempConfigName = "temp"
|
||||
configName = "sftpgo"
|
||||
)
|
||||
|
||||
func reset() {
|
||||
viper.Reset()
|
||||
config.Init()
|
||||
}
|
||||
|
||||
func TestLoadConfigTest(t *testing.T) {
|
||||
reset()
|
||||
|
||||
configDir := ".."
|
||||
err := config.LoadConfig(configDir, configName)
|
||||
err := config.LoadConfig(configDir, "")
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, httpd.Conf{}, config.GetHTTPConfig())
|
||||
assert.NotEqual(t, dataprovider.Config{}, config.GetProviderConf())
|
||||
@@ -35,25 +42,27 @@ func TestLoadConfigTest(t *testing.T) {
|
||||
assert.NotEqual(t, httpclient.Config{}, config.GetHTTPConfig())
|
||||
confName := tempConfigName + ".json"
|
||||
configFilePath := filepath.Join(configDir, confName)
|
||||
err = config.LoadConfig(configDir, tempConfigName)
|
||||
err = config.LoadConfig(configDir, configFilePath)
|
||||
assert.NoError(t, err)
|
||||
err = ioutil.WriteFile(configFilePath, []byte("{invalid json}"), os.ModePerm)
|
||||
assert.NoError(t, err)
|
||||
err = config.LoadConfig(configDir, tempConfigName)
|
||||
err = config.LoadConfig(configDir, configFilePath)
|
||||
assert.NoError(t, err)
|
||||
err = ioutil.WriteFile(configFilePath, []byte("{\"sftpd\": {\"bind_port\": \"a\"}}"), os.ModePerm)
|
||||
assert.NoError(t, err)
|
||||
err = config.LoadConfig(configDir, tempConfigName)
|
||||
err = config.LoadConfig(configDir, configFilePath)
|
||||
assert.NotNil(t, err)
|
||||
err = os.Remove(configFilePath)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestEmptyBanner(t *testing.T) {
|
||||
reset()
|
||||
|
||||
configDir := ".."
|
||||
confName := tempConfigName + ".json"
|
||||
configFilePath := filepath.Join(configDir, confName)
|
||||
err := config.LoadConfig(configDir, configName)
|
||||
err := config.LoadConfig(configDir, "")
|
||||
assert.NoError(t, err)
|
||||
sftpdConf := config.GetSFTPDConfig()
|
||||
sftpdConf.Banner = " "
|
||||
@@ -62,7 +71,7 @@ func TestEmptyBanner(t *testing.T) {
|
||||
jsonConf, _ := json.Marshal(c)
|
||||
err = ioutil.WriteFile(configFilePath, jsonConf, os.ModePerm)
|
||||
assert.NoError(t, err)
|
||||
err = config.LoadConfig(configDir, tempConfigName)
|
||||
err = config.LoadConfig(configDir, configFilePath)
|
||||
assert.NoError(t, err)
|
||||
sftpdConf = config.GetSFTPDConfig()
|
||||
assert.NotEmpty(t, strings.TrimSpace(sftpdConf.Banner))
|
||||
@@ -76,7 +85,7 @@ func TestEmptyBanner(t *testing.T) {
|
||||
jsonConf, _ = json.Marshal(c1)
|
||||
err = ioutil.WriteFile(configFilePath, jsonConf, os.ModePerm)
|
||||
assert.NoError(t, err)
|
||||
err = config.LoadConfig(configDir, tempConfigName)
|
||||
err = config.LoadConfig(configDir, configFilePath)
|
||||
assert.NoError(t, err)
|
||||
ftpdConf = config.GetFTPDConfig()
|
||||
assert.NotEmpty(t, strings.TrimSpace(ftpdConf.Banner))
|
||||
@@ -85,10 +94,12 @@ func TestEmptyBanner(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestInvalidUploadMode(t *testing.T) {
|
||||
reset()
|
||||
|
||||
configDir := ".."
|
||||
confName := tempConfigName + ".json"
|
||||
configFilePath := filepath.Join(configDir, confName)
|
||||
err := config.LoadConfig(configDir, configName)
|
||||
err := config.LoadConfig(configDir, "")
|
||||
assert.NoError(t, err)
|
||||
commonConf := config.GetCommonConfig()
|
||||
commonConf.UploadMode = 10
|
||||
@@ -98,17 +109,19 @@ func TestInvalidUploadMode(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
err = ioutil.WriteFile(configFilePath, jsonConf, os.ModePerm)
|
||||
assert.NoError(t, err)
|
||||
err = config.LoadConfig(configDir, tempConfigName)
|
||||
err = config.LoadConfig(configDir, configFilePath)
|
||||
assert.NotNil(t, err)
|
||||
err = os.Remove(configFilePath)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestInvalidExternalAuthScope(t *testing.T) {
|
||||
reset()
|
||||
|
||||
configDir := ".."
|
||||
confName := tempConfigName + ".json"
|
||||
configFilePath := filepath.Join(configDir, confName)
|
||||
err := config.LoadConfig(configDir, configName)
|
||||
err := config.LoadConfig(configDir, "")
|
||||
assert.NoError(t, err)
|
||||
providerConf := config.GetProviderConf()
|
||||
providerConf.ExternalAuthScope = 10
|
||||
@@ -118,17 +131,19 @@ func TestInvalidExternalAuthScope(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
err = ioutil.WriteFile(configFilePath, jsonConf, os.ModePerm)
|
||||
assert.NoError(t, err)
|
||||
err = config.LoadConfig(configDir, tempConfigName)
|
||||
err = config.LoadConfig(configDir, configFilePath)
|
||||
assert.NotNil(t, err)
|
||||
err = os.Remove(configFilePath)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestInvalidCredentialsPath(t *testing.T) {
|
||||
reset()
|
||||
|
||||
configDir := ".."
|
||||
confName := tempConfigName + ".json"
|
||||
configFilePath := filepath.Join(configDir, confName)
|
||||
err := config.LoadConfig(configDir, configName)
|
||||
err := config.LoadConfig(configDir, "")
|
||||
assert.NoError(t, err)
|
||||
providerConf := config.GetProviderConf()
|
||||
providerConf.CredentialsPath = ""
|
||||
@@ -138,17 +153,19 @@ func TestInvalidCredentialsPath(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
err = ioutil.WriteFile(configFilePath, jsonConf, os.ModePerm)
|
||||
assert.NoError(t, err)
|
||||
err = config.LoadConfig(configDir, tempConfigName)
|
||||
err = config.LoadConfig(configDir, configFilePath)
|
||||
assert.NotNil(t, err)
|
||||
err = os.Remove(configFilePath)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestInvalidProxyProtocol(t *testing.T) {
|
||||
reset()
|
||||
|
||||
configDir := ".."
|
||||
confName := tempConfigName + ".json"
|
||||
configFilePath := filepath.Join(configDir, confName)
|
||||
err := config.LoadConfig(configDir, configName)
|
||||
err := config.LoadConfig(configDir, "")
|
||||
assert.NoError(t, err)
|
||||
commonConf := config.GetCommonConfig()
|
||||
commonConf.ProxyProtocol = 10
|
||||
@@ -158,17 +175,19 @@ func TestInvalidProxyProtocol(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
err = ioutil.WriteFile(configFilePath, jsonConf, os.ModePerm)
|
||||
assert.NoError(t, err)
|
||||
err = config.LoadConfig(configDir, tempConfigName)
|
||||
err = config.LoadConfig(configDir, configFilePath)
|
||||
assert.NotNil(t, err)
|
||||
err = os.Remove(configFilePath)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestInvalidUsersBaseDir(t *testing.T) {
|
||||
reset()
|
||||
|
||||
configDir := ".."
|
||||
confName := tempConfigName + ".json"
|
||||
configFilePath := filepath.Join(configDir, confName)
|
||||
err := config.LoadConfig(configDir, configName)
|
||||
err := config.LoadConfig(configDir, "")
|
||||
assert.NoError(t, err)
|
||||
providerConf := config.GetProviderConf()
|
||||
providerConf.UsersBaseDir = "."
|
||||
@@ -178,17 +197,19 @@ func TestInvalidUsersBaseDir(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
err = ioutil.WriteFile(configFilePath, jsonConf, os.ModePerm)
|
||||
assert.NoError(t, err)
|
||||
err = config.LoadConfig(configDir, tempConfigName)
|
||||
err = config.LoadConfig(configDir, configFilePath)
|
||||
assert.NotNil(t, err)
|
||||
err = os.Remove(configFilePath)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestCommonParamsCompatibility(t *testing.T) {
|
||||
reset()
|
||||
|
||||
configDir := ".."
|
||||
confName := tempConfigName + ".json"
|
||||
configFilePath := filepath.Join(configDir, confName)
|
||||
err := config.LoadConfig(configDir, configName)
|
||||
err := config.LoadConfig(configDir, "")
|
||||
assert.NoError(t, err)
|
||||
sftpdConf := config.GetSFTPDConfig()
|
||||
sftpdConf.IdleTimeout = 21 //nolint:staticcheck
|
||||
@@ -204,7 +225,7 @@ func TestCommonParamsCompatibility(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
err = ioutil.WriteFile(configFilePath, jsonConf, os.ModePerm)
|
||||
assert.NoError(t, err)
|
||||
err = config.LoadConfig(configDir, tempConfigName)
|
||||
err = config.LoadConfig(configDir, configFilePath)
|
||||
assert.NoError(t, err)
|
||||
commonConf := config.GetCommonConfig()
|
||||
assert.Equal(t, 21, commonConf.IdleTimeout)
|
||||
@@ -220,10 +241,12 @@ func TestCommonParamsCompatibility(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestHostKeyCompatibility(t *testing.T) {
|
||||
reset()
|
||||
|
||||
configDir := ".."
|
||||
confName := tempConfigName + ".json"
|
||||
configFilePath := filepath.Join(configDir, confName)
|
||||
err := config.LoadConfig(configDir, configName)
|
||||
err := config.LoadConfig(configDir, "")
|
||||
assert.NoError(t, err)
|
||||
sftpdConf := config.GetSFTPDConfig()
|
||||
sftpdConf.Keys = []sftpd.Key{ //nolint:staticcheck
|
||||
@@ -240,7 +263,7 @@ func TestHostKeyCompatibility(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
err = ioutil.WriteFile(configFilePath, jsonConf, os.ModePerm)
|
||||
assert.NoError(t, err)
|
||||
err = config.LoadConfig(configDir, tempConfigName)
|
||||
err = config.LoadConfig(configDir, configFilePath)
|
||||
assert.NoError(t, err)
|
||||
sftpdConf = config.GetSFTPDConfig()
|
||||
assert.Equal(t, 2, len(sftpdConf.HostKeys))
|
||||
@@ -251,6 +274,8 @@ func TestHostKeyCompatibility(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSetGetConfig(t *testing.T) {
|
||||
reset()
|
||||
|
||||
sftpdConf := config.GetSFTPDConfig()
|
||||
sftpdConf.MaxAuthTries = 10
|
||||
config.SetSFTPDConfig(sftpdConf)
|
||||
@@ -288,8 +313,10 @@ func TestSetGetConfig(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestServiceToStart(t *testing.T) {
|
||||
reset()
|
||||
|
||||
configDir := ".."
|
||||
err := config.LoadConfig(configDir, configName)
|
||||
err := config.LoadConfig(configDir, "")
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, config.HasServicesToStart())
|
||||
sftpdConf := config.GetSFTPDConfig()
|
||||
@@ -315,6 +342,8 @@ func TestServiceToStart(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConfigFromEnv(t *testing.T) {
|
||||
reset()
|
||||
|
||||
os.Setenv("SFTPGO_SFTPD__BIND_ADDRESS", "127.0.0.1")
|
||||
os.Setenv("SFTPGO_DATA_PROVIDER__PASSWORD_HASHING__ARGON2_OPTIONS__ITERATIONS", "41")
|
||||
os.Setenv("SFTPGO_DATA_PROVIDER__POOL_SIZE", "10")
|
||||
|
||||
Reference in New Issue
Block a user