From f45c89fc46c70ee74042153d592956515171042b Mon Sep 17 00:00:00 2001 From: Nicola Murino Date: Mon, 19 Apr 2021 08:14:04 +0200 Subject: [PATCH] add rate limiting support for REST API/web admin too --- common/common.go | 9 +++++---- common/common_test.go | 17 +++++++++-------- common/ratelimiter.go | 10 +++++----- common/ratelimiter_test.go | 20 ++++++++++---------- config/config.go | 2 +- config/config_test.go | 3 ++- docs/full-configuration.md | 2 +- docs/rate-limiting.md | 16 ++++++++++++---- ftpd/server.go | 3 ++- httpd/httpd_test.go | 38 ++++++++++++++++++++++++++++++++++++++ httpd/middleware.go | 14 ++++++++++++++ httpd/server.go | 3 ++- sftpd/server.go | 3 ++- sftpgo.json | 3 ++- webdavd/server.go | 5 ++++- 15 files changed, 109 insertions(+), 39 deletions(-) diff --git a/common/common.go b/common/common.go index df21949f..2efb32d4 100644 --- a/common/common.go +++ b/common/common.go @@ -70,6 +70,7 @@ const ( ProtocolSSH = "SSH" ProtocolFTP = "FTP" ProtocolWebDAV = "DAV" + ProtocolHTTP = "HTTP" ) // Upload modes @@ -144,14 +145,14 @@ func Initialize(c Configuration) error { // allow one event to happen. // It returns an error if the time to wait exceeds the max // allowed delay -func LimitRate(protocol, ip string) error { +func LimitRate(protocol, ip string) (time.Duration, error) { for _, limiter := range rateLimiters[protocol] { - if err := limiter.Wait(ip); err != nil { + if delay, err := limiter.Wait(ip); err != nil { logger.Debug(logSender, "", "protocol %v ip %v: %v", protocol, ip, err) - return err + return delay, err } } - return nil + return 0, nil } // ReloadDefender reloads the defender's block and safe lists diff --git a/common/common_test.go b/common/common_test.go index f4ba5c58..6b486576 100644 --- a/common/common_test.go +++ b/common/common_test.go @@ -194,30 +194,31 @@ func TestRateLimitersIntegration(t *testing.T) { err = Initialize(Config) assert.NoError(t, err) - assert.Len(t, rateLimiters, 3) + assert.Len(t, rateLimiters, 4) assert.Len(t, rateLimiters[ProtocolSSH], 1) assert.Len(t, rateLimiters[ProtocolFTP], 2) assert.Len(t, rateLimiters[ProtocolWebDAV], 2) + assert.Len(t, rateLimiters[ProtocolHTTP], 1) source1 := "127.1.1.1" source2 := "127.1.1.2" - err = LimitRate(ProtocolSSH, source1) + _, err = LimitRate(ProtocolSSH, source1) assert.NoError(t, err) - err = LimitRate(ProtocolFTP, source1) + _, err = LimitRate(ProtocolFTP, source1) assert.NoError(t, err) // sleep to allow the add configured burst to the token. // This sleep is not enough to add the per-source burst time.Sleep(20 * time.Millisecond) - err = LimitRate(ProtocolWebDAV, source2) + _, err = LimitRate(ProtocolWebDAV, source2) assert.NoError(t, err) - err = LimitRate(ProtocolFTP, source1) + _, err = LimitRate(ProtocolFTP, source1) assert.Error(t, err) - err = LimitRate(ProtocolWebDAV, source2) + _, err = LimitRate(ProtocolWebDAV, source2) assert.Error(t, err) - err = LimitRate(ProtocolSSH, source1) + _, err = LimitRate(ProtocolSSH, source1) assert.NoError(t, err) - err = LimitRate(ProtocolSSH, source2) + _, err = LimitRate(ProtocolSSH, source2) assert.NoError(t, err) Config = configCopy diff --git a/common/ratelimiter.go b/common/ratelimiter.go index bc7a73e7..4104b957 100644 --- a/common/ratelimiter.go +++ b/common/ratelimiter.go @@ -16,7 +16,7 @@ import ( var ( errNoBucket = errors.New("no bucket found") errReserve = errors.New("unable to reserve token") - rateLimiterProtocolValues = []string{ProtocolSSH, ProtocolFTP, ProtocolWebDAV} + rateLimiterProtocolValues = []string{ProtocolSSH, ProtocolFTP, ProtocolWebDAV, ProtocolHTTP} ) // RateLimiterType defines the supported rate limiters types @@ -130,7 +130,7 @@ type rateLimiter struct { // Wait blocks until the limit allows one event to happen // or returns an error if the time to wait exceeds the max // allowed delay -func (rl *rateLimiter) Wait(source string) error { +func (rl *rateLimiter) Wait(source string) (time.Duration, error) { var res *rate.Reservation if rl.globalBucket != nil { res = rl.globalBucket.Reserve() @@ -143,7 +143,7 @@ func (rl *rateLimiter) Wait(source string) error { } } if !res.OK() { - return errReserve + return 0, errReserve } delay := res.Delay() if delay > rl.maxDelay { @@ -151,10 +151,10 @@ func (rl *rateLimiter) Wait(source string) error { if rl.generateDefenderEvents && rl.globalBucket == nil { AddDefenderEvent(source, HostEventRateExceeded) } - return fmt.Errorf("rate limit exceed, wait time to respect rate %v, max wait time allowed %v", delay, rl.maxDelay) + return delay, fmt.Errorf("rate limit exceed, wait time to respect rate %v, max wait time allowed %v", delay, rl.maxDelay) } time.Sleep(delay) - return nil + return 0, nil } type sourceRateLimiter struct { diff --git a/common/ratelimiter_test.go b/common/ratelimiter_test.go index 7a24740e..7955c0fd 100644 --- a/common/ratelimiter_test.go +++ b/common/ratelimiter_test.go @@ -63,9 +63,9 @@ func TestRateLimiter(t *testing.T) { Protocols: rateLimiterProtocolValues, } limiter := config.getLimiter() - err := limiter.Wait("") + _, err := limiter.Wait("") require.NoError(t, err) - err = limiter.Wait("") + _, err = limiter.Wait("") require.Error(t, err) config.Type = int(rateLimiterTypeSource) @@ -75,17 +75,17 @@ func TestRateLimiter(t *testing.T) { limiter = config.getLimiter() source := "192.168.1.2" - err = limiter.Wait(source) + _, err = limiter.Wait(source) require.NoError(t, err) - err = limiter.Wait(source) + _, err = limiter.Wait(source) require.Error(t, err) // a different source should work - err = limiter.Wait(source + "1") + _, err = limiter.Wait(source + "1") require.NoError(t, err) config.Burst = 0 limiter = config.getLimiter() - err = limiter.Wait(source) + _, err = limiter.Wait(source) require.ErrorIs(t, err, errReserve) } @@ -104,10 +104,10 @@ func TestLimiterCleanup(t *testing.T) { source2 := "10.8.0.2" source3 := "10.8.0.3" source4 := "10.8.0.4" - err := limiter.Wait(source1) + _, err := limiter.Wait(source1) assert.NoError(t, err) time.Sleep(20 * time.Millisecond) - err = limiter.Wait(source2) + _, err = limiter.Wait(source2) assert.NoError(t, err) time.Sleep(20 * time.Millisecond) assert.Len(t, limiter.buckets.buckets, 2) @@ -115,7 +115,7 @@ func TestLimiterCleanup(t *testing.T) { assert.True(t, ok) _, ok = limiter.buckets.buckets[source2] assert.True(t, ok) - err = limiter.Wait(source3) + _, err = limiter.Wait(source3) assert.NoError(t, err) assert.Len(t, limiter.buckets.buckets, 3) _, ok = limiter.buckets.buckets[source1] @@ -125,7 +125,7 @@ func TestLimiterCleanup(t *testing.T) { _, ok = limiter.buckets.buckets[source3] assert.True(t, ok) time.Sleep(20 * time.Millisecond) - err = limiter.Wait(source4) + _, err = limiter.Wait(source4) assert.NoError(t, err) assert.Len(t, limiter.buckets.buckets, 2) _, ok = limiter.buckets.buckets[source3] diff --git a/config/config.go b/config/config.go index 71dd5e9c..1a4512a7 100644 --- a/config/config.go +++ b/config/config.go @@ -74,7 +74,7 @@ var ( Period: 1000, Burst: 1, Type: 2, - Protocols: []string{common.ProtocolSSH, common.ProtocolFTP, common.ProtocolWebDAV}, + Protocols: []string{common.ProtocolSSH, common.ProtocolFTP, common.ProtocolWebDAV, common.ProtocolHTTP}, GenerateDefenderEvents: false, EntriesSoftLimit: 100, EntriesHardLimit: 150, diff --git a/config/config_test.go b/config/config_test.go index 59c66138..408194e9 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -474,10 +474,11 @@ func TestRateLimitersFromEnv(t *testing.T) { require.Equal(t, 1, limiters[1].Burst) require.Equal(t, 2, limiters[1].Type) protocols = limiters[1].Protocols - require.Len(t, protocols, 3) + require.Len(t, protocols, 4) require.True(t, utils.IsStringInSlice(common.ProtocolFTP, protocols)) require.True(t, utils.IsStringInSlice(common.ProtocolSSH, protocols)) require.True(t, utils.IsStringInSlice(common.ProtocolWebDAV, protocols)) + require.True(t, utils.IsStringInSlice(common.ProtocolHTTP, protocols)) require.False(t, limiters[1].GenerateDefenderEvents) require.Equal(t, 100, limiters[1].EntriesSoftLimit) require.Equal(t, 150, limiters[1].EntriesHardLimit) diff --git a/docs/full-configuration.md b/docs/full-configuration.md index ccb5232a..38eab73c 100644 --- a/docs/full-configuration.md +++ b/docs/full-configuration.md @@ -83,7 +83,7 @@ The configuration file contains the following sections: - `period`, integer. Period defines the period as milliseconds. The rate is actually defined by dividing average by period Default: 1000 (1 second). - `burst`, integer. Burst defines the maximum number of requests allowed to go through in the same arbitrarily small period of time. Default: 1 - `type`, integer. 1 means a global rate limiter, independent from the source host. 2 means a per-ip rate limiter. Default: 2 - - `protocols`, list of strings. Available protocols are `SSH`, `FTP`, `DAV`. By default all supported protocols are enabled + - `protocols`, list of strings. Available protocols are `SSH`, `FTP`, `DAV`, `HTTP`. By default all supported protocols are enabled - `generate_defender_events`, boolean. If `true`, the defender is enabled, and this is not a global rate limiter, a new defender event will be generated each time the configured limit is exceeded. Default `false` - `entries_soft_limit`, integer. - `entries_hard_limit`, integer. The number of per-ip rate limiters kept in memory will vary between the soft and hard limit diff --git a/docs/rate-limiting.md b/docs/rate-limiting.md index 2741ab43..b2c9dd4f 100644 --- a/docs/rate-limiting.md +++ b/docs/rate-limiting.md @@ -1,6 +1,6 @@ # Rate limiting -Rate limiting allows to control the number of requests going to the configured services. +Rate limiting allows to control the number of requests going to the SFTPGo services. SFTPGo implements a [token bucket](https://en.wikipedia.org/wiki/Token_bucket) initially full and refilled at the configured rate. The `burst` configuration parameter defines the size of the bucket. The rate is defined by dividing `average` by `period`, so for a rate below 1 req/s, one needs to define a period larger than a second. @@ -8,9 +8,16 @@ Requests that exceed the configured limit will be delayed or denied if they exce SFTPGo allows to define per-protocol rate limiters so you can have different configurations for different protocols. +The supported protocols are: + +- `SSH`, includes SFTP and SSH commands +- `FTP`, includes FTP, FTPES, FTPS +- `DAV`, WebDAV +- `HTTP`, REST API and web admin + You can also define two types of rate limiters: -- global, it is independent from the source host and therefore define a limit for the configured protocol/s +- global, it is independent from the source host and therefore define an aggregate limit for the configured protocol/s - per-host, this type of rate limiter can be connected to the built-in [defender](./defender.md) and generate `score_rate_exceeded` events and thus hosts that repeatedly exceed the configured limit can be automatically blocked If you configure a per-host rate limiter, SFTPGo will keep a rate limiter in memory for each host that connects to the service, you can limit the memory usage using the `entries_soft_limit` and `entries_hard_limit` configuration keys. @@ -27,7 +34,8 @@ You can defines how many rate limiters as you want, but keep in mind that if you "protocols": [ "SSH", "FTP", - "DAV" + "DAV", + "HTTP" ], "generate_defender_events": false, "entries_soft_limit": 100, @@ -48,6 +56,6 @@ You can defines how many rate limiters as you want, but keep in mind that if you ] ``` -we have a global rate limiter that limit the rate for the whole service to 100 req/s and an additional rate limiter that limits the `FTP` protocol to 10 req/s per host. +we have a global rate limiter that limit the aggregate rate for the all the services to 100 req/s and an additional rate limiter that limits the `FTP` protocol to 10 req/s per host. With this configuration, when a client connects via FTP it will be limited first by the global rate limiter and then by the per host rate limiter. Clients connecting via SFTP/WebDAV will be checked only against the global rate limiter. diff --git a/ftpd/server.go b/ftpd/server.go index 99c946f6..13ac646f 100644 --- a/ftpd/server.go +++ b/ftpd/server.go @@ -144,7 +144,8 @@ func (s *Server) ClientConnected(cc ftpserver.ClientContext) (string, error) { logger.Log(logger.LevelDebug, common.ProtocolFTP, "", "connection refused, configured limit reached") return "Access denied: max allowed connection exceeded", common.ErrConnectionDenied } - if err := common.LimitRate(common.ProtocolFTP, ipAddr); err != nil { + _, err := common.LimitRate(common.ProtocolFTP, ipAddr) + if err != nil { return fmt.Sprintf("Access denied: %v", err.Error()), err } if err := common.Config.ExecutePostConnectHook(ipAddr, common.ProtocolFTP); err != nil { diff --git a/httpd/httpd_test.go b/httpd/httpd_test.go index a56cdc79..82d9fc0b 100644 --- a/httpd/httpd_test.go +++ b/httpd/httpd_test.go @@ -3118,6 +3118,44 @@ func TestLoaddataMode(t *testing.T) { assert.NoError(t, err) } +func TestRateLimiter(t *testing.T) { + oldConfig := config.GetCommonConfig() + + cfg := config.GetCommonConfig() + cfg.RateLimitersConfig = []common.RateLimiterConfig{ + { + Average: 1, + Period: 1000, + Burst: 1, + Type: 1, + Protocols: []string{common.ProtocolHTTP}, + }, + } + + err := common.Initialize(cfg) + assert.NoError(t, err) + + client := &http.Client{ + Timeout: 5 * time.Second, + } + resp, err := client.Get(httpBaseURL + healthzPath) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + err = resp.Body.Close() + assert.NoError(t, err) + + resp, err = client.Get(httpBaseURL + healthzPath) + assert.NoError(t, err) + assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) + assert.NotEmpty(t, resp.Header.Get("Retry-After")) + assert.NotEmpty(t, resp.Header.Get("X-Retry-In")) + err = resp.Body.Close() + assert.NoError(t, err) + + err = common.Initialize(oldConfig) + assert.NoError(t, err) +} + func TestHTTPSConnection(t *testing.T) { client := &http.Client{ Timeout: 5 * time.Second, diff --git a/httpd/middleware.go b/httpd/middleware.go index cb8fec5c..e6d50bef 100644 --- a/httpd/middleware.go +++ b/httpd/middleware.go @@ -3,11 +3,13 @@ package httpd import ( "context" "errors" + "fmt" "net/http" "github.com/go-chi/jwtauth/v5" "github.com/lestrrat-go/jwx/jwt" + "github.com/drakkan/sftpgo/common" "github.com/drakkan/sftpgo/logger" "github.com/drakkan/sftpgo/utils" ) @@ -141,3 +143,15 @@ func verifyCSRFHeader(next http.Handler) http.Handler { next.ServeHTTP(w, r) }) } + +func rateLimiter(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if delay, err := common.LimitRate(common.ProtocolHTTP, utils.GetIPFromRemoteAddress(r.RemoteAddr)); err != nil { + w.Header().Set("Retry-After", fmt.Sprintf("%.0f", delay.Seconds())) + w.Header().Set("X-Retry-In", delay.String()) + sendAPIResponse(w, r, err, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests) + return + } + next.ServeHTTP(w, r) + }) +} diff --git a/httpd/server.go b/httpd/server.go index 48a24d9d..f1b26f5f 100644 --- a/httpd/server.go +++ b/httpd/server.go @@ -259,6 +259,8 @@ func (s *httpdServer) initializeRouter() { s.router.Use(saveConnectionAddress) s.router.Use(middleware.GetHead) s.router.Use(middleware.StripSlashes) + s.router.Use(middleware.RealIP) + s.router.Use(rateLimiter) s.router.Group(func(r chi.Router) { r.Get(healthzPath, func(w http.ResponseWriter, r *http.Request) { @@ -268,7 +270,6 @@ func (s *httpdServer) initializeRouter() { s.router.Group(func(router chi.Router) { router.Use(middleware.RequestID) - router.Use(middleware.RealIP) router.Use(logger.NewStructuredLogger(logger.GetLogger())) router.Use(middleware.Recoverer) diff --git a/sftpd/server.go b/sftpd/server.go index 74c2d826..ba6e37e4 100644 --- a/sftpd/server.go +++ b/sftpd/server.go @@ -360,7 +360,8 @@ func canAcceptConnection(ip string) bool { logger.Log(logger.LevelDebug, common.ProtocolSSH, "", "connection refused, configured limit reached") return false } - if err := common.LimitRate(common.ProtocolSSH, ip); err != nil { + _, err := common.LimitRate(common.ProtocolSSH, ip) + if err != nil { return false } if err := common.Config.ExecutePostConnectHook(ip, common.ProtocolSSH); err != nil { diff --git a/sftpgo.json b/sftpgo.json index 6d61b1be..eb2b6d79 100644 --- a/sftpgo.json +++ b/sftpgo.json @@ -35,7 +35,8 @@ "protocols": [ "SSH", "FTP", - "DAV" + "DAV", + "HTTP" ], "generate_defender_events": false, "entries_soft_limit": 100, diff --git a/webdavd/server.go b/webdavd/server.go index 5a076701..2aca18f1 100644 --- a/webdavd/server.go +++ b/webdavd/server.go @@ -158,7 +158,10 @@ func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { http.Error(w, common.ErrConnectionDenied.Error(), http.StatusForbidden) return } - if err := common.LimitRate(common.ProtocolWebDAV, ipAddr); err != nil { + delay, err := common.LimitRate(common.ProtocolWebDAV, ipAddr) + if err != nil { + w.Header().Set("Retry-After", fmt.Sprintf("%.0f", delay.Seconds())) + w.Header().Set("X-Retry-In", delay.String()) http.Error(w, err.Error(), http.StatusTooManyRequests) return }