mirror of
https://github.com/drakkan/sftpgo.git
synced 2025-12-08 15:28:05 +03:00
@@ -160,7 +160,6 @@ type ActiveConnection interface {
|
|||||||
GetLastActivity() time.Time
|
GetLastActivity() time.Time
|
||||||
GetCommand() string
|
GetCommand() string
|
||||||
Disconnect() error
|
Disconnect() error
|
||||||
SetConnDeadline()
|
|
||||||
AddTransfer(t ActiveTransfer)
|
AddTransfer(t ActiveTransfer)
|
||||||
RemoveTransfer(t ActiveTransfer)
|
RemoveTransfer(t ActiveTransfer)
|
||||||
GetTransfers() []ConnectionTransfer
|
GetTransfers() []ConnectionTransfer
|
||||||
@@ -405,16 +404,7 @@ func (conns *ActiveConnections) Remove(connectionID string) {
|
|||||||
conns.connections[len(conns.connections)-1] = nil
|
conns.connections[len(conns.connections)-1] = nil
|
||||||
conns.connections = conns.connections[:len(conns.connections)-1]
|
conns.connections = conns.connections[:len(conns.connections)-1]
|
||||||
metrics.UpdateActiveConnectionsSize(len(conns.connections))
|
metrics.UpdateActiveConnectionsSize(len(conns.connections))
|
||||||
logger.Debug(c.GetProtocol(), c.GetID(), "connection removed, num open connections: %v",
|
logger.Debug(c.GetProtocol(), c.GetID(), "connection removed, num open connections: %v", len(conns.connections))
|
||||||
len(conns.connections))
|
|
||||||
// we have finished to send data here and most of the time the underlying network connection
|
|
||||||
// is already closed. Sometime a client can still be reading the last sended data, so we set
|
|
||||||
// a deadline instead of directly closing the network connection.
|
|
||||||
// Setting a deadline on an already closed connection has no effect.
|
|
||||||
// We only need to ensure that a connection will not remain indefinitely open and so the
|
|
||||||
// underlying file descriptor is not released.
|
|
||||||
// This should protect us against buggy clients and edge cases.
|
|
||||||
c.SetConnDeadline()
|
|
||||||
} else {
|
} else {
|
||||||
logger.Warn(logSender, "", "connection to remove with id %#v not found!", connectionID)
|
logger.Warn(logSender, "", "connection to remove with id %#v not found!", connectionID)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -68,8 +68,6 @@ func (c *fakeConnection) GetRemoteAddress() string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *fakeConnection) SetConnDeadline() {}
|
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
logfilePath := "common_test.log"
|
logfilePath := "common_test.log"
|
||||||
logger.InitLogger(logfilePath, 5, 1, 28, false, zerolog.DebugLevel)
|
logger.InitLogger(logfilePath, 5, 1, 28, false, zerolog.DebugLevel)
|
||||||
|
|||||||
@@ -42,9 +42,6 @@ func (c *Connection) GetRemoteAddress() string {
|
|||||||
return c.clientContext.RemoteAddr().String()
|
return c.clientContext.RemoteAddr().String()
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetConnDeadline does nothing
|
|
||||||
func (c *Connection) SetConnDeadline() {}
|
|
||||||
|
|
||||||
// Disconnect disconnects the client
|
// Disconnect disconnects the client
|
||||||
func (c *Connection) Disconnect() error {
|
func (c *Connection) Disconnect() error {
|
||||||
return c.clientContext.Close(ftpserver.StatusServiceNotAvailable, "connection closed")
|
return c.clientContext.Close(ftpserver.StatusServiceNotAvailable, "connection closed")
|
||||||
|
|||||||
@@ -114,8 +114,6 @@ func (c *fakeConnection) GetRemoteAddress() string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *fakeConnection) SetConnDeadline() {}
|
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
homeBasePath = os.TempDir()
|
homeBasePath = os.TempDir()
|
||||||
logfilePath := filepath.Join(configDir, "sftpgo_api_test.log")
|
logfilePath := filepath.Join(configDir, "sftpgo_api_test.log")
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ type Connection struct {
|
|||||||
ClientVersion string
|
ClientVersion string
|
||||||
// Remote address for this connection
|
// Remote address for this connection
|
||||||
RemoteAddr net.Addr
|
RemoteAddr net.Addr
|
||||||
netConn net.Conn
|
|
||||||
channel ssh.Channel
|
channel ssh.Channel
|
||||||
command string
|
command string
|
||||||
}
|
}
|
||||||
@@ -38,11 +37,6 @@ func (c *Connection) GetRemoteAddress() string {
|
|||||||
return c.RemoteAddr.String()
|
return c.RemoteAddr.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetConnDeadline sets a deadline on the network connection so it will be eventually closed
|
|
||||||
func (c *Connection) SetConnDeadline() {
|
|
||||||
c.netConn.SetDeadline(time.Now().Add(2 * time.Minute)) //nolint:errcheck
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetCommand returns the SSH command, if any
|
// GetCommand returns the SSH command, if any
|
||||||
func (c *Connection) GetCommand() string {
|
func (c *Connection) GetCommand() string {
|
||||||
return c.command
|
return c.command
|
||||||
@@ -413,11 +407,7 @@ func (c *Connection) handleSFTPUploadToExistingFile(pflags sftp.FileOpenFlags, r
|
|||||||
|
|
||||||
// Disconnect disconnects the client closing the network connection
|
// Disconnect disconnects the client closing the network connection
|
||||||
func (c *Connection) Disconnect() error {
|
func (c *Connection) Disconnect() error {
|
||||||
if c.channel != nil {
|
return c.channel.Close()
|
||||||
err := c.channel.Close()
|
|
||||||
c.Log(logger.LevelInfo, "channel close, err: %v", err)
|
|
||||||
}
|
|
||||||
return c.netConn.Close()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func getOSOpenFlags(requestFlags sftp.FileOpenFlags) (flags int) {
|
func getOSOpenFlags(requestFlags sftp.FileOpenFlags) (flags int) {
|
||||||
|
|||||||
@@ -518,7 +518,6 @@ func TestSSHCommandErrors(t *testing.T) {
|
|||||||
connection := Connection{
|
connection := Connection{
|
||||||
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs),
|
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs),
|
||||||
channel: &mockSSHChannel,
|
channel: &mockSSHChannel,
|
||||||
netConn: client,
|
|
||||||
}
|
}
|
||||||
cmd := sshCommand{
|
cmd := sshCommand{
|
||||||
command: "md5sum",
|
command: "md5sum",
|
||||||
@@ -674,7 +673,6 @@ func TestCommandsWithExtensionsFilter(t *testing.T) {
|
|||||||
connection := &Connection{
|
connection := &Connection{
|
||||||
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs),
|
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs),
|
||||||
channel: &mockSSHChannel,
|
channel: &mockSSHChannel,
|
||||||
netConn: client,
|
|
||||||
}
|
}
|
||||||
cmd := sshCommand{
|
cmd := sshCommand{
|
||||||
command: "md5sum",
|
command: "md5sum",
|
||||||
@@ -747,7 +745,6 @@ func TestSSHCommandsRemoteFs(t *testing.T) {
|
|||||||
connection := &Connection{
|
connection := &Connection{
|
||||||
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs),
|
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs),
|
||||||
channel: &mockSSHChannel,
|
channel: &mockSSHChannel,
|
||||||
netConn: client,
|
|
||||||
}
|
}
|
||||||
cmd := sshCommand{
|
cmd := sshCommand{
|
||||||
command: "md5sum",
|
command: "md5sum",
|
||||||
@@ -960,7 +957,6 @@ func TestSystemCommandErrors(t *testing.T) {
|
|||||||
connection := &Connection{
|
connection := &Connection{
|
||||||
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs),
|
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs),
|
||||||
channel: &mockSSHChannel,
|
channel: &mockSSHChannel,
|
||||||
netConn: client,
|
|
||||||
}
|
}
|
||||||
var sshCmd sshCommand
|
var sshCmd sshCommand
|
||||||
if runtime.GOOS == osWindows {
|
if runtime.GOOS == osWindows {
|
||||||
@@ -1268,7 +1264,6 @@ func TestSCPCommandHandleErrors(t *testing.T) {
|
|||||||
connection := &Connection{
|
connection := &Connection{
|
||||||
BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, dataprovider.User{}, nil),
|
BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, dataprovider.User{}, nil),
|
||||||
channel: &mockSSHChannel,
|
channel: &mockSSHChannel,
|
||||||
netConn: client,
|
|
||||||
}
|
}
|
||||||
scpCommand := scpCommand{
|
scpCommand := scpCommand{
|
||||||
sshCommand: sshCommand{
|
sshCommand: sshCommand{
|
||||||
@@ -1309,7 +1304,6 @@ func TestSCPErrorsMockFs(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
connection := &Connection{
|
connection := &Connection{
|
||||||
channel: &mockSSHChannel,
|
channel: &mockSSHChannel,
|
||||||
netConn: client,
|
|
||||||
BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, u, fs),
|
BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, u, fs),
|
||||||
}
|
}
|
||||||
scpCommand := scpCommand{
|
scpCommand := scpCommand{
|
||||||
@@ -1364,7 +1358,6 @@ func TestSCPRecursiveDownloadErrors(t *testing.T) {
|
|||||||
connection := &Connection{
|
connection := &Connection{
|
||||||
BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, dataprovider.User{}, fs),
|
BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, dataprovider.User{}, fs),
|
||||||
channel: &mockSSHChannel,
|
channel: &mockSSHChannel,
|
||||||
netConn: client,
|
|
||||||
}
|
}
|
||||||
scpCommand := scpCommand{
|
scpCommand := scpCommand{
|
||||||
sshCommand: sshCommand{
|
sshCommand: sshCommand{
|
||||||
|
|||||||
@@ -287,6 +287,8 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server
|
|||||||
// handshake completed so remove the deadline, we'll use IdleTimeout configuration from now on
|
// handshake completed so remove the deadline, we'll use IdleTimeout configuration from now on
|
||||||
conn.SetDeadline(time.Time{}) //nolint:errcheck
|
conn.SetDeadline(time.Time{}) //nolint:errcheck
|
||||||
|
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
var user dataprovider.User
|
var user dataprovider.User
|
||||||
|
|
||||||
// Unmarshal cannot fails here and even if it fails we'll have a user with no permissions
|
// Unmarshal cannot fails here and even if it fails we'll have a user with no permissions
|
||||||
@@ -299,62 +301,68 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server
|
|||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Warn(logSender, "", "could create filesystem for user %#v err: %v", user.Username, err)
|
logger.Warn(logSender, "", "could create filesystem for user %#v err: %v", user.Username, err)
|
||||||
conn.Close()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
connection := Connection{
|
fs.CheckRootPath(user.Username, user.GetUID(), user.GetGID())
|
||||||
BaseConnection: common.NewBaseConnection(connectionID, "sftpd", user, fs),
|
|
||||||
ClientVersion: string(sconn.ClientVersion()),
|
|
||||||
RemoteAddr: remoteAddr,
|
|
||||||
netConn: conn,
|
|
||||||
channel: nil,
|
|
||||||
}
|
|
||||||
|
|
||||||
connection.Fs.CheckRootPath(user.Username, user.GetUID(), user.GetGID())
|
logger.Log(logger.LevelInfo, common.ProtocolSSH, connectionID,
|
||||||
|
"User id: %d, logged in with: %#v, username: %#v, home_dir: %#v remote addr: %#v",
|
||||||
connection.Log(logger.LevelInfo, "User id: %d, logged in with: %#v, username: %#v, home_dir: %#v remote addr: %#v",
|
|
||||||
user.ID, loginType, user.Username, user.HomeDir, remoteAddr.String())
|
user.ID, loginType, user.Username, user.HomeDir, remoteAddr.String())
|
||||||
dataprovider.UpdateLastLogin(user) //nolint:errcheck
|
dataprovider.UpdateLastLogin(user) //nolint:errcheck
|
||||||
|
|
||||||
go ssh.DiscardRequests(reqs)
|
go ssh.DiscardRequests(reqs)
|
||||||
|
|
||||||
|
channelCounter := 0
|
||||||
for newChannel := range chans {
|
for newChannel := range chans {
|
||||||
// If its not a session channel we just move on because its not something we
|
// If its not a session channel we just move on because its not something we
|
||||||
// know how to handle at this point.
|
// know how to handle at this point.
|
||||||
if newChannel.ChannelType() != "session" {
|
if newChannel.ChannelType() != "session" {
|
||||||
connection.Log(logger.LevelDebug, "received an unknown channel type: %v", newChannel.ChannelType())
|
logger.Log(logger.LevelDebug, common.ProtocolSSH, connectionID, "received an unknown channel type: %v",
|
||||||
|
newChannel.ChannelType())
|
||||||
newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") //nolint:errcheck
|
newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") //nolint:errcheck
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
channel, requests, err := newChannel.Accept()
|
channel, requests, err := newChannel.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
connection.Log(logger.LevelWarn, "could not accept a channel: %v", err)
|
logger.Log(logger.LevelWarn, common.ProtocolSSH, connectionID, "could not accept a channel: %v", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
channelCounter++
|
||||||
// Channels have a type that is dependent on the protocol. For SFTP this is "subsystem"
|
// Channels have a type that is dependent on the protocol. For SFTP this is "subsystem"
|
||||||
// with a payload that (should) be "sftp". Discard anything else we receive ("pty", "shell", etc)
|
// with a payload that (should) be "sftp". Discard anything else we receive ("pty", "shell", etc)
|
||||||
go func(in <-chan *ssh.Request) {
|
go func(in <-chan *ssh.Request, counter int) {
|
||||||
for req := range in {
|
for req := range in {
|
||||||
ok := false
|
ok := false
|
||||||
|
connID := fmt.Sprintf("%v_%v", connectionID, counter)
|
||||||
|
|
||||||
switch req.Type {
|
switch req.Type {
|
||||||
case "subsystem":
|
case "subsystem":
|
||||||
if string(req.Payload[4:]) == "sftp" {
|
if string(req.Payload[4:]) == "sftp" {
|
||||||
ok = true
|
ok = true
|
||||||
connection.SetProtocol(common.ProtocolSFTP)
|
connection := Connection{
|
||||||
connection.channel = channel
|
BaseConnection: common.NewBaseConnection(connID, common.ProtocolSFTP, user, fs),
|
||||||
|
ClientVersion: string(sconn.ClientVersion()),
|
||||||
|
RemoteAddr: remoteAddr,
|
||||||
|
channel: channel,
|
||||||
|
}
|
||||||
go c.handleSftpConnection(channel, &connection)
|
go c.handleSftpConnection(channel, &connection)
|
||||||
}
|
}
|
||||||
case "exec":
|
case "exec":
|
||||||
connection.SetProtocol(common.ProtocolSSH)
|
// protocol will be set later inside processSSHCommand it could be SSH or SCP
|
||||||
ok = processSSHCommand(req.Payload, &connection, channel, c.EnabledSSHCommands)
|
connection := Connection{
|
||||||
|
BaseConnection: common.NewBaseConnection(connID, "sshd", user, fs),
|
||||||
|
ClientVersion: string(sconn.ClientVersion()),
|
||||||
|
RemoteAddr: remoteAddr,
|
||||||
|
channel: channel,
|
||||||
|
}
|
||||||
|
ok = processSSHCommand(req.Payload, &connection, c.EnabledSSHCommands)
|
||||||
}
|
}
|
||||||
req.Reply(ok, nil) //nolint:errcheck
|
req.Reply(ok, nil) //nolint:errcheck
|
||||||
}
|
}
|
||||||
}(requests)
|
}(requests, channelCounter)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5368,6 +5368,33 @@ func TestPermsSubDirsSetstat(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenUnhandledChannel(t *testing.T) {
|
||||||
|
u := getTestUser(false)
|
||||||
|
user, _, err := httpd.AddUser(u, http.StatusOK)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
config := &ssh.ClientConfig{
|
||||||
|
User: user.Username,
|
||||||
|
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
Auth: []ssh.AuthMethod{ssh.Password(defaultPassword)},
|
||||||
|
}
|
||||||
|
conn, err := ssh.Dial("tcp", sftpServerAddr, config)
|
||||||
|
if assert.NoError(t, err) {
|
||||||
|
_, _, err = conn.OpenChannel("unhandled", nil)
|
||||||
|
if assert.Error(t, err) {
|
||||||
|
assert.Contains(t, err.Error(), "unknown channel type")
|
||||||
|
}
|
||||||
|
err = conn.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
_, err = httpd.RemoveUser(user, http.StatusOK)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
err = os.RemoveAll(user.GetHomeDir())
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
func TestPermsSubDirsCommands(t *testing.T) {
|
func TestPermsSubDirsCommands(t *testing.T) {
|
||||||
usePubKey := true
|
usePubKey := true
|
||||||
u := getTestUser(usePubKey)
|
u := getTestUser(usePubKey)
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ type systemCommand struct {
|
|||||||
quotaCheckPath string
|
quotaCheckPath string
|
||||||
}
|
}
|
||||||
|
|
||||||
func processSSHCommand(payload []byte, connection *Connection, channel ssh.Channel, enabledSSHCommands []string) bool {
|
func processSSHCommand(payload []byte, connection *Connection, enabledSSHCommands []string) bool {
|
||||||
var msg sshSubsystemExecMsg
|
var msg sshSubsystemExecMsg
|
||||||
if err := ssh.Unmarshal(payload, &msg); err == nil {
|
if err := ssh.Unmarshal(payload, &msg); err == nil {
|
||||||
name, args, err := parseCommandPayload(msg.Command)
|
name, args, err := parseCommandPayload(msg.Command)
|
||||||
@@ -58,7 +58,6 @@ func processSSHCommand(payload []byte, connection *Connection, channel ssh.Chann
|
|||||||
connection.command = msg.Command
|
connection.command = msg.Command
|
||||||
if name == scpCmdName && len(args) >= 2 {
|
if name == scpCmdName && len(args) >= 2 {
|
||||||
connection.SetProtocol(common.ProtocolSCP)
|
connection.SetProtocol(common.ProtocolSCP)
|
||||||
connection.channel = channel
|
|
||||||
scpCommand := scpCommand{
|
scpCommand := scpCommand{
|
||||||
sshCommand: sshCommand{
|
sshCommand: sshCommand{
|
||||||
command: name,
|
command: name,
|
||||||
@@ -70,7 +69,6 @@ func processSSHCommand(payload []byte, connection *Connection, channel ssh.Chann
|
|||||||
}
|
}
|
||||||
if name != scpCmdName {
|
if name != scpCmdName {
|
||||||
connection.SetProtocol(common.ProtocolSSH)
|
connection.SetProtocol(common.ProtocolSSH)
|
||||||
connection.channel = channel
|
|
||||||
sshCommand := sshCommand{
|
sshCommand := sshCommand{
|
||||||
command: name,
|
command: name,
|
||||||
connection: connection,
|
connection: connection,
|
||||||
|
|||||||
@@ -39,9 +39,6 @@ func (c *Connection) GetRemoteAddress() string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetConnDeadline does nothing
|
|
||||||
func (c *Connection) SetConnDeadline() {}
|
|
||||||
|
|
||||||
// Disconnect closes the active transfer
|
// Disconnect closes the active transfer
|
||||||
func (c *Connection) Disconnect() error {
|
func (c *Connection) Disconnect() error {
|
||||||
return c.SignalTransfersAbort()
|
return c.SignalTransfersAbort()
|
||||||
|
|||||||
Reference in New Issue
Block a user