diff --git a/ftpd/ftpd.go b/ftpd/ftpd.go index 27f10fd1..95024825 100644 --- a/ftpd/ftpd.go +++ b/ftpd/ftpd.go @@ -141,7 +141,7 @@ func (c *Configuration) Initialize(configDir string) error { PassivePortRange: c.PassivePortRange, } - exitChannel := make(chan error) + exitChannel := make(chan error, 1) for idx, binding := range c.Bindings { if !binding.IsValid() { diff --git a/sftpd/internal_test.go b/sftpd/internal_test.go index 35448bc1..676c27a6 100644 --- a/sftpd/internal_test.go +++ b/sftpd/internal_test.go @@ -1866,3 +1866,71 @@ func TestRecoverer(t *testing.T) { assert.EqualError(t, err, common.ErrGenericFailure.Error()) assert.Len(t, common.Connections.GetStats(), 0) } + +func TestListernerAcceptErrors(t *testing.T) { + errFake := errors.New("a fake error") + listener := newFakeListener(errFake) + c := Configuration{} + err := c.serve(listener, nil) + require.EqualError(t, err, errFake.Error()) + err = listener.Close() + require.NoError(t, err) + + errNetFake := &fakeNetError{error: errFake} + listener = newFakeListener(errNetFake) + err = c.serve(listener, nil) + require.EqualError(t, err, errFake.Error()) + err = listener.Close() + require.NoError(t, err) +} + +type fakeNetError struct { + error + count int +} + +func (e *fakeNetError) Timeout() bool { + return false +} + +func (e *fakeNetError) Temporary() bool { + e.count++ + return e.count < 10 +} + +func (e *fakeNetError) Error() string { + return e.error.Error() +} + +type fakeListener struct { + server net.Conn + client net.Conn + err error +} + +func (l *fakeListener) Accept() (net.Conn, error) { + return l.client, l.err +} + +func (l *fakeListener) Close() error { + errClient := l.client.Close() + errServer := l.server.Close() + if errServer != nil { + return errServer + } + return errClient +} + +func (l *fakeListener) Addr() net.Addr { + return l.server.LocalAddr() +} + +func newFakeListener(err error) net.Listener { + server, client := net.Pipe() + + return &fakeListener{ + server: server, + client: client, + err: err, + } +} diff --git a/sftpd/server.go b/sftpd/server.go index 6eee274d..68f1883f 100644 --- a/sftpd/server.go +++ b/sftpd/server.go @@ -15,7 +15,6 @@ import ( "strings" "time" - "github.com/pires/go-proxyproto" "github.com/pkg/sftp" "golang.org/x/crypto/ssh" @@ -224,7 +223,7 @@ func (c *Configuration) Initialize(configDir string) error { c.configureLoginBanner(serverConfig, configDir) c.checkSSHCommands() - exitChannel := make(chan error) + exitChannel := make(chan error, 1) serviceStatus.Bindings = nil for _, binding := range c.Bindings { @@ -234,7 +233,27 @@ func (c *Configuration) Initialize(configDir string) error { serviceStatus.Bindings = append(serviceStatus.Bindings, binding) go func(binding Binding) { - exitChannel <- c.listenAndServe(binding, serverConfig) + addr := binding.GetAddress() + listener, err := net.Listen("tcp", addr) + if err != nil { + logger.Warn(logSender, "", "error starting listener on address %v: %v", addr, err) + exitChannel <- err + return + } + + if binding.ApplyProxyConfig { + proxyListener, err := common.Config.GetProxyListener(listener) + if err != nil { + logger.Warn(logSender, "", "error enabling proxy listener: %v", err) + exitChannel <- err + return + } + if proxyListener != nil { + listener = proxyListener + } + } + + exitChannel <- c.serve(listener, serverConfig) }(binding) } @@ -244,33 +263,31 @@ func (c *Configuration) Initialize(configDir string) error { return <-exitChannel } -func (c *Configuration) listenAndServe(binding Binding, serverConfig *ssh.ServerConfig) error { - addr := binding.GetAddress() - listener, err := net.Listen("tcp", addr) - if err != nil { - logger.Warn(logSender, "", "error starting listener on address %v: %v", addr, err) - return err - } - var proxyListener *proxyproto.Listener +func (c *Configuration) serve(listener net.Listener, serverConfig *ssh.ServerConfig) error { + logger.Info(logSender, "", "server listener registered, address: %v", listener.Addr().String()) + var tempDelay time.Duration // how long to sleep on accept failure - if binding.ApplyProxyConfig { - proxyListener, err = common.Config.GetProxyListener(listener) + for { + conn, err := listener.Accept() if err != nil { - logger.Warn(logSender, "", "error enabling proxy listener: %v", err) + if ne, ok := err.(net.Error); ok && ne.Temporary() { + if tempDelay == 0 { + tempDelay = 5 * time.Millisecond + } else { + tempDelay *= 2 + } + if max := 1 * time.Second; tempDelay > max { + tempDelay = max + } + logger.Warn(logSender, "", "accept error: %v; retrying in %v", err, tempDelay) + time.Sleep(tempDelay) + continue + } + logger.Warn(logSender, "", "unrecoverable accept error: %v", err) return err } - } - logger.Info(logSender, "", "server listener registered, address: %v", listener.Addr().String()) - for { - var conn net.Conn - if proxyListener != nil { - conn, err = proxyListener.Accept() - } else { - conn, err = listener.Accept() - } - if conn != nil && err == nil { - go c.AcceptInboundConnection(conn, serverConfig) - } + + go c.AcceptInboundConnection(conn, serverConfig) } } diff --git a/sftpd/sftpd_test.go b/sftpd/sftpd_test.go index 441ec231..d833ee3f 100644 --- a/sftpd/sftpd_test.go +++ b/sftpd/sftpd_test.go @@ -8420,7 +8420,7 @@ func sftpDownloadFile(remoteSourcePath string, localDestPath string, expectedSiz } func sftpUploadNonBlocking(localSourcePath string, remoteDestPath string, expectedSize int64, client *sftp.Client) <-chan error { - c := make(chan error) + c := make(chan error, 1) go func() { c <- sftpUploadFile(localSourcePath, remoteDestPath, expectedSize, client) }() @@ -8428,7 +8428,7 @@ func sftpUploadNonBlocking(localSourcePath string, remoteDestPath string, expect } func sftpDownloadNonBlocking(remoteSourcePath string, localDestPath string, expectedSize int64, client *sftp.Client) <-chan error { - c := make(chan error) + c := make(chan error, 1) go func() { c <- sftpDownloadFile(remoteSourcePath, localDestPath, expectedSize, client) }() diff --git a/webdavd/webdavd.go b/webdavd/webdavd.go index 8295fbfc..5231b1bb 100644 --- a/webdavd/webdavd.go +++ b/webdavd/webdavd.go @@ -140,7 +140,7 @@ func (c *Configuration) Initialize(configDir string) error { server.status.Bindings = nil - exitChannel := make(chan error) + exitChannel := make(chan error, 1) for _, binding := range c.Bindings { if !binding.IsValid() {