diff --git a/internal/dataprovider/bolt.go b/internal/dataprovider/bolt.go index c90c2f08..40d0f88a 100644 --- a/internal/dataprovider/bolt.go +++ b/internal/dataprovider/bolt.go @@ -39,7 +39,7 @@ import ( ) const ( - boltDatabaseVersion = 32 + boltDatabaseVersion = 33 ) var ( @@ -446,9 +446,6 @@ func (p *BoltProvider) addAdmin(admin *Admin) error { admin.LastLogin = 0 admin.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) admin.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) - sort.Slice(admin.Groups, func(i, j int) bool { - return admin.Groups[i].Name < admin.Groups[j].Name - }) for idx := range admin.Groups { err = p.addAdminToGroupMapping(admin.Username, admin.Groups[idx].Name, groupBucket) if err != nil { @@ -507,9 +504,6 @@ func (p *BoltProvider) updateAdmin(admin *Admin) error { if err = p.addAdminToRole(admin.Username, admin.Role, rolesBucket); err != nil { return err } - sort.Slice(admin.Groups, func(i, j int) bool { - return admin.Groups[i].Name < admin.Groups[j].Name - }) for idx := range admin.Groups { err = p.addAdminToGroupMapping(admin.Username, admin.Groups[idx].Name, groupBucket) if err != nil { @@ -721,18 +715,12 @@ func (p *BoltProvider) addUser(user *User) error { if err := p.addUserToRole(user.Username, user.Role, rolesBucket); err != nil { return err } - sort.Slice(user.VirtualFolders, func(i, j int) bool { - return user.VirtualFolders[i].Name < user.VirtualFolders[j].Name - }) for idx := range user.VirtualFolders { err = p.addRelationToFolderMapping(user.VirtualFolders[idx].Name, user, nil, foldersBucket) if err != nil { return err } } - sort.Slice(user.Groups, func(i, j int) bool { - return user.Groups[i].Name < user.Groups[j].Name - }) for idx := range user.Groups { err = p.addUserToGroupMapping(user.Username, user.Groups[idx].Name, groupBucket) if err != nil { @@ -1504,9 +1492,6 @@ func (p *BoltProvider) addGroup(group *Group) error { group.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) group.Users = nil group.Admins = nil - sort.Slice(group.VirtualFolders, func(i, j int) bool { - return group.VirtualFolders[i].Name < group.VirtualFolders[j].Name - }) for idx := range group.VirtualFolders { err = p.addRelationToFolderMapping(group.VirtualFolders[idx].Name, nil, group, foldersBucket) if err != nil { @@ -1549,9 +1534,6 @@ func (p *BoltProvider) updateGroup(group *Group) error { return err } } - sort.Slice(group.VirtualFolders, func(i, j int) bool { - return group.VirtualFolders[i].Name < group.VirtualFolders[j].Name - }) for idx := range group.VirtualFolders { err = p.addRelationToFolderMapping(group.VirtualFolders[idx].Name, nil, group, foldersBucket) if err != nil { @@ -3185,13 +3167,15 @@ func (p *BoltProvider) migrateDatabase() error { providerLog(logger.LevelError, "%v", err) logger.ErrorToConsole("%v", err) return err - case version == 29, version == 30, version == 31: - logger.InfoToConsole("updating database schema version: %d -> 32", version) - providerLog(logger.LevelInfo, "updating database schema version: %d -> 32", version) - if err := updateEventActions(); err != nil { - return err + case version == 29, version == 30, version == 31, version == 32: + logger.InfoToConsole("updating database schema version: %d -> 33", version) + providerLog(logger.LevelInfo, "updating database schema version: %d -> 33", version) + if version <= 31 { + if err := updateEventActions(); err != nil { + return err + } } - return updateBoltDatabaseVersion(p.dbHandle, 32) + return updateBoltDatabaseVersion(p.dbHandle, 33) default: if version > boltDatabaseVersion { providerLog(logger.LevelError, "database schema version %d is newer than the supported one: %d", version, @@ -3213,10 +3197,10 @@ func (p *BoltProvider) revertDatabase(targetVersion int) error { //nolint:gocycl return errors.New("current version match target version, nothing to do") } switch dbVersion.Version { - case 30, 31, 32: + case 30, 31, 32, 33: logger.InfoToConsole("downgrading database schema version: %d -> 29", dbVersion.Version) providerLog(logger.LevelInfo, "downgrading database schema version: %d -> 29", dbVersion.Version) - if dbVersion.Version == 32 { + if dbVersion.Version >= 32 { if err := restoreEventActions(); err != nil { return err } @@ -3745,18 +3729,12 @@ func (p *BoltProvider) updateUserRelations(tx *bolt.Tx, user *User, oldUser User if err = p.removeUserFromRole(oldUser.Username, oldUser.Role, rolesBucket); err != nil { return err } - sort.Slice(user.VirtualFolders, func(i, j int) bool { - return user.VirtualFolders[i].Name < user.VirtualFolders[j].Name - }) for idx := range user.VirtualFolders { err = p.addRelationToFolderMapping(user.VirtualFolders[idx].Name, user, nil, foldersBucket) if err != nil { return err } } - sort.Slice(user.Groups, func(i, j int) bool { - return user.Groups[i].Name < user.Groups[j].Name - }) for idx := range user.Groups { err = p.addUserToGroupMapping(user.Username, user.Groups[idx].Name, groupsBucket) if err != nil { diff --git a/internal/dataprovider/dataprovider.go b/internal/dataprovider/dataprovider.go index 8f3df50a..805c68e3 100644 --- a/internal/dataprovider/dataprovider.go +++ b/internal/dataprovider/dataprovider.go @@ -113,7 +113,6 @@ const ( operationDelete = "delete" sqlPrefixValidChars = "abcdefghijklmnopqrstuvwxyz_0123456789" maxHookResponseSize = 1048576 // 1MB - iso8601UTCFormat = "2006-01-02T15:04:05Z" ) // Supported algorithms for hashing passwords. diff --git a/internal/dataprovider/memory.go b/internal/dataprovider/memory.go index 8c7c0c48..e8595291 100644 --- a/internal/dataprovider/memory.go +++ b/internal/dataprovider/memory.go @@ -376,9 +376,6 @@ func (p *MemoryProvider) addUser(user *User) error { if err := p.addUserToRole(user.Username, user.Role); err != nil { return err } - sort.Slice(user.Groups, func(i, j int) bool { - return user.Groups[i].Name < user.Groups[j].Name - }) var mappedGroups []string for idx := range user.Groups { if err = p.addUserToGroupMapping(user.Username, user.Groups[idx].Name); err != nil { @@ -390,9 +387,6 @@ func (p *MemoryProvider) addUser(user *User) error { } mappedGroups = append(mappedGroups, user.Groups[idx].Name) } - sort.Slice(user.VirtualFolders, func(i, j int) bool { - return user.VirtualFolders[i].Name < user.VirtualFolders[j].Name - }) var mappedFolders []string for idx := range user.VirtualFolders { if err = p.addUserToFolderMapping(user.Username, user.VirtualFolders[idx].Name); err != nil { @@ -438,9 +432,6 @@ func (p *MemoryProvider) updateUser(user *User) error { //nolint:gocyclo for idx := range u.Groups { p.removeUserFromGroupMapping(u.Username, u.Groups[idx].Name) } - sort.Slice(user.Groups, func(i, j int) bool { - return user.Groups[i].Name < user.Groups[j].Name - }) for idx := range user.Groups { if err = p.addUserToGroupMapping(user.Username, user.Groups[idx].Name); err != nil { // try to add old mapping @@ -456,9 +447,6 @@ func (p *MemoryProvider) updateUser(user *User) error { //nolint:gocyclo for _, oldFolder := range u.VirtualFolders { p.removeRelationFromFolderMapping(oldFolder.Name, u.Username, "") } - sort.Slice(user.VirtualFolders, func(i, j int) bool { - return user.VirtualFolders[i].Name < user.VirtualFolders[j].Name - }) for idx := range user.VirtualFolders { if err = p.addUserToFolderMapping(user.Username, user.VirtualFolders[idx].Name); err != nil { // try to add old mapping @@ -771,9 +759,6 @@ func (p *MemoryProvider) addAdmin(admin *Admin) error { return err } var mappedAdmins []string - sort.Slice(admin.Groups, func(i, j int) bool { - return admin.Groups[i].Name < admin.Groups[j].Name - }) for idx := range admin.Groups { if err = p.addAdminToGroupMapping(admin.Username, admin.Groups[idx].Name); err != nil { // try to remove group mapping @@ -816,9 +801,6 @@ func (p *MemoryProvider) updateAdmin(admin *Admin) error { for idx := range a.Groups { p.removeAdminFromGroupMapping(a.Username, a.Groups[idx].Name) } - sort.Slice(admin.Groups, func(i, j int) bool { - return admin.Groups[i].Name < admin.Groups[j].Name - }) for idx := range admin.Groups { if err = p.addAdminToGroupMapping(admin.Username, admin.Groups[idx].Name); err != nil { // try to add old mapping @@ -1082,9 +1064,6 @@ func (p *MemoryProvider) addGroup(group *Group) error { group.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) group.Users = nil group.Admins = nil - sort.Slice(group.VirtualFolders, func(i, j int) bool { - return group.VirtualFolders[i].Name < group.VirtualFolders[j].Name - }) var mappedFolders []string for idx := range group.VirtualFolders { if err = p.addGroupToFolderMapping(group.Name, group.VirtualFolders[idx].Name); err != nil { @@ -1118,9 +1097,6 @@ func (p *MemoryProvider) updateGroup(group *Group) error { for _, oldFolder := range g.VirtualFolders { p.removeRelationFromFolderMapping(oldFolder.Name, "", g.Name) } - sort.Slice(group.VirtualFolders, func(i, j int) bool { - return group.VirtualFolders[i].Name < group.VirtualFolders[j].Name - }) for idx := range group.VirtualFolders { if err = p.addGroupToFolderMapping(group.Name, group.VirtualFolders[idx].Name); err != nil { // try to add old mapping diff --git a/internal/dataprovider/mysql.go b/internal/dataprovider/mysql.go index fabbd07d..12400c28 100644 --- a/internal/dataprovider/mysql.go +++ b/internal/dataprovider/mysql.go @@ -205,6 +205,22 @@ const ( "`data` longtext NOT NULL, `type` integer NOT NULL, `timestamp` bigint NOT NULL);" + "CREATE INDEX `{{prefix}}shared_sessions_type_idx` ON `{{shared_sessions}}` (`type`);" + "CREATE INDEX `{{prefix}}shared_sessions_timestamp_idx` ON `{{shared_sessions}}` (`timestamp`);" + mysqlV33SQL = "ALTER TABLE `{{admins_groups_mapping}}` ADD COLUMN `sort_order` integer DEFAULT 0 NOT NULL; " + + "ALTER TABLE `{{admins_groups_mapping}}` ALTER COLUMN `sort_order` DROP DEFAULT; " + + "ALTER TABLE `{{groups_folders_mapping}}` ADD COLUMN `sort_order` integer DEFAULT 0 NOT NULL; " + + "ALTER TABLE `{{groups_folders_mapping}}` ALTER COLUMN `sort_order` DROP DEFAULT; " + + "ALTER TABLE `{{users_folders_mapping}}` ADD COLUMN `sort_order` integer DEFAULT 0 NOT NULL; " + + "ALTER TABLE `{{users_folders_mapping}}` ALTER COLUMN `sort_order` DROP DEFAULT; " + + "ALTER TABLE `{{users_groups_mapping}}` ADD COLUMN `sort_order` integer DEFAULT 0 NOT NULL; " + + "ALTER TABLE `{{users_groups_mapping}}` ALTER COLUMN `sort_order` DROP DEFAULT; " + + "CREATE INDEX `{{prefix}}admins_groups_mapping_sort_order_idx` ON `{{admins_groups_mapping}}` (`sort_order`); " + + "CREATE INDEX `{{prefix}}groups_folders_mapping_sort_order_idx` ON `{{groups_folders_mapping}}` (`sort_order`); " + + "CREATE INDEX `{{prefix}}users_folders_mapping_sort_order_idx` ON `{{users_folders_mapping}}` (`sort_order`);" + + "CREATE INDEX `{{prefix}}users_groups_mapping_sort_order_idx` ON `{{users_groups_mapping}}` (`sort_order`);" + mysqlV33DownSQL = "ALTER TABLE `{{users_groups_mapping}}` DROP COLUMN `sort_order`; " + + "ALTER TABLE `{{users_folders_mapping}}` DROP COLUMN `sort_order`; " + + "ALTER TABLE `{{groups_folders_mapping}}` DROP COLUMN `sort_order`; " + + "ALTER TABLE `{{admins_groups_mapping}}` DROP COLUMN `sort_order`; " ) // MySQLProvider defines the auth provider for MySQL/MariaDB database @@ -819,6 +835,8 @@ func (p *MySQLProvider) migrateDatabase() error { return updateMySQLDatabaseFromV30(p.dbHandle) case version == 31: return updateMySQLDatabaseFromV31(p.dbHandle) + case version == 32: + return updateMySQLDatabaseFromV32(p.dbHandle) default: if version > sqlDatabaseVersion { providerLog(logger.LevelError, "database schema version %d is newer than the supported one: %d", version, @@ -847,6 +865,8 @@ func (p *MySQLProvider) revertDatabase(targetVersion int) error { return downgradeMySQLDatabaseFromV31(p.dbHandle) case 32: return downgradeMySQLDatabaseFromV32(p.dbHandle) + case 33: + return downgradeMySQLDatabaseFromV33(p.dbHandle) default: return fmt.Errorf("database schema version not handled: %d", dbVersion.Version) } @@ -900,7 +920,14 @@ func updateMySQLDatabaseFromV30(dbHandle *sql.DB) error { } func updateMySQLDatabaseFromV31(dbHandle *sql.DB) error { - return updateSQLDatabaseFrom31To32(dbHandle) + if err := updateSQLDatabaseFrom31To32(dbHandle); err != nil { + return err + } + return updateMySQLDatabaseFromV32(dbHandle) +} + +func updateMySQLDatabaseFromV32(dbHandle *sql.DB) error { + return updateMySQLDatabaseFrom32To33(dbHandle) } func downgradeMySQLDatabaseFromV30(dbHandle *sql.DB) error { @@ -921,6 +948,13 @@ func downgradeMySQLDatabaseFromV32(dbHandle *sql.DB) error { return downgradeMySQLDatabaseFromV31(dbHandle) } +func downgradeMySQLDatabaseFromV33(dbHandle *sql.DB) error { + if err := downgradeMySQLDatabaseFrom33To32(dbHandle); err != nil { + return err + } + return downgradeMySQLDatabaseFromV32(dbHandle) +} + func updateMySQLDatabaseFrom29To30(dbHandle *sql.DB) error { logger.InfoToConsole("updating database schema version: 29 -> 30") providerLog(logger.LevelInfo, "updating database schema version: 29 -> 30") @@ -954,3 +988,28 @@ func downgradeMySQLDatabaseFrom31To30(dbHandle *sql.DB) error { sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix) return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 30, false) } + +func updateMySQLDatabaseFrom32To33(dbHandle *sql.DB) error { + logger.InfoToConsole("updating database schema version: 32 -> 33") + providerLog(logger.LevelInfo, "updating database schema version: 32 -> 33") + + sql := strings.ReplaceAll(mysqlV33SQL, "{{prefix}}", config.SQLTablesPrefix) + sql = strings.ReplaceAll(sql, "{{users_folders_mapping}}", sqlTableUsersFoldersMapping) + sql = strings.ReplaceAll(sql, "{{users_groups_mapping}}", sqlTableUsersGroupsMapping) + sql = strings.ReplaceAll(sql, "{{admins_groups_mapping}}", sqlTableAdminsGroupsMapping) + sql = strings.ReplaceAll(sql, "{{groups_folders_mapping}}", sqlTableGroupsFoldersMapping) + sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 33, true) +} + +func downgradeMySQLDatabaseFrom33To32(dbHandle *sql.DB) error { + logger.InfoToConsole("downgrading database schema version: 33 -> 32") + providerLog(logger.LevelInfo, "downgrading database schema version: 33 -> 32") + + sql := mysqlV33DownSQL + sql = strings.ReplaceAll(sql, "{{users_folders_mapping}}", sqlTableUsersFoldersMapping) + sql = strings.ReplaceAll(sql, "{{users_groups_mapping}}", sqlTableUsersGroupsMapping) + sql = strings.ReplaceAll(sql, "{{admins_groups_mapping}}", sqlTableAdminsGroupsMapping) + sql = strings.ReplaceAll(sql, "{{groups_folders_mapping}}", sqlTableGroupsFoldersMapping) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 32, false) +} diff --git a/internal/dataprovider/pgsql.go b/internal/dataprovider/pgsql.go index ae38e9ed..ea5c2156 100644 --- a/internal/dataprovider/pgsql.go +++ b/internal/dataprovider/pgsql.go @@ -221,7 +221,26 @@ CREATE INDEX "{{prefix}}shared_sessions_timestamp_idx" ON "{{shared_sessions}}" CREATE TABLE "{{shared_sessions}}" ("key" varchar(128) NOT NULL PRIMARY KEY, "data" text NOT NULL, "type" integer NOT NULL, "timestamp" bigint NOT NULL); CREATE INDEX "{{prefix}}shared_sessions_type_idx" ON "{{shared_sessions}}" ("type"); -CREATE INDEX "{{prefix}}shared_sessions_timestamp_idx" ON "{{shared_sessions}}" ("timestamp");` +CREATE INDEX "{{prefix}}shared_sessions_timestamp_idx" ON "{{shared_sessions}}" ("timestamp"); +` + pgsqlV33SQL = `ALTER TABLE "{{admins_groups_mapping}}" ADD COLUMN "sort_order" integer DEFAULT 0 NOT NULL; +ALTER TABLE "{{admins_groups_mapping}}" ALTER COLUMN "sort_order" DROP DEFAULT; +ALTER TABLE "{{groups_folders_mapping}}" ADD COLUMN "sort_order" integer DEFAULT 0 NOT NULL; +ALTER TABLE "{{groups_folders_mapping}}" ALTER COLUMN "sort_order" DROP DEFAULT; +ALTER TABLE "{{users_folders_mapping}}" ADD COLUMN "sort_order" integer DEFAULT 0 NOT NULL; +ALTER TABLE "{{users_folders_mapping}}" ALTER COLUMN "sort_order" DROP DEFAULT; +ALTER TABLE "{{users_groups_mapping}}" ADD COLUMN "sort_order" integer DEFAULT 0 NOT NULL; +ALTER TABLE "{{users_groups_mapping}}" ALTER COLUMN "sort_order" DROP DEFAULT; +CREATE INDEX "{{prefix}}admins_groups_mapping_sort_order_idx" ON "{{admins_groups_mapping}}" ("sort_order"); +CREATE INDEX "{{prefix}}groups_folders_mapping_sort_order_idx" ON "{{groups_folders_mapping}}" ("sort_order"); +CREATE INDEX "{{prefix}}users_folders_mapping_sort_order_idx" ON "{{users_folders_mapping}}" ("sort_order"); +CREATE INDEX "{{prefix}}users_groups_mapping_sort_order_idx" ON "{{users_groups_mapping}}" ("sort_order"); +` + pgsqlV33DownSQL = `ALTER TABLE "{{users_groups_mapping}}" DROP COLUMN "sort_order" CASCADE; +ALTER TABLE "{{users_folders_mapping}}" DROP COLUMN "sort_order" CASCADE; +ALTER TABLE "{{groups_folders_mapping}}" DROP COLUMN "sort_order" CASCADE; +ALTER TABLE "{{admins_groups_mapping}}" DROP COLUMN "sort_order" CASCADE; +` ) var ( @@ -844,6 +863,8 @@ func (p *PGSQLProvider) migrateDatabase() error { //nolint:dupl return updatePGSQLDatabaseFromV30(p.dbHandle) case version == 31: return updatePGSQLDatabaseFromV31(p.dbHandle) + case version == 32: + return updatePGSQLDatabaseFromV32(p.dbHandle) default: if version > sqlDatabaseVersion { providerLog(logger.LevelError, "database schema version %d is newer than the supported one: %d", version, @@ -872,6 +893,8 @@ func (p *PGSQLProvider) revertDatabase(targetVersion int) error { return downgradePGSQLDatabaseFromV31(p.dbHandle) case 32: return downgradePGSQLDatabaseFromV32(p.dbHandle) + case 33: + return downgradePGSQLDatabaseFromV33(p.dbHandle) default: return fmt.Errorf("database schema version not handled: %d", dbVersion.Version) } @@ -925,7 +948,14 @@ func updatePGSQLDatabaseFromV30(dbHandle *sql.DB) error { } func updatePGSQLDatabaseFromV31(dbHandle *sql.DB) error { - return updateSQLDatabaseFrom31To32(dbHandle) + if err := updateSQLDatabaseFrom31To32(dbHandle); err != nil { + return err + } + return updatePGSQLDatabaseFromV32(dbHandle) +} + +func updatePGSQLDatabaseFromV32(dbHandle *sql.DB) error { + return updatePGSQLDatabaseFrom32To33(dbHandle) } func downgradePGSQLDatabaseFromV30(dbHandle *sql.DB) error { @@ -946,6 +976,13 @@ func downgradePGSQLDatabaseFromV32(dbHandle *sql.DB) error { return downgradePGSQLDatabaseFromV31(dbHandle) } +func downgradePGSQLDatabaseFromV33(dbHandle *sql.DB) error { + if err := downgradePGSQLDatabaseFrom33To32(dbHandle); err != nil { + return err + } + return downgradePGSQLDatabaseFromV32(dbHandle) +} + func updatePGSQLDatabaseFrom29To30(dbHandle *sql.DB) error { logger.InfoToConsole("updating database schema version: 29 -> 30") providerLog(logger.LevelInfo, "updating database schema version: 29 -> 30") @@ -979,3 +1016,28 @@ func downgradePGSQLDatabaseFrom31To30(dbHandle *sql.DB) error { sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix) return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 30, false) } + +func updatePGSQLDatabaseFrom32To33(dbHandle *sql.DB) error { + logger.InfoToConsole("updating database schema version: 32 -> 33") + providerLog(logger.LevelInfo, "updating database schema version: 32 -> 33") + + sql := strings.ReplaceAll(pgsqlV33SQL, "{{prefix}}", config.SQLTablesPrefix) + sql = strings.ReplaceAll(sql, "{{users_folders_mapping}}", sqlTableUsersFoldersMapping) + sql = strings.ReplaceAll(sql, "{{users_groups_mapping}}", sqlTableUsersGroupsMapping) + sql = strings.ReplaceAll(sql, "{{admins_groups_mapping}}", sqlTableAdminsGroupsMapping) + sql = strings.ReplaceAll(sql, "{{groups_folders_mapping}}", sqlTableGroupsFoldersMapping) + sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 33, true) +} + +func downgradePGSQLDatabaseFrom33To32(dbHandle *sql.DB) error { + logger.InfoToConsole("downgrading database schema version: 33 -> 32") + providerLog(logger.LevelInfo, "downgrading database schema version: 33 -> 32") + + sql := pgsqlV33DownSQL + sql = strings.ReplaceAll(sql, "{{users_folders_mapping}}", sqlTableUsersFoldersMapping) + sql = strings.ReplaceAll(sql, "{{users_groups_mapping}}", sqlTableUsersGroupsMapping) + sql = strings.ReplaceAll(sql, "{{admins_groups_mapping}}", sqlTableAdminsGroupsMapping) + sql = strings.ReplaceAll(sql, "{{groups_folders_mapping}}", sqlTableGroupsFoldersMapping) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 32, false) +} diff --git a/internal/dataprovider/sqlcommon.go b/internal/dataprovider/sqlcommon.go index aba4675f..541b5396 100644 --- a/internal/dataprovider/sqlcommon.go +++ b/internal/dataprovider/sqlcommon.go @@ -36,7 +36,7 @@ import ( ) const ( - sqlDatabaseVersion = 32 + sqlDatabaseVersion = 33 defaultSQLQueryTimeout = 10 * time.Second longSQLQueryTimeout = 60 * time.Second ) @@ -2523,9 +2523,9 @@ func sqlCommonClearUserGroupMapping(ctx context.Context, user *User, dbHandle sq return err } -func sqlCommonAddUserFolderMapping(ctx context.Context, user *User, folder *vfs.VirtualFolder, dbHandle sqlQuerier) error { +func sqlCommonAddUserFolderMapping(ctx context.Context, user *User, folder *vfs.VirtualFolder, sortOrder int, dbHandle sqlQuerier) error { q := getAddUserFolderMappingQuery() - _, err := dbHandle.ExecContext(ctx, q, folder.VirtualPath, folder.QuotaSize, folder.QuotaFiles, folder.Name, user.Username) + _, err := dbHandle.ExecContext(ctx, q, folder.VirtualPath, folder.QuotaSize, folder.QuotaFiles, folder.Name, user.Username, sortOrder) return err } @@ -2535,27 +2535,29 @@ func sqlCommonClearAdminGroupMapping(ctx context.Context, admin *Admin, dbHandle return err } -func sqlCommonAddGroupFolderMapping(ctx context.Context, group *Group, folder *vfs.VirtualFolder, dbHandle sqlQuerier) error { +func sqlCommonAddGroupFolderMapping(ctx context.Context, group *Group, folder *vfs.VirtualFolder, sortOrder int, + dbHandle sqlQuerier, +) error { q := getAddGroupFolderMappingQuery() - _, err := dbHandle.ExecContext(ctx, q, folder.VirtualPath, folder.QuotaSize, folder.QuotaFiles, folder.Name, group.Name) + _, err := dbHandle.ExecContext(ctx, q, folder.VirtualPath, folder.QuotaSize, folder.QuotaFiles, folder.Name, group.Name, sortOrder) return err } -func sqlCommonAddUserGroupMapping(ctx context.Context, username, groupName string, groupType int, dbHandle sqlQuerier) error { +func sqlCommonAddUserGroupMapping(ctx context.Context, username, groupName string, groupType, sortOrder int, dbHandle sqlQuerier) error { q := getAddUserGroupMappingQuery() - _, err := dbHandle.ExecContext(ctx, q, username, groupName, groupType) + _, err := dbHandle.ExecContext(ctx, q, username, groupName, groupType, sortOrder) return err } func sqlCommonAddAdminGroupMapping(ctx context.Context, username, groupName string, mappingOptions AdminGroupMappingOptions, - dbHandle sqlQuerier, + sortOrder int, dbHandle sqlQuerier, ) error { options, err := json.Marshal(mappingOptions) if err != nil { return err } q := getAddAdminGroupMappingQuery() - _, err = dbHandle.ExecContext(ctx, q, username, groupName, options) + _, err = dbHandle.ExecContext(ctx, q, username, groupName, options, sortOrder) return err } @@ -2566,7 +2568,7 @@ func generateGroupVirtualFoldersMapping(ctx context.Context, group *Group, dbHan } for idx := range group.VirtualFolders { vfolder := &group.VirtualFolders[idx] - err = sqlCommonAddGroupFolderMapping(ctx, group, vfolder, dbHandle) + err = sqlCommonAddGroupFolderMapping(ctx, group, vfolder, idx, dbHandle) if err != nil { return err } @@ -2581,7 +2583,7 @@ func generateUserVirtualFoldersMapping(ctx context.Context, user *User, dbHandle } for idx := range user.VirtualFolders { vfolder := &user.VirtualFolders[idx] - err = sqlCommonAddUserFolderMapping(ctx, user, vfolder, dbHandle) + err = sqlCommonAddUserFolderMapping(ctx, user, vfolder, idx, dbHandle) if err != nil { return err } @@ -2594,8 +2596,8 @@ func generateUserGroupMapping(ctx context.Context, user *User, dbHandle sqlQueri if err != nil { return err } - for _, group := range user.Groups { - err = sqlCommonAddUserGroupMapping(ctx, user.Username, group.Name, group.Type, dbHandle) + for idx, group := range user.Groups { + err = sqlCommonAddUserGroupMapping(ctx, user.Username, group.Name, group.Type, idx, dbHandle) if err != nil { return err } @@ -2608,8 +2610,8 @@ func generateAdminGroupMapping(ctx context.Context, admin *Admin, dbHandle sqlQu if err != nil { return err } - for _, group := range admin.Groups { - err = sqlCommonAddAdminGroupMapping(ctx, admin.Username, group.Name, group.Options, dbHandle) + for idx, group := range admin.Groups { + err = sqlCommonAddAdminGroupMapping(ctx, admin.Username, group.Name, group.Options, idx, dbHandle) if err != nil { return err } diff --git a/internal/dataprovider/sqlite.go b/internal/dataprovider/sqlite.go index 54347169..04311c19 100644 --- a/internal/dataprovider/sqlite.go +++ b/internal/dataprovider/sqlite.go @@ -193,6 +193,24 @@ CREATE TABLE "{{shared_sessions}}" ("key" varchar(128) NOT NULL PRIMARY KEY, "da "type" integer NOT NULL, "timestamp" bigint NOT NULL); CREATE INDEX "{{prefix}}shared_sessions_type_idx" ON "{{shared_sessions}}" ("type"); CREATE INDEX "{{prefix}}shared_sessions_timestamp_idx" ON "{{shared_sessions}}" ("timestamp"); +` + sqliteV33SQL = `ALTER TABLE "{{admins_groups_mapping}}" ADD COLUMN "sort_order" integer DEFAULT 0 NOT NULL; +ALTER TABLE "{{groups_folders_mapping}}" ADD COLUMN "sort_order" integer DEFAULT 0 NOT NULL; +ALTER TABLE "{{users_folders_mapping}}" ADD COLUMN "sort_order" integer DEFAULT 0 NOT NULL; +ALTER TABLE "{{users_groups_mapping}}" ADD COLUMN "sort_order" integer DEFAULT 0 NOT NULL; +CREATE INDEX "{{prefix}}admins_groups_mapping_sort_order_idx" ON "{{admins_groups_mapping}}" ("sort_order"); +CREATE INDEX "{{prefix}}groups_folders_mapping_sort_order_idx" ON "{{groups_folders_mapping}}" ("sort_order"); +CREATE INDEX "{{prefix}}users_folders_mapping_sort_order_idx" ON "{{users_folders_mapping}}" ("sort_order"); +CREATE INDEX "{{prefix}}users_groups_mapping_sort_order_idx" ON "{{users_groups_mapping}}" ("sort_order"); +` + sqliteV33DownSQL = `DROP INDEX "{{prefix}}users_groups_mapping_sort_order_idx"; +DROP INDEX "{{prefix}}users_folders_mapping_sort_order_idx"; +DROP INDEX "{{prefix}}groups_folders_mapping_sort_order_idx"; +DROP INDEX "{{prefix}}admins_groups_mapping_sort_order_idx"; +ALTER TABLE "{{users_groups_mapping}}" DROP COLUMN "sort_order"; +ALTER TABLE "{{users_folders_mapping}}" DROP COLUMN "sort_order"; +ALTER TABLE "{{groups_folders_mapping}}" DROP COLUMN "sort_order"; +ALTER TABLE "{{admins_groups_mapping}}" DROP COLUMN "sort_order"; ` ) @@ -742,6 +760,8 @@ func (p *SQLiteProvider) migrateDatabase() error { //nolint:dupl return updateSQLiteDatabaseFromV30(p.dbHandle) case version == 31: return updateSQLiteDatabaseFromV31(p.dbHandle) + case version == 32: + return updateSQLiteDatabaseFromV32(p.dbHandle) default: if version > sqlDatabaseVersion { providerLog(logger.LevelError, "database schema version %d is newer than the supported one: %d", version, @@ -770,6 +790,8 @@ func (p *SQLiteProvider) revertDatabase(targetVersion int) error { return downgradeSQLiteDatabaseFromV31(p.dbHandle) case 32: return downgradeSQLiteDatabaseFromV32(p.dbHandle) + case 33: + return downgradeSQLiteDatabaseFromV33(p.dbHandle) default: return fmt.Errorf("database schema version not handled: %d", dbVersion.Version) } @@ -830,7 +852,14 @@ func updateSQLiteDatabaseFromV30(dbHandle *sql.DB) error { } func updateSQLiteDatabaseFromV31(dbHandle *sql.DB) error { - return updateSQLDatabaseFrom31To32(dbHandle) + if err := updateSQLDatabaseFrom31To32(dbHandle); err != nil { + return err + } + return updateSQLiteDatabaseFromV32(dbHandle) +} + +func updateSQLiteDatabaseFromV32(dbHandle *sql.DB) error { + return updateSQLiteDatabaseFrom32To33(dbHandle) } func downgradeSQLiteDatabaseFromV30(dbHandle *sql.DB) error { @@ -851,6 +880,13 @@ func downgradeSQLiteDatabaseFromV32(dbHandle *sql.DB) error { return downgradeSQLiteDatabaseFromV31(dbHandle) } +func downgradeSQLiteDatabaseFromV33(dbHandle *sql.DB) error { + if err := downgradeSQLiteDatabaseFrom33To32(dbHandle); err != nil { + return err + } + return downgradeSQLiteDatabaseFromV32(dbHandle) +} + func updateSQLiteDatabaseFrom29To30(dbHandle *sql.DB) error { logger.InfoToConsole("updating database schema version: 29 -> 30") providerLog(logger.LevelInfo, "updating database schema version: 29 -> 30") @@ -885,6 +921,31 @@ func downgradeSQLiteDatabaseFrom31To30(dbHandle *sql.DB) error { return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 30, false) } +func updateSQLiteDatabaseFrom32To33(dbHandle *sql.DB) error { + logger.InfoToConsole("updating database schema version: 32 -> 33") + providerLog(logger.LevelInfo, "updating database schema version: 32 -> 33") + + sql := strings.ReplaceAll(sqliteV33SQL, "{{prefix}}", config.SQLTablesPrefix) + sql = strings.ReplaceAll(sql, "{{users_folders_mapping}}", sqlTableUsersFoldersMapping) + sql = strings.ReplaceAll(sql, "{{users_groups_mapping}}", sqlTableUsersGroupsMapping) + sql = strings.ReplaceAll(sql, "{{admins_groups_mapping}}", sqlTableAdminsGroupsMapping) + sql = strings.ReplaceAll(sql, "{{groups_folders_mapping}}", sqlTableGroupsFoldersMapping) + sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 33, true) +} + +func downgradeSQLiteDatabaseFrom33To32(dbHandle *sql.DB) error { + logger.InfoToConsole("downgrading database schema version: 33 -> 32") + providerLog(logger.LevelInfo, "downgrading database schema version: 33 -> 32") + + sql := strings.ReplaceAll(sqliteV33DownSQL, "{{prefix}}", config.SQLTablesPrefix) + sql = strings.ReplaceAll(sql, "{{users_folders_mapping}}", sqlTableUsersFoldersMapping) + sql = strings.ReplaceAll(sql, "{{users_groups_mapping}}", sqlTableUsersGroupsMapping) + sql = strings.ReplaceAll(sql, "{{admins_groups_mapping}}", sqlTableAdminsGroupsMapping) + sql = strings.ReplaceAll(sql, "{{groups_folders_mapping}}", sqlTableGroupsFoldersMapping) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 32, false) +} + /*func setPragmaFK(dbHandle *sql.DB, value string) error { ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout) defer cancel() diff --git a/internal/dataprovider/sqlqueries.go b/internal/dataprovider/sqlqueries.go index 957cd82a..2d5aaf6b 100644 --- a/internal/dataprovider/sqlqueries.go +++ b/internal/dataprovider/sqlqueries.go @@ -767,10 +767,10 @@ func getClearUserGroupMappingQuery() string { } func getAddUserGroupMappingQuery() string { - return fmt.Sprintf(`INSERT INTO %s (user_id,group_id,group_type) VALUES ((SELECT id FROM %s WHERE username = %s), - (SELECT id FROM %s WHERE name = %s),%s)`, + return fmt.Sprintf(`INSERT INTO %s (user_id,group_id,group_type,sort_order) VALUES ((SELECT id FROM %s WHERE username = %s), + (SELECT id FROM %s WHERE name = %s),%s,%s)`, sqlTableUsersGroupsMapping, sqlTableUsers, sqlPlaceholders[0], getSQLQuotedName(sqlTableGroups), - sqlPlaceholders[1], sqlPlaceholders[2]) + sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) } func getClearAdminGroupMappingQuery() string { @@ -779,10 +779,10 @@ func getClearAdminGroupMappingQuery() string { } func getAddAdminGroupMappingQuery() string { - return fmt.Sprintf(`INSERT INTO %s (admin_id,group_id,options) VALUES ((SELECT id FROM %s WHERE username = %s), - (SELECT id FROM %s WHERE name = %s),%s)`, + return fmt.Sprintf(`INSERT INTO %s (admin_id,group_id,options,sort_order) VALUES ((SELECT id FROM %s WHERE username = %s), + (SELECT id FROM %s WHERE name = %s),%s,%s)`, sqlTableAdminsGroupsMapping, sqlTableAdmins, sqlPlaceholders[0], getSQLQuotedName(sqlTableGroups), - sqlPlaceholders[1], sqlPlaceholders[2]) + sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) } func getClearGroupFolderMappingQuery() string { @@ -791,10 +791,10 @@ func getClearGroupFolderMappingQuery() string { } func getAddGroupFolderMappingQuery() string { - return fmt.Sprintf(`INSERT INTO %s (virtual_path,quota_size,quota_files,folder_id,group_id) - VALUES (%s,%s,%s,(SELECT id FROM %s WHERE name = %s),(SELECT id FROM %s WHERE name = %s))`, + return fmt.Sprintf(`INSERT INTO %s (virtual_path,quota_size,quota_files,folder_id,group_id,sort_order) + VALUES (%s,%s,%s,(SELECT id FROM %s WHERE name = %s),(SELECT id FROM %s WHERE name = %s),%s)`, sqlTableGroupsFoldersMapping, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlTableFolders, - sqlPlaceholders[3], getSQLQuotedName(sqlTableGroups), sqlPlaceholders[4]) + sqlPlaceholders[3], getSQLQuotedName(sqlTableGroups), sqlPlaceholders[4], sqlPlaceholders[5]) } func getClearUserFolderMappingQuery() string { @@ -803,10 +803,10 @@ func getClearUserFolderMappingQuery() string { } func getAddUserFolderMappingQuery() string { - return fmt.Sprintf(`INSERT INTO %s (virtual_path,quota_size,quota_files,folder_id,user_id) - VALUES (%s,%s,%s,(SELECT id FROM %s WHERE name = %s),(SELECT id FROM %s WHERE username = %s))`, + return fmt.Sprintf(`INSERT INTO %s (virtual_path,quota_size,quota_files,folder_id,user_id,sort_order) + VALUES (%s,%s,%s,(SELECT id FROM %s WHERE name = %s),(SELECT id FROM %s WHERE username = %s),%s)`, sqlTableUsersFoldersMapping, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlTableFolders, - sqlPlaceholders[3], sqlTableUsers, sqlPlaceholders[4]) + sqlPlaceholders[3], sqlTableUsers, sqlPlaceholders[4], sqlPlaceholders[5]) } func getFoldersQuery(order string, minimal bool) string { @@ -848,7 +848,7 @@ func getRelatedGroupsForUsersQuery(users []User) string { sb.WriteString(")") } return fmt.Sprintf(`SELECT g.name,ug.group_type,ug.user_id FROM %s g INNER JOIN %s ug ON g.id = ug.group_id WHERE - ug.user_id IN %s ORDER BY g.name`, getSQLQuotedName(sqlTableGroups), sqlTableUsersGroupsMapping, sb.String()) + ug.user_id IN %s ORDER BY ug.sort_order`, getSQLQuotedName(sqlTableGroups), sqlTableUsersGroupsMapping, sb.String()) } func getRelatedGroupsForAdminsQuery(admins []Admin) string { @@ -865,7 +865,7 @@ func getRelatedGroupsForAdminsQuery(admins []Admin) string { sb.WriteString(")") } return fmt.Sprintf(`SELECT g.name,ag.options,ag.admin_id FROM %s g INNER JOIN %s ag ON g.id = ag.group_id WHERE - ag.admin_id IN %s ORDER BY g.name`, getSQLQuotedName(sqlTableGroups), sqlTableAdminsGroupsMapping, sb.String()) + ag.admin_id IN %s ORDER BY ag.sort_order`, getSQLQuotedName(sqlTableGroups), sqlTableAdminsGroupsMapping, sb.String()) } func getRelatedFoldersForUsersQuery(users []User) string { @@ -883,7 +883,7 @@ func getRelatedFoldersForUsersQuery(users []User) string { } return fmt.Sprintf(`SELECT f.id,f.name,f.path,f.used_quota_size,f.used_quota_files,f.last_quota_update,fm.virtual_path, fm.quota_size,fm.quota_files,fm.user_id,f.filesystem,f.description FROM %s f INNER JOIN %s fm ON f.id = fm.folder_id WHERE - fm.user_id IN %s ORDER BY f.name`, sqlTableFolders, sqlTableUsersFoldersMapping, sb.String()) + fm.user_id IN %s ORDER BY fm.sort_order`, sqlTableFolders, sqlTableUsersFoldersMapping, sb.String()) } func getRelatedUsersForFoldersQuery(folders []vfs.BaseVirtualFolder) string { @@ -970,7 +970,7 @@ func getRelatedFoldersForGroupsQuery(groups []Group) string { } return fmt.Sprintf(`SELECT f.id,f.name,f.path,f.used_quota_size,f.used_quota_files,f.last_quota_update,fm.virtual_path, fm.quota_size,fm.quota_files,fm.group_id,f.filesystem,f.description FROM %s f INNER JOIN %s fm ON f.id = fm.folder_id WHERE - fm.group_id IN %s ORDER BY f.name`, sqlTableFolders, sqlTableGroupsFoldersMapping, sb.String()) + fm.group_id IN %s ORDER BY fm.sort_order`, sqlTableFolders, sqlTableGroupsFoldersMapping, sb.String()) } func getActiveTransfersQuery() string { diff --git a/internal/httpd/httpd_test.go b/internal/httpd/httpd_test.go index e54c513f..1c74775a 100644 --- a/internal/httpd/httpd_test.go +++ b/internal/httpd/httpd_test.go @@ -960,6 +960,282 @@ func TestTLSCert(t *testing.T) { assert.NoError(t, err) } +func TestSortRelatedFolders(t *testing.T) { + folder1 := util.GenerateUniqueID() + folder2 := util.GenerateUniqueID() + folder3 := util.GenerateUniqueID() + + f1 := vfs.BaseVirtualFolder{ + Name: folder1, + MappedPath: filepath.Clean(os.TempDir()), + } + f2 := vfs.BaseVirtualFolder{ + Name: folder2, + MappedPath: filepath.Clean(os.TempDir()), + } + f3 := vfs.BaseVirtualFolder{ + Name: folder3, + MappedPath: filepath.Clean(os.TempDir()), + } + _, _, err := httpdtest.AddFolder(f1, http.StatusCreated) + assert.NoError(t, err) + _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) + assert.NoError(t, err) + _, _, err = httpdtest.AddFolder(f3, http.StatusCreated) + assert.NoError(t, err) + + u := getTestUser() + u.VirtualFolders = []vfs.VirtualFolder{ + { + BaseVirtualFolder: f1, + VirtualPath: "/" + folder1, + }, + { + BaseVirtualFolder: f2, + VirtualPath: "/" + folder2, + }, + { + BaseVirtualFolder: f3, + VirtualPath: "/" + folder3, + }, + } + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + if assert.Len(t, user.VirtualFolders, 3) { + assert.Equal(t, folder1, user.VirtualFolders[0].Name) + assert.Equal(t, folder2, user.VirtualFolders[1].Name) + assert.Equal(t, folder3, user.VirtualFolders[2].Name) + } + // Update + user.VirtualFolders = []vfs.VirtualFolder{ + { + BaseVirtualFolder: f2, + VirtualPath: "/" + folder2, + }, + { + BaseVirtualFolder: f1, + VirtualPath: "/" + folder1, + }, + { + BaseVirtualFolder: f3, + VirtualPath: "/" + folder3, + }, + } + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + if assert.Len(t, user.VirtualFolders, 3) { + assert.Equal(t, folder2, user.VirtualFolders[0].Name) + assert.Equal(t, folder1, user.VirtualFolders[1].Name) + assert.Equal(t, folder3, user.VirtualFolders[2].Name) + } + + g := getTestGroup() + g.VirtualFolders = []vfs.VirtualFolder{ + { + BaseVirtualFolder: f1, + VirtualPath: "/" + folder1, + }, + { + BaseVirtualFolder: f2, + VirtualPath: "/" + folder2, + }, + { + BaseVirtualFolder: f3, + VirtualPath: "/" + folder3, + }, + } + group, _, err := httpdtest.AddGroup(g, http.StatusCreated) + assert.NoError(t, err) + group, _, err = httpdtest.GetGroupByName(group.Name, http.StatusOK) + assert.NoError(t, err) + if assert.Len(t, group.VirtualFolders, 3) { + assert.Equal(t, folder1, group.VirtualFolders[0].Name) + assert.Equal(t, folder2, group.VirtualFolders[1].Name) + assert.Equal(t, folder3, group.VirtualFolders[2].Name) + } + group, _, err = httpdtest.GetGroupByName(group.Name, http.StatusOK) + assert.NoError(t, err) + group.VirtualFolders = []vfs.VirtualFolder{ + { + BaseVirtualFolder: f3, + VirtualPath: "/" + folder3, + }, + { + BaseVirtualFolder: f1, + VirtualPath: "/" + folder1, + }, + { + BaseVirtualFolder: f2, + VirtualPath: "/" + folder2, + }, + } + group, _, err = httpdtest.UpdateGroup(group, http.StatusOK) + assert.NoError(t, err) + group, _, err = httpdtest.GetGroupByName(group.Name, http.StatusOK) + assert.NoError(t, err) + if assert.Len(t, group.VirtualFolders, 3) { + assert.Equal(t, folder3, group.VirtualFolders[0].Name) + assert.Equal(t, folder1, group.VirtualFolders[1].Name) + assert.Equal(t, folder2, group.VirtualFolders[2].Name) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveGroup(group, http.StatusOK) + assert.NoError(t, err) + + _, err = httpdtest.RemoveFolder(f1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(f2, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(f3, http.StatusOK) + assert.NoError(t, err) +} + +func TestSortRelatedGroups(t *testing.T) { + name1 := util.GenerateUniqueID() + name2 := util.GenerateUniqueID() + name3 := util.GenerateUniqueID() + + g1 := getTestGroup() + g1.Name = name1 + g2 := getTestGroup() + g2.Name = name2 + g3 := getTestGroup() + g3.Name = name3 + + group1, _, err := httpdtest.AddGroup(g1, http.StatusCreated) + assert.NoError(t, err) + group2, _, err := httpdtest.AddGroup(g2, http.StatusCreated) + assert.NoError(t, err) + group3, _, err := httpdtest.AddGroup(g3, http.StatusCreated) + assert.NoError(t, err) + + u := getTestUser() + u.Groups = []sdk.GroupMapping{ + { + Name: name1, + Type: sdk.GroupTypePrimary, + }, + { + Name: name2, + Type: sdk.GroupTypeSecondary, + }, + { + Name: name3, + Type: sdk.GroupTypeMembership, + }, + } + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + if assert.Len(t, user.Groups, 3) { + assert.Equal(t, name1, user.Groups[0].Name) + assert.Equal(t, name2, user.Groups[1].Name) + assert.Equal(t, name3, user.Groups[2].Name) + } + user.Groups = []sdk.GroupMapping{ + { + Name: name2, + Type: sdk.GroupTypeSecondary, + }, + { + Name: name3, + Type: sdk.GroupTypeMembership, + }, + { + Name: name1, + Type: sdk.GroupTypePrimary, + }, + } + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + if assert.Len(t, user.Groups, 3) { + assert.Equal(t, name2, user.Groups[0].Name) + assert.Equal(t, name3, user.Groups[1].Name) + assert.Equal(t, name1, user.Groups[2].Name) + } + + a := getTestAdmin() + a.Username = altAdminUsername + a.Groups = []dataprovider.AdminGroupMapping{ + { + Name: name3, + Options: dataprovider.AdminGroupMappingOptions{ + AddToUsersAs: dataprovider.GroupAddToUsersAsSecondary, + }, + }, + { + Name: name2, + Options: dataprovider.AdminGroupMappingOptions{ + AddToUsersAs: dataprovider.GroupAddToUsersAsPrimary, + }, + }, + { + Name: name1, + Options: dataprovider.AdminGroupMappingOptions{ + AddToUsersAs: dataprovider.GroupAddToUsersAsMembership, + }, + }, + } + admin, _, err := httpdtest.AddAdmin(a, http.StatusCreated) + assert.NoError(t, err) + admin, _, err = httpdtest.GetAdminByUsername(admin.Username, http.StatusOK) + assert.NoError(t, err) + if assert.Len(t, admin.Groups, 3) { + assert.Equal(t, name3, admin.Groups[0].Name) + assert.Equal(t, name2, admin.Groups[1].Name) + assert.Equal(t, name1, admin.Groups[2].Name) + } + admin.Groups = []dataprovider.AdminGroupMapping{ + { + Name: name1, + Options: dataprovider.AdminGroupMappingOptions{ + AddToUsersAs: dataprovider.GroupAddToUsersAsPrimary, + }, + }, + { + Name: name3, + Options: dataprovider.AdminGroupMappingOptions{ + AddToUsersAs: dataprovider.GroupAddToUsersAsMembership, + }, + }, + { + Name: name2, + Options: dataprovider.AdminGroupMappingOptions{ + AddToUsersAs: dataprovider.GroupAddToUsersAsSecondary, + }, + }, + } + admin, _, err = httpdtest.UpdateAdmin(admin, http.StatusOK) + assert.NoError(t, err) + admin, _, err = httpdtest.GetAdminByUsername(admin.Username, http.StatusOK) + assert.NoError(t, err) + if assert.Len(t, admin.Groups, 3) { + assert.Equal(t, name1, admin.Groups[0].Name) + assert.Equal(t, name3, admin.Groups[1].Name) + assert.Equal(t, name2, admin.Groups[2].Name) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveGroup(group1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveGroup(group2, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveGroup(group3, http.StatusOK) + assert.NoError(t, err) +} + func TestBasicGroupHandling(t *testing.T) { g := getTestGroup() g.UserSettings.Filters.TLSCerts = []string{"invalid cert"} // ignored for groups