fix connection limits

an SFTP client can start multiple transfers on a single connection

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
Nicola Murino
2024-10-26 21:18:19 +02:00
parent c69fbe6bf9
commit ae1487d733
24 changed files with 707 additions and 7 deletions

View File

@@ -38,11 +38,13 @@ import (
"time"
"github.com/minio/sio"
"github.com/pkg/sftp"
"github.com/rs/zerolog"
"github.com/sftpgo/sdk"
sdkkms "github.com/sftpgo/sdk/kms"
"github.com/stretchr/testify/assert"
"github.com/studio-b12/gowebdav"
"golang.org/x/crypto/ssh"
"github.com/drakkan/sftpgo/v2/internal/common"
"github.com/drakkan/sftpgo/v2/internal/config"
@@ -637,6 +639,7 @@ func TestBasicHandling(t *testing.T) {
assert.NoError(t, err)
assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 },
1*time.Second, 100*time.Millisecond)
assert.Equal(t, int32(0), common.Connections.GetTotalTransfers())
status := webdavd.GetStatus()
assert.True(t, status.IsActive)
}
@@ -721,6 +724,7 @@ func TestBasicHandlingCryptFs(t *testing.T) {
assert.NoError(t, err)
assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 },
1*time.Second, 100*time.Millisecond)
assert.Equal(t, int32(0), common.Connections.GetTotalTransfers())
}
func TestBufferedUser(t *testing.T) {
@@ -1010,6 +1014,8 @@ func TestRenameWithLock(t *testing.T) {
err = resp.Body.Close()
assert.NoError(t, err)
err = os.Remove(testFilePath)
assert.NoError(t, err)
_, err = httpdtest.RemoveUser(user, http.StatusOK)
assert.NoError(t, err)
err = os.RemoveAll(user.GetHomeDir())
@@ -1077,6 +1083,7 @@ func TestPropPatch(t *testing.T) {
assert.NoError(t, err)
assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 },
1*time.Second, 100*time.Millisecond)
assert.Equal(t, int32(0), common.Connections.GetTotalTransfers())
}
func TestLoginInvalidPwd(t *testing.T) {
@@ -1520,6 +1527,7 @@ func TestPreDownloadHook(t *testing.T) {
assert.NoError(t, err)
assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 },
1*time.Second, 100*time.Millisecond)
assert.Equal(t, int32(0), common.Connections.GetTotalTransfers())
common.Config.Actions.ExecuteOn = []string{common.OperationPreDownload}
common.Config.Actions.Hook = preDownloadPath
@@ -1570,6 +1578,7 @@ func TestPreUploadHook(t *testing.T) {
assert.NoError(t, err)
assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 },
1*time.Second, 100*time.Millisecond)
assert.Equal(t, int32(0), common.Connections.GetTotalTransfers())
common.Config.Actions.ExecuteOn = oldExecuteOn
common.Config.Actions.Hook = oldHook
@@ -1633,6 +1642,7 @@ func TestMaxConnections(t *testing.T) {
assert.NoError(t, err)
assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 },
1*time.Second, 100*time.Millisecond)
assert.Equal(t, int32(0), common.Connections.GetTotalTransfers())
common.Config.MaxTotalConnections = oldValue
}
@@ -1665,6 +1675,61 @@ func TestMaxPerHostConnections(t *testing.T) {
assert.NoError(t, err)
assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 },
1*time.Second, 100*time.Millisecond)
assert.Equal(t, int32(0), common.Connections.GetTotalTransfers())
common.Config.MaxPerHostConnections = oldValue
}
func TestMaxTransfers(t *testing.T) {
oldValue := common.Config.MaxPerHostConnections
common.Config.MaxPerHostConnections = 2
assert.Eventually(t, func() bool {
return common.Connections.GetClientConnections() == 0
}, 1000*time.Millisecond, 50*time.Millisecond)
user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated)
assert.NoError(t, err)
client := getWebDavClient(user, true, nil)
assert.NoError(t, checkBasicFunc(client))
conn, sftpClient, err := getSftpClient(user)
assert.NoError(t, err)
defer conn.Close()
defer sftpClient.Close()
f1, err := sftpClient.Create("file1")
assert.NoError(t, err)
f2, err := sftpClient.Create("file2")
assert.NoError(t, err)
_, err = f1.Write([]byte(" "))
assert.NoError(t, err)
_, err = f2.Write([]byte(" "))
assert.NoError(t, err)
testFilePath := filepath.Join(homeBasePath, testFileName)
testFileSize := int64(65535)
err = createTestFile(testFilePath, testFileSize)
assert.NoError(t, err)
err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword,
false, testFileSize, client)
assert.Error(t, err)
err = os.Remove(testFilePath)
assert.NoError(t, err)
err = f1.Close()
assert.NoError(t, err)
err = f2.Close()
assert.NoError(t, err)
_, err = httpdtest.RemoveUser(user, http.StatusOK)
assert.NoError(t, err)
err = os.RemoveAll(user.GetHomeDir())
assert.NoError(t, err)
assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 },
1*time.Second, 100*time.Millisecond)
assert.Equal(t, int32(0), common.Connections.GetTotalTransfers())
common.Config.MaxPerHostConnections = oldValue
}
@@ -1712,6 +1777,7 @@ func TestMaxSessions(t *testing.T) {
assert.NoError(t, err)
assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 },
1*time.Second, 100*time.Millisecond)
assert.Equal(t, int32(0), common.Connections.GetTotalTransfers())
}
func TestLoginWithIPilters(t *testing.T) {
@@ -2171,6 +2237,7 @@ func TestClientClose(t *testing.T) {
wg.Wait()
assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 },
1*time.Second, 100*time.Millisecond)
assert.Equal(t, int32(0), common.Connections.GetTotalTransfers())
err = os.Remove(localDownloadPath)
assert.NoError(t, err)
@@ -3276,6 +3343,7 @@ func TestNestedVirtualFolders(t *testing.T) {
assert.NoError(t, err)
assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 },
1*time.Second, 100*time.Millisecond)
assert.Equal(t, int32(0), common.Connections.GetTotalTransfers())
}
func checkBasicFunc(client *gowebdav.Client) error {
@@ -3472,6 +3540,30 @@ func getTestUserWithCryptFs() dataprovider.User {
return user
}
func getSftpClient(user dataprovider.User) (*ssh.Client, *sftp.Client, error) {
var sftpClient *sftp.Client
config := &ssh.ClientConfig{
User: user.Username,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: 5 * time.Second,
}
if user.Password != "" {
config.Auth = []ssh.AuthMethod{ssh.Password(user.Password)}
} else {
config.Auth = []ssh.AuthMethod{ssh.Password(defaultPassword)}
}
conn, err := ssh.Dial("tcp", sftpServerAddr, config)
if err != nil {
return conn, sftpClient, err
}
sftpClient, err = sftp.NewClient(conn)
if err != nil {
conn.Close()
}
return conn, sftpClient, err
}
func getEncryptedFileSize(size int64) (int64, error) {
encSize, err := sio.EncryptedSize(uint64(size))
return int64(encSize) + 33, err