mirror of
https://github.com/drakkan/sftpgo.git
synced 2025-12-07 14:50:55 +03:00
fix a potential race condition for pre-login and ext auth
hooks doing something like this: err = provider.updateUser(u) ... return provider.userExists(username) could be racy if another update happen before provider.userExists(username) also pass a pointer to updateUser so if the user is modified inside "validateUser" we can just return the modified user without do a new query
This commit is contained in:
@@ -48,7 +48,7 @@ func getUserByUsername(username string, dbHandle sqlQuerier) (User, error) {
|
||||
|
||||
func sqlCommonValidateUserAndPass(username, password, ip, protocol string, dbHandle *sql.DB) (User, error) {
|
||||
var user User
|
||||
if len(password) == 0 {
|
||||
if password == "" {
|
||||
return user, errors.New("Credentials cannot be null or empty")
|
||||
}
|
||||
user, err := getUserByUsername(username, dbHandle)
|
||||
@@ -177,8 +177,8 @@ func sqlCommonCheckUserExists(username string, dbHandle *sql.DB) (User, error) {
|
||||
return getUserWithVirtualFolders(user, dbHandle)
|
||||
}
|
||||
|
||||
func sqlCommonAddUser(user User, dbHandle *sql.DB) error {
|
||||
err := validateUser(&user)
|
||||
func sqlCommonAddUser(user *User, dbHandle *sql.DB) error {
|
||||
err := validateUser(user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -231,8 +231,8 @@ func sqlCommonAddUser(user User, dbHandle *sql.DB) error {
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func sqlCommonUpdateUser(user User, dbHandle *sql.DB) error {
|
||||
err := validateUser(&user)
|
||||
func sqlCommonUpdateUser(user *User, dbHandle *sql.DB) error {
|
||||
err := validateUser(user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -285,7 +285,7 @@ func sqlCommonUpdateUser(user User, dbHandle *sql.DB) error {
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func sqlCommonDeleteUser(user User, dbHandle *sql.DB) error {
|
||||
func sqlCommonDeleteUser(user *User, dbHandle *sql.DB) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
q := getDeleteUserQuery()
|
||||
@@ -470,7 +470,7 @@ func sqlCommonCheckFolderExists(ctx context.Context, name string, dbHandle sqlQu
|
||||
func sqlCommonAddOrGetFolder(ctx context.Context, name string, usedQuotaSize int64, usedQuotaFiles int, lastQuotaUpdate int64, dbHandle sqlQuerier) (vfs.BaseVirtualFolder, error) {
|
||||
folder, err := sqlCommonCheckFolderExists(ctx, name, dbHandle)
|
||||
if _, ok := err.(*RecordNotFoundError); ok {
|
||||
f := vfs.BaseVirtualFolder{
|
||||
f := &vfs.BaseVirtualFolder{
|
||||
MappedPath: name,
|
||||
UsedQuotaSize: usedQuotaSize,
|
||||
UsedQuotaFiles: usedQuotaFiles,
|
||||
@@ -485,8 +485,8 @@ func sqlCommonAddOrGetFolder(ctx context.Context, name string, usedQuotaSize int
|
||||
return folder, err
|
||||
}
|
||||
|
||||
func sqlCommonAddFolder(folder vfs.BaseVirtualFolder, dbHandle sqlQuerier) error {
|
||||
err := validateFolder(&folder)
|
||||
func sqlCommonAddFolder(folder *vfs.BaseVirtualFolder, dbHandle sqlQuerier) error {
|
||||
err := validateFolder(folder)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -503,7 +503,7 @@ func sqlCommonAddFolder(folder vfs.BaseVirtualFolder, dbHandle sqlQuerier) error
|
||||
return err
|
||||
}
|
||||
|
||||
func sqlCommonDeleteFolder(folder vfs.BaseVirtualFolder, dbHandle sqlQuerier) error {
|
||||
func sqlCommonDeleteFolder(folder *vfs.BaseVirtualFolder, dbHandle sqlQuerier) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
q := getDeleteFolderQuery()
|
||||
@@ -585,7 +585,7 @@ func sqlCommonGetFolders(limit, offset int, order, folderPath string, dbHandle s
|
||||
return getVirtualFoldersWithUsers(folders, dbHandle)
|
||||
}
|
||||
|
||||
func sqlCommonClearFolderMapping(ctx context.Context, user User, dbHandle sqlQuerier) error {
|
||||
func sqlCommonClearFolderMapping(ctx context.Context, user *User, dbHandle sqlQuerier) error {
|
||||
q := getClearFolderMappingQuery()
|
||||
stmt, err := dbHandle.PrepareContext(ctx, q)
|
||||
if err != nil {
|
||||
@@ -597,7 +597,7 @@ func sqlCommonClearFolderMapping(ctx context.Context, user User, dbHandle sqlQue
|
||||
return err
|
||||
}
|
||||
|
||||
func sqlCommonAddFolderMapping(ctx context.Context, user User, folder vfs.VirtualFolder, dbHandle sqlQuerier) error {
|
||||
func sqlCommonAddFolderMapping(ctx context.Context, user *User, folder vfs.VirtualFolder, dbHandle sqlQuerier) error {
|
||||
q := getAddFolderMappingQuery()
|
||||
stmt, err := dbHandle.PrepareContext(ctx, q)
|
||||
if err != nil {
|
||||
@@ -609,7 +609,7 @@ func sqlCommonAddFolderMapping(ctx context.Context, user User, folder vfs.Virtua
|
||||
return err
|
||||
}
|
||||
|
||||
func generateVirtualFoldersMapping(ctx context.Context, user User, dbHandle sqlQuerier) error {
|
||||
func generateVirtualFoldersMapping(ctx context.Context, user *User, dbHandle sqlQuerier) error {
|
||||
err := sqlCommonClearFolderMapping(ctx, user, dbHandle)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -813,7 +813,7 @@ func sqlCommonExecSQLAndUpdateDBVersion(dbHandle *sql.DB, sql []string, newVersi
|
||||
return err
|
||||
}
|
||||
for _, q := range sql {
|
||||
if len(strings.TrimSpace(q)) == 0 {
|
||||
if strings.TrimSpace(q) == "" {
|
||||
continue
|
||||
}
|
||||
_, err = tx.ExecContext(ctx, q)
|
||||
@@ -892,7 +892,7 @@ func sqlCommonRestoreCompatVirtualFolders(ctx context.Context, users []userCompa
|
||||
QuotaSize: quotaSize,
|
||||
QuotaFiles: quotaFiles,
|
||||
}
|
||||
err = sqlCommonAddFolderMapping(ctx, u, f, dbHandle)
|
||||
err = sqlCommonAddFolderMapping(ctx, &u, f, dbHandle)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelWarn, "error adding virtual folder mapping for user %#v: %v", user.Username, err)
|
||||
return foldersToScan, err
|
||||
@@ -923,7 +923,7 @@ func sqlCommonUpdateDatabaseFrom3To4(sqlV4 string, dbHandle *sql.DB) error {
|
||||
return err
|
||||
}
|
||||
for _, q := range strings.Split(sql, ";") {
|
||||
if len(strings.TrimSpace(q)) == 0 {
|
||||
if strings.TrimSpace(q) == "" {
|
||||
continue
|
||||
}
|
||||
_, err = tx.ExecContext(ctx, q)
|
||||
|
||||
Reference in New Issue
Block a user