data provider: add CockroachDB support

This commit is contained in:
Nicola Murino
2021-03-23 19:14:15 +01:00
parent 8a1249878a
commit 70e035315e
12 changed files with 227 additions and 117 deletions

View File

@@ -10,6 +10,8 @@ import (
"strings"
"time"
"github.com/cockroachdb/cockroach-go/v2/crdb"
"github.com/drakkan/sftpgo/logger"
"github.com/drakkan/sftpgo/utils"
"github.com/drakkan/sftpgo/vfs"
@@ -327,44 +329,38 @@ func sqlCommonAddUser(user *User, dbHandle *sql.DB) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
tx, err := dbHandle.BeginTx(ctx, nil)
if err != nil {
return err
}
q := getAddUserQuery()
stmt, err := tx.PrepareContext(ctx, q)
if err != nil {
providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
return err
}
defer stmt.Close()
permissions, err := user.GetPermissionsAsJSON()
if err != nil {
return err
}
publicKeys, err := user.GetPublicKeysAsJSON()
if err != nil {
return err
}
filters, err := user.GetFiltersAsJSON()
if err != nil {
return err
}
fsConfig, err := user.GetFsConfigAsJSON()
if err != nil {
return err
}
_, err = stmt.ExecContext(ctx, user.Username, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID, user.MaxSessions, user.QuotaSize,
user.QuotaFiles, string(permissions), user.UploadBandwidth, user.DownloadBandwidth, user.Status, user.ExpirationDate, string(filters),
string(fsConfig), user.AdditionalInfo, user.Description)
if err != nil {
return err
}
err = generateVirtualFoldersMapping(ctx, user, tx)
if err != nil {
return err
}
return tx.Commit()
return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
q := getAddUserQuery()
stmt, err := tx.PrepareContext(ctx, q)
if err != nil {
providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
return err
}
defer stmt.Close()
permissions, err := user.GetPermissionsAsJSON()
if err != nil {
return err
}
publicKeys, err := user.GetPublicKeysAsJSON()
if err != nil {
return err
}
filters, err := user.GetFiltersAsJSON()
if err != nil {
return err
}
fsConfig, err := user.GetFsConfigAsJSON()
if err != nil {
return err
}
_, err = stmt.ExecContext(ctx, user.Username, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID, user.MaxSessions, user.QuotaSize,
user.QuotaFiles, string(permissions), user.UploadBandwidth, user.DownloadBandwidth, user.Status, user.ExpirationDate, string(filters),
string(fsConfig), user.AdditionalInfo, user.Description)
if err != nil {
return err
}
return generateVirtualFoldersMapping(ctx, user, tx)
})
}
func sqlCommonUpdateUser(user *User, dbHandle *sql.DB) error {
@@ -375,44 +371,38 @@ func sqlCommonUpdateUser(user *User, dbHandle *sql.DB) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
tx, err := dbHandle.BeginTx(ctx, nil)
if err != nil {
return err
}
q := getUpdateUserQuery()
stmt, err := tx.PrepareContext(ctx, q)
if err != nil {
providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
return err
}
defer stmt.Close()
permissions, err := user.GetPermissionsAsJSON()
if err != nil {
return err
}
publicKeys, err := user.GetPublicKeysAsJSON()
if err != nil {
return err
}
filters, err := user.GetFiltersAsJSON()
if err != nil {
return err
}
fsConfig, err := user.GetFsConfigAsJSON()
if err != nil {
return err
}
_, err = stmt.ExecContext(ctx, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID, user.MaxSessions, user.QuotaSize,
user.QuotaFiles, string(permissions), user.UploadBandwidth, user.DownloadBandwidth, user.Status, user.ExpirationDate,
string(filters), string(fsConfig), user.AdditionalInfo, user.Description, user.ID)
if err != nil {
return err
}
err = generateVirtualFoldersMapping(ctx, user, tx)
if err != nil {
return err
}
return tx.Commit()
return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
q := getUpdateUserQuery()
stmt, err := tx.PrepareContext(ctx, q)
if err != nil {
providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
return err
}
defer stmt.Close()
permissions, err := user.GetPermissionsAsJSON()
if err != nil {
return err
}
publicKeys, err := user.GetPublicKeysAsJSON()
if err != nil {
return err
}
filters, err := user.GetFiltersAsJSON()
if err != nil {
return err
}
fsConfig, err := user.GetFsConfigAsJSON()
if err != nil {
return err
}
_, err = stmt.ExecContext(ctx, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID, user.MaxSessions, user.QuotaSize,
user.QuotaFiles, string(permissions), user.UploadBandwidth, user.DownloadBandwidth, user.Status, user.ExpirationDate,
string(filters), string(fsConfig), user.AdditionalInfo, user.Description, user.ID)
if err != nil {
return err
}
return generateVirtualFoldersMapping(ctx, user, tx)
})
}
func sqlCommonDeleteUser(user *User, dbHandle *sql.DB) error {
@@ -1072,24 +1062,38 @@ func sqlCommonUpdateDatabaseVersion(ctx context.Context, dbHandle sqlQuerier, ve
return err
}
func sqlCommonExecSQLAndUpdateDBVersion(dbHandle *sql.DB, sql []string, newVersion int) error {
func sqlCommonExecSQLAndUpdateDBVersion(dbHandle *sql.DB, sqlQueries []string, newVersion int) error {
ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
defer cancel()
return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
for _, q := range sqlQueries {
if strings.TrimSpace(q) == "" {
continue
}
_, err := tx.ExecContext(ctx, q)
if err != nil {
return err
}
}
return sqlCommonUpdateDatabaseVersion(ctx, tx, newVersion)
})
}
func sqlCommonExecuteTx(ctx context.Context, dbHandle *sql.DB, txFn func(*sql.Tx) error) error {
tx, err := dbHandle.BeginTx(ctx, nil)
if err != nil {
return err
}
for _, q := range sql {
if strings.TrimSpace(q) == "" {
continue
}
_, err = tx.ExecContext(ctx, q)
if err != nil {
return err
}
if config.Driver == CockroachDataProviderName {
return crdb.ExecuteTx(ctx, dbHandle, nil, txFn)
}
err = sqlCommonUpdateDatabaseVersion(ctx, tx, newVersion)
err = txFn(tx)
if err != nil {
// we don't change the returned error
tx.Rollback() //nolint:errcheck
return err
}
return tx.Commit()