mirror of
https://github.com/drakkan/sftpgo.git
synced 2025-12-06 22:30:56 +03:00
@@ -217,7 +217,7 @@ func sqlCommonGetUserByUsername(username string, dbHandle sqlQuerier) (User, err
|
||||
func sqlCommonValidateUserAndPass(username, password, ip, protocol string, dbHandle *sql.DB) (User, error) {
|
||||
var user User
|
||||
if password == "" {
|
||||
return user, errors.New("Credentials cannot be null or empty")
|
||||
return user, errors.New("credentials cannot be null or empty")
|
||||
}
|
||||
user, err := sqlCommonGetUserByUsername(username, dbHandle)
|
||||
if err != nil {
|
||||
@@ -243,7 +243,7 @@ func sqlCommonValidateUserAndTLSCertificate(username, protocol string, tlsCert *
|
||||
func sqlCommonValidateUserAndPubKey(username string, pubKey []byte, dbHandle *sql.DB) (User, string, error) {
|
||||
var user User
|
||||
if len(pubKey) == 0 {
|
||||
return user, "", errors.New("Credentials cannot be null or empty")
|
||||
return user, "", errors.New("credentials cannot be null or empty")
|
||||
}
|
||||
user, err := sqlCommonGetUserByUsername(username, dbHandle)
|
||||
if err != nil {
|
||||
@@ -587,7 +587,7 @@ func getUserFromDbRow(row sqlScanner) (User, error) {
|
||||
}
|
||||
}
|
||||
if fsConfig.Valid {
|
||||
var fs Filesystem
|
||||
var fs vfs.Filesystem
|
||||
err = json.Unmarshal([]byte(fsConfig.String), &fs)
|
||||
if err == nil {
|
||||
user.FsConfig = fs
|
||||
@@ -603,7 +603,20 @@ func getUserFromDbRow(row sqlScanner) (User, error) {
|
||||
return user, err
|
||||
}
|
||||
|
||||
func sqlCommonCheckFolderExists(ctx context.Context, name string, dbHandle sqlQuerier) (vfs.BaseVirtualFolder, error) {
|
||||
func sqlCommonCheckFolderExists(ctx context.Context, name string, dbHandle sqlQuerier) error {
|
||||
var folderName string
|
||||
q := checkFolderNameQuery()
|
||||
stmt, err := dbHandle.PrepareContext(ctx, q)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
row := stmt.QueryRowContext(ctx, name)
|
||||
return row.Scan(&folderName)
|
||||
}
|
||||
|
||||
func sqlCommonGetFolder(ctx context.Context, name string, dbHandle sqlQuerier) (vfs.BaseVirtualFolder, error) {
|
||||
var folder vfs.BaseVirtualFolder
|
||||
q := getFolderByNameQuery()
|
||||
stmt, err := dbHandle.PrepareContext(ctx, q)
|
||||
@@ -613,9 +626,9 @@ func sqlCommonCheckFolderExists(ctx context.Context, name string, dbHandle sqlQu
|
||||
}
|
||||
defer stmt.Close()
|
||||
row := stmt.QueryRowContext(ctx, name)
|
||||
var mappedPath, description sql.NullString
|
||||
var mappedPath, description, fsConfig sql.NullString
|
||||
err = row.Scan(&folder.ID, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles, &folder.LastQuotaUpdate,
|
||||
&folder.Name, &description)
|
||||
&folder.Name, &description, &fsConfig)
|
||||
if err == sql.ErrNoRows {
|
||||
return folder, &RecordNotFoundError{err: err.Error()}
|
||||
}
|
||||
@@ -625,11 +638,18 @@ func sqlCommonCheckFolderExists(ctx context.Context, name string, dbHandle sqlQu
|
||||
if description.Valid {
|
||||
folder.Description = description.String
|
||||
}
|
||||
if fsConfig.Valid {
|
||||
var fs vfs.Filesystem
|
||||
err = json.Unmarshal([]byte(fsConfig.String), &fs)
|
||||
if err == nil {
|
||||
folder.FsConfig = fs
|
||||
}
|
||||
}
|
||||
return folder, err
|
||||
}
|
||||
|
||||
func sqlCommonGetFolderByName(ctx context.Context, name string, dbHandle sqlQuerier) (vfs.BaseVirtualFolder, error) {
|
||||
folder, err := sqlCommonCheckFolderExists(ctx, name, dbHandle)
|
||||
folder, err := sqlCommonGetFolder(ctx, name, dbHandle)
|
||||
if err != nil {
|
||||
return folder, err
|
||||
}
|
||||
@@ -643,23 +663,30 @@ func sqlCommonGetFolderByName(ctx context.Context, name string, dbHandle sqlQuer
|
||||
return folders[0], nil
|
||||
}
|
||||
|
||||
func sqlCommonAddOrGetFolder(ctx context.Context, baseFolder vfs.BaseVirtualFolder, usedQuotaSize int64, usedQuotaFiles int, lastQuotaUpdate int64, dbHandle sqlQuerier) (vfs.BaseVirtualFolder, error) {
|
||||
folder, err := sqlCommonCheckFolderExists(ctx, baseFolder.Name, dbHandle)
|
||||
if _, ok := err.(*RecordNotFoundError); ok {
|
||||
f := &vfs.BaseVirtualFolder{
|
||||
Name: baseFolder.Name,
|
||||
MappedPath: baseFolder.MappedPath,
|
||||
UsedQuotaSize: usedQuotaSize,
|
||||
UsedQuotaFiles: usedQuotaFiles,
|
||||
LastQuotaUpdate: lastQuotaUpdate,
|
||||
}
|
||||
err = sqlCommonAddFolder(f, dbHandle)
|
||||
func sqlCommonAddOrUpdateFolder(ctx context.Context, baseFolder *vfs.BaseVirtualFolder, usedQuotaSize int64,
|
||||
usedQuotaFiles int, lastQuotaUpdate int64, dbHandle sqlQuerier) (vfs.BaseVirtualFolder, error) {
|
||||
var folder vfs.BaseVirtualFolder
|
||||
// FIXME: we could use an UPSERT here, this SELECT could be racy
|
||||
err := sqlCommonCheckFolderExists(ctx, baseFolder.Name, dbHandle)
|
||||
switch err {
|
||||
case nil:
|
||||
err = sqlCommonUpdateFolder(baseFolder, dbHandle)
|
||||
if err != nil {
|
||||
return folder, err
|
||||
}
|
||||
return sqlCommonCheckFolderExists(ctx, baseFolder.Name, dbHandle)
|
||||
case sql.ErrNoRows:
|
||||
baseFolder.UsedQuotaFiles = usedQuotaFiles
|
||||
baseFolder.UsedQuotaSize = usedQuotaSize
|
||||
baseFolder.LastQuotaUpdate = lastQuotaUpdate
|
||||
err = sqlCommonAddFolder(baseFolder, dbHandle)
|
||||
if err != nil {
|
||||
return folder, err
|
||||
}
|
||||
default:
|
||||
return folder, err
|
||||
}
|
||||
return folder, err
|
||||
|
||||
return sqlCommonGetFolder(ctx, baseFolder.Name, dbHandle)
|
||||
}
|
||||
|
||||
func sqlCommonAddFolder(folder *vfs.BaseVirtualFolder, dbHandle sqlQuerier) error {
|
||||
@@ -667,6 +694,10 @@ func sqlCommonAddFolder(folder *vfs.BaseVirtualFolder, dbHandle sqlQuerier) erro
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fsConfig, err := json.Marshal(folder.FsConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
q := getAddFolderQuery()
|
||||
@@ -677,15 +708,19 @@ func sqlCommonAddFolder(folder *vfs.BaseVirtualFolder, dbHandle sqlQuerier) erro
|
||||
}
|
||||
defer stmt.Close()
|
||||
_, err = stmt.ExecContext(ctx, folder.MappedPath, folder.UsedQuotaSize, folder.UsedQuotaFiles,
|
||||
folder.LastQuotaUpdate, folder.Name, folder.Description)
|
||||
folder.LastQuotaUpdate, folder.Name, folder.Description, string(fsConfig))
|
||||
return err
|
||||
}
|
||||
|
||||
func sqlCommonUpdateFolder(folder *vfs.BaseVirtualFolder, dbHandle *sql.DB) error {
|
||||
func sqlCommonUpdateFolder(folder *vfs.BaseVirtualFolder, dbHandle sqlQuerier) error {
|
||||
err := ValidateFolder(folder)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fsConfig, err := json.Marshal(folder.FsConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
q := getUpdateFolderQuery()
|
||||
@@ -695,7 +730,7 @@ func sqlCommonUpdateFolder(folder *vfs.BaseVirtualFolder, dbHandle *sql.DB) erro
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
_, err = stmt.ExecContext(ctx, folder.MappedPath, folder.Description, folder.Name)
|
||||
_, err = stmt.ExecContext(ctx, folder.MappedPath, folder.Description, string(fsConfig), folder.Name)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -731,9 +766,9 @@ func sqlCommonDumpFolders(dbHandle sqlQuerier) ([]vfs.BaseVirtualFolder, error)
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var folder vfs.BaseVirtualFolder
|
||||
var mappedPath, description sql.NullString
|
||||
var mappedPath, description, fsConfig sql.NullString
|
||||
err = rows.Scan(&folder.ID, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles,
|
||||
&folder.LastQuotaUpdate, &folder.Name, &description)
|
||||
&folder.LastQuotaUpdate, &folder.Name, &description, &fsConfig)
|
||||
if err != nil {
|
||||
return folders, err
|
||||
}
|
||||
@@ -743,6 +778,13 @@ func sqlCommonDumpFolders(dbHandle sqlQuerier) ([]vfs.BaseVirtualFolder, error)
|
||||
if description.Valid {
|
||||
folder.Description = description.String
|
||||
}
|
||||
if fsConfig.Valid {
|
||||
var fs vfs.Filesystem
|
||||
err = json.Unmarshal([]byte(fsConfig.String), &fs)
|
||||
if err == nil {
|
||||
folder.FsConfig = fs
|
||||
}
|
||||
}
|
||||
folders = append(folders, folder)
|
||||
}
|
||||
err = rows.Err()
|
||||
@@ -771,9 +813,9 @@ func sqlCommonGetFolders(limit, offset int, order string, dbHandle sqlQuerier) (
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var folder vfs.BaseVirtualFolder
|
||||
var mappedPath, description sql.NullString
|
||||
var mappedPath, description, fsConfig sql.NullString
|
||||
err = rows.Scan(&folder.ID, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles,
|
||||
&folder.LastQuotaUpdate, &folder.Name, &description)
|
||||
&folder.LastQuotaUpdate, &folder.Name, &description, &fsConfig)
|
||||
if err != nil {
|
||||
return folders, err
|
||||
}
|
||||
@@ -783,6 +825,14 @@ func sqlCommonGetFolders(limit, offset int, order string, dbHandle sqlQuerier) (
|
||||
if description.Valid {
|
||||
folder.Description = description.String
|
||||
}
|
||||
if fsConfig.Valid {
|
||||
var fs vfs.Filesystem
|
||||
err = json.Unmarshal([]byte(fsConfig.String), &fs)
|
||||
if err == nil {
|
||||
folder.FsConfig = fs
|
||||
}
|
||||
}
|
||||
folder.HideConfidentialData()
|
||||
folders = append(folders, folder)
|
||||
}
|
||||
|
||||
@@ -805,7 +855,7 @@ func sqlCommonClearFolderMapping(ctx context.Context, user *User, dbHandle sqlQu
|
||||
return err
|
||||
}
|
||||
|
||||
func sqlCommonAddFolderMapping(ctx context.Context, user *User, folder vfs.VirtualFolder, dbHandle sqlQuerier) error {
|
||||
func sqlCommonAddFolderMapping(ctx context.Context, user *User, folder *vfs.VirtualFolder, dbHandle sqlQuerier) error {
|
||||
q := getAddFolderMappingQuery()
|
||||
stmt, err := dbHandle.PrepareContext(ctx, q)
|
||||
if err != nil {
|
||||
@@ -822,8 +872,9 @@ func generateVirtualFoldersMapping(ctx context.Context, user *User, dbHandle sql
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, vfolder := range user.VirtualFolders {
|
||||
f, err := sqlCommonAddOrGetFolder(ctx, vfolder.BaseVirtualFolder, 0, 0, 0, dbHandle)
|
||||
for idx := range user.VirtualFolders {
|
||||
vfolder := &user.VirtualFolders[idx]
|
||||
f, err := sqlCommonAddOrUpdateFolder(ctx, &vfolder.BaseVirtualFolder, 0, 0, 0, dbHandle)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -870,15 +921,26 @@ func getUsersWithVirtualFolders(users []User, dbHandle sqlQuerier) ([]User, erro
|
||||
for rows.Next() {
|
||||
var folder vfs.VirtualFolder
|
||||
var userID int64
|
||||
var mappedPath sql.NullString
|
||||
var mappedPath, fsConfig, description sql.NullString
|
||||
err = rows.Scan(&folder.ID, &folder.Name, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles,
|
||||
&folder.LastQuotaUpdate, &folder.VirtualPath, &folder.QuotaSize, &folder.QuotaFiles, &userID)
|
||||
&folder.LastQuotaUpdate, &folder.VirtualPath, &folder.QuotaSize, &folder.QuotaFiles, &userID, &fsConfig,
|
||||
&description)
|
||||
if err != nil {
|
||||
return users, err
|
||||
}
|
||||
if mappedPath.Valid {
|
||||
folder.MappedPath = mappedPath.String
|
||||
}
|
||||
if description.Valid {
|
||||
folder.Description = description.String
|
||||
}
|
||||
if fsConfig.Valid {
|
||||
var fs vfs.Filesystem
|
||||
err = json.Unmarshal([]byte(fsConfig.String), &fs)
|
||||
if err == nil {
|
||||
folder.FsConfig = fs
|
||||
}
|
||||
}
|
||||
usersVirtualFolders[userID] = append(usersVirtualFolders[userID], folder)
|
||||
}
|
||||
err = rows.Err()
|
||||
|
||||
Reference in New Issue
Block a user