diff --git a/common/common.go b/common/common.go index 9badd1fd..c0591dbd 100644 --- a/common/common.go +++ b/common/common.go @@ -181,6 +181,7 @@ func AddDefenderEvent(ip string, event HostEvent) { Config.defender.AddEvent(ip, event) } +// the ticker cannot be started/stopped from multiple goroutines func startIdleTimeoutTicker(duration time.Duration) { stopIdleTimeoutTicker() idleTimeoutTicker = time.NewTicker(duration) diff --git a/httpd/auth_utils.go b/httpd/auth_utils.go index 937342b6..7bcf44d4 100644 --- a/httpd/auth_utils.go +++ b/httpd/auth_utils.go @@ -123,7 +123,7 @@ func (c *jwtTokenClaims) createAndSetCookie(w http.ResponseWriter, tokenAuth *jw return nil } -func (c *jwtTokenClaims) removeCookie(w http.ResponseWriter) { +func (c *jwtTokenClaims) removeCookie(w http.ResponseWriter, r *http.Request) { http.SetCookie(w, &http.Cookie{ Name: "jwt", Value: "", @@ -131,6 +131,37 @@ func (c *jwtTokenClaims) removeCookie(w http.ResponseWriter) { MaxAge: -1, HttpOnly: true, }) + invalidateToken(r) +} + +func isTokenInvalidated(r *http.Request) bool { + isTokenFound := false + token := jwtauth.TokenFromHeader(r) + if token != "" { + isTokenFound = true + if _, ok := invalidatedJWTTokens.Load(token); ok { + return true + } + } + token = jwtauth.TokenFromCookie(r) + if token != "" { + isTokenFound = true + if _, ok := invalidatedJWTTokens.Load(token); ok { + return true + } + } + return !isTokenFound +} + +func invalidateToken(r *http.Request) { + tokenString := jwtauth.TokenFromHeader(r) + if tokenString != "" { + invalidatedJWTTokens.Store(tokenString, time.Now().UTC().Add(tokenDuration)) + } + tokenString = jwtauth.TokenFromCookie(r) + if tokenString != "" { + invalidatedJWTTokens.Store(tokenString, time.Now().UTC().Add(tokenDuration)) + } } func getAdminFromToken(r *http.Request) *dataprovider.Admin { diff --git a/httpd/httpd.go b/httpd/httpd.go index 05db231a..c983dc2b 100644 --- a/httpd/httpd.go +++ b/httpd/httpd.go @@ -11,6 +11,8 @@ import ( "path/filepath" "runtime" "strings" + "sync" + "time" "github.com/go-chi/chi" @@ -26,6 +28,7 @@ import ( const ( logSender = "httpd" tokenPath = "/api/v2/token" + logoutPath = "/api/v2/logout" activeConnectionsPath = "/api/v2/connections" quotaScanPath = "/api/v2/quota-scans" quotaScanVFolderPath = "/api/v2/folder-quota-scans" @@ -69,8 +72,11 @@ const ( ) var ( - backupsPath string - certMgr *common.CertManager + backupsPath string + certMgr *common.CertManager + jwtTokensCleanupTicker *time.Ticker + jwtTokensCleanupDone chan bool + invalidatedJWTTokens sync.Map ) // Binding defines the configuration for a network listener @@ -213,6 +219,7 @@ func (c *Conf) Initialize(configDir string) error { }(binding) } + startJWTTokensCleanupTicker(tokenDuration) return <-exitChannel } @@ -286,3 +293,39 @@ func GetHTTPRouter() http.Handler { server.initializeRouter() return server.router } + +// the ticker cannot be started/stopped from multiple goroutines +func startJWTTokensCleanupTicker(duration time.Duration) { + stopJWTTokensCleanupTicker() + jwtTokensCleanupTicker = time.NewTicker(duration) + jwtTokensCleanupDone = make(chan bool) + + go func() { + for { + select { + case <-jwtTokensCleanupDone: + return + case <-jwtTokensCleanupTicker.C: + cleanupExpiredJWTTokens() + } + } + }() +} + +func stopJWTTokensCleanupTicker() { + if jwtTokensCleanupTicker != nil { + jwtTokensCleanupTicker.Stop() + jwtTokensCleanupDone <- true + jwtTokensCleanupTicker = nil + } +} + +func cleanupExpiredJWTTokens() { + invalidatedJWTTokens.Range(func(key, value interface{}) bool { + exp, ok := value.(time.Time) + if !ok || exp.Before(time.Now().UTC()) { + invalidatedJWTTokens.Delete(key) + } + return true + }) +} diff --git a/httpd/httpd_test.go b/httpd/httpd_test.go index 5a1da291..347b8852 100644 --- a/httpd/httpd_test.go +++ b/httpd/httpd_test.go @@ -61,6 +61,7 @@ const ( updateFolderUsedQuotaPath = "/api/v2/folder-quota-update" defenderUnban = "/api/v2/defender/unban" versionPath = "/api/v2/version" + logoutPath = "/api/v2/logout" healthzPath = "/healthz" webBasePath = "/web" webLoginPath = "/web/login" @@ -3635,6 +3636,26 @@ func TestWebNotFoundURI(t *testing.T) { assert.Equal(t, http.StatusNotFound, resp.StatusCode) } +func TestLogout(t *testing.T) { + token, err := getJWTTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + req, _ := http.NewRequest(http.MethodGet, serverStatusPath, nil) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, _ = http.NewRequest(http.MethodGet, logoutPath, nil) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, _ = http.NewRequest(http.MethodGet, serverStatusPath, nil) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusUnauthorized, rr) + assert.Contains(t, rr.Body.String(), "Your token is no longer valid") +} + func TestWebLoginMock(t *testing.T) { form := getAdminLoginForm(defaultTokenAuthUser, defaultTokenAuthPass) req, _ := http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) @@ -3656,12 +3677,29 @@ func TestWebLoginMock(t *testing.T) { rr = executeRequest(req) checkResponseCode(t, http.StatusNotFound, rr) + req, _ = http.NewRequest(http.MethodGet, webStatusPath, nil) + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + req, _ = http.NewRequest(http.MethodGet, webLogoutPath, nil) setJWTCookieForReq(req, token) rr = executeRequest(req) checkResponseCode(t, http.StatusFound, rr) cookie = rr.Header().Get("Cookie") assert.Empty(t, cookie) + + req, _ = http.NewRequest(http.MethodGet, serverStatusPath, nil) + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusUnauthorized, rr) + assert.Contains(t, rr.Body.String(), "Your token is no longer valid") + + req, _ = http.NewRequest(http.MethodGet, webStatusPath, nil) + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + // now try using wrong credentials form = getAdminLoginForm(defaultTokenAuthUser, "wrong pwd") req, _ = http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) diff --git a/httpd/internal_test.go b/httpd/internal_test.go index 30abafc4..aadc72fa 100644 --- a/httpd/internal_test.go +++ b/httpd/internal_test.go @@ -755,3 +755,32 @@ func TestGetUserFromTemplate(t *testing.T) { require.Equal(t, "sftp_"+username, userTemplate.FsConfig.SFTPConfig.Username) require.Equal(t, "sftp"+password, userTemplate.FsConfig.SFTPConfig.Password.GetPayload()) } + +func TestJWTTokenCleanup(t *testing.T) { + server := httpdServer{ + tokenAuth: jwtauth.New("HS256", utils.GenerateRandomBytes(32), nil), + } + admin := dataprovider.Admin{ + Username: "newtestadmin", + Password: "password", + Permissions: []string{dataprovider.PermAdminAny}, + } + claims := make(map[string]interface{}) + claims[claimUsernameKey] = admin.Username + claims[claimPermissionsKey] = admin.Permissions + claims[jwt.SubjectKey] = admin.GetSignature() + claims[jwt.ExpirationKey] = time.Now().Add(1 * time.Minute) + _, token, err := server.tokenAuth.Encode(claims) + assert.NoError(t, err) + + req, _ := http.NewRequest(http.MethodGet, versionPath, nil) + assert.True(t, isTokenInvalidated(req)) + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %v", token)) + + invalidatedJWTTokens.Store(token, time.Now().UTC().Add(-tokenDuration)) + require.True(t, isTokenInvalidated(req)) + startJWTTokensCleanupTicker(100 * time.Millisecond) + assert.Eventually(t, func() bool { return !isTokenInvalidated(req) }, 1*time.Second, 200*time.Millisecond) + stopJWTTokensCleanupTicker() +} diff --git a/httpd/middleware.go b/httpd/middleware.go index e3d0216e..97e611a1 100644 --- a/httpd/middleware.go +++ b/httpd/middleware.go @@ -37,6 +37,11 @@ func jwtAuthenticator(next http.Handler) http.Handler { sendAPIResponse(w, r, err, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } + if isTokenInvalidated(r) { + logger.Debug(logSender, "", "the token has been invalidated") + sendAPIResponse(w, r, nil, "Your token is no longer valid", http.StatusUnauthorized) + return + } // Token is authenticated, pass it through next.ServeHTTP(w, r) @@ -59,6 +64,11 @@ func jwtAuthenticatorWeb(next http.Handler) http.Handler { http.Redirect(w, r, webLoginPath, http.StatusFound) return } + if isTokenInvalidated(r) { + logger.Debug(logSender, "", "the token has been invalidated") + http.Redirect(w, r, webLoginPath, http.StatusFound) + return + } // Token is authenticated, pass it through next.ServeHTTP(w, r) diff --git a/httpd/schema/openapi.yaml b/httpd/schema/openapi.yaml index 07474787..bc3a6e0a 100644 --- a/httpd/schema/openapi.yaml +++ b/httpd/schema/openapi.yaml @@ -2,7 +2,7 @@ openapi: 3.0.3 info: title: SFTPGo description: SFTPGo REST API - version: 2.4.2 + version: 2.4.3 servers: - url: /api/v2 @@ -49,6 +49,27 @@ paths: $ref: '#/components/responses/InternalServerError' default: $ref: '#/components/responses/DefaultResponse' + /logout: + get: + tags: + - token + summary: invalidate the access token + operationId: logout + responses: + 200: + description: successful operation + content: + application/json: + schema: + $ref : '#/components/schemas/ApiResponse' + 401: + $ref: '#/components/responses/Unauthorized' + 403: + $ref: '#/components/responses/Forbidden' + 500: + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' /version: get: tags: diff --git a/httpd/server.go b/httpd/server.go index 773efe16..120911c6 100644 --- a/httpd/server.go +++ b/httpd/server.go @@ -137,6 +137,11 @@ func (s *httpdServer) handleWebLoginPost(w http.ResponseWriter, r *http.Request) http.Redirect(w, r, webUsersPath, http.StatusFound) } +func (s *httpdServer) logout(w http.ResponseWriter, r *http.Request) { + invalidateToken(r) + sendAPIResponse(w, r, nil, "Your token has been invalidated", http.StatusOK) +} + func (s *httpdServer) getToken(w http.ResponseWriter, r *http.Request) { username, password, ok := r.BasicAuth() if !ok { @@ -274,6 +279,7 @@ func (s *httpdServer) initializeRouter() { render.JSON(w, r, version.Get()) }) + router.Get(logoutPath, s.logout) router.Put(adminPwdPath, changeAdminPassword) router.With(checkPerm(dataprovider.PermAdminViewServerStatus)). diff --git a/httpd/web.go b/httpd/web.go index de6d0490..68b9ed5c 100644 --- a/httpd/web.go +++ b/httpd/web.go @@ -980,7 +980,7 @@ func handleWebAdminChangePwdPost(w http.ResponseWriter, r *http.Request) { func handleWebLogout(w http.ResponseWriter, r *http.Request) { c := jwtTokenClaims{} - c.removeCookie(w) + c.removeCookie(w, r) http.Redirect(w, r, webLoginPath, http.StatusFound) }