diff --git a/sftpd/handler.go b/sftpd/handler.go index ac96f71d..adb29e9d 100644 --- a/sftpd/handler.go +++ b/sftpd/handler.go @@ -479,11 +479,16 @@ func (c Connection) handleSFTPUploadToExistingFile(pflags sftp.FileOpenFlags, re return nil, vfs.GetSFTPError(c.fs, err) } + initialSize := int64(0) if pflags.Append && osFlags&os.O_TRUNC == 0 { c.Log(logger.LevelDebug, logSender, "upload resume requested, file path: %#v initial size: %v", filePath, fileSize) minWriteOffset = fileSize } else { - dataprovider.UpdateUserQuota(dataProvider, c.User, 0, -fileSize, false) + if vfs.IsLocalOsFs(c.fs) { + dataprovider.UpdateUserQuota(dataProvider, c.User, 0, -fileSize, false) + } else { + initialSize = fileSize + } } vfs.SetPathPermissions(c.fs, filePath, c.User.GetUID(), c.User.GetGID()) @@ -506,6 +511,7 @@ func (c Connection) handleSFTPUploadToExistingFile(pflags sftp.FileOpenFlags, re transferError: nil, isFinished: false, minWriteOffset: minWriteOffset, + initialSize: initialSize, lock: new(sync.Mutex), } addTransfer(&transfer) diff --git a/sftpd/internal_test.go b/sftpd/internal_test.go index 66e44040..86a44913 100644 --- a/sftpd/internal_test.go +++ b/sftpd/internal_test.go @@ -425,6 +425,28 @@ func TestUploadFiles(t *testing.T) { if err == nil { t.Errorf("upload new file in missing path must fail") } + c.fs = newMockOsFs(nil, nil, false, "123", os.TempDir()) + f, _ := ioutil.TempFile("", "temp") + f.Close() + _, err = c.handleSFTPUploadToExistingFile(flags, f.Name(), f.Name(), 123) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if len(activeTransfers) != 1 { + t.Errorf("unexpected number of transfer, expected 1, current: %v", len(activeTransfers)) + } + transfer := activeTransfers[0] + if transfer.initialSize != 123 { + t.Errorf("unexpected initial size: %v", transfer.initialSize) + } + err = transfer.Close() + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if len(activeTransfers) != 0 { + t.Errorf("unexpected number of transfer, expected 0, current: %v", len(activeTransfers)) + } + os.Remove(f.Name()) uploadMode = oldUploadMode } @@ -899,6 +921,17 @@ func TestSystemCommandErrors(t *testing.T) { } } +func TestTransferUpdateQuota(t *testing.T) { + transfer := Transfer{ + transferType: transferUpload, + bytesReceived: 123, + lock: new(sync.Mutex)} + transfer.TransferError(errors.New("fake error")) + if transfer.updateQuota(1) { + t.Errorf("update quota must fail, there is a error and this is a remote upload") + } +} + func TestGetConnectionInfo(t *testing.T) { c := ConnectionStatus{ Username: "test_user", @@ -1222,6 +1255,10 @@ func TestSCPErrorsMockFs(t *testing.T) { if err != errFake { t.Errorf("unexpected error: %v", err) } + err = scpCommand.handleUploadFile(testfile, testfile, 0, false, 4) + if err != nil { + t.Errorf("unexpected error: %v", err) + } os.Remove(testfile) } diff --git a/sftpd/scp.go b/sftpd/scp.go index 6871861c..94552d49 100644 --- a/sftpd/scp.go +++ b/sftpd/scp.go @@ -181,7 +181,7 @@ func (c *scpCommand) getUploadFileData(sizeToRead int64, transfer *Transfer) err return c.sendConfirmationMessage() } -func (c *scpCommand) handleUploadFile(requestPath, filePath string, sizeToRead int64, isNewFile bool) error { +func (c *scpCommand) handleUploadFile(requestPath, filePath string, sizeToRead int64, isNewFile bool, fileSize int64) error { if !c.connection.hasSpace(true) { err := fmt.Errorf("denying file write due to space limit") c.connection.Log(logger.LevelWarn, logSenderSCP, "error uploading file: %#v, err: %v", filePath, err) @@ -189,6 +189,14 @@ func (c *scpCommand) handleUploadFile(requestPath, filePath string, sizeToRead i return err } + initialSize := int64(0) + if !isNewFile { + if vfs.IsLocalOsFs(c.connection.fs) { + dataprovider.UpdateUserQuota(dataProvider, c.connection.User, 0, -fileSize, false) + } else { + initialSize = fileSize + } + } file, w, cancelFn, err := c.connection.fs.Create(filePath, 0) if err != nil { c.connection.Log(logger.LevelError, logSenderSCP, "error creating file %#v: %v", requestPath, err) @@ -216,6 +224,7 @@ func (c *scpCommand) handleUploadFile(requestPath, filePath string, sizeToRead i transferError: nil, isFinished: false, minWriteOffset: 0, + initialSize: initialSize, lock: new(sync.Mutex), } addTransfer(&transfer) @@ -246,7 +255,7 @@ func (c *scpCommand) handleUpload(uploadFilePath string, sizeToRead int64) error c.sendErrorMessage(err.Error()) return err } - return c.handleUploadFile(p, filePath, sizeToRead, true) + return c.handleUploadFile(p, filePath, sizeToRead, true, 0) } if statErr != nil { @@ -279,9 +288,7 @@ func (c *scpCommand) handleUpload(uploadFilePath string, sizeToRead int64) error } } - dataprovider.UpdateUserQuota(dataProvider, c.connection.User, 0, -stat.Size(), false) - - return c.handleUploadFile(p, filePath, sizeToRead, false) + return c.handleUploadFile(p, filePath, sizeToRead, false, stat.Size()) } func (c *scpCommand) sendDownloadProtocolMessages(dirPath string, stat os.FileInfo) error { diff --git a/sftpd/sftpd_test.go b/sftpd/sftpd_test.go index 2dbd712b..d0a25d53 100644 --- a/sftpd/sftpd_test.go +++ b/sftpd/sftpd_test.go @@ -3123,7 +3123,7 @@ func TestResolvePaths(t *testing.T) { } path = "../test/sub" resolved, err = fs.ResolvePath(filepath.ToSlash(path)) - if fs.Name() == "osfs" { + if vfs.IsLocalOsFs(fs) { if err == nil { t.Errorf("Unexpected resolved path: %v for: %v, fs: %v", resolved, path, fs.Name()) } @@ -3134,7 +3134,7 @@ func TestResolvePaths(t *testing.T) { } path = "../../../test/../sub" resolved, err = fs.ResolvePath(filepath.ToSlash(path)) - if fs.Name() == "osfs" { + if vfs.IsLocalOsFs(fs) { if err == nil { t.Errorf("Unexpected resolved path: %v for: %v, fs: %v", resolved, path, fs.Name()) } @@ -4624,7 +4624,7 @@ func getKeyboardInteractiveScriptContent(questions []string, sleepTime int, nonJ content := []byte("#!/bin/sh\n\n") q, _ := json.Marshal(questions) echos := []bool{} - for index, _ := range questions { + for index := range questions { echos = append(echos, index%2 == 0) } e, _ := json.Marshal(echos) @@ -4633,7 +4633,7 @@ func getKeyboardInteractiveScriptContent(questions []string, sleepTime int, nonJ } else { content = append(content, []byte(fmt.Sprintf("echo '{\"questions\":%v,\"echos\":%v}'\n", string(q), string(e)))...) } - for index, _ := range questions { + for index := range questions { content = append(content, []byte(fmt.Sprintf("read ANSWER%v\n", index))...) } if sleepTime > 0 { diff --git a/sftpd/transfer.go b/sftpd/transfer.go index 7ec2fa55..5e691323 100644 --- a/sftpd/transfer.go +++ b/sftpd/transfer.go @@ -44,6 +44,7 @@ type Transfer struct { isFinished bool minWriteOffset int64 expectedSize int64 + initialSize int64 lock *sync.Mutex } @@ -163,9 +164,7 @@ func (t *Transfer) Close() error { } metrics.TransferCompleted(t.bytesSent, t.bytesReceived, t.transferType, t.transferError) removeTransfer(t) - if t.transferType == transferUpload && (numFiles != 0 || t.bytesReceived > 0) { - dataprovider.UpdateUserQuota(dataProvider, t.user, numFiles, t.bytesReceived, false) - } + t.updateQuota(numFiles) return err } @@ -181,6 +180,18 @@ func (t *Transfer) closeIO() error { return err } +func (t *Transfer) updateQuota(numFiles int) bool { + // S3 uploads are atomic, if there is an error nothing is uploaded + if t.file == nil && t.transferError != nil { + return false + } + if t.transferType == transferUpload && (numFiles != 0 || t.bytesReceived > 0) { + dataprovider.UpdateUserQuota(dataProvider, t.user, numFiles, t.bytesReceived-t.initialSize, false) + return true + } + return false +} + func (t *Transfer) checkDownloadSize() { if t.transferType == transferDownload && t.transferError == nil && t.bytesSent < t.expectedSize { t.transferError = fmt.Errorf("incomplete download: %v/%v bytes transferred", t.bytesSent, t.expectedSize)