mirror of
https://github.com/drakkan/sftpgo.git
synced 2025-12-07 14:50:55 +03:00
allow to store temporary sessions within the data provider
so we can persist password reset codes, OIDC auth sessions and tokens. These features will also work in multi-node setups without sicky sessions now Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
@@ -103,6 +103,8 @@ func TestOIDCInitialization(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestOIDCLoginLogout(t *testing.T) {
|
||||
oidcMgr, ok := oidcMgr.(*memoryOIDCManager)
|
||||
require.True(t, ok)
|
||||
server := getTestOIDCServer()
|
||||
err := server.binding.OIDC.initialize()
|
||||
assert.NoError(t, err)
|
||||
@@ -119,7 +121,7 @@ func TestOIDCLoginLogout(t *testing.T) {
|
||||
State: xid.New().String(),
|
||||
Nonce: xid.New().String(),
|
||||
Audience: tokenAudienceWebClient,
|
||||
IssueAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-10 * time.Minute)),
|
||||
IssuedAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-10 * time.Minute)),
|
||||
}
|
||||
oidcMgr.addPendingAuth(expiredAuthReq)
|
||||
rr = httptest.NewRecorder()
|
||||
@@ -209,7 +211,7 @@ func TestOIDCLoginLogout(t *testing.T) {
|
||||
AccessToken: "123",
|
||||
Expiry: time.Now().Add(5 * time.Minute),
|
||||
}
|
||||
token = token.WithExtra(map[string]interface{}{
|
||||
token = token.WithExtra(map[string]any{
|
||||
"id_token": "id_token_val",
|
||||
})
|
||||
server.binding.OIDC.oauth2Config = &mockOAuth2Config{
|
||||
@@ -502,6 +504,8 @@ func TestOIDCLoginLogout(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestOIDCRefreshToken(t *testing.T) {
|
||||
oidcMgr, ok := oidcMgr.(*memoryOIDCManager)
|
||||
require.True(t, ok)
|
||||
token := oidcToken{
|
||||
Cookie: xid.New().String(),
|
||||
AccessToken: xid.New().String(),
|
||||
@@ -542,7 +546,7 @@ func TestOIDCRefreshToken(t *testing.T) {
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), "the refreshed token has no id token")
|
||||
}
|
||||
newToken = newToken.WithExtra(map[string]interface{}{
|
||||
newToken = newToken.WithExtra(map[string]any{
|
||||
"id_token": "id_token_val",
|
||||
})
|
||||
newToken.Expiry = time.Time{}
|
||||
@@ -557,7 +561,7 @@ func TestOIDCRefreshToken(t *testing.T) {
|
||||
err = token.refresh(&config, &verifier)
|
||||
assert.ErrorIs(t, err, common.ErrGenericFailure)
|
||||
|
||||
newToken = newToken.WithExtra(map[string]interface{}{
|
||||
newToken = newToken.WithExtra(map[string]any{
|
||||
"id_token": "id_token_val",
|
||||
})
|
||||
newToken.Expiry = time.Now().Add(5 * time.Minute)
|
||||
@@ -597,6 +601,8 @@ func TestOIDCRefreshToken(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestValidateOIDCToken(t *testing.T) {
|
||||
oidcMgr, ok := oidcMgr.(*memoryOIDCManager)
|
||||
require.True(t, ok)
|
||||
server := getTestOIDCServer()
|
||||
err := server.binding.OIDC.initialize()
|
||||
assert.NoError(t, err)
|
||||
@@ -787,7 +793,9 @@ func TestOIDCToken(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestOIDCManager(t *testing.T) {
|
||||
func TestMemoryOIDCManager(t *testing.T) {
|
||||
oidcMgr, ok := oidcMgr.(*memoryOIDCManager)
|
||||
require.True(t, ok)
|
||||
require.Len(t, oidcMgr.pendingAuths, 0)
|
||||
authReq := newOIDCPendingAuth(tokenAudienceWebAdmin)
|
||||
oidcMgr.addPendingAuth(authReq)
|
||||
@@ -796,19 +804,15 @@ func TestOIDCManager(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
oidcMgr.removePendingAuth(authReq.State)
|
||||
require.Len(t, oidcMgr.pendingAuths, 0)
|
||||
authReq.IssueAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-61 * time.Second))
|
||||
authReq.IssuedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-61 * time.Second))
|
||||
oidcMgr.addPendingAuth(authReq)
|
||||
require.Len(t, oidcMgr.pendingAuths, 1)
|
||||
_, err = oidcMgr.getPendingAuth(authReq.State)
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), "too old")
|
||||
}
|
||||
oidcMgr.checkCleanup()
|
||||
require.Len(t, oidcMgr.pendingAuths, 1)
|
||||
oidcMgr.lastCleanup = time.Now().Add(-1 * time.Hour)
|
||||
oidcMgr.checkCleanup()
|
||||
oidcMgr.cleanup()
|
||||
require.Len(t, oidcMgr.pendingAuths, 0)
|
||||
assert.True(t, oidcMgr.lastCleanup.After(time.Now().Add(-10*time.Second)))
|
||||
|
||||
token := oidcToken{
|
||||
AccessToken: xid.New().String(),
|
||||
@@ -826,6 +830,7 @@ func TestOIDCManager(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
storedToken, err := oidcMgr.getToken(token.Cookie)
|
||||
assert.NoError(t, err)
|
||||
token.UsedAt = 0 // ensure we don't modify the stored token
|
||||
assert.Greater(t, storedToken.UsedAt, int64(0))
|
||||
token.UsedAt = storedToken.UsedAt
|
||||
assert.Equal(t, token, storedToken)
|
||||
@@ -848,6 +853,12 @@ func TestOIDCManager(t *testing.T) {
|
||||
assert.Greater(t, storedToken.UsedAt, usedAt)
|
||||
token.UsedAt = storedToken.UsedAt
|
||||
assert.Equal(t, token, storedToken)
|
||||
storedToken.UsedAt = util.GetTimeAsMsSinceEpoch(time.Now()) - tokenDeleteInterval - 1
|
||||
oidcMgr.tokens[token.Cookie] = storedToken
|
||||
storedToken, err = oidcMgr.getToken(token.Cookie)
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), "token is too old")
|
||||
}
|
||||
oidcMgr.removeToken(xid.New().String())
|
||||
require.Len(t, oidcMgr.tokens, 1)
|
||||
oidcMgr.removeToken(token.Cookie)
|
||||
@@ -859,8 +870,8 @@ func TestOIDCManager(t *testing.T) {
|
||||
newToken := oidcToken{
|
||||
Cookie: xid.New().String(),
|
||||
}
|
||||
oidcMgr.lastCleanup = time.Now().Add(-1 * time.Hour)
|
||||
oidcMgr.addToken(newToken)
|
||||
oidcMgr.cleanup()
|
||||
require.Len(t, oidcMgr.tokens, 1)
|
||||
_, err = oidcMgr.getToken(token.Cookie)
|
||||
assert.Error(t, err)
|
||||
@@ -874,6 +885,8 @@ func TestOIDCPreLoginHook(t *testing.T) {
|
||||
if runtime.GOOS == osWindows {
|
||||
t.Skip("this test is not available on Windows")
|
||||
}
|
||||
oidcMgr, ok := oidcMgr.(*memoryOIDCManager)
|
||||
require.True(t, ok)
|
||||
username := "test_oidc_user_prelogin"
|
||||
u := dataprovider.User{
|
||||
BaseUser: sdk.BaseUser{
|
||||
@@ -902,7 +915,7 @@ func TestOIDCPreLoginHook(t *testing.T) {
|
||||
server.initializeRouter()
|
||||
|
||||
_, err = dataprovider.UserExists(username)
|
||||
_, ok := err.(*util.RecordNotFoundError)
|
||||
_, ok = err.(*util.RecordNotFoundError)
|
||||
assert.True(t, ok)
|
||||
// now login with OIDC
|
||||
authReq := newOIDCPendingAuth(tokenAudienceWebClient)
|
||||
@@ -911,7 +924,7 @@ func TestOIDCPreLoginHook(t *testing.T) {
|
||||
AccessToken: "1234",
|
||||
Expiry: time.Now().Add(5 * time.Minute),
|
||||
}
|
||||
token = token.WithExtra(map[string]interface{}{
|
||||
token = token.WithExtra(map[string]any{
|
||||
"id_token": "id_token_val",
|
||||
})
|
||||
server.binding.OIDC.oauth2Config = &mockOAuth2Config{
|
||||
@@ -979,6 +992,143 @@ func TestOIDCPreLoginHook(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestOIDCIsAdmin(t *testing.T) {
|
||||
type test struct {
|
||||
input any
|
||||
want bool
|
||||
}
|
||||
|
||||
emptySlice := make([]any, 0)
|
||||
|
||||
tests := []test{
|
||||
{input: "admin", want: true},
|
||||
{input: append(emptySlice, "admin"), want: true},
|
||||
{input: append(emptySlice, "user", "admin"), want: true},
|
||||
{input: "user", want: false},
|
||||
{input: emptySlice, want: false},
|
||||
{input: append(emptySlice, 1), want: false},
|
||||
{input: 1, want: false},
|
||||
{input: nil, want: false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
token := oidcToken{
|
||||
Role: tc.input,
|
||||
}
|
||||
assert.Equal(t, tc.want, token.isAdmin(), "%v should return %t", tc.input, tc.want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDbOIDCManager(t *testing.T) {
|
||||
if !isSharedProviderSupported() {
|
||||
t.Skip("this test it is not available with this provider")
|
||||
}
|
||||
mgr := newOIDCManager(1)
|
||||
pendingAuth := newOIDCPendingAuth(tokenAudienceWebAdmin)
|
||||
mgr.addPendingAuth(pendingAuth)
|
||||
authReq, err := mgr.getPendingAuth(pendingAuth.State)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, pendingAuth, authReq)
|
||||
pendingAuth.IssuedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-24 * time.Hour))
|
||||
mgr.addPendingAuth(pendingAuth)
|
||||
_, err = mgr.getPendingAuth(pendingAuth.State)
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), "auth request is too old")
|
||||
}
|
||||
mgr.removePendingAuth(pendingAuth.State)
|
||||
_, err = mgr.getPendingAuth(pendingAuth.State)
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), "unable to get the auth request for the specified state")
|
||||
}
|
||||
mgr.addPendingAuth(pendingAuth)
|
||||
_, err = mgr.getPendingAuth(pendingAuth.State)
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), "auth request is too old")
|
||||
}
|
||||
mgr.cleanup()
|
||||
_, err = mgr.getPendingAuth(pendingAuth.State)
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), "unable to get the auth request for the specified state")
|
||||
}
|
||||
|
||||
token := oidcToken{
|
||||
Cookie: xid.New().String(),
|
||||
AccessToken: xid.New().String(),
|
||||
TokenType: "Bearer",
|
||||
RefreshToken: xid.New().String(),
|
||||
ExpiresAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-2 * time.Minute)),
|
||||
SessionID: xid.New().String(),
|
||||
IDToken: xid.New().String(),
|
||||
Nonce: xid.New().String(),
|
||||
Username: xid.New().String(),
|
||||
Permissions: []string{dataprovider.PermAdminAny},
|
||||
Role: "admin",
|
||||
}
|
||||
mgr.addToken(token)
|
||||
tokenGet, err := mgr.getToken(token.Cookie)
|
||||
assert.NoError(t, err)
|
||||
assert.Greater(t, tokenGet.UsedAt, int64(0))
|
||||
token.UsedAt = tokenGet.UsedAt
|
||||
assert.Equal(t, token, tokenGet)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
mgr.updateTokenUsage(token)
|
||||
// no change
|
||||
tokenGet, err = mgr.getToken(token.Cookie)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, token.UsedAt, tokenGet.UsedAt)
|
||||
tokenGet.UsedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-24 * time.Hour))
|
||||
tokenGet.RefreshToken = xid.New().String()
|
||||
mgr.updateTokenUsage(tokenGet)
|
||||
tokenGet, err = mgr.getToken(token.Cookie)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, tokenGet.RefreshToken)
|
||||
assert.NotEqual(t, token.RefreshToken, tokenGet.RefreshToken)
|
||||
assert.Greater(t, tokenGet.UsedAt, token.UsedAt)
|
||||
mgr.removeToken(token.Cookie)
|
||||
tokenGet, err = mgr.getToken(token.Cookie)
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), "unable to get the token for the specified session")
|
||||
}
|
||||
// add an expired token
|
||||
token.UsedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-24 * time.Hour))
|
||||
session := dataprovider.Session{
|
||||
Key: token.Cookie,
|
||||
Data: token,
|
||||
Type: dataprovider.SessionTypeOIDCToken,
|
||||
Timestamp: token.UsedAt + tokenDeleteInterval,
|
||||
}
|
||||
err = dataprovider.AddSharedSession(session)
|
||||
assert.NoError(t, err)
|
||||
_, err = mgr.getToken(token.Cookie)
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), "token is too old")
|
||||
}
|
||||
mgr.cleanup()
|
||||
_, err = mgr.getToken(token.Cookie)
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), "unable to get the token for the specified session")
|
||||
}
|
||||
// adding a session without a key should fail
|
||||
session.Key = ""
|
||||
err = dataprovider.AddSharedSession(session)
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), "unable to save a session with an empty key")
|
||||
}
|
||||
session.Key = xid.New().String()
|
||||
session.Type = 1000
|
||||
err = dataprovider.AddSharedSession(session)
|
||||
if assert.Error(t, err) {
|
||||
assert.Contains(t, err.Error(), "invalid session type")
|
||||
}
|
||||
|
||||
dbMgr, ok := mgr.(*dbOIDCManager)
|
||||
if assert.True(t, ok) {
|
||||
_, err = dbMgr.decodePendingAuthData(2)
|
||||
assert.Error(t, err)
|
||||
_, err = dbMgr.decodeTokenData(true)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func getTestOIDCServer() *httpdServer {
|
||||
return &httpdServer{
|
||||
binding: Binding{
|
||||
@@ -1009,29 +1159,3 @@ func getPreLoginScriptContent(user dataprovider.User, nonJSONResponse bool) []by
|
||||
}
|
||||
return content
|
||||
}
|
||||
|
||||
func TestOIDCIsAdmin(t *testing.T) {
|
||||
type test struct {
|
||||
input interface{}
|
||||
want bool
|
||||
}
|
||||
|
||||
emptySlice := make([]interface{}, 0)
|
||||
|
||||
tests := []test{
|
||||
{input: "admin", want: true},
|
||||
{input: append(emptySlice, "admin"), want: true},
|
||||
{input: append(emptySlice, "user", "admin"), want: true},
|
||||
{input: "user", want: false},
|
||||
{input: emptySlice, want: false},
|
||||
{input: append(emptySlice, 1), want: false},
|
||||
{input: 1, want: false},
|
||||
{input: nil, want: false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
token := oidcToken{
|
||||
Role: tc.input,
|
||||
}
|
||||
assert.Equal(t, tc.want, token.isAdmin(), "%v should return %t", tc.input, tc.want)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user