mirror of
https://github.com/drakkan/sftpgo.git
synced 2025-12-07 23:00:55 +03:00
allow to store temporary sessions within the data provider
so we can persist password reset codes, OIDC auth sessions and tokens. These features will also work in multi-node setups without sicky sessions now Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
@@ -20,7 +20,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
sqlDatabaseVersion = 18
|
||||
sqlDatabaseVersion = 19
|
||||
defaultSQLQueryTimeout = 10 * time.Second
|
||||
longSQLQueryTimeout = 60 * time.Second
|
||||
)
|
||||
@@ -37,7 +37,7 @@ type sqlQuerier interface {
|
||||
}
|
||||
|
||||
type sqlScanner interface {
|
||||
Scan(dest ...interface{}) error
|
||||
Scan(dest ...any) error
|
||||
}
|
||||
|
||||
func sqlReplaceAll(sql string) string {
|
||||
@@ -203,8 +203,11 @@ func sqlCommonDeleteShare(share Share, dbHandle *sql.DB) error {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
_, err = stmt.ExecContext(ctx, share.ShareID)
|
||||
return err
|
||||
res, err := stmt.ExecContext(ctx, share.ShareID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sqlCommonRequireRowAffected(res)
|
||||
}
|
||||
|
||||
func sqlCommonGetShares(limit, offset int, order, username string, dbHandle sqlQuerier) ([]Share, error) {
|
||||
@@ -352,8 +355,11 @@ func sqlCommonDeleteAPIKey(apiKey APIKey, dbHandle *sql.DB) error {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
_, err = stmt.ExecContext(ctx, apiKey.KeyID)
|
||||
return err
|
||||
res, err := stmt.ExecContext(ctx, apiKey.KeyID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sqlCommonRequireRowAffected(res)
|
||||
}
|
||||
|
||||
func sqlCommonGetAPIKeys(limit, offset int, order string, dbHandle sqlQuerier) ([]APIKey, error) {
|
||||
@@ -532,8 +538,11 @@ func sqlCommonDeleteAdmin(admin Admin, dbHandle *sql.DB) error {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
_, err = stmt.ExecContext(ctx, admin.Username)
|
||||
return err
|
||||
res, err := stmt.ExecContext(ctx, admin.Username)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sqlCommonRequireRowAffected(res)
|
||||
}
|
||||
|
||||
func sqlCommonGetAdmins(limit, offset int, order string, dbHandle sqlQuerier) ([]Admin, error) {
|
||||
@@ -667,7 +676,7 @@ func sqlCommonGetUsersInGroups(names []string, dbHandle sqlQuerier) ([]string, e
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
args := make([]interface{}, 0, len(names))
|
||||
args := make([]any, 0, len(names))
|
||||
for _, name := range names {
|
||||
args = append(args, name)
|
||||
}
|
||||
@@ -705,7 +714,7 @@ func sqlCommonGetGroupsWithNames(names []string, dbHandle sqlQuerier) ([]Group,
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
args := make([]interface{}, 0, len(names))
|
||||
args := make([]any, 0, len(names))
|
||||
for _, name := range names {
|
||||
args = append(args, name)
|
||||
}
|
||||
@@ -849,8 +858,11 @@ func sqlCommonDeleteGroup(group Group, dbHandle *sql.DB) error {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
_, err = stmt.ExecContext(ctx, group.Name)
|
||||
return err
|
||||
res, err := stmt.ExecContext(ctx, group.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sqlCommonRequireRowAffected(res)
|
||||
}
|
||||
|
||||
func sqlCommonGetUserByUsername(username string, dbHandle sqlQuerier) (User, error) {
|
||||
@@ -1206,8 +1218,11 @@ func sqlCommonDeleteUser(user User, dbHandle *sql.DB) error {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
_, err = stmt.ExecContext(ctx, user.ID)
|
||||
return err
|
||||
res, err := stmt.ExecContext(ctx, user.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sqlCommonRequireRowAffected(res)
|
||||
}
|
||||
|
||||
func sqlCommonDumpUsers(dbHandle sqlQuerier) ([]User, error) {
|
||||
@@ -1389,7 +1404,7 @@ func sqlCommonGetUsersRangeForQuotaCheck(usernames []string, dbHandle sqlQuerier
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
queryArgs := make([]interface{}, 0, len(usernames))
|
||||
queryArgs := make([]any, 0, len(usernames))
|
||||
for idx := range usernames {
|
||||
queryArgs = append(queryArgs, usernames[idx])
|
||||
}
|
||||
@@ -1730,11 +1745,12 @@ func sqlCommonDeleteDefenderHost(ip string, dbHandle sqlQuerier) error {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
_, err = stmt.ExecContext(ctx, ip)
|
||||
res, err := stmt.ExecContext(ctx, ip)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "unable to delete defender host %#v: %v", ip, err)
|
||||
return err
|
||||
}
|
||||
return err
|
||||
return sqlCommonRequireRowAffected(res)
|
||||
}
|
||||
|
||||
func sqlCommonAddDefenderHostAndEvent(ip string, score int, dbHandle *sql.DB) error {
|
||||
@@ -2160,8 +2176,11 @@ func sqlCommonDeleteFolder(folder vfs.BaseVirtualFolder, dbHandle sqlQuerier) er
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
_, err = stmt.ExecContext(ctx, folder.ID)
|
||||
return err
|
||||
res, err := stmt.ExecContext(ctx, folder.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sqlCommonRequireRowAffected(res)
|
||||
}
|
||||
|
||||
func sqlCommonDumpFolders(dbHandle sqlQuerier) ([]vfs.BaseVirtualFolder, error) {
|
||||
@@ -2911,6 +2930,86 @@ func sqlCommonGetAPIKeyRelatedIDs(apiKey *APIKey) (sql.NullInt64, sql.NullInt64,
|
||||
return userID, adminID, nil
|
||||
}
|
||||
|
||||
func sqlCommonAddSession(session Session, dbHandle *sql.DB) error {
|
||||
if err := session.validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
data, err := json.Marshal(session.Data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
q := getAddSessionQuery()
|
||||
|
||||
stmt, err := dbHandle.PrepareContext(ctx, q)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
_, err = stmt.ExecContext(ctx, session.Key, data, session.Type, session.Timestamp)
|
||||
return err
|
||||
}
|
||||
|
||||
func sqlCommonGetSession(key string, dbHandle sqlQuerier) (Session, error) {
|
||||
var session Session
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
q := getSessionQuery()
|
||||
stmt, err := dbHandle.PrepareContext(ctx, q)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
|
||||
return session, err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
var data []byte // type hint, some driver will use string instead of []byte if the type is any
|
||||
err = stmt.QueryRowContext(ctx, key).Scan(&session.Key, &data, &session.Type, &session.Timestamp)
|
||||
if err != nil {
|
||||
return session, err
|
||||
}
|
||||
session.Data = data
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func sqlCommonDeleteSession(key string, dbHandle *sql.DB) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
q := getDeleteSessionQuery()
|
||||
stmt, err := dbHandle.PrepareContext(ctx, q)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
res, err := stmt.ExecContext(ctx, key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sqlCommonRequireRowAffected(res)
|
||||
}
|
||||
|
||||
func sqlCommonCleanupSessions(sessionType SessionType, before int64, dbHandle *sql.DB) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
q := getCleanupSessionsQuery()
|
||||
stmt, err := dbHandle.PrepareContext(ctx, q)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
_, err = stmt.ExecContext(ctx, sessionType, before)
|
||||
return err
|
||||
}
|
||||
|
||||
func sqlCommonGetDatabaseVersion(dbHandle sqlQuerier, showInitWarn bool) (schemaVersion, error) {
|
||||
var result schemaVersion
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
@@ -2931,6 +3030,16 @@ func sqlCommonGetDatabaseVersion(dbHandle sqlQuerier, showInitWarn bool) (schema
|
||||
return result, err
|
||||
}
|
||||
|
||||
func sqlCommonRequireRowAffected(res sql.Result) error {
|
||||
// MariaDB/MySQL returns 0 rows affected for updates that don't change anything
|
||||
// so we don't check rows affected for updates
|
||||
affected, err := res.RowsAffected()
|
||||
if err == nil && affected == 0 {
|
||||
return util.NewRecordNotFoundError(sql.ErrNoRows.Error())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func sqlCommonUpdateDatabaseVersion(ctx context.Context, dbHandle sqlQuerier, version int) error {
|
||||
q := getUpdateDBVersionQuery()
|
||||
stmt, err := dbHandle.PrepareContext(ctx, q)
|
||||
@@ -2943,8 +3052,8 @@ func sqlCommonUpdateDatabaseVersion(ctx context.Context, dbHandle sqlQuerier, ve
|
||||
return err
|
||||
}
|
||||
|
||||
func sqlCommonExecSQLAndUpdateDBVersion(dbHandle *sql.DB, sqlQueries []string, newVersion int) error {
|
||||
if err := sqlAquireLock(dbHandle); err != nil {
|
||||
func sqlCommonExecSQLAndUpdateDBVersion(dbHandle *sql.DB, sqlQueries []string, newVersion int, isUp bool) error {
|
||||
if err := sqlAcquireLock(dbHandle); err != nil {
|
||||
return err
|
||||
}
|
||||
defer sqlReleaseLock(dbHandle)
|
||||
@@ -2954,10 +3063,12 @@ func sqlCommonExecSQLAndUpdateDBVersion(dbHandle *sql.DB, sqlQueries []string, n
|
||||
|
||||
if newVersion > 0 {
|
||||
currentVersion, err := sqlCommonGetDatabaseVersion(dbHandle, false)
|
||||
if err == nil && currentVersion.Version >= newVersion {
|
||||
providerLog(logger.LevelInfo, "current schema version: %v, requested: %v, did you execute simultaneous migrations?",
|
||||
currentVersion.Version, newVersion)
|
||||
return nil
|
||||
if err == nil {
|
||||
if (isUp && currentVersion.Version >= newVersion) || (!isUp && currentVersion.Version <= newVersion) {
|
||||
providerLog(logger.LevelInfo, "current schema version: %v, requested: %v, did you execute simultaneous migrations?",
|
||||
currentVersion.Version, newVersion)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2978,7 +3089,7 @@ func sqlCommonExecSQLAndUpdateDBVersion(dbHandle *sql.DB, sqlQueries []string, n
|
||||
})
|
||||
}
|
||||
|
||||
func sqlAquireLock(dbHandle *sql.DB) error {
|
||||
func sqlAcquireLock(dbHandle *sql.DB) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user