diff --git a/README.md b/README.md index 7433be43..dd6dcd2c 100644 --- a/README.md +++ b/README.md @@ -154,6 +154,8 @@ You can also reset your provider by using the `resetprovider` sub-command. Take sftpgo resetprovider --help ``` +:warning: Please note that some data providers (e.g. MySQL and CockroachDB) do not support schema changes within a transaction, this means that you may end up with an inconsistent schema if migrations are forcibly aborted or if they are run concurrently by multiple instances. + ## Create the first admin To start using SFTPGo you need to create an admin user, you can do it in several ways: diff --git a/cmd/startsubsys.go b/cmd/startsubsys.go index 21037899..1eda94c0 100644 --- a/cmd/startsubsys.go +++ b/cmd/startsubsys.go @@ -99,7 +99,6 @@ Command-line flags should be specified in the Subsystem declaration. dataProviderConf.Driver, dataprovider.MemoryDataProviderName) dataProviderConf.Driver = dataprovider.MemoryDataProviderName dataProviderConf.Name = "" - dataProviderConf.PreferDatabaseCredentials = true } config.SetProviderConf(dataProviderConf) err = dataprovider.Initialize(dataProviderConf, configDir, false) diff --git a/config/config.go b/config/config.go index de901ae7..d0dff8b1 100644 --- a/config/config.go +++ b/config/config.go @@ -302,14 +302,13 @@ func Init() { MinEntropy: 0, }, }, - PasswordCaching: true, - UpdateMode: 0, - PreferDatabaseCredentials: true, - DelayedQuotaUpdate: 0, - CreateDefaultAdmin: false, - NamingRules: 0, - IsShared: 0, - BackupsPath: "backups", + PasswordCaching: true, + UpdateMode: 0, + DelayedQuotaUpdate: 0, + CreateDefaultAdmin: false, + NamingRules: 0, + IsShared: 0, + BackupsPath: "backups", AutoBackup: dataprovider.AutoBackup{ Enabled: true, Hour: "0", @@ -1610,7 +1609,6 @@ func setViperDefaults() { viper.SetDefault("data_provider.external_auth_hook", globalConf.ProviderConf.ExternalAuthHook) viper.SetDefault("data_provider.external_auth_scope", globalConf.ProviderConf.ExternalAuthScope) viper.SetDefault("data_provider.credentials_path", globalConf.ProviderConf.CredentialsPath) - viper.SetDefault("data_provider.prefer_database_credentials", globalConf.ProviderConf.PreferDatabaseCredentials) viper.SetDefault("data_provider.pre_login_hook", globalConf.ProviderConf.PreLoginHook) viper.SetDefault("data_provider.post_login_hook", globalConf.ProviderConf.PostLoginHook) viper.SetDefault("data_provider.post_login_scope", globalConf.ProviderConf.PostLoginScope) diff --git a/dataprovider/bolt.go b/dataprovider/bolt.go index cb5b12b6..3bcdc5c1 100644 --- a/dataprovider/bolt.go +++ b/dataprovider/bolt.go @@ -20,7 +20,7 @@ import ( ) const ( - boltDatabaseVersion = 17 + boltDatabaseVersion = 18 ) var ( @@ -703,10 +703,6 @@ func (p *BoltProvider) dumpUsers() ([]User, error) { if err != nil { return err } - err = addCredentialsToUser(&user) - if err != nil { - return err - } users = append(users, user) } return err @@ -1888,8 +1884,13 @@ func (p *BoltProvider) migrateDatabase() error { providerLog(logger.LevelError, "%v", err) logger.ErrorToConsole("%v", err) return err - case version == 15, version == 16: - return updateBoltDatabaseVersion(p.dbHandle, 17) + case version == 15, version == 16, version == 17: + logger.InfoToConsole(fmt.Sprintf("updating database version: %v -> 18", version)) + providerLog(logger.LevelInfo, "updating database version: %v -> 18", version) + if err = importGCSCredentials(); err != nil { + return err + } + return updateBoltDatabaseVersion(p.dbHandle, 18) default: if version > boltDatabaseVersion { providerLog(logger.LevelError, "database version %v is newer than the supported one: %v", version, @@ -1911,7 +1912,7 @@ func (p *BoltProvider) revertDatabase(targetVersion int) error { return errors.New("current version match target version, nothing to do") } switch dbVersion.Version { - case 16, 17: + case 16, 17, 18: return updateBoltDatabaseVersion(p.dbHandle, 15) default: return fmt.Errorf("database version not handled: %v", dbVersion.Version) diff --git a/dataprovider/dataprovider.go b/dataprovider/dataprovider.go index 2c628fbb..9ee61bd4 100644 --- a/dataprovider/dataprovider.go +++ b/dataprovider/dataprovider.go @@ -364,10 +364,6 @@ type Config struct { UpdateMode int `json:"update_mode" mapstructure:"update_mode"` // PasswordHashing defines the configuration for password hashing PasswordHashing PasswordHashing `json:"password_hashing" mapstructure:"password_hashing"` - // PreferDatabaseCredentials indicates whether credential files (currently used for Google - // Cloud Storage) should be stored in the database instead of in the directory specified by - // CredentialsPath. - PreferDatabaseCredentials bool `json:"prefer_database_credentials" mapstructure:"prefer_database_credentials"` // PasswordValidation defines the password validation rules PasswordValidation PasswordValidation `json:"password_validation" mapstructure:"password_validation"` // Verifying argon2 passwords has a high memory and computational cost, @@ -2350,47 +2346,6 @@ func validateBaseFilters(filters *sdk.BaseUserFilters) error { return validateFiltersPatternExtensions(filters) } -func saveGCSCredentials(fsConfig *vfs.Filesystem, helper vfs.ValidatorHelper) error { - if fsConfig.Provider != sdk.GCSFilesystemProvider { - return nil - } - if fsConfig.GCSConfig.Credentials.GetPayload() == "" { - return nil - } - if config.PreferDatabaseCredentials { - if fsConfig.GCSConfig.Credentials.IsPlain() { - fsConfig.GCSConfig.Credentials.SetAdditionalData(helper.GetEncryptionAdditionalData()) - err := fsConfig.GCSConfig.Credentials.Encrypt() - if err != nil { - return err - } - } - return nil - } - if fsConfig.GCSConfig.Credentials.IsPlain() { - fsConfig.GCSConfig.Credentials.SetAdditionalData(helper.GetEncryptionAdditionalData()) - err := fsConfig.GCSConfig.Credentials.Encrypt() - if err != nil { - return util.NewValidationError(fmt.Sprintf("could not encrypt GCS credentials: %v", err)) - } - } - creds, err := json.Marshal(fsConfig.GCSConfig.Credentials) - if err != nil { - return util.NewValidationError(fmt.Sprintf("could not marshal GCS credentials: %v", err)) - } - credentialsFilePath := helper.GetGCSCredentialsFilePath() - err = os.MkdirAll(filepath.Dir(credentialsFilePath), 0700) - if err != nil { - return util.NewValidationError(fmt.Sprintf("could not create GCS credentials dir: %v", err)) - } - err = os.WriteFile(credentialsFilePath, creds, 0600) - if err != nil { - return util.NewValidationError(fmt.Sprintf("could not save GCS credentials: %v", err)) - } - fsConfig.GCSConfig.Credentials = kms.NewEmptySecret() - return nil -} - func validateBaseParams(user *User) error { if user.Username == "" { return util.NewValidationError("username is mandatory") @@ -2426,7 +2381,7 @@ func validateBaseParams(user *User) error { user.UploadDataTransfer = 0 user.DownloadDataTransfer = 0 } - return user.FsConfig.Validate(user) + return user.FsConfig.Validate(user.GetEncryptionAdditionalData()) } func hashPlainPassword(plainPwd string) (string, error) { @@ -2482,10 +2437,7 @@ func ValidateFolder(folder *vfs.BaseVirtualFolder) error { if folder.HasRedactedSecret() { return errors.New("cannot save a folder with a redacted secret") } - if err := folder.FsConfig.Validate(folder); err != nil { - return err - } - return saveGCSCredentials(&folder.FsConfig, folder) + return folder.FsConfig.Validate(folder.GetEncryptionAdditionalData()) } // ValidateUser returns an error if the user is not valid @@ -2532,7 +2484,7 @@ func ValidateUser(user *User) error { if user.Filters.TOTPConfig.Enabled && util.IsStringInSlice(sdk.WebClientMFADisabled, user.Filters.WebClient) { return util.NewValidationError("two-factor authentication cannot be disabled for a user with an active configuration") } - return saveGCSCredentials(&user.FsConfig, user) + return nil } func isPasswordOK(user *User, password string) (bool, error) { @@ -2781,54 +2733,6 @@ func comparePbkdf2PasswordAndHash(password, hashedPassword string) (bool, error) return subtle.ConstantTimeCompare(df, expected) == 1, nil } -func addCredentialsToUser(user *User) error { - if err := addFolderCredentialsToUser(user); err != nil { - return err - } - if user.FsConfig.Provider != sdk.GCSFilesystemProvider { - return nil - } - if user.FsConfig.GCSConfig.AutomaticCredentials > 0 { - return nil - } - - // Don't read from file if credentials have already been set - if user.FsConfig.GCSConfig.Credentials.IsValid() { - return nil - } - - cred, err := os.ReadFile(user.GetGCSCredentialsFilePath()) - if err != nil { - return err - } - return json.Unmarshal(cred, &user.FsConfig.GCSConfig.Credentials) -} - -func addFolderCredentialsToUser(user *User) error { - for idx := range user.VirtualFolders { - f := &user.VirtualFolders[idx] - if f.FsConfig.Provider != sdk.GCSFilesystemProvider { - continue - } - if f.FsConfig.GCSConfig.AutomaticCredentials > 0 { - continue - } - // Don't read from file if credentials have already been set - if f.FsConfig.GCSConfig.Credentials.IsValid() { - continue - } - cred, err := os.ReadFile(f.GetGCSCredentialsFilePath()) - if err != nil { - return err - } - err = json.Unmarshal(cred, f.FsConfig.GCSConfig.Credentials) - if err != nil { - return err - } - } - return nil -} - func getSSLMode() string { if config.Driver == PGSQLDataProviderName || config.Driver == CockroachDataProviderName { switch config.SSLMode { @@ -3674,6 +3578,89 @@ func isLastActivityRecent(lastActivity int64, minDelay time.Duration) bool { return diff < minDelay } +func addGCSCredentialsToFolder(folder *vfs.BaseVirtualFolder) (bool, error) { + if folder.FsConfig.Provider != sdk.GCSFilesystemProvider { + return false, nil + } + if folder.FsConfig.GCSConfig.AutomaticCredentials > 0 { + return false, nil + } + if folder.FsConfig.GCSConfig.Credentials.IsValid() { + return false, nil + } + cred, err := os.ReadFile(folder.GetGCSCredentialsFilePath()) + if err != nil { + return false, err + } + err = json.Unmarshal(cred, &folder.FsConfig.GCSConfig.Credentials) + if err != nil { + return false, err + } + return true, nil +} + +func addGCSCredentialsToUser(user *User) (bool, error) { + if user.FsConfig.Provider != sdk.GCSFilesystemProvider { + return false, nil + } + if user.FsConfig.GCSConfig.AutomaticCredentials > 0 { + return false, nil + } + if user.FsConfig.GCSConfig.Credentials.IsValid() { + return false, nil + } + cred, err := os.ReadFile(user.GetGCSCredentialsFilePath()) + if err != nil { + return false, err + } + err = json.Unmarshal(cred, &user.FsConfig.GCSConfig.Credentials) + if err != nil { + return false, err + } + return true, nil +} + +func importGCSCredentials() error { + folders, err := provider.dumpFolders() + if err != nil { + return fmt.Errorf("unable to get folders: %w", err) + } + for idx := range folders { + folder := &folders[idx] + added, err := addGCSCredentialsToFolder(folder) + if err != nil { + return fmt.Errorf("unable to add GCS credentials to folder %#v: %w", folder.Name, err) + } + if added { + logger.InfoToConsole("importing GCS credentials for folder %#v", folder.Name) + providerLog(logger.LevelInfo, "importing GCS credentials for folder %#v", folder.Name) + if err = provider.updateFolder(folder); err != nil { + return fmt.Errorf("unable to update folder %#v: %w", folder.Name, err) + } + } + } + + users, err := provider.dumpUsers() + if err != nil { + return fmt.Errorf("unable to get users: %w", err) + } + for idx := range users { + user := &users[idx] + added, err := addGCSCredentialsToUser(user) + if err != nil { + return fmt.Errorf("unable to add GCS credentials to user %#v: %w", user.Username, err) + } + if added { + logger.InfoToConsole("importing GCS credentials for user %#v", user.Username) + providerLog(logger.LevelInfo, "importing GCS credentials for user %#v", user.Username) + if err = provider.updateUser(user); err != nil { + return fmt.Errorf("unable to update user %#v: %w", user.Username, err) + } + } + } + return nil +} + func getConfigPath(name, configDir string) string { if !util.IsFileInputValid(name) { return "" diff --git a/dataprovider/group.go b/dataprovider/group.go index 6e7b7c44..20223fb8 100644 --- a/dataprovider/group.go +++ b/dataprovider/group.go @@ -106,11 +106,6 @@ func (g *Group) GetEncryptionAdditionalData() string { return fmt.Sprintf("group_%v", g.Name) } -// GetGCSCredentialsFilePath returns the path for GCS credentials -func (g *Group) GetGCSCredentialsFilePath() string { - return filepath.Join(credentialsDirPath, "groups", fmt.Sprintf("%v_gcs_credentials.json", g.Name)) -} - // HasRedactedSecret returns true if the user has a redacted secret func (g *Group) hasRedactedSecret() bool { for idx := range g.VirtualFolders { @@ -150,10 +145,7 @@ func (g *Group) validateUserSettings() error { g.UserSettings.HomeDir)) } } - if err := g.UserSettings.FsConfig.Validate(g); err != nil { - return err - } - if err := saveGCSCredentials(&g.UserSettings.FsConfig, g); err != nil { + if err := g.UserSettings.FsConfig.Validate(g.GetEncryptionAdditionalData()); err != nil { return err } if g.UserSettings.TotalDataTransfer > 0 { diff --git a/dataprovider/memory.go b/dataprovider/memory.go index ca413601..883c2e3a 100644 --- a/dataprovider/memory.go +++ b/dataprovider/memory.go @@ -421,10 +421,6 @@ func (p *MemoryProvider) dumpUsers() ([]User, error) { u := p.dbHandle.users[username] user := u.getACopy() p.addVirtualFoldersToUser(&user) - err = addCredentialsToUser(&user) - if err != nil { - return users, err - } users = append(users, user) } return users, err diff --git a/dataprovider/mysql.go b/dataprovider/mysql.go index 8759c0f2..c6c69744 100644 --- a/dataprovider/mysql.go +++ b/dataprovider/mysql.go @@ -572,6 +572,8 @@ func (p *MySQLProvider) migrateDatabase() error { return updateMySQLDatabaseFromV15(p.dbHandle) case version == 16: return updateMySQLDatabaseFromV16(p.dbHandle) + case version == 17: + return updateMySQLDatabaseFromV17(p.dbHandle) default: if version > sqlDatabaseVersion { providerLog(logger.LevelError, "database version %v is newer than the supported one: %v", version, @@ -598,6 +600,8 @@ func (p *MySQLProvider) revertDatabase(targetVersion int) error { return downgradeMySQLDatabaseFromV16(p.dbHandle) case 17: return downgradeMySQLDatabaseFromV17(p.dbHandle) + case 18: + return downgradeMySQLDatabaseFromV18(p.dbHandle) default: return fmt.Errorf("database version not handled: %v", dbVersion.Version) } @@ -616,7 +620,18 @@ func updateMySQLDatabaseFromV15(dbHandle *sql.DB) error { } func updateMySQLDatabaseFromV16(dbHandle *sql.DB) error { - return updateMySQLDatabaseFrom16To17(dbHandle) + if err := updateMySQLDatabaseFrom16To17(dbHandle); err != nil { + return err + } + return updateMySQLDatabaseFromV17(dbHandle) +} + +func updateMySQLDatabaseFromV17(dbHandle *sql.DB) error { + return updateMySQLDatabaseFrom17To18(dbHandle) +} + +func downgradeMySQLDatabaseFromV16(dbHandle *sql.DB) error { + return downgradeMySQLDatabaseFrom16To15(dbHandle) } func downgradeMySQLDatabaseFromV17(dbHandle *sql.DB) error { @@ -626,8 +641,11 @@ func downgradeMySQLDatabaseFromV17(dbHandle *sql.DB) error { return downgradeMySQLDatabaseFromV16(dbHandle) } -func downgradeMySQLDatabaseFromV16(dbHandle *sql.DB) error { - return downgradeMySQLDatabaseFrom16To15(dbHandle) +func downgradeMySQLDatabaseFromV18(dbHandle *sql.DB) error { + if err := downgradeMySQLDatabaseFrom18To17(dbHandle); err != nil { + return err + } + return downgradeMySQLDatabaseFromV17(dbHandle) } func updateMySQLDatabaseFrom15To16(dbHandle *sql.DB) error { @@ -653,6 +671,15 @@ func updateMySQLDatabaseFrom16To17(dbHandle *sql.DB) error { return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 17) } +func updateMySQLDatabaseFrom17To18(dbHandle *sql.DB) error { + logger.InfoToConsole("updating database version: 17 -> 18") + providerLog(logger.LevelInfo, "updating database version: 17 -> 18") + if err := importGCSCredentials(); err != nil { + return err + } + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, nil, 18) +} + func downgradeMySQLDatabaseFrom16To15(dbHandle *sql.DB) error { logger.InfoToConsole("downgrading database version: 16 -> 15") providerLog(logger.LevelInfo, "downgrading database version: 16 -> 15") @@ -674,3 +701,9 @@ func downgradeMySQLDatabaseFrom17To16(dbHandle *sql.DB) error { sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix) return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 16) } + +func downgradeMySQLDatabaseFrom18To17(dbHandle *sql.DB) error { + logger.InfoToConsole("downgrading database version: 18 -> 17") + providerLog(logger.LevelInfo, "downgrading database version: 18 -> 17") + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, nil, 17) +} diff --git a/dataprovider/pgsql.go b/dataprovider/pgsql.go index ae280a33..026e5c0c 100644 --- a/dataprovider/pgsql.go +++ b/dataprovider/pgsql.go @@ -547,6 +547,8 @@ func (p *PGSQLProvider) migrateDatabase() error { return updatePGSQLDatabaseFromV15(p.dbHandle) case version == 16: return updatePGSQLDatabaseFromV16(p.dbHandle) + case version == 17: + return updatePGSQLDatabaseFromV17(p.dbHandle) default: if version > sqlDatabaseVersion { providerLog(logger.LevelError, "database version %v is newer than the supported one: %v", version, @@ -573,6 +575,8 @@ func (p *PGSQLProvider) revertDatabase(targetVersion int) error { return downgradePGSQLDatabaseFromV16(p.dbHandle) case 17: return downgradePGSQLDatabaseFromV17(p.dbHandle) + case 18: + return downgradePGSQLDatabaseFromV18(p.dbHandle) default: return fmt.Errorf("database version not handled: %v", dbVersion.Version) } @@ -591,7 +595,18 @@ func updatePGSQLDatabaseFromV15(dbHandle *sql.DB) error { } func updatePGSQLDatabaseFromV16(dbHandle *sql.DB) error { - return updatePGSQLDatabaseFrom16To17(dbHandle) + if err := updatePGSQLDatabaseFrom16To17(dbHandle); err != nil { + return err + } + return updatePGSQLDatabaseFromV17(dbHandle) +} + +func updatePGSQLDatabaseFromV17(dbHandle *sql.DB) error { + return updatePGSQLDatabaseFrom17To18(dbHandle) +} + +func downgradePGSQLDatabaseFromV16(dbHandle *sql.DB) error { + return downgradePGSQLDatabaseFrom16To15(dbHandle) } func downgradePGSQLDatabaseFromV17(dbHandle *sql.DB) error { @@ -601,8 +616,11 @@ func downgradePGSQLDatabaseFromV17(dbHandle *sql.DB) error { return downgradePGSQLDatabaseFromV16(dbHandle) } -func downgradePGSQLDatabaseFromV16(dbHandle *sql.DB) error { - return downgradePGSQLDatabaseFrom16To15(dbHandle) +func downgradePGSQLDatabaseFromV18(dbHandle *sql.DB) error { + if err := downgradePGSQLDatabaseFrom18To17(dbHandle); err != nil { + return err + } + return downgradePGSQLDatabaseFromV17(dbHandle) } func updatePGSQLDatabaseFrom15To16(dbHandle *sql.DB) error { @@ -649,6 +667,15 @@ func updatePGSQLDatabaseFrom16To17(dbHandle *sql.DB) error { return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 17) } +func updatePGSQLDatabaseFrom17To18(dbHandle *sql.DB) error { + logger.InfoToConsole("updating database version: 17 -> 18") + providerLog(logger.LevelInfo, "updating database version: 17 -> 18") + if err := importGCSCredentials(); err != nil { + return err + } + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, nil, 18) +} + func downgradePGSQLDatabaseFrom16To15(dbHandle *sql.DB) error { logger.InfoToConsole("downgrading database version: 16 -> 15") providerLog(logger.LevelInfo, "downgrading database version: 16 -> 15") @@ -675,3 +702,9 @@ func downgradePGSQLDatabaseFrom17To16(dbHandle *sql.DB) error { sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix) return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 16) } + +func downgradePGSQLDatabaseFrom18To17(dbHandle *sql.DB) error { + logger.InfoToConsole("downgrading database version: 18 -> 17") + providerLog(logger.LevelInfo, "downgrading database version: 18 -> 17") + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, nil, 17) +} diff --git a/dataprovider/sqlcommon.go b/dataprovider/sqlcommon.go index 317ffacd..d3e712a6 100644 --- a/dataprovider/sqlcommon.go +++ b/dataprovider/sqlcommon.go @@ -7,6 +7,7 @@ import ( "encoding/json" "errors" "fmt" + "runtime/debug" "strings" "time" @@ -19,7 +20,7 @@ import ( ) const ( - sqlDatabaseVersion = 17 + sqlDatabaseVersion = 18 defaultSQLQueryTimeout = 10 * time.Second longSQLQueryTimeout = 60 * time.Second ) @@ -913,10 +914,19 @@ func sqlCommonValidateUserAndPubKey(username string, pubKey []byte, isSSHCert bo return checkUserAndPubKey(&user, pubKey, isSSHCert) } -func sqlCommonCheckAvailability(dbHandle *sql.DB) error { +func sqlCommonCheckAvailability(dbHandle *sql.DB) (err error) { + defer func() { + if r := recover(); r != nil { + providerLog(logger.LevelError, "panic in check provider availability, stack trace: %v", string(debug.Stack())) + err = errors.New("unable to check provider status") + } + }() + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() - return dbHandle.PingContext(ctx) + + err = dbHandle.PingContext(ctx) + return } func sqlCommonUpdateTransferQuota(username string, uploadSize, downloadSize int64, reset bool, dbHandle *sql.DB) error { @@ -1219,10 +1229,6 @@ func sqlCommonDumpUsers(dbHandle sqlQuerier) ([]User, error) { if err != nil { return users, err } - err = addCredentialsToUser(&u) - if err != nil { - return users, err - } users = append(users, u) } err = rows.Err() diff --git a/dataprovider/sqlite.go b/dataprovider/sqlite.go index b7f44fd3..554388fa 100644 --- a/dataprovider/sqlite.go +++ b/dataprovider/sqlite.go @@ -519,6 +519,8 @@ func (p *SQLiteProvider) migrateDatabase() error { return updateSQLiteDatabaseFromV15(p.dbHandle) case version == 16: return updateSQLiteDatabaseFromV16(p.dbHandle) + case version == 17: + return updateSQLiteDatabaseFromV17(p.dbHandle) default: if version > sqlDatabaseVersion { providerLog(logger.LevelError, "database version %v is newer than the supported one: %v", version, @@ -545,6 +547,8 @@ func (p *SQLiteProvider) revertDatabase(targetVersion int) error { return downgradeSQLiteDatabaseFromV16(p.dbHandle) case 17: return downgradeSQLiteDatabaseFromV17(p.dbHandle) + case 18: + return downgradeSQLiteDatabaseFromV18(p.dbHandle) default: return fmt.Errorf("database version not handled: %v", dbVersion.Version) } @@ -563,7 +567,14 @@ func updateSQLiteDatabaseFromV15(dbHandle *sql.DB) error { } func updateSQLiteDatabaseFromV16(dbHandle *sql.DB) error { - return updateSQLiteDatabaseFrom16To17(dbHandle) + if err := updateSQLiteDatabaseFrom16To17(dbHandle); err != nil { + return err + } + return updateSQLiteDatabaseFromV17(dbHandle) +} + +func updateSQLiteDatabaseFromV17(dbHandle *sql.DB) error { + return updateSQLiteDatabaseFrom17To18(dbHandle) } func downgradeSQLiteDatabaseFromV16(dbHandle *sql.DB) error { @@ -577,6 +588,13 @@ func downgradeSQLiteDatabaseFromV17(dbHandle *sql.DB) error { return downgradeSQLiteDatabaseFromV16(dbHandle) } +func downgradeSQLiteDatabaseFromV18(dbHandle *sql.DB) error { + if err := downgradeSQLiteDatabaseFrom18To17(dbHandle); err != nil { + return err + } + return downgradeSQLiteDatabaseFromV17(dbHandle) +} + func updateSQLiteDatabaseFrom15To16(dbHandle *sql.DB) error { logger.InfoToConsole("updating database version: 15 -> 16") providerLog(logger.LevelInfo, "updating database version: 15 -> 16") @@ -606,6 +624,15 @@ func updateSQLiteDatabaseFrom16To17(dbHandle *sql.DB) error { return setPragmaFK(dbHandle, "ON") } +func updateSQLiteDatabaseFrom17To18(dbHandle *sql.DB) error { + logger.InfoToConsole("updating database version: 17 -> 18") + providerLog(logger.LevelInfo, "updating database version: 17 -> 18") + if err := importGCSCredentials(); err != nil { + return err + } + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, nil, 18) +} + func downgradeSQLiteDatabaseFrom16To15(dbHandle *sql.DB) error { logger.InfoToConsole("downgrading database version: 16 -> 15") providerLog(logger.LevelInfo, "downgrading database version: 16 -> 15") @@ -634,6 +661,12 @@ func downgradeSQLiteDatabaseFrom17To16(dbHandle *sql.DB) error { return setPragmaFK(dbHandle, "ON") } +func downgradeSQLiteDatabaseFrom18To17(dbHandle *sql.DB) error { + logger.InfoToConsole("downgrading database version: 18 -> 17") + providerLog(logger.LevelInfo, "downgrading database version: 18 -> 17") + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, nil, 17) +} + func setPragmaFK(dbHandle *sql.DB, value string) error { ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout) defer cancel() diff --git a/dataprovider/user.go b/dataprovider/user.go index 11ed53ef..294e334e 100644 --- a/dataprovider/user.go +++ b/dataprovider/user.go @@ -141,9 +141,7 @@ func (u *User) getRootFs(connectionID string) (fs vfs.Fs, err error) { case sdk.S3FilesystemProvider: return vfs.NewS3Fs(connectionID, u.GetHomeDir(), "", u.FsConfig.S3Config) case sdk.GCSFilesystemProvider: - config := u.FsConfig.GCSConfig - config.CredentialFile = u.GetGCSCredentialsFilePath() - return vfs.NewGCSFs(connectionID, u.GetHomeDir(), "", config) + return vfs.NewGCSFs(connectionID, u.GetHomeDir(), "", u.FsConfig.GCSConfig) case sdk.AzureBlobFilesystemProvider: return vfs.NewAzBlobFs(connectionID, u.GetHomeDir(), "", u.FsConfig.AzBlobConfig) case sdk.CryptedFilesystemProvider: diff --git a/docs/full-configuration.md b/docs/full-configuration.md index 0adbd360..9ba92f27 100644 --- a/docs/full-configuration.md +++ b/docs/full-configuration.md @@ -202,7 +202,6 @@ The configuration file contains the following sections: - `external_auth_hook`, string. Absolute path to an external program or an HTTP URL to invoke for users authentication. See [External Authentication](./external-auth.md) for more details. Leave empty to disable. - `external_auth_scope`, integer. 0 means all supported authentication scopes (passwords, public keys and keyboard interactive). 1 means passwords only. 2 means public keys only. 4 means key keyboard interactive only. 8 means TLS certificate. The flags can be combined, for example 6 means public keys and keyboard interactive - `credentials_path`, string. It defines the directory for storing user provided credential files such as Google Cloud Storage credentials. This can be an absolute path or a path relative to the config dir - - `prefer_database_credentials`, boolean. When `true`, users' Google Cloud Storage credentials will be written to the data provider instead of disk, though pre-existing credentials on disk will be used as a fallback. When `false`, they will be written to the directory specified by `credentials_path`. :warning: Deprecation warning: this setting is deprecated and it will be removed in future versions, we'll use `true` as default and will remove `prefer_database_credentials` and `credentials_path`. - `pre_login_hook`, string. Absolute path to an external program or an HTTP URL to invoke to modify user details just before the login. See [Dynamic user modification](./dynamic-user-mod.md) for more details. Leave empty to disable. - `post_login_hook`, string. Absolute path to an external program or an HTTP URL to invoke to notify a successful or failed login. See [Post-login hook](./post-login-hook.md) for more details. Leave empty to disable. - `post_login_scope`, defines the scope for the post-login hook. 0 means notify both failed and successful logins. 1 means notify failed logins. 2 means notify successful logins. diff --git a/ftpd/ftpd_test.go b/ftpd/ftpd_test.go index 89cc4fe6..75c237ce 100644 --- a/ftpd/ftpd_test.go +++ b/ftpd/ftpd_test.go @@ -1813,24 +1813,6 @@ func TestLoginWithDatabaseCredentials(t *testing.T) { u.FsConfig.Provider = sdk.GCSFilesystemProvider u.FsConfig.GCSConfig.Bucket = "test" u.FsConfig.GCSConfig.Credentials = kms.NewPlainSecret(`{ "type": "service_account" }`) - - providerConf := config.GetProviderConf() - providerConf.PreferDatabaseCredentials = true - credentialsFile := filepath.Join(providerConf.CredentialsPath, fmt.Sprintf("%v_gcs_credentials.json", u.Username)) - if !filepath.IsAbs(credentialsFile) { - credentialsFile = filepath.Join(configDir, credentialsFile) - } - - assert.NoError(t, dataprovider.Close()) - - err := dataprovider.Initialize(providerConf, configDir, true) - assert.NoError(t, err) - - if _, err = os.Stat(credentialsFile); err == nil { - // remove the credentials file - assert.NoError(t, os.Remove(credentialsFile)) - } - user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.GCSConfig.Credentials.GetStatus()) @@ -1838,8 +1820,6 @@ func TestLoginWithDatabaseCredentials(t *testing.T) { assert.Empty(t, user.FsConfig.GCSConfig.Credentials.GetAdditionalData()) assert.Empty(t, user.FsConfig.GCSConfig.Credentials.GetKey()) - assert.NoFileExists(t, credentialsFile) - client, err := getFTPClient(user, false, nil) if assert.NoError(t, err) { err = client.Quit() @@ -1850,23 +1830,9 @@ func TestLoginWithDatabaseCredentials(t *testing.T) { assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) - - assert.NoError(t, dataprovider.Close()) - assert.NoError(t, config.LoadConfig(configDir, "")) - providerConf = config.GetProviderConf() - assert.NoError(t, dataprovider.Initialize(providerConf, configDir, true)) } func TestLoginInvalidFs(t *testing.T) { - err := dataprovider.Close() - assert.NoError(t, err) - err = config.LoadConfig(configDir, "") - assert.NoError(t, err) - providerConf := config.GetProviderConf() - providerConf.PreferDatabaseCredentials = false - err = dataprovider.Initialize(providerConf, configDir, true) - assert.NoError(t, err) - u := getTestUser() u.FsConfig.Provider = sdk.GCSFilesystemProvider u.FsConfig.GCSConfig.Bucket = "test" @@ -1874,16 +1840,6 @@ func TestLoginInvalidFs(t *testing.T) { user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) - providerConf = config.GetProviderConf() - credentialsFile := filepath.Join(providerConf.CredentialsPath, fmt.Sprintf("%v_gcs_credentials.json", u.Username)) - if !filepath.IsAbs(credentialsFile) { - credentialsFile = filepath.Join(configDir, credentialsFile) - } - - // now remove the credentials file so the filesystem creation will fail - err = os.Remove(credentialsFile) - assert.NoError(t, err) - client, err := getFTPClient(user, false, nil) if !assert.Error(t, err) { err = client.Quit() @@ -1893,14 +1849,6 @@ func TestLoginInvalidFs(t *testing.T) { assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) - - err = dataprovider.Close() - assert.NoError(t, err) - err = config.LoadConfig(configDir, "") - assert.NoError(t, err) - providerConf = config.GetProviderConf() - err = dataprovider.Initialize(providerConf, configDir, true) - assert.NoError(t, err) } func TestClientClose(t *testing.T) { diff --git a/go.mod b/go.mod index 911469f9..aa78be9f 100644 --- a/go.mod +++ b/go.mod @@ -12,9 +12,9 @@ require ( github.com/aws/aws-sdk-go-v2/config v1.15.4 github.com/aws/aws-sdk-go-v2/credentials v1.12.0 github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.4 - github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.6 + github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.7 github.com/aws/aws-sdk-go-v2/service/marketplacemetering v1.13.4 - github.com/aws/aws-sdk-go-v2/service/s3 v1.26.6 + github.com/aws/aws-sdk-go-v2/service/s3 v1.26.7 github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.15.6 github.com/aws/aws-sdk-go-v2/service/sts v1.16.4 github.com/cockroachdb/cockroach-go/v2 v2.2.8 @@ -84,7 +84,7 @@ require ( github.com/aws/aws-sdk-go-v2/internal/ini v1.3.11 // indirect github.com/aws/aws-sdk-go-v2/internal/v4a v1.0.1 // indirect github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.9.1 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.1.4 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.1.5 // indirect github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.4 // indirect github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.13.4 // indirect github.com/aws/aws-sdk-go-v2/service/sso v1.11.4 // indirect @@ -130,7 +130,7 @@ require ( github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/oklog/run v1.1.0 // indirect github.com/pelletier/go-toml v1.9.5 // indirect - github.com/pelletier/go-toml/v2 v2.0.0-beta.8 // indirect + github.com/pelletier/go-toml/v2 v2.0.0 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/power-devops/perfstat v0.0.0-20220216144756-c35f1ee13d7c // indirect diff --git a/go.sum b/go.sum index 15e5a502..ef384c99 100644 --- a/go.sum +++ b/go.sum @@ -150,8 +150,8 @@ github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.3/go.mod h1:uk1vhHHERfSVCUnq github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.4 h1:FP8gquGeGHHdfY6G5llaMQDF+HAf20VKc8opRwmjf04= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.4/go.mod h1:u/s5/Z+ohUQOPXl00m2yJVyioWDECsbpXTQlaqSlufc= github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.3/go.mod h1:0dHuD2HZZSiwfJSy1FO5bX1hQ1TxVV1QXXjpn3XUE44= -github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.6 h1:6Q/ITRl/PoOAcbLkT3EOpch/6w9n/YNN6a/v+dfuBY8= -github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.6/go.mod h1:sj1vB2ZjQ1PQOWc89SyhEJs838UIpDcsa3HylyczQO0= +github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.7 h1:h1y9Jn+1VTEjB4TG5prxtQjW9DM5o8y7Cu9ZdNmkXWA= +github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.7/go.mod h1:NZ2wPktB/I111CyzF3ezVf8jrAg/PqKeYkdR11oBWeU= github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.9/go.mod h1:AnVH5pvai0pAF4lXRq0bmhbes1u9R8wTE+g+183bZNM= github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.10 h1:uFWgo6mGJI1n17nbcvSc6fxVuR3xLNqvXt12JCnEcT8= github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.10/go.mod h1:F+EZtuIwjlv35kRJPyBGcsA4f7bnSoz15zOQ2lJq1Z4= @@ -166,8 +166,8 @@ github.com/aws/aws-sdk-go-v2/internal/v4a v1.0.1/go.mod h1:l/BbcfqDCT3hePawhy4ZR github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.9.1 h1:T4pFel53bkHjL2mMo+4DKE6r6AuoZnM0fg7k1/ratr4= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.9.1/go.mod h1:GeUru+8VzrTXV/83XyMJ80KpH8xO89VPoUileyNQ+tc= github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.1.3/go.mod h1:Seb8KNmD6kVTjwRjVEgOT5hPin6sq+v4C2ycJQDwuH8= -github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.1.4 h1:HXy33+dHXT6WYvnAtIvcQ7Zh4ppeAccz8ofi5bzsQ/A= -github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.1.4/go.mod h1:S8TVP66AAkMMdYYCNZGvrdEq9YRm+qLXjio4FqRnrEE= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.1.5 h1:9LSZqt4v1JiehyZTrQnRFf2mY/awmyYNNY/b7zqtduU= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.1.5/go.mod h1:S8TVP66AAkMMdYYCNZGvrdEq9YRm+qLXjio4FqRnrEE= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.3/go.mod h1:wlY6SVjuwvh3TVRpTqdy4I1JpBFLX4UGeKZdWntaocw= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.4 h1:b16QW0XWl0jWjLABFc1A+uh145Oqv+xDcObNk0iQgUk= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.4/go.mod h1:uKkN7qmSIsNJVyMtxNQoCEYMvFEXbOg9fwCJPdfp2u8= @@ -178,8 +178,8 @@ github.com/aws/aws-sdk-go-v2/service/kms v1.16.3/go.mod h1:QuiHPBqlOFCi4LqdSskYY github.com/aws/aws-sdk-go-v2/service/marketplacemetering v1.13.4 h1:qmHavnjRtgdH54nyG4iEk6ZCde9m2S++32INurhaNTk= github.com/aws/aws-sdk-go-v2/service/marketplacemetering v1.13.4/go.mod h1:CloMDruFIVZJ8qv2OsY5ENIqzg5c0eeTciVVW3KHdvE= github.com/aws/aws-sdk-go-v2/service/s3 v1.26.3/go.mod h1:g1qvDuRsJY+XghsV6zg00Z4KJ7DtFFCx8fJD2a491Ak= -github.com/aws/aws-sdk-go-v2/service/s3 v1.26.6 h1:RyI53C9+8xxZ3zrllJwzZjI6/FePzxNv3pvh59Ir0aE= -github.com/aws/aws-sdk-go-v2/service/s3 v1.26.6/go.mod h1:FMXuMpmEOLQUDnQLMjsJ2jJGN7jpji1pQ59Kii+IM4U= +github.com/aws/aws-sdk-go-v2/service/s3 v1.26.7 h1:ZEPH6aBywdyn5LGr7hSNEwuPaKpKZodX0R9AjPj5A7c= +github.com/aws/aws-sdk-go-v2/service/s3 v1.26.7/go.mod h1:iMYipLPXlWpBJ0KFX7QJHZ84rBydHBY8as2aQICTPWk= github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.15.4/go.mod h1:PJc8s+lxyU8rrre0/4a0pn2wgwiDvOEzoOjcJUBr67o= github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.15.6 h1:m+mxqLIrGq7GJo5qw4rHn8BbUqHrvxvwFx54N1Pglvw= github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.15.6/go.mod h1:Z+i6uqZgCOBXhNoEGoRm/ZaLsaJA9rGUAmkVKM/3+g4= @@ -654,8 +654,8 @@ github.com/otiai10/mint v1.3.3 h1:7JgpsBaN0uMkyju4tbYHu0mnM55hNKVYLsXmwr15NQI= github.com/otiai10/mint v1.3.3/go.mod h1:/yxELlJQ0ufhjUwhshSj+wFjZ78CnZ48/1wtmBH1OTc= github.com/pelletier/go-toml v1.9.5 h1:4yBQzkHv+7BHq2PQUZF3Mx0IYxG7LsP222s7Agd3ve8= github.com/pelletier/go-toml v1.9.5/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= -github.com/pelletier/go-toml/v2 v2.0.0-beta.8 h1:dy81yyLYJDwMTifq24Oi/IslOslRrDSb3jwDggjz3Z0= -github.com/pelletier/go-toml/v2 v2.0.0-beta.8/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo= +github.com/pelletier/go-toml/v2 v2.0.0 h1:P7Bq0SaI8nsexyay5UAyDo+ICWy5MQPgEZ5+l8JQTKo= +github.com/pelletier/go-toml/v2 v2.0.0/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo= github.com/pires/go-proxyproto v0.6.2 h1:KAZ7UteSOt6urjme6ZldyFm4wDe/z0ZUP0Yv0Dos0d8= github.com/pires/go-proxyproto v0.6.2/go.mod h1:Odh9VFOZJCf9G8cLW5o435Xf1J95Jw9Gw5rnCjcwzAY= github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4/go.mod h1:4OwLy04Bl9Ef3GJJCoec+30X3LQs/0/m4HFRt/2LUSA= diff --git a/httpd/httpd_test.go b/httpd/httpd_test.go index a629a9d9..429a8b8b 100644 --- a/httpd/httpd_test.go +++ b/httpd/httpd_test.go @@ -3227,103 +3227,6 @@ func TestUserS3Config(t *testing.T) { assert.NoError(t, err) } -func TestUserGCSConfig(t *testing.T) { - err := dataprovider.Close() - assert.NoError(t, err) - err = config.LoadConfig(configDir, "") - assert.NoError(t, err) - providerConf := config.GetProviderConf() - providerConf.PreferDatabaseCredentials = false - providerConf.CredentialsPath = credentialsPath - err = dataprovider.Initialize(providerConf, configDir, true) - assert.NoError(t, err) - - user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) - assert.NoError(t, err) - err = os.RemoveAll(credentialsPath) - assert.NoError(t, err) - err = os.MkdirAll(credentialsPath, 0700) - assert.NoError(t, err) - user.FsConfig.Provider = sdk.GCSFilesystemProvider - user.FsConfig.GCSConfig.Bucket = "test" - user.FsConfig.GCSConfig.Credentials = kms.NewPlainSecret("fake credentials") //nolint:goconst - user, bb, err := httpdtest.UpdateUser(user, http.StatusOK, "") - assert.NoError(t, err, string(bb)) - credentialFile := filepath.Join(credentialsPath, fmt.Sprintf("%v_gcs_credentials.json", user.Username)) - assert.FileExists(t, credentialFile) - creds, err := os.ReadFile(credentialFile) - assert.NoError(t, err) - secret := kms.NewEmptySecret() - err = json.Unmarshal(creds, secret) - assert.NoError(t, err) - err = secret.Decrypt() - assert.NoError(t, err) - assert.Equal(t, "fake credentials", secret.GetPayload()) - user.FsConfig.GCSConfig.Credentials = kms.NewSecret(sdkkms.SecretStatusSecretBox, "fake encrypted credentials", "", "") - user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") - assert.NoError(t, err) - assert.FileExists(t, credentialFile) - creds, err = os.ReadFile(credentialFile) - assert.NoError(t, err) - secret = kms.NewEmptySecret() - err = json.Unmarshal(creds, secret) - assert.NoError(t, err) - err = secret.Decrypt() - assert.NoError(t, err) - assert.Equal(t, "fake credentials", secret.GetPayload()) - _, err = httpdtest.RemoveUser(user, http.StatusOK) - assert.NoError(t, err) - user.Password = defaultPassword - user.ID = 0 - user.CreatedAt = 0 - user.FsConfig.GCSConfig.Credentials = kms.NewSecret(sdkkms.SecretStatusSecretBox, "fake credentials", "", "") - _, _, err = httpdtest.AddUser(user, http.StatusCreated) - assert.Error(t, err) - user.FsConfig.GCSConfig.Credentials.SetStatus(sdkkms.SecretStatusPlain) - user, body, err := httpdtest.AddUser(user, http.StatusCreated) - assert.NoError(t, err, string(body)) - err = os.RemoveAll(credentialsPath) - assert.NoError(t, err) - err = os.MkdirAll(credentialsPath, 0700) - assert.NoError(t, err) - user.FsConfig.GCSConfig.Credentials = kms.NewEmptySecret() - user.FsConfig.GCSConfig.AutomaticCredentials = 1 - user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") - assert.NoError(t, err) - assert.NoFileExists(t, credentialFile) - user.FsConfig.GCSConfig = vfs.GCSFsConfig{} - user.FsConfig.Provider = sdk.S3FilesystemProvider - user.FsConfig.S3Config.Bucket = "test1" - user.FsConfig.S3Config.Region = "us-east-1" - user.FsConfig.S3Config.AccessKey = "Server-Access-Key1" - user.FsConfig.S3Config.AccessSecret = kms.NewPlainSecret("secret") - user.FsConfig.S3Config.Endpoint = "http://localhost:9000" - user.FsConfig.S3Config.KeyPrefix = "somedir/subdir" - user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") - assert.NoError(t, err) - user.FsConfig.S3Config = vfs.S3FsConfig{} - user.FsConfig.Provider = sdk.GCSFilesystemProvider - user.FsConfig.GCSConfig.Bucket = "test1" - user.FsConfig.GCSConfig.Credentials = kms.NewPlainSecret("fake credentials") - user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") - assert.NoError(t, err) - - _, err = httpdtest.RemoveUser(user, http.StatusOK) - assert.NoError(t, err) - - err = dataprovider.Close() - assert.NoError(t, err) - err = config.LoadConfig(configDir, "") - assert.NoError(t, err) - providerConf = config.GetProviderConf() - providerConf.BackupsPath = backupsPath - providerConf.CredentialsPath = credentialsPath - err = os.RemoveAll(credentialsPath) - assert.NoError(t, err) - err = dataprovider.Initialize(providerConf, configDir, true) - assert.NoError(t, err) -} - func TestUserAzureBlobConfig(t *testing.T) { user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) @@ -3570,15 +3473,6 @@ func TestUserSFTPFs(t *testing.T) { } func TestUserHiddenFields(t *testing.T) { - err := dataprovider.Close() - assert.NoError(t, err) - err = config.LoadConfig(configDir, "") - assert.NoError(t, err) - providerConf := config.GetProviderConf() - providerConf.PreferDatabaseCredentials = true - err = dataprovider.Initialize(providerConf, configDir, true) - assert.NoError(t, err) - // sensitive data must be hidden but not deleted from the dataprovider usernames := []string{"user1", "user2", "user3", "user4", "user5"} u1 := getTestUser() @@ -3786,18 +3680,6 @@ func TestUserHiddenFields(t *testing.T) { assert.NoError(t, err) _, err = httpdtest.RemoveUser(user5, http.StatusOK) assert.NoError(t, err) - - err = dataprovider.Close() - assert.NoError(t, err) - err = config.LoadConfig(configDir, "") - assert.NoError(t, err) - providerConf = config.GetProviderConf() - providerConf.BackupsPath = backupsPath - providerConf.CredentialsPath = credentialsPath - err = os.RemoveAll(credentialsPath) - assert.NoError(t, err) - err = dataprovider.Initialize(providerConf, configDir, true) - assert.NoError(t, err) } func TestSecretObject(t *testing.T) { @@ -9734,16 +9616,6 @@ func TestSFTPLoopError(t *testing.T) { } func TestLoginInvalidFs(t *testing.T) { - err := dataprovider.Close() - assert.NoError(t, err) - err = config.LoadConfig(configDir, "") - assert.NoError(t, err) - providerConf := config.GetProviderConf() - providerConf.PreferDatabaseCredentials = false - providerConf.CredentialsPath = credentialsPath - err = dataprovider.Initialize(providerConf, configDir, true) - assert.NoError(t, err) - u := getTestUser() u.Filters.AllowAPIKeyAuth = true u.FsConfig.Provider = sdk.GCSFilesystemProvider @@ -9758,15 +9630,6 @@ func TestLoginInvalidFs(t *testing.T) { }, http.StatusCreated) assert.NoError(t, err) - credentialsFile := filepath.Join(credentialsPath, fmt.Sprintf("%v_gcs_credentials.json", u.Username)) - if !filepath.IsAbs(credentialsFile) { - credentialsFile = filepath.Join(configDir, credentialsFile) - } - - // now remove the credentials file so the filesystem creation will fail - err = os.Remove(credentialsFile) - assert.NoError(t, err) - _, err = getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) assert.Error(t, err) @@ -9783,18 +9646,6 @@ func TestLoginInvalidFs(t *testing.T) { assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) - - err = dataprovider.Close() - assert.NoError(t, err) - err = config.LoadConfig(configDir, "") - assert.NoError(t, err) - providerConf = config.GetProviderConf() - providerConf.BackupsPath = backupsPath - providerConf.CredentialsPath = credentialsPath - err = os.RemoveAll(credentialsPath) - assert.NoError(t, err) - err = dataprovider.Initialize(providerConf, configDir, true) - assert.NoError(t, err) } func TestWebClientChangePwd(t *testing.T) { diff --git a/service/service_portable.go b/service/service_portable.go index 4ec15f7b..98787a8f 100644 --- a/service/service_portable.go +++ b/service/service_portable.go @@ -46,7 +46,6 @@ func (s *Service) StartPortableMode(sftpdPort, ftpPort, webdavPort int, enabledS dataProviderConf := config.GetProviderConf() dataProviderConf.Driver = dataprovider.MemoryDataProviderName dataProviderConf.Name = "" - dataProviderConf.PreferDatabaseCredentials = true config.SetProviderConf(dataProviderConf) httpdConf := config.GetHTTPDConfig() httpdConf.Bindings = nil diff --git a/sftpd/scp.go b/sftpd/scp.go index 211b3b4c..63fde638 100644 --- a/sftpd/scp.go +++ b/sftpd/scp.go @@ -33,7 +33,7 @@ type scpCommand struct { func (c *scpCommand) handle() (err error) { defer func() { if r := recover(); r != nil { - logger.Error(logSender, "", "panic in handle scp command: %#v stack strace: %v", r, string(debug.Stack())) + logger.Error(logSender, "", "panic in handle scp command: %#v stack trace: %v", r, string(debug.Stack())) err = common.ErrGenericFailure } }() diff --git a/sftpd/server.go b/sftpd/server.go index b1f8fa7b..d0ad805d 100644 --- a/sftpd/server.go +++ b/sftpd/server.go @@ -477,7 +477,7 @@ func canAcceptConnection(ip string) bool { func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.ServerConfig) { defer func() { if r := recover(); r != nil { - logger.Error(logSender, "", "panic in AcceptInboundConnection: %#v stack strace: %v", r, string(debug.Stack())) + logger.Error(logSender, "", "panic in AcceptInboundConnection: %#v stack trace: %v", r, string(debug.Stack())) } }() @@ -597,7 +597,7 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve func (c *Configuration) handleSftpConnection(channel ssh.Channel, connection *Connection) { defer func() { if r := recover(); r != nil { - logger.Error(logSender, "", "panic in handleSftpConnection: %#v stack strace: %v", r, string(debug.Stack())) + logger.Error(logSender, "", "panic in handleSftpConnection: %#v stack trace: %v", r, string(debug.Stack())) } }() if err := common.Connections.Add(connection); err != nil { diff --git a/sftpd/sftpd_test.go b/sftpd/sftpd_test.go index 10c218f9..1050a1ce 100644 --- a/sftpd/sftpd_test.go +++ b/sftpd/sftpd_test.go @@ -2313,24 +2313,6 @@ func TestLoginWithDatabaseCredentials(t *testing.T) { u.FsConfig.Provider = sdk.GCSFilesystemProvider u.FsConfig.GCSConfig.Bucket = "testbucket" u.FsConfig.GCSConfig.Credentials = kms.NewPlainSecret(`{ "type": "service_account" }`) - - providerConf := config.GetProviderConf() - providerConf.PreferDatabaseCredentials = true - credentialsFile := filepath.Join(providerConf.CredentialsPath, fmt.Sprintf("%v_gcs_credentials.json", u.Username)) - if !filepath.IsAbs(credentialsFile) { - credentialsFile = filepath.Join(configDir, credentialsFile) - } - - assert.NoError(t, dataprovider.Close()) - - err := dataprovider.Initialize(providerConf, configDir, true) - assert.NoError(t, err) - - if _, err = os.Stat(credentialsFile); err == nil { - // remove the credentials file - assert.NoError(t, os.Remove(credentialsFile)) - } - user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.GCSConfig.Credentials.GetStatus()) @@ -2338,8 +2320,6 @@ func TestLoginWithDatabaseCredentials(t *testing.T) { assert.Empty(t, user.FsConfig.GCSConfig.Credentials.GetAdditionalData()) assert.Empty(t, user.FsConfig.GCSConfig.Credentials.GetKey()) - assert.NoFileExists(t, credentialsFile) - conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { defer conn.Close() @@ -2350,23 +2330,9 @@ func TestLoginWithDatabaseCredentials(t *testing.T) { assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) - - assert.NoError(t, dataprovider.Close()) - assert.NoError(t, config.LoadConfig(configDir, "")) - providerConf = config.GetProviderConf() - assert.NoError(t, dataprovider.Initialize(providerConf, configDir, true)) } func TestLoginInvalidFs(t *testing.T) { - err := dataprovider.Close() - assert.NoError(t, err) - err = config.LoadConfig(configDir, "") - assert.NoError(t, err) - providerConf := config.GetProviderConf() - providerConf.PreferDatabaseCredentials = false - err = dataprovider.Initialize(providerConf, configDir, true) - assert.NoError(t, err) - usePubKey := true u := getTestUser(usePubKey) u.FsConfig.Provider = sdk.GCSFilesystemProvider @@ -2375,16 +2341,6 @@ func TestLoginInvalidFs(t *testing.T) { user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) - providerConf = config.GetProviderConf() - credentialsFile := filepath.Join(providerConf.CredentialsPath, fmt.Sprintf("%v_gcs_credentials.json", u.Username)) - if !filepath.IsAbs(credentialsFile) { - credentialsFile = filepath.Join(configDir, credentialsFile) - } - - // now remove the credentials file so the filesystem creation will fail - err = os.Remove(credentialsFile) - assert.NoError(t, err) - conn, client, err := getSftpClient(user, usePubKey) if !assert.Error(t, err, "login must fail, the user has an invalid filesystem config") { client.Close() @@ -2395,14 +2351,6 @@ func TestLoginInvalidFs(t *testing.T) { assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) - - err = dataprovider.Close() - assert.NoError(t, err) - err = config.LoadConfig(configDir, "") - assert.NoError(t, err) - providerConf = config.GetProviderConf() - err = dataprovider.Initialize(providerConf, configDir, true) - assert.NoError(t, err) } func TestDeniedProtocols(t *testing.T) { diff --git a/sftpd/ssh_cmd.go b/sftpd/ssh_cmd.go index 40254427..9d2146a6 100644 --- a/sftpd/ssh_cmd.go +++ b/sftpd/ssh_cmd.go @@ -111,7 +111,7 @@ func processSSHCommand(payload []byte, connection *Connection, enabledSSHCommand func (c *sshCommand) handle() (err error) { defer func() { if r := recover(); r != nil { - logger.Error(logSender, "", "panic in handle ssh command: %#v stack strace: %v", r, string(debug.Stack())) + logger.Error(logSender, "", "panic in handle ssh command: %#v stack trace: %v", r, string(debug.Stack())) err = common.ErrGenericFailure } }() diff --git a/sftpgo.json b/sftpgo.json index 373100e8..29f3911d 100644 --- a/sftpgo.json +++ b/sftpgo.json @@ -179,7 +179,6 @@ "external_auth_hook": "", "external_auth_scope": 0, "credentials_path": "credentials", - "prefer_database_credentials": true, "pre_login_hook": "", "post_login_hook": "", "post_login_scope": 0, diff --git a/vfs/azblobfs.go b/vfs/azblobfs.go index 106b7bd2..2cd55cbb 100644 --- a/vfs/azblobfs.go +++ b/vfs/azblobfs.go @@ -70,7 +70,7 @@ func NewAzBlobFs(connectionID, localTempDir, mountPath string, config AzBlobFsCo ctxTimeout: 30 * time.Second, ctxLongTimeout: 90 * time.Second, } - if err := fs.config.Validate(); err != nil { + if err := fs.config.validate(); err != nil { return fs, err } diff --git a/vfs/cryptfs.go b/vfs/cryptfs.go index 5c69c3df..3cdc2fc2 100644 --- a/vfs/cryptfs.go +++ b/vfs/cryptfs.go @@ -33,7 +33,7 @@ type CryptFs struct { // NewCryptFs returns a CryptFs object func NewCryptFs(connectionID, rootDir, mountPath string, config CryptFsConfig) (Fs, error) { - if err := config.Validate(); err != nil { + if err := config.validate(); err != nil { return nil, err } if err := config.Passphrase.TryDecrypt(); err != nil { diff --git a/vfs/filesystem.go b/vfs/filesystem.go index 950bda80..2e3d5f77 100644 --- a/vfs/filesystem.go +++ b/vfs/filesystem.go @@ -1,21 +1,11 @@ package vfs import ( - "fmt" - "github.com/sftpgo/sdk" "github.com/drakkan/sftpgo/v2/kms" - "github.com/drakkan/sftpgo/v2/util" ) -// ValidatorHelper implements methods we need for Filesystem.ValidateConfig. -// It is implemented by vfs.Folder and dataprovider.User -type ValidatorHelper interface { - GetGCSCredentialsFilePath() string - GetEncryptionAdditionalData() string -} - // Filesystem defines filesystem details type Filesystem struct { RedactedSecret string `json:"-"` @@ -113,14 +103,11 @@ func (f *Filesystem) IsEqual(other *Filesystem) bool { // Validate verifies the FsConfig matching the configured provider and sets all other // Filesystem.*Config to their zero value if successful -func (f *Filesystem) Validate(helper ValidatorHelper) error { +func (f *Filesystem) Validate(additionalData string) error { switch f.Provider { case sdk.S3FilesystemProvider: - if err := f.S3Config.Validate(); err != nil { - return util.NewValidationError(fmt.Sprintf("could not validate s3config: %v", err)) - } - if err := f.S3Config.EncryptCredentials(helper.GetEncryptionAdditionalData()); err != nil { - return util.NewValidationError(fmt.Sprintf("could not encrypt s3 access secret: %v", err)) + if err := f.S3Config.ValidateAndEncryptCredentials(additionalData); err != nil { + return err } f.GCSConfig = GCSFsConfig{} f.AzBlobConfig = AzBlobFsConfig{} @@ -128,8 +115,8 @@ func (f *Filesystem) Validate(helper ValidatorHelper) error { f.SFTPConfig = SFTPFsConfig{} return nil case sdk.GCSFilesystemProvider: - if err := f.GCSConfig.Validate(helper.GetGCSCredentialsFilePath()); err != nil { - return util.NewValidationError(fmt.Sprintf("could not validate GCS config: %v", err)) + if err := f.GCSConfig.ValidateAndEncryptCredentials(additionalData); err != nil { + return err } f.S3Config = S3FsConfig{} f.AzBlobConfig = AzBlobFsConfig{} @@ -137,11 +124,8 @@ func (f *Filesystem) Validate(helper ValidatorHelper) error { f.SFTPConfig = SFTPFsConfig{} return nil case sdk.AzureBlobFilesystemProvider: - if err := f.AzBlobConfig.Validate(); err != nil { - return util.NewValidationError(fmt.Sprintf("could not validate Azure Blob config: %v", err)) - } - if err := f.AzBlobConfig.EncryptCredentials(helper.GetEncryptionAdditionalData()); err != nil { - return util.NewValidationError(fmt.Sprintf("could not encrypt Azure blob account key: %v", err)) + if err := f.AzBlobConfig.ValidateAndEncryptCredentials(additionalData); err != nil { + return err } f.S3Config = S3FsConfig{} f.GCSConfig = GCSFsConfig{} @@ -149,11 +133,8 @@ func (f *Filesystem) Validate(helper ValidatorHelper) error { f.SFTPConfig = SFTPFsConfig{} return nil case sdk.CryptedFilesystemProvider: - if err := f.CryptConfig.Validate(); err != nil { - return util.NewValidationError(fmt.Sprintf("could not validate Crypt fs config: %v", err)) - } - if err := f.CryptConfig.EncryptCredentials(helper.GetEncryptionAdditionalData()); err != nil { - return util.NewValidationError(fmt.Sprintf("could not encrypt Crypt fs passphrase: %v", err)) + if err := f.CryptConfig.ValidateAndEncryptCredentials(additionalData); err != nil { + return err } f.S3Config = S3FsConfig{} f.GCSConfig = GCSFsConfig{} @@ -161,11 +142,8 @@ func (f *Filesystem) Validate(helper ValidatorHelper) error { f.SFTPConfig = SFTPFsConfig{} return nil case sdk.SFTPFilesystemProvider: - if err := f.SFTPConfig.Validate(); err != nil { - return util.NewValidationError(fmt.Sprintf("could not validate SFTP fs config: %v", err)) - } - if err := f.SFTPConfig.EncryptCredentials(helper.GetEncryptionAdditionalData()); err != nil { - return util.NewValidationError(fmt.Sprintf("could not encrypt SFTP fs credentials: %v", err)) + if err := f.SFTPConfig.ValidateAndEncryptCredentials(additionalData); err != nil { + return err } f.S3Config = S3FsConfig{} f.GCSConfig = GCSFsConfig{} @@ -262,7 +240,6 @@ func (f *Filesystem) GetACopy() Filesystem { GCSConfig: GCSFsConfig{ BaseGCSFsConfig: sdk.BaseGCSFsConfig{ Bucket: f.GCSConfig.Bucket, - CredentialFile: f.GCSConfig.CredentialFile, AutomaticCredentials: f.GCSConfig.AutomaticCredentials, StorageClass: f.GCSConfig.StorageClass, ACL: f.GCSConfig.ACL, diff --git a/vfs/folder.go b/vfs/folder.go index a726d488..f6f74d83 100644 --- a/vfs/folder.go +++ b/vfs/folder.go @@ -184,9 +184,7 @@ func (v *VirtualFolder) GetFilesystem(connectionID string, forbiddenSelfUsers [] case sdk.S3FilesystemProvider: return NewS3Fs(connectionID, v.MappedPath, v.VirtualPath, v.FsConfig.S3Config) case sdk.GCSFilesystemProvider: - config := v.FsConfig.GCSConfig - config.CredentialFile = v.GetGCSCredentialsFilePath() - return NewGCSFs(connectionID, v.MappedPath, v.VirtualPath, config) + return NewGCSFs(connectionID, v.MappedPath, v.VirtualPath, v.FsConfig.GCSConfig) case sdk.AzureBlobFilesystemProvider: return NewAzBlobFs(connectionID, v.MappedPath, v.VirtualPath, v.FsConfig.AzBlobConfig) case sdk.CryptedFilesystemProvider: diff --git a/vfs/gcsfs.go b/vfs/gcsfs.go index bb090176..9ab17f64 100644 --- a/vfs/gcsfs.go +++ b/vfs/gcsfs.go @@ -5,7 +5,6 @@ package vfs import ( "context" - "encoding/json" "fmt" "io" "mime" @@ -23,7 +22,6 @@ import ( "google.golang.org/api/iterator" "google.golang.org/api/option" - "github.com/drakkan/sftpgo/v2/kms" "github.com/drakkan/sftpgo/v2/logger" "github.com/drakkan/sftpgo/v2/metric" "github.com/drakkan/sftpgo/v2/plugin" @@ -74,34 +72,18 @@ func NewGCSFs(connectionID, localTempDir, mountPath string, config GCSFsConfig) ctxTimeout: 30 * time.Second, ctxLongTimeout: 300 * time.Second, } - if err = fs.config.Validate(fs.config.CredentialFile); err != nil { + if err = fs.config.validate(); err != nil { return fs, err } ctx := context.Background() if fs.config.AutomaticCredentials > 0 { fs.svc, err = storage.NewClient(ctx) - } else if !fs.config.Credentials.IsEmpty() { + } else { err = fs.config.Credentials.TryDecrypt() if err != nil { return fs, err } fs.svc, err = storage.NewClient(ctx, option.WithCredentialsJSON([]byte(fs.config.Credentials.GetPayload()))) - } else { - var creds []byte - creds, err = os.ReadFile(fs.config.CredentialFile) - if err != nil { - return fs, err - } - secret := kms.NewEmptySecret() - err = json.Unmarshal(creds, secret) - if err != nil { - return fs, err - } - err = secret.Decrypt() - if err != nil { - return fs, err - } - fs.svc, err = storage.NewClient(ctx, option.WithCredentialsJSON([]byte(secret.GetPayload()))) } return fs, err } diff --git a/vfs/s3fs.go b/vfs/s3fs.go index 396b4a14..7a8273ba 100644 --- a/vfs/s3fs.go +++ b/vfs/s3fs.go @@ -74,7 +74,7 @@ func NewS3Fs(connectionID, localTempDir, mountPath string, s3Config S3FsConfig) config: &s3Config, ctxTimeout: 30 * time.Second, } - if err := fs.config.Validate(); err != nil { + if err := fs.config.validate(); err != nil { return fs, err } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) diff --git a/vfs/sftpfs.go b/vfs/sftpfs.go index 8ef72f6d..970b2266 100644 --- a/vfs/sftpfs.go +++ b/vfs/sftpfs.go @@ -93,8 +93,8 @@ func (c *SFTPFsConfig) setEmptyCredentialsIfNil() { } } -// Validate returns an error if the configuration is not valid -func (c *SFTPFsConfig) Validate() error { +// validate returns an error if the configuration is not valid +func (c *SFTPFsConfig) validate() error { c.setEmptyCredentialsIfNil() if c.Endpoint == "" { return errors.New("endpoint cannot be empty") @@ -139,18 +139,21 @@ func (c *SFTPFsConfig) validateCredentials() error { return nil } -// EncryptCredentials encrypts password and/or private key if they are in plain text -func (c *SFTPFsConfig) EncryptCredentials(additionalData string) error { +// ValidateAndEncryptCredentials encrypts password and/or private key if they are in plain text +func (c *SFTPFsConfig) ValidateAndEncryptCredentials(additionalData string) error { + if err := c.validate(); err != nil { + return util.NewValidationError(fmt.Sprintf("could not validate SFTP fs config: %v", err)) + } if c.Password.IsPlain() { c.Password.SetAdditionalData(additionalData) if err := c.Password.Encrypt(); err != nil { - return err + return util.NewValidationError(fmt.Sprintf("could not encrypt SFTP fs password: %v", err)) } } if c.PrivateKey.IsPlain() { c.PrivateKey.SetAdditionalData(additionalData) if err := c.PrivateKey.Encrypt(); err != nil { - return err + return util.NewValidationError(fmt.Sprintf("could not encrypt SFTP fs private key: %v", err)) } } return nil @@ -178,7 +181,7 @@ func NewSFTPFs(connectionID, mountPath, localTempDir string, forbiddenSelfUserna localTempDir = filepath.Clean(os.TempDir()) } } - if err := config.Validate(); err != nil { + if err := config.validate(); err != nil { return nil, err } if !config.Password.IsEmpty() { diff --git a/vfs/vfs.go b/vfs/vfs.go index 7017d972..c60a8bcf 100644 --- a/vfs/vfs.go +++ b/vfs/vfs.go @@ -243,13 +243,16 @@ func (c *S3FsConfig) checkCredentials() error { return nil } -// EncryptCredentials encrypts access secret if it is in plain text -func (c *S3FsConfig) EncryptCredentials(additionalData string) error { +// ValidateAndEncryptCredentials validates the configuration and encrypts access secret if it is in plain text +func (c *S3FsConfig) ValidateAndEncryptCredentials(additionalData string) error { + if err := c.validate(); err != nil { + return util.NewValidationError(fmt.Sprintf("could not validate s3config: %v", err)) + } if c.AccessSecret.IsPlain() { c.AccessSecret.SetAdditionalData(additionalData) err := c.AccessSecret.Encrypt() if err != nil { - return err + return util.NewValidationError(fmt.Sprintf("could not encrypt s3 access secret: %v", err)) } } return nil @@ -271,8 +274,8 @@ func (c *S3FsConfig) checkPartSizeAndConcurrency() error { return nil } -// Validate returns an error if the configuration is not valid -func (c *S3FsConfig) Validate() error { +// validate returns an error if the configuration is not valid +func (c *S3FsConfig) validate() error { if c.AccessSecret == nil { c.AccessSecret = kms.NewEmptySecret() } @@ -312,6 +315,21 @@ func (c *GCSFsConfig) HideConfidentialData() { } } +// ValidateAndEncryptCredentials validates the configuration and encrypts credentials if they are in plain text +func (c *GCSFsConfig) ValidateAndEncryptCredentials(additionalData string) error { + if err := c.validate(); err != nil { + return util.NewValidationError(fmt.Sprintf("could not validate GCS config: %v", err)) + } + if c.Credentials.IsPlain() { + c.Credentials.SetAdditionalData(additionalData) + err := c.Credentials.Encrypt() + if err != nil { + return util.NewValidationError(fmt.Sprintf("could not encrypt GCS credentials: %v", err)) + } + } + return nil +} + func (c *GCSFsConfig) isEqual(other *GCSFsConfig) bool { if c.Bucket != other.Bucket { return false @@ -337,8 +355,8 @@ func (c *GCSFsConfig) isEqual(other *GCSFsConfig) bool { return c.Credentials.IsEqual(other.Credentials) } -// Validate returns an error if the configuration is not valid -func (c *GCSFsConfig) Validate(credentialsFilePath string) error { +// validate returns an error if the configuration is not valid +func (c *GCSFsConfig) validate() error { if c.Credentials == nil || c.AutomaticCredentials == 1 { c.Credentials = kms.NewEmptySecret() } @@ -358,13 +376,7 @@ func (c *GCSFsConfig) Validate(credentialsFilePath string) error { return errors.New("invalid encrypted credentials") } if c.AutomaticCredentials == 0 && !c.Credentials.IsValidInput() { - fi, err := os.Stat(credentialsFilePath) - if err != nil { - return fmt.Errorf("invalid credentials %v", err) - } - if fi.Size() == 0 { - return errors.New("credentials cannot be empty") - } + return errors.New("invalid credentials") } c.StorageClass = strings.TrimSpace(c.StorageClass) c.ACL = strings.TrimSpace(c.ACL) @@ -444,18 +456,21 @@ func (c *AzBlobFsConfig) isSecretEqual(other *AzBlobFsConfig) bool { return c.AccountKey.IsEqual(other.AccountKey) } -// EncryptCredentials encrypts access secret if it is in plain text -func (c *AzBlobFsConfig) EncryptCredentials(additionalData string) error { +// ValidateAndEncryptCredentials validates the configuration and encrypts access secret if it is in plain text +func (c *AzBlobFsConfig) ValidateAndEncryptCredentials(additionalData string) error { + if err := c.validate(); err != nil { + return util.NewValidationError(fmt.Sprintf("could not validate Azure Blob config: %v", err)) + } if c.AccountKey.IsPlain() { c.AccountKey.SetAdditionalData(additionalData) if err := c.AccountKey.Encrypt(); err != nil { - return err + return util.NewValidationError(fmt.Sprintf("could not encrypt Azure blob account key: %v", err)) } } if c.SASURL.IsPlain() { c.SASURL.SetAdditionalData(additionalData) if err := c.SASURL.Encrypt(); err != nil { - return err + return util.NewValidationError(fmt.Sprintf("could not encrypt Azure blob SAS URL: %v", err)) } } return nil @@ -507,8 +522,8 @@ func (c *AzBlobFsConfig) tryDecrypt() error { return nil } -// Validate returns an error if the configuration is not valid -func (c *AzBlobFsConfig) Validate() error { +// validate returns an error if the configuration is not valid +func (c *AzBlobFsConfig) validate() error { if c.AccountKey == nil { c.AccountKey = kms.NewEmptySecret() } @@ -562,19 +577,22 @@ func (c *CryptFsConfig) isEqual(other *CryptFsConfig) bool { return c.Passphrase.IsEqual(other.Passphrase) } -// EncryptCredentials encrypts access secret if it is in plain text -func (c *CryptFsConfig) EncryptCredentials(additionalData string) error { +// ValidateAndEncryptCredentials validates the configuration and encrypts the passphrase if it is in plain text +func (c *CryptFsConfig) ValidateAndEncryptCredentials(additionalData string) error { + if err := c.validate(); err != nil { + return util.NewValidationError(fmt.Sprintf("could not validate Crypt fs config: %v", err)) + } if c.Passphrase.IsPlain() { c.Passphrase.SetAdditionalData(additionalData) if err := c.Passphrase.Encrypt(); err != nil { - return err + return util.NewValidationError(fmt.Sprintf("could not encrypt Crypt fs passphrase: %v", err)) } } return nil } -// Validate returns an error if the configuration is not valid -func (c *CryptFsConfig) Validate() error { +// validate returns an error if the configuration is not valid +func (c *CryptFsConfig) validate() error { if c.Passphrase == nil || c.Passphrase.IsEmpty() { return errors.New("invalid passphrase") } diff --git a/webdavd/server.go b/webdavd/server.go index 55ee4c37..9beccc61 100644 --- a/webdavd/server.go +++ b/webdavd/server.go @@ -135,7 +135,7 @@ func (s *webDavServer) checkRequestMethod(ctx context.Context, r *http.Request, func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { defer func() { if r := recover(); r != nil { - logger.Error(logSender, "", "panic in ServeHTTP: %#v stack strace: %v", r, string(debug.Stack())) + logger.Error(logSender, "", "panic in ServeHTTP: %#v stack trace: %v", r, string(debug.Stack())) http.Error(w, common.ErrGenericFailure.Error(), http.StatusInternalServerError) } }() diff --git a/webdavd/webdavd_test.go b/webdavd/webdavd_test.go index a6f03eb4..ffcc88f0 100644 --- a/webdavd/webdavd_test.go +++ b/webdavd/webdavd_test.go @@ -1704,24 +1704,6 @@ func TestLoginWithDatabaseCredentials(t *testing.T) { u.FsConfig.Provider = sdk.GCSFilesystemProvider u.FsConfig.GCSConfig.Bucket = "test" u.FsConfig.GCSConfig.Credentials = kms.NewPlainSecret(`{ "type": "service_account" }`) - - providerConf := config.GetProviderConf() - providerConf.PreferDatabaseCredentials = true - credentialsFile := filepath.Join(providerConf.CredentialsPath, fmt.Sprintf("%v_gcs_credentials.json", u.Username)) - if !filepath.IsAbs(credentialsFile) { - credentialsFile = filepath.Join(configDir, credentialsFile) - } - - assert.NoError(t, dataprovider.Close()) - - err := dataprovider.Initialize(providerConf, configDir, true) - assert.NoError(t, err) - - if _, err = os.Stat(credentialsFile); err == nil { - // remove the credentials file - assert.NoError(t, os.Remove(credentialsFile)) - } - user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.GCSConfig.Credentials.GetStatus()) @@ -1729,8 +1711,6 @@ func TestLoginWithDatabaseCredentials(t *testing.T) { assert.Empty(t, user.FsConfig.GCSConfig.Credentials.GetAdditionalData()) assert.Empty(t, user.FsConfig.GCSConfig.Credentials.GetKey()) - assert.NoFileExists(t, credentialsFile) - client := getWebDavClient(user, false, nil) err = client.Connect() @@ -1740,23 +1720,9 @@ func TestLoginWithDatabaseCredentials(t *testing.T) { assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) - - assert.NoError(t, dataprovider.Close()) - assert.NoError(t, config.LoadConfig(configDir, "")) - providerConf = config.GetProviderConf() - assert.NoError(t, dataprovider.Initialize(providerConf, configDir, true)) } func TestLoginInvalidFs(t *testing.T) { - err := dataprovider.Close() - assert.NoError(t, err) - err = config.LoadConfig(configDir, "") - assert.NoError(t, err) - providerConf := config.GetProviderConf() - providerConf.PreferDatabaseCredentials = false - err = dataprovider.Initialize(providerConf, configDir, true) - assert.NoError(t, err) - u := getTestUser() u.FsConfig.Provider = sdk.GCSFilesystemProvider u.FsConfig.GCSConfig.Bucket = "test" @@ -1764,16 +1730,6 @@ func TestLoginInvalidFs(t *testing.T) { user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) - providerConf = config.GetProviderConf() - credentialsFile := filepath.Join(providerConf.CredentialsPath, fmt.Sprintf("%v_gcs_credentials.json", u.Username)) - if !filepath.IsAbs(credentialsFile) { - credentialsFile = filepath.Join(configDir, credentialsFile) - } - - // now remove the credentials file so the filesystem creation will fail - err = os.Remove(credentialsFile) - assert.NoError(t, err) - client := getWebDavClient(user, true, nil) assert.Error(t, checkBasicFunc(client)) @@ -1781,14 +1737,6 @@ func TestLoginInvalidFs(t *testing.T) { assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) - - err = dataprovider.Close() - assert.NoError(t, err) - err = config.LoadConfig(configDir, "") - assert.NoError(t, err) - providerConf = config.GetProviderConf() - err = dataprovider.Initialize(providerConf, configDir, true) - assert.NoError(t, err) } func TestSFTPBuffered(t *testing.T) {