diff --git a/sftpd/internal_test.go b/sftpd/internal_test.go index 53efaa2d..7b4bf911 100644 --- a/sftpd/internal_test.go +++ b/sftpd/internal_test.go @@ -436,7 +436,7 @@ func TestSCPCommandHandleErrors(t *testing.T) { } } -func TestRecursiveDownloadErrors(t *testing.T) { +func TestSCPRecursiveDownloadErrors(t *testing.T) { connection := Connection{} buf := make([]byte, 65535) stdErrBuf := make([]byte, 65535) @@ -475,7 +475,7 @@ func TestRecursiveDownloadErrors(t *testing.T) { os.Remove(path) } -func TestRecursiveUploadErrors(t *testing.T) { +func TestSCPRecursiveUploadErrors(t *testing.T) { connection := Connection{} buf := make([]byte, 65535) stdErrBuf := make([]byte, 65535) diff --git a/sftpd/scp.go b/sftpd/scp.go index 4b6c7071..94073d80 100644 --- a/sftpd/scp.go +++ b/sftpd/scp.go @@ -101,32 +101,16 @@ func (c *scpCommand) handleRecursiveUpload() error { if err != nil { return err } - objPath := path.Join(destPath, name) if strings.HasPrefix(command, "D") { numDirs++ - err = c.handleCreateDir(objPath) + destPath = path.Join(destPath, name) + err = c.handleCreateDir(destPath) if err != nil { return err } - destPath = objPath logger.Debug(logSenderSCP, "received start dir command, num dirs: %v destPath: %v", numDirs, destPath) } else if strings.HasPrefix(command, "C") { - // if the upload is not recursive and the destination path does not end with "/" - // then this is the wanted filename ... - if !c.isRecursive() { - if !strings.HasSuffix(destPath, "/") { - objPath = destPath - // ... but if the requested path is an existing directory then put the uploaded file inside that directory - if p, err := c.connection.buildPath(objPath); err == nil { - if stat, err := os.Stat(p); err == nil { - if stat.IsDir() { - objPath = path.Join(destPath, name) - } - } - } - } - } - err = c.handleUpload(objPath, sizeToRead) + err = c.handleUpload(c.getFileUploadDestPath(destPath, name), sizeToRead) if err != nil { return err } @@ -690,6 +674,31 @@ func (c *scpCommand) parseUploadMessage(command string) (int64, string, error) { return size, name, err } +func (c *scpCommand) getFileUploadDestPath(scpDestPath, fileName string) string { + if !c.isRecursive() { + // if the upload is not recursive and the destination path does not end with "/" + // then scpDestPath is the wanted filename, for example: + // scp fileName.txt user@127.0.0.1:/newFileName.txt + // or + // scp fileName.txt user@127.0.0.1:/fileName.txt + if !strings.HasSuffix(scpDestPath, "/") { + // but if scpDestPath is an existing directory then we put the uploaded file + // inside that directory this is as scp command works, for example: + // scp fileName.txt user@127.0.0.1:/existing_dir + if p, err := c.connection.buildPath(scpDestPath); err == nil { + if stat, err := os.Stat(p); err == nil { + if stat.IsDir() { + return path.Join(scpDestPath, fileName) + } + } + } + return scpDestPath + } + } + // if the upload is recursive then the destination file is relative to the current scpDestPath + return path.Join(scpDestPath, fileName) +} + func getFileModeAsString(fileMode os.FileMode, isDir bool) string { var defaultMode string if isDir { diff --git a/sftpd/sftpd_test.go b/sftpd/sftpd_test.go index e9e00d36..db228f29 100644 --- a/sftpd/sftpd_test.go +++ b/sftpd/sftpd_test.go @@ -275,6 +275,10 @@ func TestDirCommands(t *testing.T) { if err != nil { t.Errorf("error mkdir all: %v", err) } + _, err = client.ReadDir("/this/dir/does/not/exist") + if err == nil { + t.Errorf("reading a missing dir must fail") + } testFileName := "/test_file.dat" testFilePath := filepath.Join(homeBasePath, testFileName) testFileSize := int64(65535) @@ -334,6 +338,10 @@ func TestSymlink(t *testing.T) { if err != nil { t.Errorf("error creating symlink: %v", err) } + _, err = client.ReadLink(testFileName + ".link") + if err == nil { + t.Errorf("readlink is currently not implemented so must fail") + } err = client.Symlink(testFileName, testFileName+".link") if err == nil { t.Errorf("creating a symlink to an existing one must fail")