allow IPs in defender safe list to exceed max per-host connections

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
Nicola Murino
2024-02-27 18:22:21 +01:00
parent 12f599fd65
commit 799fdd7098
9 changed files with 80 additions and 43 deletions

View File

@@ -383,13 +383,14 @@ func GetDefenderScore(ip string) (int, error) {
return Config.defender.GetScore(ip)
}
// AddDefenderEvent adds the specified defender event for the given IP
func AddDefenderEvent(ip, protocol string, event HostEvent) {
// AddDefenderEvent adds the specified defender event for the given IP.
// Returns true if the IP is in the defender's safe list.
func AddDefenderEvent(ip, protocol string, event HostEvent) bool {
if Config.defender == nil {
return
return false
}
Config.defender.AddEvent(ip, protocol, event)
return Config.defender.AddEvent(ip, protocol, event)
}
func startPeriodicChecks(duration time.Duration, isShared int) {
@@ -1191,9 +1192,12 @@ func (conns *ActiveConnections) IsNewConnectionAllowed(ipAddr, protocol string)
if Config.MaxPerHostConnections > 0 {
if total := conns.clients.getTotalFrom(ipAddr); total > Config.MaxPerHostConnections {
logger.Info(logSender, "", "active connections from %s %d/%d", ipAddr, total, Config.MaxPerHostConnections)
AddDefenderEvent(ipAddr, protocol, HostEventLimitExceeded)
return ErrConnectionDenied
if !AddDefenderEvent(ipAddr, protocol, HostEventLimitExceeded) {
logger.Warn(logSender, "", "connection denied, active connections from IP %q: %d/%d",
ipAddr, total, Config.MaxPerHostConnections)
return ErrConnectionDenied
}
logger.Info(logSender, "", "active connections from safe IP %q: %d", ipAddr, total)
}
}

View File

@@ -668,9 +668,26 @@ func TestConnectionRoles(t *testing.T) {
}
func TestMaxConnectionPerHost(t *testing.T) {
oldValue := Config.MaxPerHostConnections
defender, err := newInMemoryDefender(&DefenderConfig{
Enabled: true,
Driver: DefenderDriverMemory,
BanTime: 30,
BanTimeIncrement: 50,
Threshold: 15,
ScoreInvalid: 2,
ScoreValid: 1,
ScoreLimitExceeded: 3,
ObservationTime: 30,
EntriesSoftLimit: 100,
EntriesHardLimit: 150,
})
require.NoError(t, err)
oldMaxPerHostConn := Config.MaxPerHostConnections
oldDefender := Config.defender
Config.MaxPerHostConnections = 2
Config.defender = defender
ipAddr := "192.168.9.9"
Connections.AddClientConnection(ipAddr)
@@ -682,14 +699,30 @@ func TestMaxConnectionPerHost(t *testing.T) {
Connections.AddClientConnection(ipAddr)
assert.Error(t, Connections.IsNewConnectionAllowed(ipAddr, ProtocolFTP))
assert.Equal(t, int32(3), Connections.GetClientConnections())
// Add the IP to the defender safe list
entry := dataprovider.IPListEntry{
IPOrNet: ipAddr,
Type: dataprovider.IPListTypeDefender,
Mode: dataprovider.ListModeAllow,
}
err = dataprovider.AddIPListEntry(&entry, "", "", "")
assert.NoError(t, err)
Connections.AddClientConnection(ipAddr)
assert.NoError(t, Connections.IsNewConnectionAllowed(ipAddr, ProtocolSSH))
err = dataprovider.DeleteIPListEntry(entry.IPOrNet, dataprovider.IPListTypeDefender, "", "", "")
assert.NoError(t, err)
Connections.RemoveClientConnection(ipAddr)
Connections.RemoveClientConnection(ipAddr)
Connections.RemoveClientConnection(ipAddr)
Connections.RemoveClientConnection(ipAddr)
assert.Equal(t, int32(0), Connections.GetClientConnections())
Config.MaxPerHostConnections = oldValue
Config.MaxPerHostConnections = oldMaxPerHostConn
Config.defender = oldDefender
}
func TestIdleConnections(t *testing.T) {

View File

@@ -47,7 +47,7 @@ var (
type Defender interface {
GetHosts() ([]dataprovider.DefenderEntry, error)
GetHost(ip string) (dataprovider.DefenderEntry, error)
AddEvent(ip, protocol string, event HostEvent)
AddEvent(ip, protocol string, event HostEvent) bool
IsBanned(ip, protocol string) bool
IsSafe(ip, protocol string) bool
GetBanTime(ip string) (*time.Time, error)

View File

@@ -88,17 +88,18 @@ func (d *dbDefender) DeleteHost(ip string) bool {
}
// AddEvent adds an event for the given IP.
// This method must be called for clients not yet banned
func (d *dbDefender) AddEvent(ip, protocol string, event HostEvent) {
// This method must be called for clients not yet banned.
// Returns true if the IP is in the defender's safe list.
func (d *dbDefender) AddEvent(ip, protocol string, event HostEvent) bool {
if d.IsSafe(ip, protocol) {
return
return true
}
score := d.baseDefender.getScore(event)
host, err := dataprovider.AddDefenderEvent(ip, score, d.getStartObservationTime())
if err != nil {
return
return false
}
d.baseDefender.logEvent(ip, protocol, event, host.Score)
if host.Score > d.config.Threshold {
@@ -118,6 +119,7 @@ func (d *dbDefender) AddEvent(ip, protocol string, event HostEvent) {
if err == nil {
d.cleanup()
}
return false
}
// GetBanTime returns the ban time for the given IP or nil if the IP is not banned

View File

@@ -170,10 +170,11 @@ func (d *memoryDefender) DeleteHost(ip string) bool {
}
// AddEvent adds an event for the given IP.
// This method must be called for clients not yet banned
func (d *memoryDefender) AddEvent(ip, protocol string, event HostEvent) {
// This method must be called for clients not yet banned.
// Returns true if the IP is in the defender's safe list.
func (d *memoryDefender) AddEvent(ip, protocol string, event HostEvent) bool {
if d.IsSafe(ip, protocol) {
return
return true
}
d.Lock()
@@ -182,7 +183,7 @@ func (d *memoryDefender) AddEvent(ip, protocol string, event HostEvent) {
// ignore events for already banned hosts
if v, ok := d.banned[ip]; ok {
if v.After(time.Now()) {
return
return false
}
delete(d.banned, ip)
}
@@ -231,6 +232,7 @@ func (d *memoryDefender) AddEvent(ip, protocol string, event HostEvent) {
}
d.cleanupHosts()
}
return false
}
func (d *memoryDefender) countBanned() int {

View File

@@ -146,6 +146,14 @@ func getURLParam(r *http.Request, key string) string {
return unescaped
}
func getURLPath(r *http.Request) string {
rctx := chi.RouteContext(r.Context())
if rctx != nil && rctx.RoutePath != "" {
return rctx.RoutePath
}
return r.URL.Path
}
func getCommaSeparatedQueryParam(r *http.Request, key string) []string {
var result []string

View File

@@ -1220,25 +1220,13 @@ func (s *httpdServer) redirectToWebPath(w http.ResponseWriter, r *http.Request,
// The StripSlashes causes infinite redirects at the root path if used with http.FileServer.
// We also don't strip paths with more than one trailing slash, see #1434
func (s *httpdServer) mustStripSlash(r *http.Request) bool {
var urlPath string
rctx := chi.RouteContext(r.Context())
if rctx != nil && rctx.RoutePath != "" {
urlPath = rctx.RoutePath
} else {
urlPath = r.URL.Path
}
urlPath := getURLPath(r)
return !strings.HasSuffix(urlPath, "//") && !strings.HasPrefix(urlPath, webOpenAPIPath) &&
!strings.HasPrefix(urlPath, webStaticFilesPath) && !strings.HasPrefix(urlPath, acmeChallengeURI)
}
func (s *httpdServer) mustCheckPath(r *http.Request) bool {
var urlPath string
rctx := chi.RouteContext(r.Context())
if rctx != nil && rctx.RoutePath != "" {
urlPath = rctx.RoutePath
} else {
urlPath = r.URL.Path
}
urlPath := getURLPath(r)
return !strings.HasPrefix(urlPath, webStaticFilesPath) && !strings.HasPrefix(urlPath, acmeChallengeURI)
}