mirror of
https://github.com/drakkan/sftpgo.git
synced 2025-12-07 14:50:55 +03:00
connections: close the ssh channel before the network connection
This way if pkg/sftp is stuck in Serve() method should be unlocked.
This commit is contained in:
37
sftpd/scp.go
37
sftpd/scp.go
@@ -35,7 +35,6 @@ type exitStatusMsg struct {
|
||||
type scpCommand struct {
|
||||
connection Connection
|
||||
args []string
|
||||
channel ssh.Channel
|
||||
}
|
||||
|
||||
func (c *scpCommand) handle() error {
|
||||
@@ -160,7 +159,7 @@ func (c *scpCommand) getUploadFileData(sizeToRead int64, transfer *Transfer) err
|
||||
remaining := sizeToRead
|
||||
buf := make([]byte, int64(math.Min(32768, float64(sizeToRead))))
|
||||
for {
|
||||
n, err := c.channel.Read(buf)
|
||||
n, err := c.connection.channel.Read(buf)
|
||||
if err != nil {
|
||||
c.sendErrorMessage(err.Error())
|
||||
transfer.TransferError(err)
|
||||
@@ -403,7 +402,7 @@ func (c *scpCommand) sendDownloadFileData(filePath string, stat os.FileInfo, tra
|
||||
n, err := transfer.ReadAt(buf, readed)
|
||||
if err == nil || err == io.EOF {
|
||||
if n > 0 {
|
||||
_, err = c.channel.Write(buf[:n])
|
||||
_, err = c.connection.channel.Write(buf[:n])
|
||||
}
|
||||
}
|
||||
readed += int64(n)
|
||||
@@ -517,15 +516,15 @@ func (c *scpCommand) isRecursive() bool {
|
||||
func (c *scpCommand) readConfirmationMessage() error {
|
||||
var msg strings.Builder
|
||||
buf := make([]byte, 1)
|
||||
n, err := c.channel.Read(buf)
|
||||
n, err := c.connection.channel.Read(buf)
|
||||
if err != nil {
|
||||
c.channel.Close()
|
||||
c.connection.channel.Close()
|
||||
return err
|
||||
}
|
||||
if n == 1 && (buf[0] == warnMsg[0] || buf[0] == errMsg[0]) {
|
||||
isError := buf[0] == errMsg[0]
|
||||
for {
|
||||
n, err = c.channel.Read(buf)
|
||||
n, err = c.connection.channel.Read(buf)
|
||||
readed := buf[:n]
|
||||
if err != nil || (n == 1 && readed[0] == newLine[0]) {
|
||||
break
|
||||
@@ -536,7 +535,7 @@ func (c *scpCommand) readConfirmationMessage() error {
|
||||
}
|
||||
c.connection.Log(logger.LevelInfo, logSenderSCP, "scp error message received: %v is error: %v", msg.String(), isError)
|
||||
err = fmt.Errorf("%v", msg.String())
|
||||
c.channel.Close()
|
||||
c.connection.channel.Close()
|
||||
}
|
||||
return err
|
||||
}
|
||||
@@ -548,7 +547,7 @@ func (c *scpCommand) readProtocolMessage() (string, error) {
|
||||
buf := make([]byte, 1)
|
||||
for {
|
||||
var n int
|
||||
n, err = c.channel.Read(buf)
|
||||
n, err = c.connection.channel.Read(buf)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
@@ -561,34 +560,34 @@ func (c *scpCommand) readProtocolMessage() (string, error) {
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
c.channel.Close()
|
||||
c.connection.channel.Close()
|
||||
}
|
||||
return command.String(), err
|
||||
}
|
||||
|
||||
// send an error message and close the channel
|
||||
func (c *scpCommand) sendErrorMessage(error string) {
|
||||
c.channel.Write(errMsg)
|
||||
c.channel.Write([]byte(error))
|
||||
c.channel.Write(newLine)
|
||||
c.channel.Close()
|
||||
c.connection.channel.Write(errMsg)
|
||||
c.connection.channel.Write([]byte(error))
|
||||
c.connection.channel.Write(newLine)
|
||||
c.connection.channel.Close()
|
||||
}
|
||||
|
||||
// send scp confirmation message and close the channel if an error happen
|
||||
func (c *scpCommand) sendConfirmationMessage() error {
|
||||
_, err := c.channel.Write(okMsg)
|
||||
_, err := c.connection.channel.Write(okMsg)
|
||||
if err != nil {
|
||||
c.channel.Close()
|
||||
c.connection.channel.Close()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// sends a protocol message and close the channel on error
|
||||
func (c *scpCommand) sendProtocolMessage(message string) error {
|
||||
_, err := c.channel.Write([]byte(message))
|
||||
_, err := c.connection.channel.Write([]byte(message))
|
||||
if err != nil {
|
||||
c.connection.Log(logger.LevelWarn, logSenderSCP, "error sending protocol message: %v, err: %v", message, err)
|
||||
c.channel.Close()
|
||||
c.connection.channel.Close()
|
||||
}
|
||||
return err
|
||||
}
|
||||
@@ -604,8 +603,8 @@ func (c *scpCommand) sendExitStatus(err error) {
|
||||
}
|
||||
c.connection.Log(logger.LevelDebug, logSenderSCP, "send exit status for command with args: %v user: %v err: %v",
|
||||
c.args, c.connection.User.Username, err)
|
||||
c.channel.SendRequest("exit-status", false, ssh.Marshal(&ex))
|
||||
c.channel.Close()
|
||||
c.connection.channel.SendRequest("exit-status", false, ssh.Marshal(&ex))
|
||||
c.connection.channel.Close()
|
||||
}
|
||||
|
||||
// get the next upload protocol message ignoring T command if any
|
||||
|
||||
Reference in New Issue
Block a user