diff --git a/sftpd/handler.go b/sftpd/handler.go index 2dd05a9b..04cdf5e3 100644 --- a/sftpd/handler.go +++ b/sftpd/handler.go @@ -13,6 +13,7 @@ import ( "github.com/drakkan/sftpgo/utils" "github.com/rs/xid" + "golang.org/x/crypto/ssh" "github.com/drakkan/sftpgo/dataprovider" "github.com/drakkan/sftpgo/logger" @@ -37,6 +38,7 @@ type Connection struct { protocol string lock *sync.Mutex netConn net.Conn + channel ssh.Channel } // Log outputs a log entry to the configured logger @@ -580,6 +582,10 @@ func (c Connection) createMissingDirs(filePath string) error { } func (c Connection) close() error { + if c.channel != nil { + err := c.channel.Close() + c.Log(logger.LevelInfo, logSender, "channel close, err: %v", err) + } return c.netConn.Close() } diff --git a/sftpd/internal_test.go b/sftpd/internal_test.go index 1ebcf14a..46374c6a 100644 --- a/sftpd/internal_test.go +++ b/sftpd/internal_test.go @@ -252,7 +252,6 @@ func TestSCPGetNonExistingDirContent(t *testing.T) { } func TestSCPParseUploadMessage(t *testing.T) { - connection := Connection{} buf := make([]byte, 65535) stdErrBuf := make([]byte, 65535) mockSSHChannel := MockChannel{ @@ -260,10 +259,12 @@ func TestSCPParseUploadMessage(t *testing.T) { StdErrBuffer: bytes.NewBuffer(stdErrBuf), ReadError: nil, } + connection := Connection{ + channel: &mockSSHChannel, + } scpCommand := scpCommand{ connection: connection, args: []string{"-t", "/tmp"}, - channel: &mockSSHChannel, } _, _, err := scpCommand.parseUploadMessage("invalid") if err == nil { @@ -284,7 +285,6 @@ func TestSCPParseUploadMessage(t *testing.T) { } func TestSCPProtocolMessages(t *testing.T) { - connection := Connection{} buf := make([]byte, 65535) stdErrBuf := make([]byte, 65535) readErr := fmt.Errorf("test read error") @@ -295,10 +295,12 @@ func TestSCPProtocolMessages(t *testing.T) { ReadError: readErr, WriteError: writeErr, } + connection := Connection{ + channel: &mockSSHChannel, + } scpCommand := scpCommand{ connection: connection, args: []string{"-t", "/tmp"}, - channel: &mockSSHChannel, } _, err := scpCommand.readProtocolMessage() if err == nil || err != readErr { @@ -322,7 +324,7 @@ func TestSCPProtocolMessages(t *testing.T) { ReadError: nil, WriteError: writeErr, } - scpCommand.channel = &mockSSHChannel + scpCommand.connection.channel = &mockSSHChannel _, err = scpCommand.getNextUploadProtocolMessage() if err == nil || err != writeErr { t.Errorf("read next upload protocol message must fail, we are sending a fake write error") @@ -337,7 +339,7 @@ func TestSCPProtocolMessages(t *testing.T) { ReadError: nil, WriteError: nil, } - scpCommand.channel = &mockSSHChannel + scpCommand.connection.channel = &mockSSHChannel err = scpCommand.readConfirmationMessage() if err == nil || err.Error() != protocolErrorMsg { t.Errorf("read confirmation message must return the expected protocol error, actual err: %v", err) @@ -345,7 +347,6 @@ func TestSCPProtocolMessages(t *testing.T) { } func TestSCPTestDownloadProtocolMessages(t *testing.T) { - connection := Connection{} buf := make([]byte, 65535) stdErrBuf := make([]byte, 65535) readErr := fmt.Errorf("test read error") @@ -356,10 +357,12 @@ func TestSCPTestDownloadProtocolMessages(t *testing.T) { ReadError: readErr, WriteError: writeErr, } + connection := Connection{ + channel: &mockSSHChannel, + } scpCommand := scpCommand{ connection: connection, args: []string{"-f", "-p", "/tmp"}, - channel: &mockSSHChannel, } path := "testDir" os.Mkdir(path, 0777) @@ -388,7 +391,7 @@ func TestSCPTestDownloadProtocolMessages(t *testing.T) { WriteError: writeErr, } scpCommand.args = []string{"-f", "/tmp"} - scpCommand.channel = &mockSSHChannel + scpCommand.connection.channel = &mockSSHChannel err = scpCommand.sendDownloadProtocolMessages(path, stat) if err != writeErr { t.Errorf("sendDownloadProtocolMessages must return the expected error: %v", err) @@ -400,7 +403,7 @@ func TestSCPTestDownloadProtocolMessages(t *testing.T) { ReadError: readErr, WriteError: nil, } - scpCommand.channel = &mockSSHChannel + scpCommand.connection.channel = &mockSSHChannel err = scpCommand.sendDownloadProtocolMessages(path, stat) if err != readErr { t.Errorf("sendDownloadProtocolMessages must return the expected error: %v", err) @@ -409,7 +412,6 @@ func TestSCPTestDownloadProtocolMessages(t *testing.T) { } func TestSCPCommandHandleErrors(t *testing.T) { - connection := Connection{} buf := make([]byte, 65535) stdErrBuf := make([]byte, 65535) readErr := fmt.Errorf("test read error") @@ -420,10 +422,12 @@ func TestSCPCommandHandleErrors(t *testing.T) { ReadError: readErr, WriteError: writeErr, } + connection := Connection{ + channel: &mockSSHChannel, + } scpCommand := scpCommand{ connection: connection, args: []string{"-f", "/tmp"}, - channel: &mockSSHChannel, } err := scpCommand.handle() if err == nil || err != readErr { @@ -437,7 +441,6 @@ func TestSCPCommandHandleErrors(t *testing.T) { } func TestSCPRecursiveDownloadErrors(t *testing.T) { - connection := Connection{} buf := make([]byte, 65535) stdErrBuf := make([]byte, 65535) readErr := fmt.Errorf("test read error") @@ -448,10 +451,12 @@ func TestSCPRecursiveDownloadErrors(t *testing.T) { ReadError: readErr, WriteError: writeErr, } + connection := Connection{ + channel: &mockSSHChannel, + } scpCommand := scpCommand{ connection: connection, args: []string{"-r", "-f", "/tmp"}, - channel: &mockSSHChannel, } path := "testDir" os.Mkdir(path, 0777) @@ -466,7 +471,7 @@ func TestSCPRecursiveDownloadErrors(t *testing.T) { ReadError: nil, WriteError: nil, } - scpCommand.channel = &mockSSHChannel + scpCommand.connection.channel = &mockSSHChannel err = scpCommand.handleRecursiveDownload("invalid_dir", stat) if err == nil { t.Errorf("recursive upload download must fail for a non existing dir") @@ -476,7 +481,6 @@ func TestSCPRecursiveDownloadErrors(t *testing.T) { } func TestSCPRecursiveUploadErrors(t *testing.T) { - connection := Connection{} buf := make([]byte, 65535) stdErrBuf := make([]byte, 65535) readErr := fmt.Errorf("test read error") @@ -487,10 +491,12 @@ func TestSCPRecursiveUploadErrors(t *testing.T) { ReadError: readErr, WriteError: writeErr, } + connection := Connection{ + channel: &mockSSHChannel, + } scpCommand := scpCommand{ connection: connection, args: []string{"-r", "-t", "/tmp"}, - channel: &mockSSHChannel, } err := scpCommand.handleRecursiveUpload() if err == nil { @@ -502,7 +508,7 @@ func TestSCPRecursiveUploadErrors(t *testing.T) { ReadError: readErr, WriteError: nil, } - scpCommand.channel = &mockSSHChannel + scpCommand.connection.channel = &mockSSHChannel err = scpCommand.handleRecursiveUpload() if err == nil { t.Errorf("recursive upload must fail, we send a fake error message") @@ -516,19 +522,19 @@ func TestSCPCreateDirs(t *testing.T) { u.HomeDir = "home_rel_path" u.Username = "test" u.Permissions = []string{"*"} - connection := Connection{ - User: u, - } mockSSHChannel := MockChannel{ Buffer: bytes.NewBuffer(buf), StdErrBuffer: bytes.NewBuffer(stdErrBuf), ReadError: nil, WriteError: nil, } + connection := Connection{ + User: u, + channel: &mockSSHChannel, + } scpCommand := scpCommand{ connection: connection, args: []string{"-r", "-t", "/tmp"}, - channel: &mockSSHChannel, } err := scpCommand.handleCreateDir("invalid_dir") if err == nil { @@ -542,7 +548,6 @@ func TestSCPDownloadFileData(t *testing.T) { readErr := fmt.Errorf("test read error") writeErr := fmt.Errorf("test write error") stdErrBuf := make([]byte, 65535) - connection := Connection{} mockSSHChannelReadErr := MockChannel{ Buffer: bytes.NewBuffer(buf), StdErrBuffer: bytes.NewBuffer(stdErrBuf), @@ -555,10 +560,12 @@ func TestSCPDownloadFileData(t *testing.T) { ReadError: nil, WriteError: writeErr, } + connection := Connection{ + channel: &mockSSHChannelReadErr, + } scpCommand := scpCommand{ connection: connection, args: []string{"-r", "-f", "/tmp"}, - channel: &mockSSHChannelReadErr, } ioutil.WriteFile(testfile, []byte("test"), 0666) stat, _ := os.Stat(testfile) @@ -566,7 +573,7 @@ func TestSCPDownloadFileData(t *testing.T) { if err != readErr { t.Errorf("send download file data must fail with the expected error: %v", err) } - scpCommand.channel = &mockSSHChannelWriteErr + scpCommand.connection.channel = &mockSSHChannelWriteErr err = scpCommand.sendDownloadFileData(testfile, stat, nil) if err != writeErr { t.Errorf("send download file data must fail with the expected error: %v", err) @@ -576,7 +583,7 @@ func TestSCPDownloadFileData(t *testing.T) { if err != writeErr { t.Errorf("send download file data must fail with the expected error: %v", err) } - scpCommand.channel = &mockSSHChannelReadErr + scpCommand.connection.channel = &mockSSHChannelReadErr err = scpCommand.sendDownloadFileData(testfile, stat, nil) if err != readErr { t.Errorf("send download file data must fail with the expected error: %v", err) @@ -586,12 +593,6 @@ func TestSCPDownloadFileData(t *testing.T) { func TestSCPUploadFiledata(t *testing.T) { testfile := "testfile" - connection := Connection{ - User: dataprovider.User{ - Username: "testuser", - }, - protocol: protocolSCP, - } buf := make([]byte, 65535) stdErrBuf := make([]byte, 65535) readErr := fmt.Errorf("test read error") @@ -602,10 +603,16 @@ func TestSCPUploadFiledata(t *testing.T) { ReadError: readErr, WriteError: writeErr, } + connection := Connection{ + User: dataprovider.User{ + Username: "testuser", + }, + protocol: protocolSCP, + channel: &mockSSHChannel, + } scpCommand := scpCommand{ connection: connection, args: []string{"-r", "-t", "/tmp"}, - channel: &mockSSHChannel, } file, _ := os.Create(testfile) transfer := Transfer{ @@ -634,7 +641,7 @@ func TestSCPUploadFiledata(t *testing.T) { ReadError: readErr, WriteError: nil, } - scpCommand.channel = &mockSSHChannel + scpCommand.connection.channel = &mockSSHChannel file, _ = os.Create(testfile) transfer.file = file addTransfer(&transfer) @@ -651,7 +658,7 @@ func TestSCPUploadFiledata(t *testing.T) { ReadError: nil, WriteError: nil, } - scpCommand.channel = &mockSSHChannel + scpCommand.connection.channel = &mockSSHChannel file, _ = os.Create(testfile) transfer.file = file addTransfer(&transfer) diff --git a/sftpd/scp.go b/sftpd/scp.go index 62c4e3a3..da4e0d3f 100644 --- a/sftpd/scp.go +++ b/sftpd/scp.go @@ -35,7 +35,6 @@ type exitStatusMsg struct { type scpCommand struct { connection Connection args []string - channel ssh.Channel } func (c *scpCommand) handle() error { @@ -160,7 +159,7 @@ func (c *scpCommand) getUploadFileData(sizeToRead int64, transfer *Transfer) err remaining := sizeToRead buf := make([]byte, int64(math.Min(32768, float64(sizeToRead)))) for { - n, err := c.channel.Read(buf) + n, err := c.connection.channel.Read(buf) if err != nil { c.sendErrorMessage(err.Error()) transfer.TransferError(err) @@ -403,7 +402,7 @@ func (c *scpCommand) sendDownloadFileData(filePath string, stat os.FileInfo, tra n, err := transfer.ReadAt(buf, readed) if err == nil || err == io.EOF { if n > 0 { - _, err = c.channel.Write(buf[:n]) + _, err = c.connection.channel.Write(buf[:n]) } } readed += int64(n) @@ -517,15 +516,15 @@ func (c *scpCommand) isRecursive() bool { func (c *scpCommand) readConfirmationMessage() error { var msg strings.Builder buf := make([]byte, 1) - n, err := c.channel.Read(buf) + n, err := c.connection.channel.Read(buf) if err != nil { - c.channel.Close() + c.connection.channel.Close() return err } if n == 1 && (buf[0] == warnMsg[0] || buf[0] == errMsg[0]) { isError := buf[0] == errMsg[0] for { - n, err = c.channel.Read(buf) + n, err = c.connection.channel.Read(buf) readed := buf[:n] if err != nil || (n == 1 && readed[0] == newLine[0]) { break @@ -536,7 +535,7 @@ func (c *scpCommand) readConfirmationMessage() error { } c.connection.Log(logger.LevelInfo, logSenderSCP, "scp error message received: %v is error: %v", msg.String(), isError) err = fmt.Errorf("%v", msg.String()) - c.channel.Close() + c.connection.channel.Close() } return err } @@ -548,7 +547,7 @@ func (c *scpCommand) readProtocolMessage() (string, error) { buf := make([]byte, 1) for { var n int - n, err = c.channel.Read(buf) + n, err = c.connection.channel.Read(buf) if err != nil { break } @@ -561,34 +560,34 @@ func (c *scpCommand) readProtocolMessage() (string, error) { } } if err != nil { - c.channel.Close() + c.connection.channel.Close() } return command.String(), err } // send an error message and close the channel func (c *scpCommand) sendErrorMessage(error string) { - c.channel.Write(errMsg) - c.channel.Write([]byte(error)) - c.channel.Write(newLine) - c.channel.Close() + c.connection.channel.Write(errMsg) + c.connection.channel.Write([]byte(error)) + c.connection.channel.Write(newLine) + c.connection.channel.Close() } // send scp confirmation message and close the channel if an error happen func (c *scpCommand) sendConfirmationMessage() error { - _, err := c.channel.Write(okMsg) + _, err := c.connection.channel.Write(okMsg) if err != nil { - c.channel.Close() + c.connection.channel.Close() } return err } // sends a protocol message and close the channel on error func (c *scpCommand) sendProtocolMessage(message string) error { - _, err := c.channel.Write([]byte(message)) + _, err := c.connection.channel.Write([]byte(message)) if err != nil { c.connection.Log(logger.LevelWarn, logSenderSCP, "error sending protocol message: %v, err: %v", message, err) - c.channel.Close() + c.connection.channel.Close() } return err } @@ -604,8 +603,8 @@ func (c *scpCommand) sendExitStatus(err error) { } c.connection.Log(logger.LevelDebug, logSenderSCP, "send exit status for command with args: %v user: %v err: %v", c.args, c.connection.User.Username, err) - c.channel.SendRequest("exit-status", false, ssh.Marshal(&ex)) - c.channel.Close() + c.connection.channel.SendRequest("exit-status", false, ssh.Marshal(&ex)) + c.connection.channel.Close() } // get the next upload protocol message ignoring T command if any diff --git a/sftpd/server.go b/sftpd/server.go index 82e2d3da..353d9a6d 100644 --- a/sftpd/server.go +++ b/sftpd/server.go @@ -229,6 +229,7 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server lastActivity: time.Now(), lock: new(sync.Mutex), netConn: conn, + channel: nil, } connection.Log(logger.LevelInfo, logSender, "User id: %d, logged in with: %#v, username: %#v, home_dir: %#v", user.ID, loginType, user.Username, user.HomeDir) @@ -261,6 +262,7 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server if string(req.Payload[4:]) == "sftp" { ok = true connection.protocol = protocolSFTP + connection.channel = channel go c.handleSftpConnection(channel, connection) } case "exec": @@ -274,10 +276,10 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server if err == nil && name == "scp" && len(scpArgs) >= 2 { ok = true connection.protocol = protocolSCP + connection.channel = channel scpCommand := scpCommand{ connection: connection, args: scpArgs, - channel: channel, } go scpCommand.handle() } @@ -290,7 +292,7 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server } } -func (c Configuration) handleSftpConnection(channel io.ReadWriteCloser, connection Connection) { +func (c Configuration) handleSftpConnection(channel ssh.Channel, connection Connection) { addConnection(connection.ID, connection) // Create a new handler for the currently logged in user's server. handler := c.createHandler(connection)