SQL providers: make sure we don't exceed the allowed placeholders

Fixes #1415

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
Nicola Murino
2023-09-12 19:16:54 +02:00
parent 9906caefd5
commit cf1cc25a48
8 changed files with 207 additions and 66 deletions

View File

@@ -958,63 +958,86 @@ func sqlCommonGetUsersInGroups(names []string, dbHandle sqlQuerier) ([]string, e
if len(names) == 0 {
return nil, nil
}
maxNames := len(sqlPlaceholders)
usernames := make([]string, 0, len(names))
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
q := getUsersInGroupsQuery(len(names))
args := make([]any, 0, len(names))
for _, name := range names {
args = append(args, name)
}
for len(names) > 0 {
if maxNames > len(names) {
maxNames = len(names)
}
usernames := make([]string, 0, len(names))
rows, err := dbHandle.QueryContext(ctx, q, args...)
if err != nil {
return nil, err
}
defer rows.Close()
q := getUsersInGroupsQuery(maxNames)
args := make([]any, 0, maxNames)
for _, name := range names[:maxNames] {
args = append(args, name)
}
for rows.Next() {
var username string
err = rows.Scan(&username)
rows, err := dbHandle.QueryContext(ctx, q, args...)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var username string
err = rows.Scan(&username)
if err != nil {
return usernames, err
}
usernames = append(usernames, username)
}
err = rows.Err()
if err != nil {
return usernames, err
}
usernames = append(usernames, username)
names = names[maxNames:]
}
return usernames, rows.Err()
return usernames, nil
}
func sqlCommonGetGroupsWithNames(names []string, dbHandle sqlQuerier) ([]Group, error) {
if len(names) == 0 {
return nil, nil
}
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
q := getGroupsWithNamesQuery(len(names))
args := make([]any, 0, len(names))
for _, name := range names {
args = append(args, name)
}
maxNames := len(sqlPlaceholders)
groups := make([]Group, 0, len(names))
rows, err := dbHandle.QueryContext(ctx, q, args...)
if err != nil {
return groups, err
}
defer rows.Close()
for len(names) > 0 {
if maxNames > len(names) {
maxNames = len(names)
}
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
for rows.Next() {
group, err := getGroupFromDbRow(rows)
q := getGroupsWithNamesQuery(maxNames)
args := make([]any, 0, maxNames)
for _, name := range names[:maxNames] {
args = append(args, name)
}
rows, err := dbHandle.QueryContext(ctx, q, args...)
if err != nil {
return groups, err
}
groups = append(groups, group)
}
err = rows.Err()
if err != nil {
return groups, err
defer rows.Close()
for rows.Next() {
group, err := getGroupFromDbRow(rows)
if err != nil {
return groups, err
}
groups = append(groups, group)
}
err = rows.Err()
if err != nil {
return groups, err
}
names = names[maxNames:]
}
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
return getGroupsWithVirtualFolders(ctx, groups, dbHandle)
}
@@ -1535,6 +1558,9 @@ func sqlCommonGetRecentlyUpdatedUsers(after int64, dbHandle sqlQuerier) ([]User,
}
}
groupNames = util.RemoveDuplicates(groupNames, false)
if len(groupNames) == 0 {
return users, nil
}
groups, err := sqlCommonGetGroupsWithNames(groupNames, dbHandle)
if err != nil {
return users, err
@@ -1553,15 +1579,23 @@ func sqlCommonGetRecentlyUpdatedUsers(after int64, dbHandle sqlQuerier) ([]User,
return users, nil
}
func sqlGetMaxUsersForQuotaCheckRange() int {
maxUsers := 50
if maxUsers > len(sqlPlaceholders) {
maxUsers = len(sqlPlaceholders)
}
return maxUsers
}
func sqlCommonGetUsersForQuotaCheck(toFetch map[string]bool, dbHandle sqlQuerier) ([]User, error) {
users := make([]User, 0, 30)
maxUsers := sqlGetMaxUsersForQuotaCheckRange()
users := make([]User, 0, maxUsers)
usernames := make([]string, 0, len(toFetch))
for k := range toFetch {
usernames = append(usernames, k)
}
maxUsers := 30
for len(usernames) > 0 {
if maxUsers > len(usernames) {
maxUsers = len(usernames)