sshd: map each channel with a new connection

Fixes #169
This commit is contained in:
Nicola Murino
2020-09-18 10:52:53 +02:00
parent 98a6d138d4
commit 2df0dd1f70
10 changed files with 57 additions and 61 deletions

View File

@@ -160,7 +160,6 @@ type ActiveConnection interface {
GetLastActivity() time.Time GetLastActivity() time.Time
GetCommand() string GetCommand() string
Disconnect() error Disconnect() error
SetConnDeadline()
AddTransfer(t ActiveTransfer) AddTransfer(t ActiveTransfer)
RemoveTransfer(t ActiveTransfer) RemoveTransfer(t ActiveTransfer)
GetTransfers() []ConnectionTransfer GetTransfers() []ConnectionTransfer
@@ -405,16 +404,7 @@ func (conns *ActiveConnections) Remove(connectionID string) {
conns.connections[len(conns.connections)-1] = nil conns.connections[len(conns.connections)-1] = nil
conns.connections = conns.connections[:len(conns.connections)-1] conns.connections = conns.connections[:len(conns.connections)-1]
metrics.UpdateActiveConnectionsSize(len(conns.connections)) metrics.UpdateActiveConnectionsSize(len(conns.connections))
logger.Debug(c.GetProtocol(), c.GetID(), "connection removed, num open connections: %v", logger.Debug(c.GetProtocol(), c.GetID(), "connection removed, num open connections: %v", len(conns.connections))
len(conns.connections))
// we have finished to send data here and most of the time the underlying network connection
// is already closed. Sometime a client can still be reading the last sended data, so we set
// a deadline instead of directly closing the network connection.
// Setting a deadline on an already closed connection has no effect.
// We only need to ensure that a connection will not remain indefinitely open and so the
// underlying file descriptor is not released.
// This should protect us against buggy clients and edge cases.
c.SetConnDeadline()
} else { } else {
logger.Warn(logSender, "", "connection to remove with id %#v not found!", connectionID) logger.Warn(logSender, "", "connection to remove with id %#v not found!", connectionID)
} }

View File

@@ -68,8 +68,6 @@ func (c *fakeConnection) GetRemoteAddress() string {
return "" return ""
} }
func (c *fakeConnection) SetConnDeadline() {}
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
logfilePath := "common_test.log" logfilePath := "common_test.log"
logger.InitLogger(logfilePath, 5, 1, 28, false, zerolog.DebugLevel) logger.InitLogger(logfilePath, 5, 1, 28, false, zerolog.DebugLevel)

View File

@@ -42,9 +42,6 @@ func (c *Connection) GetRemoteAddress() string {
return c.clientContext.RemoteAddr().String() return c.clientContext.RemoteAddr().String()
} }
// SetConnDeadline does nothing
func (c *Connection) SetConnDeadline() {}
// Disconnect disconnects the client // Disconnect disconnects the client
func (c *Connection) Disconnect() error { func (c *Connection) Disconnect() error {
return c.clientContext.Close(ftpserver.StatusServiceNotAvailable, "connection closed") return c.clientContext.Close(ftpserver.StatusServiceNotAvailable, "connection closed")

View File

@@ -114,8 +114,6 @@ func (c *fakeConnection) GetRemoteAddress() string {
return "" return ""
} }
func (c *fakeConnection) SetConnDeadline() {}
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
homeBasePath = os.TempDir() homeBasePath = os.TempDir()
logfilePath := filepath.Join(configDir, "sftpgo_api_test.log") logfilePath := filepath.Join(configDir, "sftpgo_api_test.log")

View File

@@ -23,7 +23,6 @@ type Connection struct {
ClientVersion string ClientVersion string
// Remote address for this connection // Remote address for this connection
RemoteAddr net.Addr RemoteAddr net.Addr
netConn net.Conn
channel ssh.Channel channel ssh.Channel
command string command string
} }
@@ -38,11 +37,6 @@ func (c *Connection) GetRemoteAddress() string {
return c.RemoteAddr.String() return c.RemoteAddr.String()
} }
// SetConnDeadline sets a deadline on the network connection so it will be eventually closed
func (c *Connection) SetConnDeadline() {
c.netConn.SetDeadline(time.Now().Add(2 * time.Minute)) //nolint:errcheck
}
// GetCommand returns the SSH command, if any // GetCommand returns the SSH command, if any
func (c *Connection) GetCommand() string { func (c *Connection) GetCommand() string {
return c.command return c.command
@@ -413,11 +407,7 @@ func (c *Connection) handleSFTPUploadToExistingFile(pflags sftp.FileOpenFlags, r
// Disconnect disconnects the client closing the network connection // Disconnect disconnects the client closing the network connection
func (c *Connection) Disconnect() error { func (c *Connection) Disconnect() error {
if c.channel != nil { return c.channel.Close()
err := c.channel.Close()
c.Log(logger.LevelInfo, "channel close, err: %v", err)
}
return c.netConn.Close()
} }
func getOSOpenFlags(requestFlags sftp.FileOpenFlags) (flags int) { func getOSOpenFlags(requestFlags sftp.FileOpenFlags) (flags int) {

View File

@@ -518,7 +518,6 @@ func TestSSHCommandErrors(t *testing.T) {
connection := Connection{ connection := Connection{
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs), BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs),
channel: &mockSSHChannel, channel: &mockSSHChannel,
netConn: client,
} }
cmd := sshCommand{ cmd := sshCommand{
command: "md5sum", command: "md5sum",
@@ -674,7 +673,6 @@ func TestCommandsWithExtensionsFilter(t *testing.T) {
connection := &Connection{ connection := &Connection{
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs), BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs),
channel: &mockSSHChannel, channel: &mockSSHChannel,
netConn: client,
} }
cmd := sshCommand{ cmd := sshCommand{
command: "md5sum", command: "md5sum",
@@ -747,7 +745,6 @@ func TestSSHCommandsRemoteFs(t *testing.T) {
connection := &Connection{ connection := &Connection{
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs), BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs),
channel: &mockSSHChannel, channel: &mockSSHChannel,
netConn: client,
} }
cmd := sshCommand{ cmd := sshCommand{
command: "md5sum", command: "md5sum",
@@ -960,7 +957,6 @@ func TestSystemCommandErrors(t *testing.T) {
connection := &Connection{ connection := &Connection{
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs), BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs),
channel: &mockSSHChannel, channel: &mockSSHChannel,
netConn: client,
} }
var sshCmd sshCommand var sshCmd sshCommand
if runtime.GOOS == osWindows { if runtime.GOOS == osWindows {
@@ -1268,7 +1264,6 @@ func TestSCPCommandHandleErrors(t *testing.T) {
connection := &Connection{ connection := &Connection{
BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, dataprovider.User{}, nil), BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, dataprovider.User{}, nil),
channel: &mockSSHChannel, channel: &mockSSHChannel,
netConn: client,
} }
scpCommand := scpCommand{ scpCommand := scpCommand{
sshCommand: sshCommand{ sshCommand: sshCommand{
@@ -1309,7 +1304,6 @@ func TestSCPErrorsMockFs(t *testing.T) {
}() }()
connection := &Connection{ connection := &Connection{
channel: &mockSSHChannel, channel: &mockSSHChannel,
netConn: client,
BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, u, fs), BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, u, fs),
} }
scpCommand := scpCommand{ scpCommand := scpCommand{
@@ -1364,7 +1358,6 @@ func TestSCPRecursiveDownloadErrors(t *testing.T) {
connection := &Connection{ connection := &Connection{
BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, dataprovider.User{}, fs), BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, dataprovider.User{}, fs),
channel: &mockSSHChannel, channel: &mockSSHChannel,
netConn: client,
} }
scpCommand := scpCommand{ scpCommand := scpCommand{
sshCommand: sshCommand{ sshCommand: sshCommand{

View File

@@ -287,6 +287,8 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server
// handshake completed so remove the deadline, we'll use IdleTimeout configuration from now on // handshake completed so remove the deadline, we'll use IdleTimeout configuration from now on
conn.SetDeadline(time.Time{}) //nolint:errcheck conn.SetDeadline(time.Time{}) //nolint:errcheck
defer conn.Close()
var user dataprovider.User var user dataprovider.User
// Unmarshal cannot fails here and even if it fails we'll have a user with no permissions // Unmarshal cannot fails here and even if it fails we'll have a user with no permissions
@@ -299,62 +301,68 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server
if err != nil { if err != nil {
logger.Warn(logSender, "", "could create filesystem for user %#v err: %v", user.Username, err) logger.Warn(logSender, "", "could create filesystem for user %#v err: %v", user.Username, err)
conn.Close()
return return
} }
connection := Connection{ fs.CheckRootPath(user.Username, user.GetUID(), user.GetGID())
BaseConnection: common.NewBaseConnection(connectionID, "sftpd", user, fs),
ClientVersion: string(sconn.ClientVersion()),
RemoteAddr: remoteAddr,
netConn: conn,
channel: nil,
}
connection.Fs.CheckRootPath(user.Username, user.GetUID(), user.GetGID()) logger.Log(logger.LevelInfo, common.ProtocolSSH, connectionID,
"User id: %d, logged in with: %#v, username: %#v, home_dir: %#v remote addr: %#v",
connection.Log(logger.LevelInfo, "User id: %d, logged in with: %#v, username: %#v, home_dir: %#v remote addr: %#v",
user.ID, loginType, user.Username, user.HomeDir, remoteAddr.String()) user.ID, loginType, user.Username, user.HomeDir, remoteAddr.String())
dataprovider.UpdateLastLogin(user) //nolint:errcheck dataprovider.UpdateLastLogin(user) //nolint:errcheck
go ssh.DiscardRequests(reqs) go ssh.DiscardRequests(reqs)
channelCounter := 0
for newChannel := range chans { for newChannel := range chans {
// If its not a session channel we just move on because its not something we // If its not a session channel we just move on because its not something we
// know how to handle at this point. // know how to handle at this point.
if newChannel.ChannelType() != "session" { if newChannel.ChannelType() != "session" {
connection.Log(logger.LevelDebug, "received an unknown channel type: %v", newChannel.ChannelType()) logger.Log(logger.LevelDebug, common.ProtocolSSH, connectionID, "received an unknown channel type: %v",
newChannel.ChannelType())
newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") //nolint:errcheck newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") //nolint:errcheck
continue continue
} }
channel, requests, err := newChannel.Accept() channel, requests, err := newChannel.Accept()
if err != nil { if err != nil {
connection.Log(logger.LevelWarn, "could not accept a channel: %v", err) logger.Log(logger.LevelWarn, common.ProtocolSSH, connectionID, "could not accept a channel: %v", err)
continue continue
} }
channelCounter++
// Channels have a type that is dependent on the protocol. For SFTP this is "subsystem" // Channels have a type that is dependent on the protocol. For SFTP this is "subsystem"
// with a payload that (should) be "sftp". Discard anything else we receive ("pty", "shell", etc) // with a payload that (should) be "sftp". Discard anything else we receive ("pty", "shell", etc)
go func(in <-chan *ssh.Request) { go func(in <-chan *ssh.Request, counter int) {
for req := range in { for req := range in {
ok := false ok := false
connID := fmt.Sprintf("%v_%v", connectionID, counter)
switch req.Type { switch req.Type {
case "subsystem": case "subsystem":
if string(req.Payload[4:]) == "sftp" { if string(req.Payload[4:]) == "sftp" {
ok = true ok = true
connection.SetProtocol(common.ProtocolSFTP) connection := Connection{
connection.channel = channel BaseConnection: common.NewBaseConnection(connID, common.ProtocolSFTP, user, fs),
ClientVersion: string(sconn.ClientVersion()),
RemoteAddr: remoteAddr,
channel: channel,
}
go c.handleSftpConnection(channel, &connection) go c.handleSftpConnection(channel, &connection)
} }
case "exec": case "exec":
connection.SetProtocol(common.ProtocolSSH) // protocol will be set later inside processSSHCommand it could be SSH or SCP
ok = processSSHCommand(req.Payload, &connection, channel, c.EnabledSSHCommands) connection := Connection{
BaseConnection: common.NewBaseConnection(connID, "sshd", user, fs),
ClientVersion: string(sconn.ClientVersion()),
RemoteAddr: remoteAddr,
channel: channel,
}
ok = processSSHCommand(req.Payload, &connection, c.EnabledSSHCommands)
} }
req.Reply(ok, nil) //nolint:errcheck req.Reply(ok, nil) //nolint:errcheck
} }
}(requests) }(requests, channelCounter)
} }
} }

View File

@@ -5368,6 +5368,33 @@ func TestPermsSubDirsSetstat(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
} }
func TestOpenUnhandledChannel(t *testing.T) {
u := getTestUser(false)
user, _, err := httpd.AddUser(u, http.StatusOK)
assert.NoError(t, err)
config := &ssh.ClientConfig{
User: user.Username,
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
return nil
},
Auth: []ssh.AuthMethod{ssh.Password(defaultPassword)},
}
conn, err := ssh.Dial("tcp", sftpServerAddr, config)
if assert.NoError(t, err) {
_, _, err = conn.OpenChannel("unhandled", nil)
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "unknown channel type")
}
err = conn.Close()
assert.NoError(t, err)
}
_, err = httpd.RemoveUser(user, http.StatusOK)
assert.NoError(t, err)
err = os.RemoveAll(user.GetHomeDir())
assert.NoError(t, err)
}
func TestPermsSubDirsCommands(t *testing.T) { func TestPermsSubDirsCommands(t *testing.T) {
usePubKey := true usePubKey := true
u := getTestUser(usePubKey) u := getTestUser(usePubKey)

View File

@@ -48,7 +48,7 @@ type systemCommand struct {
quotaCheckPath string quotaCheckPath string
} }
func processSSHCommand(payload []byte, connection *Connection, channel ssh.Channel, enabledSSHCommands []string) bool { func processSSHCommand(payload []byte, connection *Connection, enabledSSHCommands []string) bool {
var msg sshSubsystemExecMsg var msg sshSubsystemExecMsg
if err := ssh.Unmarshal(payload, &msg); err == nil { if err := ssh.Unmarshal(payload, &msg); err == nil {
name, args, err := parseCommandPayload(msg.Command) name, args, err := parseCommandPayload(msg.Command)
@@ -58,7 +58,6 @@ func processSSHCommand(payload []byte, connection *Connection, channel ssh.Chann
connection.command = msg.Command connection.command = msg.Command
if name == scpCmdName && len(args) >= 2 { if name == scpCmdName && len(args) >= 2 {
connection.SetProtocol(common.ProtocolSCP) connection.SetProtocol(common.ProtocolSCP)
connection.channel = channel
scpCommand := scpCommand{ scpCommand := scpCommand{
sshCommand: sshCommand{ sshCommand: sshCommand{
command: name, command: name,
@@ -70,7 +69,6 @@ func processSSHCommand(payload []byte, connection *Connection, channel ssh.Chann
} }
if name != scpCmdName { if name != scpCmdName {
connection.SetProtocol(common.ProtocolSSH) connection.SetProtocol(common.ProtocolSSH)
connection.channel = channel
sshCommand := sshCommand{ sshCommand := sshCommand{
command: name, command: name,
connection: connection, connection: connection,

View File

@@ -39,9 +39,6 @@ func (c *Connection) GetRemoteAddress() string {
return "" return ""
} }
// SetConnDeadline does nothing
func (c *Connection) SetConnDeadline() {}
// Disconnect closes the active transfer // Disconnect closes the active transfer
func (c *Connection) Disconnect() error { func (c *Connection) Disconnect() error {
return c.SignalTransfersAbort() return c.SignalTransfersAbort()