update data transfer quota only if the current IP has some limits

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
Nicola Murino
2022-01-31 19:30:25 +01:00
parent 02db00d008
commit d51adb041e
5 changed files with 50 additions and 18 deletions

View File

@@ -1002,6 +1002,12 @@ func TestVirtualFoldersQuotaRenameOverwrite(t *testing.T) {
func TestQuotaRenameOverwrite(t *testing.T) { func TestQuotaRenameOverwrite(t *testing.T) {
u := getTestUser() u := getTestUser()
u.QuotaFiles = 100 u.QuotaFiles = 100
u.Filters.DataTransferLimits = []sdk.DataTransferLimit{
{
Sources: []string{"10.8.0.0/8"},
TotalDataTransfer: 1,
},
}
user, _, err := httpdtest.AddUser(u, http.StatusCreated) user, _, err := httpdtest.AddUser(u, http.StatusCreated)
assert.NoError(t, err) assert.NoError(t, err)
conn, client, err := getSftpClient(user) conn, client, err := getSftpClient(user)
@@ -1013,16 +1019,28 @@ func TestQuotaRenameOverwrite(t *testing.T) {
testFileName1 := "test_file1.dat" testFileName1 := "test_file1.dat"
err = writeSFTPFile(testFileName, testFileSize, client) err = writeSFTPFile(testFileName, testFileSize, client)
assert.NoError(t, err) assert.NoError(t, err)
f, err := client.Open(testFileName)
assert.NoError(t, err)
contents := make([]byte, testFileSize)
n, err := io.ReadFull(f, contents)
assert.NoError(t, err)
assert.Equal(t, int(testFileSize), n)
err = f.Close()
assert.NoError(t, err)
err = writeSFTPFile(testFileName1, testFileSize1, client) err = writeSFTPFile(testFileName1, testFileSize1, client)
assert.NoError(t, err) assert.NoError(t, err)
user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, int64(0), user.UsedDownloadDataTransfer)
assert.Equal(t, int64(0), user.UsedUploadDataTransfer)
assert.Equal(t, 2, user.UsedQuotaFiles) assert.Equal(t, 2, user.UsedQuotaFiles)
assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize)
err = client.Rename(testFileName, testFileName1) err = client.Rename(testFileName, testFileName1)
assert.NoError(t, err) assert.NoError(t, err)
user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, int64(0), user.UsedDownloadDataTransfer)
assert.Equal(t, int64(0), user.UsedUploadDataTransfer)
assert.Equal(t, 1, user.UsedQuotaFiles) assert.Equal(t, 1, user.UsedQuotaFiles)
assert.Equal(t, testFileSize, user.UsedQuotaSize) assert.Equal(t, testFileSize, user.UsedQuotaSize)
err = client.Remove(testFileName1) err = client.Remove(testFileName1)

View File

@@ -153,8 +153,7 @@ func (t *BaseTransfer) HasSizeLimit() bool {
if t.MaxWriteSize > 0 { if t.MaxWriteSize > 0 {
return true return true
} }
if t.transferQuota.AllowedDLSize > 0 || t.transferQuota.AllowedULSize > 0 || if t.transferQuota.HasSizeLimits() {
t.transferQuota.AllowedTotalSize > 0 {
return true return true
} }
@@ -249,10 +248,11 @@ func (t *BaseTransfer) Truncate(fsPath string, size int64) (int64, error) {
sizeDiff := initialSize - size sizeDiff := initialSize - size
t.MaxWriteSize += sizeDiff t.MaxWriteSize += sizeDiff
metric.TransferCompleted(atomic.LoadInt64(&t.BytesSent), atomic.LoadInt64(&t.BytesReceived), t.transferType, t.ErrTransfer) metric.TransferCompleted(atomic.LoadInt64(&t.BytesSent), atomic.LoadInt64(&t.BytesReceived), t.transferType, t.ErrTransfer)
go func(ulSize, dlSize int64, user dataprovider.User) { if t.transferQuota.HasSizeLimits() {
dataprovider.UpdateUserTransferQuota(&user, ulSize, dlSize, false) //nolint:errcheck go func(ulSize, dlSize int64, user dataprovider.User) {
}(atomic.LoadInt64(&t.BytesReceived), atomic.LoadInt64(&t.BytesSent), t.Connection.User) dataprovider.UpdateUserTransferQuota(&user, ulSize, dlSize, false) //nolint:errcheck
}(atomic.LoadInt64(&t.BytesReceived), atomic.LoadInt64(&t.BytesSent), t.Connection.User)
}
atomic.StoreInt64(&t.BytesReceived, 0) atomic.StoreInt64(&t.BytesReceived, 0)
} }
t.Unlock() t.Unlock()
@@ -321,8 +321,10 @@ func (t *BaseTransfer) Close() error {
} }
metric.TransferCompleted(atomic.LoadInt64(&t.BytesSent), atomic.LoadInt64(&t.BytesReceived), metric.TransferCompleted(atomic.LoadInt64(&t.BytesSent), atomic.LoadInt64(&t.BytesReceived),
t.transferType, t.ErrTransfer) t.transferType, t.ErrTransfer)
dataprovider.UpdateUserTransferQuota(&t.Connection.User, atomic.LoadInt64(&t.BytesReceived), //nolint:errcheck if t.transferQuota.HasSizeLimits() {
atomic.LoadInt64(&t.BytesSent), false) dataprovider.UpdateUserTransferQuota(&t.Connection.User, atomic.LoadInt64(&t.BytesReceived), //nolint:errcheck
atomic.LoadInt64(&t.BytesSent), false)
}
if t.File != nil && t.Connection.IsQuotaExceededError(t.ErrTransfer) { if t.File != nil && t.Connection.IsQuotaExceededError(t.ErrTransfer) {
// if quota is exceeded we try to remove the partial file for uploads to local filesystem // if quota is exceeded we try to remove the partial file for uploads to local filesystem
err = t.Fs.Remove(t.File.Name(), false) err = t.Fs.Remove(t.File.Name(), false)

View File

@@ -461,6 +461,11 @@ type TransferQuota struct {
AllowedTotalSize int64 AllowedTotalSize int64
} }
// HasSizeLimits returns true if any size limit is set
func (q *TransferQuota) HasSizeLimits() bool {
return q.AllowedDLSize > 0 || q.AllowedULSize > 0 || q.AllowedTotalSize > 0
}
// HasUploadSpace returns true if there is transfer upload space available // HasUploadSpace returns true if there is transfer upload space available
func (q *TransferQuota) HasUploadSpace() bool { func (q *TransferQuota) HasUploadSpace() bool {
if q.TotalSize <= 0 && q.ULSize <= 0 { if q.TotalSize <= 0 && q.ULSize <= 0 {

View File

@@ -233,6 +233,15 @@ func (f *webDavFile) updateStatInfo() error {
return nil return nil
} }
func (f *webDavFile) updateTransferQuotaOnSeek() {
transferQuota := f.GetTransferQuota()
if transferQuota.HasSizeLimits() {
go func(ulSize, dlSize int64, user dataprovider.User) {
dataprovider.UpdateUserTransferQuota(&user, ulSize, dlSize, false) //nolint:errcheck
}(atomic.LoadInt64(&f.BytesReceived), atomic.LoadInt64(&f.BytesSent), f.Connection.User)
}
}
// Seek sets the offset for the next Read or Write on the writer to offset, // Seek sets the offset for the next Read or Write on the writer to offset,
// interpreted according to whence: 0 means relative to the origin of the file, // interpreted according to whence: 0 means relative to the origin of the file,
// 1 means relative to the current offset, and 2 means relative to the end. // 1 means relative to the current offset, and 2 means relative to the end.
@@ -267,9 +276,7 @@ func (f *webDavFile) Seek(offset int64, whence int) (int64, error) {
startByte := int64(0) startByte := int64(0)
atomic.StoreInt64(&f.BytesReceived, 0) atomic.StoreInt64(&f.BytesReceived, 0)
atomic.StoreInt64(&f.BytesSent, 0) atomic.StoreInt64(&f.BytesSent, 0)
go func(ulSize, dlSize int64, user dataprovider.User) { f.updateTransferQuotaOnSeek()
dataprovider.UpdateUserTransferQuota(&user, ulSize, dlSize, false) //nolint:errcheck
}(atomic.LoadInt64(&f.BytesReceived), atomic.LoadInt64(&f.BytesSent), f.Connection.User)
switch whence { switch whence {
case io.SeekStart: case io.SeekStart:

View File

@@ -841,7 +841,7 @@ func TestTransferSeek(t *testing.T) {
testFilePath := filepath.Join(user.HomeDir, testFile) testFilePath := filepath.Join(user.HomeDir, testFile)
testFileContents := []byte("content") testFileContents := []byte("content")
baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
common.TransferUpload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) common.TransferUpload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{AllowedTotalSize: 100})
davFile := newWebDavFile(baseTransfer, nil, nil) davFile := newWebDavFile(baseTransfer, nil, nil)
_, err := davFile.Seek(0, io.SeekStart) _, err := davFile.Seek(0, io.SeekStart)
assert.EqualError(t, err, common.ErrOpUnsupported.Error()) assert.EqualError(t, err, common.ErrOpUnsupported.Error())
@@ -849,7 +849,7 @@ func TestTransferSeek(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{AllowedTotalSize: 100})
davFile = newWebDavFile(baseTransfer, nil, nil) davFile = newWebDavFile(baseTransfer, nil, nil)
_, err = davFile.Seek(0, io.SeekCurrent) _, err = davFile.Seek(0, io.SeekCurrent)
assert.True(t, os.IsNotExist(err)) assert.True(t, os.IsNotExist(err))
@@ -863,14 +863,14 @@ func TestTransferSeek(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
} }
baseTransfer = common.NewBaseTransfer(f, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, baseTransfer = common.NewBaseTransfer(f, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{AllowedTotalSize: 100})
davFile = newWebDavFile(baseTransfer, nil, nil) davFile = newWebDavFile(baseTransfer, nil, nil)
_, err = davFile.Seek(0, io.SeekStart) _, err = davFile.Seek(0, io.SeekStart)
assert.Error(t, err) assert.Error(t, err)
davFile.Connection.RemoveTransfer(davFile.BaseTransfer) davFile.Connection.RemoveTransfer(davFile.BaseTransfer)
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{AllowedTotalSize: 100})
davFile = newWebDavFile(baseTransfer, nil, nil) davFile = newWebDavFile(baseTransfer, nil, nil)
res, err := davFile.Seek(0, io.SeekStart) res, err := davFile.Seek(0, io.SeekStart)
assert.NoError(t, err) assert.NoError(t, err)
@@ -885,14 +885,14 @@ func TestTransferSeek(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath+"1", testFilePath+"1", testFile, baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath+"1", testFilePath+"1", testFile,
common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{AllowedTotalSize: 100})
davFile = newWebDavFile(baseTransfer, nil, nil) davFile = newWebDavFile(baseTransfer, nil, nil)
_, err = davFile.Seek(0, io.SeekEnd) _, err = davFile.Seek(0, io.SeekEnd)
assert.True(t, os.IsNotExist(err)) assert.True(t, os.IsNotExist(err))
davFile.Connection.RemoveTransfer(davFile.BaseTransfer) davFile.Connection.RemoveTransfer(davFile.BaseTransfer)
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile,
common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{AllowedTotalSize: 100})
davFile = newWebDavFile(baseTransfer, nil, nil) davFile = newWebDavFile(baseTransfer, nil, nil)
davFile.reader = f davFile.reader = f
davFile.Fs = newMockOsFs(nil, true, fs.ConnectionID(), user.GetHomeDir(), nil) davFile.Fs = newMockOsFs(nil, true, fs.ConnectionID(), user.GetHomeDir(), nil)
@@ -907,7 +907,7 @@ func TestTransferSeek(t *testing.T) {
assert.Equal(t, int64(5), res) assert.Equal(t, int64(5), res)
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath+"1", testFilePath+"1", testFile, baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath+"1", testFilePath+"1", testFile,
common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{AllowedTotalSize: 100})
davFile = newWebDavFile(baseTransfer, nil, nil) davFile = newWebDavFile(baseTransfer, nil, nil)
davFile.Fs = newMockOsFs(nil, true, fs.ConnectionID(), user.GetHomeDir(), nil) davFile.Fs = newMockOsFs(nil, true, fs.ConnectionID(), user.GetHomeDir(), nil)