diff --git a/cmd/root.go b/cmd/root.go index f53a4b97..01c6e2f4 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -149,7 +149,7 @@ func getCustomServeFlags() []string { result = append(result, "--"+configFileFlag) result = append(result, configFile) } - if logFilePath != defaultLogFile && len(logFilePath) > 0 && logFilePath != "." { + if logFilePath != defaultLogFile && utils.IsFileInputValid(logFilePath) { if !filepath.IsAbs(logFilePath) { logFilePath = filepath.Join(configDir, logFilePath) } diff --git a/cmd/start_windows.go b/cmd/start_windows.go index 74d3768e..e68f3599 100644 --- a/cmd/start_windows.go +++ b/cmd/start_windows.go @@ -5,6 +5,7 @@ import ( "path/filepath" "github.com/drakkan/sftpgo/service" + "github.com/drakkan/sftpgo/utils" "github.com/spf13/cobra" ) @@ -14,7 +15,7 @@ var ( Short: "Start SFTPGo Windows Service", Run: func(cmd *cobra.Command, args []string) { configDir = filepath.Clean(configDir) - if !filepath.IsAbs(logFilePath) && len(logFilePath) > 0 && logFilePath != "." { + if !filepath.IsAbs(logFilePath) && utils.IsFileInputValid(logFilePath) { logFilePath = filepath.Join(configDir, logFilePath) } s := service.Service{ diff --git a/config/config.go b/config/config.go index d0ab411c..5cbd5a6c 100644 --- a/config/config.go +++ b/config/config.go @@ -172,6 +172,12 @@ func LoadConfig(configDir, configName string) error { if strings.TrimSpace(globalConf.SFTPD.Banner) == "" { globalConf.SFTPD.Banner = defaultBanner } + if len(globalConf.ProviderConf.UsersBaseDir) > 0 && !utils.IsFileInputValid(globalConf.ProviderConf.UsersBaseDir) { + err = fmt.Errorf("invalid users base dir %#v will be ignored", globalConf.ProviderConf.UsersBaseDir) + globalConf.ProviderConf.UsersBaseDir = "" + logger.Warn(logSender, "", "Configuration error: %v", err) + logger.WarnToConsole("Configuration error: %v", err) + } if globalConf.SFTPD.UploadMode < 0 || globalConf.SFTPD.UploadMode > 2 { err = fmt.Errorf("invalid upload_mode 0, 1 and 2 are supported, configured: %v reset upload_mode to 0", globalConf.SFTPD.UploadMode) @@ -198,6 +204,6 @@ func LoadConfig(configDir, configName string) error { logger.Warn(logSender, "", "Configuration error: %v", err) logger.WarnToConsole("Configuration error: %v", err) } - logger.Debug(logSender, "", "config file used: '%v', config loaded: %+v", viper.ConfigFileUsed(), getRedactedGlobalConf()) + logger.Debug(logSender, "", "config file used: '%#v', config loaded: %+v", viper.ConfigFileUsed(), getRedactedGlobalConf()) return err } diff --git a/config/config_test.go b/config/config_test.go index ff154a10..cd4cd420 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -161,6 +161,27 @@ func TestInvalidProxyProtocol(t *testing.T) { os.Remove(configFilePath) } +func TestInvalidUsersBaseDir(t *testing.T) { + configDir := ".." + confName := tempConfigName + ".json" + configFilePath := filepath.Join(configDir, confName) + config.LoadConfig(configDir, "") + providerConf := config.GetProviderConf() + providerConf.UsersBaseDir = "." + c := make(map[string]dataprovider.Config) + c["data_provider"] = providerConf + jsonConf, _ := json.Marshal(c) + err := ioutil.WriteFile(configFilePath, jsonConf, 0666) + if err != nil { + t.Errorf("error saving temporary configuration") + } + err = config.LoadConfig(configDir, tempConfigName) + if err == nil { + t.Errorf("Loading configuration with invalid users base dir must fail") + } + os.Remove(configFilePath) +} + func TestSetGetConfig(t *testing.T) { sftpdConf := config.GetSFTPDConfig() sftpdConf.IdleTimeout = 3 diff --git a/dataprovider/bolt.go b/dataprovider/bolt.go index e70e6fd4..47bd61c5 100644 --- a/dataprovider/bolt.go +++ b/dataprovider/bolt.go @@ -55,7 +55,7 @@ func initializeBoltProvider(basePath string) error { var err error logSender = BoltDataProviderName dbPath := config.Name - if dbPath == "." { + if !utils.IsFileInputValid(dbPath) { return fmt.Errorf("Invalid database path: %#v", dbPath) } if !filepath.IsAbs(dbPath) { diff --git a/dataprovider/memory.go b/dataprovider/memory.go index a517d668..868be5bb 100644 --- a/dataprovider/memory.go +++ b/dataprovider/memory.go @@ -39,7 +39,7 @@ type MemoryProvider struct { func initializeMemoryProvider(basePath string) error { configFile := "" - if len(config.Name) > 0 && config.Name != "." { + if utils.IsFileInputValid(config.Name) { configFile = config.Name if !filepath.IsAbs(configFile) { configFile = filepath.Join(basePath, configFile) diff --git a/dataprovider/sqlite.go b/dataprovider/sqlite.go index 8d8043fb..82b94260 100644 --- a/dataprovider/sqlite.go +++ b/dataprovider/sqlite.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/drakkan/sftpgo/logger" + "github.com/drakkan/sftpgo/utils" ) const ( @@ -32,7 +33,7 @@ func initializeSQLiteProvider(basePath string) error { logSender = SQLiteDataProviderName if len(config.ConnectionString) == 0 { dbPath := config.Name - if dbPath == "." { + if !utils.IsFileInputValid(dbPath) { return fmt.Errorf("Invalid database path: %#v", dbPath) } if !filepath.IsAbs(dbPath) { diff --git a/go.mod b/go.mod index 2759255c..f1b2caaf 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,6 @@ module github.com/drakkan/sftpgo go 1.13 require ( - cloud.google.com/go v0.53.0 // indirect cloud.google.com/go/storage v1.6.0 github.com/alexedwards/argon2id v0.0.0-20190612080829-01a59b2b8802 github.com/aws/aws-sdk-go v1.29.14 diff --git a/httpd/httpd.go b/httpd/httpd.go index 1d863d53..d2326369 100644 --- a/httpd/httpd.go +++ b/httpd/httpd.go @@ -15,6 +15,7 @@ import ( "github.com/drakkan/sftpgo/dataprovider" "github.com/drakkan/sftpgo/logger" + "github.com/drakkan/sftpgo/utils" "github.com/go-chi/chi" ) @@ -90,6 +91,10 @@ func (c Conf) Initialize(configDir string) error { backupsPath = getConfigPath(c.BackupsPath, configDir) staticFilesPath := getConfigPath(c.StaticFilesPath, configDir) templatesPath := getConfigPath(c.TemplatesPath, configDir) + if len(backupsPath) == 0 || len(staticFilesPath) == 0 || len(templatesPath) == 0 { + return fmt.Errorf("Required directory is invalid, backup path %#v, static file path: %#v template path: %#v", + backupsPath, staticFilesPath, templatesPath) + } authUserFile := getConfigPath(c.AuthUserFile, configDir) httpAuth, err = newBasicAuthProvider(authUserFile) if err != nil { @@ -129,7 +134,10 @@ func ReloadTLSCertificate() { } func getConfigPath(name, configDir string) string { - if len(name) > 0 && !filepath.IsAbs(name) && name != "." { + if !utils.IsFileInputValid(name) { + return "" + } + if len(name) > 0 && !filepath.IsAbs(name) { return filepath.Join(configDir, name) } return name diff --git a/httpd/httpd_test.go b/httpd/httpd_test.go index cd5a82db..8c1c2670 100644 --- a/httpd/httpd_test.go +++ b/httpd/httpd_test.go @@ -169,6 +169,13 @@ func TestInitialization(t *testing.T) { if err == nil { t.Error("Inizialize must fail") } + httpdConf.CertificateFile = "" + httpdConf.CertificateKeyFile = "" + httpdConf.TemplatesPath = "." + err = httpdConf.Initialize(configDir) + if err == nil { + t.Error("Inizialize must fail") + } } func TestBasicUserHandling(t *testing.T) { diff --git a/logger/logger.go b/logger/logger.go index 03af6ec7..cab07185 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -47,7 +47,7 @@ func GetLogger() *zerolog.Logger { // InitLogger configures the logger using the given parameters func InitLogger(logFilePath string, logMaxSize int, logMaxBackups int, logMaxAge int, logCompress bool, level zerolog.Level) { zerolog.TimeFieldFormat = dateFormat - if len(logFilePath) > 0 && filepath.Clean(logFilePath) != "." { + if isLogFilePathValid(logFilePath) { logger = zerolog.New(&lumberjack.Logger{ Filename: logFilePath, MaxSize: logMaxSize, @@ -183,3 +183,11 @@ func ConnectionFailedLog(user, ip, loginType, errorString string) { Str("error", errorString). Msg("") } + +func isLogFilePathValid(logFilePath string) bool { + cleanInput := filepath.Clean(logFilePath) + if cleanInput == "." || cleanInput == ".." { + return false + } + return true +} diff --git a/service/service.go b/service/service.go index 8ee3a0e2..8e2676bc 100644 --- a/service/service.go +++ b/service/service.go @@ -50,7 +50,7 @@ func (s *Service) Start() error { if !s.LogVerbose { logLevel = zerolog.InfoLevel } - if !filepath.IsAbs(s.LogFilePath) && len(s.LogFilePath) > 0 && s.LogFilePath != "." { + if !filepath.IsAbs(s.LogFilePath) && utils.IsFileInputValid(s.LogFilePath) { s.LogFilePath = filepath.Join(s.ConfigDir, s.LogFilePath) } logger.InitLogger(s.LogFilePath, s.LogMaxSize, s.LogMaxBackups, s.LogMaxAge, s.LogCompress, logLevel) diff --git a/utils/utils.go b/utils/utils.go index 90620abc..4b7c547d 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -288,3 +288,14 @@ func LoadTemplate(t *template.Template, err error) *template.Template { } return t } + +// IsFileInputValid returns true this is a valid file name. +// This method must be used before joining a file name, generally provided as +// user input, with a directory +func IsFileInputValid(fileInput string) bool { + cleanInput := filepath.Clean(fileInput) + if cleanInput == "." || cleanInput == ".." { + return false + } + return true +}