mirror of
https://github.com/drakkan/sftpgo.git
synced 2025-12-07 06:40:54 +03:00
add memory data provider and use it for portable mode
This commit is contained in:
@@ -161,7 +161,7 @@ func (p BoltProvider) userExists(username string) (User, error) {
|
|||||||
}
|
}
|
||||||
u := bucket.Get([]byte(username))
|
u := bucket.Get([]byte(username))
|
||||||
if u == nil {
|
if u == nil {
|
||||||
return &RecordNotFoundError{err: fmt.Sprintf("username %v does not exist", user.Username)}
|
return &RecordNotFoundError{err: fmt.Sprintf("username %v does not exist", username)}
|
||||||
}
|
}
|
||||||
return json.Unmarshal(u, &user)
|
return json.Unmarshal(u, &user)
|
||||||
})
|
})
|
||||||
@@ -242,6 +242,9 @@ func (p BoltProvider) deleteUser(user User) error {
|
|||||||
func (p BoltProvider) getUsers(limit int, offset int, order string, username string) ([]User, error) {
|
func (p BoltProvider) getUsers(limit int, offset int, order string, username string) ([]User, error) {
|
||||||
users := []User{}
|
users := []User{}
|
||||||
var err error
|
var err error
|
||||||
|
if limit <= 0 {
|
||||||
|
return users, err
|
||||||
|
}
|
||||||
if len(username) > 0 {
|
if len(username) > 0 {
|
||||||
if offset == 0 {
|
if offset == 0 {
|
||||||
user, err := p.userExists(username)
|
user, err := p.userExists(username)
|
||||||
@@ -252,9 +255,6 @@ func (p BoltProvider) getUsers(limit int, offset int, order string, username str
|
|||||||
return users, err
|
return users, err
|
||||||
}
|
}
|
||||||
err = p.dbHandle.View(func(tx *bolt.Tx) error {
|
err = p.dbHandle.View(func(tx *bolt.Tx) error {
|
||||||
if limit <= 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
bucket, _, err := getBuckets(tx)
|
bucket, _, err := getBuckets(tx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -37,6 +37,8 @@ const (
|
|||||||
MySQLDataProviderName = "mysql"
|
MySQLDataProviderName = "mysql"
|
||||||
// BoltDataProviderName name for bbolt key/value store provider
|
// BoltDataProviderName name for bbolt key/value store provider
|
||||||
BoltDataProviderName = "bolt"
|
BoltDataProviderName = "bolt"
|
||||||
|
// MemoryDataProviderName name for memory provider
|
||||||
|
MemoryDataProviderName = "memory"
|
||||||
|
|
||||||
argonPwdPrefix = "$argon2id$"
|
argonPwdPrefix = "$argon2id$"
|
||||||
bcryptPwdPrefix = "$2a$"
|
bcryptPwdPrefix = "$2a$"
|
||||||
@@ -50,7 +52,8 @@ const (
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
// SupportedProviders data provider configured in the sftpgo.conf file must match of these strings
|
// SupportedProviders data provider configured in the sftpgo.conf file must match of these strings
|
||||||
SupportedProviders = []string{SQLiteDataProviderName, PGSQLDataProviderName, MySQLDataProviderName, BoltDataProviderName}
|
SupportedProviders = []string{SQLiteDataProviderName, PGSQLDataProviderName, MySQLDataProviderName,
|
||||||
|
BoltDataProviderName, MemoryDataProviderName}
|
||||||
// ValidPerms list that contains all the valid permissions for an user
|
// ValidPerms list that contains all the valid permissions for an user
|
||||||
ValidPerms = []string{PermAny, PermListItems, PermDownload, PermUpload, PermOverwrite, PermRename, PermDelete,
|
ValidPerms = []string{PermAny, PermListItems, PermDownload, PermUpload, PermOverwrite, PermRename, PermDelete,
|
||||||
PermCreateDirs, PermCreateSymlinks}
|
PermCreateDirs, PermCreateSymlinks}
|
||||||
@@ -179,6 +182,8 @@ func Initialize(cnf Config, basePath string) error {
|
|||||||
err = initializeMySQLProvider()
|
err = initializeMySQLProvider()
|
||||||
} else if config.Driver == BoltDataProviderName {
|
} else if config.Driver == BoltDataProviderName {
|
||||||
err = initializeBoltProvider(basePath)
|
err = initializeBoltProvider(basePath)
|
||||||
|
} else if config.Driver == MemoryDataProviderName {
|
||||||
|
err = initializeMemoryProvider()
|
||||||
} else {
|
} else {
|
||||||
err = fmt.Errorf("unsupported data provider: %v", config.Driver)
|
err = fmt.Errorf("unsupported data provider: %v", config.Driver)
|
||||||
}
|
}
|
||||||
|
|||||||
279
dataprovider/memory.go
Normal file
279
dataprovider/memory.go
Normal file
@@ -0,0 +1,279 @@
|
|||||||
|
package dataprovider
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"sort"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/drakkan/sftpgo/logger"
|
||||||
|
"github.com/drakkan/sftpgo/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errMemoryProviderClosed = errors.New("memory provider is closed")
|
||||||
|
)
|
||||||
|
|
||||||
|
type memoryProviderHandle struct {
|
||||||
|
isClosed bool
|
||||||
|
// slice with ordered usernames
|
||||||
|
usernames []string
|
||||||
|
// mapping between ID and username
|
||||||
|
usersIdx map[int64]string
|
||||||
|
// map for users, username is the key
|
||||||
|
users map[string]User
|
||||||
|
lock *sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// MemoryProvider auth provider for a memory store
|
||||||
|
type MemoryProvider struct {
|
||||||
|
dbHandle *memoryProviderHandle
|
||||||
|
}
|
||||||
|
|
||||||
|
func initializeMemoryProvider() error {
|
||||||
|
provider = MemoryProvider{
|
||||||
|
dbHandle: &memoryProviderHandle{
|
||||||
|
isClosed: false,
|
||||||
|
usernames: []string{},
|
||||||
|
usersIdx: make(map[int64]string),
|
||||||
|
users: make(map[string]User),
|
||||||
|
lock: new(sync.Mutex),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p MemoryProvider) checkAvailability() error {
|
||||||
|
p.dbHandle.lock.Lock()
|
||||||
|
defer p.dbHandle.lock.Unlock()
|
||||||
|
if p.dbHandle.isClosed {
|
||||||
|
return errMemoryProviderClosed
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p MemoryProvider) close() error {
|
||||||
|
p.dbHandle.lock.Lock()
|
||||||
|
defer p.dbHandle.lock.Unlock()
|
||||||
|
if p.dbHandle.isClosed {
|
||||||
|
return errMemoryProviderClosed
|
||||||
|
}
|
||||||
|
p.dbHandle.isClosed = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p MemoryProvider) validateUserAndPass(username string, password string) (User, error) {
|
||||||
|
var user User
|
||||||
|
if len(password) == 0 {
|
||||||
|
return user, errors.New("Credentials cannot be null or empty")
|
||||||
|
}
|
||||||
|
user, err := p.userExists(username)
|
||||||
|
if err != nil {
|
||||||
|
providerLog(logger.LevelWarn, "error authenticating user: %v, error: %v", username, err)
|
||||||
|
return user, err
|
||||||
|
}
|
||||||
|
return checkUserAndPass(user, password)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p MemoryProvider) validateUserAndPubKey(username string, pubKey string) (User, string, error) {
|
||||||
|
var user User
|
||||||
|
if len(pubKey) == 0 {
|
||||||
|
return user, "", errors.New("Credentials cannot be null or empty")
|
||||||
|
}
|
||||||
|
user, err := p.userExists(username)
|
||||||
|
if err != nil {
|
||||||
|
providerLog(logger.LevelWarn, "error authenticating user: %v, error: %v", username, err)
|
||||||
|
return user, "", err
|
||||||
|
}
|
||||||
|
return checkUserAndPubKey(user, pubKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p MemoryProvider) getUserByID(ID int64) (User, error) {
|
||||||
|
p.dbHandle.lock.Lock()
|
||||||
|
defer p.dbHandle.lock.Unlock()
|
||||||
|
if p.dbHandle.isClosed {
|
||||||
|
return User{}, errMemoryProviderClosed
|
||||||
|
}
|
||||||
|
if val, ok := p.dbHandle.usersIdx[ID]; ok {
|
||||||
|
return p.userExistsInternal(val)
|
||||||
|
}
|
||||||
|
return User{}, &RecordNotFoundError{err: fmt.Sprintf("user with ID %v does not exist", ID)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p MemoryProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error {
|
||||||
|
p.dbHandle.lock.Lock()
|
||||||
|
defer p.dbHandle.lock.Unlock()
|
||||||
|
if p.dbHandle.isClosed {
|
||||||
|
return errMemoryProviderClosed
|
||||||
|
}
|
||||||
|
user, err := p.userExistsInternal(username)
|
||||||
|
if err != nil {
|
||||||
|
providerLog(logger.LevelWarn, "unable to update quota for user %v error: %v", username, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if reset {
|
||||||
|
user.UsedQuotaSize = sizeAdd
|
||||||
|
user.UsedQuotaFiles = filesAdd
|
||||||
|
} else {
|
||||||
|
user.UsedQuotaSize += sizeAdd
|
||||||
|
user.UsedQuotaFiles += filesAdd
|
||||||
|
}
|
||||||
|
user.LastQuotaUpdate = utils.GetTimeAsMsSinceEpoch(time.Now())
|
||||||
|
p.dbHandle.users[user.Username] = user
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p MemoryProvider) getUsedQuota(username string) (int, int64, error) {
|
||||||
|
p.dbHandle.lock.Lock()
|
||||||
|
defer p.dbHandle.lock.Unlock()
|
||||||
|
if p.dbHandle.isClosed {
|
||||||
|
return 0, 0, errMemoryProviderClosed
|
||||||
|
}
|
||||||
|
user, err := p.userExistsInternal(username)
|
||||||
|
if err != nil {
|
||||||
|
providerLog(logger.LevelWarn, "unable to get quota for user %v error: %v", username, err)
|
||||||
|
return 0, 0, err
|
||||||
|
}
|
||||||
|
return user.UsedQuotaFiles, user.UsedQuotaSize, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p MemoryProvider) addUser(user User) error {
|
||||||
|
p.dbHandle.lock.Lock()
|
||||||
|
defer p.dbHandle.lock.Unlock()
|
||||||
|
if p.dbHandle.isClosed {
|
||||||
|
return errMemoryProviderClosed
|
||||||
|
}
|
||||||
|
err := validateUser(&user)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = p.userExistsInternal(user.Username)
|
||||||
|
if err == nil {
|
||||||
|
return fmt.Errorf("username %v already exists", user.Username)
|
||||||
|
}
|
||||||
|
user.ID = p.getNextID()
|
||||||
|
p.dbHandle.users[user.Username] = user
|
||||||
|
p.dbHandle.usersIdx[user.ID] = user.Username
|
||||||
|
p.dbHandle.usernames = append(p.dbHandle.usernames, user.Username)
|
||||||
|
sort.Strings(p.dbHandle.usernames)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p MemoryProvider) updateUser(user User) error {
|
||||||
|
p.dbHandle.lock.Lock()
|
||||||
|
defer p.dbHandle.lock.Unlock()
|
||||||
|
if p.dbHandle.isClosed {
|
||||||
|
return errMemoryProviderClosed
|
||||||
|
}
|
||||||
|
err := validateUser(&user)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = p.userExistsInternal(user.Username)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
p.dbHandle.users[user.Username] = user
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p MemoryProvider) deleteUser(user User) error {
|
||||||
|
p.dbHandle.lock.Lock()
|
||||||
|
defer p.dbHandle.lock.Unlock()
|
||||||
|
if p.dbHandle.isClosed {
|
||||||
|
return errMemoryProviderClosed
|
||||||
|
}
|
||||||
|
_, err := p.userExistsInternal(user.Username)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
delete(p.dbHandle.users, user.Username)
|
||||||
|
delete(p.dbHandle.usersIdx, user.ID)
|
||||||
|
// this could be more efficient
|
||||||
|
p.dbHandle.usernames = []string{}
|
||||||
|
for username := range p.dbHandle.users {
|
||||||
|
p.dbHandle.usernames = append(p.dbHandle.usernames, username)
|
||||||
|
}
|
||||||
|
sort.Strings(p.dbHandle.usernames)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p MemoryProvider) getUsers(limit int, offset int, order string, username string) ([]User, error) {
|
||||||
|
users := []User{}
|
||||||
|
var err error
|
||||||
|
p.dbHandle.lock.Lock()
|
||||||
|
defer p.dbHandle.lock.Unlock()
|
||||||
|
if p.dbHandle.isClosed {
|
||||||
|
return users, errMemoryProviderClosed
|
||||||
|
}
|
||||||
|
if limit <= 0 {
|
||||||
|
return users, err
|
||||||
|
}
|
||||||
|
if len(username) > 0 {
|
||||||
|
if offset == 0 {
|
||||||
|
user, err := p.userExistsInternal(username)
|
||||||
|
if err == nil {
|
||||||
|
user.Password = ""
|
||||||
|
users = append(users, user)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return users, err
|
||||||
|
}
|
||||||
|
itNum := 0
|
||||||
|
if order == "ASC" {
|
||||||
|
for _, username := range p.dbHandle.usernames {
|
||||||
|
itNum++
|
||||||
|
if itNum <= offset {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
user := p.dbHandle.users[username]
|
||||||
|
user.Password = ""
|
||||||
|
users = append(users, user)
|
||||||
|
if len(users) >= limit {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for i := len(p.dbHandle.usernames) - 1; i >= 0; i-- {
|
||||||
|
itNum++
|
||||||
|
if itNum <= offset {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
username := p.dbHandle.usernames[i]
|
||||||
|
user := p.dbHandle.users[username]
|
||||||
|
user.Password = ""
|
||||||
|
users = append(users, user)
|
||||||
|
if len(users) >= limit {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return users, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p MemoryProvider) userExists(username string) (User, error) {
|
||||||
|
p.dbHandle.lock.Lock()
|
||||||
|
defer p.dbHandle.lock.Unlock()
|
||||||
|
if p.dbHandle.isClosed {
|
||||||
|
return User{}, errMemoryProviderClosed
|
||||||
|
}
|
||||||
|
return p.userExistsInternal(username)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p MemoryProvider) userExistsInternal(username string) (User, error) {
|
||||||
|
if val, ok := p.dbHandle.users[username]; ok {
|
||||||
|
return val.getACopy(), nil
|
||||||
|
}
|
||||||
|
return User{}, &RecordNotFoundError{err: fmt.Sprintf("username %v does not exist", username)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p MemoryProvider) getNextID() int64 {
|
||||||
|
nextID := int64(1)
|
||||||
|
for id := range p.dbHandle.usersIdx {
|
||||||
|
if id >= nextID {
|
||||||
|
nextID = id + 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nextID
|
||||||
|
}
|
||||||
@@ -189,3 +189,28 @@ func (u *User) GetInfoString() string {
|
|||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u *User) getACopy() User {
|
||||||
|
pubKeys := make([]string, len(u.PublicKeys))
|
||||||
|
copy(pubKeys, u.PublicKeys)
|
||||||
|
permissions := make([]string, len(u.Permissions))
|
||||||
|
copy(permissions, u.Permissions)
|
||||||
|
return User{
|
||||||
|
ID: u.ID,
|
||||||
|
Username: u.Username,
|
||||||
|
Password: u.Password,
|
||||||
|
PublicKeys: pubKeys,
|
||||||
|
HomeDir: u.HomeDir,
|
||||||
|
UID: u.UID,
|
||||||
|
GID: u.GID,
|
||||||
|
MaxSessions: u.MaxSessions,
|
||||||
|
QuotaSize: u.QuotaSize,
|
||||||
|
QuotaFiles: u.QuotaFiles,
|
||||||
|
Permissions: permissions,
|
||||||
|
UsedQuotaSize: u.UsedQuotaSize,
|
||||||
|
UsedQuotaFiles: u.UsedQuotaFiles,
|
||||||
|
LastQuotaUpdate: u.LastQuotaUpdate,
|
||||||
|
UploadBandwidth: u.UploadBandwidth,
|
||||||
|
DownloadBandwidth: u.DownloadBandwidth,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -143,11 +143,9 @@ func (s *Service) StartPortableMode(sftpdPort int, enableSCP bool) error {
|
|||||||
}
|
}
|
||||||
tempDir := os.TempDir()
|
tempDir := os.TempDir()
|
||||||
instanceID := xid.New().String()
|
instanceID := xid.New().String()
|
||||||
databasePath := filepath.Join(tempDir, instanceID+".db")
|
|
||||||
s.LogFilePath = filepath.Join(tempDir, instanceID+".log")
|
s.LogFilePath = filepath.Join(tempDir, instanceID+".log")
|
||||||
dataProviderConf := config.GetProviderConf()
|
dataProviderConf := config.GetProviderConf()
|
||||||
dataProviderConf.Driver = dataprovider.BoltDataProviderName
|
dataProviderConf.Driver = dataprovider.MemoryDataProviderName
|
||||||
dataProviderConf.Name = databasePath
|
|
||||||
config.SetProviderConf(dataProviderConf)
|
config.SetProviderConf(dataProviderConf)
|
||||||
httpdConf := config.GetHTTPDConfig()
|
httpdConf := config.GetHTTPDConfig()
|
||||||
httpdConf.BindPort = 0
|
httpdConf.BindPort = 0
|
||||||
|
|||||||
@@ -277,8 +277,7 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server
|
|||||||
if err := ssh.Unmarshal(req.Payload, &msg); err == nil {
|
if err := ssh.Unmarshal(req.Payload, &msg); err == nil {
|
||||||
name, scpArgs, err := parseCommandPayload(msg.Command)
|
name, scpArgs, err := parseCommandPayload(msg.Command)
|
||||||
connection.Log(logger.LevelDebug, logSender, "new exec command: %#v args: %v user: %v, error: %v",
|
connection.Log(logger.LevelDebug, logSender, "new exec command: %#v args: %v user: %v, error: %v",
|
||||||
name, scpArgs,
|
name, scpArgs, connection.User.Username, err)
|
||||||
connection.User.Username, err)
|
|
||||||
if err == nil && name == "scp" && len(scpArgs) >= 2 {
|
if err == nil && name == "scp" && len(scpArgs) >= 2 {
|
||||||
ok = true
|
ok = true
|
||||||
connection.protocol = protocolSCP
|
connection.protocol = protocolSCP
|
||||||
|
|||||||
@@ -200,7 +200,7 @@ func TestBasicSFTPHandling(t *testing.T) {
|
|||||||
testFileSize := int64(65535)
|
testFileSize := int64(65535)
|
||||||
expectedQuotaSize := user.UsedQuotaSize + testFileSize
|
expectedQuotaSize := user.UsedQuotaSize + testFileSize
|
||||||
expectedQuotaFiles := user.UsedQuotaFiles + 1
|
expectedQuotaFiles := user.UsedQuotaFiles + 1
|
||||||
err = createTestFile(testFilePath, testFileSize)
|
createTestFile(testFilePath, testFileSize)
|
||||||
err = sftpUploadFile(testFilePath, path.Join("/missing_dir", testFileName), testFileSize, client)
|
err = sftpUploadFile(testFilePath, path.Join("/missing_dir", testFileName), testFileSize, client)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("upload a file to a missing dir must fail")
|
t.Errorf("upload a file to a missing dir must fail")
|
||||||
|
|||||||
Reference in New Issue
Block a user