httpd: allow to restrict allowed hosts ...

... and to add security headers to the responses

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
Nicola Murino
2022-02-17 18:22:27 +01:00
parent 876bf8aa4f
commit f1a255aa6c
15 changed files with 415 additions and 24 deletions

View File

@@ -20,6 +20,7 @@ import (
"github.com/rs/cors"
"github.com/rs/xid"
"github.com/sftpgo/sdk"
"github.com/unrolled/secure"
"github.com/drakkan/sftpgo/v2/common"
"github.com/drakkan/sftpgo/v2/dataprovider"
@@ -939,6 +940,7 @@ func (s *httpdServer) checkConnection(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
ip := net.ParseIP(ipAddr)
areHeadersAllowed := false
if ip != nil {
for _, allow := range s.binding.allowHeadersFrom {
if allow(ip) {
@@ -951,10 +953,16 @@ func (s *httpdServer) checkConnection(next http.Handler) http.Handler {
ctx := context.WithValue(r.Context(), forwardedProtoKey, forwardedProto)
r = r.WithContext(ctx)
}
areHeadersAllowed = true
break
}
}
}
if !areHeadersAllowed {
for idx := range s.binding.Security.proxyHeaders {
r.Header.Del(s.binding.Security.proxyHeaders[idx])
}
}
common.Connections.AddClientConnection(ipAddr)
defer common.Connections.RemoveClientConnection(ipAddr)
@@ -1008,6 +1016,17 @@ func (s *httpdServer) sendForbiddenResponse(w http.ResponseWriter, r *http.Reque
sendAPIResponse(w, r, errors.New(message), message, http.StatusForbidden)
}
func (s *httpdServer) badHostHandler(w http.ResponseWriter, r *http.Request) {
host := r.Host
for _, header := range s.binding.Security.HostsProxyHeaders {
if h := r.Header.Get(header); h != "" {
host = h
break
}
}
s.sendForbiddenResponse(w, r, fmt.Sprintf("The host %#v is not allowed", host))
}
func (s *httpdServer) redirectToWebPath(w http.ResponseWriter, r *http.Request, webPath string) {
if dataprovider.HasAdmin() {
http.Redirect(w, r, webPath, http.StatusFound)
@@ -1037,6 +1056,24 @@ func (s *httpdServer) initializeRouter() {
s.router.Use(s.checkConnection)
s.router.Use(logger.NewStructuredLogger(logger.GetLogger()))
s.router.Use(middleware.Recoverer)
if s.binding.Security.Enabled {
secureMiddleware := secure.New(secure.Options{
AllowedHosts: s.binding.Security.AllowedHosts,
AllowedHostsAreRegex: s.binding.Security.AllowedHostsAreRegex,
HostsProxyHeaders: s.binding.Security.HostsProxyHeaders,
SSLProxyHeaders: s.binding.Security.getHTTPSProxyHeaders(),
STSSeconds: s.binding.Security.STSSeconds,
STSIncludeSubdomains: s.binding.Security.STSIncludeSubdomains,
STSPreload: s.binding.Security.STSPreload,
ContentTypeNosniff: s.binding.Security.ContentTypeNosniff,
ContentSecurityPolicy: s.binding.Security.ContentSecurityPolicy,
PermissionsPolicy: s.binding.Security.PermissionsPolicy,
CrossOriginOpenerPolicy: s.binding.Security.CrossOriginOpenerPolicy,
ExpectCTHeader: s.binding.Security.ExpectCTHeader,
})
secureMiddleware.SetBadHostHandler(http.HandlerFunc(s.badHostHandler))
s.router.Use(secureMiddleware.Handler)
}
if s.cors.Enabled {
c := cors.New(cors.Options{
AllowedOrigins: s.cors.AllowedOrigins,