mirror of
https://github.com/drakkan/sftpgo.git
synced 2025-12-08 07:10:56 +03:00
sftpd: properly handle listener accept errors
continue on temporary errors and exit from the serve loop for the other ones
This commit is contained in:
@@ -141,7 +141,7 @@ func (c *Configuration) Initialize(configDir string) error {
|
|||||||
PassivePortRange: c.PassivePortRange,
|
PassivePortRange: c.PassivePortRange,
|
||||||
}
|
}
|
||||||
|
|
||||||
exitChannel := make(chan error)
|
exitChannel := make(chan error, 1)
|
||||||
|
|
||||||
for idx, binding := range c.Bindings {
|
for idx, binding := range c.Bindings {
|
||||||
if !binding.IsValid() {
|
if !binding.IsValid() {
|
||||||
|
|||||||
@@ -1866,3 +1866,71 @@ func TestRecoverer(t *testing.T) {
|
|||||||
assert.EqualError(t, err, common.ErrGenericFailure.Error())
|
assert.EqualError(t, err, common.ErrGenericFailure.Error())
|
||||||
assert.Len(t, common.Connections.GetStats(), 0)
|
assert.Len(t, common.Connections.GetStats(), 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestListernerAcceptErrors(t *testing.T) {
|
||||||
|
errFake := errors.New("a fake error")
|
||||||
|
listener := newFakeListener(errFake)
|
||||||
|
c := Configuration{}
|
||||||
|
err := c.serve(listener, nil)
|
||||||
|
require.EqualError(t, err, errFake.Error())
|
||||||
|
err = listener.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
errNetFake := &fakeNetError{error: errFake}
|
||||||
|
listener = newFakeListener(errNetFake)
|
||||||
|
err = c.serve(listener, nil)
|
||||||
|
require.EqualError(t, err, errFake.Error())
|
||||||
|
err = listener.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
type fakeNetError struct {
|
||||||
|
error
|
||||||
|
count int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *fakeNetError) Timeout() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *fakeNetError) Temporary() bool {
|
||||||
|
e.count++
|
||||||
|
return e.count < 10
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *fakeNetError) Error() string {
|
||||||
|
return e.error.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
type fakeListener struct {
|
||||||
|
server net.Conn
|
||||||
|
client net.Conn
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *fakeListener) Accept() (net.Conn, error) {
|
||||||
|
return l.client, l.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *fakeListener) Close() error {
|
||||||
|
errClient := l.client.Close()
|
||||||
|
errServer := l.server.Close()
|
||||||
|
if errServer != nil {
|
||||||
|
return errServer
|
||||||
|
}
|
||||||
|
return errClient
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *fakeListener) Addr() net.Addr {
|
||||||
|
return l.server.LocalAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFakeListener(err error) net.Listener {
|
||||||
|
server, client := net.Pipe()
|
||||||
|
|
||||||
|
return &fakeListener{
|
||||||
|
server: server,
|
||||||
|
client: client,
|
||||||
|
err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pires/go-proxyproto"
|
|
||||||
"github.com/pkg/sftp"
|
"github.com/pkg/sftp"
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
|
|
||||||
@@ -224,7 +223,7 @@ func (c *Configuration) Initialize(configDir string) error {
|
|||||||
c.configureLoginBanner(serverConfig, configDir)
|
c.configureLoginBanner(serverConfig, configDir)
|
||||||
c.checkSSHCommands()
|
c.checkSSHCommands()
|
||||||
|
|
||||||
exitChannel := make(chan error)
|
exitChannel := make(chan error, 1)
|
||||||
serviceStatus.Bindings = nil
|
serviceStatus.Bindings = nil
|
||||||
|
|
||||||
for _, binding := range c.Bindings {
|
for _, binding := range c.Bindings {
|
||||||
@@ -234,7 +233,27 @@ func (c *Configuration) Initialize(configDir string) error {
|
|||||||
serviceStatus.Bindings = append(serviceStatus.Bindings, binding)
|
serviceStatus.Bindings = append(serviceStatus.Bindings, binding)
|
||||||
|
|
||||||
go func(binding Binding) {
|
go func(binding Binding) {
|
||||||
exitChannel <- c.listenAndServe(binding, serverConfig)
|
addr := binding.GetAddress()
|
||||||
|
listener, err := net.Listen("tcp", addr)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn(logSender, "", "error starting listener on address %v: %v", addr, err)
|
||||||
|
exitChannel <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if binding.ApplyProxyConfig {
|
||||||
|
proxyListener, err := common.Config.GetProxyListener(listener)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn(logSender, "", "error enabling proxy listener: %v", err)
|
||||||
|
exitChannel <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if proxyListener != nil {
|
||||||
|
listener = proxyListener
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
exitChannel <- c.serve(listener, serverConfig)
|
||||||
}(binding)
|
}(binding)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -244,34 +263,32 @@ func (c *Configuration) Initialize(configDir string) error {
|
|||||||
return <-exitChannel
|
return <-exitChannel
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Configuration) listenAndServe(binding Binding, serverConfig *ssh.ServerConfig) error {
|
func (c *Configuration) serve(listener net.Listener, serverConfig *ssh.ServerConfig) error {
|
||||||
addr := binding.GetAddress()
|
|
||||||
listener, err := net.Listen("tcp", addr)
|
|
||||||
if err != nil {
|
|
||||||
logger.Warn(logSender, "", "error starting listener on address %v: %v", addr, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
var proxyListener *proxyproto.Listener
|
|
||||||
|
|
||||||
if binding.ApplyProxyConfig {
|
|
||||||
proxyListener, err = common.Config.GetProxyListener(listener)
|
|
||||||
if err != nil {
|
|
||||||
logger.Warn(logSender, "", "error enabling proxy listener: %v", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
logger.Info(logSender, "", "server listener registered, address: %v", listener.Addr().String())
|
logger.Info(logSender, "", "server listener registered, address: %v", listener.Addr().String())
|
||||||
|
var tempDelay time.Duration // how long to sleep on accept failure
|
||||||
|
|
||||||
for {
|
for {
|
||||||
var conn net.Conn
|
conn, err := listener.Accept()
|
||||||
if proxyListener != nil {
|
if err != nil {
|
||||||
conn, err = proxyListener.Accept()
|
if ne, ok := err.(net.Error); ok && ne.Temporary() {
|
||||||
|
if tempDelay == 0 {
|
||||||
|
tempDelay = 5 * time.Millisecond
|
||||||
} else {
|
} else {
|
||||||
conn, err = listener.Accept()
|
tempDelay *= 2
|
||||||
}
|
}
|
||||||
if conn != nil && err == nil {
|
if max := 1 * time.Second; tempDelay > max {
|
||||||
|
tempDelay = max
|
||||||
|
}
|
||||||
|
logger.Warn(logSender, "", "accept error: %v; retrying in %v", err, tempDelay)
|
||||||
|
time.Sleep(tempDelay)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
logger.Warn(logSender, "", "unrecoverable accept error: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
go c.AcceptInboundConnection(conn, serverConfig)
|
go c.AcceptInboundConnection(conn, serverConfig)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Configuration) configureSecurityOptions(serverConfig *ssh.ServerConfig) {
|
func (c *Configuration) configureSecurityOptions(serverConfig *ssh.ServerConfig) {
|
||||||
|
|||||||
@@ -8420,7 +8420,7 @@ func sftpDownloadFile(remoteSourcePath string, localDestPath string, expectedSiz
|
|||||||
}
|
}
|
||||||
|
|
||||||
func sftpUploadNonBlocking(localSourcePath string, remoteDestPath string, expectedSize int64, client *sftp.Client) <-chan error {
|
func sftpUploadNonBlocking(localSourcePath string, remoteDestPath string, expectedSize int64, client *sftp.Client) <-chan error {
|
||||||
c := make(chan error)
|
c := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
c <- sftpUploadFile(localSourcePath, remoteDestPath, expectedSize, client)
|
c <- sftpUploadFile(localSourcePath, remoteDestPath, expectedSize, client)
|
||||||
}()
|
}()
|
||||||
@@ -8428,7 +8428,7 @@ func sftpUploadNonBlocking(localSourcePath string, remoteDestPath string, expect
|
|||||||
}
|
}
|
||||||
|
|
||||||
func sftpDownloadNonBlocking(remoteSourcePath string, localDestPath string, expectedSize int64, client *sftp.Client) <-chan error {
|
func sftpDownloadNonBlocking(remoteSourcePath string, localDestPath string, expectedSize int64, client *sftp.Client) <-chan error {
|
||||||
c := make(chan error)
|
c := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
c <- sftpDownloadFile(remoteSourcePath, localDestPath, expectedSize, client)
|
c <- sftpDownloadFile(remoteSourcePath, localDestPath, expectedSize, client)
|
||||||
}()
|
}()
|
||||||
|
|||||||
@@ -140,7 +140,7 @@ func (c *Configuration) Initialize(configDir string) error {
|
|||||||
|
|
||||||
server.status.Bindings = nil
|
server.status.Bindings = nil
|
||||||
|
|
||||||
exitChannel := make(chan error)
|
exitChannel := make(chan error, 1)
|
||||||
|
|
||||||
for _, binding := range c.Bindings {
|
for _, binding := range c.Bindings {
|
||||||
if !binding.IsValid() {
|
if !binding.IsValid() {
|
||||||
|
|||||||
Reference in New Issue
Block a user