Add support for graceful shutdown

Fixes #1014

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
Nicola Murino
2022-10-22 11:56:41 +02:00
parent 87045284cc
commit db0e58ae7e
32 changed files with 371 additions and 124 deletions

View File

@@ -135,6 +135,7 @@ var (
ErrNoCredentials = errors.New("no credential provided")
ErrInternalFailure = errors.New("internal failure")
ErrTransferAborted = errors.New("transfer aborted")
ErrShuttingDown = errors.New("the service is shutting down")
errNoTransfer = errors.New("requested transfer not found")
errTransferMismatch = errors.New("transfer mismatch")
)
@@ -153,11 +154,13 @@ var (
ProtocolHTTP, ProtocolHTTPShare, ProtocolOIDC}
disconnHookProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP}
// the map key is the protocol, for each protocol we can have multiple rate limiters
rateLimiters map[string][]*rateLimiter
rateLimiters map[string][]*rateLimiter
isShuttingDown atomic.Bool
)
// Initialize sets the common configuration
func Initialize(c Configuration, isShared int) error {
isShuttingDown.Store(false)
Config = c
Config.Actions.ExecuteOn = util.RemoveDuplicates(Config.Actions.ExecuteOn, true)
Config.Actions.ExecuteSync = util.RemoveDuplicates(Config.Actions.ExecuteSync, true)
@@ -220,6 +223,67 @@ func Initialize(c Configuration, isShared int) error {
return nil
}
// CheckClosing returns an error if the service is closing
func CheckClosing() error {
if isShuttingDown.Load() {
return ErrShuttingDown
}
return nil
}
// WaitForTransfers waits, for the specified grace time, for currently ongoing
// client-initiated transfer sessions to completes.
// A zero graceTime means no wait
func WaitForTransfers(graceTime int) {
if graceTime == 0 {
return
}
if isShuttingDown.Swap(true) {
return
}
if activeHooks.Load() == 0 && getActiveConnections() == 0 {
return
}
graceTimer := time.NewTimer(time.Duration(graceTime) * time.Second)
ticker := time.NewTicker(3 * time.Second)
for {
select {
case <-ticker.C:
hooks := activeHooks.Load()
logger.Info(logSender, "", "active hooks: %d", hooks)
if hooks == 0 && getActiveConnections() == 0 {
logger.Info(logSender, "", "no more active connections, graceful shutdown")
ticker.Stop()
graceTimer.Stop()
return
}
case <-graceTimer.C:
logger.Info(logSender, "", "grace time expired, hard shutdown")
ticker.Stop()
return
}
}
}
// getActiveConnections returns the number of connections with active transfers
func getActiveConnections() int {
var activeConns int
Connections.RLock()
for _, c := range Connections.connections {
if len(c.GetTransfers()) > 0 {
activeConns++
}
}
Connections.RUnlock()
logger.Info(logSender, "", "number of connections with active transfers: %d", activeConns)
return activeConns
}
// LimitRate blocks until all the configured rate limiters
// allow one event to happen.
// It returns an error if the time to wait exceeds the max
@@ -1051,30 +1115,34 @@ func (conns *ActiveConnections) GetClientConnections() int32 {
return conns.clients.getTotal()
}
// IsNewConnectionAllowed returns false if the maximum number of concurrent allowed connections is exceeded
// or a whitelist is defined and the specified ipAddr is not listed
func (conns *ActiveConnections) IsNewConnectionAllowed(ipAddr string) bool {
// IsNewConnectionAllowed returns an error if the maximum number of concurrent allowed
// connections is exceeded or a whitelist is defined and the specified ipAddr is not listed
// or the service is shutting down
func (conns *ActiveConnections) IsNewConnectionAllowed(ipAddr string) error {
if isShuttingDown.Load() {
return ErrShuttingDown
}
if Config.whitelist != nil {
if !Config.whitelist.isAllowed(ipAddr) {
return false
return ErrConnectionDenied
}
}
if Config.MaxTotalConnections == 0 && Config.MaxPerHostConnections == 0 {
return true
return nil
}
if Config.MaxPerHostConnections > 0 {
if total := conns.clients.getTotalFrom(ipAddr); total > Config.MaxPerHostConnections {
logger.Debug(logSender, "", "active connections from %v %v/%v", ipAddr, total, Config.MaxPerHostConnections)
logger.Info(logSender, "", "active connections from %s %d/%d", ipAddr, total, Config.MaxPerHostConnections)
AddDefenderEvent(ipAddr, HostEventLimitExceeded)
return false
return ErrConnectionDenied
}
}
if Config.MaxTotalConnections > 0 {
if total := conns.clients.getTotal(); total > int32(Config.MaxTotalConnections) {
logger.Debug(logSender, "", "active client connections %v/%v", total, Config.MaxTotalConnections)
return false
logger.Info(logSender, "", "active client connections %d/%d", total, Config.MaxTotalConnections)
return ErrConnectionDenied
}
// on a single SFTP connection we could have multiple SFTP channels or commands
@@ -1083,10 +1151,13 @@ func (conns *ActiveConnections) IsNewConnectionAllowed(ipAddr string) bool {
conns.RLock()
defer conns.RUnlock()
return len(conns.connections) < Config.MaxTotalConnections
if sess := len(conns.connections); sess >= Config.MaxTotalConnections {
logger.Info(logSender, "", "active client sessions %d/%d", sess, Config.MaxTotalConnections)
return ErrConnectionDenied
}
}
return true
return nil
}
// GetStats returns stats for active connections