dataprovider: add support for user status and expiration

an user can now be disabled or expired.

If you are using an SQL database as dataprovider please remember to
execute the sql update script inside "sql" folder.

Fixes #57
This commit is contained in:
Nicola Murino
2019-11-13 11:36:21 +01:00
parent 363b9ccc7f
commit c2ff50c917
35 changed files with 1101 additions and 88 deletions

View File

@@ -13,9 +13,15 @@ import (
bolt "go.etcd.io/bbolt"
)
const (
databaseVersion = 2
)
var (
usersBucket = []byte("users")
usersIDIdxBucket = []byte("users_id_idx")
dbVersionBucket = []byte("db_version")
dbVersionKey = []byte("version")
)
// BoltProvider auth provider for bolt key/value store
@@ -23,6 +29,10 @@ type BoltProvider struct {
dbHandle *bolt.DB
}
type boltDatabaseVersion struct {
Version int
}
func initializeBoltProvider(basePath string) error {
var err error
logSender = BoltDataProviderName
@@ -52,7 +62,16 @@ func initializeBoltProvider(basePath string) error {
providerLog(logger.LevelWarn, "error creating username idx bucket: %v", err)
return err
}
err = dbHandle.Update(func(tx *bolt.Tx) error {
_, e := tx.CreateBucketIfNotExists(dbVersionBucket)
return e
})
if err != nil {
providerLog(logger.LevelWarn, "error creating database version bucket: %v", err)
return err
}
provider = BoltProvider{dbHandle: dbHandle}
err = checkBoltDatabaseVersion(dbHandle)
} else {
providerLog(logger.LevelWarn, "error creating bolt key/value store handler: %v", err)
}
@@ -104,7 +123,7 @@ func (p BoltProvider) getUserByID(ID int64) (User, error) {
}
u := bucket.Get(username)
if u == nil {
return &RecordNotFoundError{err: fmt.Sprintf("username %v and ID: %v does not exist", string(username), ID)}
return &RecordNotFoundError{err: fmt.Sprintf("username %#v and ID: %v does not exist", string(username), ID)}
}
return json.Unmarshal(u, &user)
})
@@ -112,6 +131,30 @@ func (p BoltProvider) getUserByID(ID int64) (User, error) {
return user, err
}
func (p BoltProvider) updateLastLogin(username string) error {
return p.dbHandle.Update(func(tx *bolt.Tx) error {
bucket, _, err := getBuckets(tx)
if err != nil {
return err
}
var u []byte
if u = bucket.Get([]byte(username)); u == nil {
return &RecordNotFoundError{err: fmt.Sprintf("username %#v does not exist, unable to update last login", username)}
}
var user User
err = json.Unmarshal(u, &user)
if err != nil {
return err
}
user.LastLogin = utils.GetTimeAsMsSinceEpoch(time.Now())
buf, err := json.Marshal(user)
if err != nil {
return err
}
return bucket.Put([]byte(username), buf)
})
}
func (p BoltProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error {
return p.dbHandle.Update(func(tx *bolt.Tx) error {
bucket, _, err := getBuckets(tx)
@@ -120,7 +163,7 @@ func (p BoltProvider) updateQuota(username string, filesAdd int, sizeAdd int64,
}
var u []byte
if u = bucket.Get([]byte(username)); u == nil {
return &RecordNotFoundError{err: fmt.Sprintf("username %v does not exist, unable to update quota", username)}
return &RecordNotFoundError{err: fmt.Sprintf("username %#v does not exist, unable to update quota", username)}
}
var user User
err = json.Unmarshal(u, &user)
@@ -322,3 +365,90 @@ func getBuckets(tx *bolt.Tx) (*bolt.Bucket, *bolt.Bucket, error) {
}
return bucket, idxBucket, err
}
func checkBoltDatabaseVersion(dbHandle *bolt.DB) error {
dbVersion, err := getBoltDatabaseVersion(dbHandle)
if err != nil {
return err
}
if dbVersion.Version == databaseVersion {
providerLog(logger.LevelDebug, "bolt database updated, version: %v", dbVersion.Version)
return nil
}
if dbVersion.Version == 1 {
providerLog(logger.LevelInfo, "update bolt database version: 1 -> 2")
usernames, err := getBoltAvailableUsernames(dbHandle)
if err != nil {
return err
}
for _, u := range usernames {
user, err := provider.userExists(u)
if err != nil {
return err
}
user.Status = 1
err = provider.updateUser(user)
if err != nil {
return err
}
providerLog(logger.LevelInfo, "user %#v updated, \"status\" setted to 1", user.Username)
}
return updateBoltDatabaseVersion(dbHandle, 2)
}
return err
}
func getBoltAvailableUsernames(dbHandle *bolt.DB) ([]string, error) {
usernames := []string{}
err := dbHandle.View(func(tx *bolt.Tx) error {
_, idxBucket, err := getBuckets(tx)
if err != nil {
return err
}
cursor := idxBucket.Cursor()
for k, v := cursor.First(); k != nil; k, v = cursor.Next() {
usernames = append(usernames, string(v))
}
return nil
})
return usernames, err
}
func getBoltDatabaseVersion(dbHandle *bolt.DB) (boltDatabaseVersion, error) {
var dbVersion boltDatabaseVersion
err := dbHandle.View(func(tx *bolt.Tx) error {
bucket := tx.Bucket(dbVersionBucket)
if bucket == nil {
return fmt.Errorf("unable to find database version bucket")
}
v := bucket.Get(dbVersionKey)
if v == nil {
dbVersion = boltDatabaseVersion{
Version: 1,
}
return nil
}
return json.Unmarshal(v, &dbVersion)
})
return dbVersion, err
}
func updateBoltDatabaseVersion(dbHandle *bolt.DB, version int) error {
err := dbHandle.Update(func(tx *bolt.Tx) error {
bucket := tx.Bucket(dbVersionBucket)
if bucket == nil {
return fmt.Errorf("unable to find database version bucket")
}
newDbVersion := boltDatabaseVersion{
Version: version,
}
buf, err := json.Marshal(newDbVersion)
if err != nil {
return err
}
return bucket.Put(dbVersionKey, buf)
})
return err
}

View File

@@ -160,6 +160,7 @@ type Provider interface {
deleteUser(user User) error
getUsers(limit int, offset int, order string, username string) ([]User, error)
getUserByID(ID int64) (User, error)
updateLastLogin(username string) error
checkAvailability() error
close() error
}
@@ -203,6 +204,14 @@ func CheckUserAndPubKey(p Provider, username string, pubKey string) (User, strin
return p.validateUserAndPubKey(username, pubKey)
}
// UpdateLastLogin updates the last login fields for the given SFTP user
func UpdateLastLogin(p Provider, user User) error {
if config.ManageUsers == 0 {
return &MethodDisabledError{err: manageUsersDisabledError}
}
return p.updateLastLogin(user.Username)
}
// UpdateUserQuota updates the quota for the given SFTP user adding filesAdd and sizeAdd.
// If reset is true filesAdd and sizeAdd indicates the total files and the total size instead of the difference.
func UpdateUserQuota(p Provider, user User, filesAdd int, sizeAdd int64, reset bool) error {
@@ -211,6 +220,9 @@ func UpdateUserQuota(p Provider, user User, filesAdd int, sizeAdd int64, reset b
} else if config.TrackQuota == 2 && !reset && !user.HasQuotaRestrictions() {
return nil
}
if config.ManageUsers == 0 {
return &MethodDisabledError{err: manageUsersDisabledError}
}
return p.updateQuota(user.Username, filesAdd, sizeAdd, reset)
}
@@ -311,6 +323,9 @@ func validateUser(user *User) error {
if err := validatePermissions(user); err != nil {
return err
}
if user.Status < 0 || user.Status > 1 {
return &ValidationError{err: fmt.Sprintf("invalid user status: %v", user.Status)}
}
if len(user.Password) > 0 && !utils.IsStringPrefixInSlice(user.Password, hashPwdPrefixes) {
pwd, err := argon2id.CreateHash(user.Password, argon2id.DefaultParams)
if err != nil {
@@ -327,8 +342,22 @@ func validateUser(user *User) error {
return nil
}
func checkLoginConditions(user User) error {
if user.Status < 1 {
return fmt.Errorf("user %#v is disabled", user.Username)
}
if user.ExpirationDate > 0 && user.ExpirationDate < utils.GetTimeAsMsSinceEpoch(time.Now()) {
return fmt.Errorf("user %#v is expired, expiration timestamp: %v current timestamp: %v", user.Username,
user.ExpirationDate, utils.GetTimeAsMsSinceEpoch(time.Now()))
}
return nil
}
func checkUserAndPass(user User, password string) (User, error) {
var err error
err := checkLoginConditions(user)
if err != nil {
return user, err
}
if len(user.Password) == 0 {
return user, errors.New("Credentials cannot be null or empty")
}
@@ -372,6 +401,10 @@ func checkUserAndPass(user User, password string) (User, error) {
}
func checkUserAndPubKey(user User, pubKey string) (User, string, error) {
err := checkLoginConditions(user)
if err != nil {
return user, "", err
}
if len(user.PublicKeys) == 0 {
return user, "", errors.New("Invalid credentials")
}

View File

@@ -101,6 +101,21 @@ func (p MemoryProvider) getUserByID(ID int64) (User, error) {
return User{}, &RecordNotFoundError{err: fmt.Sprintf("user with ID %v does not exist", ID)}
}
func (p MemoryProvider) updateLastLogin(username string) error {
p.dbHandle.lock.Lock()
defer p.dbHandle.lock.Unlock()
if p.dbHandle.isClosed {
return errMemoryProviderClosed
}
user, err := p.userExistsInternal(username)
if err != nil {
return err
}
user.LastLogin = utils.GetTimeAsMsSinceEpoch(time.Now())
p.dbHandle.users[user.Username] = user
return nil
}
func (p MemoryProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error {
p.dbHandle.lock.Lock()
defer p.dbHandle.lock.Unlock()

View File

@@ -64,6 +64,10 @@ func (p MySQLProvider) updateQuota(username string, filesAdd int, sizeAdd int64,
return sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p.dbHandle)
}
func (p MySQLProvider) updateLastLogin(username string) error {
return sqlCommonUpdateLastLogin(username, p.dbHandle)
}
func (p MySQLProvider) getUsedQuota(username string) (int, int64, error) {
return sqlCommonGetUsedQuota(username, p.dbHandle)
}

View File

@@ -63,6 +63,10 @@ func (p PGSQLProvider) updateQuota(username string, filesAdd int, sizeAdd int64,
return sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p.dbHandle)
}
func (p PGSQLProvider) updateLastLogin(username string) error {
return sqlCommonUpdateLastLogin(username, p.dbHandle)
}
func (p PGSQLProvider) getUsedQuota(username string) (int, int64, error) {
return sqlCommonGetUsedQuota(username, p.dbHandle)
}

View File

@@ -81,10 +81,27 @@ func sqlCommonUpdateQuota(username string, filesAdd int, sizeAdd int64, reset bo
defer stmt.Close()
_, err = stmt.Exec(sizeAdd, filesAdd, utils.GetTimeAsMsSinceEpoch(time.Now()), username)
if err == nil {
providerLog(logger.LevelDebug, "quota updated for user %v, files increment: %v size increment: %v is reset? %v",
providerLog(logger.LevelDebug, "quota updated for user %#v, files increment: %v size increment: %v is reset? %v",
username, filesAdd, sizeAdd, reset)
} else {
providerLog(logger.LevelWarn, "error updating quota for username %v: %v", username, err)
providerLog(logger.LevelWarn, "error updating quota for user %#v: %v", username, err)
}
return err
}
func sqlCommonUpdateLastLogin(username string, dbHandle *sql.DB) error {
q := getUpdateLastLoginQuery()
stmt, err := dbHandle.Prepare(q)
if err != nil {
providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
return err
}
defer stmt.Close()
_, err = stmt.Exec(utils.GetTimeAsMsSinceEpoch(time.Now()), username)
if err == nil {
providerLog(logger.LevelDebug, "last login updated for user %#v", username)
} else {
providerLog(logger.LevelWarn, "error updating last login for user %#v: %v", username, err)
}
return err
}
@@ -142,7 +159,7 @@ func sqlCommonAddUser(user User, dbHandle *sql.DB) error {
return err
}
_, err = stmt.Exec(user.Username, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID, user.MaxSessions, user.QuotaSize,
user.QuotaFiles, string(permissions), user.UploadBandwidth, user.DownloadBandwidth)
user.QuotaFiles, string(permissions), user.UploadBandwidth, user.DownloadBandwidth, user.Status, user.ExpirationDate)
return err
}
@@ -167,7 +184,7 @@ func sqlCommonUpdateUser(user User, dbHandle *sql.DB) error {
return err
}
_, err = stmt.Exec(user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID, user.MaxSessions, user.QuotaSize,
user.QuotaFiles, string(permissions), user.UploadBandwidth, user.DownloadBandwidth, user.ID)
user.QuotaFiles, string(permissions), user.UploadBandwidth, user.DownloadBandwidth, user.Status, user.ExpirationDate, user.ID)
return err
}
@@ -224,12 +241,12 @@ func getUserFromDbRow(row *sql.Row, rows *sql.Rows) (User, error) {
if row != nil {
err = row.Scan(&user.ID, &user.Username, &password, &publicKey, &user.HomeDir, &user.UID, &user.GID, &user.MaxSessions,
&user.QuotaSize, &user.QuotaFiles, &permissions, &user.UsedQuotaSize, &user.UsedQuotaFiles, &user.LastQuotaUpdate,
&user.UploadBandwidth, &user.DownloadBandwidth)
&user.UploadBandwidth, &user.DownloadBandwidth, &user.ExpirationDate, &user.LastLogin, &user.Status)
} else {
err = rows.Scan(&user.ID, &user.Username, &password, &publicKey, &user.HomeDir, &user.UID, &user.GID, &user.MaxSessions,
&user.QuotaSize, &user.QuotaFiles, &permissions, &user.UsedQuotaSize, &user.UsedQuotaFiles, &user.LastQuotaUpdate,
&user.UploadBandwidth, &user.DownloadBandwidth)
&user.UploadBandwidth, &user.DownloadBandwidth, &user.ExpirationDate, &user.LastLogin, &user.Status)
}
if err != nil {
if err == sql.ErrNoRows {

View File

@@ -70,6 +70,10 @@ func (p SQLiteProvider) updateQuota(username string, filesAdd int, sizeAdd int64
return sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p.dbHandle)
}
func (p SQLiteProvider) updateLastLogin(username string) error {
return sqlCommonUpdateLastLogin(username, p.dbHandle)
}
func (p SQLiteProvider) getUsedQuota(username string) (int, int64, error) {
return sqlCommonGetUsedQuota(username, p.dbHandle)
}

View File

@@ -4,7 +4,7 @@ import "fmt"
const (
selectUserFields = "id,username,password,public_keys,home_dir,uid,gid,max_sessions,quota_size,quota_files,permissions," +
"used_quota_size,used_quota_files,last_quota_update,upload_bandwidth,download_bandwidth"
"used_quota_size,used_quota_files,last_quota_update,upload_bandwidth,download_bandwidth,expiration_date,last_login,status"
)
func getSQLPlaceholders() []string {
@@ -38,13 +38,17 @@ func getUsersQuery(order string, username string) string {
func getUpdateQuotaQuery(reset bool) string {
if reset {
return fmt.Sprintf(`UPDATE %v SET used_quota_size = %v,used_quota_files = %v,last_quota_update = %v
return fmt.Sprintf(`UPDATE %v SET used_quota_size = %v,used_quota_files = %v,last_quota_update = %v
WHERE username = %v`, config.UsersTable, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3])
}
return fmt.Sprintf(`UPDATE %v SET used_quota_size = used_quota_size + %v,used_quota_files = used_quota_files + %v,last_quota_update = %v
return fmt.Sprintf(`UPDATE %v SET used_quota_size = used_quota_size + %v,used_quota_files = used_quota_files + %v,last_quota_update = %v
WHERE username = %v`, config.UsersTable, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3])
}
func getUpdateLastLoginQuery() string {
return fmt.Sprintf(`UPDATE %v SET last_login = %v WHERE username = %v`, config.UsersTable, sqlPlaceholders[0], sqlPlaceholders[1])
}
func getQuotaQuery() string {
return fmt.Sprintf(`SELECT used_quota_size,used_quota_files FROM %v WHERE username = %v`, config.UsersTable,
sqlPlaceholders[0])
@@ -52,17 +56,18 @@ func getQuotaQuery() string {
func getAddUserQuery() string {
return fmt.Sprintf(`INSERT INTO %v (username,password,public_keys,home_dir,uid,gid,max_sessions,quota_size,quota_files,permissions,
used_quota_size,used_quota_files,last_quota_update,upload_bandwidth,download_bandwidth)
VALUES (%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,0,0,0,%v,%v)`, config.UsersTable, sqlPlaceholders[0], sqlPlaceholders[1],
used_quota_size,used_quota_files,last_quota_update,upload_bandwidth,download_bandwidth,status,last_login,expiration_date)
VALUES (%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,0,0,0,%v,%v,%v,0,%v)`, config.UsersTable, sqlPlaceholders[0], sqlPlaceholders[1],
sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6], sqlPlaceholders[7],
sqlPlaceholders[8], sqlPlaceholders[9], sqlPlaceholders[10], sqlPlaceholders[11])
sqlPlaceholders[8], sqlPlaceholders[9], sqlPlaceholders[10], sqlPlaceholders[11], sqlPlaceholders[12], sqlPlaceholders[13])
}
func getUpdateUserQuery() string {
return fmt.Sprintf(`UPDATE %v SET password=%v,public_keys=%v,home_dir=%v,uid=%v,gid=%v,max_sessions=%v,quota_size=%v,
quota_files=%v,permissions=%v,upload_bandwidth=%v,download_bandwidth=%v WHERE id = %v`, config.UsersTable,
quota_files=%v,permissions=%v,upload_bandwidth=%v,download_bandwidth=%v,status=%v,expiration_date=%v WHERE id = %v`, config.UsersTable,
sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5],
sqlPlaceholders[6], sqlPlaceholders[7], sqlPlaceholders[8], sqlPlaceholders[9], sqlPlaceholders[10], sqlPlaceholders[11])
sqlPlaceholders[6], sqlPlaceholders[7], sqlPlaceholders[8], sqlPlaceholders[9], sqlPlaceholders[10], sqlPlaceholders[11],
sqlPlaceholders[12], sqlPlaceholders[13])
}
func getDeleteUserQuery() string {

View File

@@ -36,13 +36,16 @@ const (
type User struct {
// Database unique identifier
ID int64 `json:"id"`
// 1 enabled, 0 disabled (login is not allowed)
Status int `json:"status"`
// Username
Username string `json:"username"`
// Account expiration date as unix timestamp in milliseconds. An expired account cannot login.
// 0 means no expiration
ExpirationDate int64 `json:"expiration_date"`
// Password used for password authentication.
// For users created using SFTPGo REST API the password is be stored using argon2id hashing algo.
// Checking passwords stored with bcrypt is supported too.
// Currently, as fallback, there is a clear text password checking but you should not store passwords
// as clear text and this support could be removed at any time, so please don't depend on it.
// Checking passwords stored with bcrypt, pbkdf2 and sha512crypt is supported too.
Password string `json:"password,omitempty"`
// PublicKeys used for public key authentication. At least one between password and a public key is mandatory
PublicKeys []string `json:"public_keys,omitempty"`
@@ -70,6 +73,8 @@ type User struct {
UploadBandwidth int64 `json:"upload_bandwidth"`
// Maximum download bandwidth as KB/s, 0 means unlimited
DownloadBandwidth int64 `json:"download_bandwidth"`
// Last login as unix timestamp in milliseconds
LastLogin int64 `json:"last_login"`
}
// HasPerm returns true if the user has the given permission or any permission
@@ -175,6 +180,10 @@ func (u *User) GetBandwidthAsString() string {
// Number of public keys, max sessions, uid and gid are returned
func (u *User) GetInfoString() string {
var result string
if u.LastLogin > 0 {
t := utils.GetTimeFromMsecSinceEpoch(u.LastLogin)
result += fmt.Sprintf("Last login: %v ", t.Format("2006-01-02 15:04:05")) // YYYY-MM-DD HH:MM:SS
}
if len(u.PublicKeys) > 0 {
result += fmt.Sprintf("Public keys: %v ", len(u.PublicKeys))
}
@@ -190,6 +199,15 @@ func (u *User) GetInfoString() string {
return result
}
// GetExpirationDateAsString returns expiration date formatted as YYYY-MM-DD
func (u *User) GetExpirationDateAsString() string {
if u.ExpirationDate > 0 {
t := utils.GetTimeFromMsecSinceEpoch(u.ExpirationDate)
return t.Format("2006-01-02")
}
return ""
}
func (u *User) getACopy() User {
pubKeys := make([]string, len(u.PublicKeys))
copy(pubKeys, u.PublicKeys)
@@ -212,5 +230,8 @@ func (u *User) getACopy() User {
LastQuotaUpdate: u.LastQuotaUpdate,
UploadBandwidth: u.UploadBandwidth,
DownloadBandwidth: u.DownloadBandwidth,
Status: u.Status,
ExpirationDate: u.ExpirationDate,
LastLogin: u.LastLogin,
}
}