From b67cd0d3df127bcd5ca5e84ab5a5c704a502c4d0 Mon Sep 17 00:00:00 2001 From: Nicola Murino Date: Tue, 11 May 2021 08:04:57 +0200 Subject: [PATCH] ensure no client is connected before running max connections test cases --- common/common.go | 5 +++++ common/common_test.go | 5 +++++ ftpd/cryptfs_test.go | 2 ++ ftpd/ftpd_test.go | 14 ++++++++++++++ sftpd/sftpd_test.go | 8 ++++++++ webdavd/webdavd_test.go | 24 ++++++++++++++++++++++-- 6 files changed, 56 insertions(+), 2 deletions(-) diff --git a/common/common.go b/common/common.go index fe753aca..8ec0980f 100644 --- a/common/common.go +++ b/common/common.go @@ -711,6 +711,11 @@ func (conns *ActiveConnections) RemoveClientConnection(ipAddr string) { conns.clients.remove(ipAddr) } +// GetClientConnections returns the total number of client connections +func (conns *ActiveConnections) GetClientConnections() int32 { + return conns.clients.getTotal() +} + // IsNewConnectionAllowed returns false if the maximum number of concurrent allowed connections is exceeded func (conns *ActiveConnections) IsNewConnectionAllowed(ipAddr string) bool { if Config.MaxTotalConnections == 0 && Config.MaxPerHostConnections == 0 { diff --git a/common/common_test.go b/common/common_test.go index 2cd8259d..0c3e47cb 100644 --- a/common/common_test.go +++ b/common/common_test.go @@ -276,9 +276,13 @@ func TestMaxConnectionPerHost(t *testing.T) { Connections.AddClientConnection(ipAddr) assert.False(t, Connections.IsNewConnectionAllowed(ipAddr)) + assert.Equal(t, int32(3), Connections.GetClientConnections()) Connections.RemoveClientConnection(ipAddr) Connections.RemoveClientConnection(ipAddr) + Connections.RemoveClientConnection(ipAddr) + + assert.Equal(t, int32(0), Connections.GetClientConnections()) Config.MaxPerHostConnections = oldValue } @@ -357,6 +361,7 @@ func TestIdleConnections(t *testing.T) { defer Connections.RUnlock() return len(Connections.sshConnections) == 0 }, 1*time.Second, 200*time.Millisecond) + assert.Equal(t, int32(0), Connections.GetClientConnections()) stopIdleTimeoutTicker() assert.True(t, customConn1.isClosed) assert.True(t, customConn2.isClosed) diff --git a/ftpd/cryptfs_test.go b/ftpd/cryptfs_test.go index b449e76e..d16092f9 100644 --- a/ftpd/cryptfs_test.go +++ b/ftpd/cryptfs_test.go @@ -117,6 +117,8 @@ func TestBasicFTPHandlingCryptFs(t *testing.T) { err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) assert.Eventually(t, func() bool { return len(common.Connections.GetStats()) == 0 }, 1*time.Second, 50*time.Millisecond) + assert.Eventually(t, func() bool { return common.Connections.GetClientConnections() == 0 }, 1000*time.Millisecond, + 50*time.Millisecond) } func TestZeroBytesTransfersCryptFs(t *testing.T) { diff --git a/ftpd/ftpd_test.go b/ftpd/ftpd_test.go index 5aabdc1c..b10c8025 100644 --- a/ftpd/ftpd_test.go +++ b/ftpd/ftpd_test.go @@ -560,6 +560,8 @@ func TestBasicFTPHandling(t *testing.T) { err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) assert.Eventually(t, func() bool { return len(common.Connections.GetStats()) == 0 }, 1*time.Second, 50*time.Millisecond) + assert.Eventually(t, func() bool { return common.Connections.GetClientConnections() == 0 }, 1000*time.Millisecond, + 50*time.Millisecond) } func TestLoginInvalidCredentials(t *testing.T) { @@ -756,10 +758,15 @@ func TestPostConnectHook(t *testing.T) { common.Config.PostConnectHook = "" } +//nolint:dupl func TestMaxConnections(t *testing.T) { oldValue := common.Config.MaxTotalConnections common.Config.MaxTotalConnections = 1 + assert.Eventually(t, func() bool { + return common.Connections.GetClientConnections() == 0 + }, 1000*time.Millisecond, 50*time.Millisecond) + user := getTestUser() err := dataprovider.AddUser(&user) assert.NoError(t, err) @@ -781,10 +788,15 @@ func TestMaxConnections(t *testing.T) { common.Config.MaxTotalConnections = oldValue } +//nolint:dupl func TestMaxPerHostConnections(t *testing.T) { oldValue := common.Config.MaxPerHostConnections common.Config.MaxPerHostConnections = 1 + assert.Eventually(t, func() bool { + return common.Connections.GetClientConnections() == 0 + }, 1000*time.Millisecond, 50*time.Millisecond) + user := getTestUser() err := dataprovider.AddUser(&user) assert.NoError(t, err) @@ -2689,6 +2701,8 @@ func TestNestedVirtualFolders(t *testing.T) { err = os.RemoveAll(localUser.GetHomeDir()) assert.NoError(t, err) assert.Eventually(t, func() bool { return len(common.Connections.GetStats()) == 0 }, 1*time.Second, 50*time.Millisecond) + assert.Eventually(t, func() bool { return common.Connections.GetClientConnections() == 0 }, 1000*time.Millisecond, + 50*time.Millisecond) } func checkBasicFTP(client *ftp.ServerConn) error { diff --git a/sftpd/sftpd_test.go b/sftpd/sftpd_test.go index f6252138..3591ccb3 100644 --- a/sftpd/sftpd_test.go +++ b/sftpd/sftpd_test.go @@ -2906,6 +2906,10 @@ func TestMaxConnections(t *testing.T) { oldValue := common.Config.MaxTotalConnections common.Config.MaxTotalConnections = 1 + assert.Eventually(t, func() bool { + return common.Connections.GetClientConnections() == 0 + }, 1000*time.Millisecond, 50*time.Millisecond) + usePubKey := true user := getTestUser(usePubKey) err := dataprovider.AddUser(&user) @@ -2937,6 +2941,10 @@ func TestMaxPerHostConnections(t *testing.T) { oldValue := common.Config.MaxPerHostConnections common.Config.MaxPerHostConnections = 1 + assert.Eventually(t, func() bool { + return common.Connections.GetClientConnections() == 0 + }, 1000*time.Millisecond, 50*time.Millisecond) + usePubKey := true user := getTestUser(usePubKey) err := dataprovider.AddUser(&user) diff --git a/webdavd/webdavd_test.go b/webdavd/webdavd_test.go index 41ff56f6..dff71450 100644 --- a/webdavd/webdavd_test.go +++ b/webdavd/webdavd_test.go @@ -919,6 +919,10 @@ func TestMaxConnections(t *testing.T) { oldValue := common.Config.MaxTotalConnections common.Config.MaxTotalConnections = 1 + assert.Eventually(t, func() bool { + return common.Connections.GetClientConnections() == 0 + }, 1000*time.Millisecond, 50*time.Millisecond) + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) client := getWebDavClient(user, true, nil) @@ -944,6 +948,10 @@ func TestMaxPerHostConnections(t *testing.T) { oldValue := common.Config.MaxPerHostConnections common.Config.MaxPerHostConnections = 1 + assert.Eventually(t, func() bool { + return common.Connections.GetClientConnections() == 0 + }, 1000*time.Millisecond, 50*time.Millisecond) + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) assert.NoError(t, err) client := getWebDavClient(user, true, nil) @@ -1188,7 +1196,7 @@ func TestQuotaLimits(t *testing.T) { if !assert.NoError(t, err, "username: %v", user.Username) { info, err := os.Stat(testFilePath) if assert.NoError(t, err) { - fmt.Printf("local file size %v", info.Size()) + fmt.Printf("local file size: %v\n", info.Size()) } printLatestLogs(20) } @@ -2580,7 +2588,19 @@ func createTestFile(path string, size int64) error { if err != nil { return err } - return os.WriteFile(path, content, os.ModePerm) + + f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.ModePerm) + if err != nil { + return err + } + _, err = f.Write(content) + if err == nil { + err = f.Sync() + } + if err1 := f.Close(); err1 != nil && err == nil { + err = err1 + } + return err } func printLatestLogs(maxNumberOfLines int) {