httpfs: add support for UNIX domain sockets

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
Nicola Murino
2022-06-12 18:29:49 +02:00
parent 0b9a96ec6b
commit 6f4475ff72
7 changed files with 158 additions and 12 deletions

View File

@@ -9,6 +9,7 @@ import (
"io"
"io/fs"
"mime"
"net"
"net/http"
"net/url"
"os"
@@ -42,6 +43,10 @@ type HTTPFsConfig struct {
APIKey *kms.Secret `json:"api_key,omitempty"`
}
func (c *HTTPFsConfig) isUnixDomainSocket() bool {
return strings.HasPrefix(c.Endpoint, "http://unix") || strings.HasPrefix(c.Endpoint, "https://unix")
}
// HideConfidentialData hides confidential data
func (c *HTTPFsConfig) HideConfidentialData() {
if c.Password != nil {
@@ -95,13 +100,19 @@ func (c *HTTPFsConfig) validate() error {
return errors.New("httpfs: endpoint cannot be empty")
}
c.Endpoint = strings.TrimRight(c.Endpoint, "/")
_, err := url.Parse(c.Endpoint)
endpointURL, err := url.Parse(c.Endpoint)
if err != nil {
return fmt.Errorf("httpfs: invalid endpoint: %w", err)
}
if !util.IsStringPrefixInSlice(c.Endpoint, supportedEndpointSchema) {
return errors.New("httpfs: invalid endpoint schema: http and https are supported")
}
if endpointURL.Host == "unix" {
socketPath := endpointURL.Query().Get("socket_path")
if !filepath.IsAbs(socketPath) {
return fmt.Errorf("httpfs: invalid unix domain socket path: %q", socketPath)
}
}
if c.Password.IsEncrypted() && !c.Password.IsValid() {
return errors.New("httpfs: invalid encrypted password")
}
@@ -179,14 +190,43 @@ func NewHTTPFs(connectionID, localTempDir, mountPath string, config HTTPFsConfig
transport.MaxResponseHeaderBytes = 1 << 16
transport.WriteBufferSize = 1 << 16
transport.ReadBufferSize = 1 << 16
if fs.config.isUnixDomainSocket() {
endpointURL, err := url.Parse(fs.config.Endpoint)
if err != nil {
return nil, err
}
if endpointURL.Host == "unix" {
socketPath := endpointURL.Query().Get("socket_path")
if !filepath.IsAbs(socketPath) {
return nil, fmt.Errorf("httpfs: invalid unix domain socket path: %q", socketPath)
}
if endpointURL.Scheme == "https" {
transport.DialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
var tlsConfig *tls.Config
var d tls.Dialer
if config.SkipTLSVerify {
tlsConfig = getInsecureTLSConfig()
}
d.Config = tlsConfig
return d.DialContext(ctx, "unix", socketPath)
}
} else {
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
var d net.Dialer
return d.DialContext(ctx, "unix", socketPath)
}
}
endpointURL.Path = path.Join(endpointURL.Path, endpointURL.Query().Get("api_prefix"))
endpointURL.RawQuery = ""
endpointURL.RawFragment = ""
fs.config.Endpoint = endpointURL.String()
}
}
if config.SkipTLSVerify {
if transport.TLSClientConfig != nil {
transport.TLSClientConfig.InsecureSkipVerify = true
} else {
transport.TLSClientConfig = &tls.Config{
NextProtos: []string{"h2", "http/1.1"},
InsecureSkipVerify: true,
}
transport.TLSClientConfig = getInsecureTLSConfig()
}
}
fs.client = &http.Client{
@@ -646,6 +686,13 @@ func getErrorFromResponseCode(code int) error {
}
}
func getInsecureTLSConfig() *tls.Config {
return &tls.Config{
NextProtos: []string{"h2", "http/1.1"},
InsecureSkipVerify: true,
}
}
type wrapReader struct {
reader io.Reader
}