diff --git a/config/config_test.go b/config/config_test.go index 11c30ff7..c8231f7a 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -13,239 +13,207 @@ import ( "github.com/drakkan/sftpgo/httpclient" "github.com/drakkan/sftpgo/httpd" "github.com/drakkan/sftpgo/sftpd" + "github.com/stretchr/testify/assert" ) const ( tempConfigName = "temp" + configName = "sftpgo" ) func TestLoadConfigTest(t *testing.T) { configDir := ".." - err := config.LoadConfig(configDir, "") - if err != nil { - t.Errorf("error loading config") - } - emptyHTTPDConf := httpd.Conf{} - if config.GetHTTPDConfig() == emptyHTTPDConf { - t.Errorf("error loading httpd conf") - } - emptyProviderConf := dataprovider.Config{} - if config.GetProviderConf().Driver == emptyProviderConf.Driver { - t.Errorf("error loading provider conf") - } - emptySFTPDConf := sftpd.Configuration{} - if config.GetSFTPDConfig().BindPort == emptySFTPDConf.BindPort { - t.Errorf("error loading SFTPD conf") - } - emptyHTTPConfig := httpclient.Config{} - if config.GetHTTPConfig().Timeout == emptyHTTPConfig.Timeout { - t.Errorf("error loading HTTP conf") - } + err := config.LoadConfig(configDir, configName) + assert.Nil(t, err) + assert.NotEqual(t, httpd.Conf{}, config.GetHTTPConfig()) + assert.NotEqual(t, dataprovider.Config{}, config.GetProviderConf()) + assert.NotEqual(t, sftpd.Configuration{}, config.GetSFTPDConfig()) + assert.NotEqual(t, httpclient.Config{}, config.GetHTTPConfig()) confName := tempConfigName + ".json" configFilePath := filepath.Join(configDir, confName) err = config.LoadConfig(configDir, tempConfigName) - if err == nil { - t.Errorf("loading a non existent config file must fail") - } - ioutil.WriteFile(configFilePath, []byte("{invalid json}"), 0666) + assert.NotNil(t, err) + err = ioutil.WriteFile(configFilePath, []byte("{invalid json}"), 0666) + assert.Nil(t, err) err = config.LoadConfig(configDir, tempConfigName) - if err == nil { - t.Errorf("loading an invalid config file must fail") - } - ioutil.WriteFile(configFilePath, []byte("{\"sftpd\": {\"bind_port\": \"a\"}}"), 0666) + assert.NotNil(t, err) + err = ioutil.WriteFile(configFilePath, []byte("{\"sftpd\": {\"bind_port\": \"a\"}}"), 0666) + assert.Nil(t, err) err = config.LoadConfig(configDir, tempConfigName) - if err == nil { - t.Errorf("loading a config with an invalid bond_port must fail") - } - os.Remove(configFilePath) + assert.NotNil(t, err) + err = os.Remove(configFilePath) + assert.Nil(t, err) } func TestEmptyBanner(t *testing.T) { configDir := ".." confName := tempConfigName + ".json" configFilePath := filepath.Join(configDir, confName) - config.LoadConfig(configDir, "") + err := config.LoadConfig(configDir, configName) + assert.Nil(t, err) sftpdConf := config.GetSFTPDConfig() sftpdConf.Banner = " " c := make(map[string]sftpd.Configuration) c["sftpd"] = sftpdConf jsonConf, _ := json.Marshal(c) - err := ioutil.WriteFile(configFilePath, jsonConf, 0666) - if err != nil { - t.Errorf("error saving temporary configuration") - } - config.LoadConfig(configDir, tempConfigName) + err = ioutil.WriteFile(configFilePath, jsonConf, 0666) + assert.Nil(t, err) + err = config.LoadConfig(configDir, tempConfigName) + assert.Nil(t, err) sftpdConf = config.GetSFTPDConfig() - if strings.TrimSpace(sftpdConf.Banner) == "" { - t.Errorf("SFTPD banner cannot be empty") - } - os.Remove(configFilePath) + assert.NotEmpty(t, strings.TrimSpace(sftpdConf.Banner)) + err = os.Remove(configFilePath) + assert.Nil(t, err) } func TestInvalidUploadMode(t *testing.T) { configDir := ".." confName := tempConfigName + ".json" configFilePath := filepath.Join(configDir, confName) - config.LoadConfig(configDir, "") + err := config.LoadConfig(configDir, configName) + assert.Nil(t, err) sftpdConf := config.GetSFTPDConfig() sftpdConf.UploadMode = 10 c := make(map[string]sftpd.Configuration) c["sftpd"] = sftpdConf - jsonConf, _ := json.Marshal(c) - err := ioutil.WriteFile(configFilePath, jsonConf, 0666) - if err != nil { - t.Errorf("error saving temporary configuration") - } + jsonConf, err := json.Marshal(c) + assert.Nil(t, err) + err = ioutil.WriteFile(configFilePath, jsonConf, 0666) + assert.Nil(t, err) err = config.LoadConfig(configDir, tempConfigName) - if err == nil { - t.Errorf("Loading configuration with invalid upload_mode must fail") - } - os.Remove(configFilePath) + assert.NotNil(t, err) + err = os.Remove(configFilePath) + assert.Nil(t, err) } func TestInvalidExternalAuthScope(t *testing.T) { configDir := ".." confName := tempConfigName + ".json" configFilePath := filepath.Join(configDir, confName) - config.LoadConfig(configDir, "") + err := config.LoadConfig(configDir, configName) + assert.Nil(t, err) providerConf := config.GetProviderConf() providerConf.ExternalAuthScope = 10 c := make(map[string]dataprovider.Config) c["data_provider"] = providerConf - jsonConf, _ := json.Marshal(c) - err := ioutil.WriteFile(configFilePath, jsonConf, 0666) - if err != nil { - t.Errorf("error saving temporary configuration") - } + jsonConf, err := json.Marshal(c) + assert.Nil(t, err) + err = ioutil.WriteFile(configFilePath, jsonConf, 0666) + assert.Nil(t, err) err = config.LoadConfig(configDir, tempConfigName) - if err == nil { - t.Errorf("Loading configuration with invalid external_auth_scope must fail") - } - os.Remove(configFilePath) + assert.NotNil(t, err) + err = os.Remove(configFilePath) + assert.Nil(t, err) } func TestInvalidCredentialsPath(t *testing.T) { configDir := ".." confName := tempConfigName + ".json" configFilePath := filepath.Join(configDir, confName) - config.LoadConfig(configDir, "") + err := config.LoadConfig(configDir, configName) + assert.Nil(t, err) providerConf := config.GetProviderConf() providerConf.CredentialsPath = "" c := make(map[string]dataprovider.Config) c["data_provider"] = providerConf - jsonConf, _ := json.Marshal(c) - err := ioutil.WriteFile(configFilePath, jsonConf, 0666) - if err != nil { - t.Errorf("error saving temporary configuration") - } + jsonConf, err := json.Marshal(c) + assert.Nil(t, err) + err = ioutil.WriteFile(configFilePath, jsonConf, 0666) + assert.Nil(t, err) err = config.LoadConfig(configDir, tempConfigName) - if err == nil { - t.Errorf("Loading configuration with credentials path must fail") - } - os.Remove(configFilePath) + assert.NotNil(t, err) + err = os.Remove(configFilePath) + assert.Nil(t, err) } func TestInvalidProxyProtocol(t *testing.T) { configDir := ".." confName := tempConfigName + ".json" configFilePath := filepath.Join(configDir, confName) - config.LoadConfig(configDir, "") + err := config.LoadConfig(configDir, configName) + assert.Nil(t, err) sftpdConf := config.GetSFTPDConfig() sftpdConf.ProxyProtocol = 10 c := make(map[string]sftpd.Configuration) c["sftpd"] = sftpdConf - jsonConf, _ := json.Marshal(c) - err := ioutil.WriteFile(configFilePath, jsonConf, 0666) - if err != nil { - t.Errorf("error saving temporary configuration") - } + jsonConf, err := json.Marshal(c) + assert.Nil(t, err) + err = ioutil.WriteFile(configFilePath, jsonConf, 0666) + assert.Nil(t, err) err = config.LoadConfig(configDir, tempConfigName) - if err == nil { - t.Errorf("Loading configuration with invalid proxy_protocol must fail") - } - os.Remove(configFilePath) + assert.NotNil(t, err) + err = os.Remove(configFilePath) + assert.Nil(t, err) } func TestInvalidUsersBaseDir(t *testing.T) { configDir := ".." confName := tempConfigName + ".json" configFilePath := filepath.Join(configDir, confName) - config.LoadConfig(configDir, "") + err := config.LoadConfig(configDir, configName) + assert.Nil(t, err) providerConf := config.GetProviderConf() providerConf.UsersBaseDir = "." c := make(map[string]dataprovider.Config) c["data_provider"] = providerConf - jsonConf, _ := json.Marshal(c) - err := ioutil.WriteFile(configFilePath, jsonConf, 0666) - if err != nil { - t.Errorf("error saving temporary configuration") - } + jsonConf, err := json.Marshal(c) + assert.Nil(t, err) + err = ioutil.WriteFile(configFilePath, jsonConf, 0666) + assert.Nil(t, err) err = config.LoadConfig(configDir, tempConfigName) - if err == nil { - t.Errorf("Loading configuration with invalid users base dir must fail") - } - os.Remove(configFilePath) + assert.NotNil(t, err) + err = os.Remove(configFilePath) + assert.Nil(t, err) } func TestHookCompatibity(t *testing.T) { configDir := ".." confName := tempConfigName + ".json" configFilePath := filepath.Join(configDir, confName) - config.LoadConfig(configDir, "") + err := config.LoadConfig(configDir, configName) + assert.Nil(t, err) providerConf := config.GetProviderConf() providerConf.ExternalAuthProgram = "ext_auth_program" providerConf.PreLoginProgram = "pre_login_program" c := make(map[string]dataprovider.Config) c["data_provider"] = providerConf - jsonConf, _ := json.Marshal(c) - err := ioutil.WriteFile(configFilePath, jsonConf, 0666) - if err != nil { - t.Errorf("error saving temporary configuration") - } - config.LoadConfig(configDir, tempConfigName) + jsonConf, err := json.Marshal(c) + assert.Nil(t, err) + err = ioutil.WriteFile(configFilePath, jsonConf, 0666) + assert.Nil(t, err) + err = config.LoadConfig(configDir, tempConfigName) providerConf = config.GetProviderConf() - if providerConf.ExternalAuthHook != "ext_auth_program" { - t.Error("unexpected external auth hook") - } - if providerConf.PreLoginHook != "pre_login_program" { - t.Error("unexpected pre-login hook") - } - os.Remove(configFilePath) - + assert.Equal(t, "ext_auth_program", providerConf.ExternalAuthHook) + assert.Equal(t, "pre_login_program", providerConf.PreLoginHook) + err = os.Remove(configFilePath) + assert.Nil(t, err) sftpdConf := config.GetSFTPDConfig() sftpdConf.KeyboardInteractiveProgram = "key_int_program" cnf := make(map[string]sftpd.Configuration) cnf["sftpd"] = sftpdConf - jsonConf, _ = json.Marshal(cnf) + jsonConf, err = json.Marshal(cnf) + assert.Nil(t, err) err = ioutil.WriteFile(configFilePath, jsonConf, 0666) - if err != nil { - t.Errorf("error saving temporary configuration") - } - config.LoadConfig(configDir, tempConfigName) + assert.Nil(t, err) + err = config.LoadConfig(configDir, tempConfigName) + assert.Nil(t, err) sftpdConf = config.GetSFTPDConfig() - if sftpdConf.KeyboardInteractiveHook != "key_int_program" { - t.Error("unexpected keyboard interactive hook") - } - os.Remove(configFilePath) + assert.Equal(t, "key_int_program", sftpdConf.KeyboardInteractiveHook) + err = os.Remove(configFilePath) + assert.Nil(t, err) } func TestSetGetConfig(t *testing.T) { sftpdConf := config.GetSFTPDConfig() sftpdConf.IdleTimeout = 3 config.SetSFTPDConfig(sftpdConf) - if config.GetSFTPDConfig().IdleTimeout != sftpdConf.IdleTimeout { - t.Errorf("set sftpd conf failed") - } + assert.Equal(t, sftpdConf.IdleTimeout, config.GetSFTPDConfig().IdleTimeout) dataProviderConf := config.GetProviderConf() dataProviderConf.Host = "test host" config.SetProviderConf(dataProviderConf) - if config.GetProviderConf().Host != dataProviderConf.Host { - t.Errorf("set data provider conf failed") - } + assert.Equal(t, dataProviderConf.Host, config.GetProviderConf().Host) httpdConf := config.GetHTTPDConfig() httpdConf.BindAddress = "0.0.0.0" config.SetHTTPDConfig(httpdConf) - if config.GetHTTPDConfig().BindAddress != httpdConf.BindAddress { - t.Errorf("set httpd conf failed") - } + assert.Equal(t, httpdConf.BindAddress, config.GetHTTPDConfig().BindAddress) } diff --git a/go.mod b/go.mod index d93211f3..9270658f 100644 --- a/go.mod +++ b/go.mod @@ -32,6 +32,7 @@ require ( github.com/spf13/jwalterweatherman v1.1.0 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/spf13/viper v1.6.3 + github.com/stretchr/testify v1.5.1 go.etcd.io/bbolt v1.3.4 golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59 golang.org/x/sys v0.0.0-20200331124033-c3d80250170d diff --git a/go.sum b/go.sum index 6ee3821f..4f87b71b 100644 --- a/go.sum +++ b/go.sum @@ -59,6 +59,7 @@ github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7 github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA= github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no= @@ -194,6 +195,7 @@ github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= github.com/prometheus/client_golang v0.9.3/go.mod h1:/TN21ttK/J9q6uSwhBd54HahCDft0ttaMvbicHlPoso= @@ -253,6 +255,7 @@ github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/subosito/gotenv v1.2.0 h1:Slr1R9HxAlEKefgq5jn9U+DnETlIUa6HfgEzj0g5d7s= github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw=