diff --git a/internal/dataprovider/dataprovider.go b/internal/dataprovider/dataprovider.go index a3baf9eb..7dd34d35 100644 --- a/internal/dataprovider/dataprovider.go +++ b/internal/dataprovider/dataprovider.go @@ -4205,6 +4205,8 @@ func getPreLoginHookResponse(loginMethod, ip, protocol string, userAsJSON []byte } func executePreLoginHook(username, loginMethod, ip, protocol string, oidcTokenFields *map[string]any) (User, error) { + var user User + u, mergedUser, userAsJSON, err := getUserAndJSONForHook(username, oidcTokenFields) if err != nil { return u, err @@ -4227,53 +4229,38 @@ func executePreLoginHook(username, loginMethod, ip, protocol string, oidcTokenFi } return u, nil } - - userID := u.ID - userUsedQuotaSize := u.UsedQuotaSize - userUsedQuotaFiles := u.UsedQuotaFiles - userUsedDownloadTransfer := u.UsedDownloadDataTransfer - userUsedUploadTransfer := u.UsedUploadDataTransfer - userLastQuotaUpdate := u.LastQuotaUpdate - userLastLogin := u.LastLogin - userFirstDownload := u.FirstDownload - userFirstUpload := u.FirstUpload - userLastPwdChange := u.LastPasswordChange - userCreatedAt := u.CreatedAt - totpConfig := u.Filters.TOTPConfig - recoveryCodes := u.Filters.RecoveryCodes - err = json.Unmarshal(out, &u) + err = json.Unmarshal(out, &user) if err != nil { return u, fmt.Errorf("invalid pre-login hook response %q, error: %v", out, err) } - u.ID = userID - u.UsedQuotaSize = userUsedQuotaSize - u.UsedQuotaFiles = userUsedQuotaFiles - u.UsedUploadDataTransfer = userUsedUploadTransfer - u.UsedDownloadDataTransfer = userUsedDownloadTransfer - u.LastQuotaUpdate = userLastQuotaUpdate - u.LastLogin = userLastLogin - u.LastPasswordChange = userLastPwdChange - u.FirstDownload = userFirstDownload - u.FirstUpload = userFirstUpload - u.CreatedAt = userCreatedAt - if userID == 0 { - err = provider.addUser(&u) - } else { - u.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + if u.ID > 0 { + user.ID = u.ID + user.UsedQuotaSize = u.UsedQuotaSize + user.UsedQuotaFiles = u.UsedQuotaFiles + user.UsedUploadDataTransfer = u.UsedUploadDataTransfer + user.UsedDownloadDataTransfer = u.UsedDownloadDataTransfer + user.LastQuotaUpdate = u.LastQuotaUpdate + user.LastLogin = u.LastLogin + user.LastPasswordChange = u.LastPasswordChange + user.FirstDownload = u.FirstDownload + user.FirstUpload = u.FirstUpload // preserve TOTP config and recovery codes - u.Filters.TOTPConfig = totpConfig - u.Filters.RecoveryCodes = recoveryCodes - err = provider.updateUser(&u) + user.Filters.TOTPConfig = u.Filters.TOTPConfig + user.Filters.RecoveryCodes = u.Filters.RecoveryCodes + if err := provider.updateUser(&user); err != nil { + return u, err + } + } else { + if err := provider.addUser(&user); err != nil { + return u, err + } } + user, err = provider.userExists(user.Username, "") if err != nil { return u, err } - user, err := provider.userExists(username, "") - if err != nil { - return u, err - } - providerLog(logger.LevelDebug, "user %q added/updated from pre-login hook response, id: %d", username, userID) - if userID > 0 { + providerLog(logger.LevelDebug, "user %q added/updated from pre-login hook response, id: %d", username, u.ID) + if u.ID > 0 { webDAVUsersCache.swap(&user, "") } return user, nil diff --git a/internal/sftpd/sftpd_test.go b/internal/sftpd/sftpd_test.go index 73b76985..135571a9 100644 --- a/internal/sftpd/sftpd_test.go +++ b/internal/sftpd/sftpd_test.go @@ -3000,6 +3000,15 @@ func TestPreLoginScript(t *testing.T) { } usePubKey := true u := getTestUser(usePubKey) + mappedPath := filepath.Join(os.TempDir(), "vdir") + folderName := filepath.Base(mappedPath) + folderMountPath := "/vpath" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName, + }, + VirtualPath: folderMountPath, + }) err := dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") @@ -3011,8 +3020,6 @@ func TestPreLoginScript(t *testing.T) { err = dataprovider.Initialize(providerConf, configDir, true) assert.NoError(t, err) - mappedPath := filepath.Join(os.TempDir(), "vdir") - folderName := filepath.Base(mappedPath) f := vfs.BaseVirtualFolder{ Name: folderName, MappedPath: mappedPath, @@ -3026,13 +3033,6 @@ func TestPreLoginScript(t *testing.T) { _, _, err = httpdtest.AddFolder(f, http.StatusCreated) assert.NoError(t, err) - folderMountPath := "/vpath" - u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ - BaseVirtualFolder: vfs.BaseVirtualFolder{ - Name: folderName, - }, - VirtualPath: folderMountPath, - }) user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(u, usePubKey) @@ -3108,6 +3108,7 @@ func TestPreLoginUserCreation(t *testing.T) { } usePubKey := false u := getTestUser(usePubKey) + u.Permissions["/list"] = []string{"list", "download"} err := dataprovider.Close() assert.NoError(t, err) err = config.LoadConfig(configDir, "") @@ -3129,6 +3130,23 @@ func TestPreLoginUserCreation(t *testing.T) { } user, _, err := httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) assert.NoError(t, err) + assert.Len(t, user.Permissions, 2) + assert.Empty(t, user.Description) + u.Description = "some desc" + delete(u.Permissions, "/list") + err = os.WriteFile(preLoginPath, getPreLoginScriptContent(u, false), os.ModePerm) + assert.NoError(t, err) + // The user should be updated and list permission removed + conn, client, err = getSftpClient(u, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + } + user, _, err = httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, user.Permissions, 1) + assert.NotEmpty(t, user.Description) _, err = httpdtest.RemoveUser(user, http.StatusOK) assert.NoError(t, err) err = os.RemoveAll(user.GetHomeDir())