add roles

Fixes #837

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
Nicola Murino
2022-11-16 19:04:50 +01:00
parent a9207857cf
commit 5a222807b7
83 changed files with 4285 additions and 806 deletions

View File

@@ -540,7 +540,7 @@ func TestUserMaxSessions(t *testing.T) {
Connections.Lock()
Connections.removeUserConnection(userTestUsername)
Connections.Unlock()
assert.Len(t, Connections.GetStats(), 0)
assert.Len(t, Connections.GetStats(""), 0)
}
func TestMaxConnections(t *testing.T) {
@@ -562,12 +562,12 @@ func TestMaxConnections(t *testing.T) {
}
err := Connections.Add(fakeConn)
assert.NoError(t, err)
assert.Len(t, Connections.GetStats(), 1)
assert.Len(t, Connections.GetStats(""), 1)
assert.Error(t, Connections.IsNewConnectionAllowed(ipAddr))
res := Connections.Close(fakeConn.GetID())
res := Connections.Close(fakeConn.GetID(), "")
assert.True(t, res)
assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 300*time.Millisecond, 50*time.Millisecond)
assert.Eventually(t, func() bool { return len(Connections.GetStats("")) == 0 }, 300*time.Millisecond, 50*time.Millisecond)
assert.NoError(t, Connections.IsNewConnectionAllowed(ipAddr))
Connections.AddClientConnection(ipAddr)
@@ -580,6 +580,33 @@ func TestMaxConnections(t *testing.T) {
Config.MaxTotalConnections = oldValue
}
func TestConnectionRoles(t *testing.T) {
username := "testUsername"
role1 := "testRole1"
role2 := "testRole2"
c := NewBaseConnection("id", ProtocolSFTP, "", "", dataprovider.User{
BaseUser: sdk.BaseUser{
Username: username,
Role: role1,
},
})
fakeConn := &fakeConnection{
BaseConnection: c,
}
err := Connections.Add(fakeConn)
assert.NoError(t, err)
assert.Len(t, Connections.GetStats(""), 1)
assert.Len(t, Connections.GetStats(role1), 1)
assert.Len(t, Connections.GetStats(role2), 0)
res := Connections.Close(fakeConn.GetID(), role2)
assert.False(t, res)
assert.Len(t, Connections.GetStats(""), 1)
res = Connections.Close(fakeConn.GetID(), role1)
assert.True(t, res)
assert.Eventually(t, func() bool { return len(Connections.GetStats("")) == 0 }, 300*time.Millisecond, 50*time.Millisecond)
}
func TestMaxConnectionPerHost(t *testing.T) {
oldValue := Config.MaxPerHostConnections
@@ -660,7 +687,7 @@ func TestIdleConnections(t *testing.T) {
err = Connections.Add(fakeConn)
assert.NoError(t, err)
assert.Equal(t, Connections.GetActiveSessions(username), 2)
assert.Len(t, Connections.GetStats(), 3)
assert.Len(t, Connections.GetStats(""), 3)
Connections.RLock()
assert.Len(t, Connections.sshConnections, 2)
Connections.RUnlock()
@@ -673,12 +700,12 @@ func TestIdleConnections(t *testing.T) {
return len(Connections.sshConnections) == 1
}, 1*time.Second, 200*time.Millisecond)
stopEventScheduler()
assert.Len(t, Connections.GetStats(), 2)
assert.Len(t, Connections.GetStats(""), 2)
c.lastActivity.Store(time.Now().Add(-24 * time.Hour).UnixNano())
cFTP.lastActivity.Store(time.Now().Add(-24 * time.Hour).UnixNano())
sshConn2.lastActivity.Store(c.lastActivity.Load())
startPeriodicChecks(100 * time.Millisecond)
assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 2*time.Second, 200*time.Millisecond)
assert.Eventually(t, func() bool { return len(Connections.GetStats("")) == 0 }, 2*time.Second, 200*time.Millisecond)
assert.Eventually(t, func() bool {
Connections.RLock()
defer Connections.RUnlock()
@@ -700,11 +727,11 @@ func TestCloseConnection(t *testing.T) {
assert.NoError(t, Connections.IsNewConnectionAllowed("127.0.0.1"))
err := Connections.Add(fakeConn)
assert.NoError(t, err)
assert.Len(t, Connections.GetStats(), 1)
res := Connections.Close(fakeConn.GetID())
assert.Len(t, Connections.GetStats(""), 1)
res := Connections.Close(fakeConn.GetID(), "")
assert.True(t, res)
assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 300*time.Millisecond, 50*time.Millisecond)
res = Connections.Close(fakeConn.GetID())
assert.Eventually(t, func() bool { return len(Connections.GetStats("")) == 0 }, 300*time.Millisecond, 50*time.Millisecond)
res = Connections.Close(fakeConn.GetID(), "")
assert.False(t, res)
Connections.Remove(fakeConn.GetID())
}
@@ -716,8 +743,8 @@ func TestSwapConnection(t *testing.T) {
}
err := Connections.Add(fakeConn)
assert.NoError(t, err)
if assert.Len(t, Connections.GetStats(), 1) {
assert.Equal(t, "", Connections.GetStats()[0].Username)
if assert.Len(t, Connections.GetStats(""), 1) {
assert.Equal(t, "", Connections.GetStats("")[0].Username)
}
c = NewBaseConnection("id", ProtocolFTP, "", "", dataprovider.User{
BaseUser: sdk.BaseUser{
@@ -743,12 +770,12 @@ func TestSwapConnection(t *testing.T) {
Connections.Remove(fakeConn1.ID)
err = Connections.Swap(fakeConn)
assert.NoError(t, err)
if assert.Len(t, Connections.GetStats(), 1) {
assert.Equal(t, userTestUsername, Connections.GetStats()[0].Username)
if assert.Len(t, Connections.GetStats(""), 1) {
assert.Equal(t, userTestUsername, Connections.GetStats("")[0].Username)
}
res := Connections.Close(fakeConn.GetID())
res := Connections.Close(fakeConn.GetID(), "")
assert.True(t, res)
assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 300*time.Millisecond, 50*time.Millisecond)
assert.Eventually(t, func() bool { return len(Connections.GetStats("")) == 0 }, 300*time.Millisecond, 50*time.Millisecond)
err = Connections.Swap(fakeConn)
assert.Error(t, err)
}
@@ -800,7 +827,7 @@ func TestConnectionStatus(t *testing.T) {
err = Connections.Add(fakeConn3)
assert.NoError(t, err)
stats := Connections.GetStats()
stats := Connections.GetStats("")
assert.Len(t, stats, 3)
for _, stat := range stats {
assert.Equal(t, stat.Username, username)
@@ -838,24 +865,24 @@ func TestConnectionStatus(t *testing.T) {
assert.Error(t, err)
Connections.Remove(fakeConn1.GetID())
stats = Connections.GetStats()
stats = Connections.GetStats("")
assert.Len(t, stats, 2)
assert.Equal(t, fakeConn3.GetID(), stats[0].ConnectionID)
assert.Equal(t, fakeConn2.GetID(), stats[1].ConnectionID)
Connections.Remove(fakeConn2.GetID())
stats = Connections.GetStats()
stats = Connections.GetStats("")
assert.Len(t, stats, 1)
assert.Equal(t, fakeConn3.GetID(), stats[0].ConnectionID)
Connections.Remove(fakeConn3.GetID())
stats = Connections.GetStats()
stats = Connections.GetStats("")
assert.Len(t, stats, 0)
}
func TestQuotaScans(t *testing.T) {
username := "username"
assert.True(t, QuotaScans.AddUserQuotaScan(username))
assert.False(t, QuotaScans.AddUserQuotaScan(username))
usersScans := QuotaScans.GetUsersQuotaScans()
assert.True(t, QuotaScans.AddUserQuotaScan(username, ""))
assert.False(t, QuotaScans.AddUserQuotaScan(username, ""))
usersScans := QuotaScans.GetUsersQuotaScans("")
if assert.Len(t, usersScans, 1) {
assert.Equal(t, usersScans[0].Username, username)
assert.Equal(t, QuotaScans.UserScans[0].StartTime, usersScans[0].StartTime)
@@ -865,7 +892,7 @@ func TestQuotaScans(t *testing.T) {
assert.True(t, QuotaScans.RemoveUserQuotaScan(username))
assert.False(t, QuotaScans.RemoveUserQuotaScan(username))
assert.Len(t, QuotaScans.GetUsersQuotaScans(), 0)
assert.Len(t, QuotaScans.GetUsersQuotaScans(""), 0)
assert.Len(t, usersScans, 1)
folderName := "folder"
@@ -880,6 +907,24 @@ func TestQuotaScans(t *testing.T) {
assert.Len(t, QuotaScans.GetVFoldersQuotaScans(), 0)
}
func TestQuotaScansRole(t *testing.T) {
username := "u"
role1 := "r1"
role2 := "r2"
assert.True(t, QuotaScans.AddUserQuotaScan(username, role1))
assert.False(t, QuotaScans.AddUserQuotaScan(username, ""))
usersScans := QuotaScans.GetUsersQuotaScans("")
assert.Len(t, usersScans, 1)
assert.Empty(t, usersScans[0].Role)
usersScans = QuotaScans.GetUsersQuotaScans(role1)
assert.Len(t, usersScans, 1)
usersScans = QuotaScans.GetUsersQuotaScans(role2)
assert.Len(t, usersScans, 0)
assert.True(t, QuotaScans.RemoveUserQuotaScan(username))
assert.False(t, QuotaScans.RemoveUserQuotaScan(username))
assert.Len(t, QuotaScans.GetUsersQuotaScans(""), 0)
}
func TestProxyProtocolVersion(t *testing.T) {
c := Configuration{
ProxyProtocol: 0,
@@ -1342,13 +1387,13 @@ func TestUpdateTransferTimestamps(t *testing.T) {
err = dataprovider.UpdateUserTransferTimestamps(username, true)
assert.NoError(t, err)
userGet, err := dataprovider.UserExists(username)
userGet, err := dataprovider.UserExists(username, "")
assert.NoError(t, err)
assert.Greater(t, userGet.FirstUpload, int64(0))
assert.Equal(t, int64(0), user.FirstDownload)
err = dataprovider.UpdateUserTransferTimestamps(username, false)
assert.NoError(t, err)
userGet, err = dataprovider.UserExists(username)
userGet, err = dataprovider.UserExists(username, "")
assert.NoError(t, err)
assert.Greater(t, userGet.FirstUpload, int64(0))
assert.Greater(t, userGet.FirstDownload, int64(0))
@@ -1358,23 +1403,40 @@ func TestUpdateTransferTimestamps(t *testing.T) {
err = dataprovider.UpdateUserTransferTimestamps(username, false)
assert.Error(t, err)
// cleanup
err = dataprovider.DeleteUser(username, "", "")
err = dataprovider.DeleteUser(username, "", "", "")
assert.NoError(t, err)
}
func TestMetadataAPI(t *testing.T) {
username := "metadatauser"
require.False(t, ActiveMetadataChecks.Remove(username))
require.True(t, ActiveMetadataChecks.Add(username))
require.False(t, ActiveMetadataChecks.Add(username))
checks := ActiveMetadataChecks.Get()
require.True(t, ActiveMetadataChecks.Add(username, ""))
require.False(t, ActiveMetadataChecks.Add(username, ""))
checks := ActiveMetadataChecks.Get("")
require.Len(t, checks, 1)
checks[0].Username = username + "a"
checks = ActiveMetadataChecks.Get()
checks = ActiveMetadataChecks.Get("")
require.Len(t, checks, 1)
require.Equal(t, username, checks[0].Username)
require.True(t, ActiveMetadataChecks.Remove(username))
require.Len(t, ActiveMetadataChecks.Get(), 0)
require.Len(t, ActiveMetadataChecks.Get(""), 0)
}
func TestMetadataAPIRole(t *testing.T) {
username := "muser"
role1 := "r1"
role2 := "r2"
require.True(t, ActiveMetadataChecks.Add(username, role2))
require.False(t, ActiveMetadataChecks.Add(username, ""))
checks := ActiveMetadataChecks.Get("")
require.Len(t, checks, 1)
assert.Empty(t, checks[0].Role)
checks = ActiveMetadataChecks.Get(role1)
require.Len(t, checks, 0)
checks = ActiveMetadataChecks.Get(role2)
require.Len(t, checks, 1)
require.True(t, ActiveMetadataChecks.Remove(username))
require.Len(t, ActiveMetadataChecks.Get(""), 0)
}
func BenchmarkBcryptHashing(b *testing.B) {