dataprovider move db handle to provider struct

This is needed to support non SQL providers
This commit is contained in:
Nicola Murino
2019-08-11 14:53:37 +02:00
parent 51aacae3c5
commit cb87fe811a
7 changed files with 74 additions and 58 deletions

View File

@@ -40,7 +40,14 @@ $ go get -u github.com/drakkan/sftpgo
Make sure [Git is installed](https://git-scm.com/downloads) on your machine and in your system's `PATH`. Make sure [Git is installed](https://git-scm.com/downloads) on your machine and in your system's `PATH`.
Version info can be embedded populating the following variables at build time: SFTPGo depends on [go-sqlite3](https://github.com/mattn/go-sqlite3) that is a CGO package and so it requires a `C` compiler at build time.
On Linux and macOS a compiler is easy to install or already installed, on Windows you need to download [MinGW-w64](https://sourceforge.net/projects/mingw-w64/files/) and build SFTPGo from it's command prompt.
The compiler is a build time only dependency, it is not not required at runtime.
If you don't need SQLite, you can also get/build SFTPGo setting the environment variable `GCO_ENABLED` to 0, this way SQLite support will be disabled but PostgreSQL and MySQL will work and you don't need a `C` compiler for building.
Version info, such as git commit and build date, can be embedded setting the following string variables at build time:
- `github.com/drakkan/sftpgo/utils.commit` - `github.com/drakkan/sftpgo/utils.commit`
- `github.com/drakkan/sftpgo/utils.date` - `github.com/drakkan/sftpgo/utils.date`
@@ -54,11 +61,11 @@ go build -i -ldflags "-s -w -X github.com/drakkan/sftpgo/utils.commit=`git descr
and you will get a version that includes git commit and build date like this one: and you will get a version that includes git commit and build date like this one:
```bash ```bash
./sftpgo -v sftpgo -v
SFTPGo version: 0.9.0-dev-90607d4-dirty-2019-08-08T19:28:36Z SFTPGo version: 0.9.0-dev-90607d4-dirty-2019-08-08T19:28:36Z
``` ```
A systemd sample [service](https://github.com/drakkan/sftpgo/tree/master/init/sftpgo.service "systemd service") can be found inside the source tree. For Linux, a systemd sample [service](https://github.com/drakkan/sftpgo/tree/master/init/sftpgo.service "systemd service") can be found inside the source tree.
Alternately you can use distro packages: Alternately you can use distro packages:

View File

@@ -4,7 +4,6 @@
package dataprovider package dataprovider
import ( import (
"database/sql"
"fmt" "fmt"
"path/filepath" "path/filepath"
"strings" "strings"
@@ -33,7 +32,6 @@ const (
var ( var (
// SupportedProviders data provider configured in the sftpgo.conf file must match of these strings // SupportedProviders data provider configured in the sftpgo.conf file must match of these strings
SupportedProviders = []string{SQLiteDataProviderName, PGSSQLDataProviderName, MySQLDataProviderName} SupportedProviders = []string{SQLiteDataProviderName, PGSSQLDataProviderName, MySQLDataProviderName}
dbHandle *sql.DB
config Config config Config
provider Provider provider Provider
sqlPlaceholders []string sqlPlaceholders []string
@@ -124,13 +122,10 @@ func Initialize(cnf Config, basePath string) error {
config = cnf config = cnf
sqlPlaceholders = getSQLPlaceholders() sqlPlaceholders = getSQLPlaceholders()
if config.Driver == SQLiteDataProviderName { if config.Driver == SQLiteDataProviderName {
provider = SQLiteProvider{}
return initializeSQLiteProvider(basePath) return initializeSQLiteProvider(basePath)
} else if config.Driver == PGSSQLDataProviderName { } else if config.Driver == PGSSQLDataProviderName {
provider = PGSQLProvider{}
return initializePGSQLProvider() return initializePGSQLProvider()
} else if config.Driver == MySQLDataProviderName { } else if config.Driver == MySQLDataProviderName {
provider = MySQLProvider{}
return initializeMySQLProvider() return initializeMySQLProvider()
} }
return fmt.Errorf("Unsupported data provider: %v", config.Driver) return fmt.Errorf("Unsupported data provider: %v", config.Driver)
@@ -226,7 +221,8 @@ func validateUser(user *User) error {
return &ValidationError{err: fmt.Sprintf("Invalid permission: %v", p)} return &ValidationError{err: fmt.Sprintf("Invalid permission: %v", p)}
} }
} }
if len(user.Password) > 0 && !strings.HasPrefix(user.Password, argonPwdPrefix) { if len(user.Password) > 0 && !strings.HasPrefix(user.Password, argonPwdPrefix) &&
!strings.HasPrefix(user.Password, bcryptPwdPrefix) {
pwd, err := argon2id.CreateHash(user.Password, argon2id.DefaultParams) pwd, err := argon2id.CreateHash(user.Password, argon2id.DefaultParams)
if err != nil { if err != nil {
return err return err

View File

@@ -11,6 +11,7 @@ import (
// MySQLProvider auth provider for MySQL/MariaDB database // MySQLProvider auth provider for MySQL/MariaDB database
type MySQLProvider struct { type MySQLProvider struct {
dbHandle *sql.DB
} }
func initializeMySQLProvider() error { func initializeMySQLProvider() error {
@@ -22,13 +23,14 @@ func initializeMySQLProvider() error {
} else { } else {
connectionString = config.ConnectionString connectionString = config.ConnectionString
} }
dbHandle, err = sql.Open("mysql", connectionString) dbHandle, err := sql.Open("mysql", connectionString)
if err == nil { if err == nil {
numCPU := runtime.NumCPU() numCPU := runtime.NumCPU()
logger.Debug(logSender, "mysql database handle created, connection string: '%v', pool size: %v", connectionString, numCPU) logger.Debug(logSender, "mysql database handle created, connection string: '%v', pool size: %v", connectionString, numCPU)
dbHandle.SetMaxIdleConns(numCPU) dbHandle.SetMaxIdleConns(numCPU)
dbHandle.SetMaxOpenConns(numCPU) dbHandle.SetMaxOpenConns(numCPU)
dbHandle.SetConnMaxLifetime(1800 * time.Second) dbHandle.SetConnMaxLifetime(1800 * time.Second)
provider = MySQLProvider{dbHandle: dbHandle}
} else { } else {
logger.Warn(logSender, "error creating mysql database handler, connection string: '%v', error: %v", connectionString, err) logger.Warn(logSender, "error creating mysql database handler, connection string: '%v', error: %v", connectionString, err)
} }
@@ -36,24 +38,24 @@ func initializeMySQLProvider() error {
} }
func (p MySQLProvider) validateUserAndPass(username string, password string) (User, error) { func (p MySQLProvider) validateUserAndPass(username string, password string) (User, error) {
return sqlCommonValidateUserAndPass(username, password) return sqlCommonValidateUserAndPass(username, password, p.dbHandle)
} }
func (p MySQLProvider) validateUserAndPubKey(username string, publicKey string) (User, error) { func (p MySQLProvider) validateUserAndPubKey(username string, publicKey string) (User, error) {
return sqlCommonValidateUserAndPubKey(username, publicKey) return sqlCommonValidateUserAndPubKey(username, publicKey, p.dbHandle)
} }
func (p MySQLProvider) getUserByID(ID int64) (User, error) { func (p MySQLProvider) getUserByID(ID int64) (User, error) {
return sqlCommonGetUserByID(ID) return sqlCommonGetUserByID(ID, p.dbHandle)
} }
func (p MySQLProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error { func (p MySQLProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error {
tx, err := dbHandle.Begin() tx, err := p.dbHandle.Begin()
if err != nil { if err != nil {
logger.Warn(logSender, "error starting transaction to update quota for user %v: %v", username, err) logger.Warn(logSender, "error starting transaction to update quota for user %v: %v", username, err)
return err return err
} }
err = sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p) err = sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p.dbHandle)
if err == nil { if err == nil {
err = tx.Commit() err = tx.Commit()
} else { } else {
@@ -66,25 +68,25 @@ func (p MySQLProvider) updateQuota(username string, filesAdd int, sizeAdd int64,
} }
func (p MySQLProvider) getUsedQuota(username string) (int, int64, error) { func (p MySQLProvider) getUsedQuota(username string) (int, int64, error) {
return sqlCommonGetUsedQuota(username) return sqlCommonGetUsedQuota(username, p.dbHandle)
} }
func (p MySQLProvider) userExists(username string) (User, error) { func (p MySQLProvider) userExists(username string) (User, error) {
return sqlCommonCheckUserExists(username) return sqlCommonCheckUserExists(username, p.dbHandle)
} }
func (p MySQLProvider) addUser(user User) error { func (p MySQLProvider) addUser(user User) error {
return sqlCommonAddUser(user) return sqlCommonAddUser(user, p.dbHandle)
} }
func (p MySQLProvider) updateUser(user User) error { func (p MySQLProvider) updateUser(user User) error {
return sqlCommonUpdateUser(user) return sqlCommonUpdateUser(user, p.dbHandle)
} }
func (p MySQLProvider) deleteUser(user User) error { func (p MySQLProvider) deleteUser(user User) error {
return sqlCommonDeleteUser(user) return sqlCommonDeleteUser(user, p.dbHandle)
} }
func (p MySQLProvider) getUsers(limit int, offset int, order string, username string) ([]User, error) { func (p MySQLProvider) getUsers(limit int, offset int, order string, username string) ([]User, error) {
return sqlCommonGetUsers(limit, offset, order, username) return sqlCommonGetUsers(limit, offset, order, username, p.dbHandle)
} }

View File

@@ -10,6 +10,7 @@ import (
// PGSQLProvider auth provider for PostgreSQL database // PGSQLProvider auth provider for PostgreSQL database
type PGSQLProvider struct { type PGSQLProvider struct {
dbHandle *sql.DB
} }
func initializePGSQLProvider() error { func initializePGSQLProvider() error {
@@ -21,12 +22,13 @@ func initializePGSQLProvider() error {
} else { } else {
connectionString = config.ConnectionString connectionString = config.ConnectionString
} }
dbHandle, err = sql.Open("postgres", connectionString) dbHandle, err := sql.Open("postgres", connectionString)
if err == nil { if err == nil {
numCPU := runtime.NumCPU() numCPU := runtime.NumCPU()
logger.Debug(logSender, "postgres database handle created, connection string: '%v', pool size: %v", connectionString, numCPU) logger.Debug(logSender, "postgres database handle created, connection string: '%v', pool size: %v", connectionString, numCPU)
dbHandle.SetMaxIdleConns(numCPU) dbHandle.SetMaxIdleConns(numCPU)
dbHandle.SetMaxOpenConns(numCPU) dbHandle.SetMaxOpenConns(numCPU)
provider = PGSQLProvider{dbHandle: dbHandle}
} else { } else {
logger.Warn(logSender, "error creating postgres database handler, connection string: '%v', error: %v", connectionString, err) logger.Warn(logSender, "error creating postgres database handler, connection string: '%v', error: %v", connectionString, err)
} }
@@ -34,24 +36,24 @@ func initializePGSQLProvider() error {
} }
func (p PGSQLProvider) validateUserAndPass(username string, password string) (User, error) { func (p PGSQLProvider) validateUserAndPass(username string, password string) (User, error) {
return sqlCommonValidateUserAndPass(username, password) return sqlCommonValidateUserAndPass(username, password, p.dbHandle)
} }
func (p PGSQLProvider) validateUserAndPubKey(username string, publicKey string) (User, error) { func (p PGSQLProvider) validateUserAndPubKey(username string, publicKey string) (User, error) {
return sqlCommonValidateUserAndPubKey(username, publicKey) return sqlCommonValidateUserAndPubKey(username, publicKey, p.dbHandle)
} }
func (p PGSQLProvider) getUserByID(ID int64) (User, error) { func (p PGSQLProvider) getUserByID(ID int64) (User, error) {
return sqlCommonGetUserByID(ID) return sqlCommonGetUserByID(ID, p.dbHandle)
} }
func (p PGSQLProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error { func (p PGSQLProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error {
tx, err := dbHandle.Begin() tx, err := p.dbHandle.Begin()
if err != nil { if err != nil {
logger.Warn(logSender, "error starting transaction to update quota for user %v: %v", username, err) logger.Warn(logSender, "error starting transaction to update quota for user %v: %v", username, err)
return err return err
} }
err = sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p) err = sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p.dbHandle)
if err == nil { if err == nil {
err = tx.Commit() err = tx.Commit()
} else { } else {
@@ -64,25 +66,25 @@ func (p PGSQLProvider) updateQuota(username string, filesAdd int, sizeAdd int64,
} }
func (p PGSQLProvider) getUsedQuota(username string) (int, int64, error) { func (p PGSQLProvider) getUsedQuota(username string) (int, int64, error) {
return sqlCommonGetUsedQuota(username) return sqlCommonGetUsedQuota(username, p.dbHandle)
} }
func (p PGSQLProvider) userExists(username string) (User, error) { func (p PGSQLProvider) userExists(username string) (User, error) {
return sqlCommonCheckUserExists(username) return sqlCommonCheckUserExists(username, p.dbHandle)
} }
func (p PGSQLProvider) addUser(user User) error { func (p PGSQLProvider) addUser(user User) error {
return sqlCommonAddUser(user) return sqlCommonAddUser(user, p.dbHandle)
} }
func (p PGSQLProvider) updateUser(user User) error { func (p PGSQLProvider) updateUser(user User) error {
return sqlCommonUpdateUser(user) return sqlCommonUpdateUser(user, p.dbHandle)
} }
func (p PGSQLProvider) deleteUser(user User) error { func (p PGSQLProvider) deleteUser(user User) error {
return sqlCommonDeleteUser(user) return sqlCommonDeleteUser(user, p.dbHandle)
} }
func (p PGSQLProvider) getUsers(limit int, offset int, order string, username string) ([]User, error) { func (p PGSQLProvider) getUsers(limit int, offset int, order string, username string) ([]User, error) {
return sqlCommonGetUsers(limit, offset, order, username) return sqlCommonGetUsers(limit, offset, order, username, p.dbHandle)
} }

View File

@@ -16,7 +16,7 @@ import (
"github.com/drakkan/sftpgo/utils" "github.com/drakkan/sftpgo/utils"
) )
func getUserByUsername(username string) (User, error) { func getUserByUsername(username string, dbHandle *sql.DB) (User, error) {
var user User var user User
q := getUserByUsernameQuery() q := getUserByUsernameQuery()
stmt, err := dbHandle.Prepare(q) stmt, err := dbHandle.Prepare(q)
@@ -30,12 +30,12 @@ func getUserByUsername(username string) (User, error) {
return getUserFromDbRow(row, nil) return getUserFromDbRow(row, nil)
} }
func sqlCommonValidateUserAndPass(username string, password string) (User, error) { func sqlCommonValidateUserAndPass(username string, password string, dbHandle *sql.DB) (User, error) {
var user User var user User
if len(password) == 0 { if len(password) == 0 {
return user, errors.New("Credentials cannot be null or empty") return user, errors.New("Credentials cannot be null or empty")
} }
user, err := getUserByUsername(username) user, err := getUserByUsername(username, dbHandle)
if err != nil { if err != nil {
logger.Warn(logSender, "error authenticating user: %v, error: %v", username, err) logger.Warn(logSender, "error authenticating user: %v, error: %v", username, err)
} else { } else {
@@ -68,12 +68,12 @@ func sqlCommonValidateUserAndPass(username string, password string) (User, error
return user, err return user, err
} }
func sqlCommonValidateUserAndPubKey(username string, pubKey string) (User, error) { func sqlCommonValidateUserAndPubKey(username string, pubKey string, dbHandle *sql.DB) (User, error) {
var user User var user User
if len(pubKey) == 0 { if len(pubKey) == 0 {
return user, errors.New("Credentials cannot be null or empty") return user, errors.New("Credentials cannot be null or empty")
} }
user, err := getUserByUsername(username) user, err := getUserByUsername(username, dbHandle)
if err != nil { if err != nil {
logger.Warn(logSender, "error authenticating user: %v, error: %v", username, err) logger.Warn(logSender, "error authenticating user: %v, error: %v", username, err)
return user, err return user, err
@@ -95,7 +95,7 @@ func sqlCommonValidateUserAndPubKey(username string, pubKey string) (User, error
return user, errors.New("Invalid credentials") return user, errors.New("Invalid credentials")
} }
func sqlCommonGetUserByID(ID int64) (User, error) { func sqlCommonGetUserByID(ID int64, dbHandle *sql.DB) (User, error) {
var user User var user User
q := getUserByIDQuery() q := getUserByIDQuery()
stmt, err := dbHandle.Prepare(q) stmt, err := dbHandle.Prepare(q)
@@ -109,7 +109,7 @@ func sqlCommonGetUserByID(ID int64) (User, error) {
return getUserFromDbRow(row, nil) return getUserFromDbRow(row, nil)
} }
func sqlCommonUpdateQuota(username string, filesAdd int, sizeAdd int64, reset bool, p Provider) error { func sqlCommonUpdateQuota(username string, filesAdd int, sizeAdd int64, reset bool, dbHandle *sql.DB) error {
q := getUpdateQuotaQuery(reset) q := getUpdateQuotaQuery(reset)
stmt, err := dbHandle.Prepare(q) stmt, err := dbHandle.Prepare(q)
if err != nil { if err != nil {
@@ -127,7 +127,7 @@ func sqlCommonUpdateQuota(username string, filesAdd int, sizeAdd int64, reset bo
return err return err
} }
func sqlCommonGetUsedQuota(username string) (int, int64, error) { func sqlCommonGetUsedQuota(username string, dbHandle *sql.DB) (int, int64, error) {
q := getQuotaQuery() q := getQuotaQuery()
stmt, err := dbHandle.Prepare(q) stmt, err := dbHandle.Prepare(q)
if err != nil { if err != nil {
@@ -146,7 +146,7 @@ func sqlCommonGetUsedQuota(username string) (int, int64, error) {
return usedFiles, usedSize, err return usedFiles, usedSize, err
} }
func sqlCommonCheckUserExists(username string) (User, error) { func sqlCommonCheckUserExists(username string, dbHandle *sql.DB) (User, error) {
var user User var user User
q := getUserByUsernameQuery() q := getUserByUsernameQuery()
stmt, err := dbHandle.Prepare(q) stmt, err := dbHandle.Prepare(q)
@@ -159,7 +159,7 @@ func sqlCommonCheckUserExists(username string) (User, error) {
return getUserFromDbRow(row, nil) return getUserFromDbRow(row, nil)
} }
func sqlCommonAddUser(user User) error { func sqlCommonAddUser(user User, dbHandle *sql.DB) error {
err := validateUser(&user) err := validateUser(&user)
if err != nil { if err != nil {
return err return err
@@ -184,7 +184,7 @@ func sqlCommonAddUser(user User) error {
return err return err
} }
func sqlCommonUpdateUser(user User) error { func sqlCommonUpdateUser(user User, dbHandle *sql.DB) error {
err := validateUser(&user) err := validateUser(&user)
if err != nil { if err != nil {
return err return err
@@ -209,7 +209,7 @@ func sqlCommonUpdateUser(user User) error {
return err return err
} }
func sqlCommonDeleteUser(user User) error { func sqlCommonDeleteUser(user User, dbHandle *sql.DB) error {
q := getDeleteUserQuery() q := getDeleteUserQuery()
stmt, err := dbHandle.Prepare(q) stmt, err := dbHandle.Prepare(q)
if err != nil { if err != nil {
@@ -221,7 +221,7 @@ func sqlCommonDeleteUser(user User) error {
return err return err
} }
func sqlCommonGetUsers(limit int, offset int, order string, username string) ([]User, error) { func sqlCommonGetUsers(limit int, offset int, order string, username string, dbHandle *sql.DB) ([]User, error) {
users := []User{} users := []User{}
q := getUsersQuery(order, username) q := getUsersQuery(order, username)
stmt, err := dbHandle.Prepare(q) stmt, err := dbHandle.Prepare(q)

View File

@@ -12,6 +12,7 @@ import (
// SQLiteProvider auth provider for SQLite database // SQLiteProvider auth provider for SQLite database
type SQLiteProvider struct { type SQLiteProvider struct {
dbHandle *sql.DB
} }
func initializeSQLiteProvider(basePath string) error { func initializeSQLiteProvider(basePath string) error {
@@ -36,10 +37,11 @@ func initializeSQLiteProvider(basePath string) error {
} else { } else {
connectionString = config.ConnectionString connectionString = config.ConnectionString
} }
dbHandle, err = sql.Open("sqlite3", connectionString) dbHandle, err := sql.Open("sqlite3", connectionString)
if err == nil { if err == nil {
logger.Debug(logSender, "sqlite database handle created, connection string: '%v'", connectionString) logger.Debug(logSender, "sqlite database handle created, connection string: '%v'", connectionString)
dbHandle.SetMaxOpenConns(1) dbHandle.SetMaxOpenConns(1)
provider = SQLiteProvider{dbHandle: dbHandle}
} else { } else {
logger.Warn(logSender, "error creating sqlite database handler, connection string: '%v', error: %v", connectionString, err) logger.Warn(logSender, "error creating sqlite database handler, connection string: '%v', error: %v", connectionString, err)
} }
@@ -47,43 +49,43 @@ func initializeSQLiteProvider(basePath string) error {
} }
func (p SQLiteProvider) validateUserAndPass(username string, password string) (User, error) { func (p SQLiteProvider) validateUserAndPass(username string, password string) (User, error) {
return sqlCommonValidateUserAndPass(username, password) return sqlCommonValidateUserAndPass(username, password, p.dbHandle)
} }
func (p SQLiteProvider) validateUserAndPubKey(username string, publicKey string) (User, error) { func (p SQLiteProvider) validateUserAndPubKey(username string, publicKey string) (User, error) {
return sqlCommonValidateUserAndPubKey(username, publicKey) return sqlCommonValidateUserAndPubKey(username, publicKey, p.dbHandle)
} }
func (p SQLiteProvider) getUserByID(ID int64) (User, error) { func (p SQLiteProvider) getUserByID(ID int64) (User, error) {
return sqlCommonGetUserByID(ID) return sqlCommonGetUserByID(ID, p.dbHandle)
} }
func (p SQLiteProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error { func (p SQLiteProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error {
// we keep only 1 open connection (SetMaxOpenConns(1)) so a transaction is not needed and it could block // we keep only 1 open connection (SetMaxOpenConns(1)) so a transaction is not needed and it could block
// the database access since it will try to open a new connection // the database access since it will try to open a new connection
return sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p) return sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p.dbHandle)
} }
func (p SQLiteProvider) getUsedQuota(username string) (int, int64, error) { func (p SQLiteProvider) getUsedQuota(username string) (int, int64, error) {
return sqlCommonGetUsedQuota(username) return sqlCommonGetUsedQuota(username, p.dbHandle)
} }
func (p SQLiteProvider) userExists(username string) (User, error) { func (p SQLiteProvider) userExists(username string) (User, error) {
return sqlCommonCheckUserExists(username) return sqlCommonCheckUserExists(username, p.dbHandle)
} }
func (p SQLiteProvider) addUser(user User) error { func (p SQLiteProvider) addUser(user User) error {
return sqlCommonAddUser(user) return sqlCommonAddUser(user, p.dbHandle)
} }
func (p SQLiteProvider) updateUser(user User) error { func (p SQLiteProvider) updateUser(user User) error {
return sqlCommonUpdateUser(user) return sqlCommonUpdateUser(user, p.dbHandle)
} }
func (p SQLiteProvider) deleteUser(user User) error { func (p SQLiteProvider) deleteUser(user User) error {
return sqlCommonDeleteUser(user) return sqlCommonDeleteUser(user, p.dbHandle)
} }
func (p SQLiteProvider) getUsers(limit int, offset int, order string, username string) ([]User, error) { func (p SQLiteProvider) getUsers(limit int, offset int, order string, username string) ([]User, error) {
return sqlCommonGetUsers(limit, offset, order, username) return sqlCommonGetUsers(limit, offset, order, username, p.dbHandle)
} }

View File

@@ -99,11 +99,18 @@ func TestUploadFiles(t *testing.T) {
uploadMode = oldUploadMode uploadMode = oldUploadMode
} }
func TestLoginWithInvalidHome(t *testing.T) { func TestWithInvalidHome(t *testing.T) {
u := dataprovider.User{} u := dataprovider.User{}
u.HomeDir = "home_rel_path" u.HomeDir = "home_rel_path"
_, err := loginUser(u) _, err := loginUser(u)
if err == nil { if err == nil {
t.Errorf("login a user with an invalid home_dir must fail") t.Errorf("login a user with an invalid home_dir must fail")
} }
c := Connection{
User: u,
}
err = c.isSubDir("dir_rel_path")
if err == nil {
t.Errorf("tested path is not a home subdir")
}
} }