sshd: map each channel with a new connection

Fixes #169
This commit is contained in:
Nicola Murino
2020-09-18 10:52:53 +02:00
parent 98a6d138d4
commit 2df0dd1f70
10 changed files with 57 additions and 61 deletions

View File

@@ -287,6 +287,8 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server
// handshake completed so remove the deadline, we'll use IdleTimeout configuration from now on
conn.SetDeadline(time.Time{}) //nolint:errcheck
defer conn.Close()
var user dataprovider.User
// Unmarshal cannot fails here and even if it fails we'll have a user with no permissions
@@ -299,62 +301,68 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server
if err != nil {
logger.Warn(logSender, "", "could create filesystem for user %#v err: %v", user.Username, err)
conn.Close()
return
}
connection := Connection{
BaseConnection: common.NewBaseConnection(connectionID, "sftpd", user, fs),
ClientVersion: string(sconn.ClientVersion()),
RemoteAddr: remoteAddr,
netConn: conn,
channel: nil,
}
fs.CheckRootPath(user.Username, user.GetUID(), user.GetGID())
connection.Fs.CheckRootPath(user.Username, user.GetUID(), user.GetGID())
connection.Log(logger.LevelInfo, "User id: %d, logged in with: %#v, username: %#v, home_dir: %#v remote addr: %#v",
logger.Log(logger.LevelInfo, common.ProtocolSSH, connectionID,
"User id: %d, logged in with: %#v, username: %#v, home_dir: %#v remote addr: %#v",
user.ID, loginType, user.Username, user.HomeDir, remoteAddr.String())
dataprovider.UpdateLastLogin(user) //nolint:errcheck
go ssh.DiscardRequests(reqs)
channelCounter := 0
for newChannel := range chans {
// If its not a session channel we just move on because its not something we
// know how to handle at this point.
if newChannel.ChannelType() != "session" {
connection.Log(logger.LevelDebug, "received an unknown channel type: %v", newChannel.ChannelType())
logger.Log(logger.LevelDebug, common.ProtocolSSH, connectionID, "received an unknown channel type: %v",
newChannel.ChannelType())
newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") //nolint:errcheck
continue
}
channel, requests, err := newChannel.Accept()
if err != nil {
connection.Log(logger.LevelWarn, "could not accept a channel: %v", err)
logger.Log(logger.LevelWarn, common.ProtocolSSH, connectionID, "could not accept a channel: %v", err)
continue
}
channelCounter++
// Channels have a type that is dependent on the protocol. For SFTP this is "subsystem"
// with a payload that (should) be "sftp". Discard anything else we receive ("pty", "shell", etc)
go func(in <-chan *ssh.Request) {
go func(in <-chan *ssh.Request, counter int) {
for req := range in {
ok := false
connID := fmt.Sprintf("%v_%v", connectionID, counter)
switch req.Type {
case "subsystem":
if string(req.Payload[4:]) == "sftp" {
ok = true
connection.SetProtocol(common.ProtocolSFTP)
connection.channel = channel
connection := Connection{
BaseConnection: common.NewBaseConnection(connID, common.ProtocolSFTP, user, fs),
ClientVersion: string(sconn.ClientVersion()),
RemoteAddr: remoteAddr,
channel: channel,
}
go c.handleSftpConnection(channel, &connection)
}
case "exec":
connection.SetProtocol(common.ProtocolSSH)
ok = processSSHCommand(req.Payload, &connection, channel, c.EnabledSSHCommands)
// protocol will be set later inside processSSHCommand it could be SSH or SCP
connection := Connection{
BaseConnection: common.NewBaseConnection(connID, "sshd", user, fs),
ClientVersion: string(sconn.ClientVersion()),
RemoteAddr: remoteAddr,
channel: channel,
}
ok = processSSHCommand(req.Payload, &connection, c.EnabledSSHCommands)
}
req.Reply(ok, nil) //nolint:errcheck
}
}(requests)
}(requests, channelCounter)
}
}