add groups support

Using groups simplifies the administration of multiple accounts by
letting you assign settings once to a group, instead of multiple
times to each individual user.

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
Nicola Murino
2022-04-25 15:49:11 +02:00
parent 857b6cc10a
commit 504cd3efda
53 changed files with 6986 additions and 1076 deletions

View File

@@ -11,6 +11,7 @@ import (
"time"
"github.com/cockroachdb/cockroach-go/v2/crdb"
"github.com/sftpgo/sdk"
"github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/util"
@@ -18,13 +19,15 @@ import (
)
const (
sqlDatabaseVersion = 16
sqlDatabaseVersion = 17
defaultSQLQueryTimeout = 10 * time.Second
longSQLQueryTimeout = 60 * time.Second
)
var (
errSQLFoldersAssosaction = errors.New("unable to associate virtual folders to user")
errSQLFoldersAssociation = errors.New("unable to associate virtual folders to user")
errSQLGroupsAssociation = errors.New("unable to associate groups to user")
errSQLUsersAssociation = errors.New("unable to associate users to group")
errSchemaVersionEmpty = errors.New("we can't determine schema version because the schema_migration table is empty. The SFTPGo database might be corrupted. Consider using the \"resetprovider\" sub-command")
)
@@ -36,6 +39,25 @@ type sqlScanner interface {
Scan(dest ...interface{}) error
}
func sqlReplaceAll(sql string) string {
sql = strings.ReplaceAll(sql, "{{schema_version}}", sqlTableSchemaVersion)
sql = strings.ReplaceAll(sql, "{{admins}}", sqlTableAdmins)
sql = strings.ReplaceAll(sql, "{{folders}}", sqlTableFolders)
sql = strings.ReplaceAll(sql, "{{users}}", sqlTableUsers)
sql = strings.ReplaceAll(sql, "{{groups}}", sqlTableGroups)
sql = strings.ReplaceAll(sql, "{{folders_mapping}}", sqlTableFoldersMapping)
sql = strings.ReplaceAll(sql, "{{users_folders_mapping}}", sqlTableUsersFoldersMapping)
sql = strings.ReplaceAll(sql, "{{users_groups_mapping}}", sqlTableUsersGroupsMapping)
sql = strings.ReplaceAll(sql, "{{groups_folders_mapping}}", sqlTableGroupsFoldersMapping)
sql = strings.ReplaceAll(sql, "{{api_keys}}", sqlTableAPIKeys)
sql = strings.ReplaceAll(sql, "{{shares}}", sqlTableShares)
sql = strings.ReplaceAll(sql, "{{defender_events}}", sqlTableDefenderEvents)
sql = strings.ReplaceAll(sql, "{{defender_hosts}}", sqlTableDefenderHosts)
sql = strings.ReplaceAll(sql, "{{active_transfers}}", sqlTableActiveTransfers)
sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix)
return sql
}
func sqlCommonGetShareByID(shareID, username string, dbHandle sqlQuerier) (Share, error) {
var share Share
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
@@ -169,7 +191,7 @@ func sqlCommonUpdateShare(share *Share, dbHandle *sql.DB) error {
return err
}
func sqlCommonDeleteShare(share *Share, dbHandle *sql.DB) error {
func sqlCommonDeleteShare(share Share, dbHandle *sql.DB) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
@@ -319,7 +341,7 @@ func sqlCommonUpdateAPIKey(apiKey *APIKey, dbHandle *sql.DB) error {
return err
}
func sqlCommonDeleteAPIKey(apiKey *APIKey, dbHandle *sql.DB) error {
func sqlCommonDeleteAPIKey(apiKey APIKey, dbHandle *sql.DB) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
q := getDeleteAPIKeyQuery()
@@ -499,7 +521,7 @@ func sqlCommonUpdateAdmin(admin *Admin, dbHandle *sql.DB) error {
return err
}
func sqlCommonDeleteAdmin(admin *Admin, dbHandle *sql.DB) error {
func sqlCommonDeleteAdmin(admin Admin, dbHandle *sql.DB) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
q := getDeleteAdminQuery()
@@ -574,10 +596,264 @@ func sqlCommonDumpAdmins(dbHandle sqlQuerier) ([]Admin, error) {
return admins, rows.Err()
}
func sqlCommonGetGroupByName(name string, dbHandle sqlQuerier) (Group, error) {
var group Group
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
q := getGroupByNameQuery()
stmt, err := dbHandle.PrepareContext(ctx, q)
if err != nil {
providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
return group, err
}
defer stmt.Close()
row := stmt.QueryRowContext(ctx, name)
group, err = getGroupFromDbRow(row)
if err != nil {
return group, err
}
group, err = getGroupWithVirtualFolders(ctx, group, dbHandle)
if err != nil {
return group, err
}
return getGroupWithUsers(ctx, group, dbHandle)
}
func sqlCommonDumpGroups(dbHandle sqlQuerier) ([]Group, error) {
groups := make([]Group, 0, 50)
ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
defer cancel()
q := getDumpGroupsQuery()
stmt, err := dbHandle.PrepareContext(ctx, q)
if err != nil {
providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
return nil, err
}
defer stmt.Close()
rows, err := stmt.QueryContext(ctx)
if err != nil {
return groups, err
}
defer rows.Close()
for rows.Next() {
group, err := getGroupFromDbRow(rows)
if err != nil {
return groups, err
}
group.PrepareForRendering()
groups = append(groups, group)
}
err = rows.Err()
return groups, err
}
func sqlCommonGetUsersInGroups(names []string, dbHandle sqlQuerier) ([]string, error) {
if len(names) == 0 {
return nil, nil
}
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
q := getUsersInGroupsQuery(len(names))
stmt, err := dbHandle.PrepareContext(ctx, q)
if err != nil {
providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
return nil, err
}
defer stmt.Close()
args := make([]interface{}, 0, len(names))
for _, name := range names {
args = append(args, name)
}
usernames := make([]string, 0, len(names))
rows, err := stmt.QueryContext(ctx, args...)
if err == nil {
defer rows.Close()
for rows.Next() {
var username string
err = rows.Scan(&username)
if err != nil {
return usernames, err
}
usernames = append(usernames, username)
}
}
return usernames, rows.Err()
}
func sqlCommonGetGroupsWithNames(names []string, dbHandle sqlQuerier) ([]Group, error) {
if len(names) == 0 {
return nil, nil
}
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
q := getGroupsWithNamesQuery(len(names))
stmt, err := dbHandle.PrepareContext(ctx, q)
if err != nil {
providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
return nil, err
}
defer stmt.Close()
args := make([]interface{}, 0, len(names))
for _, name := range names {
args = append(args, name)
}
groups := make([]Group, 0, len(names))
rows, err := stmt.QueryContext(ctx, args...)
if err == nil {
defer rows.Close()
for rows.Next() {
group, err := getGroupFromDbRow(rows)
if err != nil {
return groups, err
}
groups = append(groups, group)
}
}
err = rows.Err()
if err != nil {
return groups, err
}
return getGroupsWithVirtualFolders(ctx, groups, dbHandle)
}
func sqlCommonGetGroups(limit int, offset int, order string, minimal bool, dbHandle sqlQuerier) ([]Group, error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
q := getGroupsQuery(order, minimal)
stmt, err := dbHandle.PrepareContext(ctx, q)
if err != nil {
providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
return nil, err
}
defer stmt.Close()
groups := make([]Group, 0, limit)
rows, err := stmt.QueryContext(ctx, limit, offset)
if err == nil {
defer rows.Close()
for rows.Next() {
var group Group
if minimal {
err = rows.Scan(&group.ID, &group.Name)
} else {
group, err = getGroupFromDbRow(rows)
}
if err != nil {
return groups, err
}
groups = append(groups, group)
}
}
err = rows.Err()
if err != nil {
return groups, err
}
if minimal {
return groups, nil
}
groups, err = getGroupsWithVirtualFolders(ctx, groups, dbHandle)
if err != nil {
return groups, err
}
groups, err = getGroupsWithUsers(ctx, groups, dbHandle)
if err != nil {
return groups, err
}
for idx := range groups {
groups[idx].PrepareForRendering()
}
return groups, nil
}
func sqlCommonAddGroup(group *Group, dbHandle *sql.DB) error {
if err := group.validate(); err != nil {
return err
}
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
q := getAddGroupQuery()
stmt, err := tx.PrepareContext(ctx, q)
if err != nil {
providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
return err
}
defer stmt.Close()
settings, err := json.Marshal(group.UserSettings)
if err != nil {
return err
}
_, err = stmt.ExecContext(ctx, group.Name, group.Description, util.GetTimeAsMsSinceEpoch(time.Now()),
util.GetTimeAsMsSinceEpoch(time.Now()), string(settings))
if err != nil {
return err
}
return generateGroupVirtualFoldersMapping(ctx, group, tx)
})
}
func sqlCommonUpdateGroup(group *Group, dbHandle *sql.DB) error {
if err := group.validate(); err != nil {
return err
}
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
q := getUpdateGroupQuery()
stmt, err := tx.PrepareContext(ctx, q)
if err != nil {
providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
return err
}
defer stmt.Close()
settings, err := json.Marshal(group.UserSettings)
if err != nil {
return err
}
_, err = stmt.ExecContext(ctx, group.Description, settings, util.GetTimeAsMsSinceEpoch(time.Now()), group.Name)
if err != nil {
return err
}
return generateGroupVirtualFoldersMapping(ctx, group, tx)
})
}
func sqlCommonDeleteGroup(group Group, dbHandle *sql.DB) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
q := getDeleteGroupQuery()
stmt, err := dbHandle.PrepareContext(ctx, q)
if err != nil {
providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
return err
}
defer stmt.Close()
_, err = stmt.ExecContext(ctx, group.Name)
return err
}
func sqlCommonGetUserByUsername(username string, dbHandle sqlQuerier) (User, error) {
var user User
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
q := getUserByUsernameQuery()
stmt, err := dbHandle.PrepareContext(ctx, q)
if err != nil {
@@ -591,7 +867,11 @@ func sqlCommonGetUserByUsername(username string, dbHandle sqlQuerier) (User, err
if err != nil {
return user, err
}
return getUserWithVirtualFolders(ctx, user, dbHandle)
user, err = getUserWithVirtualFolders(ctx, user, dbHandle)
if err != nil {
return user, err
}
return getUserWithGroups(ctx, user, dbHandle)
}
func sqlCommonValidateUserAndPass(username, password, ip, protocol string, dbHandle *sql.DB) (User, error) {
@@ -834,7 +1114,10 @@ func sqlCommonAddUser(user *User, dbHandle *sql.DB) error {
if err != nil {
return err
}
return generateVirtualFoldersMapping(ctx, user, tx)
if err := generateUserVirtualFoldersMapping(ctx, user, tx); err != nil {
return err
}
return generateUserGroupMapping(ctx, user, tx)
})
}
@@ -893,11 +1176,14 @@ func sqlCommonUpdateUser(user *User, dbHandle *sql.DB) error {
if err != nil {
return err
}
return generateVirtualFoldersMapping(ctx, user, tx)
if err := generateUserVirtualFoldersMapping(ctx, user, tx); err != nil {
return err
}
return generateUserGroupMapping(ctx, user, tx)
})
}
func sqlCommonDeleteUser(user *User, dbHandle *sql.DB) error {
func sqlCommonDeleteUser(user User, dbHandle *sql.DB) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
q := getDeleteUserQuery()
@@ -943,7 +1229,11 @@ func sqlCommonDumpUsers(dbHandle sqlQuerier) ([]User, error) {
if err != nil {
return users, err
}
return getUsersWithVirtualFolders(ctx, users, dbHandle)
users, err = getUsersWithVirtualFolders(ctx, users, dbHandle)
if err != nil {
return users, err
}
return getUsersWithGroups(ctx, users, dbHandle)
}
func sqlCommonGetRecentlyUpdatedUsers(after int64, dbHandle sqlQuerier) ([]User, error) {
@@ -973,7 +1263,34 @@ func sqlCommonGetRecentlyUpdatedUsers(after int64, dbHandle sqlQuerier) ([]User,
if err != nil {
return users, err
}
return getUsersWithVirtualFolders(ctx, users, dbHandle)
users, err = getUsersWithVirtualFolders(ctx, users, dbHandle)
if err != nil {
return users, err
}
users, err = getUsersWithGroups(ctx, users, dbHandle)
if err != nil {
return users, err
}
var groupNames []string
for _, u := range users {
for _, g := range u.Groups {
groupNames = append(groupNames, g.Name)
}
}
groupNames = util.RemoveDuplicates(groupNames)
groups, err := sqlCommonGetGroupsWithNames(groupNames, dbHandle)
if err != nil {
return users, err
}
groupsMapping := make(map[string]Group)
for _, g := range groups {
groupsMapping[g.Name] = g
}
for idx := range users {
ref := &users[idx]
ref.applyGroupSettings(groupsMapping)
}
return users, nil
}
func sqlCommonGetUsersForQuotaCheck(toFetch map[string]bool, dbHandle sqlQuerier) ([]User, error) {
@@ -1021,6 +1338,29 @@ func sqlCommonGetUsersForQuotaCheck(toFetch map[string]bool, dbHandle sqlQuerier
return users, err
}
users = append(users, usersWithFolders...)
users, err = getUsersWithGroups(ctx, users, dbHandle)
if err != nil {
return users, err
}
var groupNames []string
for _, u := range users {
for _, g := range u.Groups {
groupNames = append(groupNames, g.Name)
}
}
groupNames = util.RemoveDuplicates(groupNames)
groups, err := sqlCommonGetGroupsWithNames(groupNames, dbHandle)
if err != nil {
return users, err
}
groupsMapping := make(map[string]Group)
for _, g := range groups {
groupsMapping[g.Name] = g
}
for idx := range users {
ref := &users[idx]
ref.applyGroupSettings(groupsMapping)
}
return users, nil
}
@@ -1188,7 +1528,6 @@ func sqlCommonGetUsers(limit int, offset int, order string, dbHandle sqlQuerier)
if err != nil {
return users, err
}
u.PrepareForRendering()
users = append(users, u)
}
}
@@ -1196,7 +1535,18 @@ func sqlCommonGetUsers(limit int, offset int, order string, dbHandle sqlQuerier)
if err != nil {
return users, err
}
return getUsersWithVirtualFolders(ctx, users, dbHandle)
users, err = getUsersWithVirtualFolders(ctx, users, dbHandle)
if err != nil {
return users, err
}
users, err = getUsersWithGroups(ctx, users, dbHandle)
if err != nil {
return users, err
}
for idx := range users {
users[idx].PrepareForRendering()
}
return users, nil
}
func sqlCommonGetDefenderHosts(from int64, limit int, dbHandle sqlQuerier) ([]DefenderEntry, error) {
@@ -1572,6 +1922,31 @@ func getAdminFromDbRow(row sqlScanner) (Admin, error) {
return admin, nil
}
func getGroupFromDbRow(row sqlScanner) (Group, error) {
var group Group
var userSettings, description sql.NullString
err := row.Scan(&group.ID, &group.Name, &description, &group.CreatedAt, &group.UpdatedAt, &userSettings)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return group, util.NewRecordNotFoundError(err.Error())
}
return group, err
}
if description.Valid {
group.Description = description.String
}
if userSettings.Valid {
var settings GroupUserSettings
err = json.Unmarshal([]byte(userSettings.String), &settings)
if err == nil {
group.UserSettings = settings
}
}
return group, nil
}
func getUserFromDbRow(row sqlScanner) (User, error) {
var user User
var permissions sql.NullString
@@ -1701,6 +2076,13 @@ func sqlCommonGetFolderByName(ctx context.Context, name string, dbHandle sqlQuer
if len(folders) != 1 {
return folder, fmt.Errorf("unable to associate users with folder %#v", name)
}
folders, err = getVirtualFoldersWithGroups([]vfs.BaseVirtualFolder{folders[0]}, dbHandle)
if err != nil {
return folder, err
}
if len(folders) != 1 {
return folder, fmt.Errorf("unable to associate groups with folder %#v", name)
}
return folders[0], nil
}
@@ -1775,7 +2157,7 @@ func sqlCommonUpdateFolder(folder *vfs.BaseVirtualFolder, dbHandle sqlQuerier) e
return err
}
func sqlCommonDeleteFolder(folder *vfs.BaseVirtualFolder, dbHandle sqlQuerier) error {
func sqlCommonDeleteFolder(folder vfs.BaseVirtualFolder, dbHandle sqlQuerier) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
q := getDeleteFolderQuery()
@@ -1829,17 +2211,14 @@ func sqlCommonDumpFolders(dbHandle sqlQuerier) ([]vfs.BaseVirtualFolder, error)
folders = append(folders, folder)
}
err = rows.Err()
if err != nil {
return folders, err
}
return getVirtualFoldersWithUsers(folders, dbHandle)
return folders, err
}
func sqlCommonGetFolders(limit, offset int, order string, dbHandle sqlQuerier) ([]vfs.BaseVirtualFolder, error) {
func sqlCommonGetFolders(limit, offset int, order string, minimal bool, dbHandle sqlQuerier) ([]vfs.BaseVirtualFolder, error) {
folders := make([]vfs.BaseVirtualFolder, 0, limit)
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
q := getFoldersQuery(order)
q := getFoldersQuery(order, minimal)
stmt, err := dbHandle.PrepareContext(ctx, q)
if err != nil {
providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
@@ -1854,23 +2233,30 @@ func sqlCommonGetFolders(limit, offset int, order string, dbHandle sqlQuerier) (
defer rows.Close()
for rows.Next() {
var folder vfs.BaseVirtualFolder
var mappedPath, description, fsConfig sql.NullString
err = rows.Scan(&folder.ID, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles,
&folder.LastQuotaUpdate, &folder.Name, &description, &fsConfig)
if err != nil {
return folders, err
}
if mappedPath.Valid {
folder.MappedPath = mappedPath.String
}
if description.Valid {
folder.Description = description.String
}
if fsConfig.Valid {
var fs vfs.Filesystem
err = json.Unmarshal([]byte(fsConfig.String), &fs)
if err == nil {
folder.FsConfig = fs
if minimal {
err = rows.Scan(&folder.ID, &folder.Name)
if err != nil {
return folders, err
}
} else {
var mappedPath, description, fsConfig sql.NullString
err = rows.Scan(&folder.ID, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles,
&folder.LastQuotaUpdate, &folder.Name, &description, &fsConfig)
if err != nil {
return folders, err
}
if mappedPath.Valid {
folder.MappedPath = mappedPath.String
}
if description.Valid {
folder.Description = description.String
}
if fsConfig.Valid {
var fs vfs.Filesystem
err = json.Unmarshal([]byte(fsConfig.String), &fs)
if err == nil {
folder.FsConfig = fs
}
}
}
folder.PrepareForRendering()
@@ -1881,11 +2267,18 @@ func sqlCommonGetFolders(limit, offset int, order string, dbHandle sqlQuerier) (
if err != nil {
return folders, err
}
return getVirtualFoldersWithUsers(folders, dbHandle)
if minimal {
return folders, nil
}
folders, err = getVirtualFoldersWithUsers(folders, dbHandle)
if err != nil {
return folders, err
}
return getVirtualFoldersWithGroups(folders, dbHandle)
}
func sqlCommonClearFolderMapping(ctx context.Context, user *User, dbHandle sqlQuerier) error {
q := getClearFolderMappingQuery()
func sqlCommonClearUserFolderMapping(ctx context.Context, user *User, dbHandle sqlQuerier) error {
q := getClearUserFolderMappingQuery()
stmt, err := dbHandle.PrepareContext(ctx, q)
if err != nil {
providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
@@ -1896,8 +2289,32 @@ func sqlCommonClearFolderMapping(ctx context.Context, user *User, dbHandle sqlQu
return err
}
func sqlCommonAddFolderMapping(ctx context.Context, user *User, folder *vfs.VirtualFolder, dbHandle sqlQuerier) error {
q := getAddFolderMappingQuery()
func sqlCommonClearGroupFolderMapping(ctx context.Context, group *Group, dbHandle sqlQuerier) error {
q := getClearGroupFolderMappingQuery()
stmt, err := dbHandle.PrepareContext(ctx, q)
if err != nil {
providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
return err
}
defer stmt.Close()
_, err = stmt.ExecContext(ctx, group.Name)
return err
}
func sqlCommonClearUserGroupMapping(ctx context.Context, user *User, dbHandle sqlQuerier) error {
q := getClearUserGroupMappingQuery()
stmt, err := dbHandle.PrepareContext(ctx, q)
if err != nil {
providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
return err
}
defer stmt.Close()
_, err = stmt.ExecContext(ctx, user.Username)
return err
}
func sqlCommonAddUserFolderMapping(ctx context.Context, user *User, folder *vfs.VirtualFolder, dbHandle sqlQuerier) error {
q := getAddUserFolderMappingQuery()
stmt, err := dbHandle.PrepareContext(ctx, q)
if err != nil {
providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
@@ -1908,8 +2325,52 @@ func sqlCommonAddFolderMapping(ctx context.Context, user *User, folder *vfs.Virt
return err
}
func generateVirtualFoldersMapping(ctx context.Context, user *User, dbHandle sqlQuerier) error {
err := sqlCommonClearFolderMapping(ctx, user, dbHandle)
func sqlCommonAddGroupFolderMapping(ctx context.Context, group *Group, folder *vfs.VirtualFolder, dbHandle sqlQuerier) error {
q := getAddGroupFolderMappingQuery()
stmt, err := dbHandle.PrepareContext(ctx, q)
if err != nil {
providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
return err
}
defer stmt.Close()
_, err = stmt.ExecContext(ctx, folder.VirtualPath, folder.QuotaSize, folder.QuotaFiles, folder.Name, group.Name)
return err
}
func sqlCommonAddUserGroupMapping(ctx context.Context, username, groupName string, groupType int, dbHandle sqlQuerier) error {
q := getAddUserGroupMappingQuery()
stmt, err := dbHandle.PrepareContext(ctx, q)
if err != nil {
providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
return err
}
defer stmt.Close()
_, err = stmt.ExecContext(ctx, username, groupName, groupType)
return err
}
func generateGroupVirtualFoldersMapping(ctx context.Context, group *Group, dbHandle sqlQuerier) error {
err := sqlCommonClearGroupFolderMapping(ctx, group, dbHandle)
if err != nil {
return err
}
for idx := range group.VirtualFolders {
vfolder := &group.VirtualFolders[idx]
f, err := sqlCommonAddOrUpdateFolder(ctx, &vfolder.BaseVirtualFolder, 0, 0, 0, dbHandle)
if err != nil {
return err
}
vfolder.BaseVirtualFolder = f
err = sqlCommonAddGroupFolderMapping(ctx, group, vfolder, dbHandle)
if err != nil {
return err
}
}
return err
}
func generateUserVirtualFoldersMapping(ctx context.Context, user *User, dbHandle sqlQuerier) error {
err := sqlCommonClearUserFolderMapping(ctx, user, dbHandle)
if err != nil {
return err
}
@@ -1920,7 +2381,7 @@ func generateVirtualFoldersMapping(ctx context.Context, user *User, dbHandle sql
return err
}
vfolder.BaseVirtualFolder = f
err = sqlCommonAddFolderMapping(ctx, user, vfolder, dbHandle)
err = sqlCommonAddUserFolderMapping(ctx, user, vfolder, dbHandle)
if err != nil {
return err
}
@@ -1928,15 +2389,18 @@ func generateVirtualFoldersMapping(ctx context.Context, user *User, dbHandle sql
return err
}
func getUserWithVirtualFolders(ctx context.Context, user User, dbHandle sqlQuerier) (User, error) {
users, err := getUsersWithVirtualFolders(ctx, []User{user}, dbHandle)
func generateUserGroupMapping(ctx context.Context, user *User, dbHandle sqlQuerier) error {
err := sqlCommonClearUserGroupMapping(ctx, user, dbHandle)
if err != nil {
return user, err
return err
}
if len(users) == 0 {
return user, errSQLFoldersAssosaction
for _, group := range user.Groups {
err = sqlCommonAddUserGroupMapping(ctx, user.Username, group.Name, group.Type, dbHandle)
if err != nil {
return err
}
}
return users[0], err
return err
}
func getDefenderHostsWithScores(ctx context.Context, hosts []DefenderEntry, from int64, idForScores []int64,
@@ -1994,6 +2458,17 @@ func getDefenderHostsWithScores(ctx context.Context, hosts []DefenderEntry, from
return result, nil
}
func getUserWithVirtualFolders(ctx context.Context, user User, dbHandle sqlQuerier) (User, error) {
users, err := getUsersWithVirtualFolders(ctx, []User{user}, dbHandle)
if err != nil {
return user, err
}
if len(users) == 0 {
return user, errSQLFoldersAssociation
}
return users[0], err
}
func getUsersWithVirtualFolders(ctx context.Context, users []User, dbHandle sqlQuerier) ([]User, error) {
if len(users) == 0 {
return users, nil
@@ -2052,13 +2527,232 @@ func getUsersWithVirtualFolders(ctx context.Context, users []User, dbHandle sqlQ
return users, err
}
func getUserWithGroups(ctx context.Context, user User, dbHandle sqlQuerier) (User, error) {
users, err := getUsersWithGroups(ctx, []User{user}, dbHandle)
if err != nil {
return user, err
}
if len(users) == 0 {
return user, errSQLGroupsAssociation
}
return users[0], err
}
func getUsersWithGroups(ctx context.Context, users []User, dbHandle sqlQuerier) ([]User, error) {
if len(users) == 0 {
return users, nil
}
var err error
usersGroups := make(map[int64][]sdk.GroupMapping)
q := getRelatedGroupsForUsersQuery(users)
stmt, err := dbHandle.PrepareContext(ctx, q)
if err != nil {
providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
return nil, err
}
defer stmt.Close()
rows, err := stmt.QueryContext(ctx)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var group sdk.GroupMapping
var userID int64
err = rows.Scan(&group.Name, &group.Type, &userID)
if err != nil {
return users, err
}
usersGroups[userID] = append(usersGroups[userID], group)
}
err = rows.Err()
if err != nil {
return users, err
}
if len(usersGroups) == 0 {
return users, err
}
for idx := range users {
ref := &users[idx]
ref.Groups = usersGroups[ref.ID]
}
return users, err
}
func getGroupWithUsers(ctx context.Context, group Group, dbHandle sqlQuerier) (Group, error) {
groups, err := getGroupsWithUsers(ctx, []Group{group}, dbHandle)
if err != nil {
return group, err
}
if len(groups) == 0 {
return group, errSQLUsersAssociation
}
return groups[0], err
}
func getGroupWithVirtualFolders(ctx context.Context, group Group, dbHandle sqlQuerier) (Group, error) {
groups, err := getGroupsWithVirtualFolders(ctx, []Group{group}, dbHandle)
if err != nil {
return group, err
}
if len(groups) == 0 {
return group, errSQLFoldersAssociation
}
return groups[0], err
}
func getGroupsWithVirtualFolders(ctx context.Context, groups []Group, dbHandle sqlQuerier) ([]Group, error) {
if len(groups) == 0 {
return groups, nil
}
var err error
q := getRelatedFoldersForGroupsQuery(groups)
stmt, err := dbHandle.PrepareContext(ctx, q)
if err != nil {
providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
return nil, err
}
defer stmt.Close()
rows, err := stmt.QueryContext(ctx)
if err != nil {
return nil, err
}
defer rows.Close()
groupsVirtualFolders := make(map[int64][]vfs.VirtualFolder)
for rows.Next() {
var groupID int64
var folder vfs.VirtualFolder
var mappedPath, fsConfig, description sql.NullString
err = rows.Scan(&folder.ID, &folder.Name, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles,
&folder.LastQuotaUpdate, &folder.VirtualPath, &folder.QuotaSize, &folder.QuotaFiles, &groupID, &fsConfig,
&description)
if err != nil {
return groups, err
}
if mappedPath.Valid {
folder.MappedPath = mappedPath.String
}
if description.Valid {
folder.Description = description.String
}
if fsConfig.Valid {
var fs vfs.Filesystem
err = json.Unmarshal([]byte(fsConfig.String), &fs)
if err == nil {
folder.FsConfig = fs
}
}
groupsVirtualFolders[groupID] = append(groupsVirtualFolders[groupID], folder)
}
err = rows.Err()
if err != nil {
return groups, err
}
if len(groupsVirtualFolders) == 0 {
return groups, err
}
for idx := range groups {
ref := &groups[idx]
ref.VirtualFolders = groupsVirtualFolders[ref.ID]
}
return groups, err
}
func getGroupsWithUsers(ctx context.Context, groups []Group, dbHandle sqlQuerier) ([]Group, error) {
if len(groups) == 0 {
return groups, nil
}
var err error
q := getRelatedUsersForGroupsQuery(groups)
stmt, err := dbHandle.PrepareContext(ctx, q)
if err != nil {
providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
return nil, err
}
defer stmt.Close()
rows, err := stmt.QueryContext(ctx)
if err != nil {
return nil, err
}
defer rows.Close()
groupsUsers := make(map[int64][]string)
for rows.Next() {
var username string
var groupID int64
err = rows.Scan(&groupID, &username)
if err != nil {
return groups, err
}
groupsUsers[groupID] = append(groupsUsers[groupID], username)
}
err = rows.Err()
if err != nil {
return groups, err
}
if len(groupsUsers) == 0 {
return groups, err
}
for idx := range groups {
ref := &groups[idx]
ref.Users = groupsUsers[ref.ID]
}
return groups, err
}
func getVirtualFoldersWithGroups(folders []vfs.BaseVirtualFolder, dbHandle sqlQuerier) ([]vfs.BaseVirtualFolder, error) {
if len(folders) == 0 {
return folders, nil
}
var err error
vFoldersGroups := make(map[int64][]string)
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
q := getRelatedGroupsForFoldersQuery(folders)
stmt, err := dbHandle.PrepareContext(ctx, q)
if err != nil {
providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
return nil, err
}
defer stmt.Close()
rows, err := stmt.QueryContext(ctx)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var name string
var folderID int64
err = rows.Scan(&folderID, &name)
if err != nil {
return folders, err
}
vFoldersGroups[folderID] = append(vFoldersGroups[folderID], name)
}
err = rows.Err()
if err != nil {
return folders, err
}
if len(vFoldersGroups) == 0 {
return folders, err
}
for idx := range folders {
ref := &folders[idx]
ref.Groups = vFoldersGroups[ref.ID]
}
return folders, err
}
func getVirtualFoldersWithUsers(folders []vfs.BaseVirtualFolder, dbHandle sqlQuerier) ([]vfs.BaseVirtualFolder, error) {
if len(folders) == 0 {
return folders, nil
}
var err error
vFoldersUsers := make(map[int64][]string)
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
q := getRelatedUsersForFoldersQuery(folders)
@@ -2073,6 +2767,8 @@ func getVirtualFoldersWithUsers(folders []vfs.BaseVirtualFolder, dbHandle sqlQue
return nil, err
}
defer rows.Close()
vFoldersUsers := make(map[int64][]string)
for rows.Next() {
var username string
var folderID int64