sftpd: properly handle listener accept errors

continue on temporary errors and exit from the serve loop for the
other ones
This commit is contained in:
Nicola Murino
2020-12-23 19:53:07 +01:00
parent 7ab7941ddd
commit 187a5b1908
5 changed files with 115 additions and 30 deletions

View File

@@ -141,7 +141,7 @@ func (c *Configuration) Initialize(configDir string) error {
PassivePortRange: c.PassivePortRange, PassivePortRange: c.PassivePortRange,
} }
exitChannel := make(chan error) exitChannel := make(chan error, 1)
for idx, binding := range c.Bindings { for idx, binding := range c.Bindings {
if !binding.IsValid() { if !binding.IsValid() {

View File

@@ -1866,3 +1866,71 @@ func TestRecoverer(t *testing.T) {
assert.EqualError(t, err, common.ErrGenericFailure.Error()) assert.EqualError(t, err, common.ErrGenericFailure.Error())
assert.Len(t, common.Connections.GetStats(), 0) assert.Len(t, common.Connections.GetStats(), 0)
} }
func TestListernerAcceptErrors(t *testing.T) {
errFake := errors.New("a fake error")
listener := newFakeListener(errFake)
c := Configuration{}
err := c.serve(listener, nil)
require.EqualError(t, err, errFake.Error())
err = listener.Close()
require.NoError(t, err)
errNetFake := &fakeNetError{error: errFake}
listener = newFakeListener(errNetFake)
err = c.serve(listener, nil)
require.EqualError(t, err, errFake.Error())
err = listener.Close()
require.NoError(t, err)
}
type fakeNetError struct {
error
count int
}
func (e *fakeNetError) Timeout() bool {
return false
}
func (e *fakeNetError) Temporary() bool {
e.count++
return e.count < 10
}
func (e *fakeNetError) Error() string {
return e.error.Error()
}
type fakeListener struct {
server net.Conn
client net.Conn
err error
}
func (l *fakeListener) Accept() (net.Conn, error) {
return l.client, l.err
}
func (l *fakeListener) Close() error {
errClient := l.client.Close()
errServer := l.server.Close()
if errServer != nil {
return errServer
}
return errClient
}
func (l *fakeListener) Addr() net.Addr {
return l.server.LocalAddr()
}
func newFakeListener(err error) net.Listener {
server, client := net.Pipe()
return &fakeListener{
server: server,
client: client,
err: err,
}
}

View File

@@ -15,7 +15,6 @@ import (
"strings" "strings"
"time" "time"
"github.com/pires/go-proxyproto"
"github.com/pkg/sftp" "github.com/pkg/sftp"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
@@ -224,7 +223,7 @@ func (c *Configuration) Initialize(configDir string) error {
c.configureLoginBanner(serverConfig, configDir) c.configureLoginBanner(serverConfig, configDir)
c.checkSSHCommands() c.checkSSHCommands()
exitChannel := make(chan error) exitChannel := make(chan error, 1)
serviceStatus.Bindings = nil serviceStatus.Bindings = nil
for _, binding := range c.Bindings { for _, binding := range c.Bindings {
@@ -234,7 +233,27 @@ func (c *Configuration) Initialize(configDir string) error {
serviceStatus.Bindings = append(serviceStatus.Bindings, binding) serviceStatus.Bindings = append(serviceStatus.Bindings, binding)
go func(binding Binding) { go func(binding Binding) {
exitChannel <- c.listenAndServe(binding, serverConfig) addr := binding.GetAddress()
listener, err := net.Listen("tcp", addr)
if err != nil {
logger.Warn(logSender, "", "error starting listener on address %v: %v", addr, err)
exitChannel <- err
return
}
if binding.ApplyProxyConfig {
proxyListener, err := common.Config.GetProxyListener(listener)
if err != nil {
logger.Warn(logSender, "", "error enabling proxy listener: %v", err)
exitChannel <- err
return
}
if proxyListener != nil {
listener = proxyListener
}
}
exitChannel <- c.serve(listener, serverConfig)
}(binding) }(binding)
} }
@@ -244,35 +263,33 @@ func (c *Configuration) Initialize(configDir string) error {
return <-exitChannel return <-exitChannel
} }
func (c *Configuration) listenAndServe(binding Binding, serverConfig *ssh.ServerConfig) error { func (c *Configuration) serve(listener net.Listener, serverConfig *ssh.ServerConfig) error {
addr := binding.GetAddress()
listener, err := net.Listen("tcp", addr)
if err != nil {
logger.Warn(logSender, "", "error starting listener on address %v: %v", addr, err)
return err
}
var proxyListener *proxyproto.Listener
if binding.ApplyProxyConfig {
proxyListener, err = common.Config.GetProxyListener(listener)
if err != nil {
logger.Warn(logSender, "", "error enabling proxy listener: %v", err)
return err
}
}
logger.Info(logSender, "", "server listener registered, address: %v", listener.Addr().String()) logger.Info(logSender, "", "server listener registered, address: %v", listener.Addr().String())
var tempDelay time.Duration // how long to sleep on accept failure
for { for {
var conn net.Conn conn, err := listener.Accept()
if proxyListener != nil { if err != nil {
conn, err = proxyListener.Accept() if ne, ok := err.(net.Error); ok && ne.Temporary() {
if tempDelay == 0 {
tempDelay = 5 * time.Millisecond
} else { } else {
conn, err = listener.Accept() tempDelay *= 2
} }
if conn != nil && err == nil { if max := 1 * time.Second; tempDelay > max {
tempDelay = max
}
logger.Warn(logSender, "", "accept error: %v; retrying in %v", err, tempDelay)
time.Sleep(tempDelay)
continue
}
logger.Warn(logSender, "", "unrecoverable accept error: %v", err)
return err
}
go c.AcceptInboundConnection(conn, serverConfig) go c.AcceptInboundConnection(conn, serverConfig)
} }
} }
}
func (c *Configuration) configureSecurityOptions(serverConfig *ssh.ServerConfig) { func (c *Configuration) configureSecurityOptions(serverConfig *ssh.ServerConfig) {
if len(c.KexAlgorithms) > 0 { if len(c.KexAlgorithms) > 0 {

View File

@@ -8420,7 +8420,7 @@ func sftpDownloadFile(remoteSourcePath string, localDestPath string, expectedSiz
} }
func sftpUploadNonBlocking(localSourcePath string, remoteDestPath string, expectedSize int64, client *sftp.Client) <-chan error { func sftpUploadNonBlocking(localSourcePath string, remoteDestPath string, expectedSize int64, client *sftp.Client) <-chan error {
c := make(chan error) c := make(chan error, 1)
go func() { go func() {
c <- sftpUploadFile(localSourcePath, remoteDestPath, expectedSize, client) c <- sftpUploadFile(localSourcePath, remoteDestPath, expectedSize, client)
}() }()
@@ -8428,7 +8428,7 @@ func sftpUploadNonBlocking(localSourcePath string, remoteDestPath string, expect
} }
func sftpDownloadNonBlocking(remoteSourcePath string, localDestPath string, expectedSize int64, client *sftp.Client) <-chan error { func sftpDownloadNonBlocking(remoteSourcePath string, localDestPath string, expectedSize int64, client *sftp.Client) <-chan error {
c := make(chan error) c := make(chan error, 1)
go func() { go func() {
c <- sftpDownloadFile(remoteSourcePath, localDestPath, expectedSize, client) c <- sftpDownloadFile(remoteSourcePath, localDestPath, expectedSize, client)
}() }()

View File

@@ -140,7 +140,7 @@ func (c *Configuration) Initialize(configDir string) error {
server.status.Bindings = nil server.status.Bindings = nil
exitChannel := make(chan error) exitChannel := make(chan error, 1)
for _, binding := range c.Bindings { for _, binding := range c.Bindings {
if !binding.IsValid() { if !binding.IsValid() {