diff --git a/common/protocol_test.go b/common/protocol_test.go index a381757b..c61d593e 100644 --- a/common/protocol_test.go +++ b/common/protocol_test.go @@ -1956,6 +1956,93 @@ func TestResolvePathError(t *testing.T) { assert.NoError(t, err) } +func TestDelayedQuotaUpdater(t *testing.T) { + err := dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + providerConf.DelayedQuotaUpdate = 120 + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + + u := getTestUser() + u.QuotaFiles = 100 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + err = dataprovider.UpdateUserQuota(&user, 10, 6000, false) + assert.NoError(t, err) + files, size, err := dataprovider.GetUsedQuota(user.Username) + assert.NoError(t, err) + assert.Equal(t, 10, files) + assert.Equal(t, int64(6000), size) + + userGet, err := dataprovider.UserExists(user.Username) + assert.NoError(t, err) + assert.Equal(t, 0, userGet.UsedQuotaFiles) + assert.Equal(t, int64(0), userGet.UsedQuotaSize) + + err = dataprovider.UpdateUserQuota(&user, 10, 6000, true) + assert.NoError(t, err) + files, size, err = dataprovider.GetUsedQuota(user.Username) + assert.NoError(t, err) + assert.Equal(t, 10, files) + assert.Equal(t, int64(6000), size) + + userGet, err = dataprovider.UserExists(user.Username) + assert.NoError(t, err) + assert.Equal(t, 10, userGet.UsedQuotaFiles) + assert.Equal(t, int64(6000), userGet.UsedQuotaSize) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + folder := vfs.BaseVirtualFolder{ + Name: "folder", + MappedPath: filepath.Join(os.TempDir(), "p"), + } + err = dataprovider.AddFolder(&folder) + assert.NoError(t, err) + + err = dataprovider.UpdateVirtualFolderQuota(&folder, 10, 6000, false) + assert.NoError(t, err) + files, size, err = dataprovider.GetUsedVirtualFolderQuota(folder.Name) + assert.NoError(t, err) + assert.Equal(t, 10, files) + assert.Equal(t, int64(6000), size) + + folderGet, err := dataprovider.GetFolderByName(folder.Name) + assert.NoError(t, err) + assert.Equal(t, 0, folderGet.UsedQuotaFiles) + assert.Equal(t, int64(0), folderGet.UsedQuotaSize) + + err = dataprovider.UpdateVirtualFolderQuota(&folder, 10, 6000, true) + assert.NoError(t, err) + files, size, err = dataprovider.GetUsedVirtualFolderQuota(folder.Name) + assert.NoError(t, err) + assert.Equal(t, 10, files) + assert.Equal(t, int64(6000), size) + + folderGet, err = dataprovider.GetFolderByName(folder.Name) + assert.NoError(t, err) + assert.Equal(t, 10, folderGet.UsedQuotaFiles) + assert.Equal(t, int64(6000), folderGet.UsedQuotaSize) + + err = dataprovider.DeleteFolder(folder.Name) + assert.NoError(t, err) + + err = dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf = config.GetProviderConf() + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) +} + func TestQuotaTrackDisabled(t *testing.T) { err := dataprovider.Close() assert.NoError(t, err) diff --git a/config/config.go b/config/config.go index 9cc1c6c5..2eb46d60 100644 --- a/config/config.go +++ b/config/config.go @@ -181,7 +181,7 @@ func Init() { Driver: "sqlite", Name: "sftpgo.db", Host: "", - Port: 5432, + Port: 0, Username: "", Password: "", ConnectionString: "", @@ -212,6 +212,7 @@ func Init() { UpdateMode: 0, PreferDatabaseCredentials: false, SkipNaturalKeysValidation: false, + DelayedQuotaUpdate: 0, }, HTTPDConfig: httpd.Conf{ Bindings: []httpd.Binding{defaultHTTPDBinding}, @@ -857,6 +858,7 @@ func setViperDefaults() { viper.SetDefault("data_provider.password_hashing.argon2_options.parallelism", globalConf.ProviderConf.PasswordHashing.Argon2Options.Parallelism) viper.SetDefault("data_provider.update_mode", globalConf.ProviderConf.UpdateMode) viper.SetDefault("data_provider.skip_natural_keys_validation", globalConf.ProviderConf.SkipNaturalKeysValidation) + viper.SetDefault("data_provider.delayed_quota_update", globalConf.ProviderConf.DelayedQuotaUpdate) viper.SetDefault("httpd.templates_path", globalConf.HTTPDConfig.TemplatesPath) viper.SetDefault("httpd.static_files_path", globalConf.HTTPDConfig.StaticFilesPath) viper.SetDefault("httpd.backups_path", globalConf.HTTPDConfig.BackupsPath) diff --git a/dataprovider/dataprovider.go b/dataprovider/dataprovider.go index 4915c80b..73c3914b 100644 --- a/dataprovider/dataprovider.go +++ b/dataprovider/dataprovider.go @@ -279,6 +279,14 @@ type Config struct { // folder name. These keys are used in URIs for REST API and Web admin. By default only unreserved URI // characters are allowed: ALPHA / DIGIT / "-" / "." / "_" / "~". SkipNaturalKeysValidation bool `json:"skip_natural_keys_validation" mapstructure:"skip_natural_keys_validation"` + // DelayedQuotaUpdate defines the number of seconds to accumulate quota updates. + // If there are a lot of close uploads, accumulating quota updates can save you many + // queries to the data provider. + // If you want to track quotas, a scheduled quota update is recommended in any case, the stored + // quota size may be incorrect for several reasons, such as an unexpected shutdown, temporary provider + // failures, file copied outside of SFTPGo, and so on. + // 0 means immediate quota update. + DelayedQuotaUpdate int `json:"delayed_quota_update" mapstructure:"delayed_quota_update"` } // BackupData defines the structure for the backup/restore files @@ -469,6 +477,7 @@ func Initialize(cnf Config, basePath string, checkAdmins bool) error { providerLog(logger.LevelInfo, "database initialization/migration skipped, manual mode is configured") } startAvailabilityTimer() + delayedQuotaUpdater.start() return nil } @@ -767,7 +776,14 @@ func UpdateUserQuota(user *User, filesAdd int, sizeAdd int64, reset bool) error if filesAdd == 0 && sizeAdd == 0 && !reset { return nil } - return provider.updateQuota(user.Username, filesAdd, sizeAdd, reset) + if config.DelayedQuotaUpdate == 0 || reset { + if reset { + delayedQuotaUpdater.resetUserQuota(user.Username) + } + return provider.updateQuota(user.Username, filesAdd, sizeAdd, reset) + } + delayedQuotaUpdater.updateUserQuota(user.Username, filesAdd, sizeAdd) + return nil } // UpdateVirtualFolderQuota updates the quota for the given virtual folder adding filesAdd and sizeAdd. @@ -779,7 +795,14 @@ func UpdateVirtualFolderQuota(vfolder *vfs.BaseVirtualFolder, filesAdd int, size if filesAdd == 0 && sizeAdd == 0 && !reset { return nil } - return provider.updateFolderQuota(vfolder.Name, filesAdd, sizeAdd, reset) + if config.DelayedQuotaUpdate == 0 || reset { + if reset { + delayedQuotaUpdater.resetFolderQuota(vfolder.Name) + } + return provider.updateFolderQuota(vfolder.Name, filesAdd, sizeAdd, reset) + } + delayedQuotaUpdater.updateFolderQuota(vfolder.Name, filesAdd, sizeAdd) + return nil } // GetUsedQuota returns the used quota for the given SFTP user. @@ -787,7 +810,12 @@ func GetUsedQuota(username string) (int, int64, error) { if config.TrackQuota == 0 { return 0, 0, &MethodDisabledError{err: trackQuotaDisabledError} } - return provider.getUsedQuota(username) + files, size, err := provider.getUsedQuota(username) + if err != nil { + return files, size, err + } + delayedFiles, delayedSize := delayedQuotaUpdater.getUserPendingQuota(username) + return files + delayedFiles, size + delayedSize, err } // GetUsedVirtualFolderQuota returns the used quota for the given virtual folder. @@ -795,7 +823,12 @@ func GetUsedVirtualFolderQuota(name string) (int, int64, error) { if config.TrackQuota == 0 { return 0, 0, &MethodDisabledError{err: trackQuotaDisabledError} } - return provider.getUsedFolderQuota(name) + files, size, err := provider.getUsedFolderQuota(name) + if err != nil { + return files, size, err + } + delayedFiles, delayedSize := delayedQuotaUpdater.getFolderPendingQuota(name) + return files + delayedFiles, size + delayedSize, err } // AddAdmin adds a new SFTPGo admin @@ -855,6 +888,7 @@ func DeleteUser(username string) error { err = provider.deleteUser(&user) if err == nil { RemoveCachedWebDAVUser(user.Username) + delayedQuotaUpdater.resetUserQuota(username) executeAction(operationDelete, &user) } return err @@ -904,6 +938,7 @@ func DeleteFolder(folderName string) error { for _, user := range folder.Users { RemoveCachedWebDAVUser(user) } + delayedQuotaUpdater.resetFolderQuota(folderName) } return err } diff --git a/dataprovider/quotaupdater.go b/dataprovider/quotaupdater.go new file mode 100644 index 00000000..8645e9e7 --- /dev/null +++ b/dataprovider/quotaupdater.go @@ -0,0 +1,183 @@ +package dataprovider + +import ( + "sync" + "time" + + "github.com/drakkan/sftpgo/logger" +) + +var delayedQuotaUpdater *quotaUpdater + +func init() { + delayedQuotaUpdater = newQuotaUpdater() +} + +type quotaObject struct { + size int64 + files int +} + +type quotaUpdater struct { + paramsMutex sync.RWMutex + waitTime time.Duration + sync.RWMutex + pendingUserQuotaUpdates map[string]quotaObject + pendingFolderQuotaUpdates map[string]quotaObject +} + +func newQuotaUpdater() *quotaUpdater { + return "aUpdater{ + pendingUserQuotaUpdates: make(map[string]quotaObject), + pendingFolderQuotaUpdates: make(map[string]quotaObject), + } +} + +func (q *quotaUpdater) start() { + q.setWaitTime(config.DelayedQuotaUpdate) + + go q.loop() +} + +func (q *quotaUpdater) loop() { + waitTime := q.getWaitTime() + providerLog(logger.LevelDebug, "delayed quota update loop started, wait time: %v", waitTime) + for waitTime > 0 { + // We do this with a time.Sleep instead of a time.Ticker because we don't know + // how long each quota processing cycle will take, and we want to make + // sure we wait the configured seconds between each iteration + time.Sleep(waitTime) + providerLog(logger.LevelDebug, "delayed quota update check start") + q.storeUsersQuota() + q.storeFoldersQuota() + providerLog(logger.LevelDebug, "delayed quota update check end") + waitTime = q.getWaitTime() + } + providerLog(logger.LevelDebug, "delayed quota update loop ended, wait time: %v", waitTime) +} + +func (q *quotaUpdater) setWaitTime(secs int) { + q.paramsMutex.Lock() + defer q.paramsMutex.Unlock() + + q.waitTime = time.Duration(secs) * time.Second +} + +func (q *quotaUpdater) getWaitTime() time.Duration { + q.paramsMutex.RLock() + defer q.paramsMutex.RUnlock() + + return q.waitTime +} + +func (q *quotaUpdater) resetUserQuota(username string) { + q.Lock() + defer q.Unlock() + + delete(q.pendingUserQuotaUpdates, username) +} + +func (q *quotaUpdater) updateUserQuota(username string, files int, size int64) { + q.Lock() + defer q.Unlock() + + obj := q.pendingUserQuotaUpdates[username] + obj.size += size + obj.files += files + if obj.files == 0 && obj.size == 0 { + delete(q.pendingUserQuotaUpdates, username) + return + } + q.pendingUserQuotaUpdates[username] = obj +} + +func (q *quotaUpdater) getUserPendingQuota(username string) (int, int64) { + q.RLock() + defer q.RUnlock() + + obj := q.pendingUserQuotaUpdates[username] + + return obj.files, obj.size +} + +func (q *quotaUpdater) resetFolderQuota(name string) { + q.Lock() + defer q.Unlock() + + delete(q.pendingFolderQuotaUpdates, name) +} + +func (q *quotaUpdater) updateFolderQuota(name string, files int, size int64) { + q.Lock() + defer q.Unlock() + + obj := q.pendingFolderQuotaUpdates[name] + obj.size += size + obj.files += files + if obj.files == 0 && obj.size == 0 { + delete(q.pendingFolderQuotaUpdates, name) + return + } + q.pendingFolderQuotaUpdates[name] = obj +} + +func (q *quotaUpdater) getFolderPendingQuota(name string) (int, int64) { + q.RLock() + defer q.RUnlock() + + obj := q.pendingFolderQuotaUpdates[name] + + return obj.files, obj.size +} + +func (q *quotaUpdater) getUsernames() []string { + q.RLock() + defer q.RUnlock() + + result := make([]string, 0, len(q.pendingUserQuotaUpdates)) + for username := range q.pendingUserQuotaUpdates { + result = append(result, username) + } + + return result +} + +func (q *quotaUpdater) getFoldernames() []string { + q.RLock() + defer q.RUnlock() + + result := make([]string, 0, len(q.pendingFolderQuotaUpdates)) + for name := range q.pendingFolderQuotaUpdates { + result = append(result, name) + } + + return result +} + +func (q *quotaUpdater) storeUsersQuota() { + for _, username := range q.getUsernames() { + files, size := q.getUserPendingQuota(username) + if size != 0 || files != 0 { + err := provider.updateQuota(username, files, size, false) + if err != nil { + providerLog(logger.LevelWarn, "unable to update quota delayed for user %#v: %v", username, err) + continue + } + q.updateUserQuota(username, -files, -size) + } + } +} + +func (q *quotaUpdater) storeFoldersQuota() { + for _, name := range q.getFoldernames() { + files, size := q.getFolderPendingQuota(name) + if size != 0 || files != 0 { + err := provider.updateFolderQuota(name, files, size, false) + if err != nil { + providerLog(logger.LevelWarn, "unable to update quota delayed for folder %#v: %v", name, err) + continue + } + q.updateFolderQuota(name, -files, -size) + } + } +} diff --git a/docs/full-configuration.md b/docs/full-configuration.md index 75ef9797..6216c749 100644 --- a/docs/full-configuration.md +++ b/docs/full-configuration.md @@ -191,6 +191,7 @@ The configuration file contains the following sections: - `parallelism`. unsigned 8 bit integer. The number of threads (or lanes) used by the algorithm. Default: 2. - `update_mode`, integer. Defines how the database will be initialized/updated. 0 means automatically. 1 means manually using the initprovider sub-command. - `skip_natural_keys_validation`, boolean. If `true` you can use any UTF-8 character for natural keys as username, admin name, folder name. These keys are used in URIs for REST API and Web admin. If `false` only unreserved URI characters are allowed: ALPHA / DIGIT / "-" / "." / "_" / "~". Default: `false`. + - `delayed_quota_update`, integer. This configuration parameter defines the number of seconds to accumulate quota updates. If there are a lot of close uploads, accumulating quota updates can save you many queries to the data provider. If you want to track quotas, a scheduled quota update is recommended in any case, the stored quota size may be incorrect for several reasons, such as an unexpected shutdown, temporary provider failures, file copied outside of SFTPGo, and so on. 0 means immediate quota update. - **"httpd"**, the configuration for the HTTP server used to serve REST API and to expose the built-in web interface - `bindings`, list of structs. Each struct has the following fields: - `port`, integer. The port used for serving HTTP requests. Default: 8080. diff --git a/httpd/httpd.go b/httpd/httpd.go index 4471397c..30a3285d 100644 --- a/httpd/httpd.go +++ b/httpd/httpd.go @@ -17,6 +17,7 @@ import ( "github.com/go-chi/chi/v5" "github.com/go-chi/jwtauth/v5" + "github.com/lestrrat-go/jwx/jwa" "github.com/drakkan/sftpgo/common" "github.com/drakkan/sftpgo/dataprovider" @@ -250,7 +251,7 @@ func (c *Conf) Initialize(configDir string) error { certMgr = mgr } - csrfTokenAuth = jwtauth.New("HS256", utils.GenerateRandomBytes(32), nil) + csrfTokenAuth = jwtauth.New(jwa.HS256.String(), utils.GenerateRandomBytes(32), nil) exitChannel := make(chan error, 1) diff --git a/httpd/internal_test.go b/httpd/internal_test.go index bfa8455a..baef4b1a 100644 --- a/httpd/internal_test.go +++ b/httpd/internal_test.go @@ -22,6 +22,7 @@ import ( "github.com/go-chi/chi/v5" "github.com/go-chi/jwtauth/v5" + "github.com/lestrrat-go/jwx/jwa" "github.com/lestrrat-go/jwx/jwt" "github.com/rs/xid" "github.com/stretchr/testify/assert" @@ -409,7 +410,7 @@ func TestCSRFToken(t *testing.T) { tokenString = createCSRFToken() assert.Empty(t, tokenString) - csrfTokenAuth = jwtauth.New("HS256", utils.GenerateRandomBytes(32), nil) + csrfTokenAuth = jwtauth.New(jwa.HS256.String(), utils.GenerateRandomBytes(32), nil) } func TestCreateTokenError(t *testing.T) { @@ -460,7 +461,7 @@ func TestCreateTokenError(t *testing.T) { } func TestJWTTokenValidation(t *testing.T) { - tokenAuth := jwtauth.New("HS256", utils.GenerateRandomBytes(32), nil) + tokenAuth := jwtauth.New(jwa.HS256.String(), utils.GenerateRandomBytes(32), nil) claims := make(map[string]interface{}) claims["username"] = "admin" claims[jwt.ExpirationKey] = time.Now().UTC().Add(-1 * time.Hour) @@ -520,7 +521,7 @@ func TestAdminAllowListConnAddr(t *testing.T) { func TestUpdateContextFromCookie(t *testing.T) { server := httpdServer{ - tokenAuth: jwtauth.New("HS256", utils.GenerateRandomBytes(32), nil), + tokenAuth: jwtauth.New(jwa.HS256.String(), utils.GenerateRandomBytes(32), nil), } req, _ := http.NewRequest(http.MethodGet, tokenPath, nil) claims := make(map[string]interface{}) @@ -534,7 +535,7 @@ func TestUpdateContextFromCookie(t *testing.T) { func TestCookieExpiration(t *testing.T) { server := httpdServer{ - tokenAuth: jwtauth.New("HS256", utils.GenerateRandomBytes(32), nil), + tokenAuth: jwtauth.New(jwa.HS256.String(), utils.GenerateRandomBytes(32), nil), } err := errors.New("test error") rr := httptest.NewRecorder() @@ -842,7 +843,7 @@ func TestGetUserFromTemplate(t *testing.T) { func TestJWTTokenCleanup(t *testing.T) { server := httpdServer{ - tokenAuth: jwtauth.New("HS256", utils.GenerateRandomBytes(32), nil), + tokenAuth: jwtauth.New(jwa.HS256.String(), utils.GenerateRandomBytes(32), nil), } admin := dataprovider.Admin{ Username: "newtestadmin", diff --git a/httpd/server.go b/httpd/server.go index f4ddbb11..48a24d9d 100644 --- a/httpd/server.go +++ b/httpd/server.go @@ -13,6 +13,7 @@ import ( "github.com/go-chi/chi/v5/middleware" "github.com/go-chi/jwtauth/v5" "github.com/go-chi/render" + "github.com/lestrrat-go/jwx/jwa" "github.com/drakkan/sftpgo/common" "github.com/drakkan/sftpgo/dataprovider" @@ -252,7 +253,7 @@ func (s *httpdServer) updateContextFromCookie(r *http.Request) *http.Request { } func (s *httpdServer) initializeRouter() { - s.tokenAuth = jwtauth.New("HS256", utils.GenerateRandomBytes(32), nil) + s.tokenAuth = jwtauth.New(jwa.HS256.String(), utils.GenerateRandomBytes(32), nil) s.router = chi.NewRouter() s.router.Use(saveConnectionAddress) diff --git a/sftpgo.json b/sftpgo.json index 05903b8f..e26cdab5 100644 --- a/sftpgo.json +++ b/sftpgo.json @@ -126,6 +126,7 @@ "connection_string": "", "sql_tables_prefix": "", "track_quota": 2, + "delayed_quota_update": 0, "pool_size": 0, "users_base_dir": "", "actions": {