connections: close the ssh channel before the network connection

This way if pkg/sftp is stuck in Serve() method should be unlocked.
This commit is contained in:
Nicola Murino
2019-09-11 16:29:56 +02:00
parent 9794ca7ee0
commit 3d13fe15c3
4 changed files with 70 additions and 56 deletions

View File

@@ -13,6 +13,7 @@ import (
"github.com/drakkan/sftpgo/utils" "github.com/drakkan/sftpgo/utils"
"github.com/rs/xid" "github.com/rs/xid"
"golang.org/x/crypto/ssh"
"github.com/drakkan/sftpgo/dataprovider" "github.com/drakkan/sftpgo/dataprovider"
"github.com/drakkan/sftpgo/logger" "github.com/drakkan/sftpgo/logger"
@@ -37,6 +38,7 @@ type Connection struct {
protocol string protocol string
lock *sync.Mutex lock *sync.Mutex
netConn net.Conn netConn net.Conn
channel ssh.Channel
} }
// Log outputs a log entry to the configured logger // Log outputs a log entry to the configured logger
@@ -580,6 +582,10 @@ func (c Connection) createMissingDirs(filePath string) error {
} }
func (c Connection) close() error { func (c Connection) close() error {
if c.channel != nil {
err := c.channel.Close()
c.Log(logger.LevelInfo, logSender, "channel close, err: %v", err)
}
return c.netConn.Close() return c.netConn.Close()
} }

View File

@@ -252,7 +252,6 @@ func TestSCPGetNonExistingDirContent(t *testing.T) {
} }
func TestSCPParseUploadMessage(t *testing.T) { func TestSCPParseUploadMessage(t *testing.T) {
connection := Connection{}
buf := make([]byte, 65535) buf := make([]byte, 65535)
stdErrBuf := make([]byte, 65535) stdErrBuf := make([]byte, 65535)
mockSSHChannel := MockChannel{ mockSSHChannel := MockChannel{
@@ -260,10 +259,12 @@ func TestSCPParseUploadMessage(t *testing.T) {
StdErrBuffer: bytes.NewBuffer(stdErrBuf), StdErrBuffer: bytes.NewBuffer(stdErrBuf),
ReadError: nil, ReadError: nil,
} }
connection := Connection{
channel: &mockSSHChannel,
}
scpCommand := scpCommand{ scpCommand := scpCommand{
connection: connection, connection: connection,
args: []string{"-t", "/tmp"}, args: []string{"-t", "/tmp"},
channel: &mockSSHChannel,
} }
_, _, err := scpCommand.parseUploadMessage("invalid") _, _, err := scpCommand.parseUploadMessage("invalid")
if err == nil { if err == nil {
@@ -284,7 +285,6 @@ func TestSCPParseUploadMessage(t *testing.T) {
} }
func TestSCPProtocolMessages(t *testing.T) { func TestSCPProtocolMessages(t *testing.T) {
connection := Connection{}
buf := make([]byte, 65535) buf := make([]byte, 65535)
stdErrBuf := make([]byte, 65535) stdErrBuf := make([]byte, 65535)
readErr := fmt.Errorf("test read error") readErr := fmt.Errorf("test read error")
@@ -295,10 +295,12 @@ func TestSCPProtocolMessages(t *testing.T) {
ReadError: readErr, ReadError: readErr,
WriteError: writeErr, WriteError: writeErr,
} }
connection := Connection{
channel: &mockSSHChannel,
}
scpCommand := scpCommand{ scpCommand := scpCommand{
connection: connection, connection: connection,
args: []string{"-t", "/tmp"}, args: []string{"-t", "/tmp"},
channel: &mockSSHChannel,
} }
_, err := scpCommand.readProtocolMessage() _, err := scpCommand.readProtocolMessage()
if err == nil || err != readErr { if err == nil || err != readErr {
@@ -322,7 +324,7 @@ func TestSCPProtocolMessages(t *testing.T) {
ReadError: nil, ReadError: nil,
WriteError: writeErr, WriteError: writeErr,
} }
scpCommand.channel = &mockSSHChannel scpCommand.connection.channel = &mockSSHChannel
_, err = scpCommand.getNextUploadProtocolMessage() _, err = scpCommand.getNextUploadProtocolMessage()
if err == nil || err != writeErr { if err == nil || err != writeErr {
t.Errorf("read next upload protocol message must fail, we are sending a fake write error") t.Errorf("read next upload protocol message must fail, we are sending a fake write error")
@@ -337,7 +339,7 @@ func TestSCPProtocolMessages(t *testing.T) {
ReadError: nil, ReadError: nil,
WriteError: nil, WriteError: nil,
} }
scpCommand.channel = &mockSSHChannel scpCommand.connection.channel = &mockSSHChannel
err = scpCommand.readConfirmationMessage() err = scpCommand.readConfirmationMessage()
if err == nil || err.Error() != protocolErrorMsg { if err == nil || err.Error() != protocolErrorMsg {
t.Errorf("read confirmation message must return the expected protocol error, actual err: %v", err) t.Errorf("read confirmation message must return the expected protocol error, actual err: %v", err)
@@ -345,7 +347,6 @@ func TestSCPProtocolMessages(t *testing.T) {
} }
func TestSCPTestDownloadProtocolMessages(t *testing.T) { func TestSCPTestDownloadProtocolMessages(t *testing.T) {
connection := Connection{}
buf := make([]byte, 65535) buf := make([]byte, 65535)
stdErrBuf := make([]byte, 65535) stdErrBuf := make([]byte, 65535)
readErr := fmt.Errorf("test read error") readErr := fmt.Errorf("test read error")
@@ -356,10 +357,12 @@ func TestSCPTestDownloadProtocolMessages(t *testing.T) {
ReadError: readErr, ReadError: readErr,
WriteError: writeErr, WriteError: writeErr,
} }
connection := Connection{
channel: &mockSSHChannel,
}
scpCommand := scpCommand{ scpCommand := scpCommand{
connection: connection, connection: connection,
args: []string{"-f", "-p", "/tmp"}, args: []string{"-f", "-p", "/tmp"},
channel: &mockSSHChannel,
} }
path := "testDir" path := "testDir"
os.Mkdir(path, 0777) os.Mkdir(path, 0777)
@@ -388,7 +391,7 @@ func TestSCPTestDownloadProtocolMessages(t *testing.T) {
WriteError: writeErr, WriteError: writeErr,
} }
scpCommand.args = []string{"-f", "/tmp"} scpCommand.args = []string{"-f", "/tmp"}
scpCommand.channel = &mockSSHChannel scpCommand.connection.channel = &mockSSHChannel
err = scpCommand.sendDownloadProtocolMessages(path, stat) err = scpCommand.sendDownloadProtocolMessages(path, stat)
if err != writeErr { if err != writeErr {
t.Errorf("sendDownloadProtocolMessages must return the expected error: %v", err) t.Errorf("sendDownloadProtocolMessages must return the expected error: %v", err)
@@ -400,7 +403,7 @@ func TestSCPTestDownloadProtocolMessages(t *testing.T) {
ReadError: readErr, ReadError: readErr,
WriteError: nil, WriteError: nil,
} }
scpCommand.channel = &mockSSHChannel scpCommand.connection.channel = &mockSSHChannel
err = scpCommand.sendDownloadProtocolMessages(path, stat) err = scpCommand.sendDownloadProtocolMessages(path, stat)
if err != readErr { if err != readErr {
t.Errorf("sendDownloadProtocolMessages must return the expected error: %v", err) t.Errorf("sendDownloadProtocolMessages must return the expected error: %v", err)
@@ -409,7 +412,6 @@ func TestSCPTestDownloadProtocolMessages(t *testing.T) {
} }
func TestSCPCommandHandleErrors(t *testing.T) { func TestSCPCommandHandleErrors(t *testing.T) {
connection := Connection{}
buf := make([]byte, 65535) buf := make([]byte, 65535)
stdErrBuf := make([]byte, 65535) stdErrBuf := make([]byte, 65535)
readErr := fmt.Errorf("test read error") readErr := fmt.Errorf("test read error")
@@ -420,10 +422,12 @@ func TestSCPCommandHandleErrors(t *testing.T) {
ReadError: readErr, ReadError: readErr,
WriteError: writeErr, WriteError: writeErr,
} }
connection := Connection{
channel: &mockSSHChannel,
}
scpCommand := scpCommand{ scpCommand := scpCommand{
connection: connection, connection: connection,
args: []string{"-f", "/tmp"}, args: []string{"-f", "/tmp"},
channel: &mockSSHChannel,
} }
err := scpCommand.handle() err := scpCommand.handle()
if err == nil || err != readErr { if err == nil || err != readErr {
@@ -437,7 +441,6 @@ func TestSCPCommandHandleErrors(t *testing.T) {
} }
func TestSCPRecursiveDownloadErrors(t *testing.T) { func TestSCPRecursiveDownloadErrors(t *testing.T) {
connection := Connection{}
buf := make([]byte, 65535) buf := make([]byte, 65535)
stdErrBuf := make([]byte, 65535) stdErrBuf := make([]byte, 65535)
readErr := fmt.Errorf("test read error") readErr := fmt.Errorf("test read error")
@@ -448,10 +451,12 @@ func TestSCPRecursiveDownloadErrors(t *testing.T) {
ReadError: readErr, ReadError: readErr,
WriteError: writeErr, WriteError: writeErr,
} }
connection := Connection{
channel: &mockSSHChannel,
}
scpCommand := scpCommand{ scpCommand := scpCommand{
connection: connection, connection: connection,
args: []string{"-r", "-f", "/tmp"}, args: []string{"-r", "-f", "/tmp"},
channel: &mockSSHChannel,
} }
path := "testDir" path := "testDir"
os.Mkdir(path, 0777) os.Mkdir(path, 0777)
@@ -466,7 +471,7 @@ func TestSCPRecursiveDownloadErrors(t *testing.T) {
ReadError: nil, ReadError: nil,
WriteError: nil, WriteError: nil,
} }
scpCommand.channel = &mockSSHChannel scpCommand.connection.channel = &mockSSHChannel
err = scpCommand.handleRecursiveDownload("invalid_dir", stat) err = scpCommand.handleRecursiveDownload("invalid_dir", stat)
if err == nil { if err == nil {
t.Errorf("recursive upload download must fail for a non existing dir") t.Errorf("recursive upload download must fail for a non existing dir")
@@ -476,7 +481,6 @@ func TestSCPRecursiveDownloadErrors(t *testing.T) {
} }
func TestSCPRecursiveUploadErrors(t *testing.T) { func TestSCPRecursiveUploadErrors(t *testing.T) {
connection := Connection{}
buf := make([]byte, 65535) buf := make([]byte, 65535)
stdErrBuf := make([]byte, 65535) stdErrBuf := make([]byte, 65535)
readErr := fmt.Errorf("test read error") readErr := fmt.Errorf("test read error")
@@ -487,10 +491,12 @@ func TestSCPRecursiveUploadErrors(t *testing.T) {
ReadError: readErr, ReadError: readErr,
WriteError: writeErr, WriteError: writeErr,
} }
connection := Connection{
channel: &mockSSHChannel,
}
scpCommand := scpCommand{ scpCommand := scpCommand{
connection: connection, connection: connection,
args: []string{"-r", "-t", "/tmp"}, args: []string{"-r", "-t", "/tmp"},
channel: &mockSSHChannel,
} }
err := scpCommand.handleRecursiveUpload() err := scpCommand.handleRecursiveUpload()
if err == nil { if err == nil {
@@ -502,7 +508,7 @@ func TestSCPRecursiveUploadErrors(t *testing.T) {
ReadError: readErr, ReadError: readErr,
WriteError: nil, WriteError: nil,
} }
scpCommand.channel = &mockSSHChannel scpCommand.connection.channel = &mockSSHChannel
err = scpCommand.handleRecursiveUpload() err = scpCommand.handleRecursiveUpload()
if err == nil { if err == nil {
t.Errorf("recursive upload must fail, we send a fake error message") t.Errorf("recursive upload must fail, we send a fake error message")
@@ -516,19 +522,19 @@ func TestSCPCreateDirs(t *testing.T) {
u.HomeDir = "home_rel_path" u.HomeDir = "home_rel_path"
u.Username = "test" u.Username = "test"
u.Permissions = []string{"*"} u.Permissions = []string{"*"}
connection := Connection{
User: u,
}
mockSSHChannel := MockChannel{ mockSSHChannel := MockChannel{
Buffer: bytes.NewBuffer(buf), Buffer: bytes.NewBuffer(buf),
StdErrBuffer: bytes.NewBuffer(stdErrBuf), StdErrBuffer: bytes.NewBuffer(stdErrBuf),
ReadError: nil, ReadError: nil,
WriteError: nil, WriteError: nil,
} }
connection := Connection{
User: u,
channel: &mockSSHChannel,
}
scpCommand := scpCommand{ scpCommand := scpCommand{
connection: connection, connection: connection,
args: []string{"-r", "-t", "/tmp"}, args: []string{"-r", "-t", "/tmp"},
channel: &mockSSHChannel,
} }
err := scpCommand.handleCreateDir("invalid_dir") err := scpCommand.handleCreateDir("invalid_dir")
if err == nil { if err == nil {
@@ -542,7 +548,6 @@ func TestSCPDownloadFileData(t *testing.T) {
readErr := fmt.Errorf("test read error") readErr := fmt.Errorf("test read error")
writeErr := fmt.Errorf("test write error") writeErr := fmt.Errorf("test write error")
stdErrBuf := make([]byte, 65535) stdErrBuf := make([]byte, 65535)
connection := Connection{}
mockSSHChannelReadErr := MockChannel{ mockSSHChannelReadErr := MockChannel{
Buffer: bytes.NewBuffer(buf), Buffer: bytes.NewBuffer(buf),
StdErrBuffer: bytes.NewBuffer(stdErrBuf), StdErrBuffer: bytes.NewBuffer(stdErrBuf),
@@ -555,10 +560,12 @@ func TestSCPDownloadFileData(t *testing.T) {
ReadError: nil, ReadError: nil,
WriteError: writeErr, WriteError: writeErr,
} }
connection := Connection{
channel: &mockSSHChannelReadErr,
}
scpCommand := scpCommand{ scpCommand := scpCommand{
connection: connection, connection: connection,
args: []string{"-r", "-f", "/tmp"}, args: []string{"-r", "-f", "/tmp"},
channel: &mockSSHChannelReadErr,
} }
ioutil.WriteFile(testfile, []byte("test"), 0666) ioutil.WriteFile(testfile, []byte("test"), 0666)
stat, _ := os.Stat(testfile) stat, _ := os.Stat(testfile)
@@ -566,7 +573,7 @@ func TestSCPDownloadFileData(t *testing.T) {
if err != readErr { if err != readErr {
t.Errorf("send download file data must fail with the expected error: %v", err) t.Errorf("send download file data must fail with the expected error: %v", err)
} }
scpCommand.channel = &mockSSHChannelWriteErr scpCommand.connection.channel = &mockSSHChannelWriteErr
err = scpCommand.sendDownloadFileData(testfile, stat, nil) err = scpCommand.sendDownloadFileData(testfile, stat, nil)
if err != writeErr { if err != writeErr {
t.Errorf("send download file data must fail with the expected error: %v", err) t.Errorf("send download file data must fail with the expected error: %v", err)
@@ -576,7 +583,7 @@ func TestSCPDownloadFileData(t *testing.T) {
if err != writeErr { if err != writeErr {
t.Errorf("send download file data must fail with the expected error: %v", err) t.Errorf("send download file data must fail with the expected error: %v", err)
} }
scpCommand.channel = &mockSSHChannelReadErr scpCommand.connection.channel = &mockSSHChannelReadErr
err = scpCommand.sendDownloadFileData(testfile, stat, nil) err = scpCommand.sendDownloadFileData(testfile, stat, nil)
if err != readErr { if err != readErr {
t.Errorf("send download file data must fail with the expected error: %v", err) t.Errorf("send download file data must fail with the expected error: %v", err)
@@ -586,12 +593,6 @@ func TestSCPDownloadFileData(t *testing.T) {
func TestSCPUploadFiledata(t *testing.T) { func TestSCPUploadFiledata(t *testing.T) {
testfile := "testfile" testfile := "testfile"
connection := Connection{
User: dataprovider.User{
Username: "testuser",
},
protocol: protocolSCP,
}
buf := make([]byte, 65535) buf := make([]byte, 65535)
stdErrBuf := make([]byte, 65535) stdErrBuf := make([]byte, 65535)
readErr := fmt.Errorf("test read error") readErr := fmt.Errorf("test read error")
@@ -602,10 +603,16 @@ func TestSCPUploadFiledata(t *testing.T) {
ReadError: readErr, ReadError: readErr,
WriteError: writeErr, WriteError: writeErr,
} }
connection := Connection{
User: dataprovider.User{
Username: "testuser",
},
protocol: protocolSCP,
channel: &mockSSHChannel,
}
scpCommand := scpCommand{ scpCommand := scpCommand{
connection: connection, connection: connection,
args: []string{"-r", "-t", "/tmp"}, args: []string{"-r", "-t", "/tmp"},
channel: &mockSSHChannel,
} }
file, _ := os.Create(testfile) file, _ := os.Create(testfile)
transfer := Transfer{ transfer := Transfer{
@@ -634,7 +641,7 @@ func TestSCPUploadFiledata(t *testing.T) {
ReadError: readErr, ReadError: readErr,
WriteError: nil, WriteError: nil,
} }
scpCommand.channel = &mockSSHChannel scpCommand.connection.channel = &mockSSHChannel
file, _ = os.Create(testfile) file, _ = os.Create(testfile)
transfer.file = file transfer.file = file
addTransfer(&transfer) addTransfer(&transfer)
@@ -651,7 +658,7 @@ func TestSCPUploadFiledata(t *testing.T) {
ReadError: nil, ReadError: nil,
WriteError: nil, WriteError: nil,
} }
scpCommand.channel = &mockSSHChannel scpCommand.connection.channel = &mockSSHChannel
file, _ = os.Create(testfile) file, _ = os.Create(testfile)
transfer.file = file transfer.file = file
addTransfer(&transfer) addTransfer(&transfer)

View File

@@ -35,7 +35,6 @@ type exitStatusMsg struct {
type scpCommand struct { type scpCommand struct {
connection Connection connection Connection
args []string args []string
channel ssh.Channel
} }
func (c *scpCommand) handle() error { func (c *scpCommand) handle() error {
@@ -160,7 +159,7 @@ func (c *scpCommand) getUploadFileData(sizeToRead int64, transfer *Transfer) err
remaining := sizeToRead remaining := sizeToRead
buf := make([]byte, int64(math.Min(32768, float64(sizeToRead)))) buf := make([]byte, int64(math.Min(32768, float64(sizeToRead))))
for { for {
n, err := c.channel.Read(buf) n, err := c.connection.channel.Read(buf)
if err != nil { if err != nil {
c.sendErrorMessage(err.Error()) c.sendErrorMessage(err.Error())
transfer.TransferError(err) transfer.TransferError(err)
@@ -403,7 +402,7 @@ func (c *scpCommand) sendDownloadFileData(filePath string, stat os.FileInfo, tra
n, err := transfer.ReadAt(buf, readed) n, err := transfer.ReadAt(buf, readed)
if err == nil || err == io.EOF { if err == nil || err == io.EOF {
if n > 0 { if n > 0 {
_, err = c.channel.Write(buf[:n]) _, err = c.connection.channel.Write(buf[:n])
} }
} }
readed += int64(n) readed += int64(n)
@@ -517,15 +516,15 @@ func (c *scpCommand) isRecursive() bool {
func (c *scpCommand) readConfirmationMessage() error { func (c *scpCommand) readConfirmationMessage() error {
var msg strings.Builder var msg strings.Builder
buf := make([]byte, 1) buf := make([]byte, 1)
n, err := c.channel.Read(buf) n, err := c.connection.channel.Read(buf)
if err != nil { if err != nil {
c.channel.Close() c.connection.channel.Close()
return err return err
} }
if n == 1 && (buf[0] == warnMsg[0] || buf[0] == errMsg[0]) { if n == 1 && (buf[0] == warnMsg[0] || buf[0] == errMsg[0]) {
isError := buf[0] == errMsg[0] isError := buf[0] == errMsg[0]
for { for {
n, err = c.channel.Read(buf) n, err = c.connection.channel.Read(buf)
readed := buf[:n] readed := buf[:n]
if err != nil || (n == 1 && readed[0] == newLine[0]) { if err != nil || (n == 1 && readed[0] == newLine[0]) {
break break
@@ -536,7 +535,7 @@ func (c *scpCommand) readConfirmationMessage() error {
} }
c.connection.Log(logger.LevelInfo, logSenderSCP, "scp error message received: %v is error: %v", msg.String(), isError) c.connection.Log(logger.LevelInfo, logSenderSCP, "scp error message received: %v is error: %v", msg.String(), isError)
err = fmt.Errorf("%v", msg.String()) err = fmt.Errorf("%v", msg.String())
c.channel.Close() c.connection.channel.Close()
} }
return err return err
} }
@@ -548,7 +547,7 @@ func (c *scpCommand) readProtocolMessage() (string, error) {
buf := make([]byte, 1) buf := make([]byte, 1)
for { for {
var n int var n int
n, err = c.channel.Read(buf) n, err = c.connection.channel.Read(buf)
if err != nil { if err != nil {
break break
} }
@@ -561,34 +560,34 @@ func (c *scpCommand) readProtocolMessage() (string, error) {
} }
} }
if err != nil { if err != nil {
c.channel.Close() c.connection.channel.Close()
} }
return command.String(), err return command.String(), err
} }
// send an error message and close the channel // send an error message and close the channel
func (c *scpCommand) sendErrorMessage(error string) { func (c *scpCommand) sendErrorMessage(error string) {
c.channel.Write(errMsg) c.connection.channel.Write(errMsg)
c.channel.Write([]byte(error)) c.connection.channel.Write([]byte(error))
c.channel.Write(newLine) c.connection.channel.Write(newLine)
c.channel.Close() c.connection.channel.Close()
} }
// send scp confirmation message and close the channel if an error happen // send scp confirmation message and close the channel if an error happen
func (c *scpCommand) sendConfirmationMessage() error { func (c *scpCommand) sendConfirmationMessage() error {
_, err := c.channel.Write(okMsg) _, err := c.connection.channel.Write(okMsg)
if err != nil { if err != nil {
c.channel.Close() c.connection.channel.Close()
} }
return err return err
} }
// sends a protocol message and close the channel on error // sends a protocol message and close the channel on error
func (c *scpCommand) sendProtocolMessage(message string) error { func (c *scpCommand) sendProtocolMessage(message string) error {
_, err := c.channel.Write([]byte(message)) _, err := c.connection.channel.Write([]byte(message))
if err != nil { if err != nil {
c.connection.Log(logger.LevelWarn, logSenderSCP, "error sending protocol message: %v, err: %v", message, err) c.connection.Log(logger.LevelWarn, logSenderSCP, "error sending protocol message: %v, err: %v", message, err)
c.channel.Close() c.connection.channel.Close()
} }
return err return err
} }
@@ -604,8 +603,8 @@ func (c *scpCommand) sendExitStatus(err error) {
} }
c.connection.Log(logger.LevelDebug, logSenderSCP, "send exit status for command with args: %v user: %v err: %v", c.connection.Log(logger.LevelDebug, logSenderSCP, "send exit status for command with args: %v user: %v err: %v",
c.args, c.connection.User.Username, err) c.args, c.connection.User.Username, err)
c.channel.SendRequest("exit-status", false, ssh.Marshal(&ex)) c.connection.channel.SendRequest("exit-status", false, ssh.Marshal(&ex))
c.channel.Close() c.connection.channel.Close()
} }
// get the next upload protocol message ignoring T command if any // get the next upload protocol message ignoring T command if any

View File

@@ -229,6 +229,7 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server
lastActivity: time.Now(), lastActivity: time.Now(),
lock: new(sync.Mutex), lock: new(sync.Mutex),
netConn: conn, netConn: conn,
channel: nil,
} }
connection.Log(logger.LevelInfo, logSender, "User id: %d, logged in with: %#v, username: %#v, home_dir: %#v", connection.Log(logger.LevelInfo, logSender, "User id: %d, logged in with: %#v, username: %#v, home_dir: %#v",
user.ID, loginType, user.Username, user.HomeDir) user.ID, loginType, user.Username, user.HomeDir)
@@ -261,6 +262,7 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server
if string(req.Payload[4:]) == "sftp" { if string(req.Payload[4:]) == "sftp" {
ok = true ok = true
connection.protocol = protocolSFTP connection.protocol = protocolSFTP
connection.channel = channel
go c.handleSftpConnection(channel, connection) go c.handleSftpConnection(channel, connection)
} }
case "exec": case "exec":
@@ -274,10 +276,10 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server
if err == nil && name == "scp" && len(scpArgs) >= 2 { if err == nil && name == "scp" && len(scpArgs) >= 2 {
ok = true ok = true
connection.protocol = protocolSCP connection.protocol = protocolSCP
connection.channel = channel
scpCommand := scpCommand{ scpCommand := scpCommand{
connection: connection, connection: connection,
args: scpArgs, args: scpArgs,
channel: channel,
} }
go scpCommand.handle() go scpCommand.handle()
} }
@@ -290,7 +292,7 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server
} }
} }
func (c Configuration) handleSftpConnection(channel io.ReadWriteCloser, connection Connection) { func (c Configuration) handleSftpConnection(channel ssh.Channel, connection Connection) {
addConnection(connection.ID, connection) addConnection(connection.ID, connection)
// Create a new handler for the currently logged in user's server. // Create a new handler for the currently logged in user's server.
handler := c.createHandler(connection) handler := c.createHandler(connection)