dataprovider move db handle to provider struct

This is needed to support non SQL providers
This commit is contained in:
Nicola Murino
2019-08-11 14:53:37 +02:00
parent 51aacae3c5
commit cb87fe811a
7 changed files with 74 additions and 58 deletions

View File

@@ -16,7 +16,7 @@ import (
"github.com/drakkan/sftpgo/utils"
)
func getUserByUsername(username string) (User, error) {
func getUserByUsername(username string, dbHandle *sql.DB) (User, error) {
var user User
q := getUserByUsernameQuery()
stmt, err := dbHandle.Prepare(q)
@@ -30,12 +30,12 @@ func getUserByUsername(username string) (User, error) {
return getUserFromDbRow(row, nil)
}
func sqlCommonValidateUserAndPass(username string, password string) (User, error) {
func sqlCommonValidateUserAndPass(username string, password string, dbHandle *sql.DB) (User, error) {
var user User
if len(password) == 0 {
return user, errors.New("Credentials cannot be null or empty")
}
user, err := getUserByUsername(username)
user, err := getUserByUsername(username, dbHandle)
if err != nil {
logger.Warn(logSender, "error authenticating user: %v, error: %v", username, err)
} else {
@@ -68,12 +68,12 @@ func sqlCommonValidateUserAndPass(username string, password string) (User, error
return user, err
}
func sqlCommonValidateUserAndPubKey(username string, pubKey string) (User, error) {
func sqlCommonValidateUserAndPubKey(username string, pubKey string, dbHandle *sql.DB) (User, error) {
var user User
if len(pubKey) == 0 {
return user, errors.New("Credentials cannot be null or empty")
}
user, err := getUserByUsername(username)
user, err := getUserByUsername(username, dbHandle)
if err != nil {
logger.Warn(logSender, "error authenticating user: %v, error: %v", username, err)
return user, err
@@ -95,7 +95,7 @@ func sqlCommonValidateUserAndPubKey(username string, pubKey string) (User, error
return user, errors.New("Invalid credentials")
}
func sqlCommonGetUserByID(ID int64) (User, error) {
func sqlCommonGetUserByID(ID int64, dbHandle *sql.DB) (User, error) {
var user User
q := getUserByIDQuery()
stmt, err := dbHandle.Prepare(q)
@@ -109,7 +109,7 @@ func sqlCommonGetUserByID(ID int64) (User, error) {
return getUserFromDbRow(row, nil)
}
func sqlCommonUpdateQuota(username string, filesAdd int, sizeAdd int64, reset bool, p Provider) error {
func sqlCommonUpdateQuota(username string, filesAdd int, sizeAdd int64, reset bool, dbHandle *sql.DB) error {
q := getUpdateQuotaQuery(reset)
stmt, err := dbHandle.Prepare(q)
if err != nil {
@@ -127,7 +127,7 @@ func sqlCommonUpdateQuota(username string, filesAdd int, sizeAdd int64, reset bo
return err
}
func sqlCommonGetUsedQuota(username string) (int, int64, error) {
func sqlCommonGetUsedQuota(username string, dbHandle *sql.DB) (int, int64, error) {
q := getQuotaQuery()
stmt, err := dbHandle.Prepare(q)
if err != nil {
@@ -146,7 +146,7 @@ func sqlCommonGetUsedQuota(username string) (int, int64, error) {
return usedFiles, usedSize, err
}
func sqlCommonCheckUserExists(username string) (User, error) {
func sqlCommonCheckUserExists(username string, dbHandle *sql.DB) (User, error) {
var user User
q := getUserByUsernameQuery()
stmt, err := dbHandle.Prepare(q)
@@ -159,7 +159,7 @@ func sqlCommonCheckUserExists(username string) (User, error) {
return getUserFromDbRow(row, nil)
}
func sqlCommonAddUser(user User) error {
func sqlCommonAddUser(user User, dbHandle *sql.DB) error {
err := validateUser(&user)
if err != nil {
return err
@@ -184,7 +184,7 @@ func sqlCommonAddUser(user User) error {
return err
}
func sqlCommonUpdateUser(user User) error {
func sqlCommonUpdateUser(user User, dbHandle *sql.DB) error {
err := validateUser(&user)
if err != nil {
return err
@@ -209,7 +209,7 @@ func sqlCommonUpdateUser(user User) error {
return err
}
func sqlCommonDeleteUser(user User) error {
func sqlCommonDeleteUser(user User, dbHandle *sql.DB) error {
q := getDeleteUserQuery()
stmt, err := dbHandle.Prepare(q)
if err != nil {
@@ -221,7 +221,7 @@ func sqlCommonDeleteUser(user User) error {
return err
}
func sqlCommonGetUsers(limit int, offset int, order string, username string) ([]User, error) {
func sqlCommonGetUsers(limit int, offset int, order string, username string, dbHandle *sql.DB) ([]User, error) {
users := []User{}
q := getUsersQuery(order, username)
stmt, err := dbHandle.Prepare(q)