fix a potential race condition for pre-login and ext auth

hooks

doing something like this:

err = provider.updateUser(u)
...
return provider.userExists(username)

could be racy if another update happen before

provider.userExists(username)

also pass a pointer to updateUser so if the user is modified inside
"validateUser" we can just return the modified user without do a new
query
This commit is contained in:
Nicola Murino
2021-01-05 09:50:22 +01:00
parent 72b2c83392
commit daac90c4e1
26 changed files with 167 additions and 163 deletions

View File

@@ -48,7 +48,7 @@ func getUserByUsername(username string, dbHandle sqlQuerier) (User, error) {
func sqlCommonValidateUserAndPass(username, password, ip, protocol string, dbHandle *sql.DB) (User, error) {
var user User
if len(password) == 0 {
if password == "" {
return user, errors.New("Credentials cannot be null or empty")
}
user, err := getUserByUsername(username, dbHandle)
@@ -177,8 +177,8 @@ func sqlCommonCheckUserExists(username string, dbHandle *sql.DB) (User, error) {
return getUserWithVirtualFolders(user, dbHandle)
}
func sqlCommonAddUser(user User, dbHandle *sql.DB) error {
err := validateUser(&user)
func sqlCommonAddUser(user *User, dbHandle *sql.DB) error {
err := validateUser(user)
if err != nil {
return err
}
@@ -231,8 +231,8 @@ func sqlCommonAddUser(user User, dbHandle *sql.DB) error {
return tx.Commit()
}
func sqlCommonUpdateUser(user User, dbHandle *sql.DB) error {
err := validateUser(&user)
func sqlCommonUpdateUser(user *User, dbHandle *sql.DB) error {
err := validateUser(user)
if err != nil {
return err
}
@@ -285,7 +285,7 @@ func sqlCommonUpdateUser(user User, dbHandle *sql.DB) error {
return tx.Commit()
}
func sqlCommonDeleteUser(user User, dbHandle *sql.DB) error {
func sqlCommonDeleteUser(user *User, dbHandle *sql.DB) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
q := getDeleteUserQuery()
@@ -470,7 +470,7 @@ func sqlCommonCheckFolderExists(ctx context.Context, name string, dbHandle sqlQu
func sqlCommonAddOrGetFolder(ctx context.Context, name string, usedQuotaSize int64, usedQuotaFiles int, lastQuotaUpdate int64, dbHandle sqlQuerier) (vfs.BaseVirtualFolder, error) {
folder, err := sqlCommonCheckFolderExists(ctx, name, dbHandle)
if _, ok := err.(*RecordNotFoundError); ok {
f := vfs.BaseVirtualFolder{
f := &vfs.BaseVirtualFolder{
MappedPath: name,
UsedQuotaSize: usedQuotaSize,
UsedQuotaFiles: usedQuotaFiles,
@@ -485,8 +485,8 @@ func sqlCommonAddOrGetFolder(ctx context.Context, name string, usedQuotaSize int
return folder, err
}
func sqlCommonAddFolder(folder vfs.BaseVirtualFolder, dbHandle sqlQuerier) error {
err := validateFolder(&folder)
func sqlCommonAddFolder(folder *vfs.BaseVirtualFolder, dbHandle sqlQuerier) error {
err := validateFolder(folder)
if err != nil {
return err
}
@@ -503,7 +503,7 @@ func sqlCommonAddFolder(folder vfs.BaseVirtualFolder, dbHandle sqlQuerier) error
return err
}
func sqlCommonDeleteFolder(folder vfs.BaseVirtualFolder, dbHandle sqlQuerier) error {
func sqlCommonDeleteFolder(folder *vfs.BaseVirtualFolder, dbHandle sqlQuerier) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
q := getDeleteFolderQuery()
@@ -585,7 +585,7 @@ func sqlCommonGetFolders(limit, offset int, order, folderPath string, dbHandle s
return getVirtualFoldersWithUsers(folders, dbHandle)
}
func sqlCommonClearFolderMapping(ctx context.Context, user User, dbHandle sqlQuerier) error {
func sqlCommonClearFolderMapping(ctx context.Context, user *User, dbHandle sqlQuerier) error {
q := getClearFolderMappingQuery()
stmt, err := dbHandle.PrepareContext(ctx, q)
if err != nil {
@@ -597,7 +597,7 @@ func sqlCommonClearFolderMapping(ctx context.Context, user User, dbHandle sqlQue
return err
}
func sqlCommonAddFolderMapping(ctx context.Context, user User, folder vfs.VirtualFolder, dbHandle sqlQuerier) error {
func sqlCommonAddFolderMapping(ctx context.Context, user *User, folder vfs.VirtualFolder, dbHandle sqlQuerier) error {
q := getAddFolderMappingQuery()
stmt, err := dbHandle.PrepareContext(ctx, q)
if err != nil {
@@ -609,7 +609,7 @@ func sqlCommonAddFolderMapping(ctx context.Context, user User, folder vfs.Virtua
return err
}
func generateVirtualFoldersMapping(ctx context.Context, user User, dbHandle sqlQuerier) error {
func generateVirtualFoldersMapping(ctx context.Context, user *User, dbHandle sqlQuerier) error {
err := sqlCommonClearFolderMapping(ctx, user, dbHandle)
if err != nil {
return err
@@ -813,7 +813,7 @@ func sqlCommonExecSQLAndUpdateDBVersion(dbHandle *sql.DB, sql []string, newVersi
return err
}
for _, q := range sql {
if len(strings.TrimSpace(q)) == 0 {
if strings.TrimSpace(q) == "" {
continue
}
_, err = tx.ExecContext(ctx, q)
@@ -892,7 +892,7 @@ func sqlCommonRestoreCompatVirtualFolders(ctx context.Context, users []userCompa
QuotaSize: quotaSize,
QuotaFiles: quotaFiles,
}
err = sqlCommonAddFolderMapping(ctx, u, f, dbHandle)
err = sqlCommonAddFolderMapping(ctx, &u, f, dbHandle)
if err != nil {
providerLog(logger.LevelWarn, "error adding virtual folder mapping for user %#v: %v", user.Username, err)
return foldersToScan, err
@@ -923,7 +923,7 @@ func sqlCommonUpdateDatabaseFrom3To4(sqlV4 string, dbHandle *sql.DB) error {
return err
}
for _, q := range strings.Split(sql, ";") {
if len(strings.TrimSpace(q)) == 0 {
if strings.TrimSpace(q) == "" {
continue
}
_, err = tx.ExecContext(ctx, q)