sftpd: add support for upload resume

we support resume only if the client sets the correct offset while resuming
the upload.
Based on the specs the offset is optional for resume, but all the tested
clients sets a right offset.
If an invalid offset is given we interrupt the transfer with the error
"Invalid write offset ..."

See https://github.com/pkg/sftp/issues/295

This commit add a new upload mode: "atomic with resume support", this acts
as atomic but if there is an upload error the temporary file is renamed
to the requested path and not deleted, this way a client can reconnect
and resume the upload
This commit is contained in:
Nicola Murino
2019-10-09 17:33:30 +02:00
parent 4f36c1de06
commit 1d917561fe
9 changed files with 331 additions and 137 deletions

View File

@@ -1,7 +1,9 @@
package sftpd_test
import (
"bytes"
"crypto/rand"
"crypto/sha256"
"fmt"
"io"
"io/ioutil"
@@ -112,10 +114,10 @@ func TestMain(m *testing.M) {
sftpdConf.LoginBannerFile = loginBannerFileName
// we need to test SCP support
sftpdConf.IsSCPEnabled = true
// we run the test cases with UploadMode atomic. The non atomic code path
// we run the test cases with UploadMode atomic and resume support. The non atomic code path
// simply does not execute some code so if it works in atomic mode will
// work in non atomic mode too
sftpdConf.UploadMode = 1
sftpdConf.UploadMode = 2
if runtime.GOOS == "windows" {
homeBasePath = "C:\\"
} else {
@@ -187,6 +189,7 @@ func TestBasicSFTPHandling(t *testing.T) {
if err != nil {
t.Errorf("unable to add user: %v", err)
}
os.RemoveAll(user.GetHomeDir())
client, err := getSftpClient(user, usePubKey)
if err != nil {
t.Errorf("unable to create sftp client: %v", err)
@@ -246,6 +249,67 @@ func TestBasicSFTPHandling(t *testing.T) {
os.RemoveAll(user.GetHomeDir())
}
func TestUploadResume(t *testing.T) {
usePubKey := false
u := getTestUser(usePubKey)
user, _, err := httpd.AddUser(u, http.StatusOK)
if err != nil {
t.Errorf("unable to add user: %v", err)
}
os.RemoveAll(user.GetHomeDir())
client, err := getSftpClient(user, usePubKey)
if err != nil {
t.Errorf("unable to create sftp client: %v", err)
} else {
defer client.Close()
testFileName := "test_file.dat"
testFilePath := filepath.Join(homeBasePath, testFileName)
testFileSize := int64(65535)
appendDataSize := int64(65535)
err = createTestFile(testFilePath, testFileSize)
if err != nil {
t.Errorf("unable to create test file: %v", err)
}
err = sftpUploadFile(testFilePath, testFileName, testFileSize, client)
if err != nil {
t.Errorf("file upload error: %v", err)
}
err = appendToTestFile(testFilePath, appendDataSize)
if err != nil {
t.Errorf("unable to append to test file: %v", err)
}
err = sftpUploadResumeFile(testFilePath, testFileName, testFileSize+appendDataSize, false, client)
if err != nil {
t.Errorf("file upload resume error: %v", err)
}
localDownloadPath := filepath.Join(homeBasePath, "test_download.dat")
err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize+appendDataSize, client)
if err != nil {
t.Errorf("file download error: %v", err)
}
initialHash, err := computeFileHash(localDownloadPath)
if err != nil {
t.Errorf("error computing file hash: %v", err)
}
donwloadedFileHash, err := computeFileHash(localDownloadPath)
if err != nil {
t.Errorf("error computing downloaded file hash: %v", err)
}
if donwloadedFileHash != initialHash {
t.Errorf("resume failed: file hash does not match")
}
err = sftpUploadResumeFile(testFilePath, testFileName, testFileSize+appendDataSize, true, client)
if err == nil {
t.Errorf("file upload resume with invalid offset must fail")
}
}
_, err = httpd.RemoveUser(user, http.StatusOK)
if err != nil {
t.Errorf("unable to remove user: %v", err)
}
os.RemoveAll(user.GetHomeDir())
}
func TestDirCommands(t *testing.T) {
usePubKey := false
user, _, err := httpd.AddUser(getTestUser(usePubKey), http.StatusOK)
@@ -2301,6 +2365,26 @@ func createTestFile(path string, size int64) error {
return ioutil.WriteFile(path, content, 0666)
}
func appendToTestFile(path string, size int64) error {
content := make([]byte, size)
_, err := rand.Read(content)
if err != nil {
return err
}
f, err := os.OpenFile(path, os.O_APPEND|os.O_WRONLY, 0666)
if err != nil {
return err
}
written, err := io.Copy(f, bytes.NewReader(content))
if err != nil {
return err
}
if int64(written) != size {
return fmt.Errorf("write error, written: %v/%v", written, size)
}
return nil
}
func sftpUploadFile(localSourcePath string, remoteDestPath string, expectedSize int64, client *sftp.Client) error {
srcFile, err := os.Open(localSourcePath)
if err != nil {
@@ -2331,6 +2415,53 @@ func sftpUploadFile(localSourcePath string, remoteDestPath string, expectedSize
return err
}
func sftpUploadResumeFile(localSourcePath string, remoteDestPath string, expectedSize int64, invalidOffset bool,
client *sftp.Client) error {
srcFile, err := os.Open(localSourcePath)
if err != nil {
return err
}
defer srcFile.Close()
fi, err := client.Lstat(remoteDestPath)
if err != nil {
return err
}
if !invalidOffset {
_, err = srcFile.Seek(fi.Size(), 0)
if err != nil {
return err
}
}
destFile, err := client.OpenFile(remoteDestPath, os.O_WRONLY|os.O_APPEND)
if err != nil {
return err
}
if !invalidOffset {
_, err = destFile.Seek(fi.Size(), 0)
if err != nil {
return err
}
}
_, err = io.Copy(destFile, srcFile)
if err != nil {
destFile.Close()
return err
}
// we need to close the file to trigger the close method on server
// we cannot defer closing or Lstat will fail for upload atomic mode
destFile.Close()
if expectedSize > 0 {
fi, err := client.Lstat(remoteDestPath)
if err != nil {
return err
}
if fi.Size() != expectedSize {
return fmt.Errorf("uploaded file size does not match, actual: %v, expected: %v", fi.Size(), expectedSize)
}
}
return err
}
func sftpDownloadFile(remoteSourcePath string, localDestPath string, expectedSize int64, client *sftp.Client) error {
downloadDest, err := os.Create(localDestPath)
if err != nil {
@@ -2432,6 +2563,21 @@ func getScpUploadCommand(localPath, remotePath string, preserveTime, remoteToRem
return exec.Command(scpPath, args...)
}
func computeFileHash(path string) (string, error) {
hash := ""
f, err := os.Open(path)
if err != nil {
return hash, err
}
defer f.Close()
h := sha256.New()
if _, err := io.Copy(h, f); err != nil {
return hash, err
}
hash = fmt.Sprintf("%x", h.Sum(nil))
return hash, err
}
func waitForNoActiveTransfer() {
for len(sftpd.GetConnectionsStats()) > 0 {
time.Sleep(100 * time.Millisecond)