S3: fix quota update after an upload error

S3 uploads are atomic, if the upload fails we have no partial file so we
have to update the user quota only if the upload succeed
This commit is contained in:
Nicola Murino
2020-01-23 10:19:56 +01:00
parent 7ebbbe5c29
commit d481294519
5 changed files with 74 additions and 13 deletions

View File

@@ -479,11 +479,16 @@ func (c Connection) handleSFTPUploadToExistingFile(pflags sftp.FileOpenFlags, re
return nil, vfs.GetSFTPError(c.fs, err) return nil, vfs.GetSFTPError(c.fs, err)
} }
initialSize := int64(0)
if pflags.Append && osFlags&os.O_TRUNC == 0 { if pflags.Append && osFlags&os.O_TRUNC == 0 {
c.Log(logger.LevelDebug, logSender, "upload resume requested, file path: %#v initial size: %v", filePath, fileSize) c.Log(logger.LevelDebug, logSender, "upload resume requested, file path: %#v initial size: %v", filePath, fileSize)
minWriteOffset = fileSize minWriteOffset = fileSize
} else { } else {
dataprovider.UpdateUserQuota(dataProvider, c.User, 0, -fileSize, false) if vfs.IsLocalOsFs(c.fs) {
dataprovider.UpdateUserQuota(dataProvider, c.User, 0, -fileSize, false)
} else {
initialSize = fileSize
}
} }
vfs.SetPathPermissions(c.fs, filePath, c.User.GetUID(), c.User.GetGID()) vfs.SetPathPermissions(c.fs, filePath, c.User.GetUID(), c.User.GetGID())
@@ -506,6 +511,7 @@ func (c Connection) handleSFTPUploadToExistingFile(pflags sftp.FileOpenFlags, re
transferError: nil, transferError: nil,
isFinished: false, isFinished: false,
minWriteOffset: minWriteOffset, minWriteOffset: minWriteOffset,
initialSize: initialSize,
lock: new(sync.Mutex), lock: new(sync.Mutex),
} }
addTransfer(&transfer) addTransfer(&transfer)

View File

@@ -425,6 +425,28 @@ func TestUploadFiles(t *testing.T) {
if err == nil { if err == nil {
t.Errorf("upload new file in missing path must fail") t.Errorf("upload new file in missing path must fail")
} }
c.fs = newMockOsFs(nil, nil, false, "123", os.TempDir())
f, _ := ioutil.TempFile("", "temp")
f.Close()
_, err = c.handleSFTPUploadToExistingFile(flags, f.Name(), f.Name(), 123)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if len(activeTransfers) != 1 {
t.Errorf("unexpected number of transfer, expected 1, current: %v", len(activeTransfers))
}
transfer := activeTransfers[0]
if transfer.initialSize != 123 {
t.Errorf("unexpected initial size: %v", transfer.initialSize)
}
err = transfer.Close()
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if len(activeTransfers) != 0 {
t.Errorf("unexpected number of transfer, expected 0, current: %v", len(activeTransfers))
}
os.Remove(f.Name())
uploadMode = oldUploadMode uploadMode = oldUploadMode
} }
@@ -899,6 +921,17 @@ func TestSystemCommandErrors(t *testing.T) {
} }
} }
func TestTransferUpdateQuota(t *testing.T) {
transfer := Transfer{
transferType: transferUpload,
bytesReceived: 123,
lock: new(sync.Mutex)}
transfer.TransferError(errors.New("fake error"))
if transfer.updateQuota(1) {
t.Errorf("update quota must fail, there is a error and this is a remote upload")
}
}
func TestGetConnectionInfo(t *testing.T) { func TestGetConnectionInfo(t *testing.T) {
c := ConnectionStatus{ c := ConnectionStatus{
Username: "test_user", Username: "test_user",
@@ -1222,6 +1255,10 @@ func TestSCPErrorsMockFs(t *testing.T) {
if err != errFake { if err != errFake {
t.Errorf("unexpected error: %v", err) t.Errorf("unexpected error: %v", err)
} }
err = scpCommand.handleUploadFile(testfile, testfile, 0, false, 4)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
os.Remove(testfile) os.Remove(testfile)
} }

View File

@@ -181,7 +181,7 @@ func (c *scpCommand) getUploadFileData(sizeToRead int64, transfer *Transfer) err
return c.sendConfirmationMessage() return c.sendConfirmationMessage()
} }
func (c *scpCommand) handleUploadFile(requestPath, filePath string, sizeToRead int64, isNewFile bool) error { func (c *scpCommand) handleUploadFile(requestPath, filePath string, sizeToRead int64, isNewFile bool, fileSize int64) error {
if !c.connection.hasSpace(true) { if !c.connection.hasSpace(true) {
err := fmt.Errorf("denying file write due to space limit") err := fmt.Errorf("denying file write due to space limit")
c.connection.Log(logger.LevelWarn, logSenderSCP, "error uploading file: %#v, err: %v", filePath, err) c.connection.Log(logger.LevelWarn, logSenderSCP, "error uploading file: %#v, err: %v", filePath, err)
@@ -189,6 +189,14 @@ func (c *scpCommand) handleUploadFile(requestPath, filePath string, sizeToRead i
return err return err
} }
initialSize := int64(0)
if !isNewFile {
if vfs.IsLocalOsFs(c.connection.fs) {
dataprovider.UpdateUserQuota(dataProvider, c.connection.User, 0, -fileSize, false)
} else {
initialSize = fileSize
}
}
file, w, cancelFn, err := c.connection.fs.Create(filePath, 0) file, w, cancelFn, err := c.connection.fs.Create(filePath, 0)
if err != nil { if err != nil {
c.connection.Log(logger.LevelError, logSenderSCP, "error creating file %#v: %v", requestPath, err) c.connection.Log(logger.LevelError, logSenderSCP, "error creating file %#v: %v", requestPath, err)
@@ -216,6 +224,7 @@ func (c *scpCommand) handleUploadFile(requestPath, filePath string, sizeToRead i
transferError: nil, transferError: nil,
isFinished: false, isFinished: false,
minWriteOffset: 0, minWriteOffset: 0,
initialSize: initialSize,
lock: new(sync.Mutex), lock: new(sync.Mutex),
} }
addTransfer(&transfer) addTransfer(&transfer)
@@ -246,7 +255,7 @@ func (c *scpCommand) handleUpload(uploadFilePath string, sizeToRead int64) error
c.sendErrorMessage(err.Error()) c.sendErrorMessage(err.Error())
return err return err
} }
return c.handleUploadFile(p, filePath, sizeToRead, true) return c.handleUploadFile(p, filePath, sizeToRead, true, 0)
} }
if statErr != nil { if statErr != nil {
@@ -279,9 +288,7 @@ func (c *scpCommand) handleUpload(uploadFilePath string, sizeToRead int64) error
} }
} }
dataprovider.UpdateUserQuota(dataProvider, c.connection.User, 0, -stat.Size(), false) return c.handleUploadFile(p, filePath, sizeToRead, false, stat.Size())
return c.handleUploadFile(p, filePath, sizeToRead, false)
} }
func (c *scpCommand) sendDownloadProtocolMessages(dirPath string, stat os.FileInfo) error { func (c *scpCommand) sendDownloadProtocolMessages(dirPath string, stat os.FileInfo) error {

View File

@@ -3123,7 +3123,7 @@ func TestResolvePaths(t *testing.T) {
} }
path = "../test/sub" path = "../test/sub"
resolved, err = fs.ResolvePath(filepath.ToSlash(path)) resolved, err = fs.ResolvePath(filepath.ToSlash(path))
if fs.Name() == "osfs" { if vfs.IsLocalOsFs(fs) {
if err == nil { if err == nil {
t.Errorf("Unexpected resolved path: %v for: %v, fs: %v", resolved, path, fs.Name()) t.Errorf("Unexpected resolved path: %v for: %v, fs: %v", resolved, path, fs.Name())
} }
@@ -3134,7 +3134,7 @@ func TestResolvePaths(t *testing.T) {
} }
path = "../../../test/../sub" path = "../../../test/../sub"
resolved, err = fs.ResolvePath(filepath.ToSlash(path)) resolved, err = fs.ResolvePath(filepath.ToSlash(path))
if fs.Name() == "osfs" { if vfs.IsLocalOsFs(fs) {
if err == nil { if err == nil {
t.Errorf("Unexpected resolved path: %v for: %v, fs: %v", resolved, path, fs.Name()) t.Errorf("Unexpected resolved path: %v for: %v, fs: %v", resolved, path, fs.Name())
} }
@@ -4624,7 +4624,7 @@ func getKeyboardInteractiveScriptContent(questions []string, sleepTime int, nonJ
content := []byte("#!/bin/sh\n\n") content := []byte("#!/bin/sh\n\n")
q, _ := json.Marshal(questions) q, _ := json.Marshal(questions)
echos := []bool{} echos := []bool{}
for index, _ := range questions { for index := range questions {
echos = append(echos, index%2 == 0) echos = append(echos, index%2 == 0)
} }
e, _ := json.Marshal(echos) e, _ := json.Marshal(echos)
@@ -4633,7 +4633,7 @@ func getKeyboardInteractiveScriptContent(questions []string, sleepTime int, nonJ
} else { } else {
content = append(content, []byte(fmt.Sprintf("echo '{\"questions\":%v,\"echos\":%v}'\n", string(q), string(e)))...) content = append(content, []byte(fmt.Sprintf("echo '{\"questions\":%v,\"echos\":%v}'\n", string(q), string(e)))...)
} }
for index, _ := range questions { for index := range questions {
content = append(content, []byte(fmt.Sprintf("read ANSWER%v\n", index))...) content = append(content, []byte(fmt.Sprintf("read ANSWER%v\n", index))...)
} }
if sleepTime > 0 { if sleepTime > 0 {

View File

@@ -44,6 +44,7 @@ type Transfer struct {
isFinished bool isFinished bool
minWriteOffset int64 minWriteOffset int64
expectedSize int64 expectedSize int64
initialSize int64
lock *sync.Mutex lock *sync.Mutex
} }
@@ -163,9 +164,7 @@ func (t *Transfer) Close() error {
} }
metrics.TransferCompleted(t.bytesSent, t.bytesReceived, t.transferType, t.transferError) metrics.TransferCompleted(t.bytesSent, t.bytesReceived, t.transferType, t.transferError)
removeTransfer(t) removeTransfer(t)
if t.transferType == transferUpload && (numFiles != 0 || t.bytesReceived > 0) { t.updateQuota(numFiles)
dataprovider.UpdateUserQuota(dataProvider, t.user, numFiles, t.bytesReceived, false)
}
return err return err
} }
@@ -181,6 +180,18 @@ func (t *Transfer) closeIO() error {
return err return err
} }
func (t *Transfer) updateQuota(numFiles int) bool {
// S3 uploads are atomic, if there is an error nothing is uploaded
if t.file == nil && t.transferError != nil {
return false
}
if t.transferType == transferUpload && (numFiles != 0 || t.bytesReceived > 0) {
dataprovider.UpdateUserQuota(dataProvider, t.user, numFiles, t.bytesReceived-t.initialSize, false)
return true
}
return false
}
func (t *Transfer) checkDownloadSize() { func (t *Transfer) checkDownloadSize() {
if t.transferType == transferDownload && t.transferError == nil && t.bytesSent < t.expectedSize { if t.transferType == transferDownload && t.transferError == nil && t.bytesSent < t.expectedSize {
t.transferError = fmt.Errorf("incomplete download: %v/%v bytes transferred", t.bytesSent, t.expectedSize) t.transferError = fmt.Errorf("incomplete download: %v/%v bytes transferred", t.bytesSent, t.expectedSize)