fix connection limits

an SFTP client can start multiple transfers on a single connection

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
Nicola Murino
2024-10-26 21:18:19 +02:00
parent c69fbe6bf9
commit ae1487d733
24 changed files with 707 additions and 7 deletions

View File

@@ -37,6 +37,7 @@ import (
ftpserver "github.com/fclairamb/ftpserverlib"
"github.com/jlaffaye/ftp"
"github.com/pkg/sftp"
"github.com/pquerna/otp"
"github.com/pquerna/otp/totp"
"github.com/rs/zerolog"
@@ -44,6 +45,7 @@ import (
sdkkms "github.com/sftpgo/sdk/kms"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
"github.com/drakkan/sftpgo/v2/internal/common"
"github.com/drakkan/sftpgo/v2/internal/config"
@@ -671,6 +673,7 @@ func TestBasicFTPHandling(t *testing.T) {
assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 50*time.Millisecond)
assert.Eventually(t, func() bool { return common.Connections.GetClientConnections() == 0 }, 1000*time.Millisecond,
50*time.Millisecond)
assert.Equal(t, int32(0), common.Connections.GetTotalTransfers())
}
func TestHTTPFs(t *testing.T) {
@@ -715,6 +718,7 @@ func TestHTTPFs(t *testing.T) {
assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 50*time.Millisecond)
assert.Eventually(t, func() bool { return common.Connections.GetClientConnections() == 0 }, 1000*time.Millisecond,
50*time.Millisecond)
assert.Equal(t, int32(0), common.Connections.GetTotalTransfers())
}
func TestListDirWithWildcards(t *testing.T) {
@@ -1735,6 +1739,66 @@ func TestMaxPerHostConnections(t *testing.T) {
common.Config.MaxPerHostConnections = oldValue
}
func TestMaxTransfers(t *testing.T) {
oldValue := common.Config.MaxPerHostConnections
common.Config.MaxPerHostConnections = 2
assert.Eventually(t, func() bool {
return common.Connections.GetClientConnections() == 0
}, 1000*time.Millisecond, 50*time.Millisecond)
user := getTestUser()
err := dataprovider.AddUser(&user, "", "", "")
assert.NoError(t, err)
user.Password = ""
testFilePath := filepath.Join(homeBasePath, testFileName)
testFileSize := int64(65535)
err = createTestFile(testFilePath, testFileSize)
assert.NoError(t, err)
conn, sftpClient, err := getSftpClient(user)
assert.NoError(t, err)
defer conn.Close()
defer sftpClient.Close()
f1, err := sftpClient.Create("file1")
assert.NoError(t, err)
f2, err := sftpClient.Create("file2")
assert.NoError(t, err)
_, err = f1.Write([]byte(" "))
assert.NoError(t, err)
_, err = f2.Write([]byte(" "))
assert.NoError(t, err)
client, err := getFTPClient(user, true, nil)
if assert.NoError(t, err) {
err = checkBasicFTP(client)
assert.NoError(t, err)
err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0)
assert.Error(t, err)
localDownloadPath := filepath.Join(homeBasePath, testDLFileName)
err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0)
assert.Error(t, err)
err := client.Quit()
assert.NoError(t, err)
err = os.Remove(localDownloadPath)
assert.NoError(t, err)
}
err = f1.Close()
assert.NoError(t, err)
err = f2.Close()
assert.NoError(t, err)
err = dataprovider.DeleteUser(user.Username, "", "", "")
assert.NoError(t, err)
err = os.RemoveAll(user.GetHomeDir())
assert.NoError(t, err)
common.Config.MaxPerHostConnections = oldValue
}
func TestRateLimiter(t *testing.T) {
oldConfig := config.GetCommonConfig()
@@ -3962,6 +4026,7 @@ func TestNestedVirtualFolders(t *testing.T) {
assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 50*time.Millisecond)
assert.Eventually(t, func() bool { return common.Connections.GetClientConnections() == 0 }, 1000*time.Millisecond,
50*time.Millisecond)
assert.Equal(t, int32(0), common.Connections.GetTotalTransfers())
}
func checkBasicFTP(client *ftp.ServerConn) error {
@@ -4213,6 +4278,30 @@ func getPreLoginScriptContent(user dataprovider.User, nonJSONResponse bool) []by
return content
}
func getSftpClient(user dataprovider.User) (*ssh.Client, *sftp.Client, error) {
var sftpClient *sftp.Client
config := &ssh.ClientConfig{
User: user.Username,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: 5 * time.Second,
}
if user.Password != "" {
config.Auth = []ssh.AuthMethod{ssh.Password(user.Password)}
} else {
config.Auth = []ssh.AuthMethod{ssh.Password(defaultPassword)}
}
conn, err := ssh.Dial("tcp", sftpServerAddr, config)
if err != nil {
return conn, sftpClient, err
}
sftpClient, err = sftp.NewClient(conn)
if err != nil {
conn.Close()
}
return conn, sftpClient, err
}
func getExitCodeScriptContent(exitCode int) []byte {
content := []byte("#!/bin/sh\n\n")
content = append(content, []byte(fmt.Sprintf("exit %v", exitCode))...)