REST API: add support for API key authentication

This commit is contained in:
Nicola Murino
2021-08-17 18:08:32 +02:00
parent 05c62b9f40
commit fe953d6b38
41 changed files with 3620 additions and 274 deletions

View File

@@ -19,7 +19,7 @@ import (
)
const (
sqlDatabaseVersion = 10
sqlDatabaseVersion = 11
defaultSQLQueryTimeout = 10 * time.Second
longSQLQueryTimeout = 60 * time.Second
)
@@ -34,6 +34,170 @@ type sqlScanner interface {
Scan(dest ...interface{}) error
}
func sqlCommonGetAPIKeyByID(keyID string, dbHandle sqlQuerier) (APIKey, error) {
var apiKey APIKey
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
q := getAPIKeyByIDQuery()
stmt, err := dbHandle.PrepareContext(ctx, q)
if err != nil {
providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
return apiKey, err
}
defer stmt.Close()
row := stmt.QueryRowContext(ctx, keyID)
apiKey, err = getAPIKeyFromDbRow(row)
if err != nil {
return apiKey, err
}
return getAPIKeyWithRelatedFields(ctx, apiKey, dbHandle)
}
func sqlCommonAddAPIKey(apiKey *APIKey, dbHandle *sql.DB) error {
err := apiKey.validate()
if err != nil {
return err
}
userID, adminID, err := sqlCommonGetAPIKeyRelatedIDs(apiKey)
if err != nil {
return err
}
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
q := getAddAPIKeyQuery()
stmt, err := dbHandle.PrepareContext(ctx, q)
if err != nil {
providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
return err
}
defer stmt.Close()
_, err = stmt.ExecContext(ctx, apiKey.KeyID, apiKey.Name, apiKey.Key, apiKey.Scope, apiKey.CreatedAt,
util.GetTimeAsMsSinceEpoch(time.Now()), apiKey.LastUseAt, apiKey.ExpiresAt, apiKey.Description,
userID, adminID)
return err
}
func sqlCommonUpdateAPIKey(apiKey *APIKey, dbHandle *sql.DB) error {
err := apiKey.validate()
if err != nil {
return err
}
userID, adminID, err := sqlCommonGetAPIKeyRelatedIDs(apiKey)
if err != nil {
return err
}
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
q := getUpdateAPIKeyQuery()
stmt, err := dbHandle.PrepareContext(ctx, q)
if err != nil {
providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
return err
}
defer stmt.Close()
_, err = stmt.ExecContext(ctx, apiKey.Name, apiKey.Scope, apiKey.ExpiresAt, userID, adminID,
apiKey.Description, util.GetTimeAsMsSinceEpoch(time.Now()), apiKey.KeyID)
return err
}
func sqlCommonDeleteAPIKey(apiKey *APIKey, dbHandle *sql.DB) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
q := getDeleteAPIKeyQuery()
stmt, err := dbHandle.PrepareContext(ctx, q)
if err != nil {
providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
return err
}
defer stmt.Close()
_, err = stmt.ExecContext(ctx, apiKey.KeyID)
return err
}
func sqlCommonGetAPIKeys(limit, offset int, order string, dbHandle sqlQuerier) ([]APIKey, error) {
apiKeys := make([]APIKey, 0, limit)
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
q := getAPIKeysQuery(order)
stmt, err := dbHandle.PrepareContext(ctx, q)
if err != nil {
providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
return nil, err
}
defer stmt.Close()
rows, err := stmt.QueryContext(ctx, limit, offset)
if err != nil {
return apiKeys, err
}
defer rows.Close()
for rows.Next() {
k, err := getAPIKeyFromDbRow(rows)
if err != nil {
return apiKeys, err
}
k.HideConfidentialData()
apiKeys = append(apiKeys, k)
}
err = rows.Err()
if err != nil {
return apiKeys, err
}
apiKeys, err = getRelatedValuesForAPIKeys(ctx, apiKeys, dbHandle, APIKeyScopeAdmin)
if err != nil {
return apiKeys, err
}
return getRelatedValuesForAPIKeys(ctx, apiKeys, dbHandle, APIKeyScopeUser)
}
func sqlCommonDumpAPIKeys(dbHandle sqlQuerier) ([]APIKey, error) {
apiKeys := make([]APIKey, 0, 30)
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
q := getDumpAPIKeysQuery()
stmt, err := dbHandle.PrepareContext(ctx, q)
if err != nil {
providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
return nil, err
}
defer stmt.Close()
rows, err := stmt.QueryContext(ctx)
if err != nil {
return apiKeys, err
}
defer rows.Close()
for rows.Next() {
k, err := getAPIKeyFromDbRow(rows)
if err != nil {
return apiKeys, err
}
apiKeys = append(apiKeys, k)
}
err = rows.Err()
if err != nil {
return apiKeys, err
}
apiKeys, err = getRelatedValuesForAPIKeys(ctx, apiKeys, dbHandle, APIKeyScopeAdmin)
if err != nil {
return apiKeys, err
}
return getRelatedValuesForAPIKeys(ctx, apiKeys, dbHandle, APIKeyScopeUser)
}
func sqlCommonGetAdminByUsername(username string, dbHandle sqlQuerier) (Admin, error) {
var admin Admin
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
@@ -303,6 +467,25 @@ func sqlCommonGetUsedQuota(username string, dbHandle *sql.DB) (int, int64, error
return usedFiles, usedSize, err
}
func sqlCommonUpdateAPIKeyLastUse(keyID string, dbHandle *sql.DB) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
q := getUpdateAPIKeyLastUseQuery()
stmt, err := dbHandle.PrepareContext(ctx, q)
if err != nil {
providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
return err
}
defer stmt.Close()
_, err = stmt.ExecContext(ctx, util.GetTimeAsMsSinceEpoch(time.Now()), keyID)
if err == nil {
providerLog(logger.LevelDebug, "last use updated for key %#v", keyID)
} else {
providerLog(logger.LevelWarn, "error updating last use for key %#v: %v", keyID, err)
}
return err
}
func sqlCommonUpdateLastLogin(username string, dbHandle *sql.DB) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
@@ -486,6 +669,34 @@ func sqlCommonGetUsers(limit int, offset int, order string, dbHandle sqlQuerier)
return getUsersWithVirtualFolders(ctx, users, dbHandle)
}
func getAPIKeyFromDbRow(row sqlScanner) (APIKey, error) {
var apiKey APIKey
var userID, adminID sql.NullInt64
var description sql.NullString
err := row.Scan(&apiKey.KeyID, &apiKey.Name, &apiKey.Key, &apiKey.Scope, &apiKey.CreatedAt, &apiKey.UpdatedAt,
&apiKey.LastUseAt, &apiKey.ExpiresAt, &description, &userID, &adminID)
if err != nil {
if err == sql.ErrNoRows {
return apiKey, util.NewRecordNotFoundError(err.Error())
}
return apiKey, err
}
if userID.Valid {
apiKey.userID = userID.Int64
}
if adminID.Valid {
apiKey.adminID = adminID.Int64
}
if description.Valid {
apiKey.Description = description.String
}
return apiKey, nil
}
func getAdminFromDbRow(row sqlScanner) (Admin, error) {
var admin Admin
var email, filters, additionalInfo, permissions, description sql.NullString
@@ -526,7 +737,7 @@ func getAdminFromDbRow(row sqlScanner) (Admin, error) {
admin.Description = description.String
}
return admin, err
return admin, nil
}
func getUserFromDbRow(row sqlScanner) (User, error) {
@@ -565,7 +776,7 @@ func getUserFromDbRow(row sqlScanner) (User, error) {
perms := make(map[string][]string)
err = json.Unmarshal([]byte(permissions.String), &perms)
if err != nil {
providerLog(logger.LevelDebug, "unable to deserialize permissions for user %#v: %v", user.Username, err)
providerLog(logger.LevelWarn, "unable to deserialize permissions for user %#v: %v", user.Username, err)
return user, fmt.Errorf("unable to deserialize permissions for user %#v: %v", user.Username, err)
}
user.Permissions = perms
@@ -591,7 +802,7 @@ func getUserFromDbRow(row sqlScanner) (User, error) {
user.Description = description.String
}
user.SetEmptySecretsIfNil()
return user, err
return user, nil
}
func sqlCommonCheckFolderExists(ctx context.Context, name string, dbHandle sqlQuerier) error {
@@ -890,11 +1101,12 @@ func getUserWithVirtualFolders(ctx context.Context, user User, dbHandle sqlQueri
}
func getUsersWithVirtualFolders(ctx context.Context, users []User, dbHandle sqlQuerier) ([]User, error) {
if len(users) == 0 {
return users, nil
}
var err error
usersVirtualFolders := make(map[int64][]vfs.VirtualFolder)
if len(users) == 0 {
return users, err
}
q := getRelatedFoldersForUsersQuery(users)
stmt, err := dbHandle.PrepareContext(ctx, q)
if err != nil {
@@ -947,11 +1159,12 @@ func getUsersWithVirtualFolders(ctx context.Context, users []User, dbHandle sqlQ
}
func getVirtualFoldersWithUsers(folders []vfs.BaseVirtualFolder, dbHandle sqlQuerier) ([]vfs.BaseVirtualFolder, error) {
if len(folders) == 0 {
return folders, nil
}
var err error
vFoldersUsers := make(map[int64][]string)
if len(folders) == 0 {
return folders, err
}
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
q := getRelatedUsersForFoldersQuery(folders)
@@ -1030,6 +1243,94 @@ func sqlCommonGetFolderUsedQuota(mappedPath string, dbHandle *sql.DB) (int, int6
return usedFiles, usedSize, err
}
func getAPIKeyWithRelatedFields(ctx context.Context, apiKey APIKey, dbHandle sqlQuerier) (APIKey, error) {
var apiKeys []APIKey
var err error
scope := APIKeyScopeAdmin
if apiKey.userID > 0 {
scope = APIKeyScopeUser
}
apiKeys, err = getRelatedValuesForAPIKeys(ctx, []APIKey{apiKey}, dbHandle, scope)
if err != nil {
return apiKey, err
}
if len(apiKeys) > 0 {
apiKey = apiKeys[0]
}
return apiKey, nil
}
func getRelatedValuesForAPIKeys(ctx context.Context, apiKeys []APIKey, dbHandle sqlQuerier, scope APIKeyScope) ([]APIKey, error) {
if len(apiKeys) == 0 {
return apiKeys, nil
}
values := make(map[int64]string)
var q string
if scope == APIKeyScopeUser {
q = getRelatedUsersForAPIKeysQuery(apiKeys)
} else {
q = getRelatedAdminsForAPIKeysQuery(apiKeys)
}
stmt, err := dbHandle.PrepareContext(ctx, q)
if err != nil {
providerLog(logger.LevelWarn, "error preparing database query %#v: %v", q, err)
return nil, err
}
defer stmt.Close()
rows, err := stmt.QueryContext(ctx)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var valueID int64
var valueName string
err = rows.Scan(&valueID, &valueName)
if err != nil {
return apiKeys, err
}
values[valueID] = valueName
}
err = rows.Err()
if err != nil {
return apiKeys, err
}
if len(values) == 0 {
return apiKeys, nil
}
for idx := range apiKeys {
ref := &apiKeys[idx]
if scope == APIKeyScopeUser {
ref.User = values[ref.userID]
} else {
ref.Admin = values[ref.adminID]
}
}
return apiKeys, nil
}
func sqlCommonGetAPIKeyRelatedIDs(apiKey *APIKey) (sql.NullInt64, sql.NullInt64, error) {
var userID, adminID sql.NullInt64
if apiKey.User != "" {
u, err := provider.userExists(apiKey.User)
if err != nil {
return userID, adminID, util.NewValidationError(fmt.Sprintf("unable to validate user %v", apiKey.User))
}
userID.Valid = true
userID.Int64 = u.ID
}
if apiKey.Admin != "" {
a, err := provider.adminExists(apiKey.Admin)
if err != nil {
return userID, adminID, util.NewValidationError(fmt.Sprintf("unable to validate admin %v", apiKey.Admin))
}
adminID.Valid = true
adminID.Int64 = a.ID
}
return userID, adminID, nil
}
func sqlCommonGetDatabaseVersion(dbHandle *sql.DB, showInitWarn bool) (schemaVersion, error) {
var result schemaVersion
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)