diff --git a/common/protocol_test.go b/common/protocol_test.go index 01bb4728..03e6b728 100644 --- a/common/protocol_test.go +++ b/common/protocol_test.go @@ -1002,6 +1002,12 @@ func TestVirtualFoldersQuotaRenameOverwrite(t *testing.T) { func TestQuotaRenameOverwrite(t *testing.T) { u := getTestUser() u.QuotaFiles = 100 + u.Filters.DataTransferLimits = []sdk.DataTransferLimit{ + { + Sources: []string{"10.8.0.0/8"}, + TotalDataTransfer: 1, + }, + } user, _, err := httpdtest.AddUser(u, http.StatusCreated) assert.NoError(t, err) conn, client, err := getSftpClient(user) @@ -1013,16 +1019,28 @@ func TestQuotaRenameOverwrite(t *testing.T) { testFileName1 := "test_file1.dat" err = writeSFTPFile(testFileName, testFileSize, client) assert.NoError(t, err) + f, err := client.Open(testFileName) + assert.NoError(t, err) + contents := make([]byte, testFileSize) + n, err := io.ReadFull(f, contents) + assert.NoError(t, err) + assert.Equal(t, int(testFileSize), n) + err = f.Close() + assert.NoError(t, err) err = writeSFTPFile(testFileName1, testFileSize1, client) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) + assert.Equal(t, int64(0), user.UsedDownloadDataTransfer) + assert.Equal(t, int64(0), user.UsedUploadDataTransfer) assert.Equal(t, 2, user.UsedQuotaFiles) assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) err = client.Rename(testFileName, testFileName1) assert.NoError(t, err) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) assert.NoError(t, err) + assert.Equal(t, int64(0), user.UsedDownloadDataTransfer) + assert.Equal(t, int64(0), user.UsedUploadDataTransfer) assert.Equal(t, 1, user.UsedQuotaFiles) assert.Equal(t, testFileSize, user.UsedQuotaSize) err = client.Remove(testFileName1) diff --git a/common/transfer.go b/common/transfer.go index f33f5c68..af20b611 100644 --- a/common/transfer.go +++ b/common/transfer.go @@ -153,8 +153,7 @@ func (t *BaseTransfer) HasSizeLimit() bool { if t.MaxWriteSize > 0 { return true } - if t.transferQuota.AllowedDLSize > 0 || t.transferQuota.AllowedULSize > 0 || - t.transferQuota.AllowedTotalSize > 0 { + if t.transferQuota.HasSizeLimits() { return true } @@ -249,10 +248,11 @@ func (t *BaseTransfer) Truncate(fsPath string, size int64) (int64, error) { sizeDiff := initialSize - size t.MaxWriteSize += sizeDiff metric.TransferCompleted(atomic.LoadInt64(&t.BytesSent), atomic.LoadInt64(&t.BytesReceived), t.transferType, t.ErrTransfer) - go func(ulSize, dlSize int64, user dataprovider.User) { - dataprovider.UpdateUserTransferQuota(&user, ulSize, dlSize, false) //nolint:errcheck - }(atomic.LoadInt64(&t.BytesReceived), atomic.LoadInt64(&t.BytesSent), t.Connection.User) - + if t.transferQuota.HasSizeLimits() { + go func(ulSize, dlSize int64, user dataprovider.User) { + dataprovider.UpdateUserTransferQuota(&user, ulSize, dlSize, false) //nolint:errcheck + }(atomic.LoadInt64(&t.BytesReceived), atomic.LoadInt64(&t.BytesSent), t.Connection.User) + } atomic.StoreInt64(&t.BytesReceived, 0) } t.Unlock() @@ -321,8 +321,10 @@ func (t *BaseTransfer) Close() error { } metric.TransferCompleted(atomic.LoadInt64(&t.BytesSent), atomic.LoadInt64(&t.BytesReceived), t.transferType, t.ErrTransfer) - dataprovider.UpdateUserTransferQuota(&t.Connection.User, atomic.LoadInt64(&t.BytesReceived), //nolint:errcheck - atomic.LoadInt64(&t.BytesSent), false) + if t.transferQuota.HasSizeLimits() { + dataprovider.UpdateUserTransferQuota(&t.Connection.User, atomic.LoadInt64(&t.BytesReceived), //nolint:errcheck + atomic.LoadInt64(&t.BytesSent), false) + } if t.File != nil && t.Connection.IsQuotaExceededError(t.ErrTransfer) { // if quota is exceeded we try to remove the partial file for uploads to local filesystem err = t.Fs.Remove(t.File.Name(), false) diff --git a/dataprovider/dataprovider.go b/dataprovider/dataprovider.go index 20e4358b..388c4118 100644 --- a/dataprovider/dataprovider.go +++ b/dataprovider/dataprovider.go @@ -461,6 +461,11 @@ type TransferQuota struct { AllowedTotalSize int64 } +// HasSizeLimits returns true if any size limit is set +func (q *TransferQuota) HasSizeLimits() bool { + return q.AllowedDLSize > 0 || q.AllowedULSize > 0 || q.AllowedTotalSize > 0 +} + // HasUploadSpace returns true if there is transfer upload space available func (q *TransferQuota) HasUploadSpace() bool { if q.TotalSize <= 0 && q.ULSize <= 0 { diff --git a/webdavd/file.go b/webdavd/file.go index 141bb33b..5ecd33e3 100644 --- a/webdavd/file.go +++ b/webdavd/file.go @@ -233,6 +233,15 @@ func (f *webDavFile) updateStatInfo() error { return nil } +func (f *webDavFile) updateTransferQuotaOnSeek() { + transferQuota := f.GetTransferQuota() + if transferQuota.HasSizeLimits() { + go func(ulSize, dlSize int64, user dataprovider.User) { + dataprovider.UpdateUserTransferQuota(&user, ulSize, dlSize, false) //nolint:errcheck + }(atomic.LoadInt64(&f.BytesReceived), atomic.LoadInt64(&f.BytesSent), f.Connection.User) + } +} + // Seek sets the offset for the next Read or Write on the writer to offset, // interpreted according to whence: 0 means relative to the origin of the file, // 1 means relative to the current offset, and 2 means relative to the end. @@ -267,9 +276,7 @@ func (f *webDavFile) Seek(offset int64, whence int) (int64, error) { startByte := int64(0) atomic.StoreInt64(&f.BytesReceived, 0) atomic.StoreInt64(&f.BytesSent, 0) - go func(ulSize, dlSize int64, user dataprovider.User) { - dataprovider.UpdateUserTransferQuota(&user, ulSize, dlSize, false) //nolint:errcheck - }(atomic.LoadInt64(&f.BytesReceived), atomic.LoadInt64(&f.BytesSent), f.Connection.User) + f.updateTransferQuotaOnSeek() switch whence { case io.SeekStart: diff --git a/webdavd/internal_test.go b/webdavd/internal_test.go index 21c2950f..8557e88e 100644 --- a/webdavd/internal_test.go +++ b/webdavd/internal_test.go @@ -841,7 +841,7 @@ func TestTransferSeek(t *testing.T) { testFilePath := filepath.Join(user.HomeDir, testFile) testFileContents := []byte("content") baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, - common.TransferUpload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) + common.TransferUpload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{AllowedTotalSize: 100}) davFile := newWebDavFile(baseTransfer, nil, nil) _, err := davFile.Seek(0, io.SeekStart) assert.EqualError(t, err, common.ErrOpUnsupported.Error()) @@ -849,7 +849,7 @@ func TestTransferSeek(t *testing.T) { assert.NoError(t, err) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, - common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) + common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{AllowedTotalSize: 100}) davFile = newWebDavFile(baseTransfer, nil, nil) _, err = davFile.Seek(0, io.SeekCurrent) assert.True(t, os.IsNotExist(err)) @@ -863,14 +863,14 @@ func TestTransferSeek(t *testing.T) { assert.NoError(t, err) } baseTransfer = common.NewBaseTransfer(f, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, - common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) + common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{AllowedTotalSize: 100}) davFile = newWebDavFile(baseTransfer, nil, nil) _, err = davFile.Seek(0, io.SeekStart) assert.Error(t, err) davFile.Connection.RemoveTransfer(davFile.BaseTransfer) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, - common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) + common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{AllowedTotalSize: 100}) davFile = newWebDavFile(baseTransfer, nil, nil) res, err := davFile.Seek(0, io.SeekStart) assert.NoError(t, err) @@ -885,14 +885,14 @@ func TestTransferSeek(t *testing.T) { assert.Nil(t, err) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath+"1", testFilePath+"1", testFile, - common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) + common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{AllowedTotalSize: 100}) davFile = newWebDavFile(baseTransfer, nil, nil) _, err = davFile.Seek(0, io.SeekEnd) assert.True(t, os.IsNotExist(err)) davFile.Connection.RemoveTransfer(davFile.BaseTransfer) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, - common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) + common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{AllowedTotalSize: 100}) davFile = newWebDavFile(baseTransfer, nil, nil) davFile.reader = f davFile.Fs = newMockOsFs(nil, true, fs.ConnectionID(), user.GetHomeDir(), nil) @@ -907,7 +907,7 @@ func TestTransferSeek(t *testing.T) { assert.Equal(t, int64(5), res) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath+"1", testFilePath+"1", testFile, - common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) + common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{AllowedTotalSize: 100}) davFile = newWebDavFile(baseTransfer, nil, nil) davFile.Fs = newMockOsFs(nil, true, fs.ConnectionID(), user.GetHomeDir(), nil)