scp: fix wildcard uploads

Fixes #285
This commit is contained in:
Nicola Murino
2021-01-20 22:37:59 +01:00
parent 57976b4085
commit c0e09374a8
3 changed files with 91 additions and 20 deletions

View File

@@ -1069,6 +1069,78 @@ func TestSCPFileMode(t *testing.T) {
assert.Equal(t, "1044", mode) assert.Equal(t, "1044", mode)
} }
func TestSCPUploadError(t *testing.T) {
buf := make([]byte, 65535)
stdErrBuf := make([]byte, 65535)
writeErr := fmt.Errorf("test write error")
mockSSHChannel := MockChannel{
Buffer: bytes.NewBuffer(buf),
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
ReadError: nil,
WriteError: writeErr,
}
user := dataprovider.User{
HomeDir: filepath.Join(os.TempDir()),
Permissions: make(map[string][]string),
}
user.Permissions["/"] = []string{dataprovider.PermAny}
fs := vfs.NewOsFs("", user.HomeDir, nil)
connection := &Connection{
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs),
channel: &mockSSHChannel,
}
scpCommand := scpCommand{
sshCommand: sshCommand{
command: "scp",
connection: connection,
args: []string{"-t", "/"},
},
}
err := scpCommand.handle()
assert.EqualError(t, err, writeErr.Error())
mockSSHChannel = MockChannel{
Buffer: bytes.NewBuffer([]byte("D0755 0 testdir\n")),
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
ReadError: nil,
WriteError: writeErr,
}
err = scpCommand.handleRecursiveUpload()
assert.EqualError(t, err, writeErr.Error())
mockSSHChannel = MockChannel{
Buffer: bytes.NewBuffer([]byte("D0755 a testdir\n")),
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
ReadError: nil,
WriteError: nil,
}
err = scpCommand.handleRecursiveUpload()
assert.Error(t, err)
}
func TestSCPInvalidEndDir(t *testing.T) {
stdErrBuf := make([]byte, 65535)
mockSSHChannel := MockChannel{
Buffer: bytes.NewBuffer([]byte("E\n")),
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
}
fs := vfs.NewOsFs("", os.TempDir(), nil)
connection := &Connection{
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, dataprovider.User{}, fs),
channel: &mockSSHChannel,
}
scpCommand := scpCommand{
sshCommand: sshCommand{
command: "scp",
connection: connection,
args: []string{"-t", "/tmp"},
},
}
err := scpCommand.handleRecursiveUpload()
assert.EqualError(t, err, "unacceptable end dir command")
}
func TestSCPParseUploadMessage(t *testing.T) { func TestSCPParseUploadMessage(t *testing.T) {
buf := make([]byte, 65535) buf := make([]byte, 65535)
stdErrBuf := make([]byte, 65535) stdErrBuf := make([]byte, 65535)

View File

@@ -1,6 +1,7 @@
package sftpd package sftpd
import ( import (
"errors"
"fmt" "fmt"
"io" "io"
"math" "math"
@@ -45,6 +46,10 @@ func (c *scpCommand) handle() (err error) {
c.args, c.connection.User.Username, commandType, destPath) c.args, c.connection.User.Username, commandType, destPath)
if commandType == "-t" { if commandType == "-t" {
// -t means "to", so upload // -t means "to", so upload
err = c.sendConfirmationMessage()
if err != nil {
return err
}
err = c.handleRecursiveUpload() err = c.handleRecursiveUpload()
if err != nil { if err != nil {
return err return err
@@ -68,31 +73,24 @@ func (c *scpCommand) handle() (err error) {
} }
func (c *scpCommand) handleRecursiveUpload() error { func (c *scpCommand) handleRecursiveUpload() error {
var err error
numDirs := 0 numDirs := 0
destPath := c.getDestPath() destPath := c.getDestPath()
for { for {
err = c.sendConfirmationMessage()
if err != nil {
return err
}
command, err := c.getNextUploadProtocolMessage() command, err := c.getNextUploadProtocolMessage()
if err != nil { if err != nil {
if errors.Is(err, io.EOF) {
return nil
}
return err return err
} }
if strings.HasPrefix(command, "E") { if strings.HasPrefix(command, "E") {
numDirs-- numDirs--
c.connection.Log(logger.LevelDebug, "received end dir command, num dirs: %v", numDirs) c.connection.Log(logger.LevelDebug, "received end dir command, num dirs: %v", numDirs)
if numDirs == 0 { if numDirs < 0 {
// upload is now complete send confirmation message return errors.New("unacceptable end dir command")
err = c.sendConfirmationMessage()
if err != nil {
return err
} }
} else {
// the destination dir is now the parent directory // the destination dir is now the parent directory
destPath = path.Join(destPath, "..") destPath = path.Join(destPath, "..")
}
} else { } else {
sizeToRead, name, err := c.parseUploadMessage(command) sizeToRead, name, err := c.parseUploadMessage(command)
if err != nil { if err != nil {
@@ -113,12 +111,12 @@ func (c *scpCommand) handleRecursiveUpload() error {
} }
} }
} }
if err != nil || numDirs == 0 { err = c.sendConfirmationMessage()
break if err != nil {
}
}
return err return err
} }
}
}
func (c *scpCommand) handleCreateDir(dirPath string) error { func (c *scpCommand) handleCreateDir(dirPath string) error {
c.connection.UpdateLastActivity() c.connection.UpdateLastActivity()
@@ -189,7 +187,7 @@ func (c *scpCommand) getUploadFileData(sizeToRead int64, transfer *transfer) err
c.sendErrorMessage(err) c.sendErrorMessage(err)
return err return err
} }
return c.sendConfirmationMessage() return nil
} }
func (c *scpCommand) handleUploadFile(resolvedPath, filePath string, sizeToRead int64, isNewFile bool, fileSize int64, requestPath string) error { func (c *scpCommand) handleUploadFile(resolvedPath, filePath string, sizeToRead int64, isNewFile bool, fileSize int64, requestPath string) error {
@@ -572,7 +570,7 @@ func (c *scpCommand) readProtocolMessage() (string, error) {
command.Write(readed) command.Write(readed)
} }
} }
if err != nil { if err != nil && !errors.Is(err, io.EOF) {
c.connection.channel.Close() c.connection.channel.Close()
} }
return command.String(), err return command.String(), err

View File

@@ -712,7 +712,8 @@ func (c *sshCommand) sendExitStatus(err error) {
exitStatus := sshSubsystemExitStatus{ exitStatus := sshSubsystemExitStatus{
Status: status, Status: status,
} }
c.connection.channel.(ssh.Channel).SendRequest("exit-status", false, ssh.Marshal(&exitStatus)) //nolint:errcheck _, err = c.connection.channel.(ssh.Channel).SendRequest("exit-status", false, ssh.Marshal(&exitStatus))
c.connection.Log(logger.LevelDebug, "exit status sent, error: %v", err)
c.connection.channel.Close() c.connection.channel.Close()
// for scp we notify single uploads/downloads // for scp we notify single uploads/downloads
if c.command != scpCmdName { if c.command != scpCmdName {