JWT: replace jwtauth/jwx with lightweight wrapper around go-jose

We replaced the jwtauth and jwx libraries with a minimal custom wrapper
around go-jose because we don’t need the full feature set provided by jwx.
Implementing our own wrapper simplifies the codebase and improves
maintainability.

Moreover, go-jose depends only on the standard library, resulting in a
leaner dependency that still meets all our requirements.

This change also reduces the SFTPGo binary size by approximately 1MB

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
Nicola Murino
2025-10-08 18:10:39 +02:00
parent 9ca35c3555
commit 0ae2354fed
31 changed files with 1222 additions and 967 deletions

View File

@@ -24,12 +24,12 @@ import (
"strings"
"time"
"github.com/go-chi/jwtauth/v5"
"github.com/rs/xid"
"github.com/sftpgo/sdk"
"github.com/drakkan/sftpgo/v2/internal/common"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/jwt"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/util"
)
@@ -48,7 +48,7 @@ func (k *contextKey) String() string {
}
func validateJWTToken(w http.ResponseWriter, r *http.Request, audience tokenAudience) error {
token, _, err := jwtauth.FromContext(r.Context())
token, err := jwt.FromContext(r.Context())
var redirectPath string
if audience == tokenAudienceWebAdmin {
@@ -70,7 +70,7 @@ func validateJWTToken(w http.ResponseWriter, r *http.Request, audience tokenAudi
}
}
if err != nil || token == nil {
if err != nil {
logger.Debug(logSender, "", "error getting jwt token: %v", err)
doRedirect(http.StatusText(http.StatusUnauthorized), err)
return errInvalidToken
@@ -82,17 +82,17 @@ func validateJWTToken(w http.ResponseWriter, r *http.Request, audience tokenAudi
return errInvalidToken
}
// a user with a partial token will be always redirected to the appropriate two factor auth page
if err := checkPartialAuth(w, r, audience, token.Audience()); err != nil {
if err := checkPartialAuth(w, r, audience, token.Audience); err != nil {
return err
}
if !slices.Contains(token.Audience(), audience) {
if !token.Audience.Contains(audience) {
logger.Debug(logSender, "", "the token is not valid for audience %q", audience)
doRedirect("Your token audience is not valid", nil)
return errInvalidToken
}
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := validateIPForToken(token, ipAddr); err != nil {
logger.Debug(logSender, "", "the token with id %q is not valid for the ip address %q", token.JwtID(), ipAddr)
logger.Debug(logSender, "", "the token with id %q is not valid for the ip address %q", token.ID, ipAddr)
doRedirect("Your token is not valid", nil)
return err
}
@@ -104,14 +104,14 @@ func validateJWTToken(w http.ResponseWriter, r *http.Request, audience tokenAudi
}
func (s *httpdServer) validateJWTPartialToken(w http.ResponseWriter, r *http.Request, audience tokenAudience) error {
token, _, err := jwtauth.FromContext(r.Context())
token, err := jwt.FromContext(r.Context())
var notFoundFunc func(w http.ResponseWriter, r *http.Request, err error)
if audience == tokenAudienceWebAdminPartial {
notFoundFunc = s.renderNotFoundPage
} else {
notFoundFunc = s.renderClientNotFoundPage
}
if err != nil || token == nil {
if err != nil {
notFoundFunc(w, r, nil)
return errInvalidToken
}
@@ -119,14 +119,14 @@ func (s *httpdServer) validateJWTPartialToken(w http.ResponseWriter, r *http.Req
notFoundFunc(w, r, nil)
return errInvalidToken
}
if !slices.Contains(token.Audience(), audience) {
logger.Debug(logSender, "", "the partial token with id %q is not valid for audience %q", token.JwtID(), audience)
if !token.Audience.Contains(audience) {
logger.Debug(logSender, "", "the partial token with id %q is not valid for audience %q", token.ID, audience)
notFoundFunc(w, r, nil)
return errInvalidToken
}
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
if err := validateIPForToken(token, ipAddr); err != nil {
logger.Debug(logSender, "", "the partial token with id %q is not valid for the ip address %q", token.JwtID(), ipAddr)
logger.Debug(logSender, "", "the partial token with id %q is not valid for the ip address %q", token.ID, ipAddr)
notFoundFunc(w, r, nil)
return err
}
@@ -194,7 +194,7 @@ func jwtAuthenticatorWebClient(next http.Handler) http.Handler {
func (s *httpdServer) checkHTTPUserPerm(perm string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, claims, err := jwtauth.FromContext(r.Context())
claims, err := jwt.FromContext(r.Context())
if err != nil {
if isWebRequest(r) {
s.renderClientBadRequestPage(w, r, err)
@@ -203,10 +203,8 @@ func (s *httpdServer) checkHTTPUserPerm(perm string) func(next http.Handler) htt
}
return
}
tokenClaims := jwtTokenClaims{}
tokenClaims.Decode(claims)
// for web client perms are negated and not granted
if tokenClaims.hasPerm(perm) {
if claims.HasPerm(perm) {
if isWebRequest(r) {
s.renderClientForbiddenPage(w, r, errors.New("you don't have permission for this action"))
} else {
@@ -223,7 +221,7 @@ func (s *httpdServer) checkHTTPUserPerm(perm string) func(next http.Handler) htt
// checkAuthRequirements checks if the user must set a second factor auth or change the password
func (s *httpdServer) checkAuthRequirements(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, claims, err := jwtauth.FromContext(r.Context())
claims, err := jwt.FromContext(r.Context())
if err != nil {
if isWebRequest(r) {
if isWebClientRequest(r) {
@@ -236,13 +234,11 @@ func (s *httpdServer) checkAuthRequirements(next http.Handler) http.Handler {
}
return
}
tokenClaims := jwtTokenClaims{}
tokenClaims.Decode(claims)
if tokenClaims.MustSetTwoFactorAuth || tokenClaims.MustChangePassword {
if claims.MustSetTwoFactorAuth || claims.MustChangePassword {
var err error
if tokenClaims.MustSetTwoFactorAuth {
if len(tokenClaims.RequiredTwoFactorProtocols) > 0 {
protocols := strings.Join(tokenClaims.RequiredTwoFactorProtocols, ", ")
if claims.MustSetTwoFactorAuth {
if len(claims.RequiredTwoFactorProtocols) > 0 {
protocols := strings.Join(claims.RequiredTwoFactorProtocols, ", ")
err = util.NewI18nError(
util.NewGenericError(
fmt.Sprintf("Two-factor authentication requirements not met, please configure two-factor authentication for the following protocols: %v",
@@ -301,7 +297,7 @@ func (s *httpdServer) requireBuiltinLogin(next http.Handler) http.Handler {
func (s *httpdServer) checkPerms(perms ...string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, claims, err := jwtauth.FromContext(r.Context())
claims, err := jwt.FromContext(r.Context())
if err != nil {
if isWebRequest(r) {
s.renderBadRequestPage(w, r, err)
@@ -310,11 +306,9 @@ func (s *httpdServer) checkPerms(perms ...string) func(next http.Handler) http.H
}
return
}
tokenClaims := jwtTokenClaims{}
tokenClaims.Decode(claims)
for _, perm := range perms {
if !tokenClaims.hasPerm(perm) {
if !claims.HasPerm(perm) {
if isWebRequest(r) {
s.renderForbiddenPage(w, r, util.NewI18nError(fs.ErrPermission, util.I18nError403Message))
} else {
@@ -332,14 +326,14 @@ func (s *httpdServer) checkPerms(perms ...string) func(next http.Handler) http.H
func (s *httpdServer) verifyCSRFHeader(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tokenString := r.Header.Get(csrfHeaderToken)
token, err := jwtauth.VerifyToken(s.csrfTokenAuth, tokenString)
token, err := jwt.VerifyToken(s.csrfTokenAuth, tokenString)
if err != nil || token == nil {
logger.Debug(logSender, "", "error validating CSRF header: %v", err)
sendAPIResponse(w, r, err, "Invalid token", http.StatusForbidden)
return
}
if !slices.Contains(token.Audience(), tokenAudienceCSRF) {
if !token.Audience.Contains(tokenAudienceCSRF) {
logger.Debug(logSender, "", "error validating CSRF header token audience")
sendAPIResponse(w, r, errors.New("the token is not valid"), "", http.StatusForbidden)
return
@@ -359,49 +353,52 @@ func (s *httpdServer) verifyCSRFHeader(next http.Handler) http.Handler {
})
}
func checkNodeToken(tokenAuth *jwtauth.JWTAuth) func(next http.Handler) http.Handler {
func checkNodeToken(tokenAuth *jwt.Signer) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := r.Header.Get(dataprovider.NodeTokenHeader)
if token == "" {
bearer := r.Header.Get(dataprovider.NodeTokenHeader)
if bearer == "" {
next.ServeHTTP(w, r)
return
}
if len(token) > 7 && strings.ToUpper(token[0:6]) == "BEARER" {
token = token[7:]
const prefix = "Bearer "
if len(bearer) >= len(prefix) && strings.EqualFold(bearer[:len(prefix)], prefix) {
bearer = bearer[len(prefix):]
}
if invalidatedJWTTokens.Get(token) {
if invalidatedJWTTokens.Get(bearer) {
logger.Debug(logSender, "", "the node token has been invalidated")
sendAPIResponse(w, r, fmt.Errorf("the provided token is not valid"), "", http.StatusUnauthorized)
return
}
admin, role, perms, err := dataprovider.AuthenticateNodeToken(token)
claims, err := dataprovider.AuthenticateNodeToken(bearer)
if err != nil {
logger.Debug(logSender, "", "unable to authenticate node token %q: %v", token, err)
logger.Debug(logSender, "", "unable to authenticate node token %q: %v", bearer, err)
sendAPIResponse(w, r, fmt.Errorf("the provided token cannot be authenticated"), "", http.StatusUnauthorized)
return
}
defer invalidatedJWTTokens.Add(token, time.Now().Add(2*time.Minute).UTC())
defer invalidatedJWTTokens.Add(bearer, time.Now().Add(2*time.Minute).UTC())
c := jwtTokenClaims{
Username: admin,
Permissions: perms,
c := &jwt.Claims{
Username: claims.Username,
Permissions: claims.Permissions,
NodeID: dataprovider.GetNodeName(),
Role: role,
Role: claims.Role,
}
resp, err := c.createTokenResponse(tokenAuth, tokenAudienceAPI, util.GetIPFromRemoteAddress(r.RemoteAddr))
token, err := tokenAuth.SignWithParams(c, tokenAudienceAPI, util.GetIPFromRemoteAddress(r.RemoteAddr), getTokenDuration(tokenAudienceAPI))
if err != nil {
sendAPIResponse(w, r, err, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
r.Header.Set("Authorization", fmt.Sprintf("Bearer %v", resp["access_token"]))
resp := c.BuildTokenResponse(token)
r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", resp.Token))
next.ServeHTTP(w, r)
})
}
}
func checkAPIKeyAuth(tokenAuth *jwtauth.JWTAuth, scope dataprovider.APIKeyScope) func(next http.Handler) http.Handler {
func checkAPIKeyAuth(tokenAuth *jwt.Signer, scope dataprovider.APIKeyScope) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
apiKey := r.Header.Get("X-SFTPGO-API-KEY")
@@ -484,7 +481,7 @@ func checkAPIKeyAuth(tokenAuth *jwtauth.JWTAuth, scope dataprovider.APIKeyScope)
func forbidAPIKeyAuthentication(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
claims, err := getTokenClaims(r)
claims, err := jwt.FromContext(r.Context())
if err != nil || claims.Username == "" {
sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest)
return
@@ -498,7 +495,7 @@ func forbidAPIKeyAuthentication(next http.Handler) http.Handler {
})
}
func authenticateAdminWithAPIKey(username, keyID string, tokenAuth *jwtauth.JWTAuth, r *http.Request) error {
func authenticateAdminWithAPIKey(username, keyID string, tokenAuth *jwt.Signer, r *http.Request) error {
if username == "" {
return errors.New("the provided key is not associated with any admin and no username was provided")
}
@@ -513,25 +510,26 @@ func authenticateAdminWithAPIKey(username, keyID string, tokenAuth *jwtauth.JWTA
if err := admin.CanLogin(ipAddr); err != nil {
return err
}
c := jwtTokenClaims{
c := &jwt.Claims{
Username: admin.Username,
Permissions: admin.Permissions,
Signature: admin.GetSignature(),
Role: admin.Role,
APIKeyID: keyID,
}
c.Subject = admin.GetSignature()
resp, err := c.createTokenResponse(tokenAuth, tokenAudienceAPI, ipAddr)
token, err := tokenAuth.SignWithParams(c, tokenAudienceAPI, ipAddr, getTokenDuration(tokenAudienceAPI))
if err != nil {
return err
}
r.Header.Set("Authorization", fmt.Sprintf("Bearer %v", resp["access_token"]))
resp := c.BuildTokenResponse(token)
r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", resp.Token))
dataprovider.UpdateAdminLastLogin(&admin)
common.DelayLogin(nil)
return nil
}
func authenticateUserWithAPIKey(username, keyID string, tokenAuth *jwtauth.JWTAuth, r *http.Request) error {
func authenticateUserWithAPIKey(username, keyID string, tokenAuth *jwt.Signer, r *http.Request) error {
ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr)
protocol := common.ProtocolHTTP
if username == "" {
@@ -569,20 +567,21 @@ func authenticateUserWithAPIKey(username, keyID string, tokenAuth *jwtauth.JWTAu
updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure, r)
return common.ErrInternalFailure
}
c := jwtTokenClaims{
c := &jwt.Claims{
Username: user.Username,
Permissions: user.Filters.WebClient,
Signature: user.GetSignature(),
Role: user.Role,
APIKeyID: keyID,
}
c.Subject = user.GetSignature()
resp, err := c.createTokenResponse(tokenAuth, tokenAudienceAPIUser, ipAddr)
token, err := tokenAuth.SignWithParams(c, tokenAudienceAPIUser, ipAddr, getTokenDuration(tokenAudienceAPIUser))
if err != nil {
updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure, r)
return err
}
r.Header.Set("Authorization", fmt.Sprintf("Bearer %v", resp["access_token"]))
resp := c.BuildTokenResponse(token)
r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", resp.Token))
dataprovider.UpdateLastLogin(&user)
updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, nil, r)