add support for limiting max concurrent client connections

This commit is contained in:
Nicola Murino
2020-12-15 19:29:30 +01:00
parent ea0bf5e4c8
commit f34462e3c3
11 changed files with 149 additions and 19 deletions

View File

@@ -247,7 +247,9 @@ type Configuration struct {
// Absolute path to an external program or an HTTP URL to invoke after a user connects // Absolute path to an external program or an HTTP URL to invoke after a user connects
// and before he tries to login. It allows you to reject the connection based on the source // and before he tries to login. It allows you to reject the connection based on the source
// ip address. Leave empty do disable. // ip address. Leave empty do disable.
PostConnectHook string `json:"post_connect_hook" mapstructure:"post_connect_hook"` PostConnectHook string `json:"post_connect_hook" mapstructure:"post_connect_hook"`
// Maximum number of concurrent client connections. 0 means unlimited
MaxTotalConnections int `json:"max_total_connections" mapstructure:"max_total_connections"`
idleTimeoutAsDuration time.Duration idleTimeoutAsDuration time.Duration
idleLoginTimeout time.Duration idleLoginTimeout time.Duration
} }
@@ -544,6 +546,18 @@ func (conns *ActiveConnections) checkIdles() {
conns.RUnlock() conns.RUnlock()
} }
// IsNewConnectionAllowed returns false if the maximum number of concurrent allowed connections is exceeded
func (conns *ActiveConnections) IsNewConnectionAllowed() bool {
if Config.MaxTotalConnections == 0 {
return true
}
conns.RLock()
defer conns.RUnlock()
return len(conns.connections) < Config.MaxTotalConnections
}
// GetStats returns stats for active connections // GetStats returns stats for active connections
func (conns *ActiveConnections) GetStats() []ConnectionStatus { func (conns *ActiveConnections) GetStats() []ConnectionStatus {
conns.RLock() conns.RLock()

View File

@@ -225,6 +225,26 @@ func TestSSHConnections(t *testing.T) {
assert.NoError(t, sshConn3.Close()) assert.NoError(t, sshConn3.Close())
} }
func TestMaxConnections(t *testing.T) {
oldValue := Config.MaxTotalConnections
Config.MaxTotalConnections = 1
assert.True(t, Connections.IsNewConnectionAllowed())
c := NewBaseConnection("id", ProtocolSFTP, dataprovider.User{}, nil)
fakeConn := &fakeConnection{
BaseConnection: c,
}
Connections.Add(fakeConn)
assert.Len(t, Connections.GetStats(), 1)
assert.False(t, Connections.IsNewConnectionAllowed())
res := Connections.Close(fakeConn.GetID())
assert.True(t, res)
assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 300*time.Millisecond, 50*time.Millisecond)
Config.MaxTotalConnections = oldValue
}
func TestIdleConnections(t *testing.T) { func TestIdleConnections(t *testing.T) {
configCopy := Config configCopy := Config
@@ -310,6 +330,7 @@ func TestCloseConnection(t *testing.T) {
fakeConn := &fakeConnection{ fakeConn := &fakeConnection{
BaseConnection: c, BaseConnection: c,
} }
assert.True(t, Connections.IsNewConnectionAllowed())
Connections.Add(fakeConn) Connections.Add(fakeConn)
assert.Len(t, Connections.GetStats(), 1) assert.Len(t, Connections.GetStats(), 1)
res := Connections.Close(fakeConn.GetID()) res := Connections.Close(fakeConn.GetID())

View File

@@ -65,9 +65,11 @@ func Init() {
ExecuteOn: []string{}, ExecuteOn: []string{},
Hook: "", Hook: "",
}, },
SetstatMode: 0, SetstatMode: 0,
ProxyProtocol: 0, ProxyProtocol: 0,
ProxyAllowed: []string{}, ProxyAllowed: []string{},
PostConnectHook: "",
MaxTotalConnections: 0,
}, },
SFTPD: sftpd.Configuration{ SFTPD: sftpd.Configuration{
Banner: defaultSFTPDBanner, Banner: defaultSFTPDBanner,
@@ -413,6 +415,7 @@ func setViperDefaults() {
viper.SetDefault("common.proxy_protocol", globalConf.Common.ProxyProtocol) viper.SetDefault("common.proxy_protocol", globalConf.Common.ProxyProtocol)
viper.SetDefault("common.proxy_allowed", globalConf.Common.ProxyAllowed) viper.SetDefault("common.proxy_allowed", globalConf.Common.ProxyAllowed)
viper.SetDefault("common.post_connect_hook", globalConf.Common.PostConnectHook) viper.SetDefault("common.post_connect_hook", globalConf.Common.PostConnectHook)
viper.SetDefault("common.max_total_connections", globalConf.Common.MaxTotalConnections)
viper.SetDefault("sftpd.bind_port", globalConf.SFTPD.BindPort) viper.SetDefault("sftpd.bind_port", globalConf.SFTPD.BindPort)
viper.SetDefault("sftpd.bind_address", globalConf.SFTPD.BindAddress) viper.SetDefault("sftpd.bind_address", globalConf.SFTPD.BindAddress)
viper.SetDefault("sftpd.max_auth_tries", globalConf.SFTPD.MaxAuthTries) viper.SetDefault("sftpd.max_auth_tries", globalConf.SFTPD.MaxAuthTries)

View File

@@ -63,6 +63,7 @@ The configuration file contains the following sections:
- If `proxy_protocol` is set to 1 and we receive a proxy header from an IP that is not in the list then the connection will be accepted and the header will be ignored - If `proxy_protocol` is set to 1 and we receive a proxy header from an IP that is not in the list then the connection will be accepted and the header will be ignored
- If `proxy_protocol` is set to 2 and we receive a proxy header from an IP that is not in the list then the connection will be rejected - If `proxy_protocol` is set to 2 and we receive a proxy header from an IP that is not in the list then the connection will be rejected
- `post_connect_hook`, string. Absolute path to the command to execute or HTTP URL to notify. See [Post connect hook](./post-connect-hook.md) for more details. Leave empty to disable - `post_connect_hook`, string. Absolute path to the command to execute or HTTP URL to notify. See [Post connect hook](./post-connect-hook.md) for more details. Leave empty to disable
- `max_total_connections`, integer. Maximum number of concurrent client connections. 0 means unlimited
- **"sftpd"**, the configuration for the SFTP server - **"sftpd"**, the configuration for the SFTP server
- `bind_port`, integer. The port used for serving SFTP requests. 0 means disabled. Default: 2022 - `bind_port`, integer. The port used for serving SFTP requests. 0 means disabled. Default: 2022
- `bind_address`, string. Leave blank to listen on all available network interfaces. Default: "" - `bind_address`, string. Leave blank to listen on all available network interfaces. Default: ""

View File

@@ -502,6 +502,29 @@ func TestPostConnectHook(t *testing.T) {
common.Config.PostConnectHook = "" common.Config.PostConnectHook = ""
} }
func TestMaxConnections(t *testing.T) {
oldValue := common.Config.MaxTotalConnections
common.Config.MaxTotalConnections = 1
user, _, err := httpd.AddUser(getTestUser(), http.StatusOK)
assert.NoError(t, err)
client, err := getFTPClient(user, true)
if assert.NoError(t, err) {
err = checkBasicFTP(client)
assert.NoError(t, err)
_, err = getFTPClient(user, false)
assert.Error(t, err)
err = client.Quit()
assert.NoError(t, err)
}
_, err = httpd.RemoveUser(user, http.StatusOK)
assert.NoError(t, err)
err = os.RemoveAll(user.GetHomeDir())
assert.NoError(t, err)
common.Config.MaxTotalConnections = oldValue
}
func TestMaxSessions(t *testing.T) { func TestMaxSessions(t *testing.T) {
u := getTestUser() u := getTestUser()
u.MaxSessions = 1 u.MaxSessions = 1

View File

@@ -98,8 +98,12 @@ func (s *Server) GetSettings() (*ftpserver.Settings, error) {
// ClientConnected is called to send the very first welcome message // ClientConnected is called to send the very first welcome message
func (s *Server) ClientConnected(cc ftpserver.ClientContext) (string, error) { func (s *Server) ClientConnected(cc ftpserver.ClientContext) (string, error) {
if !common.Connections.IsNewConnectionAllowed() {
logger.Log(logger.LevelDebug, common.ProtocolFTP, "", "connection refused, configured limit reached")
return "", common.ErrConnectionDenied
}
if err := common.Config.ExecutePostConnectHook(cc.RemoteAddr().String(), common.ProtocolFTP); err != nil { if err := common.Config.ExecutePostConnectHook(cc.RemoteAddr().String(), common.ProtocolFTP); err != nil {
return common.ErrConnectionDenied.Error(), err return "", err
} }
connID := fmt.Sprintf("%v", cc.ID()) connID := fmt.Sprintf("%v", cc.ID())
user := dataprovider.User{} user := dataprovider.User{}

View File

@@ -277,23 +277,22 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve
logger.Error(logSender, "", "panic in AcceptInboundConnection: %#v stack strace: %v", r, string(debug.Stack())) logger.Error(logSender, "", "panic in AcceptInboundConnection: %#v stack strace: %v", r, string(debug.Stack()))
} }
}() }()
if !common.Connections.IsNewConnectionAllowed() {
logger.Log(logger.LevelDebug, common.ProtocolSSH, "", "connection refused, configured limit reached")
conn.Close()
return
}
// Before beginning a handshake must be performed on the incoming net.Conn // Before beginning a handshake must be performed on the incoming net.Conn
// we'll set a Deadline for handshake to complete, the default is 2 minutes as OpenSSH // we'll set a Deadline for handshake to complete, the default is 2 minutes as OpenSSH
conn.SetDeadline(time.Now().Add(handshakeTimeout)) //nolint:errcheck conn.SetDeadline(time.Now().Add(handshakeTimeout)) //nolint:errcheck
remoteAddr := conn.RemoteAddr() if err := common.Config.ExecutePostConnectHook(conn.RemoteAddr().String(), common.ProtocolSSH); err != nil {
if err := common.Config.ExecutePostConnectHook(remoteAddr.String(), common.ProtocolSSH); err != nil {
conn.Close() conn.Close()
return return
} }
sconn, chans, reqs, err := ssh.NewServerConn(conn, config) sconn, chans, reqs, err := ssh.NewServerConn(conn, config)
if err != nil { if err != nil {
logger.Debug(logSender, "", "failed to accept an incoming connection: %v", err) logger.Debug(logSender, "", "failed to accept an incoming connection: %v", err)
if _, ok := err.(*ssh.ServerAuthError); !ok { checkAuthError(conn, err)
ip := utils.GetIPFromRemoteAddress(remoteAddr.String())
logger.ConnectionFailedLog("", ip, dataprovider.LoginMethodNoAuthTryed, common.ProtocolSSH, err.Error())
metrics.AddNoAuthTryed()
dataprovider.ExecutePostLoginHook("", dataprovider.LoginMethodNoAuthTryed, ip, common.ProtocolSSH, err)
}
return return
} }
// handshake completed so remove the deadline, we'll use IdleTimeout configuration from now on // handshake completed so remove the deadline, we'll use IdleTimeout configuration from now on
@@ -315,7 +314,7 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve
logger.Log(logger.LevelInfo, common.ProtocolSSH, connectionID, logger.Log(logger.LevelInfo, common.ProtocolSSH, connectionID,
"User id: %d, logged in with: %#v, username: %#v, home_dir: %#v remote addr: %#v", "User id: %d, logged in with: %#v, username: %#v, home_dir: %#v remote addr: %#v",
user.ID, loginType, user.Username, user.HomeDir, remoteAddr.String()) user.ID, loginType, user.Username, user.HomeDir, conn.RemoteAddr().String())
dataprovider.UpdateLastLogin(user) //nolint:errcheck dataprovider.UpdateLastLogin(user) //nolint:errcheck
sshConnection := common.NewSSHConnection(connectionID, conn) sshConnection := common.NewSSHConnection(connectionID, conn)
@@ -354,13 +353,13 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve
switch req.Type { switch req.Type {
case "subsystem": case "subsystem":
if string(req.Payload[4:]) == "sftp" { if string(req.Payload[4:]) == "sftp" {
fs, err := user.GetFilesystem(connectionID) fs, err := user.GetFilesystem(connID)
if err == nil { if err == nil {
ok = true ok = true
connection := Connection{ connection := Connection{
BaseConnection: common.NewBaseConnection(connID, common.ProtocolSFTP, user, fs), BaseConnection: common.NewBaseConnection(connID, common.ProtocolSFTP, user, fs),
ClientVersion: string(sconn.ClientVersion()), ClientVersion: string(sconn.ClientVersion()),
RemoteAddr: remoteAddr, RemoteAddr: conn.RemoteAddr(),
channel: channel, channel: channel,
} }
go c.handleSftpConnection(channel, &connection) go c.handleSftpConnection(channel, &connection)
@@ -368,12 +367,12 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve
} }
case "exec": case "exec":
// protocol will be set later inside processSSHCommand it could be SSH or SCP // protocol will be set later inside processSSHCommand it could be SSH or SCP
fs, err := user.GetFilesystem(connectionID) fs, err := user.GetFilesystem(connID)
if err == nil { if err == nil {
connection := Connection{ connection := Connection{
BaseConnection: common.NewBaseConnection(connID, "sshd_exec", user, fs), BaseConnection: common.NewBaseConnection(connID, "sshd_exec", user, fs),
ClientVersion: string(sconn.ClientVersion()), ClientVersion: string(sconn.ClientVersion()),
RemoteAddr: remoteAddr, RemoteAddr: conn.RemoteAddr(),
channel: channel, channel: channel,
} }
ok = processSSHCommand(req.Payload, &connection, c.EnabledSSHCommands) ok = processSSHCommand(req.Payload, &connection, c.EnabledSSHCommands)
@@ -420,6 +419,15 @@ func (c *Configuration) createHandler(connection *Connection) sftp.Handlers {
} }
} }
func checkAuthError(conn net.Conn, err error) {
if _, ok := err.(*ssh.ServerAuthError); !ok {
ip := utils.GetIPFromRemoteAddress(conn.RemoteAddr().String())
logger.ConnectionFailedLog("", ip, dataprovider.LoginMethodNoAuthTryed, common.ProtocolSSH, err.Error())
metrics.AddNoAuthTryed()
dataprovider.ExecutePostLoginHook("", dataprovider.LoginMethodNoAuthTryed, ip, common.ProtocolSSH, err)
}
}
func checkRootPath(user *dataprovider.User, connectionID string) error { func checkRootPath(user *dataprovider.User, connectionID string) error {
if user.FsConfig.Provider != dataprovider.SFTPFilesystemProvider { if user.FsConfig.Provider != dataprovider.SFTPFilesystemProvider {
// for sftp fs check root path does nothing so don't open a useless SFTP connection // for sftp fs check root path does nothing so don't open a useless SFTP connection

View File

@@ -2441,6 +2441,31 @@ func TestQuotaDisabledError(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
} }
func TestMaxConnections(t *testing.T) {
oldValue := common.Config.MaxTotalConnections
common.Config.MaxTotalConnections = 1
usePubKey := true
u := getTestUser(usePubKey)
user, _, err := httpd.AddUser(u, http.StatusOK)
assert.NoError(t, err)
client, err := getSftpClient(user, usePubKey)
if assert.NoError(t, err) {
defer client.Close()
assert.NoError(t, checkBasicSFTP(client))
c, err := getSftpClient(user, usePubKey)
if !assert.Error(t, err, "max sessions exceeded, new login should not succeed") {
c.Close()
}
}
_, err = httpd.RemoveUser(user, http.StatusOK)
assert.NoError(t, err)
err = os.RemoveAll(user.GetHomeDir())
assert.NoError(t, err)
common.Config.MaxTotalConnections = oldValue
}
func TestMaxSessions(t *testing.T) { func TestMaxSessions(t *testing.T) {
usePubKey := false usePubKey := false
u := getTestUser(usePubKey) u := getTestUser(usePubKey)

View File

@@ -9,7 +9,8 @@
"setstat_mode": 0, "setstat_mode": 0,
"proxy_protocol": 0, "proxy_protocol": 0,
"proxy_allowed": [], "proxy_allowed": [],
"post_connect_hook": "" "post_connect_hook": "",
"max_total_connections": 0
}, },
"sftpd": { "sftpd": {
"bind_port": 2022, "bind_port": 2022,

View File

@@ -112,6 +112,11 @@ func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
http.Error(w, common.ErrGenericFailure.Error(), http.StatusInternalServerError) http.Error(w, common.ErrGenericFailure.Error(), http.StatusInternalServerError)
} }
}() }()
if !common.Connections.IsNewConnectionAllowed() {
logger.Log(logger.LevelDebug, common.ProtocolFTP, "", "connection refused, configured limit reached")
http.Error(w, common.ErrConnectionDenied.Error(), http.StatusServiceUnavailable)
return
}
checkRemoteAddress(r) checkRemoteAddress(r)
if err := common.Config.ExecutePostConnectHook(r.RemoteAddr, common.ProtocolWebDAV); err != nil { if err := common.Config.ExecutePostConnectHook(r.RemoteAddr, common.ProtocolWebDAV); err != nil {
http.Error(w, common.ErrConnectionDenied.Error(), http.StatusForbidden) http.Error(w, common.ErrConnectionDenied.Error(), http.StatusForbidden)

View File

@@ -650,6 +650,31 @@ func TestPostConnectHook(t *testing.T) {
common.Config.PostConnectHook = "" common.Config.PostConnectHook = ""
} }
func TestMaxConnections(t *testing.T) {
oldValue := common.Config.MaxTotalConnections
common.Config.MaxTotalConnections = 1
user, _, err := httpd.AddUser(getTestUser(), http.StatusOK)
assert.NoError(t, err)
client := getWebDavClient(user)
assert.NoError(t, checkBasicFunc(client))
// now add a fake connection
fs := vfs.NewOsFs("id", os.TempDir(), nil)
connection := &webdavd.Connection{
BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, user, fs),
}
common.Connections.Add(connection)
assert.Error(t, checkBasicFunc(client))
common.Connections.Remove(connection.GetID())
_, err = httpd.RemoveUser(user, http.StatusOK)
assert.NoError(t, err)
err = os.RemoveAll(user.GetHomeDir())
assert.NoError(t, err)
assert.Len(t, common.Connections.GetStats(), 0)
common.Config.MaxTotalConnections = oldValue
}
func TestMaxSessions(t *testing.T) { func TestMaxSessions(t *testing.T) {
u := getTestUser() u := getTestUser()
u.MaxSessions = 1 u.MaxSessions = 1