webdav: refactor server initialization

This commit is contained in:
Nicola Murino
2021-01-03 09:51:54 +01:00
parent 1e1c46ae1b
commit 4b522a2455
7 changed files with 67 additions and 62 deletions

View File

@@ -177,6 +177,7 @@ func (c *Configuration) Initialize(configDir string) error {
go func(s *Server) { go func(s *Server) {
ftpServer := ftpserver.NewFtpServer(s) ftpServer := ftpserver.NewFtpServer(s)
logger.Info(logSender, "", "starting FTP serving, binding: %v", s.binding.GetAddress())
exitChannel <- ftpServer.ListenAndServe() exitChannel <- ftpServer.ListenAndServe()
}(server) }(server)

View File

@@ -165,9 +165,9 @@ func (c Conf) Initialize(configDir string) error {
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
} }
httpServer.TLSConfig = config httpServer.TLSConfig = config
return utils.HTTPListenAndServe(httpServer, c.BindAddress, c.BindPort, true) return utils.HTTPListenAndServe(httpServer, c.BindAddress, c.BindPort, true, logSender)
} }
return utils.HTTPListenAndServe(httpServer, c.BindAddress, c.BindPort, false) return utils.HTTPListenAndServe(httpServer, c.BindAddress, c.BindPort, false, logSender)
} }
// ReloadTLSCertificate reloads the TLS certificate and key from the configured paths // ReloadTLSCertificate reloads the TLS certificate and key from the configured paths

View File

@@ -93,9 +93,9 @@ func (c Conf) Initialize(configDir string) error {
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
} }
httpServer.TLSConfig = config httpServer.TLSConfig = config
return utils.HTTPListenAndServe(httpServer, c.BindAddress, c.BindPort, true) return utils.HTTPListenAndServe(httpServer, c.BindAddress, c.BindPort, true, logSender)
} }
return utils.HTTPListenAndServe(httpServer, c.BindAddress, c.BindPort, false) return utils.HTTPListenAndServe(httpServer, c.BindAddress, c.BindPort, false, logSender)
} }
// ReloadTLSCertificate reloads the TLS certificate and key from the configured paths // ReloadTLSCertificate reloads the TLS certificate and key from the configured paths

View File

@@ -385,7 +385,7 @@ func createDirPathIfMissing(file string, perm os.FileMode) error {
// HTTPListenAndServe is a wrapper for ListenAndServe that support both tcp // HTTPListenAndServe is a wrapper for ListenAndServe that support both tcp
// and Unix-domain sockets // and Unix-domain sockets
func HTTPListenAndServe(srv *http.Server, address string, port int, isTLS bool) error { func HTTPListenAndServe(srv *http.Server, address string, port int, isTLS bool, logSender string) error {
var listener net.Listener var listener net.Listener
var err error var err error
@@ -408,6 +408,8 @@ func HTTPListenAndServe(srv *http.Server, address string, port int, isTLS bool)
return err return err
} }
logger.Info(logSender, "", "server listener registered, address: %v TLS enabled: %v", listener.Addr().String(), isTLS)
defer listener.Close() defer listener.Close()
if isTLS { if isTLS {

View File

@@ -26,8 +26,7 @@ import (
) )
const ( const (
configDir = ".." testFile = "test_dav_file"
testFile = "test_dav_file"
) )
var ( var (
@@ -165,8 +164,11 @@ func TestUserInvalidParams(t *testing.T) {
}, },
}, },
} }
server, err := newServer(c, configDir)
assert.NoError(t, err) server := webDavServer{
config: c,
binding: c.Bindings[0],
}
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", u.Username), nil) req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", u.Username), nil)
assert.NoError(t, err) assert.NoError(t, err)
@@ -686,8 +688,10 @@ func TestBasicUsersCache(t *testing.T) {
}, },
}, },
} }
server, err := newServer(c, configDir) server := webDavServer{
assert.NoError(t, err) config: c,
binding: c.Bindings[0],
}
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user.Username), nil) req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user.Username), nil)
assert.NoError(t, err) assert.NoError(t, err)
@@ -807,8 +811,10 @@ func TestUsersCacheSizeAndExpiration(t *testing.T) {
}, },
}, },
} }
server, err := newServer(c, configDir) server := webDavServer{
assert.NoError(t, err) config: c,
binding: c.Bindings[0],
}
ipAddr := "127.0.1.1" ipAddr := "127.0.1.1"
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user1.Username), nil) req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user1.Username), nil)
@@ -953,8 +959,10 @@ func TestRecoverer(t *testing.T) {
}, },
}, },
} }
server, err := newServer(c, configDir) server := webDavServer{
assert.NoError(t, err) config: c,
binding: c.Bindings[0],
}
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
server.ServeHTTP(rr, nil) server.ServeHTTP(rr, nil)
assert.Equal(t, http.StatusInternalServerError, rr.Code) assert.Equal(t, http.StatusInternalServerError, rr.Code)

View File

@@ -32,34 +32,13 @@ var (
type webDavServer struct { type webDavServer struct {
config *Configuration config *Configuration
certMgr *common.CertManager binding Binding
status ServiceStatus
} }
func newServer(config *Configuration, configDir string) (*webDavServer, error) { func (s *webDavServer) listenAndServe() error {
var err error
server := &webDavServer{
config: config,
certMgr: nil,
}
certificateFile := getConfigPath(config.CertificateFile, configDir)
certificateKeyFile := getConfigPath(config.CertificateKeyFile, configDir)
if certificateFile != "" && certificateKeyFile != "" {
server.certMgr, err = common.NewCertManager(certificateFile, certificateKeyFile, logSender)
if err != nil {
return server, err
}
if err := server.certMgr.LoadRootCAs(config.CACertificates, configDir); err != nil {
return server, err
}
}
return server, nil
}
func (s *webDavServer) listenAndServe(binding Binding) error {
httpServer := &http.Server{ httpServer := &http.Server{
Addr: binding.GetAddress(), Addr: s.binding.GetAddress(),
Handler: server, Handler: s,
ReadHeaderTimeout: 30 * time.Second, ReadHeaderTimeout: 30 * time.Second,
IdleTimeout: 120 * time.Second, IdleTimeout: 120 * time.Second,
MaxHeaderBytes: 1 << 16, // 64KB MaxHeaderBytes: 1 << 16, // 64KB
@@ -74,22 +53,24 @@ func (s *webDavServer) listenAndServe(binding Binding) error {
AllowCredentials: s.config.Cors.AllowCredentials, AllowCredentials: s.config.Cors.AllowCredentials,
OptionsPassthrough: true, OptionsPassthrough: true,
}) })
httpServer.Handler = c.Handler(server) httpServer.Handler = c.Handler(s)
} }
if s.certMgr != nil && binding.EnableHTTPS { if certMgr != nil && s.binding.EnableHTTPS {
server.status.Bindings = append(server.status.Bindings, binding) serviceStatus.Bindings = append(serviceStatus.Bindings, s.binding)
httpServer.TLSConfig = &tls.Config{ httpServer.TLSConfig = &tls.Config{
GetCertificate: s.certMgr.GetCertificateFunc(), GetCertificate: certMgr.GetCertificateFunc(),
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
} }
if binding.ClientAuthType == 1 { if s.binding.ClientAuthType == 1 {
httpServer.TLSConfig.ClientCAs = s.certMgr.GetRootCAs() httpServer.TLSConfig.ClientCAs = certMgr.GetRootCAs()
httpServer.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert httpServer.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert
} }
logger.Info(logSender, "", "starting HTTPS serving, binding: %v", s.binding.GetAddress())
return httpServer.ListenAndServeTLS("", "") return httpServer.ListenAndServeTLS("", "")
} }
binding.EnableHTTPS = false s.binding.EnableHTTPS = false
server.status.Bindings = append(server.status.Bindings, binding) serviceStatus.Bindings = append(serviceStatus.Bindings, s.binding)
logger.Info(logSender, "", "starting HTTP serving, binding: %v", s.binding.GetAddress())
return httpServer.ListenAndServe() return httpServer.ListenAndServe()
} }

View File

@@ -22,7 +22,9 @@ const (
) )
var ( var (
server *webDavServer //server *webDavServer
certMgr *common.CertManager
serviceStatus ServiceStatus
) )
// ServiceStatus defines the service status // ServiceStatus defines the service status
@@ -107,10 +109,7 @@ type Configuration struct {
// GetStatus returns the server status // GetStatus returns the server status
func GetStatus() ServiceStatus { func GetStatus() ServiceStatus {
if server == nil { return serviceStatus
return ServiceStatus{}
}
return server.status
} }
// ShouldBind returns true if there is at least a valid binding // ShouldBind returns true if there is at least a valid binding
@@ -126,7 +125,6 @@ func (c *Configuration) ShouldBind() bool {
// Initialize configures and starts the WebDAV server // Initialize configures and starts the WebDAV server
func (c *Configuration) Initialize(configDir string) error { func (c *Configuration) Initialize(configDir string) error {
var err error
logger.Debug(logSender, "", "initializing WebDAV server with config %+v", *c) logger.Debug(logSender, "", "initializing WebDAV server with config %+v", *c)
mimeTypeCache = mimeCache{ mimeTypeCache = mimeCache{
maxSize: c.Cache.MimeTypes.MaxSize, maxSize: c.Cache.MimeTypes.MaxSize,
@@ -138,12 +136,23 @@ func (c *Configuration) Initialize(configDir string) error {
if !c.ShouldBind() { if !c.ShouldBind() {
return common.ErrNoBinding return common.ErrNoBinding
} }
server, err = newServer(c, configDir)
if err != nil { certificateFile := getConfigPath(c.CertificateFile, configDir)
return err certificateKeyFile := getConfigPath(c.CertificateKeyFile, configDir)
if certificateFile != "" && certificateKeyFile != "" {
mgr, err := common.NewCertManager(certificateFile, certificateKeyFile, logSender)
if err != nil {
return err
}
if err := mgr.LoadRootCAs(c.CACertificates, configDir); err != nil {
return err
}
certMgr = mgr
} }
server.status.Bindings = nil serviceStatus = ServiceStatus{
Bindings: nil,
}
exitChannel := make(chan error, 1) exitChannel := make(chan error, 1)
@@ -153,19 +162,23 @@ func (c *Configuration) Initialize(configDir string) error {
} }
go func(binding Binding) { go func(binding Binding) {
exitChannel <- server.listenAndServe(binding) server := webDavServer{
config: c,
binding: binding,
}
exitChannel <- server.listenAndServe()
}(binding) }(binding)
} }
server.status.IsActive = true serviceStatus.IsActive = true
return <-exitChannel return <-exitChannel
} }
// ReloadTLSCertificate reloads the TLS certificate and key from the configured paths // ReloadTLSCertificate reloads the TLS certificate and key from the configured paths
func ReloadTLSCertificate() error { func ReloadTLSCertificate() error {
if server != nil && server.certMgr != nil { if certMgr != nil {
return server.certMgr.LoadCertificate(logSender) return certMgr.LoadCertificate(logSender)
} }
return nil return nil
} }