From ddbe40cefa189ec3d662a1664e4bb9fb5ef33bd1 Mon Sep 17 00:00:00 2001 From: Nicola Murino Date: Sat, 2 Aug 2025 18:00:45 +0200 Subject: [PATCH] HTTPD, WebDAV: use http.ResponseController backport from Enterprise edition Signed-off-by: Nicola Murino --- internal/httpd/api_http_user.go | 7 +-- internal/httpd/api_shares.go | 7 +-- internal/httpd/api_utils.go | 9 +++ internal/httpd/handler.go | 14 +++++ internal/httpd/internal_test.go | 53 +++++++++------- internal/httpd/server.go | 7 ++- internal/httpd/webclient.go | 35 ++++------- internal/util/timeoutlistener.go | 100 ------------------------------- internal/util/util.go | 4 +- internal/webdavd/handler.go | 14 +++++ internal/webdavd/server.go | 24 +++++--- 11 files changed, 105 insertions(+), 169 deletions(-) delete mode 100644 internal/util/timeoutlistener.go diff --git a/internal/httpd/api_http_user.go b/internal/httpd/api_http_user.go index f8a8111a..373b7bfe 100644 --- a/internal/httpd/api_http_user.go +++ b/internal/httpd/api_http_user.go @@ -53,11 +53,8 @@ func getUserConnection(w http.ResponseWriter, r *http.Request) (*Connection, err sendAPIResponse(w, r, err, http.StatusText(http.StatusForbidden), http.StatusForbidden) return nil, err } - connection := &Connection{ - BaseConnection: common.NewBaseConnection(connID, protocol, util.GetHTTPLocalAddress(r), - r.RemoteAddr, user), - request: r, - } + baseConn := common.NewBaseConnection(connID, protocol, util.GetHTTPLocalAddress(r), r.RemoteAddr, user) + connection := newConnection(baseConn, w, r) if err = common.Connections.Add(connection); err != nil { sendAPIResponse(w, r, err, "Unable to add connection", http.StatusTooManyRequests) return connection, err diff --git a/internal/httpd/api_shares.go b/internal/httpd/api_shares.go index ba9e2527..92f155c7 100644 --- a/internal/httpd/api_shares.go +++ b/internal/httpd/api_shares.go @@ -532,11 +532,8 @@ func (s *httpdServer) checkPublicShare(w http.ResponseWriter, r *http.Request, v return share, nil, err } connID := xid.New().String() - connection := &Connection{ - BaseConnection: common.NewBaseConnection(connID, common.ProtocolHTTPShare, util.GetHTTPLocalAddress(r), - r.RemoteAddr, user), - request: r, - } + baseConn := common.NewBaseConnection(connID, common.ProtocolHTTPShare, util.GetHTTPLocalAddress(r), r.RemoteAddr, user) + connection := newConnection(baseConn, w, r) return share, connection, nil } diff --git a/internal/httpd/api_utils.go b/internal/httpd/api_utils.go index 853973d0..7fc7567a 100644 --- a/internal/httpd/api_utils.go +++ b/internal/httpd/api_utils.go @@ -946,3 +946,12 @@ func hideConfidentialData(claims *jwtTokenClaims, r *http.Request) bool { } return r.URL.Query().Get("confidential_data") != "1" } + +func responseControllerDeadlines(rc *http.ResponseController, read, write time.Time) { + if err := rc.SetReadDeadline(read); err != nil { + logger.Error(logSender, "", "unable to set read timeout to %s: %v", read, err) + } + if err := rc.SetWriteDeadline(write); err != nil { + logger.Error(logSender, "", "unable to set write timeout to %s: %v", write, err) + } +} diff --git a/internal/httpd/handler.go b/internal/httpd/handler.go index 9821af69..77364b64 100644 --- a/internal/httpd/handler.go +++ b/internal/httpd/handler.go @@ -35,6 +35,17 @@ import ( type Connection struct { *common.BaseConnection request *http.Request + rc *http.ResponseController +} + +func newConnection(conn *common.BaseConnection, w http.ResponseWriter, r *http.Request) *Connection { + rc := http.NewResponseController(w) + responseControllerDeadlines(rc, time.Time{}, time.Time{}) + return &Connection{ + BaseConnection: conn, + request: r, + rc: rc, + } } // GetClientVersion returns the connected client's version. @@ -60,6 +71,9 @@ func (c *Connection) GetRemoteAddress() string { // Disconnect closes the active transfer func (c *Connection) Disconnect() (err error) { + if c.rc != nil { + responseControllerDeadlines(c.rc, time.Now().Add(5*time.Second), time.Now().Add(5*time.Second)) + } return c.SignalTransfersAbort() } diff --git a/internal/httpd/internal_test.go b/internal/httpd/internal_test.go index 7176eab6..6a6788c3 100644 --- a/internal/httpd/internal_test.go +++ b/internal/httpd/internal_test.go @@ -2686,10 +2686,11 @@ func TestCompressorAbortHandler(t *testing.T) { assert.Equal(t, http.ErrAbortHandler, rcv) }() - connection := &Connection{ - BaseConnection: common.NewBaseConnection(xid.New().String(), common.ProtocolHTTP, "", "", dataprovider.User{}), - request: nil, - } + connection := newConnection( + common.NewBaseConnection(xid.New().String(), common.ProtocolHTTP, "", "", dataprovider.User{}), + nil, + nil, + ) share := &dataprovider.Share{} renderCompressedFiles(&failingWriter{}, connection, "", nil, share) } @@ -2711,10 +2712,11 @@ func TestZipErrors(t *testing.T) { } user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{dataprovider.PermAny} - connection := &Connection{ - BaseConnection: common.NewBaseConnection(xid.New().String(), common.ProtocolHTTP, "", "", user), - request: nil, - } + connection := newConnection( + common.NewBaseConnection(xid.New().String(), common.ProtocolHTTP, "", "", user), + nil, + nil, + ) testDir := filepath.Join(os.TempDir(), "testDir") err := os.MkdirAll(testDir, os.ModePerm) @@ -2935,10 +2937,11 @@ func TestConnection(t *testing.T) { } user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{dataprovider.PermAny} - connection := &Connection{ - BaseConnection: common.NewBaseConnection(xid.New().String(), common.ProtocolHTTP, "", "", user), - request: nil, - } + connection := newConnection( + common.NewBaseConnection(xid.New().String(), common.ProtocolHTTP, "", "", user), + nil, + nil, + ) assert.Empty(t, connection.GetClientVersion()) assert.Empty(t, connection.GetRemoteAddress()) assert.Empty(t, connection.GetCommand()) @@ -2959,10 +2962,11 @@ func TestGetFileWriterErrors(t *testing.T) { } user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{dataprovider.PermAny} - connection := &Connection{ - BaseConnection: common.NewBaseConnection(xid.New().String(), common.ProtocolHTTP, "", "", user), - request: nil, - } + connection := newConnection( + common.NewBaseConnection(xid.New().String(), common.ProtocolHTTP, "", "", user), + nil, + nil, + ) _, err := connection.getFileWriter("name") assert.Error(t, err) @@ -2975,10 +2979,11 @@ func TestGetFileWriterErrors(t *testing.T) { }, AccessSecret: kms.NewPlainSecret("secret"), } - connection = &Connection{ - BaseConnection: common.NewBaseConnection(xid.New().String(), common.ProtocolHTTP, "", "", user), - request: nil, - } + connection = newConnection( + common.NewBaseConnection(xid.New().String(), common.ProtocolHTTP, "", "", user), + nil, + nil, + ) _, err = connection.getFileWriter("/path") assert.Error(t, err) } @@ -3007,9 +3012,11 @@ func TestHTTPDFile(t *testing.T) { } user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{dataprovider.PermAny} - connection := &Connection{ - BaseConnection: common.NewBaseConnection(xid.New().String(), common.ProtocolHTTP, "", "", user), - } + connection := newConnection( + common.NewBaseConnection(xid.New().String(), common.ProtocolHTTP, "", "", user), + nil, + nil, + ) fs, err := user.GetFilesystem("") assert.NoError(t, err) diff --git a/internal/httpd/server.go b/internal/httpd/server.go index 235014ad..c572307c 100644 --- a/internal/httpd/server.go +++ b/internal/httpd/server.go @@ -103,8 +103,6 @@ func (s *httpdServer) listenAndServe() error { httpServer := &http.Server{ Handler: s.router, ReadHeaderTimeout: 30 * time.Second, - ReadTimeout: 60 * time.Second, - WriteTimeout: 60 * time.Second, IdleTimeout: 60 * time.Second, MaxHeaderBytes: 1 << 16, // 64KB ErrorLog: log.New(&logger.StdLoggerWrapper{Sender: logSender}, "", 0), @@ -1087,6 +1085,11 @@ func (s *httpdServer) updateContextFromCookie(r *http.Request) *http.Request { func (s *httpdServer) parseHeaders(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + responseControllerDeadlines( + http.NewResponseController(w), + time.Now().Add(60*time.Second), + time.Now().Add(60*time.Second), + ) w.Header().Set("Server", version.GetServerVersion("/", false)) ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) var ip net.IP diff --git a/internal/httpd/webclient.go b/internal/httpd/webclient.go index eadd95eb..23ab9cbf 100644 --- a/internal/httpd/webclient.go +++ b/internal/httpd/webclient.go @@ -906,11 +906,8 @@ func (s *httpdServer) handleWebClientDownloadZip(w http.ResponseWriter, r *http. s.renderClientForbiddenPage(w, r, err) return } - connection := &Connection{ - BaseConnection: common.NewBaseConnection(connID, protocol, util.GetHTTPLocalAddress(r), - r.RemoteAddr, user), - request: r, - } + baseConn := common.NewBaseConnection(connID, protocol, util.GetHTTPLocalAddress(r), r.RemoteAddr, user) + connection := newConnection(baseConn, w, r) if err = common.Connections.Add(connection); err != nil { s.renderClientMessagePage(w, r, util.I18nError429Title, http.StatusTooManyRequests, util.NewI18nError(err, util.I18nError429Message), "") @@ -1197,11 +1194,8 @@ func (s *httpdServer) handleClientGetDirContents(w http.ResponseWriter, r *http. sendAPIResponse(w, r, err, getI18NErrorString(err, util.I18nErrorDirList403), http.StatusForbidden) return } - connection := &Connection{ - BaseConnection: common.NewBaseConnection(connID, protocol, util.GetHTTPLocalAddress(r), - r.RemoteAddr, user), - request: r, - } + baseConn := common.NewBaseConnection(connID, protocol, util.GetHTTPLocalAddress(r), r.RemoteAddr, user) + connection := newConnection(baseConn, w, r) if err = common.Connections.Add(connection); err != nil { sendAPIResponse(w, r, err, util.I18nErrorDirList429, http.StatusTooManyRequests) return @@ -1287,11 +1281,8 @@ func (s *httpdServer) handleClientGetFiles(w http.ResponseWriter, r *http.Reques s.renderClientForbiddenPage(w, r, err) return } - connection := &Connection{ - BaseConnection: common.NewBaseConnection(connID, protocol, util.GetHTTPLocalAddress(r), - r.RemoteAddr, user), - request: r, - } + baseConn := common.NewBaseConnection(connID, protocol, util.GetHTTPLocalAddress(r), r.RemoteAddr, user) + connection := newConnection(baseConn, w, r) if err = common.Connections.Add(connection); err != nil { s.renderClientMessagePage(w, r, util.I18nError429Title, http.StatusTooManyRequests, util.NewI18nError(err, util.I18nError429Message), "") @@ -1348,11 +1339,8 @@ func (s *httpdServer) handleClientEditFile(w http.ResponseWriter, r *http.Reques s.renderClientForbiddenPage(w, r, err) return } - connection := &Connection{ - BaseConnection: common.NewBaseConnection(connID, protocol, util.GetHTTPLocalAddress(r), - r.RemoteAddr, user), - request: r, - } + baseConn := common.NewBaseConnection(connID, protocol, util.GetHTTPLocalAddress(r), r.RemoteAddr, user) + connection := newConnection(baseConn, w, r) if err = common.Connections.Add(connection); err != nil { s.renderClientMessagePage(w, r, util.I18nError429Title, http.StatusTooManyRequests, util.NewI18nError(err, util.I18nError429Message), "") @@ -1844,11 +1832,8 @@ func (s *httpdServer) handleClientGetPDF(w http.ResponseWriter, r *http.Request) s.renderClientForbiddenPage(w, r, err) return } - connection := &Connection{ - BaseConnection: common.NewBaseConnection(connID, protocol, util.GetHTTPLocalAddress(r), - r.RemoteAddr, user), - request: r, - } + baseConn := common.NewBaseConnection(connID, protocol, util.GetHTTPLocalAddress(r), r.RemoteAddr, user) + connection := newConnection(baseConn, w, r) if err = common.Connections.Add(connection); err != nil { s.renderClientMessagePage(w, r, util.I18nError429Title, http.StatusTooManyRequests, util.NewI18nError(err, util.I18nError429Message), "") diff --git a/internal/util/timeoutlistener.go b/internal/util/timeoutlistener.go deleted file mode 100644 index fce93cca..00000000 --- a/internal/util/timeoutlistener.go +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright (C) 2019 Nicola Murino -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as published -// by the Free Software Foundation, version 3. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see . - -package util - -import ( - "net" - "sync/atomic" - "time" -) - -type listener struct { - net.Listener - ReadTimeout time.Duration - WriteTimeout time.Duration -} - -func (l *listener) Accept() (net.Conn, error) { - c, err := l.Listener.Accept() - if err != nil { - return nil, err - } - tc := &Conn{ - Conn: c, - ReadTimeout: l.ReadTimeout, - WriteTimeout: l.WriteTimeout, - ReadThreshold: int32((l.ReadTimeout * 1024) / time.Second), - WriteThreshold: int32((l.WriteTimeout * 1024) / time.Second), - } - tc.BytesReadFromDeadline.Store(0) - tc.BytesWrittenFromDeadline.Store(0) - return tc, nil -} - -// Conn wraps a net.Conn, and sets a deadline for every read -// and write operation. -type Conn struct { - net.Conn - ReadTimeout time.Duration - WriteTimeout time.Duration - ReadThreshold int32 - WriteThreshold int32 - BytesReadFromDeadline atomic.Int32 - BytesWrittenFromDeadline atomic.Int32 -} - -func (c *Conn) Read(b []byte) (n int, err error) { - if c.BytesReadFromDeadline.Load() > c.ReadThreshold { - c.BytesReadFromDeadline.Store(0) - // we set both read and write deadlines here otherwise after the request - // is read writing the response fails with an i/o timeout error - err = c.SetDeadline(time.Now().Add(c.ReadTimeout)) - if err != nil { - return 0, err - } - } - n, err = c.Conn.Read(b) - c.BytesReadFromDeadline.Add(int32(n)) - return -} - -func (c *Conn) Write(b []byte) (n int, err error) { - if c.BytesWrittenFromDeadline.Load() > c.WriteThreshold { - c.BytesWrittenFromDeadline.Store(0) - // we extend the read deadline too, not sure it's necessary, - // but it doesn't hurt - err = c.SetDeadline(time.Now().Add(c.WriteTimeout)) - if err != nil { - return - } - } - n, err = c.Conn.Write(b) - c.BytesWrittenFromDeadline.Add(int32(n)) - return -} - -func newListener(network, addr string, readTimeout, writeTimeout time.Duration) (net.Listener, error) { - l, err := net.Listen(network, addr) - if err != nil { - return nil, err - } - - tl := &listener{ - Listener: l, - ReadTimeout: readTimeout, - WriteTimeout: writeTimeout, - } - return tl, nil -} diff --git a/internal/util/util.go b/internal/util/util.go index 39b81c54..d3eb7dc1 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -593,7 +593,7 @@ func HTTPListenAndServe(srv *http.Server, address string, port int, isTLS bool, logger.Error(logSender, "", "error creating Unix-domain socket parent dir: %v", err) } os.Remove(address) - listener, err = newListener("unix", address, srv.ReadTimeout, srv.WriteTimeout) + listener, err = net.Listen("unix", address) if err == nil { // should a chmod err be fatal? if errChmod := os.Chmod(address, 0770); errChmod != nil { @@ -602,7 +602,7 @@ func HTTPListenAndServe(srv *http.Server, address string, port int, isTLS bool, } } else { CheckTCP4Port(port) - listener, err = newListener("tcp", fmt.Sprintf("%s:%d", address, port), srv.ReadTimeout, srv.WriteTimeout) + listener, err = net.Listen("tcp", fmt.Sprintf("%s:%d", address, port)) } if err != nil { return err diff --git a/internal/webdavd/handler.go b/internal/webdavd/handler.go index 54f42a18..2fc3ca1b 100644 --- a/internal/webdavd/handler.go +++ b/internal/webdavd/handler.go @@ -36,6 +36,17 @@ import ( type Connection struct { *common.BaseConnection request *http.Request + rc *http.ResponseController +} + +func newConnection(conn *common.BaseConnection, w http.ResponseWriter, r *http.Request) *Connection { + rc := http.NewResponseController(w) + responseControllerDeadlines(rc, time.Time{}, time.Time{}) + return &Connection{ + BaseConnection: conn, + request: r, + rc: rc, + } } func (c *Connection) getModificationTime() time.Time { @@ -73,6 +84,9 @@ func (c *Connection) GetRemoteAddress() string { // Disconnect closes the active transfer func (c *Connection) Disconnect() error { + if c.rc != nil { + responseControllerDeadlines(c.rc, time.Now().Add(5*time.Second), time.Now().Add(5*time.Second)) + } return c.SignalTransfersAbort() } diff --git a/internal/webdavd/server.go b/internal/webdavd/server.go index de99490f..06533320 100644 --- a/internal/webdavd/server.go +++ b/internal/webdavd/server.go @@ -55,8 +55,6 @@ func (s *webDavServer) listenAndServe(compressor *middleware.Compressor) error { handler := compressor.Handler(s) httpServer := &http.Server{ ReadHeaderTimeout: 30 * time.Second, - ReadTimeout: 60 * time.Second, - WriteTimeout: 60 * time.Second, IdleTimeout: 60 * time.Second, MaxHeaderBytes: 1 << 16, // 64KB ErrorLog: log.New(&logger.StdLoggerWrapper{Sender: logSender}, "", 0), @@ -170,6 +168,11 @@ func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { } }() + responseControllerDeadlines( + http.NewResponseController(w), + time.Now().Add(60*time.Second), + time.Now().Add(60*time.Second), + ) w.Header().Set("Server", version.GetServerVersion("/", false)) ipAddr := s.checkRemoteAddress(r) @@ -228,11 +231,9 @@ func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - connection := &Connection{ - BaseConnection: common.NewBaseConnection(connectionID, common.ProtocolWebDAV, util.GetHTTPLocalAddress(r), - r.RemoteAddr, user), - request: r, - } + baseConn := common.NewBaseConnection(connectionID, common.ProtocolWebDAV, util.GetHTTPLocalAddress(r), + r.RemoteAddr, user) + connection := newConnection(baseConn, w, r) if err = common.Connections.Add(connection); err != nil { errClose := user.CloseFs() logger.Warn(logSender, connectionID, "unable add connection: %v close fs error: %v", err, errClose) @@ -389,6 +390,15 @@ func (s *webDavServer) checkRemoteAddress(r *http.Request) string { return ipAddr } +func responseControllerDeadlines(rc *http.ResponseController, read, write time.Time) { + if err := rc.SetReadDeadline(read); err != nil { + logger.Error(logSender, "", "unable to set read timeout to %s: %v", read, err) + } + if err := rc.SetWriteDeadline(write); err != nil { + logger.Error(logSender, "", "unable to set write timeout to %s: %v", write, err) + } +} + func writeLog(r *http.Request, status int, err error) { scheme := "http" cipherSuite := ""