SQL providers: make sure we don't exceed the allowed placeholders

Fixes #1415

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
Nicola Murino
2023-09-12 19:15:40 +02:00
parent fac022090d
commit e54fd46a9e
8 changed files with 206 additions and 63 deletions

View File

@@ -1644,6 +1644,112 @@ func TestIPList(t *testing.T) {
}
}
func TestSQLPlaceholderLimits(t *testing.T) {
numGroups := 120
numUsers := 120
var groupMapping []sdk.GroupMapping
folder := vfs.BaseVirtualFolder{
Name: "testfolder",
MappedPath: filepath.Join(os.TempDir(), "folder"),
}
err := dataprovider.AddFolder(&folder, "", "", "")
assert.NoError(t, err)
for i := 0; i < numGroups; i++ {
group := dataprovider.Group{
BaseGroup: sdk.BaseGroup{
Name: fmt.Sprintf("testgroup%d", i),
},
UserSettings: dataprovider.GroupUserSettings{
BaseGroupUserSettings: sdk.BaseGroupUserSettings{
Permissions: map[string][]string{
fmt.Sprintf("/dir%d", i): {dataprovider.PermAny},
},
},
},
}
group.VirtualFolders = append(group.VirtualFolders, vfs.VirtualFolder{
BaseVirtualFolder: folder,
VirtualPath: "/vdir",
})
err := dataprovider.AddGroup(&group, "", "", "")
assert.NoError(t, err)
groupMapping = append(groupMapping, sdk.GroupMapping{
Name: group.Name,
Type: sdk.GroupTypeSecondary,
})
}
user := dataprovider.User{
BaseUser: sdk.BaseUser{
Username: "testusername",
HomeDir: filepath.Join(os.TempDir(), "testhome"),
Status: 1,
Permissions: map[string][]string{
"/": {dataprovider.PermAny},
},
},
Groups: groupMapping,
}
err = dataprovider.AddUser(&user, "", "", "")
assert.NoError(t, err)
users, err := dataprovider.GetUsersForQuotaCheck(map[string]bool{user.Username: true})
assert.NoError(t, err)
if assert.Len(t, users, 1) {
for i := 0; i < numGroups; i++ {
_, ok := users[0].Permissions[fmt.Sprintf("/dir%d", i)]
assert.True(t, ok)
}
}
err = dataprovider.DeleteUser(user.Username, "", "", "")
assert.NoError(t, err)
for i := 0; i < numUsers; i++ {
user := dataprovider.User{
BaseUser: sdk.BaseUser{
Username: fmt.Sprintf("testusername%d", i),
HomeDir: filepath.Join(os.TempDir()),
Status: 1,
Permissions: map[string][]string{
"/": {dataprovider.PermAny},
},
},
Groups: []sdk.GroupMapping{
{
Name: "testgroup0",
Type: sdk.GroupTypePrimary,
},
},
}
err := dataprovider.AddUser(&user, "", "", "")
assert.NoError(t, err)
}
time.Sleep(100 * time.Millisecond)
err = dataprovider.DeleteFolder(folder.Name, "", "", "")
assert.NoError(t, err)
for i := 0; i < numUsers; i++ {
username := fmt.Sprintf("testusername%d", i)
user, err := dataprovider.UserExists(username, "")
assert.NoError(t, err)
assert.Greater(t, user.UpdatedAt, user.CreatedAt)
err = dataprovider.DeleteUser(username, "", "", "")
assert.NoError(t, err)
}
for i := 0; i < numGroups; i++ {
groupName := fmt.Sprintf("testgroup%d", i)
err = dataprovider.DeleteGroup(groupName, "", "", "")
assert.NoError(t, err)
}
}
func BenchmarkBcryptHashing(b *testing.B) {
bcryptPassword := "bcryptpassword"
for i := 0; i < b.N; i++ {

View File

@@ -597,7 +597,7 @@ func TestDataTransferExceeded(t *testing.T) {
func TestGetUsersForQuotaCheck(t *testing.T) {
usersToFetch := make(map[string]bool)
for i := 0; i < 50; i++ {
for i := 0; i < 70; i++ {
usersToFetch[fmt.Sprintf("user%v", i)] = i%2 == 0
}
@@ -605,7 +605,7 @@ func TestGetUsersForQuotaCheck(t *testing.T) {
assert.NoError(t, err)
assert.Len(t, users, 0)
for i := 0; i < 40; i++ {
for i := 0; i < 60; i++ {
folder := vfs.BaseVirtualFolder{
Name: fmt.Sprintf("f%v", i),
MappedPath: filepath.Join(os.TempDir(), fmt.Sprintf("f%v", i)),
@@ -641,7 +641,7 @@ func TestGetUsersForQuotaCheck(t *testing.T) {
users, err = dataprovider.GetUsersForQuotaCheck(usersToFetch)
assert.NoError(t, err)
assert.Len(t, users, 40)
assert.Len(t, users, 60)
for _, user := range users {
userIdxStr := strings.Replace(user.Username, "user", "", 1)
@@ -665,7 +665,7 @@ func TestGetUsersForQuotaCheck(t *testing.T) {
assert.Equal(t, int64(0), total)
}
for i := 0; i < 40; i++ {
for i := 0; i < 60; i++ {
err = dataprovider.DeleteUser(fmt.Sprintf("user%v", i), "", "", "")
assert.NoError(t, err)
err = dataprovider.DeleteFolder(fmt.Sprintf("f%v", i), "", "", "")