fix some lint issues

This commit is contained in:
Nicola Murino
2020-04-30 14:23:55 +02:00
parent 67c6f27064
commit d70959c34c
33 changed files with 236 additions and 225 deletions

View File

@@ -13,8 +13,8 @@ Fully featured and highly configurable SFTP server, written in Go
- Keyboard interactive authentication. You can easily setup a customizable multi-factor authentication. - Keyboard interactive authentication. You can easily setup a customizable multi-factor authentication.
- Partial authentication. You can configure multi-step authentication requiring, for example, the user password after successful public key authentication. - Partial authentication. You can configure multi-step authentication requiring, for example, the user password after successful public key authentication.
- Per user authentication methods. You can, for example, deny one or more authentication methods to one or more users. - Per user authentication methods. You can, for example, deny one or more authentication methods to one or more users.
- Custom authentication via external programs is supported. - Custom authentication via external programs/HTTP API is supported.
- Dynamic user modification before login via external programs is supported. - Dynamic user modification before login via external programs/HTTP API is supported.
- Quota support: accounts can have individual quota expressed as max total size and/or max number of files. - Quota support: accounts can have individual quota expressed as max total size and/or max number of files.
- Bandwidth throttling is supported, with distinct settings for upload and download. - Bandwidth throttling is supported, with distinct settings for upload and download.
- Per user maximum concurrent sessions. - Per user maximum concurrent sessions.

View File

@@ -31,10 +31,13 @@ Please take a look at the usage below to customize the options.`,
logger.DisableLogger() logger.DisableLogger()
logger.EnableConsoleLogger(zerolog.DebugLevel) logger.EnableConsoleLogger(zerolog.DebugLevel)
configDir = utils.CleanDirInput(configDir) configDir = utils.CleanDirInput(configDir)
config.LoadConfig(configDir, configFile) err := config.LoadConfig(configDir, configFile)
if err != nil {
logger.WarnToConsole("Unable to initialize data provider, config load error: %v", err)
}
providerConf := config.GetProviderConf() providerConf := config.GetProviderConf()
logger.DebugToConsole("Initializing provider: %#v config file: %#v", providerConf.Driver, viper.ConfigFileUsed()) logger.DebugToConsole("Initializing provider: %#v config file: %#v", providerConf.Driver, viper.ConfigFileUsed())
err := dataprovider.InitializeDatabase(providerConf, configDir) err = dataprovider.InitializeDatabase(providerConf, configDir)
if err == nil { if err == nil {
logger.DebugToConsole("Data provider successfully initialized") logger.DebugToConsole("Data provider successfully initialized")
} else { } else {

View File

@@ -2,6 +2,7 @@ package cmd
import ( import (
"fmt" "fmt"
"strconv"
"github.com/drakkan/sftpgo/service" "github.com/drakkan/sftpgo/service"
"github.com/drakkan/sftpgo/utils" "github.com/drakkan/sftpgo/utils"
@@ -51,3 +52,42 @@ func init() {
serviceCmd.AddCommand(installCmd) serviceCmd.AddCommand(installCmd)
addServeFlags(installCmd) addServeFlags(installCmd)
} }
func getCustomServeFlags() []string {
result := []string{}
if configDir != defaultConfigDir {
configDir = utils.CleanDirInput(configDir)
result = append(result, "--"+configDirFlag)
result = append(result, configDir)
}
if configFile != defaultConfigName {
result = append(result, "--"+configFileFlag)
result = append(result, configFile)
}
if logFilePath != defaultLogFile {
result = append(result, "--"+logFilePathFlag)
result = append(result, logFilePath)
}
if logMaxSize != defaultLogMaxSize {
result = append(result, "--"+logMaxSizeFlag)
result = append(result, strconv.Itoa(logMaxSize))
}
if logMaxBackups != defaultLogMaxBackup {
result = append(result, "--"+logMaxBackupFlag)
result = append(result, strconv.Itoa(logMaxBackups))
}
if logMaxAge != defaultLogMaxAge {
result = append(result, "--"+logMaxAgeFlag)
result = append(result, strconv.Itoa(logMaxAge))
}
if logVerbose != defaultLogVerbose {
result = append(result, "--"+logVerboseFlag+"=false")
}
if logCompress != defaultLogCompress {
result = append(result, "--"+logCompressFlag+"=true")
}
if profiler != defaultProfiler {
result = append(result, "--"+profilerFlag+"=true")
}
return result
}

View File

@@ -4,7 +4,6 @@ package cmd
import ( import (
"fmt" "fmt"
"os" "os"
"strconv"
"github.com/drakkan/sftpgo/config" "github.com/drakkan/sftpgo/config"
"github.com/drakkan/sftpgo/utils" "github.com/drakkan/sftpgo/utils"
@@ -13,7 +12,6 @@ import (
) )
const ( const (
logSender = "cmd"
configDirFlag = "config-dir" configDirFlag = "config-dir"
configDirKey = "config_dir" configDirKey = "config_dir"
configFileFlag = "config-file" configFileFlag = "config-file"
@@ -79,110 +77,71 @@ func Execute() {
func addConfigFlags(cmd *cobra.Command) { func addConfigFlags(cmd *cobra.Command) {
viper.SetDefault(configDirKey, defaultConfigDir) viper.SetDefault(configDirKey, defaultConfigDir)
viper.BindEnv(configDirKey, "SFTPGO_CONFIG_DIR") viper.BindEnv(configDirKey, "SFTPGO_CONFIG_DIR") //nolint: errcheck // err is not nil only if the key to bind is missing
cmd.Flags().StringVarP(&configDir, configDirFlag, "c", viper.GetString(configDirKey), cmd.Flags().StringVarP(&configDir, configDirFlag, "c", viper.GetString(configDirKey),
"Location for SFTPGo config dir. This directory should contain the \"sftpgo\" configuration file or the configured "+ "Location for SFTPGo config dir. This directory should contain the \"sftpgo\" configuration file or the configured "+
"config-file and it is used as the base for files with a relative path (eg. the private keys for the SFTP server, "+ "config-file and it is used as the base for files with a relative path (eg. the private keys for the SFTP server, "+
"the SQLite database if you use SQLite as data provider). This flag can be set using SFTPGO_CONFIG_DIR env var too.") "the SQLite database if you use SQLite as data provider). This flag can be set using SFTPGO_CONFIG_DIR env var too.")
viper.BindPFlag(configDirKey, cmd.Flags().Lookup(configDirFlag)) viper.BindPFlag(configDirKey, cmd.Flags().Lookup(configDirFlag)) //nolint: errcheck
viper.SetDefault(configFileKey, defaultConfigName) viper.SetDefault(configFileKey, defaultConfigName)
viper.BindEnv(configFileKey, "SFTPGO_CONFIG_FILE") viper.BindEnv(configFileKey, "SFTPGO_CONFIG_FILE") //nolint: errcheck
cmd.Flags().StringVarP(&configFile, configFileFlag, "f", viper.GetString(configFileKey), cmd.Flags().StringVarP(&configFile, configFileFlag, "f", viper.GetString(configFileKey),
"Name for SFTPGo configuration file. It must be the name of a file stored in config-dir not the absolute path to the "+ "Name for SFTPGo configuration file. It must be the name of a file stored in config-dir not the absolute path to the "+
"configuration file. The specified file name must have no extension we automatically load JSON, YAML, TOML, HCL and "+ "configuration file. The specified file name must have no extension we automatically load JSON, YAML, TOML, HCL and "+
"Java properties. Therefore if you set \"sftpgo\" then \"sftpgo.json\", \"sftpgo.yaml\" and so on are searched. "+ "Java properties. Therefore if you set \"sftpgo\" then \"sftpgo.json\", \"sftpgo.yaml\" and so on are searched. "+
"This flag can be set using SFTPGO_CONFIG_FILE env var too.") "This flag can be set using SFTPGO_CONFIG_FILE env var too.")
viper.BindPFlag(configFileKey, cmd.Flags().Lookup(configFileFlag)) viper.BindPFlag(configFileKey, cmd.Flags().Lookup(configFileFlag)) //nolint: errcheck
} }
func addServeFlags(cmd *cobra.Command) { func addServeFlags(cmd *cobra.Command) {
addConfigFlags(cmd) addConfigFlags(cmd)
viper.SetDefault(logFilePathKey, defaultLogFile) viper.SetDefault(logFilePathKey, defaultLogFile)
viper.BindEnv(logFilePathKey, "SFTPGO_LOG_FILE_PATH") viper.BindEnv(logFilePathKey, "SFTPGO_LOG_FILE_PATH") //nolint: errcheck
cmd.Flags().StringVarP(&logFilePath, logFilePathFlag, "l", viper.GetString(logFilePathKey), cmd.Flags().StringVarP(&logFilePath, logFilePathFlag, "l", viper.GetString(logFilePathKey),
"Location for the log file. Leave empty to write logs to the standard output. This flag can be set using SFTPGO_LOG_FILE_PATH "+ "Location for the log file. Leave empty to write logs to the standard output. This flag can be set using SFTPGO_LOG_FILE_PATH "+
"env var too.") "env var too.")
viper.BindPFlag(logFilePathKey, cmd.Flags().Lookup(logFilePathFlag)) viper.BindPFlag(logFilePathKey, cmd.Flags().Lookup(logFilePathFlag)) //nolint: errcheck
viper.SetDefault(logMaxSizeKey, defaultLogMaxSize) viper.SetDefault(logMaxSizeKey, defaultLogMaxSize)
viper.BindEnv(logMaxSizeKey, "SFTPGO_LOG_MAX_SIZE") viper.BindEnv(logMaxSizeKey, "SFTPGO_LOG_MAX_SIZE") //nolint: errcheck
cmd.Flags().IntVarP(&logMaxSize, logMaxSizeFlag, "s", viper.GetInt(logMaxSizeKey), cmd.Flags().IntVarP(&logMaxSize, logMaxSizeFlag, "s", viper.GetInt(logMaxSizeKey),
"Maximum size in megabytes of the log file before it gets rotated. This flag can be set using SFTPGO_LOG_MAX_SIZE "+ "Maximum size in megabytes of the log file before it gets rotated. This flag can be set using SFTPGO_LOG_MAX_SIZE "+
"env var too. It is unused if log-file-path is empty.") "env var too. It is unused if log-file-path is empty.")
viper.BindPFlag(logMaxSizeKey, cmd.Flags().Lookup(logMaxSizeFlag)) viper.BindPFlag(logMaxSizeKey, cmd.Flags().Lookup(logMaxSizeFlag)) //nolint: errcheck
viper.SetDefault(logMaxBackupKey, defaultLogMaxBackup) viper.SetDefault(logMaxBackupKey, defaultLogMaxBackup)
viper.BindEnv(logMaxBackupKey, "SFTPGO_LOG_MAX_BACKUPS") viper.BindEnv(logMaxBackupKey, "SFTPGO_LOG_MAX_BACKUPS") //nolint: errcheck
cmd.Flags().IntVarP(&logMaxBackups, "log-max-backups", "b", viper.GetInt(logMaxBackupKey), cmd.Flags().IntVarP(&logMaxBackups, "log-max-backups", "b", viper.GetInt(logMaxBackupKey),
"Maximum number of old log files to retain. This flag can be set using SFTPGO_LOG_MAX_BACKUPS env var too. "+ "Maximum number of old log files to retain. This flag can be set using SFTPGO_LOG_MAX_BACKUPS env var too. "+
"It is unused if log-file-path is empty.") "It is unused if log-file-path is empty.")
viper.BindPFlag(logMaxBackupKey, cmd.Flags().Lookup(logMaxBackupFlag)) viper.BindPFlag(logMaxBackupKey, cmd.Flags().Lookup(logMaxBackupFlag)) //nolint: errcheck
viper.SetDefault(logMaxAgeKey, defaultLogMaxAge) viper.SetDefault(logMaxAgeKey, defaultLogMaxAge)
viper.BindEnv(logMaxAgeKey, "SFTPGO_LOG_MAX_AGE") viper.BindEnv(logMaxAgeKey, "SFTPGO_LOG_MAX_AGE") //nolint: errcheck
cmd.Flags().IntVarP(&logMaxAge, "log-max-age", "a", viper.GetInt(logMaxAgeKey), cmd.Flags().IntVarP(&logMaxAge, "log-max-age", "a", viper.GetInt(logMaxAgeKey),
"Maximum number of days to retain old log files. This flag can be set using SFTPGO_LOG_MAX_AGE env var too. "+ "Maximum number of days to retain old log files. This flag can be set using SFTPGO_LOG_MAX_AGE env var too. "+
"It is unused if log-file-path is empty.") "It is unused if log-file-path is empty.")
viper.BindPFlag(logMaxAgeKey, cmd.Flags().Lookup(logMaxAgeFlag)) viper.BindPFlag(logMaxAgeKey, cmd.Flags().Lookup(logMaxAgeFlag)) //nolint: errcheck
viper.SetDefault(logCompressKey, defaultLogCompress) viper.SetDefault(logCompressKey, defaultLogCompress)
viper.BindEnv(logCompressKey, "SFTPGO_LOG_COMPRESS") viper.BindEnv(logCompressKey, "SFTPGO_LOG_COMPRESS") //nolint: errcheck
cmd.Flags().BoolVarP(&logCompress, logCompressFlag, "z", viper.GetBool(logCompressKey), "Determine if the rotated "+ cmd.Flags().BoolVarP(&logCompress, logCompressFlag, "z", viper.GetBool(logCompressKey), "Determine if the rotated "+
"log files should be compressed using gzip. This flag can be set using SFTPGO_LOG_COMPRESS env var too. "+ "log files should be compressed using gzip. This flag can be set using SFTPGO_LOG_COMPRESS env var too. "+
"It is unused if log-file-path is empty.") "It is unused if log-file-path is empty.")
viper.BindPFlag(logCompressKey, cmd.Flags().Lookup(logCompressFlag)) viper.BindPFlag(logCompressKey, cmd.Flags().Lookup(logCompressFlag)) //nolint: errcheck
viper.SetDefault(logVerboseKey, defaultLogVerbose) viper.SetDefault(logVerboseKey, defaultLogVerbose)
viper.BindEnv(logVerboseKey, "SFTPGO_LOG_VERBOSE") viper.BindEnv(logVerboseKey, "SFTPGO_LOG_VERBOSE") //nolint: errcheck
cmd.Flags().BoolVarP(&logVerbose, logVerboseFlag, "v", viper.GetBool(logVerboseKey), "Enable verbose logs. "+ cmd.Flags().BoolVarP(&logVerbose, logVerboseFlag, "v", viper.GetBool(logVerboseKey), "Enable verbose logs. "+
"This flag can be set using SFTPGO_LOG_VERBOSE env var too.") "This flag can be set using SFTPGO_LOG_VERBOSE env var too.")
viper.BindPFlag(logVerboseKey, cmd.Flags().Lookup(logVerboseFlag)) viper.BindPFlag(logVerboseKey, cmd.Flags().Lookup(logVerboseFlag)) //nolint: errcheck
viper.SetDefault(profilerKey, defaultProfiler) viper.SetDefault(profilerKey, defaultProfiler)
viper.BindEnv(profilerKey, "SFTPGO_PROFILER") viper.BindEnv(profilerKey, "SFTPGO_PROFILER") //nolint: errcheck
cmd.Flags().BoolVarP(&profiler, profilerFlag, "p", viper.GetBool(profilerKey), "Enable the built-in profiler. "+ cmd.Flags().BoolVarP(&profiler, profilerFlag, "p", viper.GetBool(profilerKey), "Enable the built-in profiler. "+
"The profiler will be accessible via HTTP/HTTPS using the base URL \"/debug/pprof/\". "+ "The profiler will be accessible via HTTP/HTTPS using the base URL \"/debug/pprof/\". "+
"This flag can be set using SFTPGO_PROFILER env var too.") "This flag can be set using SFTPGO_PROFILER env var too.")
viper.BindPFlag(profilerKey, cmd.Flags().Lookup(profilerFlag)) viper.BindPFlag(profilerKey, cmd.Flags().Lookup(profilerFlag)) //nolint: errcheck
}
func getCustomServeFlags() []string {
result := []string{}
if configDir != defaultConfigDir {
configDir = utils.CleanDirInput(configDir)
result = append(result, "--"+configDirFlag)
result = append(result, configDir)
}
if configFile != defaultConfigName {
result = append(result, "--"+configFileFlag)
result = append(result, configFile)
}
if logFilePath != defaultLogFile {
result = append(result, "--"+logFilePathFlag)
result = append(result, logFilePath)
}
if logMaxSize != defaultLogMaxSize {
result = append(result, "--"+logMaxSizeFlag)
result = append(result, strconv.Itoa(logMaxSize))
}
if logMaxBackups != defaultLogMaxBackup {
result = append(result, "--"+logMaxBackupFlag)
result = append(result, strconv.Itoa(logMaxBackups))
}
if logMaxAge != defaultLogMaxAge {
result = append(result, "--"+logMaxAgeFlag)
result = append(result, strconv.Itoa(logMaxAge))
}
if logVerbose != defaultLogVerbose {
result = append(result, "--"+logVerboseFlag+"=false")
}
if logCompress != defaultLogCompress {
result = append(result, "--"+logCompressFlag+"=true")
}
if profiler != defaultProfiler {
result = append(result, "--"+profilerFlag+"=true")
}
return result
} }

View File

@@ -220,19 +220,20 @@ func LoadConfig(configDir, configName string) error {
} }
func checkHooksCompatibility() { func checkHooksCompatibility() {
if len(globalConf.ProviderConf.ExternalAuthProgram) > 0 && len(globalConf.ProviderConf.ExternalAuthHook) == 0 { // we copy deprecated fields to new ones to keep backward compatibility so lint is disabled
if len(globalConf.ProviderConf.ExternalAuthProgram) > 0 && len(globalConf.ProviderConf.ExternalAuthHook) == 0 { //nolint: staticcheck
logger.Warn(logSender, "", "external_auth_program is deprecated, please use external_auth_hook") logger.Warn(logSender, "", "external_auth_program is deprecated, please use external_auth_hook")
logger.WarnToConsole("external_auth_program is deprecated, please use external_auth_hook") logger.WarnToConsole("external_auth_program is deprecated, please use external_auth_hook")
globalConf.ProviderConf.ExternalAuthHook = globalConf.ProviderConf.ExternalAuthProgram globalConf.ProviderConf.ExternalAuthHook = globalConf.ProviderConf.ExternalAuthProgram //nolint: staticcheck
} }
if len(globalConf.ProviderConf.PreLoginProgram) > 0 && len(globalConf.ProviderConf.PreLoginHook) == 0 { if len(globalConf.ProviderConf.PreLoginProgram) > 0 && len(globalConf.ProviderConf.PreLoginHook) == 0 { //nolint: staticcheck
logger.Warn(logSender, "", "pre_login_program is deprecated, please use pre_login_hook") logger.Warn(logSender, "", "pre_login_program is deprecated, please use pre_login_hook")
logger.WarnToConsole("pre_login_program is deprecated, please use pre_login_hook") logger.WarnToConsole("pre_login_program is deprecated, please use pre_login_hook")
globalConf.ProviderConf.PreLoginHook = globalConf.ProviderConf.PreLoginProgram globalConf.ProviderConf.PreLoginHook = globalConf.ProviderConf.PreLoginProgram //nolint: staticcheck
} }
if len(globalConf.SFTPD.KeyboardInteractiveProgram) > 0 && len(globalConf.SFTPD.KeyboardInteractiveHook) == 0 { if len(globalConf.SFTPD.KeyboardInteractiveProgram) > 0 && len(globalConf.SFTPD.KeyboardInteractiveHook) == 0 { //nolint: staticcheck
logger.Warn(logSender, "", "keyboard_interactive_auth_program is deprecated, please use keyboard_interactive_auth_hook") logger.Warn(logSender, "", "keyboard_interactive_auth_program is deprecated, please use keyboard_interactive_auth_hook")
logger.WarnToConsole("keyboard_interactive_auth_program is deprecated, please use keyboard_interactive_auth_hook") logger.WarnToConsole("keyboard_interactive_auth_program is deprecated, please use keyboard_interactive_auth_hook")
globalConf.SFTPD.KeyboardInteractiveHook = globalConf.SFTPD.KeyboardInteractiveProgram globalConf.SFTPD.KeyboardInteractiveHook = globalConf.SFTPD.KeyboardInteractiveProgram //nolint: staticcheck
} }
} }

View File

@@ -171,7 +171,13 @@ func (p BoltProvider) updateLastLogin(username string) error {
if err != nil { if err != nil {
return err return err
} }
return bucket.Put([]byte(username), buf) err = bucket.Put([]byte(username), buf)
if err == nil {
providerLog(logger.LevelDebug, "last login updated for user %#v", username)
} else {
providerLog(logger.LevelWarn, "error updating last login for user %#v: %v", username, err)
}
return err
}) })
} }
@@ -304,8 +310,7 @@ func (p BoltProvider) deleteUser(user User) error {
func (p BoltProvider) dumpUsers() ([]User, error) { func (p BoltProvider) dumpUsers() ([]User, error) {
users := []User{} users := []User{}
var err error err := p.dbHandle.View(func(tx *bolt.Tx) error {
err = p.dbHandle.View(func(tx *bolt.Tx) error {
bucket, _, err := getBuckets(tx) bucket, _, err := getBuckets(tx)
if err != nil { if err != nil {
return err return err

View File

@@ -504,7 +504,7 @@ func AddUser(p Provider, user User) error {
} }
err := p.addUser(user) err := p.addUser(user)
if err == nil { if err == nil {
go executeAction(operationAdd, user) go executeAction(operationAdd, user) //nolint:errcheck // the error is used in test cases only
} }
return err return err
} }
@@ -517,7 +517,7 @@ func UpdateUser(p Provider, user User) error {
} }
err := p.updateUser(user) err := p.updateUser(user)
if err == nil { if err == nil {
go executeAction(operationUpdate, user) go executeAction(operationUpdate, user) //nolint:errcheck // the error is used in test cases only
} }
return err return err
} }
@@ -530,7 +530,7 @@ func DeleteUser(p Provider, user User) error {
} }
err := p.deleteUser(user) err := p.deleteUser(user)
if err == nil { if err == nil {
go executeAction(operationDelete, user) go executeAction(operationDelete, user) //nolint:errcheck // the error is used in test cases only
} }
return err return err
} }
@@ -1127,7 +1127,10 @@ func terminateInteractiveAuthProgram(cmd *exec.Cmd, isFinished bool) {
return return
} }
providerLog(logger.LevelInfo, "kill interactive auth program after an unexpected error") providerLog(logger.LevelInfo, "kill interactive auth program after an unexpected error")
cmd.Process.Kill() err := cmd.Process.Kill()
if err != nil {
providerLog(logger.LevelDebug, "error killing interactive auth program: %v", err)
}
} }
func validateKeyboardAuthResponse(response keyboardAuthHookResponse) error { func validateKeyboardAuthResponse(response keyboardAuthHookResponse) error {
@@ -1298,7 +1301,12 @@ func executeKeyboardInteractiveProgram(user User, authHook string, client ssh.Ke
} }
stdin.Close() stdin.Close()
once.Do(func() { terminateInteractiveAuthProgram(cmd, true) }) once.Do(func() { terminateInteractiveAuthProgram(cmd, true) })
go cmd.Process.Wait() go func() {
_, err := cmd.Process.Wait()
if err != nil {
providerLog(logger.LevelWarn, "error waiting for #%v process to exit: %v", authHook, err)
}
}()
return authResult, err return authResult, err
} }
@@ -1461,7 +1469,7 @@ func doExternalAuth(username, password string, pubKey []byte, keyboardInteractiv
var user User var user User
pkey := "" pkey := ""
if len(pubKey) > 0 { if len(pubKey) > 0 {
k, err := ssh.ParsePublicKey([]byte(pubKey)) k, err := ssh.ParsePublicKey(pubKey)
if err != nil { if err != nil {
return user, err return user, err
} }
@@ -1535,9 +1543,9 @@ func executeAction(operation string, user User) {
// we are in a goroutine but if we have to send an HTTP notification we don't want to wait for the // we are in a goroutine but if we have to send an HTTP notification we don't want to wait for the
// end of the command // end of the command
if len(config.Actions.HTTPNotificationURL) > 0 { if len(config.Actions.HTTPNotificationURL) > 0 {
go executeNotificationCommand(operation, user) go executeNotificationCommand(operation, user) //nolint:errcheck // the error is used in test cases only
} else { } else {
executeNotificationCommand(operation, user) executeNotificationCommand(operation, user) //nolint:errcheck // the error is used in test cases only
} }
} }
if len(config.Actions.HTTPNotificationURL) > 0 { if len(config.Actions.HTTPNotificationURL) > 0 {

View File

@@ -128,17 +128,17 @@ func (p MySQLProvider) initializeDatabase() error {
} }
_, err = tx.Exec(sqlUsers) _, err = tx.Exec(sqlUsers)
if err != nil { if err != nil {
tx.Rollback() sqlCommonRollbackTransaction(tx)
return err return err
} }
_, err = tx.Exec(mysqlSchemaTableSQL) _, err = tx.Exec(mysqlSchemaTableSQL)
if err != nil { if err != nil {
tx.Rollback() sqlCommonRollbackTransaction(tx)
return err return err
} }
_, err = tx.Exec(initialDBVersionSQL) _, err = tx.Exec(initialDBVersionSQL)
if err != nil { if err != nil {
tx.Rollback() sqlCommonRollbackTransaction(tx)
return err return err
} }
return tx.Commit() return tx.Commit()
@@ -186,12 +186,12 @@ func updateMySQLDatabase(dbHandle *sql.DB, sql string, newVersion int) error {
} }
_, err = tx.Exec(sql) _, err = tx.Exec(sql)
if err != nil { if err != nil {
tx.Rollback() sqlCommonRollbackTransaction(tx)
return err return err
} }
err = sqlCommonUpdateDatabaseVersionWithTX(tx, newVersion) err = sqlCommonUpdateDatabaseVersionWithTX(tx, newVersion)
if err != nil { if err != nil {
tx.Rollback() sqlCommonRollbackTransaction(tx)
return err return err
} }
return tx.Commit() return tx.Commit()

View File

@@ -126,17 +126,17 @@ func (p PGSQLProvider) initializeDatabase() error {
} }
_, err = tx.Exec(sqlUsers) _, err = tx.Exec(sqlUsers)
if err != nil { if err != nil {
tx.Rollback() sqlCommonRollbackTransaction(tx)
return err return err
} }
_, err = tx.Exec(pgsqlSchemaTableSQL) _, err = tx.Exec(pgsqlSchemaTableSQL)
if err != nil { if err != nil {
tx.Rollback() sqlCommonRollbackTransaction(tx)
return err return err
} }
_, err = tx.Exec(initialDBVersionSQL) _, err = tx.Exec(initialDBVersionSQL)
if err != nil { if err != nil {
tx.Rollback() sqlCommonRollbackTransaction(tx)
return err return err
} }
return tx.Commit() return tx.Commit()
@@ -184,12 +184,12 @@ func updatePGSQLDatabase(dbHandle *sql.DB, sql string, newVersion int) error {
} }
_, err = tx.Exec(sql) _, err = tx.Exec(sql)
if err != nil { if err != nil {
tx.Rollback() sqlCommonRollbackTransaction(tx)
return err return err
} }
err = sqlCommonUpdateDatabaseVersionWithTX(tx, newVersion) err = sqlCommonUpdateDatabaseVersionWithTX(tx, newVersion)
if err != nil { if err != nil {
tx.Rollback() sqlCommonRollbackTransaction(tx)
return err return err
} }
return tx.Commit() return tx.Commit()

View File

@@ -323,7 +323,6 @@ func getUserFromDbRow(row *sql.Row, rows *sql.Rows) (User, error) {
&user.QuotaSize, &user.QuotaFiles, &permissions, &user.UsedQuotaSize, &user.UsedQuotaFiles, &user.LastQuotaUpdate, &user.QuotaSize, &user.QuotaFiles, &permissions, &user.UsedQuotaSize, &user.UsedQuotaFiles, &user.LastQuotaUpdate,
&user.UploadBandwidth, &user.DownloadBandwidth, &user.ExpirationDate, &user.LastLogin, &user.Status, &filters, &fsConfig, &user.UploadBandwidth, &user.DownloadBandwidth, &user.ExpirationDate, &user.LastLogin, &user.Status, &filters, &fsConfig,
&virtualFolders) &virtualFolders)
} else { } else {
err = rows.Scan(&user.ID, &user.Username, &password, &publicKey, &user.HomeDir, &user.UID, &user.GID, &user.MaxSessions, err = rows.Scan(&user.ID, &user.Username, &password, &publicKey, &user.HomeDir, &user.UID, &user.GID, &user.MaxSessions,
&user.QuotaSize, &user.QuotaFiles, &permissions, &user.UsedQuotaSize, &user.UsedQuotaFiles, &user.LastQuotaUpdate, &user.QuotaSize, &user.QuotaFiles, &permissions, &user.UsedQuotaSize, &user.UsedQuotaFiles, &user.LastQuotaUpdate,
@@ -379,6 +378,13 @@ func getUserFromDbRow(row *sql.Row, rows *sql.Rows) (User, error) {
return user, err return user, err
} }
func sqlCommonRollbackTransaction(tx *sql.Tx) {
err := tx.Rollback()
if err != nil {
providerLog(logger.LevelWarn, "error rolling back transaction: %v", err)
}
}
func sqlCommonGetDatabaseVersion(dbHandle *sql.DB) (schemaVersion, error) { func sqlCommonGetDatabaseVersion(dbHandle *sql.DB) (schemaVersion, error) {
var result schemaVersion var result schemaVersion
q := getDatabaseVersionQuery() q := getDatabaseVersionQuery()

View File

@@ -637,7 +637,7 @@ func (u *User) getNotificationFieldsAsSlice(action string) []string {
return []string{action, u.Username, return []string{action, u.Username,
strconv.FormatInt(u.ID, 10), strconv.FormatInt(u.ID, 10),
strconv.FormatInt(int64(u.Status), 10), strconv.FormatInt(int64(u.Status), 10),
strconv.FormatInt(int64(u.ExpirationDate), 10), strconv.FormatInt(u.ExpirationDate, 10),
u.HomeDir, u.HomeDir,
strconv.FormatInt(int64(u.UID), 10), strconv.FormatInt(int64(u.UID), 10),
strconv.FormatInt(int64(u.GID), 10), strconv.FormatInt(int64(u.GID), 10),

View File

@@ -37,6 +37,12 @@ func dumpData(w http.ResponseWriter, r *http.Request) {
return return
} }
outputFile = filepath.Join(backupsPath, outputFile) outputFile = filepath.Join(backupsPath, outputFile)
err := os.MkdirAll(filepath.Dir(outputFile), 0700)
if err != nil {
logger.Warn(logSender, "", "dumping data error: %v, output file: %#v", err, outputFile)
sendAPIResponse(w, r, err, "", getRespStatus(err))
return
}
logger.Debug(logSender, "", "dumping data to: %#v", outputFile) logger.Debug(logSender, "", "dumping data to: %#v", outputFile)
users, err := dataprovider.DumpUsers(dataProvider) users, err := dataprovider.DumpUsers(dataProvider)
@@ -56,7 +62,6 @@ func dumpData(w http.ResponseWriter, r *http.Request) {
}) })
} }
if err == nil { if err == nil {
os.MkdirAll(filepath.Dir(outputFile), 0700)
err = ioutil.WriteFile(outputFile, dump, 0600) err = ioutil.WriteFile(outputFile, dump, 0600)
} }
if err != nil { if err != nil {
@@ -127,10 +132,10 @@ func loadData(w http.ResponseWriter, r *http.Request) {
sendAPIResponse(w, r, err, "", getRespStatus(err)) sendAPIResponse(w, r, err, "", getRespStatus(err))
return return
} }
if needQuotaScan(scanQuota, &user) { if scanQuota == 1 || (scanQuota == 2 && user.HasQuotaRestrictions()) {
if sftpd.AddQuotaScan(user.Username) { if sftpd.AddQuotaScan(user.Username) {
logger.Debug(logSender, "", "starting quota scan for restored user: %#v", user.Username) logger.Debug(logSender, "", "starting quota scan for restored user: %#v", user.Username)
go doQuotaScan(user) go doQuotaScan(user) //nolint:errcheck
} }
} }
} }
@@ -138,10 +143,6 @@ func loadData(w http.ResponseWriter, r *http.Request) {
sendAPIResponse(w, r, err, "Data restored", http.StatusOK) sendAPIResponse(w, r, err, "Data restored", http.StatusOK)
} }
func needQuotaScan(scanQuota int, user *dataprovider.User) bool {
return scanQuota == 1 || (scanQuota == 2 && user.HasQuotaRestrictions())
}
func getLoaddataOptions(r *http.Request) (string, int, int, error) { func getLoaddataOptions(r *http.Request) (string, int, int, error) {
var inputFile string var inputFile string
var err error var err error

View File

@@ -27,7 +27,7 @@ func startQuotaScan(w http.ResponseWriter, r *http.Request) {
return return
} }
if sftpd.AddQuotaScan(user.Username) { if sftpd.AddQuotaScan(user.Username) {
go doQuotaScan(user) go doQuotaScan(user) //nolint:errcheck
sendAPIResponse(w, r, err, "Scan started", http.StatusCreated) sendAPIResponse(w, r, err, "Scan started", http.StatusCreated)
} else { } else {
sendAPIResponse(w, r, err, "Another scan is already in progress", http.StatusConflict) sendAPIResponse(w, r, err, "Another scan is already in progress", http.StatusConflict)
@@ -35,7 +35,7 @@ func startQuotaScan(w http.ResponseWriter, r *http.Request) {
} }
func doQuotaScan(user dataprovider.User) error { func doQuotaScan(user dataprovider.User) error {
defer sftpd.RemoveQuotaScan(user.Username) defer sftpd.RemoveQuotaScan(user.Username) //nolint:errcheck
fs, err := user.GetFilesystem("") fs, err := user.GetFilesystem("")
if err != nil { if err != nil {
logger.Warn(logSender, "", "unable scan quota for user %#v error creating filesystem: %v", user.Username, err) logger.Warn(logSender, "", "unable scan quota for user %#v error creating filesystem: %v", user.Username, err)

View File

@@ -128,10 +128,11 @@ func (c Conf) Initialize(configDir string, profiler bool) error {
} }
// ReloadTLSCertificate reloads the TLS certificate and key from the configured paths // ReloadTLSCertificate reloads the TLS certificate and key from the configured paths
func ReloadTLSCertificate() { func ReloadTLSCertificate() error {
if certMgr != nil { if certMgr != nil {
certMgr.loadCertificate() return certMgr.loadCertificate()
} }
return nil
} }
func getConfigPath(name, configDir string) string { func getConfigPath(name, configDir string) string {

View File

@@ -181,6 +181,10 @@ func TestInitialization(t *testing.T) {
if err == nil { if err == nil {
t.Error("Inizialize must fail") t.Error("Inizialize must fail")
} }
err = httpd.ReloadTLSCertificate()
if err != nil {
t.Error("realoding TLS Certificate must return nil error if no certificate is configured")
}
} }
func TestBasicUserHandling(t *testing.T) { func TestBasicUserHandling(t *testing.T) {
@@ -1105,6 +1109,11 @@ func TestDumpdata(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("unexpected error: %v", err) t.Errorf("unexpected error: %v", err)
} }
// subdir cannot be created
_, _, err = httpd.Dumpdata(filepath.Join("subdir", "bck.json"), "", http.StatusInternalServerError)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
os.Chmod(backupsPath, 0755) os.Chmod(backupsPath, 0755)
} }
dataProvider = dataprovider.GetProvider() dataProvider = dataprovider.GetProvider()

View File

@@ -152,7 +152,7 @@ func renderMessagePage(w http.ResponseWriter, title, body string, statusCode int
} }
func renderInternalServerErrorPage(w http.ResponseWriter, err error) { func renderInternalServerErrorPage(w http.ResponseWriter, err error) {
renderMessagePage(w, page500Title, page400Title, http.StatusInternalServerError, err, "") renderMessagePage(w, page500Title, page500Body, http.StatusInternalServerError, err, "")
} }
func renderBadRequestPage(w http.ResponseWriter, err error) { func renderBadRequestPage(w http.ResponseWriter, err error) {

View File

@@ -93,7 +93,6 @@ func Log(level LogLevel, sender string, connectionID string, format string, v ..
default: default:
Error(sender, connectionID, format, v...) Error(sender, connectionID, format, v...)
} }
} }
// Debug logs at debug level for the specified sender // Debug logs at debug level for the specified sender

View File

@@ -6,6 +6,13 @@ import (
"github.com/prometheus/client_golang/prometheus/promauto" "github.com/prometheus/client_golang/prometheus/promauto"
) )
const (
loginMethodPublicKey = "publickey"
loginMethodKeyboardInteractive = "keyboard-interactive"
loginMethodKeyAndPassword = "publickey+password"
loginMethodKeyAndKeyboardInt = "publickey+keyboard-interactive"
)
var ( var (
// dataproviderAvailability is the metric that reports the availability for the configured data provider // dataproviderAvailability is the metric that reports the availability for the configured data provider
dataproviderAvailability = promauto.NewGauge(prometheus.GaugeOpts{ dataproviderAvailability = promauto.NewGauge(prometheus.GaugeOpts{
@@ -536,13 +543,13 @@ func UpdateDataProviderAvailability(err error) {
func AddLoginAttempt(authMethod string) { func AddLoginAttempt(authMethod string) {
totalLoginAttempts.Inc() totalLoginAttempts.Inc()
switch authMethod { switch authMethod {
case "publickey": case loginMethodPublicKey:
totalKeyLoginAttempts.Inc() totalKeyLoginAttempts.Inc()
case "keyboard-interactive": case loginMethodKeyboardInteractive:
totalInteractiveLoginAttempts.Inc() totalInteractiveLoginAttempts.Inc()
case "publickey+password": case loginMethodKeyAndPassword:
totalKeyAndPasswordLoginAttempts.Inc() totalKeyAndPasswordLoginAttempts.Inc()
case "publickey+keyboard-interactive": case loginMethodKeyAndKeyboardInt:
totalKeyAndKeyIntLoginAttempts.Inc() totalKeyAndKeyIntLoginAttempts.Inc()
default: default:
totalPasswordLoginAttempts.Inc() totalPasswordLoginAttempts.Inc()
@@ -554,13 +561,13 @@ func AddLoginResult(authMethod string, err error) {
if err == nil { if err == nil {
totalLoginOK.Inc() totalLoginOK.Inc()
switch authMethod { switch authMethod {
case "publickey": case loginMethodPublicKey:
totalKeyLoginOK.Inc() totalKeyLoginOK.Inc()
case "keyboard-interactive": case loginMethodKeyboardInteractive:
totalInteractiveLoginOK.Inc() totalInteractiveLoginOK.Inc()
case "publickey+password": case loginMethodKeyAndPassword:
totalKeyAndPasswordLoginOK.Inc() totalKeyAndPasswordLoginOK.Inc()
case "publickey+keyboard-interactive": case loginMethodKeyAndKeyboardInt:
totalKeyAndKeyIntLoginOK.Inc() totalKeyAndKeyIntLoginOK.Inc()
default: default:
totalPasswordLoginOK.Inc() totalPasswordLoginOK.Inc()
@@ -568,13 +575,13 @@ func AddLoginResult(authMethod string, err error) {
} else { } else {
totalLoginFailed.Inc() totalLoginFailed.Inc()
switch authMethod { switch authMethod {
case "publickey": case loginMethodPublicKey:
totalKeyLoginFailed.Inc() totalKeyLoginFailed.Inc()
case "keyboard-interactive": case loginMethodKeyboardInteractive:
totalInteractiveLoginFailed.Inc() totalInteractiveLoginFailed.Inc()
case "publickey+password": case loginMethodKeyAndPassword:
totalKeyAndPasswordLoginFailed.Inc() totalKeyAndPasswordLoginFailed.Inc()
case "publickey+keyboard-interactive": case loginMethodKeyAndKeyboardInt:
totalKeyAndKeyIntLoginFailed.Inc() totalKeyAndKeyIntLoginFailed.Inc()
default: default:
totalPasswordLoginFailed.Inc() totalPasswordLoginFailed.Inc()

View File

@@ -37,10 +37,10 @@ type Service struct {
LogMaxSize int LogMaxSize int
LogMaxBackups int LogMaxBackups int
LogMaxAge int LogMaxAge int
LogCompress bool
LogVerbose bool
PortableMode int PortableMode int
PortableUser dataprovider.User PortableUser dataprovider.User
LogCompress bool
LogVerbose bool
Profiler bool Profiler bool
Shutdown chan bool Shutdown chan bool
} }
@@ -67,7 +67,10 @@ func (s *Service) Start() error {
s.LogMaxSize, s.LogMaxBackups, s.LogMaxAge, s.LogVerbose, s.LogCompress, s.Profiler) s.LogMaxSize, s.LogMaxBackups, s.LogMaxAge, s.LogVerbose, s.LogCompress, s.Profiler)
// in portable mode we don't read configuration from file // in portable mode we don't read configuration from file
if s.PortableMode != 1 { if s.PortableMode != 1 {
config.LoadConfig(s.ConfigDir, s.ConfigFile) err := config.LoadConfig(s.ConfigDir, s.ConfigFile)
if err != nil {
logger.Error(logSender, "", "error loading configuration: %v", err)
}
} }
providerConf := config.GetProviderConf() providerConf := config.GetProviderConf()
@@ -211,7 +214,6 @@ func (s *Service) StartPortableMode(sftpdPort int, enabledSSHCommands []string,
} else { } else {
logger.InfoToConsole("SFTP service advertised via multicast DNS") logger.InfoToConsole("SFTP service advertised via multicast DNS")
} }
} }
sig := make(chan os.Signal, 1) sig := make(chan os.Signal, 1)
signal.Notify(sig, os.Interrupt, syscall.SIGTERM) signal.Notify(sig, os.Interrupt, syscall.SIGTERM)

View File

@@ -83,8 +83,14 @@ loop:
break loop break loop
case svc.ParamChange: case svc.ParamChange:
logger.Debug(logSender, "", "Received reload request") logger.Debug(logSender, "", "Received reload request")
dataprovider.ReloadConfig() err := dataprovider.ReloadConfig()
httpd.ReloadTLSCertificate() if err != nil {
logger.Warn(logSender, "", "error reloading dataprovider configuration: %v", err)
}
err = httpd.ReloadTLSCertificate()
if err != nil {
logger.Warn(logSender, "", "error reloading TLS certificate: %v", err)
}
default: default:
continue loop continue loop
} }

View File

@@ -18,8 +18,14 @@ func registerSigHup() {
go func() { go func() {
for range sig { for range sig {
logger.Debug(logSender, "", "Received reload request") logger.Debug(logSender, "", "Received reload request")
dataprovider.ReloadConfig() err := dataprovider.ReloadConfig()
httpd.ReloadTLSCertificate() if err != nil {
logger.Warn(logSender, "", "error reloading dataprovider configuration: %v", err)
}
err = httpd.ReloadTLSCertificate()
if err != nil {
logger.Warn(logSender, "", "error reloading TLS certificate: %v", err)
}
} }
}() }()
} }

View File

@@ -168,24 +168,19 @@ func (c Connection) Filecmd(request *sftp.Request) error {
if err = c.handleSFTPRename(p, target, request); err != nil { if err = c.handleSFTPRename(p, target, request); err != nil {
return err return err
} }
break
case "Rmdir": case "Rmdir":
return c.handleSFTPRmdir(p, request) return c.handleSFTPRmdir(p, request)
case "Mkdir": case "Mkdir":
err = c.handleSFTPMkdir(p, request) err = c.handleSFTPMkdir(p, request)
if err != nil { if err != nil {
return err return err
} }
break
case "Symlink": case "Symlink":
if err = c.handleSFTPSymlink(p, target, request); err != nil { if err = c.handleSFTPSymlink(p, target, request); err != nil {
return err return err
} }
break
case "Remove": case "Remove":
return c.handleSFTPRemove(p, request) return c.handleSFTPRemove(p, request)
default: default:
return sftp.ErrSSHFxOpUnsupported return sftp.ErrSSHFxOpUnsupported
} }
@@ -335,7 +330,8 @@ func (c Connection) handleSFTPRename(sourcePath string, targetPath string, reque
return vfs.GetSFTPError(c.fs, err) return vfs.GetSFTPError(c.fs, err)
} }
logger.CommandLog(renameLogSender, sourcePath, targetPath, c.User.Username, "", c.ID, c.protocol, -1, -1, "", "", "") logger.CommandLog(renameLogSender, sourcePath, targetPath, c.User.Username, "", c.ID, c.protocol, -1, -1, "", "", "")
go executeAction(newActionNotification(c.User, operationRename, sourcePath, targetPath, "", 0, nil)) // the returned error is used in test cases only, we already log the error inside executeAction
go executeAction(newActionNotification(c.User, operationRename, sourcePath, targetPath, "", 0, nil)) //nolint:errcheck
return nil return nil
} }
@@ -441,9 +437,9 @@ func (c Connection) handleSFTPRemove(filePath string, request *sftp.Request) err
logger.CommandLog(removeLogSender, filePath, "", c.User.Username, "", c.ID, c.protocol, -1, -1, "", "", "") logger.CommandLog(removeLogSender, filePath, "", c.User.Username, "", c.ID, c.protocol, -1, -1, "", "", "")
if fi.Mode()&os.ModeSymlink != os.ModeSymlink { if fi.Mode()&os.ModeSymlink != os.ModeSymlink {
dataprovider.UpdateUserQuota(dataProvider, c.User, -1, -size, false) dataprovider.UpdateUserQuota(dataProvider, c.User, -1, -size, false) //nolint:errcheck
} }
go executeAction(newActionNotification(c.User, operationDelete, filePath, "", "", fi.Size(), nil)) go executeAction(newActionNotification(c.User, operationDelete, filePath, "", "", fi.Size(), nil)) //nolint:errcheck
return sftp.ErrSSHFxOk return sftp.ErrSSHFxOk
} }
@@ -524,7 +520,7 @@ func (c Connection) handleSFTPUploadToExistingFile(pflags sftp.FileOpenFlags, re
minWriteOffset = fileSize minWriteOffset = fileSize
} else { } else {
if vfs.IsLocalOsFs(c.fs) { if vfs.IsLocalOsFs(c.fs) {
dataprovider.UpdateUserQuota(dataProvider, c.User, 0, -fileSize, false) dataprovider.UpdateUserQuota(dataProvider, c.User, 0, -fileSize, false) //nolint:errcheck
} else { } else {
initialSize = fileSize initialSize = fileSize
} }

View File

@@ -1820,21 +1820,6 @@ func TestConnectionStatusStruct(t *testing.T) {
} }
} }
func TestSFTPExtensions(t *testing.T) {
initialSFTPExtensions := sftpExtensions
c := Configuration{}
err := c.configureSFTPExtensions()
if err != nil {
t.Errorf("error configuring SFTP extensions")
}
sftpExtensions = append(sftpExtensions, "invalid@example.com")
err = c.configureSFTPExtensions()
if err == nil {
t.Errorf("configuring invalid SFTP extensions must fail")
}
sftpExtensions = initialSFTPExtensions
}
func TestProxyProtocolVersion(t *testing.T) { func TestProxyProtocolVersion(t *testing.T) {
c := Configuration{ c := Configuration{
ProxyProtocol: 1, ProxyProtocol: 1,

View File

@@ -45,7 +45,6 @@ func (c *scpCommand) handle() error {
if err != nil { if err != nil {
return err return err
} }
} else if commandType == "-f" { } else if commandType == "-f" {
// -f means "from" so download // -f means "from" so download
err = c.readConfirmationMessage() err = c.readConfirmationMessage()
@@ -199,7 +198,7 @@ func (c *scpCommand) handleUploadFile(requestPath, filePath string, sizeToRead i
initialSize := int64(0) initialSize := int64(0)
if !isNewFile { if !isNewFile {
if vfs.IsLocalOsFs(c.connection.fs) { if vfs.IsLocalOsFs(c.connection.fs) {
dataprovider.UpdateUserQuota(dataProvider, c.connection.User, 0, -fileSize, false) dataprovider.UpdateUserQuota(dataProvider, c.connection.User, 0, -fileSize, false) //nolint:errcheck
} else { } else {
initialSize = fileSize initialSize = fileSize
} }
@@ -593,6 +592,7 @@ func (c *scpCommand) readProtocolMessage() (string, error) {
} }
// send an error message and close the channel // send an error message and close the channel
//nolint:errcheck // we don't check write errors here, we have to close the channel anyway
func (c *scpCommand) sendErrorMessage(err error) { func (c *scpCommand) sendErrorMessage(err error) {
c.connection.channel.Write(errMsg) c.connection.channel.Write(errMsg)
c.connection.channel.Write([]byte(c.getMappedError(err).Error())) c.connection.channel.Write([]byte(c.getMappedError(err).Error()))

View File

@@ -3,7 +3,6 @@ package sftpd
import ( import (
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
@@ -29,8 +28,7 @@ const (
) )
var ( var (
sftpExtensions = []string{"posix-rename@openssh.com"} sftpExtensions = []string{"posix-rename@openssh.com"}
errWrongProxyProtoVersion = errors.New("unacceptable proxy protocol version")
) )
// Configuration for the SFTP server // Configuration for the SFTP server
@@ -184,10 +182,11 @@ func (c Configuration) Initialize(configDir string) error {
return err return err
} }
sftp.SetSFTPExtensions(sftpExtensions...) //nolint:errcheck // we configure valid SFTP Extensions so we cannot get an error
c.configureSecurityOptions(serverConfig) c.configureSecurityOptions(serverConfig)
c.configureKeyboardInteractiveAuth(serverConfig) c.configureKeyboardInteractiveAuth(serverConfig)
c.configureLoginBanner(serverConfig, configDir) c.configureLoginBanner(serverConfig, configDir)
c.configureSFTPExtensions()
c.checkSSHCommands() c.checkSSHCommands()
listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", c.BindAddress, c.BindPort)) listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", c.BindAddress, c.BindPort))
@@ -268,26 +267,23 @@ func (c Configuration) configureSecurityOptions(serverConfig *ssh.ServerConfig)
} }
} }
func (c Configuration) configureLoginBanner(serverConfig *ssh.ServerConfig, configDir string) error { func (c Configuration) configureLoginBanner(serverConfig *ssh.ServerConfig, configDir string) {
var err error
if len(c.LoginBannerFile) > 0 { if len(c.LoginBannerFile) > 0 {
bannerFilePath := c.LoginBannerFile bannerFilePath := c.LoginBannerFile
if !filepath.IsAbs(bannerFilePath) { if !filepath.IsAbs(bannerFilePath) {
bannerFilePath = filepath.Join(configDir, bannerFilePath) bannerFilePath = filepath.Join(configDir, bannerFilePath)
} }
var bannerContent []byte bannerContent, err := ioutil.ReadFile(bannerFilePath)
bannerContent, err = ioutil.ReadFile(bannerFilePath)
if err == nil { if err == nil {
banner := string(bannerContent) banner := string(bannerContent)
serverConfig.BannerCallback = func(conn ssh.ConnMetadata) string { serverConfig.BannerCallback = func(conn ssh.ConnMetadata) string {
return string(banner) return banner
} }
} else { } else {
logger.WarnToConsole("unable to read login banner file: %v", err) logger.WarnToConsole("unable to read login banner file: %v", err)
logger.Warn(logSender, "", "unable to read login banner file: %v", err) logger.Warn(logSender, "", "unable to read login banner file: %v", err)
} }
} }
return err
} }
func (c Configuration) configureKeyboardInteractiveAuth(serverConfig *ssh.ServerConfig) { func (c Configuration) configureKeyboardInteractiveAuth(serverConfig *ssh.ServerConfig) {
@@ -319,21 +315,11 @@ func (c Configuration) configureKeyboardInteractiveAuth(serverConfig *ssh.Server
} }
} }
func (c Configuration) configureSFTPExtensions() error {
err := sftp.SetSFTPExtensions(sftpExtensions...)
if err != nil {
logger.WarnToConsole("unable to configure SFTP extensions: %v", err)
logger.Warn(logSender, "", "unable to configure SFTP extensions: %v", err)
}
return err
}
// AcceptInboundConnection handles an inbound connection to the server instance and determines if the request should be served or not. // AcceptInboundConnection handles an inbound connection to the server instance and determines if the request should be served or not.
func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.ServerConfig) { func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.ServerConfig) {
// Before beginning a handshake must be performed on the incoming net.Conn // Before beginning a handshake must be performed on the incoming net.Conn
// we'll set a Deadline for handshake to complete, the default is 2 minutes as OpenSSH // we'll set a Deadline for handshake to complete, the default is 2 minutes as OpenSSH
conn.SetDeadline(time.Now().Add(handshakeTimeout)) conn.SetDeadline(time.Now().Add(handshakeTimeout)) //nolint:errcheck
remoteAddr := conn.RemoteAddr() remoteAddr := conn.RemoteAddr()
sconn, chans, reqs, err := ssh.NewServerConn(conn, config) sconn, chans, reqs, err := ssh.NewServerConn(conn, config)
if err != nil { if err != nil {
@@ -344,12 +330,12 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server
return return
} }
// handshake completed so remove the deadline, we'll use IdleTimeout configuration from now on // handshake completed so remove the deadline, we'll use IdleTimeout configuration from now on
conn.SetDeadline(time.Time{}) conn.SetDeadline(time.Time{}) //nolint:errcheck
var user dataprovider.User var user dataprovider.User
// Unmarshal cannot fails here and even if it fails we'll have a user with no permissions // Unmarshal cannot fails here and even if it fails we'll have a user with no permissions
json.Unmarshal([]byte(sconn.Permissions.Extensions["user"]), &user) json.Unmarshal([]byte(sconn.Permissions.Extensions["user"]), &user) //nolint:errcheck
loginType := sconn.Permissions.Extensions["login_method"] loginType := sconn.Permissions.Extensions["login_method"]
connectionID := hex.EncodeToString(sconn.SessionID()) connectionID := hex.EncodeToString(sconn.SessionID())
@@ -378,7 +364,7 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server
connection.Log(logger.LevelInfo, logSender, "User id: %d, logged in with: %#v, username: %#v, home_dir: %#v remote addr: %#v", connection.Log(logger.LevelInfo, logSender, "User id: %d, logged in with: %#v, username: %#v, home_dir: %#v remote addr: %#v",
user.ID, loginType, user.Username, user.HomeDir, remoteAddr.String()) user.ID, loginType, user.Username, user.HomeDir, remoteAddr.String())
dataprovider.UpdateLastLogin(dataProvider, user) dataprovider.UpdateLastLogin(dataProvider, user) //nolint:errcheck
go ssh.DiscardRequests(reqs) go ssh.DiscardRequests(reqs)
@@ -387,7 +373,7 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server
// know how to handle at this point. // know how to handle at this point.
if newChannel.ChannelType() != "session" { if newChannel.ChannelType() != "session" {
connection.Log(logger.LevelDebug, logSender, "received an unknown channel type: %v", newChannel.ChannelType()) connection.Log(logger.LevelDebug, logSender, "received an unknown channel type: %v", newChannel.ChannelType())
newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") //nolint:errcheck
continue continue
} }
@@ -414,7 +400,7 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server
case "exec": case "exec":
ok = processSSHCommand(req.Payload, &connection, channel, c.EnabledSSHCommands) ok = processSSHCommand(req.Payload, &connection, channel, c.EnabledSSHCommands)
} }
req.Reply(ok, nil) req.Reply(ok, nil) //nolint:errcheck
} }
}(requests) }(requests)
} }
@@ -441,7 +427,6 @@ func (c Configuration) handleSftpConnection(channel ssh.Channel, connection Conn
} }
func (c Configuration) createHandler(connection Connection) sftp.Handlers { func (c Configuration) createHandler(connection Connection) sftp.Handlers {
return sftp.Handlers{ return sftp.Handlers{
FileGet: connection, FileGet: connection,
FilePut: connection, FilePut: connection,

View File

@@ -49,7 +49,7 @@ const (
) )
const ( const (
uploadModeStandard = iota uploadModeStandard = iota //nolint:varcheck,deadcode
uploadModeAtomic uploadModeAtomic
uploadModeAtomicWithResume uploadModeAtomicWithResume
) )
@@ -429,7 +429,7 @@ func removeConnection(c Connection) {
// We only need to ensure that a connection will not remain indefinitely open and so the // We only need to ensure that a connection will not remain indefinitely open and so the
// underlying file descriptor is not released. // underlying file descriptor is not released.
// This should protect us against buggy clients and edge cases. // This should protect us against buggy clients and edge cases.
c.netConn.SetDeadline(time.Now().Add(2 * time.Minute)) c.netConn.SetDeadline(time.Now().Add(2 * time.Minute)) //nolint:errcheck
c.Log(logger.LevelDebug, logSender, "connection removed, num open connections: %v", len(openConnections)) c.Log(logger.LevelDebug, logSender, "connection removed, num open connections: %v", len(openConnections))
} }
@@ -495,9 +495,9 @@ func executeAction(a actionNotification) error {
// we are in a goroutine but if we have to send an HTTP notification we don't want to wait for the // we are in a goroutine but if we have to send an HTTP notification we don't want to wait for the
// end of the command // end of the command
if len(actions.HTTPNotificationURL) > 0 { if len(actions.HTTPNotificationURL) > 0 {
go executeNotificationCommand(a) go executeNotificationCommand(a) //nolint:errcheck
} else { } else {
err = executeNotificationCommand(a) err = executeNotificationCommand(a) //nolint:errcheck
} }
} }
if len(actions.HTTPNotificationURL) > 0 { if len(actions.HTTPNotificationURL) > 0 {

View File

@@ -24,6 +24,8 @@ import (
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
const scpCmdName = "scp"
var ( var (
errQuotaExceeded = errors.New("denying write due to space limit") errQuotaExceeded = errors.New("denying write due to space limit")
errPermissionDenied = errors.New("Permission denied. You don't have the permissions to execute this command") errPermissionDenied = errors.New("Permission denied. You don't have the permissions to execute this command")
@@ -49,7 +51,7 @@ func processSSHCommand(payload []byte, connection *Connection, channel ssh.Chann
name, args, len(args), connection.User.Username, err) name, args, len(args), connection.User.Username, err)
if err == nil && utils.IsStringInSlice(name, enabledSSHCommands) { if err == nil && utils.IsStringInSlice(name, enabledSSHCommands) {
connection.command = msg.Command connection.command = msg.Command
if name == "scp" && len(args) >= 2 { if name == scpCmdName && len(args) >= 2 {
connection.protocol = protocolSCP connection.protocol = protocolSCP
connection.channel = channel connection.channel = channel
scpCommand := scpCommand{ scpCommand := scpCommand{
@@ -58,10 +60,10 @@ func processSSHCommand(payload []byte, connection *Connection, channel ssh.Chann
connection: *connection, connection: *connection,
args: args}, args: args},
} }
go scpCommand.handle() go scpCommand.handle() //nolint:errcheck
return true return true
} }
if name != "scp" { if name != scpCmdName {
connection.protocol = protocolSSH connection.protocol = protocolSSH
connection.channel = channel connection.channel = channel
sshCommand := sshCommand{ sshCommand := sshCommand{
@@ -69,7 +71,7 @@ func processSSHCommand(payload []byte, connection *Connection, channel ssh.Chann
connection: *connection, connection: *connection,
args: args, args: args,
} }
go sshCommand.handle() go sshCommand.handle() //nolint:errcheck
return true return true
} }
} else { } else {
@@ -95,7 +97,7 @@ func (c *sshCommand) handle() error {
c.sendExitStatus(nil) c.sendExitStatus(nil)
} else if c.command == "pwd" { } else if c.command == "pwd" {
// hard coded response to "/" // hard coded response to "/"
c.connection.channel.Write([]byte("/\n")) c.connection.channel.Write([]byte("/\n")) //nolint:errcheck
c.sendExitStatus(nil) c.sendExitStatus(nil)
} }
return nil return nil
@@ -125,7 +127,7 @@ func (c *sshCommand) handleHashCommands() error {
if err != nil && err != io.EOF { if err != nil && err != io.EOF {
return c.sendErrorResponse(err) return c.sendErrorResponse(err)
} }
h.Write(buf[:n]) h.Write(buf[:n]) //nolint:errcheck
response = fmt.Sprintf("%x -\n", h.Sum(nil)) response = fmt.Sprintf("%x -\n", h.Sum(nil))
} else { } else {
sshPath := c.getDestPath() sshPath := c.getDestPath()
@@ -146,7 +148,7 @@ func (c *sshCommand) handleHashCommands() error {
} }
response = fmt.Sprintf("%v %v\n", hash, sshPath) response = fmt.Sprintf("%v %v\n", hash, sshPath)
} }
c.connection.channel.Write([]byte(response)) c.connection.channel.Write([]byte(response)) //nolint:errcheck
c.sendExitStatus(nil) c.sendExitStatus(nil)
return nil return nil
} }
@@ -184,8 +186,9 @@ func (c *sshCommand) executeSystemCommand(command systemCommand) error {
closeCmdOnError := func() { closeCmdOnError := func() {
c.connection.Log(logger.LevelDebug, logSenderSSH, "kill cmd: %#v and close ssh channel after read or write error", c.connection.Log(logger.LevelDebug, logSenderSSH, "kill cmd: %#v and close ssh channel after read or write error",
c.connection.command) c.connection.command)
command.cmd.Process.Kill() killerr := command.cmd.Process.Kill()
c.connection.channel.Close() closerr := c.connection.channel.Close()
c.connection.Log(logger.LevelDebug, logSenderSSH, "kill cmd error: %v close channel error: %v", killerr, closerr)
} }
var once sync.Once var once sync.Once
commandResponse := make(chan bool) commandResponse := make(chan bool)
@@ -214,7 +217,7 @@ func (c *sshCommand) executeSystemCommand(command systemCommand) error {
lock: new(sync.Mutex), lock: new(sync.Mutex),
} }
addTransfer(&transfer) addTransfer(&transfer)
defer removeTransfer(&transfer) defer removeTransfer(&transfer) //nolint:errcheck
w, e := transfer.copyFromReaderToWriter(stdin, c.connection.channel, remainingQuotaSize) w, e := transfer.copyFromReaderToWriter(stdin, c.connection.channel, remainingQuotaSize)
c.connection.Log(logger.LevelDebug, logSenderSSH, "command: %#v, copy from remote command to sdtin ended, written: %v, "+ c.connection.Log(logger.LevelDebug, logSenderSSH, "command: %#v, copy from remote command to sdtin ended, written: %v, "+
"initial remaining quota: %v, err: %v", c.connection.command, w, remainingQuotaSize, e) "initial remaining quota: %v, err: %v", c.connection.command, w, remainingQuotaSize, e)
@@ -242,7 +245,7 @@ func (c *sshCommand) executeSystemCommand(command systemCommand) error {
lock: new(sync.Mutex), lock: new(sync.Mutex),
} }
addTransfer(&transfer) addTransfer(&transfer)
defer removeTransfer(&transfer) defer removeTransfer(&transfer) //nolint:errcheck
w, e := transfer.copyFromReaderToWriter(c.connection.channel, stdout, 0) w, e := transfer.copyFromReaderToWriter(c.connection.channel, stdout, 0)
c.connection.Log(logger.LevelDebug, logSenderSSH, "command: %#v, copy from sdtout to remote command ended, written: %v err: %v", c.connection.Log(logger.LevelDebug, logSenderSSH, "command: %#v, copy from sdtout to remote command ended, written: %v err: %v",
c.connection.command, w, e) c.connection.command, w, e)
@@ -271,7 +274,7 @@ func (c *sshCommand) executeSystemCommand(command systemCommand) error {
lock: new(sync.Mutex), lock: new(sync.Mutex),
} }
addTransfer(&transfer) addTransfer(&transfer)
defer removeTransfer(&transfer) defer removeTransfer(&transfer) //nolint:errcheck
w, e := transfer.copyFromReaderToWriter(c.connection.channel.Stderr(), stderr, 0) w, e := transfer.copyFromReaderToWriter(c.connection.channel.Stderr(), stderr, 0)
c.connection.Log(logger.LevelDebug, logSenderSSH, "command: %#v, copy from sdterr to remote command ended, written: %v err: %v", c.connection.Log(logger.LevelDebug, logSenderSSH, "command: %#v, copy from sdterr to remote command ended, written: %v err: %v",
c.connection.command, w, e) c.connection.command, w, e)
@@ -284,7 +287,7 @@ func (c *sshCommand) executeSystemCommand(command systemCommand) error {
<-commandResponse <-commandResponse
err = command.cmd.Wait() err = command.cmd.Wait()
c.sendExitStatus(err) c.sendExitStatus(err)
c.rescanHomeDir() c.rescanHomeDir() //nolint:errcheck
return err return err
} }
@@ -403,7 +406,7 @@ func (c *sshCommand) rescanHomeDir() error {
c.connection.Log(logger.LevelDebug, logSenderSSH, "user home dir scanned, user: %#v, dir: %#v, error: %v", c.connection.Log(logger.LevelDebug, logSenderSSH, "user home dir scanned, user: %#v, dir: %#v, error: %v",
c.connection.User.Username, c.connection.User.HomeDir, err) c.connection.User.Username, c.connection.User.HomeDir, err)
} }
RemoveQuotaScan(c.connection.User.Username) RemoveQuotaScan(c.connection.User.Username) //nolint:errcheck
} }
return err return err
} }
@@ -435,7 +438,7 @@ func (c *sshCommand) getMappedError(err error) error {
func (c *sshCommand) sendErrorResponse(err error) error { func (c *sshCommand) sendErrorResponse(err error) error {
errorString := fmt.Sprintf("%v: %v %v\n", c.command, c.getDestPath(), c.getMappedError(err)) errorString := fmt.Sprintf("%v: %v %v\n", c.command, c.getDestPath(), c.getMappedError(err))
c.connection.channel.Write([]byte(errorString)) c.connection.channel.Write([]byte(errorString)) //nolint:errcheck
c.sendExitStatus(err) c.sendExitStatus(err)
return err return err
} }
@@ -453,10 +456,10 @@ func (c *sshCommand) sendExitStatus(err error) {
exitStatus := sshSubsystemExitStatus{ exitStatus := sshSubsystemExitStatus{
Status: status, Status: status,
} }
c.connection.channel.SendRequest("exit-status", false, ssh.Marshal(&exitStatus)) c.connection.channel.SendRequest("exit-status", false, ssh.Marshal(&exitStatus)) //nolint:errcheck
c.connection.channel.Close() c.connection.channel.Close()
// for scp we notify single uploads/downloads // for scp we notify single uploads/downloads
if c.command != "scp" { if c.command != scpCmdName {
metrics.SSHCommandCompleted(err) metrics.SSHCommandCompleted(err)
realPath := c.getDestPath() realPath := c.getDestPath()
if len(realPath) > 0 { if len(realPath) > 0 {
@@ -465,7 +468,7 @@ func (c *sshCommand) sendExitStatus(err error) {
realPath = p realPath = p
} }
} }
go executeAction(newActionNotification(c.connection.User, operationSSHCmd, realPath, "", c.command, 0, err)) go executeAction(newActionNotification(c.connection.User, operationSSHCmd, realPath, "", c.command, 0, err)) //nolint:errcheck
} }
} }

View File

@@ -38,14 +38,14 @@ type Transfer struct {
connectionID string connectionID string
transferType int transferType int
lastActivity time.Time lastActivity time.Time
isNewFile bool
protocol string protocol string
transferError error transferError error
isFinished bool
minWriteOffset int64 minWriteOffset int64
expectedSize int64 expectedSize int64
initialSize int64 initialSize int64
lock *sync.Mutex lock *sync.Mutex
isNewFile bool
isFinished bool
} }
// TransferError is called if there is an unexpected error. // TransferError is called if there is an unexpected error.
@@ -126,6 +126,7 @@ func (t *Transfer) Close() error {
return errTransferClosed return errTransferClosed
} }
err := t.closeIO() err := t.closeIO()
defer removeTransfer(t) //nolint:errcheck
t.isFinished = true t.isFinished = true
numFiles := 0 numFiles := 0
if t.isNewFile { if t.isNewFile {
@@ -150,10 +151,10 @@ func (t *Transfer) Close() error {
elapsed := time.Since(t.start).Nanoseconds() / 1000000 elapsed := time.Since(t.start).Nanoseconds() / 1000000
if t.transferType == transferDownload { if t.transferType == transferDownload {
logger.TransferLog(downloadLogSender, t.path, elapsed, t.bytesSent, t.user.Username, t.connectionID, t.protocol) logger.TransferLog(downloadLogSender, t.path, elapsed, t.bytesSent, t.user.Username, t.connectionID, t.protocol)
go executeAction(newActionNotification(t.user, operationDownload, t.path, "", "", t.bytesSent, t.transferError)) go executeAction(newActionNotification(t.user, operationDownload, t.path, "", "", t.bytesSent, t.transferError)) //nolint:errcheck
} else { } else {
logger.TransferLog(uploadLogSender, t.path, elapsed, t.bytesReceived, t.user.Username, t.connectionID, t.protocol) logger.TransferLog(uploadLogSender, t.path, elapsed, t.bytesReceived, t.user.Username, t.connectionID, t.protocol)
go executeAction(newActionNotification(t.user, operationUpload, t.path, "", "", t.bytesReceived+t.minWriteOffset, go executeAction(newActionNotification(t.user, operationUpload, t.path, "", "", t.bytesReceived+t.minWriteOffset, //nolint:errcheck
t.transferError)) t.transferError))
} }
if t.transferError != nil { if t.transferError != nil {
@@ -162,7 +163,6 @@ func (t *Transfer) Close() error {
err = t.transferError err = t.transferError
} }
} }
removeTransfer(t)
t.updateQuota(numFiles) t.updateQuota(numFiles)
return err return err
} }
@@ -185,7 +185,7 @@ func (t *Transfer) updateQuota(numFiles int) bool {
return false return false
} }
if t.transferType == transferUpload && (numFiles != 0 || t.bytesReceived > 0) { if t.transferType == transferUpload && (numFiles != 0 || t.bytesReceived > 0) {
dataprovider.UpdateUserQuota(dataProvider, t.user, numFiles, t.bytesReceived-t.initialSize, false) dataprovider.UpdateUserQuota(dataProvider, t.user, numFiles, t.bytesReceived-t.initialSize, false) //nolint:errcheck
return true return true
} }
return false return false

View File

@@ -11,7 +11,6 @@ type FileInfo struct {
sizeInBytes int64 sizeInBytes int64
modTime time.Time modTime time.Time
mode os.FileMode mode os.FileMode
sys interface{}
} }
// NewFileInfo creates file info. // NewFileInfo creates file info.

View File

@@ -2,9 +2,10 @@
package vfs package vfs
import "syscall" import (
"os"
import "os" "syscall"
)
var ( var (
defaultUID, defaultGID int defaultUID, defaultGID int

View File

@@ -159,7 +159,7 @@ func (fs GCSFs) Open(name string) (*os.File, *pipeat.PipeReaderAt, func(), error
defer cancelFn() defer cancelFn()
defer objectReader.Close() defer objectReader.Close()
n, err := io.Copy(w, objectReader) n, err := io.Copy(w, objectReader)
w.CloseWithError(err) w.CloseWithError(err) //nolint:errcheck // the returned error is always null
fsLog(fs, logger.LevelDebug, "download completed, path: %#v size: %v, err: %v", name, n, err) fsLog(fs, logger.LevelDebug, "download completed, path: %#v size: %v, err: %v", name, n, err)
metrics.GCSTransferCompleted(n, 1, err) metrics.GCSTransferCompleted(n, 1, err)
}() }()
@@ -183,7 +183,7 @@ func (fs GCSFs) Create(name string, flag int) (*os.File, *pipeat.PipeWriterAt, f
defer cancelFn() defer cancelFn()
defer objectWriter.Close() defer objectWriter.Close()
n, err := io.Copy(objectWriter, r) n, err := io.Copy(objectWriter, r)
r.CloseWithError(err) r.CloseWithError(err) //nolint:errcheck // the returned error is always null
fsLog(fs, logger.LevelDebug, "upload completed, path: %#v, readed bytes: %v, err: %v", name, n, err) fsLog(fs, logger.LevelDebug, "upload completed, path: %#v, readed bytes: %v, err: %v", name, n, err)
metrics.GCSTransferCompleted(n, 0, err) metrics.GCSTransferCompleted(n, 0, err)
}() }()

View File

@@ -71,7 +71,6 @@ func (OsFs) Create(name string, flag int) (*os.File, *pipeat.PipeWriterAt, func(
f, err = os.Create(name) f, err = os.Create(name)
} else { } else {
f, err = os.OpenFile(name, flag, 0666) f, err = os.OpenFile(name, flag, 0666)
} }
return f, nil, nil, err return f, nil, nil, err
} }

View File

@@ -195,7 +195,7 @@ func (fs S3Fs) Open(name string) (*os.File, *pipeat.PipeReaderAt, func(), error)
Bucket: aws.String(fs.config.Bucket), Bucket: aws.String(fs.config.Bucket),
Key: aws.String(key), Key: aws.String(key),
}) })
w.CloseWithError(err) w.CloseWithError(err) //nolint:errcheck // the returned error is always null
fsLog(fs, logger.LevelDebug, "download completed, path: %#v size: %v, err: %v", name, n, err) fsLog(fs, logger.LevelDebug, "download completed, path: %#v size: %v, err: %v", name, n, err)
metrics.S3TransferCompleted(n, 1, err) metrics.S3TransferCompleted(n, 1, err)
}() }()
@@ -222,7 +222,7 @@ func (fs S3Fs) Create(name string, flag int) (*os.File, *pipeat.PipeWriterAt, fu
u.Concurrency = fs.config.UploadConcurrency u.Concurrency = fs.config.UploadConcurrency
u.PartSize = fs.config.UploadPartSize u.PartSize = fs.config.UploadPartSize
}) })
r.CloseWithError(err) r.CloseWithError(err) //nolint:errcheck // the returned error is always null
fsLog(fs, logger.LevelDebug, "upload completed, path: %#v, response: %v, readed bytes: %v, err: %+v", fsLog(fs, logger.LevelDebug, "upload completed, path: %#v, response: %v, readed bytes: %v, err: %+v",
name, response, r.GetReadedBytes(), err) name, response, r.GetReadedBytes(), err)
metrics.S3TransferCompleted(r.GetReadedBytes(), 0, err) metrics.S3TransferCompleted(r.GetReadedBytes(), 0, err)
@@ -529,13 +529,3 @@ func (fs *S3Fs) checkIfBucketExists() error {
metrics.S3HeadBucketCompleted(err) metrics.S3HeadBucketCompleted(err)
return err return err
} }
func (fs *S3Fs) getObjectDetails(key string) (*s3.HeadObjectOutput, error) {
ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout))
defer cancelFn()
input := &s3.HeadObjectInput{
Bucket: aws.String(fs.config.Bucket),
Key: aws.String(key),
}
return fs.svc.HeadObjectWithContext(ctx, input)
}