From 600268ebb8d9fcafd3e07461666a4bfba92c7863 Mon Sep 17 00:00:00 2001 From: Nicola Murino Date: Tue, 25 May 2021 08:36:01 +0200 Subject: [PATCH] httpclient: allow to set custom headers --- common/actions.go | 4 +- common/common.go | 6 +-- common/protocol_test.go | 3 +- config/config.go | 29 +++++++++++++ config/config_test.go | 71 +++++++++++++++++++++++++++++++ dataprovider/dataprovider.go | 42 +++++------------- docs/full-configuration.md | 4 ++ httpclient/httpclient.go | 82 +++++++++++++++++++++++++++++++++++- httpd/auth_utils.go | 4 +- httpd/httpd.go | 1 - sftpgo.json | 3 +- utils/utils.go | 2 +- 12 files changed, 204 insertions(+), 47 deletions(-) diff --git a/common/actions.go b/common/actions.go index 582772e1..3459b7ae 100644 --- a/common/actions.go +++ b/common/actions.go @@ -155,12 +155,10 @@ func (h *defaultActionHandler) handleHTTP(notification *ActionNotification) erro startTime := time.Now() respCode := 0 - httpClient := httpclient.GetRetraybleHTTPClient() - var b bytes.Buffer _ = json.NewEncoder(&b).Encode(notification) - resp, err := httpClient.Post(u.String(), "application/json", &b) + resp, err := httpclient.RetryablePost(Config.Actions.Hook, "application/json", &b) if err == nil { respCode = resp.StatusCode resp.Body.Close() diff --git a/common/common.go b/common/common.go index 8ec0980f..f811a2fb 100644 --- a/common/common.go +++ b/common/common.go @@ -420,8 +420,7 @@ func (c *Configuration) ExecuteStartupHook() error { return err } startTime := time.Now() - httpClient := httpclient.GetRetraybleHTTPClient() - resp, err := httpClient.Get(url.String()) + resp, err := httpclient.RetryableGet(url.String()) if err != nil { logger.Warn(logSender, "", "Error executing startup hook: %v", err) return err @@ -457,13 +456,12 @@ func (c *Configuration) ExecutePostConnectHook(ipAddr, protocol string) error { ipAddr, c.PostConnectHook, err) return err } - httpClient := httpclient.GetRetraybleHTTPClient() q := url.Query() q.Add("ip", ipAddr) q.Add("protocol", protocol) url.RawQuery = q.Encode() - resp, err := httpClient.Get(url.String()) + resp, err := httpclient.RetryableGet(url.String()) if err != nil { logger.Warn(protocol, "", "Login from ip %#v denied, error executing post connect hook: %v", ipAddr, err) return err diff --git a/common/protocol_test.go b/common/protocol_test.go index 9b5d50c9..726b5a69 100644 --- a/common/protocol_test.go +++ b/common/protocol_test.go @@ -2631,8 +2631,7 @@ func TestNonLocalCrossRenameNonLocalBaseUser(t *testing.T) { } func TestProxyProtocol(t *testing.T) { - httpClient := httpclient.GetHTTPClient() - resp, err := httpClient.Get(fmt.Sprintf("http://%v", httpProxyAddr)) + resp, err := httpclient.Get(fmt.Sprintf("http://%v", httpProxyAddr)) if assert.NoError(t, err) { defer resp.Body.Close() assert.Equal(t, http.StatusBadRequest, resp.StatusCode) diff --git a/config/config.go b/config/config.go index 2caef2ec..b31ce0f9 100644 --- a/config/config.go +++ b/config/config.go @@ -256,6 +256,7 @@ func Init() { CACertificates: nil, Certificates: nil, SkipTLSVerify: false, + Headers: nil, }, KMSConfig: kms.Configuration{ Secrets: kms.Secrets{ @@ -577,6 +578,7 @@ func loadBindingsFromEnv() { getWebDAVDBindingFromEnv(idx) getHTTPDBindingFromEnv(idx) getHTTPClientCertificatesFromEnv(idx) + getHTTPClientHeadersFromEnv(idx) } } @@ -889,6 +891,33 @@ func getHTTPClientCertificatesFromEnv(idx int) { } } +func getHTTPClientHeadersFromEnv(idx int) { + header := httpclient.Header{} + + key, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTP__HEADERS__%v__KEY", idx)) + if ok { + header.Key = key + } + + value, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTP__HEADERS__%v__VALUE", idx)) + if ok { + header.Value = value + } + + url, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTP__HEADERS__%v__URL", idx)) + if ok { + header.URL = url + } + + if header.Key != "" && header.Value != "" { + if len(globalConf.HTTPConfig.Headers) > idx { + globalConf.HTTPConfig.Headers[idx] = header + } else { + globalConf.HTTPConfig.Headers = append(globalConf.HTTPConfig.Headers, header) + } + } +} + func setViperDefaults() { viper.SetDefault("common.idle_timeout", globalConf.Common.IdleTimeout) viper.SetDefault("common.upload_mode", globalConf.Common.UploadMode) diff --git a/config/config_test.go b/config/config_test.go index 421ca35f..3f142349 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -751,6 +751,77 @@ func TestHTTPClientCertificatesFromEnv(t *testing.T) { require.Equal(t, "key9", config.GetHTTPConfig().Certificates[1].Key) } +func TestHTTPClientHeadersFromEnv(t *testing.T) { + reset() + + configDir := ".." + confName := tempConfigName + ".json" + configFilePath := filepath.Join(configDir, confName) + err := config.LoadConfig(configDir, "") + assert.NoError(t, err) + httpConf := config.GetHTTPConfig() + httpConf.Headers = append(httpConf.Headers, httpclient.Header{ + Key: "key", + Value: "value", + URL: "url", + }) + c := make(map[string]httpclient.Config) + c["http"] = httpConf + jsonConf, err := json.Marshal(c) + require.NoError(t, err) + err = os.WriteFile(configFilePath, jsonConf, os.ModePerm) + require.NoError(t, err) + err = config.LoadConfig(configDir, confName) + require.NoError(t, err) + require.Len(t, config.GetHTTPConfig().Headers, 1) + require.Equal(t, "key", config.GetHTTPConfig().Headers[0].Key) + require.Equal(t, "value", config.GetHTTPConfig().Headers[0].Value) + require.Equal(t, "url", config.GetHTTPConfig().Headers[0].URL) + + os.Setenv("SFTPGO_HTTP__HEADERS__0__KEY", "key0") + os.Setenv("SFTPGO_HTTP__HEADERS__0__VALUE", "value0") + os.Setenv("SFTPGO_HTTP__HEADERS__0__URL", "url0") + os.Setenv("SFTPGO_HTTP__HEADERS__8__KEY", "key8") + os.Setenv("SFTPGO_HTTP__HEADERS__9__KEY", "key9") + os.Setenv("SFTPGO_HTTP__HEADERS__9__VALUE", "value9") + os.Setenv("SFTPGO_HTTP__HEADERS__9__URL", "url9") + + t.Cleanup(func() { + os.Unsetenv("SFTPGO_HTTP__HEADERS__0__KEY") + os.Unsetenv("SFTPGO_HTTP__HEADERS__0__VALUE") + os.Unsetenv("SFTPGO_HTTP__HEADERS__0__URL") + os.Unsetenv("SFTPGO_HTTP__HEADERS__8__KEY") + os.Unsetenv("SFTPGO_HTTP__HEADERS__9__KEY") + os.Unsetenv("SFTPGO_HTTP__HEADERS__9__VALUE") + os.Unsetenv("SFTPGO_HTTP__HEADERS__9__URL") + }) + + err = config.LoadConfig(configDir, confName) + require.NoError(t, err) + require.Len(t, config.GetHTTPConfig().Headers, 2) + require.Equal(t, "key0", config.GetHTTPConfig().Headers[0].Key) + require.Equal(t, "value0", config.GetHTTPConfig().Headers[0].Value) + require.Equal(t, "url0", config.GetHTTPConfig().Headers[0].URL) + require.Equal(t, "key9", config.GetHTTPConfig().Headers[1].Key) + require.Equal(t, "value9", config.GetHTTPConfig().Headers[1].Value) + require.Equal(t, "url9", config.GetHTTPConfig().Headers[1].URL) + + err = os.Remove(configFilePath) + assert.NoError(t, err) + + config.Init() + + err = config.LoadConfig(configDir, "") + require.NoError(t, err) + require.Len(t, config.GetHTTPConfig().Headers, 2) + require.Equal(t, "key0", config.GetHTTPConfig().Headers[0].Key) + require.Equal(t, "value0", config.GetHTTPConfig().Headers[0].Value) + require.Equal(t, "url0", config.GetHTTPConfig().Headers[0].URL) + require.Equal(t, "key9", config.GetHTTPConfig().Headers[1].Key) + require.Equal(t, "value9", config.GetHTTPConfig().Headers[1].Value) + require.Equal(t, "url9", config.GetHTTPConfig().Headers[1].URL) +} + func TestConfigFromEnv(t *testing.T) { reset() diff --git a/dataprovider/dataprovider.go b/dataprovider/dataprovider.go index 6332de0a..91c8f6bf 100644 --- a/dataprovider/dataprovider.go +++ b/dataprovider/dataprovider.go @@ -1912,15 +1912,14 @@ func validateKeyboardAuthResponse(response keyboardAuthHookResponse) error { return nil } -func sendKeyboardAuthHTTPReq(url *url.URL, request keyboardAuthHookRequest) (keyboardAuthHookResponse, error) { +func sendKeyboardAuthHTTPReq(url string, request keyboardAuthHookRequest) (keyboardAuthHookResponse, error) { var response keyboardAuthHookResponse - httpClient := httpclient.GetHTTPClient() reqAsJSON, err := json.Marshal(request) if err != nil { providerLog(logger.LevelWarn, "error serializing keyboard interactive auth request: %v", err) return response, err } - resp, err := httpClient.Post(url.String(), "application/json", bytes.NewBuffer(reqAsJSON)) + resp, err := httpclient.Post(url, "application/json", bytes.NewBuffer(reqAsJSON)) if err != nil { providerLog(logger.LevelWarn, "error getting keyboard interactive auth hook HTTP response: %v", err) return response, err @@ -1935,12 +1934,6 @@ func sendKeyboardAuthHTTPReq(url *url.URL, request keyboardAuthHookRequest) (key func executeKeyboardInteractiveHTTPHook(user *User, authHook string, client ssh.KeyboardInteractiveChallenge, ip, protocol string) (int, error) { authResult := 0 - var url *url.URL - url, err := url.Parse(authHook) - if err != nil { - providerLog(logger.LevelWarn, "invalid url for keyboard interactive hook %#v, error: %v", authHook, err) - return authResult, err - } requestID := xid.New().String() req := keyboardAuthHookRequest{ Username: user.Username, @@ -1949,8 +1942,9 @@ func executeKeyboardInteractiveHTTPHook(user *User, authHook string, client ssh. RequestID: requestID, } var response keyboardAuthHookResponse + var err error for { - response, err = sendKeyboardAuthHTTPReq(url, req) + response, err = sendKeyboardAuthHTTPReq(authHook, req) if err != nil { return authResult, err } @@ -2120,12 +2114,6 @@ func isCheckPasswordHookDefined(protocol string) bool { func getPasswordHookResponse(username, password, ip, protocol string) ([]byte, error) { if strings.HasPrefix(config.CheckPasswordHook, "http") { var result []byte - var url *url.URL - url, err := url.Parse(config.CheckPasswordHook) - if err != nil { - providerLog(logger.LevelWarn, "invalid url for check password hook %#v, error: %v", config.CheckPasswordHook, err) - return result, err - } req := checkPasswordRequest{ Username: username, Password: password, @@ -2136,8 +2124,7 @@ func getPasswordHookResponse(username, password, ip, protocol string) ([]byte, e if err != nil { return result, err } - httpClient := httpclient.GetHTTPClient() - resp, err := httpClient.Post(url.String(), "application/json", bytes.NewBuffer(reqAsJSON)) + resp, err := httpclient.Post(config.CheckPasswordHook, "application/json", bytes.NewBuffer(reqAsJSON)) if err != nil { providerLog(logger.LevelWarn, "error getting check password hook response: %v", err) return result, err @@ -2192,8 +2179,8 @@ func getPreLoginHookResponse(loginMethod, ip, protocol string, userAsJSON []byte q.Add("ip", ip) q.Add("protocol", protocol) url.RawQuery = q.Encode() - httpClient := httpclient.GetHTTPClient() - resp, err := httpClient.Post(url.String(), "application/json", bytes.NewBuffer(userAsJSON)) + + resp, err := httpclient.Post(url.String(), "application/json", bytes.NewBuffer(userAsJSON)) if err != nil { providerLog(logger.LevelWarn, "error getting pre-login hook response: %v", err) return result, err @@ -2318,8 +2305,7 @@ func ExecutePostLoginHook(user *User, loginMethod, ip, protocol string, err erro startTime := time.Now() respCode := 0 - httpClient := httpclient.GetRetraybleHTTPClient() - resp, err := httpClient.Post(url.String(), "application/json", bytes.NewBuffer(userAsJSON)) + resp, err := httpclient.RetryablePost(url.String(), "application/json", bytes.NewBuffer(userAsJSON)) if err == nil { respCode = resp.StatusCode resp.Body.Close() @@ -2353,14 +2339,7 @@ func getExternalAuthResponse(username, password, pkey, keyboardInteractive, ip, } } if strings.HasPrefix(config.ExternalAuthHook, "http") { - var url *url.URL var result []byte - url, err := url.Parse(config.ExternalAuthHook) - if err != nil { - providerLog(logger.LevelWarn, "invalid url for external auth hook %#v, error: %v", config.ExternalAuthHook, err) - return result, err - } - httpClient := httpclient.GetHTTPClient() authRequest := make(map[string]string) authRequest["username"] = username authRequest["ip"] = ip @@ -2377,7 +2356,7 @@ func getExternalAuthResponse(username, password, pkey, keyboardInteractive, ip, providerLog(logger.LevelWarn, "error serializing external auth request: %v", err) return result, err } - resp, err := httpClient.Post(url.String(), "application/json", bytes.NewBuffer(authRequestAsJSON)) + resp, err := httpclient.Post(config.ExternalAuthHook, "application/json", bytes.NewBuffer(authRequestAsJSON)) if err != nil { providerLog(logger.LevelWarn, "error getting external auth hook HTTP response: %v", err) return result, err @@ -2561,8 +2540,7 @@ func executeAction(operation string, user *User) { q.Add("action", operation) url.RawQuery = q.Encode() startTime := time.Now() - httpClient := httpclient.GetRetraybleHTTPClient() - resp, err := httpClient.Post(url.String(), "application/json", bytes.NewBuffer(userAsJSON)) + resp, err := httpclient.RetryablePost(url.String(), "application/json", bytes.NewBuffer(userAsJSON)) respCode := 0 if err == nil { respCode = resp.StatusCode diff --git a/docs/full-configuration.md b/docs/full-configuration.md index b87d93e8..4dd4381f 100644 --- a/docs/full-configuration.md +++ b/docs/full-configuration.md @@ -248,6 +248,10 @@ The configuration file contains the following sections: - `cert`, string. Path to the certificate file. The path can be absolute or relative to the config dir. - `key`, string. Path to the key file. The path can be absolute or relative to the config dir. - `skip_tls_verify`, boolean. if enabled the HTTP client accepts any TLS certificate presented by the server and any host name in that certificate. In this mode, TLS is susceptible to man-in-the-middle attacks. This should be used only for testing. + - `headers`, list of structs. You can define a list of http headers to add to each hook. Each struct has the following fields: + - `key`, string + - `value`, string. The header is silently ignored if `key` or `value` are empty + - `url`, string, optional. If not empty, the header will be added only if the request URL starts with the one specified here - **kms**, configuration for the Key Management Service, more details can be found [here](./kms.md) - `secrets` - `url` diff --git a/httpclient/httpclient.go b/httpclient/httpclient.go index 2e75c215..c4dd0cb4 100644 --- a/httpclient/httpclient.go +++ b/httpclient/httpclient.go @@ -4,9 +4,11 @@ import ( "crypto/tls" "crypto/x509" "fmt" + "io" "net/http" "os" "path/filepath" + "strings" "time" "github.com/hashicorp/go-retryablehttp" @@ -21,6 +23,15 @@ type TLSKeyPair struct { Key string `json:"key" mapstructure:"key"` } +// Header defines an HTTP header. +// If the URL is not empty, the header is added only if the +// requested URL starts with the one specified +type Header struct { + Key string `json:"key" mapstructure:"key"` + Value string `json:"value" mapstructure:"value"` + URL string `json:"url" mapstructure:"url"` +} + // Config defines the configuration for HTTP clients. // HTTP clients are used for executing hooks such as the ones used for // custom actions, external authentication and pre-login user modifications @@ -44,7 +55,9 @@ type Config struct { // the server and any host name in that certificate. // In this mode, TLS is susceptible to man-in-the-middle attacks. // This should be used only for testing. - SkipTLSVerify bool `json:"skip_tls_verify" mapstructure:"skip_tls_verify"` + SkipTLSVerify bool `json:"skip_tls_verify" mapstructure:"skip_tls_verify"` + // Headers defines a list of http headers to add to each request + Headers []Header `json:"headers" mapstructure:"headers"` customTransport *http.Transport tlsConfig *tls.Config } @@ -76,6 +89,13 @@ func (c *Config) Initialize(configDir string) error { if err != nil { return err } + var headers []Header + for _, h := range c.Headers { + if h.Key != "" && h.Value != "" { + headers = append(headers, h) + } + } + c.Headers = headers httpConfig = *c return nil } @@ -162,3 +182,63 @@ func GetRetraybleHTTPClient() *retryablehttp.Client { return client } + +// Get issues a GET to the specified URL +func Get(url string) (*http.Response, error) { + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return nil, err + } + addHeaders(req, url) + return GetHTTPClient().Do(req) +} + +// Post issues a POST to the specified URL +func Post(url string, contentType string, body io.Reader) (*http.Response, error) { + req, err := http.NewRequest(http.MethodPost, url, body) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", contentType) + addHeaders(req, url) + return GetHTTPClient().Do(req) +} + +// RetryableGet issues a GET to the specified URL using the retryable client +func RetryableGet(url string) (*http.Response, error) { + req, err := retryablehttp.NewRequest(http.MethodGet, url, nil) + if err != nil { + return nil, err + } + addHeadersToRetryableReq(req, url) + return GetRetraybleHTTPClient().Do(req) +} + +// RetryablePost issues a POST to the specified URL using the retryable client +func RetryablePost(url string, contentType string, body io.Reader) (*http.Response, error) { + req, err := retryablehttp.NewRequest(http.MethodPost, url, body) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", contentType) + addHeadersToRetryableReq(req, url) + return GetRetraybleHTTPClient().Do(req) +} + +func addHeaders(req *http.Request, url string) { + for idx := range httpConfig.Headers { + h := &httpConfig.Headers[idx] + if h.URL == "" || strings.HasPrefix(url, h.URL) { + req.Header.Set(h.Key, h.Value) + } + } +} + +func addHeadersToRetryableReq(req *retryablehttp.Request, url string) { + for idx := range httpConfig.Headers { + h := &httpConfig.Headers[idx] + if h.URL == "" || strings.HasPrefix(url, h.URL) { + req.Header.Set(h.Key, h.Value) + } + } +} diff --git a/httpd/auth_utils.go b/httpd/auth_utils.go index 4dd2ff74..617aed75 100644 --- a/httpd/auth_utils.go +++ b/httpd/auth_utils.go @@ -31,8 +31,8 @@ const ( ) var ( - tokenDuration = 10 * time.Minute - tokenRefreshMin = 5 * time.Minute + tokenDuration = 15 * time.Minute + tokenRefreshMin = 10 * time.Minute ) type jwtTokenClaims struct { diff --git a/httpd/httpd.go b/httpd/httpd.go index 84ab9a31..8a9db285 100644 --- a/httpd/httpd.go +++ b/httpd/httpd.go @@ -1,7 +1,6 @@ // Package httpd implements REST API and Web interface for SFTPGo. // The OpenAPI 3 schema for the exposed API can be found inside the source tree: // https://github.com/drakkan/sftpgo/blob/main/httpd/schema/openapi.yaml -// A basic Web interface to manage users and connections is provided too package httpd import ( diff --git a/sftpgo.json b/sftpgo.json index b9c7cc1f..659457e8 100644 --- a/sftpgo.json +++ b/sftpgo.json @@ -217,7 +217,8 @@ "retry_max": 3, "ca_certificates": [], "certificates": [], - "skip_tls_verify": false + "skip_tls_verify": false, + "headers": [] }, "kms": { "secrets": { diff --git a/utils/utils.go b/utils/utils.go index bf8a4191..8f8e1abb 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -404,7 +404,7 @@ func createDirPathIfMissing(file string, perm os.FileMode) error { func GenerateRandomBytes(length int) []byte { b := make([]byte, length) _, err := io.ReadFull(rand.Reader, b) - if err != nil { + if err == nil { return b }