mirror of
https://github.com/drakkan/sftpgo.git
synced 2025-12-07 06:40:54 +03:00
REST API: add support for API key authentication
This commit is contained in:
@@ -19,7 +19,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
sqlDatabaseVersion = 10
|
||||
sqlDatabaseVersion = 11
|
||||
defaultSQLQueryTimeout = 10 * time.Second
|
||||
longSQLQueryTimeout = 60 * time.Second
|
||||
)
|
||||
@@ -34,6 +34,170 @@ type sqlScanner interface {
|
||||
Scan(dest ...interface{}) error
|
||||
}
|
||||
|
||||
func sqlCommonGetAPIKeyByID(keyID string, dbHandle sqlQuerier) (APIKey, error) {
|
||||
var apiKey APIKey
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
q := getAPIKeyByIDQuery()
|
||||
stmt, err := dbHandle.PrepareContext(ctx, q)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
|
||||
return apiKey, err
|
||||
}
|
||||
defer stmt.Close()
|
||||
row := stmt.QueryRowContext(ctx, keyID)
|
||||
|
||||
apiKey, err = getAPIKeyFromDbRow(row)
|
||||
if err != nil {
|
||||
return apiKey, err
|
||||
}
|
||||
return getAPIKeyWithRelatedFields(ctx, apiKey, dbHandle)
|
||||
}
|
||||
|
||||
func sqlCommonAddAPIKey(apiKey *APIKey, dbHandle *sql.DB) error {
|
||||
err := apiKey.validate()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
userID, adminID, err := sqlCommonGetAPIKeyRelatedIDs(apiKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
q := getAddAPIKeyQuery()
|
||||
stmt, err := dbHandle.PrepareContext(ctx, q)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
_, err = stmt.ExecContext(ctx, apiKey.KeyID, apiKey.Name, apiKey.Key, apiKey.Scope, apiKey.CreatedAt,
|
||||
util.GetTimeAsMsSinceEpoch(time.Now()), apiKey.LastUseAt, apiKey.ExpiresAt, apiKey.Description,
|
||||
userID, adminID)
|
||||
return err
|
||||
}
|
||||
|
||||
func sqlCommonUpdateAPIKey(apiKey *APIKey, dbHandle *sql.DB) error {
|
||||
err := apiKey.validate()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
userID, adminID, err := sqlCommonGetAPIKeyRelatedIDs(apiKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
q := getUpdateAPIKeyQuery()
|
||||
stmt, err := dbHandle.PrepareContext(ctx, q)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
_, err = stmt.ExecContext(ctx, apiKey.Name, apiKey.Scope, apiKey.ExpiresAt, userID, adminID,
|
||||
apiKey.Description, util.GetTimeAsMsSinceEpoch(time.Now()), apiKey.KeyID)
|
||||
return err
|
||||
}
|
||||
|
||||
func sqlCommonDeleteAPIKey(apiKey *APIKey, dbHandle *sql.DB) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
q := getDeleteAPIKeyQuery()
|
||||
stmt, err := dbHandle.PrepareContext(ctx, q)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
_, err = stmt.ExecContext(ctx, apiKey.KeyID)
|
||||
return err
|
||||
}
|
||||
|
||||
func sqlCommonGetAPIKeys(limit, offset int, order string, dbHandle sqlQuerier) ([]APIKey, error) {
|
||||
apiKeys := make([]APIKey, 0, limit)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
q := getAPIKeysQuery(order)
|
||||
stmt, err := dbHandle.PrepareContext(ctx, q)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
|
||||
return nil, err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
rows, err := stmt.QueryContext(ctx, limit, offset)
|
||||
if err != nil {
|
||||
return apiKeys, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
k, err := getAPIKeyFromDbRow(rows)
|
||||
if err != nil {
|
||||
return apiKeys, err
|
||||
}
|
||||
k.HideConfidentialData()
|
||||
apiKeys = append(apiKeys, k)
|
||||
}
|
||||
err = rows.Err()
|
||||
if err != nil {
|
||||
return apiKeys, err
|
||||
}
|
||||
apiKeys, err = getRelatedValuesForAPIKeys(ctx, apiKeys, dbHandle, APIKeyScopeAdmin)
|
||||
if err != nil {
|
||||
return apiKeys, err
|
||||
}
|
||||
|
||||
return getRelatedValuesForAPIKeys(ctx, apiKeys, dbHandle, APIKeyScopeUser)
|
||||
}
|
||||
|
||||
func sqlCommonDumpAPIKeys(dbHandle sqlQuerier) ([]APIKey, error) {
|
||||
apiKeys := make([]APIKey, 0, 30)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
q := getDumpAPIKeysQuery()
|
||||
stmt, err := dbHandle.PrepareContext(ctx, q)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
|
||||
return nil, err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
rows, err := stmt.QueryContext(ctx)
|
||||
if err != nil {
|
||||
return apiKeys, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
k, err := getAPIKeyFromDbRow(rows)
|
||||
if err != nil {
|
||||
return apiKeys, err
|
||||
}
|
||||
apiKeys = append(apiKeys, k)
|
||||
}
|
||||
err = rows.Err()
|
||||
if err != nil {
|
||||
return apiKeys, err
|
||||
}
|
||||
apiKeys, err = getRelatedValuesForAPIKeys(ctx, apiKeys, dbHandle, APIKeyScopeAdmin)
|
||||
if err != nil {
|
||||
return apiKeys, err
|
||||
}
|
||||
|
||||
return getRelatedValuesForAPIKeys(ctx, apiKeys, dbHandle, APIKeyScopeUser)
|
||||
}
|
||||
|
||||
func sqlCommonGetAdminByUsername(username string, dbHandle sqlQuerier) (Admin, error) {
|
||||
var admin Admin
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
@@ -303,6 +467,25 @@ func sqlCommonGetUsedQuota(username string, dbHandle *sql.DB) (int, int64, error
|
||||
return usedFiles, usedSize, err
|
||||
}
|
||||
|
||||
func sqlCommonUpdateAPIKeyLastUse(keyID string, dbHandle *sql.DB) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
q := getUpdateAPIKeyLastUseQuery()
|
||||
stmt, err := dbHandle.PrepareContext(ctx, q)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
_, err = stmt.ExecContext(ctx, util.GetTimeAsMsSinceEpoch(time.Now()), keyID)
|
||||
if err == nil {
|
||||
providerLog(logger.LevelDebug, "last use updated for key %#v", keyID)
|
||||
} else {
|
||||
providerLog(logger.LevelWarn, "error updating last use for key %#v: %v", keyID, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func sqlCommonUpdateLastLogin(username string, dbHandle *sql.DB) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
@@ -486,6 +669,34 @@ func sqlCommonGetUsers(limit int, offset int, order string, dbHandle sqlQuerier)
|
||||
return getUsersWithVirtualFolders(ctx, users, dbHandle)
|
||||
}
|
||||
|
||||
func getAPIKeyFromDbRow(row sqlScanner) (APIKey, error) {
|
||||
var apiKey APIKey
|
||||
var userID, adminID sql.NullInt64
|
||||
var description sql.NullString
|
||||
|
||||
err := row.Scan(&apiKey.KeyID, &apiKey.Name, &apiKey.Key, &apiKey.Scope, &apiKey.CreatedAt, &apiKey.UpdatedAt,
|
||||
&apiKey.LastUseAt, &apiKey.ExpiresAt, &description, &userID, &adminID)
|
||||
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return apiKey, util.NewRecordNotFoundError(err.Error())
|
||||
}
|
||||
return apiKey, err
|
||||
}
|
||||
|
||||
if userID.Valid {
|
||||
apiKey.userID = userID.Int64
|
||||
}
|
||||
if adminID.Valid {
|
||||
apiKey.adminID = adminID.Int64
|
||||
}
|
||||
if description.Valid {
|
||||
apiKey.Description = description.String
|
||||
}
|
||||
|
||||
return apiKey, nil
|
||||
}
|
||||
|
||||
func getAdminFromDbRow(row sqlScanner) (Admin, error) {
|
||||
var admin Admin
|
||||
var email, filters, additionalInfo, permissions, description sql.NullString
|
||||
@@ -526,7 +737,7 @@ func getAdminFromDbRow(row sqlScanner) (Admin, error) {
|
||||
admin.Description = description.String
|
||||
}
|
||||
|
||||
return admin, err
|
||||
return admin, nil
|
||||
}
|
||||
|
||||
func getUserFromDbRow(row sqlScanner) (User, error) {
|
||||
@@ -565,7 +776,7 @@ func getUserFromDbRow(row sqlScanner) (User, error) {
|
||||
perms := make(map[string][]string)
|
||||
err = json.Unmarshal([]byte(permissions.String), &perms)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelDebug, "unable to deserialize permissions for user %#v: %v", user.Username, err)
|
||||
providerLog(logger.LevelWarn, "unable to deserialize permissions for user %#v: %v", user.Username, err)
|
||||
return user, fmt.Errorf("unable to deserialize permissions for user %#v: %v", user.Username, err)
|
||||
}
|
||||
user.Permissions = perms
|
||||
@@ -591,7 +802,7 @@ func getUserFromDbRow(row sqlScanner) (User, error) {
|
||||
user.Description = description.String
|
||||
}
|
||||
user.SetEmptySecretsIfNil()
|
||||
return user, err
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func sqlCommonCheckFolderExists(ctx context.Context, name string, dbHandle sqlQuerier) error {
|
||||
@@ -890,11 +1101,12 @@ func getUserWithVirtualFolders(ctx context.Context, user User, dbHandle sqlQueri
|
||||
}
|
||||
|
||||
func getUsersWithVirtualFolders(ctx context.Context, users []User, dbHandle sqlQuerier) ([]User, error) {
|
||||
if len(users) == 0 {
|
||||
return users, nil
|
||||
}
|
||||
|
||||
var err error
|
||||
usersVirtualFolders := make(map[int64][]vfs.VirtualFolder)
|
||||
if len(users) == 0 {
|
||||
return users, err
|
||||
}
|
||||
q := getRelatedFoldersForUsersQuery(users)
|
||||
stmt, err := dbHandle.PrepareContext(ctx, q)
|
||||
if err != nil {
|
||||
@@ -947,11 +1159,12 @@ func getUsersWithVirtualFolders(ctx context.Context, users []User, dbHandle sqlQ
|
||||
}
|
||||
|
||||
func getVirtualFoldersWithUsers(folders []vfs.BaseVirtualFolder, dbHandle sqlQuerier) ([]vfs.BaseVirtualFolder, error) {
|
||||
if len(folders) == 0 {
|
||||
return folders, nil
|
||||
}
|
||||
|
||||
var err error
|
||||
vFoldersUsers := make(map[int64][]string)
|
||||
if len(folders) == 0 {
|
||||
return folders, err
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
q := getRelatedUsersForFoldersQuery(folders)
|
||||
@@ -1030,6 +1243,94 @@ func sqlCommonGetFolderUsedQuota(mappedPath string, dbHandle *sql.DB) (int, int6
|
||||
return usedFiles, usedSize, err
|
||||
}
|
||||
|
||||
func getAPIKeyWithRelatedFields(ctx context.Context, apiKey APIKey, dbHandle sqlQuerier) (APIKey, error) {
|
||||
var apiKeys []APIKey
|
||||
var err error
|
||||
|
||||
scope := APIKeyScopeAdmin
|
||||
if apiKey.userID > 0 {
|
||||
scope = APIKeyScopeUser
|
||||
}
|
||||
apiKeys, err = getRelatedValuesForAPIKeys(ctx, []APIKey{apiKey}, dbHandle, scope)
|
||||
if err != nil {
|
||||
return apiKey, err
|
||||
}
|
||||
if len(apiKeys) > 0 {
|
||||
apiKey = apiKeys[0]
|
||||
}
|
||||
return apiKey, nil
|
||||
}
|
||||
|
||||
func getRelatedValuesForAPIKeys(ctx context.Context, apiKeys []APIKey, dbHandle sqlQuerier, scope APIKeyScope) ([]APIKey, error) {
|
||||
if len(apiKeys) == 0 {
|
||||
return apiKeys, nil
|
||||
}
|
||||
values := make(map[int64]string)
|
||||
var q string
|
||||
if scope == APIKeyScopeUser {
|
||||
q = getRelatedUsersForAPIKeysQuery(apiKeys)
|
||||
} else {
|
||||
q = getRelatedAdminsForAPIKeysQuery(apiKeys)
|
||||
}
|
||||
stmt, err := dbHandle.PrepareContext(ctx, q)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
|
||||
return nil, err
|
||||
}
|
||||
defer stmt.Close()
|
||||
rows, err := stmt.QueryContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var valueID int64
|
||||
var valueName string
|
||||
err = rows.Scan(&valueID, &valueName)
|
||||
if err != nil {
|
||||
return apiKeys, err
|
||||
}
|
||||
values[valueID] = valueName
|
||||
}
|
||||
err = rows.Err()
|
||||
if err != nil {
|
||||
return apiKeys, err
|
||||
}
|
||||
if len(values) == 0 {
|
||||
return apiKeys, nil
|
||||
}
|
||||
for idx := range apiKeys {
|
||||
ref := &apiKeys[idx]
|
||||
if scope == APIKeyScopeUser {
|
||||
ref.User = values[ref.userID]
|
||||
} else {
|
||||
ref.Admin = values[ref.adminID]
|
||||
}
|
||||
}
|
||||
return apiKeys, nil
|
||||
}
|
||||
|
||||
func sqlCommonGetAPIKeyRelatedIDs(apiKey *APIKey) (sql.NullInt64, sql.NullInt64, error) {
|
||||
var userID, adminID sql.NullInt64
|
||||
if apiKey.User != "" {
|
||||
u, err := provider.userExists(apiKey.User)
|
||||
if err != nil {
|
||||
return userID, adminID, util.NewValidationError(fmt.Sprintf("unable to validate user %v", apiKey.User))
|
||||
}
|
||||
userID.Valid = true
|
||||
userID.Int64 = u.ID
|
||||
}
|
||||
if apiKey.Admin != "" {
|
||||
a, err := provider.adminExists(apiKey.Admin)
|
||||
if err != nil {
|
||||
return userID, adminID, util.NewValidationError(fmt.Sprintf("unable to validate admin %v", apiKey.Admin))
|
||||
}
|
||||
adminID.Valid = true
|
||||
adminID.Int64 = a.ID
|
||||
}
|
||||
return userID, adminID, nil
|
||||
}
|
||||
|
||||
func sqlCommonGetDatabaseVersion(dbHandle *sql.DB, showInitWarn bool) (schemaVersion, error) {
|
||||
var result schemaVersion
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
|
||||
Reference in New Issue
Block a user