mirror of
https://github.com/drakkan/sftpgo.git
synced 2025-12-07 14:50:55 +03:00
@@ -1069,6 +1069,78 @@ func TestSCPFileMode(t *testing.T) {
|
|||||||
assert.Equal(t, "1044", mode)
|
assert.Equal(t, "1044", mode)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSCPUploadError(t *testing.T) {
|
||||||
|
buf := make([]byte, 65535)
|
||||||
|
stdErrBuf := make([]byte, 65535)
|
||||||
|
writeErr := fmt.Errorf("test write error")
|
||||||
|
mockSSHChannel := MockChannel{
|
||||||
|
Buffer: bytes.NewBuffer(buf),
|
||||||
|
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
|
||||||
|
ReadError: nil,
|
||||||
|
WriteError: writeErr,
|
||||||
|
}
|
||||||
|
user := dataprovider.User{
|
||||||
|
HomeDir: filepath.Join(os.TempDir()),
|
||||||
|
Permissions: make(map[string][]string),
|
||||||
|
}
|
||||||
|
user.Permissions["/"] = []string{dataprovider.PermAny}
|
||||||
|
fs := vfs.NewOsFs("", user.HomeDir, nil)
|
||||||
|
|
||||||
|
connection := &Connection{
|
||||||
|
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs),
|
||||||
|
channel: &mockSSHChannel,
|
||||||
|
}
|
||||||
|
scpCommand := scpCommand{
|
||||||
|
sshCommand: sshCommand{
|
||||||
|
command: "scp",
|
||||||
|
connection: connection,
|
||||||
|
args: []string{"-t", "/"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
err := scpCommand.handle()
|
||||||
|
assert.EqualError(t, err, writeErr.Error())
|
||||||
|
|
||||||
|
mockSSHChannel = MockChannel{
|
||||||
|
Buffer: bytes.NewBuffer([]byte("D0755 0 testdir\n")),
|
||||||
|
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
|
||||||
|
ReadError: nil,
|
||||||
|
WriteError: writeErr,
|
||||||
|
}
|
||||||
|
err = scpCommand.handleRecursiveUpload()
|
||||||
|
assert.EqualError(t, err, writeErr.Error())
|
||||||
|
|
||||||
|
mockSSHChannel = MockChannel{
|
||||||
|
Buffer: bytes.NewBuffer([]byte("D0755 a testdir\n")),
|
||||||
|
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
|
||||||
|
ReadError: nil,
|
||||||
|
WriteError: nil,
|
||||||
|
}
|
||||||
|
err = scpCommand.handleRecursiveUpload()
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSCPInvalidEndDir(t *testing.T) {
|
||||||
|
stdErrBuf := make([]byte, 65535)
|
||||||
|
mockSSHChannel := MockChannel{
|
||||||
|
Buffer: bytes.NewBuffer([]byte("E\n")),
|
||||||
|
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
|
||||||
|
}
|
||||||
|
fs := vfs.NewOsFs("", os.TempDir(), nil)
|
||||||
|
connection := &Connection{
|
||||||
|
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, dataprovider.User{}, fs),
|
||||||
|
channel: &mockSSHChannel,
|
||||||
|
}
|
||||||
|
scpCommand := scpCommand{
|
||||||
|
sshCommand: sshCommand{
|
||||||
|
command: "scp",
|
||||||
|
connection: connection,
|
||||||
|
args: []string{"-t", "/tmp"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
err := scpCommand.handleRecursiveUpload()
|
||||||
|
assert.EqualError(t, err, "unacceptable end dir command")
|
||||||
|
}
|
||||||
|
|
||||||
func TestSCPParseUploadMessage(t *testing.T) {
|
func TestSCPParseUploadMessage(t *testing.T) {
|
||||||
buf := make([]byte, 65535)
|
buf := make([]byte, 65535)
|
||||||
stdErrBuf := make([]byte, 65535)
|
stdErrBuf := make([]byte, 65535)
|
||||||
|
|||||||
34
sftpd/scp.go
34
sftpd/scp.go
@@ -1,6 +1,7 @@
|
|||||||
package sftpd
|
package sftpd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"math"
|
"math"
|
||||||
@@ -45,6 +46,10 @@ func (c *scpCommand) handle() (err error) {
|
|||||||
c.args, c.connection.User.Username, commandType, destPath)
|
c.args, c.connection.User.Username, commandType, destPath)
|
||||||
if commandType == "-t" {
|
if commandType == "-t" {
|
||||||
// -t means "to", so upload
|
// -t means "to", so upload
|
||||||
|
err = c.sendConfirmationMessage()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
err = c.handleRecursiveUpload()
|
err = c.handleRecursiveUpload()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -68,31 +73,24 @@ func (c *scpCommand) handle() (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *scpCommand) handleRecursiveUpload() error {
|
func (c *scpCommand) handleRecursiveUpload() error {
|
||||||
var err error
|
|
||||||
numDirs := 0
|
numDirs := 0
|
||||||
destPath := c.getDestPath()
|
destPath := c.getDestPath()
|
||||||
for {
|
for {
|
||||||
err = c.sendConfirmationMessage()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
command, err := c.getNextUploadProtocolMessage()
|
command, err := c.getNextUploadProtocolMessage()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(command, "E") {
|
if strings.HasPrefix(command, "E") {
|
||||||
numDirs--
|
numDirs--
|
||||||
c.connection.Log(logger.LevelDebug, "received end dir command, num dirs: %v", numDirs)
|
c.connection.Log(logger.LevelDebug, "received end dir command, num dirs: %v", numDirs)
|
||||||
if numDirs == 0 {
|
if numDirs < 0 {
|
||||||
// upload is now complete send confirmation message
|
return errors.New("unacceptable end dir command")
|
||||||
err = c.sendConfirmationMessage()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
// the destination dir is now the parent directory
|
// the destination dir is now the parent directory
|
||||||
destPath = path.Join(destPath, "..")
|
destPath = path.Join(destPath, "..")
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
sizeToRead, name, err := c.parseUploadMessage(command)
|
sizeToRead, name, err := c.parseUploadMessage(command)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -113,11 +111,11 @@ func (c *scpCommand) handleRecursiveUpload() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err != nil || numDirs == 0 {
|
err = c.sendConfirmationMessage()
|
||||||
break
|
if err != nil {
|
||||||
}
|
|
||||||
}
|
|
||||||
return err
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *scpCommand) handleCreateDir(dirPath string) error {
|
func (c *scpCommand) handleCreateDir(dirPath string) error {
|
||||||
@@ -189,7 +187,7 @@ func (c *scpCommand) getUploadFileData(sizeToRead int64, transfer *transfer) err
|
|||||||
c.sendErrorMessage(err)
|
c.sendErrorMessage(err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return c.sendConfirmationMessage()
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *scpCommand) handleUploadFile(resolvedPath, filePath string, sizeToRead int64, isNewFile bool, fileSize int64, requestPath string) error {
|
func (c *scpCommand) handleUploadFile(resolvedPath, filePath string, sizeToRead int64, isNewFile bool, fileSize int64, requestPath string) error {
|
||||||
@@ -572,7 +570,7 @@ func (c *scpCommand) readProtocolMessage() (string, error) {
|
|||||||
command.Write(readed)
|
command.Write(readed)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil && !errors.Is(err, io.EOF) {
|
||||||
c.connection.channel.Close()
|
c.connection.channel.Close()
|
||||||
}
|
}
|
||||||
return command.String(), err
|
return command.String(), err
|
||||||
|
|||||||
@@ -712,7 +712,8 @@ func (c *sshCommand) sendExitStatus(err error) {
|
|||||||
exitStatus := sshSubsystemExitStatus{
|
exitStatus := sshSubsystemExitStatus{
|
||||||
Status: status,
|
Status: status,
|
||||||
}
|
}
|
||||||
c.connection.channel.(ssh.Channel).SendRequest("exit-status", false, ssh.Marshal(&exitStatus)) //nolint:errcheck
|
_, err = c.connection.channel.(ssh.Channel).SendRequest("exit-status", false, ssh.Marshal(&exitStatus))
|
||||||
|
c.connection.Log(logger.LevelDebug, "exit status sent, error: %v", err)
|
||||||
c.connection.channel.Close()
|
c.connection.channel.Close()
|
||||||
// for scp we notify single uploads/downloads
|
// for scp we notify single uploads/downloads
|
||||||
if c.command != scpCmdName {
|
if c.command != scpCmdName {
|
||||||
|
|||||||
Reference in New Issue
Block a user