connections: close the ssh channel before the network connection

This way if pkg/sftp is stuck in Serve() method should be unlocked.
This commit is contained in:
Nicola Murino
2019-09-11 16:29:56 +02:00
parent 9794ca7ee0
commit 3d13fe15c3
4 changed files with 70 additions and 56 deletions

View File

@@ -35,7 +35,6 @@ type exitStatusMsg struct {
type scpCommand struct {
connection Connection
args []string
channel ssh.Channel
}
func (c *scpCommand) handle() error {
@@ -160,7 +159,7 @@ func (c *scpCommand) getUploadFileData(sizeToRead int64, transfer *Transfer) err
remaining := sizeToRead
buf := make([]byte, int64(math.Min(32768, float64(sizeToRead))))
for {
n, err := c.channel.Read(buf)
n, err := c.connection.channel.Read(buf)
if err != nil {
c.sendErrorMessage(err.Error())
transfer.TransferError(err)
@@ -403,7 +402,7 @@ func (c *scpCommand) sendDownloadFileData(filePath string, stat os.FileInfo, tra
n, err := transfer.ReadAt(buf, readed)
if err == nil || err == io.EOF {
if n > 0 {
_, err = c.channel.Write(buf[:n])
_, err = c.connection.channel.Write(buf[:n])
}
}
readed += int64(n)
@@ -517,15 +516,15 @@ func (c *scpCommand) isRecursive() bool {
func (c *scpCommand) readConfirmationMessage() error {
var msg strings.Builder
buf := make([]byte, 1)
n, err := c.channel.Read(buf)
n, err := c.connection.channel.Read(buf)
if err != nil {
c.channel.Close()
c.connection.channel.Close()
return err
}
if n == 1 && (buf[0] == warnMsg[0] || buf[0] == errMsg[0]) {
isError := buf[0] == errMsg[0]
for {
n, err = c.channel.Read(buf)
n, err = c.connection.channel.Read(buf)
readed := buf[:n]
if err != nil || (n == 1 && readed[0] == newLine[0]) {
break
@@ -536,7 +535,7 @@ func (c *scpCommand) readConfirmationMessage() error {
}
c.connection.Log(logger.LevelInfo, logSenderSCP, "scp error message received: %v is error: %v", msg.String(), isError)
err = fmt.Errorf("%v", msg.String())
c.channel.Close()
c.connection.channel.Close()
}
return err
}
@@ -548,7 +547,7 @@ func (c *scpCommand) readProtocolMessage() (string, error) {
buf := make([]byte, 1)
for {
var n int
n, err = c.channel.Read(buf)
n, err = c.connection.channel.Read(buf)
if err != nil {
break
}
@@ -561,34 +560,34 @@ func (c *scpCommand) readProtocolMessage() (string, error) {
}
}
if err != nil {
c.channel.Close()
c.connection.channel.Close()
}
return command.String(), err
}
// send an error message and close the channel
func (c *scpCommand) sendErrorMessage(error string) {
c.channel.Write(errMsg)
c.channel.Write([]byte(error))
c.channel.Write(newLine)
c.channel.Close()
c.connection.channel.Write(errMsg)
c.connection.channel.Write([]byte(error))
c.connection.channel.Write(newLine)
c.connection.channel.Close()
}
// send scp confirmation message and close the channel if an error happen
func (c *scpCommand) sendConfirmationMessage() error {
_, err := c.channel.Write(okMsg)
_, err := c.connection.channel.Write(okMsg)
if err != nil {
c.channel.Close()
c.connection.channel.Close()
}
return err
}
// sends a protocol message and close the channel on error
func (c *scpCommand) sendProtocolMessage(message string) error {
_, err := c.channel.Write([]byte(message))
_, err := c.connection.channel.Write([]byte(message))
if err != nil {
c.connection.Log(logger.LevelWarn, logSenderSCP, "error sending protocol message: %v, err: %v", message, err)
c.channel.Close()
c.connection.channel.Close()
}
return err
}
@@ -604,8 +603,8 @@ func (c *scpCommand) sendExitStatus(err error) {
}
c.connection.Log(logger.LevelDebug, logSenderSCP, "send exit status for command with args: %v user: %v err: %v",
c.args, c.connection.User.Username, err)
c.channel.SendRequest("exit-status", false, ssh.Marshal(&ex))
c.channel.Close()
c.connection.channel.SendRequest("exit-status", false, ssh.Marshal(&ex))
c.connection.channel.Close()
}
// get the next upload protocol message ignoring T command if any