From c0e09374a8ce7d1b09628a277cd0f2a3cbab8013 Mon Sep 17 00:00:00 2001 From: Nicola Murino Date: Wed, 20 Jan 2021 22:37:59 +0100 Subject: [PATCH] scp: fix wildcard uploads Fixes #285 --- sftpd/internal_test.go | 72 ++++++++++++++++++++++++++++++++++++++++++ sftpd/scp.go | 36 ++++++++++----------- sftpd/ssh_cmd.go | 3 +- 3 files changed, 91 insertions(+), 20 deletions(-) diff --git a/sftpd/internal_test.go b/sftpd/internal_test.go index 676c27a6..9093847d 100644 --- a/sftpd/internal_test.go +++ b/sftpd/internal_test.go @@ -1069,6 +1069,78 @@ func TestSCPFileMode(t *testing.T) { assert.Equal(t, "1044", mode) } +func TestSCPUploadError(t *testing.T) { + buf := make([]byte, 65535) + stdErrBuf := make([]byte, 65535) + writeErr := fmt.Errorf("test write error") + mockSSHChannel := MockChannel{ + Buffer: bytes.NewBuffer(buf), + StdErrBuffer: bytes.NewBuffer(stdErrBuf), + ReadError: nil, + WriteError: writeErr, + } + user := dataprovider.User{ + HomeDir: filepath.Join(os.TempDir()), + Permissions: make(map[string][]string), + } + user.Permissions["/"] = []string{dataprovider.PermAny} + fs := vfs.NewOsFs("", user.HomeDir, nil) + + connection := &Connection{ + BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs), + channel: &mockSSHChannel, + } + scpCommand := scpCommand{ + sshCommand: sshCommand{ + command: "scp", + connection: connection, + args: []string{"-t", "/"}, + }, + } + err := scpCommand.handle() + assert.EqualError(t, err, writeErr.Error()) + + mockSSHChannel = MockChannel{ + Buffer: bytes.NewBuffer([]byte("D0755 0 testdir\n")), + StdErrBuffer: bytes.NewBuffer(stdErrBuf), + ReadError: nil, + WriteError: writeErr, + } + err = scpCommand.handleRecursiveUpload() + assert.EqualError(t, err, writeErr.Error()) + + mockSSHChannel = MockChannel{ + Buffer: bytes.NewBuffer([]byte("D0755 a testdir\n")), + StdErrBuffer: bytes.NewBuffer(stdErrBuf), + ReadError: nil, + WriteError: nil, + } + err = scpCommand.handleRecursiveUpload() + assert.Error(t, err) +} + +func TestSCPInvalidEndDir(t *testing.T) { + stdErrBuf := make([]byte, 65535) + mockSSHChannel := MockChannel{ + Buffer: bytes.NewBuffer([]byte("E\n")), + StdErrBuffer: bytes.NewBuffer(stdErrBuf), + } + fs := vfs.NewOsFs("", os.TempDir(), nil) + connection := &Connection{ + BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, dataprovider.User{}, fs), + channel: &mockSSHChannel, + } + scpCommand := scpCommand{ + sshCommand: sshCommand{ + command: "scp", + connection: connection, + args: []string{"-t", "/tmp"}, + }, + } + err := scpCommand.handleRecursiveUpload() + assert.EqualError(t, err, "unacceptable end dir command") +} + func TestSCPParseUploadMessage(t *testing.T) { buf := make([]byte, 65535) stdErrBuf := make([]byte, 65535) diff --git a/sftpd/scp.go b/sftpd/scp.go index b7980fcd..7346efa9 100644 --- a/sftpd/scp.go +++ b/sftpd/scp.go @@ -1,6 +1,7 @@ package sftpd import ( + "errors" "fmt" "io" "math" @@ -45,6 +46,10 @@ func (c *scpCommand) handle() (err error) { c.args, c.connection.User.Username, commandType, destPath) if commandType == "-t" { // -t means "to", so upload + err = c.sendConfirmationMessage() + if err != nil { + return err + } err = c.handleRecursiveUpload() if err != nil { return err @@ -68,31 +73,24 @@ func (c *scpCommand) handle() (err error) { } func (c *scpCommand) handleRecursiveUpload() error { - var err error numDirs := 0 destPath := c.getDestPath() for { - err = c.sendConfirmationMessage() - if err != nil { - return err - } command, err := c.getNextUploadProtocolMessage() if err != nil { + if errors.Is(err, io.EOF) { + return nil + } return err } if strings.HasPrefix(command, "E") { numDirs-- c.connection.Log(logger.LevelDebug, "received end dir command, num dirs: %v", numDirs) - if numDirs == 0 { - // upload is now complete send confirmation message - err = c.sendConfirmationMessage() - if err != nil { - return err - } - } else { - // the destination dir is now the parent directory - destPath = path.Join(destPath, "..") + if numDirs < 0 { + return errors.New("unacceptable end dir command") } + // the destination dir is now the parent directory + destPath = path.Join(destPath, "..") } else { sizeToRead, name, err := c.parseUploadMessage(command) if err != nil { @@ -113,11 +111,11 @@ func (c *scpCommand) handleRecursiveUpload() error { } } } - if err != nil || numDirs == 0 { - break + err = c.sendConfirmationMessage() + if err != nil { + return err } } - return err } func (c *scpCommand) handleCreateDir(dirPath string) error { @@ -189,7 +187,7 @@ func (c *scpCommand) getUploadFileData(sizeToRead int64, transfer *transfer) err c.sendErrorMessage(err) return err } - return c.sendConfirmationMessage() + return nil } func (c *scpCommand) handleUploadFile(resolvedPath, filePath string, sizeToRead int64, isNewFile bool, fileSize int64, requestPath string) error { @@ -572,7 +570,7 @@ func (c *scpCommand) readProtocolMessage() (string, error) { command.Write(readed) } } - if err != nil { + if err != nil && !errors.Is(err, io.EOF) { c.connection.channel.Close() } return command.String(), err diff --git a/sftpd/ssh_cmd.go b/sftpd/ssh_cmd.go index 892dfa8e..56fef9f2 100644 --- a/sftpd/ssh_cmd.go +++ b/sftpd/ssh_cmd.go @@ -712,7 +712,8 @@ func (c *sshCommand) sendExitStatus(err error) { exitStatus := sshSubsystemExitStatus{ Status: status, } - c.connection.channel.(ssh.Channel).SendRequest("exit-status", false, ssh.Marshal(&exitStatus)) //nolint:errcheck + _, err = c.connection.channel.(ssh.Channel).SendRequest("exit-status", false, ssh.Marshal(&exitStatus)) + c.connection.Log(logger.LevelDebug, "exit status sent, error: %v", err) c.connection.channel.Close() // for scp we notify single uploads/downloads if c.command != scpCmdName {