From e590deebe0fac70ffef04c97e369ab98c3f438b7 Mon Sep 17 00:00:00 2001 From: Nicola Murino Date: Sun, 23 Mar 2025 11:34:10 +0100 Subject: [PATCH] db shared sessions: set key and type as primary key Signed-off-by: Nicola Murino --- internal/dataprovider/bolt.go | 16 ++++---- internal/dataprovider/dataprovider.go | 12 +++--- internal/dataprovider/memory.go | 4 +- internal/dataprovider/mysql.go | 56 +++++++++++++++++++++++--- internal/dataprovider/pgsql.go | 57 +++++++++++++++++++++++--- internal/dataprovider/sqlcommon.go | 10 ++--- internal/dataprovider/sqlite.go | 58 ++++++++++++++++++++++++--- internal/dataprovider/sqlqueries.go | 16 ++++---- internal/httpd/internal_test.go | 51 ++++++++++++++++++++++- internal/httpd/oauth2.go | 4 +- internal/httpd/oauth2_test.go | 8 ++-- internal/httpd/oidcmanager.go | 8 ++-- internal/httpd/resetcode.go | 4 +- internal/httpd/token.go | 2 +- internal/httpd/webtask.go | 2 +- 15 files changed, 250 insertions(+), 58 deletions(-) diff --git a/internal/dataprovider/bolt.go b/internal/dataprovider/bolt.go index d3000031..104e9dbb 100644 --- a/internal/dataprovider/bolt.go +++ b/internal/dataprovider/bolt.go @@ -40,7 +40,7 @@ import ( ) const ( - boltDatabaseVersion = 30 + boltDatabaseVersion = 31 ) var ( @@ -2135,11 +2135,11 @@ func (p *BoltProvider) addSharedSession(_ Session) error { return ErrNotImplemented } -func (p *BoltProvider) deleteSharedSession(_ string) error { +func (p *BoltProvider) deleteSharedSession(_ string, _ SessionType) error { return ErrNotImplemented } -func (p *BoltProvider) getSharedSession(_ string) (Session, error) { +func (p *BoltProvider) getSharedSession(_ string, _ SessionType) (Session, error) { return Session{}, ErrNotImplemented } @@ -3186,10 +3186,10 @@ func (p *BoltProvider) migrateDatabase() error { providerLog(logger.LevelError, "%v", err) logger.ErrorToConsole("%v", err) return err - case version == 29: - logger.InfoToConsole("updating database schema version: %d -> 30", version) - providerLog(logger.LevelInfo, "updating database schema version: %d -> 30", version) - return updateBoltDatabaseVersion(p.dbHandle, 30) + case version == 29, version == 30: + logger.InfoToConsole("updating database schema version: %d -> 31", version) + providerLog(logger.LevelInfo, "updating database schema version: %d -> 31", version) + return updateBoltDatabaseVersion(p.dbHandle, 31) default: if version > boltDatabaseVersion { providerLog(logger.LevelError, "database schema version %d is newer than the supported one: %d", version, @@ -3211,7 +3211,7 @@ func (p *BoltProvider) revertDatabase(targetVersion int) error { //nolint:gocycl return errors.New("current version match target version, nothing to do") } switch dbVersion.Version { - case 30: + case 30, 31: logger.InfoToConsole("downgrading database schema version: %d -> 29", dbVersion.Version) providerLog(logger.LevelInfo, "downgrading database schema version: %d -> 29", dbVersion.Version) return updateBoltDatabaseVersion(p.dbHandle, 29) diff --git a/internal/dataprovider/dataprovider.go b/internal/dataprovider/dataprovider.go index 4cd72bde..2b7aeada 100644 --- a/internal/dataprovider/dataprovider.go +++ b/internal/dataprovider/dataprovider.go @@ -824,8 +824,8 @@ type Provider interface { cleanupActiveTransfers(before time.Time) error getActiveTransfers(from time.Time) ([]ActiveTransfer, error) addSharedSession(session Session) error - deleteSharedSession(key string) error - getSharedSession(key string) (Session, error) + deleteSharedSession(key string, sessionType SessionType) error + getSharedSession(key string, sessionType SessionType) (Session, error) cleanupSharedSessions(sessionType SessionType, before int64) error getEventActions(limit, offset int, order string, minimal bool) ([]BaseEventAction, error) dumpEventActions() ([]BaseEventAction, error) @@ -2244,8 +2244,8 @@ func AddSharedSession(session Session) error { } // DeleteSharedSession deletes the session with the specified key -func DeleteSharedSession(key string) error { - err := provider.deleteSharedSession(key) +func DeleteSharedSession(key string, sessionType SessionType) error { + err := provider.deleteSharedSession(key, sessionType) if err != nil { providerLog(logger.LevelError, "unable to add shared session, key %q, err: %v", key, err) } @@ -2253,8 +2253,8 @@ func DeleteSharedSession(key string) error { } // GetSharedSession retrieves the session with the specified key -func GetSharedSession(key string) (Session, error) { - return provider.getSharedSession(key) +func GetSharedSession(key string, sessionType SessionType) (Session, error) { + return provider.getSharedSession(key, sessionType) } // CleanupSharedSessions removes the shared session with the specified type and diff --git a/internal/dataprovider/memory.go b/internal/dataprovider/memory.go index 9cf0b2f6..8c7c0c48 100644 --- a/internal/dataprovider/memory.go +++ b/internal/dataprovider/memory.go @@ -2117,11 +2117,11 @@ func (p *MemoryProvider) addSharedSession(_ Session) error { return ErrNotImplemented } -func (p *MemoryProvider) deleteSharedSession(_ string) error { +func (p *MemoryProvider) deleteSharedSession(_ string, _ SessionType) error { return ErrNotImplemented } -func (p *MemoryProvider) getSharedSession(_ string) (Session, error) { +func (p *MemoryProvider) getSharedSession(_ string, _ SessionType) (Session, error) { return Session{}, ErrNotImplemented } diff --git a/internal/dataprovider/mysql.go b/internal/dataprovider/mysql.go index 78a1f933..4e573379 100644 --- a/internal/dataprovider/mysql.go +++ b/internal/dataprovider/mysql.go @@ -196,6 +196,16 @@ const ( "INSERT INTO {{schema_version}} (version) VALUES (29);" mysqlV30SQL = "ALTER TABLE `{{shares}}` ADD COLUMN `options` longtext NULL;" mysqlV30DownSQL = "ALTER TABLE `{{shares}}` DROP COLUMN `options`;" + mysqlV31SQL = "DROP TABLE IF EXISTS `{{shared_sessions}}` CASCADE;" + + "CREATE TABLE `shared_sessions` (`key` varchar(128) NOT NULL, `type` integer NOT NULL, `data` longtext NOT NULL, " + + "`timestamp` bigint NOT NULL, PRIMARY KEY (`key`, `type`));" + + "CREATE INDEX `{{prefix}}shared_sessions_type_idx` ON `{{shared_sessions}}` (`type`);" + + "CREATE INDEX `{{prefix}}shared_sessions_timestamp_idx` ON `{{shared_sessions}}` (`timestamp`);" + mysqlV31DownSQL = "DROP TABLE IF EXISTS `{{shared_sessions}}` CASCADE;" + + "CREATE TABLE `{{shared_sessions}}` (`key` varchar(128) NOT NULL PRIMARY KEY, " + + "`data` longtext NOT NULL, `type` integer NOT NULL, `timestamp` bigint NOT NULL);" + + "CREATE INDEX `{{prefix}}shared_sessions_type_idx` ON `{{shared_sessions}}` (`type`);" + + "CREATE INDEX `{{prefix}}shared_sessions_timestamp_idx` ON `{{shared_sessions}}` (`timestamp`);" ) // MySQLProvider defines the auth provider for MySQL/MariaDB database @@ -589,12 +599,12 @@ func (p *MySQLProvider) addSharedSession(session Session) error { return sqlCommonAddSession(session, p.dbHandle) } -func (p *MySQLProvider) deleteSharedSession(key string) error { - return sqlCommonDeleteSession(key, p.dbHandle) +func (p *MySQLProvider) deleteSharedSession(key string, sessionType SessionType) error { + return sqlCommonDeleteSession(key, sessionType, p.dbHandle) } -func (p *MySQLProvider) getSharedSession(key string) (Session, error) { - return sqlCommonGetSession(key, p.dbHandle) +func (p *MySQLProvider) getSharedSession(key string, sessionType SessionType) (Session, error) { + return sqlCommonGetSession(key, sessionType, p.dbHandle) } func (p *MySQLProvider) cleanupSharedSessions(sessionType SessionType, before int64) error { @@ -806,6 +816,8 @@ func (p *MySQLProvider) migrateDatabase() error { return err case version == 29: return updateMySQLDatabaseFromV29(p.dbHandle) + case version == 30: + return updateMySQLDatabaseFromV30(p.dbHandle) default: if version > sqlDatabaseVersion { providerLog(logger.LevelError, "database schema version %d is newer than the supported one: %d", version, @@ -830,6 +842,8 @@ func (p *MySQLProvider) revertDatabase(targetVersion int) error { switch dbVersion.Version { case 30: return downgradeMySQLDatabaseFromV30(p.dbHandle) + case 31: + return downgradeMySQLDatabaseFromV31(p.dbHandle) default: return fmt.Errorf("database schema version not handled: %d", dbVersion.Version) } @@ -869,13 +883,27 @@ func (p *MySQLProvider) normalizeError(err error, fieldType int) error { } func updateMySQLDatabaseFromV29(dbHandle *sql.DB) error { - return updateMySQLDatabaseFrom29To30(dbHandle) + if err := updateMySQLDatabaseFrom29To30(dbHandle); err != nil { + return err + } + return updateMySQLDatabaseFromV30(dbHandle) +} + +func updateMySQLDatabaseFromV30(dbHandle *sql.DB) error { + return updateMySQLDatabaseFrom30To31(dbHandle) } func downgradeMySQLDatabaseFromV30(dbHandle *sql.DB) error { return downgradeMySQLDatabaseFrom30To29(dbHandle) } +func downgradeMySQLDatabaseFromV31(dbHandle *sql.DB) error { + if err := downgradeMySQLDatabaseFrom31To30(dbHandle); err != nil { + return err + } + return downgradeMySQLDatabaseFromV30(dbHandle) +} + func updateMySQLDatabaseFrom29To30(dbHandle *sql.DB) error { logger.InfoToConsole("updating database schema version: 29 -> 30") providerLog(logger.LevelInfo, "updating database schema version: 29 -> 30") @@ -891,3 +919,21 @@ func downgradeMySQLDatabaseFrom30To29(dbHandle *sql.DB) error { sql := strings.ReplaceAll(mysqlV30DownSQL, "{{shares}}", sqlTableShares) return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 29, false) } + +func updateMySQLDatabaseFrom30To31(dbHandle *sql.DB) error { + logger.InfoToConsole("updating database schema version: 30 -> 31") + providerLog(logger.LevelInfo, "updating database schema version: 30 -> 31") + + sql := strings.ReplaceAll(mysqlV31SQL, "{{shared_sessions}}", sqlTableSharedSessions) + sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 31, true) +} + +func downgradeMySQLDatabaseFrom31To30(dbHandle *sql.DB) error { + logger.InfoToConsole("downgrading database schema version: 31 -> 30") + providerLog(logger.LevelInfo, "downgrading database schema version: 31 -> 30") + + sql := strings.ReplaceAll(mysqlV31DownSQL, "{{shared_sessions}}", sqlTableSharedSessions) + sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 30, false) +} diff --git a/internal/dataprovider/pgsql.go b/internal/dataprovider/pgsql.go index 5bc0e5fe..19f4c35b 100644 --- a/internal/dataprovider/pgsql.go +++ b/internal/dataprovider/pgsql.go @@ -212,6 +212,17 @@ INSERT INTO {{schema_version}} (version) VALUES (29); ipListsLikeIndex = `CREATE INDEX "{{prefix}}ip_lists_ipornet_like_idx" ON "{{ip_lists}}" ("ipornet" varchar_pattern_ops);` pgsqlV30SQL = `ALTER TABLE "{{shares}}" ADD COLUMN "options" text NULL;` pgsqlV30DownSQL = `ALTER TABLE "{{shares}}" DROP COLUMN "options" CASCADE;` + pgsqlV31SQL = `DROP TABLE "{{shared_sessions}}"; +CREATE TABLE "{{shared_sessions}}" ("key" varchar(128) NOT NULL, "type" integer NOT NULL, +"data" text NOT NULL, "timestamp" bigint NOT NULL, PRIMARY KEY ("key", "type")); +CREATE INDEX "{{prefix}}shared_sessions_type_idx" ON "{{shared_sessions}}" ("type"); +CREATE INDEX "{{prefix}}shared_sessions_timestamp_idx" ON "{{shared_sessions}}" ("timestamp"); +` + pgsqlV31DownSQL = `DROP TABLE "{{shared_sessions}}" CASCADE; +CREATE TABLE "{{shared_sessions}}" ("key" varchar(128) NOT NULL PRIMARY KEY, +"data" text NOT NULL, "type" integer NOT NULL, "timestamp" bigint NOT NULL); +CREATE INDEX "{{prefix}}shared_sessions_type_idx" ON "{{shared_sessions}}" ("type"); +CREATE INDEX "{{prefix}}shared_sessions_timestamp_idx" ON "{{shared_sessions}}" ("timestamp");` ) var ( @@ -607,12 +618,12 @@ func (p *PGSQLProvider) addSharedSession(session Session) error { return sqlCommonAddSession(session, p.dbHandle) } -func (p *PGSQLProvider) deleteSharedSession(key string) error { - return sqlCommonDeleteSession(key, p.dbHandle) +func (p *PGSQLProvider) deleteSharedSession(key string, sessionType SessionType) error { + return sqlCommonDeleteSession(key, sessionType, p.dbHandle) } -func (p *PGSQLProvider) getSharedSession(key string) (Session, error) { - return sqlCommonGetSession(key, p.dbHandle) +func (p *PGSQLProvider) getSharedSession(key string, sessionType SessionType) (Session, error) { + return sqlCommonGetSession(key, sessionType, p.dbHandle) } func (p *PGSQLProvider) cleanupSharedSessions(sessionType SessionType, before int64) error { @@ -830,6 +841,8 @@ func (p *PGSQLProvider) migrateDatabase() error { //nolint:dupl return err case version == 29: return updatePGSQLDatabaseFromV29(p.dbHandle) + case version == 30: + return updatePGSQLDatabaseFromV30(p.dbHandle) default: if version > sqlDatabaseVersion { providerLog(logger.LevelError, "database schema version %d is newer than the supported one: %d", version, @@ -854,6 +867,8 @@ func (p *PGSQLProvider) revertDatabase(targetVersion int) error { switch dbVersion.Version { case 30: return downgradePGSQLDatabaseFromV30(p.dbHandle) + case 31: + return downgradePGSQLDatabaseFromV31(p.dbHandle) default: return fmt.Errorf("database schema version not handled: %d", dbVersion.Version) } @@ -893,13 +908,27 @@ func (p *PGSQLProvider) normalizeError(err error, fieldType int) error { } func updatePGSQLDatabaseFromV29(dbHandle *sql.DB) error { - return updatePGSQLDatabaseFrom29To30(dbHandle) + if err := updatePGSQLDatabaseFrom29To30(dbHandle); err != nil { + return err + } + return updatePGSQLDatabaseFromV30(dbHandle) +} + +func updatePGSQLDatabaseFromV30(dbHandle *sql.DB) error { + return updatePGSQLDatabaseFrom30To31(dbHandle) } func downgradePGSQLDatabaseFromV30(dbHandle *sql.DB) error { return downgradePGSQLDatabaseFrom30To29(dbHandle) } +func downgradePGSQLDatabaseFromV31(dbHandle *sql.DB) error { + if err := downgradePGSQLDatabaseFrom31To30(dbHandle); err != nil { + return err + } + return downgradePGSQLDatabaseFromV30(dbHandle) +} + func updatePGSQLDatabaseFrom29To30(dbHandle *sql.DB) error { logger.InfoToConsole("updating database schema version: 29 -> 30") providerLog(logger.LevelInfo, "updating database schema version: 29 -> 30") @@ -915,3 +944,21 @@ func downgradePGSQLDatabaseFrom30To29(dbHandle *sql.DB) error { sql := strings.ReplaceAll(pgsqlV30DownSQL, "{{shares}}", sqlTableShares) return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 29, false) } + +func updatePGSQLDatabaseFrom30To31(dbHandle *sql.DB) error { + logger.InfoToConsole("updating database schema version: 30 -> 31") + providerLog(logger.LevelInfo, "updating database schema version: 30 -> 31") + + sql := strings.ReplaceAll(pgsqlV31SQL, "{{shared_sessions}}", sqlTableSharedSessions) + sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 31, true) +} + +func downgradePGSQLDatabaseFrom31To30(dbHandle *sql.DB) error { + logger.InfoToConsole("downgrading database schema version: 31 -> 30") + providerLog(logger.LevelInfo, "downgrading database schema version: 31 -> 30") + + sql := strings.ReplaceAll(pgsqlV31DownSQL, "{{shared_sessions}}", sqlTableSharedSessions) + sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 30, false) +} diff --git a/internal/dataprovider/sqlcommon.go b/internal/dataprovider/sqlcommon.go index 7fb626d1..183166f1 100644 --- a/internal/dataprovider/sqlcommon.go +++ b/internal/dataprovider/sqlcommon.go @@ -36,7 +36,7 @@ import ( ) const ( - sqlDatabaseVersion = 30 + sqlDatabaseVersion = 31 defaultSQLQueryTimeout = 10 * time.Second longSQLQueryTimeout = 60 * time.Second ) @@ -3265,14 +3265,14 @@ func sqlCommonAddSession(session Session, dbHandle *sql.DB) error { return err } -func sqlCommonGetSession(key string, dbHandle sqlQuerier) (Session, error) { +func sqlCommonGetSession(key string, sessionType SessionType, dbHandle sqlQuerier) (Session, error) { var session Session ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getSessionQuery() var data []byte // type hint, some driver will use string instead of []byte if the type is any - err := dbHandle.QueryRowContext(ctx, q, key).Scan(&session.Key, &data, &session.Type, &session.Timestamp) + err := dbHandle.QueryRowContext(ctx, q, key, sessionType).Scan(&session.Key, &data, &session.Type, &session.Timestamp) if err != nil { if errors.Is(err, sql.ErrNoRows) { return session, util.NewRecordNotFoundError(err.Error()) @@ -3283,12 +3283,12 @@ func sqlCommonGetSession(key string, dbHandle sqlQuerier) (Session, error) { return session, nil } -func sqlCommonDeleteSession(key string, dbHandle *sql.DB) error { +func sqlCommonDeleteSession(key string, sessionType SessionType, dbHandle *sql.DB) error { ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) defer cancel() q := getDeleteSessionQuery() - res, err := dbHandle.ExecContext(ctx, q, key) + res, err := dbHandle.ExecContext(ctx, q, key, sessionType) if err != nil { return err } diff --git a/internal/dataprovider/sqlite.go b/internal/dataprovider/sqlite.go index b2b77af3..27ded2bc 100644 --- a/internal/dataprovider/sqlite.go +++ b/internal/dataprovider/sqlite.go @@ -183,6 +183,18 @@ INSERT INTO {{schema_version}} (version) VALUES (29); ` sqliteV30SQL = `ALTER TABLE "{{shares}}" ADD COLUMN "options" text NULL;` sqliteV30DownSQL = `ALTER TABLE "{{shares}}" DROP COLUMN "options";` + sqliteV31SQL = `DROP TABLE "{{shared_sessions}}"; +CREATE TABLE "{{shared_sessions}}" ("key" varchar(128) NOT NULL, "type" integer NOT NULL, +"data" text NOT NULL, "timestamp" bigint NOT NULL, PRIMARY KEY ("key", "type")); +CREATE INDEX "{{prefix}}shared_sessions_type_idx" ON "{{shared_sessions}}" ("type"); +CREATE INDEX "{{prefix}}shared_sessions_timestamp_idx" ON "{{shared_sessions}}" ("timestamp"); +` + sqliteV31DownSQL = `DROP TABLE "{{shared_sessions}}"; +CREATE TABLE "{{shared_sessions}}" ("key" varchar(128) NOT NULL PRIMARY KEY, "data" text NOT NULL, +"type" integer NOT NULL, "timestamp" bigint NOT NULL); +CREATE INDEX "{{prefix}}shared_sessions_type_idx" ON "{{shared_sessions}}" ("type"); +CREATE INDEX "{{prefix}}shared_sessions_timestamp_idx" ON "{{shared_sessions}}" ("timestamp"); +` ) // SQLiteProvider defines the auth provider for SQLite database @@ -511,12 +523,12 @@ func (p *SQLiteProvider) addSharedSession(session Session) error { return sqlCommonAddSession(session, p.dbHandle) } -func (p *SQLiteProvider) deleteSharedSession(key string) error { - return sqlCommonDeleteSession(key, p.dbHandle) +func (p *SQLiteProvider) deleteSharedSession(key string, sessionType SessionType) error { + return sqlCommonDeleteSession(key, sessionType, p.dbHandle) } -func (p *SQLiteProvider) getSharedSession(key string) (Session, error) { - return sqlCommonGetSession(key, p.dbHandle) +func (p *SQLiteProvider) getSharedSession(key string, sessionType SessionType) (Session, error) { + return sqlCommonGetSession(key, sessionType, p.dbHandle) } func (p *SQLiteProvider) cleanupSharedSessions(sessionType SessionType, before int64) error { @@ -727,6 +739,8 @@ func (p *SQLiteProvider) migrateDatabase() error { //nolint:dupl return err case version == 29: return updateSQLiteDatabaseFromV29(p.dbHandle) + case version == 30: + return updateSQLiteDatabaseFromV30(p.dbHandle) default: if version > sqlDatabaseVersion { providerLog(logger.LevelError, "database schema version %d is newer than the supported one: %d", version, @@ -751,6 +765,8 @@ func (p *SQLiteProvider) revertDatabase(targetVersion int) error { switch dbVersion.Version { case 30: return downgradeSQLiteDatabaseFromV30(p.dbHandle) + case 31: + return downgradeSQLiteDatabaseFromV31(p.dbHandle) default: return fmt.Errorf("database schema version not handled: %d", dbVersion.Version) } @@ -797,13 +813,27 @@ func executePragmaOptimize(dbHandle *sql.DB) error { } func updateSQLiteDatabaseFromV29(dbHandle *sql.DB) error { - return updateSQLiteDatabaseFrom29To30(dbHandle) + if err := updateSQLiteDatabaseFrom29To30(dbHandle); err != nil { + return err + } + return updateSQLiteDatabaseFromV30(dbHandle) +} + +func updateSQLiteDatabaseFromV30(dbHandle *sql.DB) error { + return updateSQLiteDatabaseFrom30To31(dbHandle) } func downgradeSQLiteDatabaseFromV30(dbHandle *sql.DB) error { return downgradeSQLiteDatabaseFrom30To29(dbHandle) } +func downgradeSQLiteDatabaseFromV31(dbHandle *sql.DB) error { + if err := downgradeSQLiteDatabaseFrom31To30(dbHandle); err != nil { + return err + } + return downgradeSQLiteDatabaseFromV30(dbHandle) +} + func updateSQLiteDatabaseFrom29To30(dbHandle *sql.DB) error { logger.InfoToConsole("updating database schema version: 29 -> 30") providerLog(logger.LevelInfo, "updating database schema version: 29 -> 30") @@ -820,6 +850,24 @@ func downgradeSQLiteDatabaseFrom30To29(dbHandle *sql.DB) error { return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 29, false) } +func updateSQLiteDatabaseFrom30To31(dbHandle *sql.DB) error { + logger.InfoToConsole("updating database schema version: 30 -> 31") + providerLog(logger.LevelInfo, "updating database schema version: 30 -> 31") + + sql := strings.ReplaceAll(sqliteV31SQL, "{{shared_sessions}}", sqlTableSharedSessions) + sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 31, true) +} + +func downgradeSQLiteDatabaseFrom31To30(dbHandle *sql.DB) error { + logger.InfoToConsole("downgrading database schema version: 31 -> 30") + providerLog(logger.LevelInfo, "downgrading database schema version: 31 -> 30") + + sql := strings.ReplaceAll(sqliteV31DownSQL, "{{shared_sessions}}", sqlTableSharedSessions) + sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 30, false) +} + /*func setPragmaFK(dbHandle *sql.DB, value string) error { ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout) defer cancel() diff --git a/internal/dataprovider/sqlqueries.go b/internal/dataprovider/sqlqueries.go index dd4fe35b..957cd82a 100644 --- a/internal/dataprovider/sqlqueries.go +++ b/internal/dataprovider/sqlqueries.go @@ -81,25 +81,27 @@ func getAddSessionQuery() string { "ON DUPLICATE KEY UPDATE `data`=VALUES(`data`), `timestamp`=VALUES(`timestamp`)", sqlTableSharedSessions, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) } - return fmt.Sprintf(`INSERT INTO %s (key,data,type,timestamp) VALUES (%s,%s,%s,%s) ON CONFLICT(key) DO UPDATE SET data= + return fmt.Sprintf(`INSERT INTO %s (key,data,type,timestamp) VALUES (%s,%s,%s,%s) ON CONFLICT(key,type) DO UPDATE SET data= EXCLUDED.data, timestamp=EXCLUDED.timestamp`, sqlTableSharedSessions, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) } func getDeleteSessionQuery() string { if config.Driver == MySQLDataProviderName { - return fmt.Sprintf("DELETE FROM %s WHERE `key` = %s", sqlTableSharedSessions, sqlPlaceholders[0]) + return fmt.Sprintf("DELETE FROM %s WHERE `key` = %s AND `type` = %s", + sqlTableSharedSessions, sqlPlaceholders[0], sqlPlaceholders[1]) } - return fmt.Sprintf(`DELETE FROM %s WHERE key = %s`, sqlTableSharedSessions, sqlPlaceholders[0]) + return fmt.Sprintf(`DELETE FROM %s WHERE key = %s AND type = %s`, + sqlTableSharedSessions, sqlPlaceholders[0], sqlPlaceholders[1]) } func getSessionQuery() string { if config.Driver == MySQLDataProviderName { - return fmt.Sprintf("SELECT `key`,`data`,`type`,`timestamp` FROM %s WHERE `key` = %s", sqlTableSharedSessions, - sqlPlaceholders[0]) + return fmt.Sprintf("SELECT `key`,`data`,`type`,`timestamp` FROM %s WHERE `key` = %s AND `type` = %s", + sqlTableSharedSessions, sqlPlaceholders[0], sqlPlaceholders[1]) } - return fmt.Sprintf(`SELECT key,data,type,timestamp FROM %s WHERE key = %s`, sqlTableSharedSessions, - sqlPlaceholders[0]) + return fmt.Sprintf(`SELECT key,data,type,timestamp FROM %s WHERE key = %s AND type = %s`, + sqlTableSharedSessions, sqlPlaceholders[0], sqlPlaceholders[1]) } func getCleanupSessionsQuery() string { diff --git a/internal/httpd/internal_test.go b/internal/httpd/internal_test.go index 831f6bb0..278c78b3 100644 --- a/internal/httpd/internal_test.go +++ b/internal/httpd/internal_test.go @@ -2395,10 +2395,59 @@ func TestDbTokenManager(t *testing.T) { dbTokenManager.Cleanup() isInvalidated = dbTokenManager.Get(testToken) assert.True(t, isInvalidated) - err := dataprovider.DeleteSharedSession(key) + err := dataprovider.DeleteSharedSession(key, dataprovider.SessionTypeInvalidToken) assert.NoError(t, err) } +func TestDatabaseSharedSessions(t *testing.T) { + if !isSharedProviderSupported() { + t.Skip("this test it is not available with this provider") + } + session1 := dataprovider.Session{ + Key: "1", + Data: map[string]string{"a": "b"}, + Type: dataprovider.SessionTypeOIDCAuth, + Timestamp: 10, + } + err := dataprovider.AddSharedSession(session1) + assert.NoError(t, err) + // Adding another session with the same key but a different type should work + session2 := session1 + session2.Type = dataprovider.SessionTypeOIDCToken + err = dataprovider.AddSharedSession(session2) + assert.NoError(t, err) + err = dataprovider.DeleteSharedSession(session1.Key, dataprovider.SessionTypeInvalidToken) + assert.ErrorIs(t, err, util.ErrNotFound) + _, err = dataprovider.GetSharedSession(session1.Key, dataprovider.SessionTypeResetCode) + assert.ErrorIs(t, err, util.ErrNotFound) + session1Get, err := dataprovider.GetSharedSession(session1.Key, dataprovider.SessionTypeOIDCAuth) + assert.NoError(t, err) + assert.Equal(t, session1.Timestamp, session1Get.Timestamp) + var stored map[string]string + err = json.Unmarshal(session1Get.Data.([]byte), &stored) + assert.NoError(t, err) + assert.Equal(t, session1.Data, stored) + session1.Timestamp = 20 + session1.Data = map[string]string{"c": "d"} + err = dataprovider.AddSharedSession(session1) + assert.NoError(t, err) + session1Get, err = dataprovider.GetSharedSession(session1.Key, dataprovider.SessionTypeOIDCAuth) + assert.NoError(t, err) + assert.Equal(t, session1.Timestamp, session1Get.Timestamp) + stored = make(map[string]string) + err = json.Unmarshal(session1Get.Data.([]byte), &stored) + assert.NoError(t, err) + assert.Equal(t, session1.Data, stored) + err = dataprovider.DeleteSharedSession(session1.Key, dataprovider.SessionTypeOIDCAuth) + assert.NoError(t, err) + err = dataprovider.DeleteSharedSession(session2.Key, dataprovider.SessionTypeOIDCToken) + assert.NoError(t, err) + _, err = dataprovider.GetSharedSession(session1.Key, dataprovider.SessionTypeOIDCAuth) + assert.ErrorIs(t, err, util.ErrNotFound) + _, err = dataprovider.GetSharedSession(session2.Key, dataprovider.SessionTypeOIDCToken) + assert.ErrorIs(t, err, util.ErrNotFound) +} + func TestAllowedProxyUnixDomainSocket(t *testing.T) { b := Binding{ Address: filepath.Join(os.TempDir(), "sock"), diff --git a/internal/httpd/oauth2.go b/internal/httpd/oauth2.go index 88a599d3..fff4b841 100644 --- a/internal/httpd/oauth2.go +++ b/internal/httpd/oauth2.go @@ -132,11 +132,11 @@ func (o *dbOAuth2Manager) addPendingAuth(pendingAuth oauth2PendingAuth) { } func (o *dbOAuth2Manager) removePendingAuth(state string) { - dataprovider.DeleteSharedSession(state) //nolint:errcheck + dataprovider.DeleteSharedSession(state, dataprovider.SessionTypeOAuth2Auth) //nolint:errcheck } func (o *dbOAuth2Manager) getPendingAuth(state string) (oauth2PendingAuth, error) { - session, err := dataprovider.GetSharedSession(state) + session, err := dataprovider.GetSharedSession(state, dataprovider.SessionTypeOAuth2Auth) if err != nil { return oauth2PendingAuth{}, errors.New("oauth2: unable to get the auth request for the specified state") } diff --git a/internal/httpd/oauth2_test.go b/internal/httpd/oauth2_test.go index a676905d..0e488136 100644 --- a/internal/httpd/oauth2_test.go +++ b/internal/httpd/oauth2_test.go @@ -86,7 +86,7 @@ func TestDbOAuth2Manager(t *testing.T) { a, err := m.getPendingAuth(auth.State) assert.NoError(t, err) assert.Equal(t, sdkkms.SecretStatusPlain, a.ClientSecret.GetStatus()) - session, err := dataprovider.GetSharedSession(auth.State) + session, err := dataprovider.GetSharedSession(auth.State, dataprovider.SessionTypeOAuth2Auth) assert.NoError(t, err) authReq := oauth2PendingAuth{} err = json.Unmarshal(session.Data.([]byte), &authReq) @@ -107,10 +107,10 @@ func TestDbOAuth2Manager(t *testing.T) { m.addPendingAuth(auth) _, err = m.getPendingAuth(auth.State) assert.Error(t, err) - _, err = dataprovider.GetSharedSession(auth.State) + _, err = dataprovider.GetSharedSession(auth.State, dataprovider.SessionTypeOAuth2Auth) assert.NoError(t, err) m.cleanup() - _, err = dataprovider.GetSharedSession(auth.State) + _, err = dataprovider.GetSharedSession(auth.State, dataprovider.SessionTypeOAuth2Auth) assert.Error(t, err) _, err = m.decodePendingAuthData("not a byte array") require.Error(t, err) @@ -126,7 +126,7 @@ func TestDbOAuth2Manager(t *testing.T) { } auth.ClientSecret.SetStatus(sdkkms.SecretStatusSecretBox) m.addPendingAuth(auth) - _, err = dataprovider.GetSharedSession(auth.State) + _, err = dataprovider.GetSharedSession(auth.State, dataprovider.SessionTypeOAuth2Auth) assert.Error(t, err) asJSON, err := json.Marshal(auth) assert.NoError(t, err) diff --git a/internal/httpd/oidcmanager.go b/internal/httpd/oidcmanager.go index d0acade6..79336748 100644 --- a/internal/httpd/oidcmanager.go +++ b/internal/httpd/oidcmanager.go @@ -167,11 +167,11 @@ func (o *dbOIDCManager) addPendingAuth(pendingAuth oidcPendingAuth) { } func (o *dbOIDCManager) removePendingAuth(state string) { - dataprovider.DeleteSharedSession(state) //nolint:errcheck + dataprovider.DeleteSharedSession(state, dataprovider.SessionTypeOIDCAuth) //nolint:errcheck } func (o *dbOIDCManager) getPendingAuth(state string) (oidcPendingAuth, error) { - session, err := dataprovider.GetSharedSession(state) + session, err := dataprovider.GetSharedSession(state, dataprovider.SessionTypeOIDCAuth) if err != nil { return oidcPendingAuth{}, errors.New("oidc: unable to get the auth request for the specified state") } @@ -204,7 +204,7 @@ func (o *dbOIDCManager) addToken(token oidcToken) { } func (o *dbOIDCManager) removeToken(cookie string) { - dataprovider.DeleteSharedSession(cookie) //nolint:errcheck + dataprovider.DeleteSharedSession(cookie, dataprovider.SessionTypeOIDCToken) //nolint:errcheck } func (o *dbOIDCManager) updateTokenUsage(token oidcToken) { @@ -215,7 +215,7 @@ func (o *dbOIDCManager) updateTokenUsage(token oidcToken) { } func (o *dbOIDCManager) getToken(cookie string) (oidcToken, error) { - session, err := dataprovider.GetSharedSession(cookie) + session, err := dataprovider.GetSharedSession(cookie, dataprovider.SessionTypeOIDCToken) if err != nil { return oidcToken{}, errors.New("oidc: unable to get the token for the specified session") } diff --git a/internal/httpd/resetcode.go b/internal/httpd/resetcode.go index 6e680c02..0be7d890 100644 --- a/internal/httpd/resetcode.go +++ b/internal/httpd/resetcode.go @@ -110,7 +110,7 @@ func (m *dbResetCodeManager) Add(code *resetCode) error { } func (m *dbResetCodeManager) Get(code string) (*resetCode, error) { - session, err := dataprovider.GetSharedSession(code) + session, err := dataprovider.GetSharedSession(code, dataprovider.SessionTypeResetCode) if err != nil { return nil, err } @@ -132,7 +132,7 @@ func (m *dbResetCodeManager) decodeData(data any) (*resetCode, error) { } func (m *dbResetCodeManager) Delete(code string) error { - return dataprovider.DeleteSharedSession(code) + return dataprovider.DeleteSharedSession(code, dataprovider.SessionTypeResetCode) } func (m *dbResetCodeManager) Cleanup() { diff --git a/internal/httpd/token.go b/internal/httpd/token.go index 93413281..4a0f9873 100644 --- a/internal/httpd/token.go +++ b/internal/httpd/token.go @@ -86,7 +86,7 @@ func (m *dbTokenManager) Add(token string, expiresAt time.Time) { func (m *dbTokenManager) Get(token string) bool { key := m.getKey(token) - _, err := dataprovider.GetSharedSession(key) + _, err := dataprovider.GetSharedSession(key, dataprovider.SessionTypeInvalidToken) return err == nil } diff --git a/internal/httpd/webtask.go b/internal/httpd/webtask.go index 8375c9c0..0c328c06 100644 --- a/internal/httpd/webtask.go +++ b/internal/httpd/webtask.go @@ -93,7 +93,7 @@ func (m *dbTaskManager) Add(data webTaskData) error { } func (m *dbTaskManager) Get(ID string) (webTaskData, error) { - sess, err := dataprovider.GetSharedSession(ID) + sess, err := dataprovider.GetSharedSession(ID, dataprovider.SessionTypeWebTask) if err != nil { return webTaskData{}, err }