mirror of
https://github.com/drakkan/sftpgo.git
synced 2025-12-07 14:50:55 +03:00
data provider: add CockroachDB support
This commit is contained in:
@@ -10,6 +10,8 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cockroachdb/cockroach-go/v2/crdb"
|
||||
|
||||
"github.com/drakkan/sftpgo/logger"
|
||||
"github.com/drakkan/sftpgo/utils"
|
||||
"github.com/drakkan/sftpgo/vfs"
|
||||
@@ -327,44 +329,38 @@ func sqlCommonAddUser(user *User, dbHandle *sql.DB) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
tx, err := dbHandle.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
q := getAddUserQuery()
|
||||
stmt, err := tx.PrepareContext(ctx, q)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
permissions, err := user.GetPermissionsAsJSON()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
publicKeys, err := user.GetPublicKeysAsJSON()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
filters, err := user.GetFiltersAsJSON()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fsConfig, err := user.GetFsConfigAsJSON()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = stmt.ExecContext(ctx, user.Username, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID, user.MaxSessions, user.QuotaSize,
|
||||
user.QuotaFiles, string(permissions), user.UploadBandwidth, user.DownloadBandwidth, user.Status, user.ExpirationDate, string(filters),
|
||||
string(fsConfig), user.AdditionalInfo, user.Description)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = generateVirtualFoldersMapping(ctx, user, tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
|
||||
q := getAddUserQuery()
|
||||
stmt, err := tx.PrepareContext(ctx, q)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
permissions, err := user.GetPermissionsAsJSON()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
publicKeys, err := user.GetPublicKeysAsJSON()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
filters, err := user.GetFiltersAsJSON()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fsConfig, err := user.GetFsConfigAsJSON()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = stmt.ExecContext(ctx, user.Username, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID, user.MaxSessions, user.QuotaSize,
|
||||
user.QuotaFiles, string(permissions), user.UploadBandwidth, user.DownloadBandwidth, user.Status, user.ExpirationDate, string(filters),
|
||||
string(fsConfig), user.AdditionalInfo, user.Description)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return generateVirtualFoldersMapping(ctx, user, tx)
|
||||
})
|
||||
}
|
||||
|
||||
func sqlCommonUpdateUser(user *User, dbHandle *sql.DB) error {
|
||||
@@ -375,44 +371,38 @@ func sqlCommonUpdateUser(user *User, dbHandle *sql.DB) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
tx, err := dbHandle.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
q := getUpdateUserQuery()
|
||||
stmt, err := tx.PrepareContext(ctx, q)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
permissions, err := user.GetPermissionsAsJSON()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
publicKeys, err := user.GetPublicKeysAsJSON()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
filters, err := user.GetFiltersAsJSON()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fsConfig, err := user.GetFsConfigAsJSON()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = stmt.ExecContext(ctx, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID, user.MaxSessions, user.QuotaSize,
|
||||
user.QuotaFiles, string(permissions), user.UploadBandwidth, user.DownloadBandwidth, user.Status, user.ExpirationDate,
|
||||
string(filters), string(fsConfig), user.AdditionalInfo, user.Description, user.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = generateVirtualFoldersMapping(ctx, user, tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
|
||||
q := getUpdateUserQuery()
|
||||
stmt, err := tx.PrepareContext(ctx, q)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
permissions, err := user.GetPermissionsAsJSON()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
publicKeys, err := user.GetPublicKeysAsJSON()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
filters, err := user.GetFiltersAsJSON()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fsConfig, err := user.GetFsConfigAsJSON()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = stmt.ExecContext(ctx, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID, user.MaxSessions, user.QuotaSize,
|
||||
user.QuotaFiles, string(permissions), user.UploadBandwidth, user.DownloadBandwidth, user.Status, user.ExpirationDate,
|
||||
string(filters), string(fsConfig), user.AdditionalInfo, user.Description, user.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return generateVirtualFoldersMapping(ctx, user, tx)
|
||||
})
|
||||
}
|
||||
|
||||
func sqlCommonDeleteUser(user *User, dbHandle *sql.DB) error {
|
||||
@@ -1072,24 +1062,38 @@ func sqlCommonUpdateDatabaseVersion(ctx context.Context, dbHandle sqlQuerier, ve
|
||||
return err
|
||||
}
|
||||
|
||||
func sqlCommonExecSQLAndUpdateDBVersion(dbHandle *sql.DB, sql []string, newVersion int) error {
|
||||
func sqlCommonExecSQLAndUpdateDBVersion(dbHandle *sql.DB, sqlQueries []string, newVersion int) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
|
||||
for _, q := range sqlQueries {
|
||||
if strings.TrimSpace(q) == "" {
|
||||
continue
|
||||
}
|
||||
_, err := tx.ExecContext(ctx, q)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return sqlCommonUpdateDatabaseVersion(ctx, tx, newVersion)
|
||||
})
|
||||
}
|
||||
|
||||
func sqlCommonExecuteTx(ctx context.Context, dbHandle *sql.DB, txFn func(*sql.Tx) error) error {
|
||||
tx, err := dbHandle.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, q := range sql {
|
||||
if strings.TrimSpace(q) == "" {
|
||||
continue
|
||||
}
|
||||
_, err = tx.ExecContext(ctx, q)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if config.Driver == CockroachDataProviderName {
|
||||
return crdb.ExecuteTx(ctx, dbHandle, nil, txFn)
|
||||
}
|
||||
err = sqlCommonUpdateDatabaseVersion(ctx, tx, newVersion)
|
||||
|
||||
err = txFn(tx)
|
||||
if err != nil {
|
||||
// we don't change the returned error
|
||||
tx.Rollback() //nolint:errcheck
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
|
||||
Reference in New Issue
Block a user