From 0ae2354fed4e9ffce792f1171448ec1818e0b620 Mon Sep 17 00:00:00 2001 From: Nicola Murino Date: Wed, 8 Oct 2025 18:10:39 +0200 Subject: [PATCH] JWT: replace jwtauth/jwx with lightweight wrapper around go-jose MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit We replaced the jwtauth and jwx libraries with a minimal custom wrapper around go-jose because we don’t need the full feature set provided by jwx. Implementing our own wrapper simplifies the codebase and improves maintainability. Moreover, go-jose depends only on the standard library, resulting in a leaner dependency that still meets all our requirements. This change also reduces the SFTPGo binary size by approximately 1MB Signed-off-by: Nicola Murino --- go.mod | 12 +- go.sum | 20 - internal/dataprovider/node.go | 85 ++--- internal/httpd/api_admin.go | 27 +- internal/httpd/api_eventrule.go | 29 +- internal/httpd/api_events.go | 7 +- internal/httpd/api_folder.go | 15 +- internal/httpd/api_group.go | 15 +- internal/httpd/api_http_user.go | 9 +- internal/httpd/api_iplist.go | 7 +- internal/httpd/api_keys.go | 7 +- internal/httpd/api_maintenance.go | 5 +- internal/httpd/api_mfa.go | 17 +- internal/httpd/api_quota.go | 9 +- internal/httpd/api_retention.go | 3 +- internal/httpd/api_role.go | 7 +- internal/httpd/api_shares.go | 27 +- internal/httpd/api_user.go | 19 +- internal/httpd/api_utils.go | 9 +- internal/httpd/auth_utils.go | 365 ++++--------------- internal/httpd/httpd.go | 8 +- internal/httpd/httpd_test.go | 9 +- internal/httpd/internal_test.go | 588 +++++++++++++++--------------- internal/httpd/middleware.go | 109 +++--- internal/httpd/oidc.go | 14 +- internal/httpd/oidc_test.go | 33 +- internal/httpd/server.go | 152 ++++---- internal/httpd/webadmin.go | 59 +-- internal/httpd/webclient.go | 34 +- internal/jwt/jwt.go | 264 ++++++++++++++ internal/jwt/jwt_test.go | 225 ++++++++++++ 31 files changed, 1222 insertions(+), 967 deletions(-) create mode 100644 internal/jwt/jwt.go create mode 100644 internal/jwt/jwt_test.go diff --git a/go.mod b/go.mod index 80b21341..867a3f1f 100644 --- a/go.mod +++ b/go.mod @@ -25,8 +25,8 @@ require ( github.com/fclairamb/go-log v0.6.0 github.com/go-acme/lego/v4 v4.26.0 github.com/go-chi/chi/v5 v5.2.3 - github.com/go-chi/jwtauth/v5 v5.3.3 github.com/go-chi/render v1.0.3 + github.com/go-jose/go-jose/v4 v4.1.3 github.com/go-sql-driver/mysql v1.9.3 github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 github.com/google/uuid v1.6.0 @@ -36,7 +36,6 @@ require ( github.com/jackc/pgx/v5 v5.7.6 github.com/jlaffaye/ftp v0.2.0 github.com/klauspost/compress v1.18.0 - github.com/lestrrat-go/jwx/v2 v2.1.6 github.com/lithammer/shortuuid/v4 v4.2.0 github.com/mattn/go-sqlite3 v1.14.32 github.com/mhale/smtpd v0.8.3 @@ -110,18 +109,15 @@ require ( github.com/coreos/go-systemd/v22 v22.6.0 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.7 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect - github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 // indirect github.com/envoyproxy/go-control-plane/envoy v1.35.0 // indirect github.com/envoyproxy/protoc-gen-validate v1.2.1 // indirect github.com/fatih/color v1.18.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect - github.com/go-jose/go-jose/v4 v4.1.3 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-ole/go-ole v1.3.0 // indirect github.com/go-viper/mapstructure/v2 v2.4.0 // indirect - github.com/goccy/go-json v0.10.5 // indirect github.com/golang-jwt/jwt/v5 v5.3.0 // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/google/s2a-go v0.1.9 // indirect @@ -137,11 +133,6 @@ require ( github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/kr/fs v0.1.0 // indirect github.com/kylelemons/godebug v1.1.0 // indirect - github.com/lestrrat-go/blackmagic v1.0.4 // indirect - github.com/lestrrat-go/httpcc v1.0.1 // indirect - github.com/lestrrat-go/httprc v1.0.6 // indirect - github.com/lestrrat-go/iter v1.0.2 // indirect - github.com/lestrrat-go/option v1.0.1 // indirect github.com/lufia/plan9stats v0.0.0-20250827001030-24949be3fa54 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect @@ -159,7 +150,6 @@ require ( github.com/prometheus/procfs v0.17.0 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/sagikazarmark/locafero v0.12.0 // indirect - github.com/segmentio/asm v1.2.1 // indirect github.com/shoenig/go-m1cpu v0.1.7 // indirect github.com/spf13/cast v1.10.0 // indirect github.com/spf13/pflag v1.0.10 // indirect diff --git a/go.sum b/go.sum index a3002f4f..3b29cbb4 100644 --- a/go.sum +++ b/go.sum @@ -124,8 +124,6 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 h1:NMZiJj8QnKe1LgsbDayM4UoHwbvwDRwnI3hwNaAHRnc= -github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0/go.mod h1:ZXNYxsqcloTdSy/rNShjYzMhyjf0LaoftYK0p+A3h40= github.com/drakkan/cron/v3 v3.0.0-20230222140221-217a1e4d96c0 h1:EW9gIJRmt9lzk66Fhh4S8VEtURA6QHZqGeSRE9Nb2/U= github.com/drakkan/cron/v3 v3.0.0-20230222140221-217a1e4d96c0/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/drakkan/ftp v0.0.0-20240430173938-7ba8270c8e7f h1:S9JUlrOzjK58UKoLqqb40YLyVlt0bcIFtYrvnanV3zc= @@ -159,8 +157,6 @@ github.com/go-acme/lego/v4 v4.26.0 h1:521aEQxNstXvPQcFDDPrJiFfixcCQuvAvm35R4GbyY github.com/go-acme/lego/v4 v4.26.0/go.mod h1:BQVAWgcyzW4IT9eIKHY/RxYlVhoyKyOMXOkq7jK1eEQ= github.com/go-chi/chi/v5 v5.2.3 h1:WQIt9uxdsAbgIYgid+BpYc+liqQZGMHRaUwp0JUcvdE= github.com/go-chi/chi/v5 v5.2.3/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops= -github.com/go-chi/jwtauth/v5 v5.3.3 h1:50Uzmacu35/ZP9ER2Ht6SazwPsnLQ9LRJy6zTZJpHEo= -github.com/go-chi/jwtauth/v5 v5.3.3/go.mod h1:O4QvPRuZLZghl9WvfVaON+ARfGzpD2PBX/QY5vUz7aQ= github.com/go-chi/render v1.0.3 h1:AsXqd2a1/INaIfUSKq3G5uA8weYx20FOsM7uSoCyyt4= github.com/go-chi/render v1.0.3/go.mod h1:/gr3hVkmYR0YlEy3LxCuVRFzEu9Ruok+gFqbIofjao0= github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs= @@ -181,8 +177,6 @@ github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1 github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU= github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= -github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= -github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gofrs/flock v0.12.1 h1:MTLVXXHf8ekldpJk3AKicLij9MdwOWkZ+a/jHHZby9E= github.com/gofrs/flock v0.12.1/go.mod h1:9zxTsyu5xtJ9DK+1tFZyibEV7y3uwDxPPfbxeeHCoD0= @@ -246,18 +240,6 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= -github.com/lestrrat-go/blackmagic v1.0.4 h1:IwQibdnf8l2KoO+qC3uT4OaTWsW7tuRQXy9TRN9QanA= -github.com/lestrrat-go/blackmagic v1.0.4/go.mod h1:6AWFyKNNj0zEXQYfTMPfZrAXUWUfTIZ5ECEUEJaijtw= -github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZrIE= -github.com/lestrrat-go/httpcc v1.0.1/go.mod h1:qiltp3Mt56+55GPVCbTdM9MlqhvzyuL6W/NMDA8vA5E= -github.com/lestrrat-go/httprc v1.0.6 h1:qgmgIRhpvBqexMJjA/PmwSvhNk679oqD1RbovdCGW8k= -github.com/lestrrat-go/httprc v1.0.6/go.mod h1:mwwz3JMTPBjHUkkDv/IGJ39aALInZLrhBp0X7KGUZlo= -github.com/lestrrat-go/iter v1.0.2 h1:gMXo1q4c2pHmC3dn8LzRhJfP1ceCbgSiT9lUydIzltI= -github.com/lestrrat-go/iter v1.0.2/go.mod h1:Momfcq3AnRlRjI5b5O8/G5/BvpzrhoFTZcn06fEOPt4= -github.com/lestrrat-go/jwx/v2 v2.1.6 h1:hxM1gfDILk/l5ylers6BX/Eq1m/pnxe9NBwW6lVfecA= -github.com/lestrrat-go/jwx/v2 v2.1.6/go.mod h1:Y722kU5r/8mV7fYDifjug0r8FK8mZdw0K0GpJw/l8pU= -github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU= -github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lithammer/shortuuid/v4 v4.2.0 h1:LMFOzVB3996a7b8aBuEXxqOBflbfPQAiVzkIcHO0h8c= @@ -331,8 +313,6 @@ github.com/sagikazarmark/locafero v0.12.0 h1:/NQhBAkUb4+fH1jivKHWusDYFjMOOKU88ee github.com/sagikazarmark/locafero v0.12.0/go.mod h1:sZh36u/YSZ918v0Io+U9ogLYQJ9tLLBmM4eneO6WwsI= github.com/secsy/goftp v0.0.0-20200609142545-aa2de14babf4 h1:PT+ElG/UUFMfqy5HrxJxNzj3QBOf7dZwupeVC+mG1Lo= github.com/secsy/goftp v0.0.0-20200609142545-aa2de14babf4/go.mod h1:MnkX001NG75g3p8bhFycnyIjeQoOjGL6CEIsdE/nKSY= -github.com/segmentio/asm v1.2.1 h1:DTNbBqs57ioxAD4PrArqftgypG4/qNpXoJx8TVXxPR0= -github.com/segmentio/asm v1.2.1/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs= github.com/sftpgo/sdk v0.1.9-0.20241011171103-64fc18a344f9 h1:wlXBnaNfJJJRZjHO2AerSS5gp0ckkYUgBzSXivUo0Wo= github.com/sftpgo/sdk v0.1.9-0.20241011171103-64fc18a344f9/go.mod h1:ehimvlTP+XTEiE3t1CPwWx9n7+6A6OGvMGlZ7ouvKFk= github.com/shirou/gopsutil/v3 v3.24.5 h1:i0t8kL+kQTvpAYToeuiVk3TgDeKOFioZO3Ztz/iZ9pI= diff --git a/internal/dataprovider/node.go b/internal/dataprovider/node.go index eb08b103..3f31a733 100644 --- a/internal/dataprovider/node.go +++ b/internal/dataprovider/node.go @@ -28,11 +28,10 @@ import ( "strings" "time" - "github.com/lestrrat-go/jwx/v2/jwa" - "github.com/lestrrat-go/jwx/v2/jwt" - "github.com/rs/xid" + "github.com/go-jose/go-jose/v4" "github.com/drakkan/sftpgo/v2/internal/httpclient" + "github.com/drakkan/sftpgo/v2/internal/jwt" "github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" @@ -46,7 +45,8 @@ const ( const ( // NodeTokenHeader defines the header to use for the node auth token - NodeTokenHeader = "X-SFTPGO-Node" + NodeTokenHeader = "X-SFTPGO-Node" + nodeTokenAudience = "node" ) var ( @@ -132,35 +132,26 @@ func (n *Node) validate() error { return n.Data.validate() } -func (n *Node) authenticate(token string) (string, string, []string, error) { +func (n *Node) authenticate(token string) (*jwt.Claims, error) { if err := n.Data.Key.TryDecrypt(); err != nil { providerLog(logger.LevelError, "unable to decrypt node key: %v", err) - return "", "", nil, err + return nil, err } if token == "" { - return "", "", nil, ErrInvalidCredentials + return nil, ErrInvalidCredentials } - t, err := jwt.Parse([]byte(token), jwt.WithKey(jwa.HS256, []byte(n.Data.Key.GetPayload())), jwt.WithValidate(true)) + claims, err := jwt.VerifyTokenWithKey(token, []jose.SignatureAlgorithm{jose.HS256}, []byte(n.Data.Key.GetPayload())) if err != nil { - return "", "", nil, fmt.Errorf("unable to parse and validate token: %v", err) + return nil, fmt.Errorf("unable to parse and validate token: %v", err) } - var adminUsername, role string - if admin, ok := t.Get("admin"); ok { - if val, ok := admin.(string); ok && val != "" { - adminUsername = val - } + if claims.Username == "" { + return nil, errors.New("no admin username associated with node token") } - if adminUsername == "" { - return "", "", nil, errors.New("no admin username associated with node token") + if !claims.Audience.Contains(nodeTokenAudience) { + return nil, errors.New("invalid node token audience") } - if r, ok := t.Get("role"); ok { - if val, ok := r.(string); ok && val != "" { - role = val - } - } - perms := getPermsFromToken(t) - return adminUsername, role, perms, nil + return claims, nil } // getBaseURL returns the base URL for this node @@ -181,22 +172,22 @@ func (n *Node) generateAuthToken(username, role string, permissions []string) (s if err := n.Data.Key.TryDecrypt(); err != nil { return "", fmt.Errorf("unable to decrypt node key: %w", err) } - now := time.Now().UTC() - - t := jwt.New() - t.Set("admin", username) //nolint:errcheck - t.Set("role", role) //nolint:errcheck - t.Set("perms", permissions) //nolint:errcheck - t.Set(jwt.IssuedAtKey, now) //nolint:errcheck - t.Set(jwt.JwtIDKey, xid.New().String()) //nolint:errcheck - t.Set(jwt.NotBeforeKey, now.Add(-30*time.Second)) //nolint:errcheck - t.Set(jwt.ExpirationKey, now.Add(1*time.Minute)) //nolint:errcheck - - payload, err := jwt.Sign(t, jwt.WithKey(jwa.HS256, []byte(n.Data.Key.GetPayload()))) + signer, err := jwt.NewSigner(jose.HS256, []byte(n.Data.Key.GetPayload())) + if err != nil { + return "", fmt.Errorf("unable to create signer: %w", err) + } + claims := &jwt.Claims{ + Username: username, + Role: role, + Permissions: permissions, + } + claims.Audience = []string{nodeTokenAudience} + claims.SetExpiry(time.Now().Add(1 * time.Minute)) + payload, err := signer.Sign(claims) if err != nil { return "", fmt.Errorf("unable to sign authentication token: %w", err) } - return util.BytesToString(payload), nil + return payload, nil } func (n *Node) prepareRequest(ctx context.Context, username, role, relativeURL, method string, @@ -273,9 +264,9 @@ func (n *Node) SendDeleteRequest(username, role, relativeURL string, permissions } // AuthenticateNodeToken check the validity of the provided token -func AuthenticateNodeToken(token string) (string, string, []string, error) { +func AuthenticateNodeToken(token string) (*jwt.Claims, error) { if currentNode == nil { - return "", "", nil, errNoClusterNodes + return nil, errNoClusterNodes } return currentNode.authenticate(token) } @@ -287,21 +278,3 @@ func GetNodeName() string { } return currentNode.Name } - -func getPermsFromToken(t jwt.Token) []string { - var perms []string - if p, ok := t.Get("perms"); ok { - switch v := p.(type) { - case []any: - for _, elem := range v { - switch elemValue := elem.(type) { - case string: - perms = append(perms, elemValue) - } - } - case []string: - perms = v - } - } - return perms -} diff --git a/internal/httpd/api_admin.go b/internal/httpd/api_admin.go index 2a8c7ec6..2fd093a0 100644 --- a/internal/httpd/api_admin.go +++ b/internal/httpd/api_admin.go @@ -21,10 +21,10 @@ import ( "net/http" "net/url" - "github.com/go-chi/jwtauth/v5" "github.com/go-chi/render" "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/jwt" "github.com/drakkan/sftpgo/v2/internal/smtp" "github.com/drakkan/sftpgo/v2/internal/util" ) @@ -68,7 +68,7 @@ func renderAdmin(w http.ResponseWriter, r *http.Request, username string, status func addAdmin(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return @@ -90,7 +90,7 @@ func addAdmin(w http.ResponseWriter, r *http.Request) { func disableAdmin2FA(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return @@ -138,7 +138,7 @@ func updateAdmin(w http.ResponseWriter, r *http.Request) { return } - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return @@ -182,7 +182,7 @@ func updateAdmin(w http.ResponseWriter, r *http.Request) { func deleteAdmin(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) username := getURLParam(r, "username") - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return @@ -202,7 +202,7 @@ func deleteAdmin(w http.ResponseWriter, r *http.Request) { func getAdminProfile(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return @@ -224,7 +224,7 @@ func getAdminProfile(w http.ResponseWriter, r *http.Request) { func updateAdminProfile(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return @@ -317,7 +317,7 @@ func doChangeAdminPassword(r *http.Request, currentPassword, newPassword, confir util.I18nErrorChangePwdNoDifferent, ) } - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil { return util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken) } @@ -335,14 +335,3 @@ func doChangeAdminPassword(r *http.Request, currentPassword, newPassword, confir return dataprovider.UpdateAdmin(&admin, dataprovider.ActionExecutorSelf, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) } - -func getTokenClaims(r *http.Request) (jwtTokenClaims, error) { - tokenClaims := jwtTokenClaims{} - _, claims, err := jwtauth.FromContext(r.Context()) - if err != nil { - return tokenClaims, err - } - tokenClaims.Decode(claims) - - return tokenClaims, nil -} diff --git a/internal/httpd/api_eventrule.go b/internal/httpd/api_eventrule.go index a4daaaa6..b8474af4 100644 --- a/internal/httpd/api_eventrule.go +++ b/internal/httpd/api_eventrule.go @@ -24,6 +24,7 @@ import ( "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/jwt" "github.com/drakkan/sftpgo/v2/internal/util" ) @@ -42,7 +43,7 @@ func getEventActions(w http.ResponseWriter, r *http.Request) { render.JSON(w, r, actions) } -func renderEventAction(w http.ResponseWriter, r *http.Request, name string, claims *jwtTokenClaims, status int) { +func renderEventAction(w http.ResponseWriter, r *http.Request, name string, claims *jwt.Claims, status int) { action, err := dataprovider.EventActionExists(name) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) @@ -61,19 +62,19 @@ func renderEventAction(w http.ResponseWriter, r *http.Request, name string, clai func getEventActionByName(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } name := getURLParam(r, "name") - renderEventAction(w, r, name, &claims, http.StatusOK) + renderEventAction(w, r, name, claims, http.StatusOK) } func addEventAction(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return @@ -91,12 +92,12 @@ func addEventAction(w http.ResponseWriter, r *http.Request) { return } w.Header().Add("Location", fmt.Sprintf("%s/%s", eventActionsPath, url.PathEscape(action.Name))) - renderEventAction(w, r, action.Name, &claims, http.StatusCreated) + renderEventAction(w, r, action.Name, claims, http.StatusCreated) } func updateEventAction(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return @@ -136,7 +137,7 @@ func updateEventAction(w http.ResponseWriter, r *http.Request) { func deleteEventAction(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return @@ -165,7 +166,7 @@ func getEventRules(w http.ResponseWriter, r *http.Request) { render.JSON(w, r, rules) } -func renderEventRule(w http.ResponseWriter, r *http.Request, name string, claims *jwtTokenClaims, status int) { +func renderEventRule(w http.ResponseWriter, r *http.Request, name string, claims *jwt.Claims, status int) { rule, err := dataprovider.EventRuleExists(name) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) @@ -184,19 +185,19 @@ func renderEventRule(w http.ResponseWriter, r *http.Request, name string, claims func getEventRuleByName(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } name := getURLParam(r, "name") - renderEventRule(w, r, name, &claims, http.StatusOK) + renderEventRule(w, r, name, claims, http.StatusOK) } func addEventRule(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return @@ -213,12 +214,12 @@ func addEventRule(w http.ResponseWriter, r *http.Request) { return } w.Header().Add("Location", fmt.Sprintf("%s/%s", eventRulesPath, url.PathEscape(rule.Name))) - renderEventRule(w, r, rule.Name, &claims, http.StatusCreated) + renderEventRule(w, r, rule.Name, claims, http.StatusCreated) } func updateEventRule(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return @@ -249,7 +250,7 @@ func updateEventRule(w http.ResponseWriter, r *http.Request) { func deleteEventRule(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return diff --git a/internal/httpd/api_events.go b/internal/httpd/api_events.go index 29e78c5e..a5ad1a5c 100644 --- a/internal/httpd/api_events.go +++ b/internal/httpd/api_events.go @@ -27,6 +27,7 @@ import ( "github.com/sftpgo/sdk/plugin/notifier" "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/jwt" "github.com/drakkan/sftpgo/v2/internal/plugin" "github.com/drakkan/sftpgo/v2/internal/util" ) @@ -143,7 +144,7 @@ func getLogSearchParamsFromRequest(r *http.Request) (eventsearcher.LogEventSearc func searchFsEvents(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return @@ -176,7 +177,7 @@ func searchFsEvents(w http.ResponseWriter, r *http.Request) { func searchProviderEvents(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return @@ -211,7 +212,7 @@ func searchProviderEvents(w http.ResponseWriter, r *http.Request) { func searchLogEvents(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return diff --git a/internal/httpd/api_folder.go b/internal/httpd/api_folder.go index debc8483..c46d4ab3 100644 --- a/internal/httpd/api_folder.go +++ b/internal/httpd/api_folder.go @@ -23,6 +23,7 @@ import ( "github.com/go-chi/render" "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/jwt" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/vfs" ) @@ -45,7 +46,7 @@ func getFolders(w http.ResponseWriter, r *http.Request) { func addFolder(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return @@ -62,12 +63,12 @@ func addFolder(w http.ResponseWriter, r *http.Request) { return } w.Header().Add("Location", fmt.Sprintf("%s/%s", folderPath, url.PathEscape(folder.Name))) - renderFolder(w, r, folder.Name, &claims, http.StatusCreated) + renderFolder(w, r, folder.Name, claims, http.StatusCreated) } func updateFolder(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return @@ -100,7 +101,7 @@ func updateFolder(w http.ResponseWriter, r *http.Request) { sendAPIResponse(w, r, nil, "Folder updated", http.StatusOK) } -func renderFolder(w http.ResponseWriter, r *http.Request, name string, claims *jwtTokenClaims, status int) { +func renderFolder(w http.ResponseWriter, r *http.Request, name string, claims *jwt.Claims, status int) { folder, err := dataprovider.GetFolderByName(name) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) @@ -119,18 +120,18 @@ func renderFolder(w http.ResponseWriter, r *http.Request, name string, claims *j func getFolderByName(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } name := getURLParam(r, "name") - renderFolder(w, r, name, &claims, http.StatusOK) + renderFolder(w, r, name, claims, http.StatusOK) } func deleteFolder(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return diff --git a/internal/httpd/api_group.go b/internal/httpd/api_group.go index 125319d5..beae9ad2 100644 --- a/internal/httpd/api_group.go +++ b/internal/httpd/api_group.go @@ -23,6 +23,7 @@ import ( "github.com/go-chi/render" "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/jwt" "github.com/drakkan/sftpgo/v2/internal/util" ) @@ -44,7 +45,7 @@ func getGroups(w http.ResponseWriter, r *http.Request) { func addGroup(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return @@ -61,12 +62,12 @@ func addGroup(w http.ResponseWriter, r *http.Request) { return } w.Header().Add("Location", fmt.Sprintf("%s/%s", groupPath, url.PathEscape(group.Name))) - renderGroup(w, r, group.Name, &claims, http.StatusCreated) + renderGroup(w, r, group.Name, claims, http.StatusCreated) } func updateGroup(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return @@ -98,7 +99,7 @@ func updateGroup(w http.ResponseWriter, r *http.Request) { sendAPIResponse(w, r, nil, "Group updated", http.StatusOK) } -func renderGroup(w http.ResponseWriter, r *http.Request, name string, claims *jwtTokenClaims, status int) { +func renderGroup(w http.ResponseWriter, r *http.Request, name string, claims *jwt.Claims, status int) { group, err := dataprovider.GroupExists(name) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) @@ -117,18 +118,18 @@ func renderGroup(w http.ResponseWriter, r *http.Request, name string, claims *jw func getGroupByName(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } name := getURLParam(r, "name") - renderGroup(w, r, name, &claims, http.StatusOK) + renderGroup(w, r, name, claims, http.StatusOK) } func deleteGroup(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return diff --git a/internal/httpd/api_http_user.go b/internal/httpd/api_http_user.go index 373b7bfe..45bbd52f 100644 --- a/internal/httpd/api_http_user.go +++ b/internal/httpd/api_http_user.go @@ -31,12 +31,13 @@ import ( "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/jwt" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" ) func getUserConnection(w http.ResponseWriter, r *http.Request) (*Connection, error) { - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return nil, fmt.Errorf("invalid token claims %w", err) @@ -457,7 +458,7 @@ func getUserFilesAsZipStream(w http.ResponseWriter, r *http.Request) { func getUserProfile(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return @@ -482,7 +483,7 @@ func getUserProfile(w http.ResponseWriter, r *http.Request) { func updateUserProfile(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return @@ -557,7 +558,7 @@ func doChangeUserPassword(r *http.Request, currentPassword, newPassword, confirm util.I18nErrorChangePwdNoDifferent, ) } - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { return util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken) } diff --git a/internal/httpd/api_iplist.go b/internal/httpd/api_iplist.go index 5c1de975..d74b68d8 100644 --- a/internal/httpd/api_iplist.go +++ b/internal/httpd/api_iplist.go @@ -25,6 +25,7 @@ import ( "github.com/go-chi/render" "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/jwt" "github.com/drakkan/sftpgo/v2/internal/util" ) @@ -68,7 +69,7 @@ func getIPListEntry(w http.ResponseWriter, r *http.Request) { func addIPListEntry(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return @@ -91,7 +92,7 @@ func addIPListEntry(w http.ResponseWriter, r *http.Request) { func updateIPListEntry(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return @@ -125,7 +126,7 @@ func updateIPListEntry(w http.ResponseWriter, r *http.Request) { func deleteIPListEntry(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return diff --git a/internal/httpd/api_keys.go b/internal/httpd/api_keys.go index d95d88de..a9a3bcb2 100644 --- a/internal/httpd/api_keys.go +++ b/internal/httpd/api_keys.go @@ -23,6 +23,7 @@ import ( "github.com/go-chi/render" "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/jwt" "github.com/drakkan/sftpgo/v2/internal/util" ) @@ -56,7 +57,7 @@ func getAPIKeyByID(w http.ResponseWriter, r *http.Request) { func addAPIKey(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return @@ -87,7 +88,7 @@ func addAPIKey(w http.ResponseWriter, r *http.Request) { func updateAPIKey(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return @@ -119,7 +120,7 @@ func updateAPIKey(w http.ResponseWriter, r *http.Request) { func deleteAPIKey(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) keyID := getURLParam(r, "id") - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return diff --git a/internal/httpd/api_maintenance.go b/internal/httpd/api_maintenance.go index 6152ab4e..560c702e 100644 --- a/internal/httpd/api_maintenance.go +++ b/internal/httpd/api_maintenance.go @@ -29,6 +29,7 @@ import ( "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/jwt" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/vfs" @@ -115,7 +116,7 @@ func dumpData(w http.ResponseWriter, r *http.Request) { func loadDataFromRequest(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, MaxRestoreSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return @@ -143,7 +144,7 @@ func loadDataFromRequest(w http.ResponseWriter, r *http.Request) { func loadData(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return diff --git a/internal/httpd/api_mfa.go b/internal/httpd/api_mfa.go index ccb44f7e..73df7593 100644 --- a/internal/httpd/api_mfa.go +++ b/internal/httpd/api_mfa.go @@ -27,6 +27,7 @@ import ( "github.com/go-chi/render" "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/jwt" "github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/internal/mfa" "github.com/drakkan/sftpgo/v2/internal/util" @@ -66,13 +67,13 @@ func getTOTPConfigs(w http.ResponseWriter, r *http.Request) { func generateTOTPSecret(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } var accountName string - if claims.hasUserAudience() { + if hasUserAudience(claims) { accountName = fmt.Sprintf("User %q", claims.Username) } else { accountName = fmt.Sprintf("Admin %q", claims.Username) @@ -113,7 +114,7 @@ func getQRCode(w http.ResponseWriter, r *http.Request) { func saveTOTPConfig(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return @@ -124,7 +125,7 @@ func saveTOTPConfig(w http.ResponseWriter, r *http.Request) { recoveryCodes = append(recoveryCodes, dataprovider.RecoveryCode{Secret: kms.NewPlainSecret(code)}) } baseURL := webBaseClientPath - if claims.hasUserAudience() { + if hasUserAudience(claims) { if err := saveUserTOTPConfig(claims.Username, r, recoveryCodes); err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) return @@ -164,14 +165,14 @@ func validateTOTPPasscode(w http.ResponseWriter, r *http.Request) { func getRecoveryCodes(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } recoveryCodes := make([]recoveryCode, 0, 12) var accountRecoveryCodes []dataprovider.RecoveryCode - if claims.hasUserAudience() { + if hasUserAudience(claims) { user, err := dataprovider.UserExists(claims.Username, "") if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) @@ -210,7 +211,7 @@ func getRecoveryCodes(w http.ResponseWriter, r *http.Request) { func generateRecoveryCodes(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return @@ -222,7 +223,7 @@ func generateRecoveryCodes(w http.ResponseWriter, r *http.Request) { recoveryCodes = append(recoveryCodes, code) accountRecoveryCodes = append(accountRecoveryCodes, dataprovider.RecoveryCode{Secret: kms.NewPlainSecret(code)}) } - if claims.hasUserAudience() { + if hasUserAudience(claims) { user, err := dataprovider.UserExists(claims.Username, "") if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) diff --git a/internal/httpd/api_quota.go b/internal/httpd/api_quota.go index 6beef8e6..db339e39 100644 --- a/internal/httpd/api_quota.go +++ b/internal/httpd/api_quota.go @@ -23,6 +23,7 @@ import ( "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/jwt" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/vfs" ) @@ -44,7 +45,7 @@ type transferQuotaUsage struct { func getUsersQuotaScans(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return @@ -91,7 +92,7 @@ func startFolderQuotaScan(w http.ResponseWriter, r *http.Request) { func updateUserTransferQuotaUsage(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return @@ -132,7 +133,7 @@ func updateUserTransferQuotaUsage(w http.ResponseWriter, r *http.Request) { } func doUpdateUserQuotaUsage(w http.ResponseWriter, r *http.Request, username string, usage quotaUsage) { - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return @@ -204,7 +205,7 @@ func doStartUserQuotaScan(w http.ResponseWriter, r *http.Request, username strin sendAPIResponse(w, r, nil, "Quota tracking is disabled!", http.StatusForbidden) return } - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return diff --git a/internal/httpd/api_retention.go b/internal/httpd/api_retention.go index 11d61b4c..1502a327 100644 --- a/internal/httpd/api_retention.go +++ b/internal/httpd/api_retention.go @@ -20,11 +20,12 @@ import ( "github.com/go-chi/render" "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/jwt" ) func getRetentionChecks(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return diff --git a/internal/httpd/api_role.go b/internal/httpd/api_role.go index 6e9a23ee..d8840155 100644 --- a/internal/httpd/api_role.go +++ b/internal/httpd/api_role.go @@ -23,6 +23,7 @@ import ( "github.com/go-chi/render" "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/jwt" "github.com/drakkan/sftpgo/v2/internal/util" ) @@ -44,7 +45,7 @@ func getRoles(w http.ResponseWriter, r *http.Request) { func addRole(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return @@ -67,7 +68,7 @@ func addRole(w http.ResponseWriter, r *http.Request) { func updateRole(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return @@ -119,7 +120,7 @@ func getRoleByName(w http.ResponseWriter, r *http.Request) { func deleteRole(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return diff --git a/internal/httpd/api_shares.go b/internal/httpd/api_shares.go index 92f155c7..1ea1542e 100644 --- a/internal/httpd/api_shares.go +++ b/internal/httpd/api_shares.go @@ -26,20 +26,20 @@ import ( "strings" "time" - "github.com/go-chi/jwtauth/v5" "github.com/go-chi/render" "github.com/rs/xid" "github.com/sftpgo/sdk" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/jwt" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" ) func getShares(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return @@ -59,7 +59,7 @@ func getShares(w http.ResponseWriter, r *http.Request) { func getShareByID(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return @@ -77,7 +77,7 @@ func getShareByID(w http.ResponseWriter, r *http.Request) { func addShare(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return @@ -126,7 +126,7 @@ func addShare(w http.ResponseWriter, r *http.Request) { func updateShare(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return @@ -177,7 +177,7 @@ func updateShare(w http.ResponseWriter, r *http.Request) { func deleteShare(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) shareID := getURLParam(r, "id") - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return @@ -432,16 +432,16 @@ func (s *httpdServer) uploadFilesToShare(w http.ResponseWriter, r *http.Request) } } -func (s *httpdServer) getShareClaims(r *http.Request, shareID string) (context.Context, *jwtTokenClaims, error) { - token, err := jwtauth.VerifyRequest(s.tokenAuth, r, jwtauth.TokenFromCookie) +func (s *httpdServer) getShareClaims(r *http.Request, shareID string) (context.Context, *jwt.Claims, error) { + token, err := jwt.VerifyRequest(s.tokenAuth, r, jwt.TokenFromCookie) if err != nil || token == nil { return nil, nil, errInvalidToken } - tokenString := jwtauth.TokenFromCookie(r) + tokenString := jwt.TokenFromCookie(r) if tokenString == "" || invalidatedJWTTokens.Get(tokenString) { return nil, nil, errInvalidToken } - if !slices.Contains(token.Audience(), tokenAudienceWebShare) { + if !token.Audience.Contains(tokenAudienceWebShare) { logger.Debug(logSender, "", "invalid token audience for share %q", shareID) return nil, nil, errInvalidToken } @@ -450,13 +450,12 @@ func (s *httpdServer) getShareClaims(r *http.Request, shareID string) (context.C logger.Debug(logSender, "", "token for share %q is not valid for the ip address %q", shareID, ipAddr) return nil, nil, err } - ctx := jwtauth.NewContext(r.Context(), token, nil) - claims, err := getTokenClaims(r.WithContext(ctx)) - if err != nil || claims.Username != shareID { + if token.Username != shareID { logger.Debug(logSender, "", "token not valid for share %q", shareID) return nil, nil, errInvalidToken } - return ctx, &claims, nil + ctx := jwt.NewContext(r.Context(), token, nil) + return ctx, token, nil } func (s *httpdServer) checkWebClientShareCredentials(w http.ResponseWriter, r *http.Request, share *dataprovider.Share) error { diff --git a/internal/httpd/api_user.go b/internal/httpd/api_user.go index d90c7e17..4af45f30 100644 --- a/internal/httpd/api_user.go +++ b/internal/httpd/api_user.go @@ -27,6 +27,7 @@ import ( "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/jwt" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/smtp" "github.com/drakkan/sftpgo/v2/internal/util" @@ -39,7 +40,7 @@ func getUsers(w http.ResponseWriter, r *http.Request) { if err != nil { return } - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return @@ -55,16 +56,16 @@ func getUsers(w http.ResponseWriter, r *http.Request) { func getUserByUsername(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return } username := getURLParam(r, "username") - renderUser(w, r, username, &claims, http.StatusOK) + renderUser(w, r, username, claims, http.StatusOK) } -func renderUser(w http.ResponseWriter, r *http.Request, username string, claims *jwtTokenClaims, status int) { +func renderUser(w http.ResponseWriter, r *http.Request, username string, claims *jwt.Claims, status int) { user, err := dataprovider.UserExists(username, claims.Role) if err != nil { sendAPIResponse(w, r, err, "", getRespStatus(err)) @@ -84,7 +85,7 @@ func renderUser(w http.ResponseWriter, r *http.Request, username string, claims func addUser(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return @@ -117,12 +118,12 @@ func addUser(w http.ResponseWriter, r *http.Request) { return } w.Header().Add("Location", fmt.Sprintf("%s/%s", userPath, url.PathEscape(user.Username))) - renderUser(w, r, user.Username, &claims, http.StatusCreated) + renderUser(w, r, user.Username, claims, http.StatusCreated) } func disableUser2FA(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return @@ -150,7 +151,7 @@ func disableUser2FA(w http.ResponseWriter, r *http.Request) { func updateUser(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return @@ -202,7 +203,7 @@ func updateUser(w http.ResponseWriter, r *http.Request) { func deleteUser(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return diff --git a/internal/httpd/api_utils.go b/internal/httpd/api_utils.go index 5d7cf3e0..cef57b53 100644 --- a/internal/httpd/api_utils.go +++ b/internal/httpd/api_utils.go @@ -42,6 +42,7 @@ import ( "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/jwt" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/metric" "github.com/drakkan/sftpgo/v2/internal/plugin" @@ -177,7 +178,7 @@ func getBoolQueryParam(r *http.Request, param string) bool { func getActiveConnections(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return @@ -191,7 +192,7 @@ func getActiveConnections(w http.ResponseWriter, r *http.Request) { func handleCloseConnection(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return @@ -943,8 +944,8 @@ func getProtocolFromRequest(r *http.Request) string { return common.ProtocolHTTP } -func hideConfidentialData(claims *jwtTokenClaims, r *http.Request) bool { - if !claims.hasPerm(dataprovider.PermAdminAny) { +func hideConfidentialData(claims *jwt.Claims, r *http.Request) bool { + if !claims.HasPerm(dataprovider.PermAdminAny) { return true } return r.URL.Query().Get("confidential_data") != "1" diff --git a/internal/httpd/auth_utils.go b/internal/httpd/auth_utils.go index 3b222ca3..bdd7c5b3 100644 --- a/internal/httpd/auth_utils.go +++ b/internal/httpd/auth_utils.go @@ -15,17 +15,14 @@ package httpd import ( + "crypto/rand" "errors" "fmt" "net/http" - "slices" "time" - "github.com/go-chi/jwtauth/v5" - "github.com/lestrrat-go/jwx/v2/jwt" - "github.com/rs/xid" - "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/jwt" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" ) @@ -52,18 +49,8 @@ const ( ) const ( - claimUsernameKey = "username" - claimPermissionsKey = "permissions" - claimRole = "role" - claimAPIKey = "api_key" - claimNodeID = "node_id" - claimMustChangePasswordKey = "chpwd" - claimMustSetSecondFactorKey = "2fa_required" - claimRequiredTwoFactorProtocols = "2fa_protos" - claimHideUserPageSection = "hus" - claimRef = "ref" - basicRealm = "Basic realm=\"SFTPGo\"" - jwtCookieKey = "jwt" + basicRealm = "Basic realm=\"SFTPGo\"" + jwtCookieKey = "jwt" ) var ( @@ -129,212 +116,26 @@ func getMaxCookieDuration() time.Duration { return result } -type jwtTokenClaims struct { - Username string - Permissions []string - Role string - Signature string - Audience []string - APIKeyID string - NodeID string - MustSetTwoFactorAuth bool - MustChangePassword bool - RequiredTwoFactorProtocols []string - HideUserPageSections int - JwtID string - JwtIssuedAt time.Time - Ref string +func hasUserAudience(claims *jwt.Claims) bool { + return claims.HasAnyAudience([]string{tokenAudienceWebClient, tokenAudienceAPIUser}) } -func (c *jwtTokenClaims) hasUserAudience() bool { - for _, audience := range c.Audience { - if audience == tokenAudienceWebClient || audience == tokenAudienceAPIUser { - return true - } - } - - return false -} - -func (c *jwtTokenClaims) asMap() map[string]any { - claims := make(map[string]any) - - claims[claimUsernameKey] = c.Username - claims[claimPermissionsKey] = c.Permissions - if c.JwtID != "" { - claims[jwt.JwtIDKey] = c.JwtID - } - if !c.JwtIssuedAt.IsZero() { - claims[jwt.IssuedAtKey] = c.JwtIssuedAt - } - if c.Ref != "" { - claims[claimRef] = c.Ref - } - if c.Role != "" { - claims[claimRole] = c.Role - } - if c.APIKeyID != "" { - claims[claimAPIKey] = c.APIKeyID - } - if c.NodeID != "" { - claims[claimNodeID] = c.NodeID - } - claims[jwt.SubjectKey] = c.Signature - if c.MustChangePassword { - claims[claimMustChangePasswordKey] = c.MustChangePassword - } - if c.MustSetTwoFactorAuth { - claims[claimMustSetSecondFactorKey] = c.MustSetTwoFactorAuth - } - if len(c.RequiredTwoFactorProtocols) > 0 { - claims[claimRequiredTwoFactorProtocols] = c.RequiredTwoFactorProtocols - } - if c.HideUserPageSections > 0 { - claims[claimHideUserPageSection] = c.HideUserPageSections - } - - return claims -} - -func (c *jwtTokenClaims) decodeSliceString(val any) []string { - switch v := val.(type) { - case []any: - result := make([]string, 0, len(v)) - for _, elem := range v { - switch elemValue := elem.(type) { - case string: - result = append(result, elemValue) - } - } - return result - case []string: - return v - default: - return nil - } -} - -func (c *jwtTokenClaims) decodeBoolean(val any) bool { - switch v := val.(type) { - case bool: - return v - default: - return false - } -} - -func (c *jwtTokenClaims) decodeString(val any) string { - switch v := val.(type) { - case string: - return v - default: - return "" - } -} - -func (c *jwtTokenClaims) Decode(token map[string]any) { - c.Permissions = nil - c.Username = c.decodeString(token[claimUsernameKey]) - c.Signature = c.decodeString(token[jwt.SubjectKey]) - c.JwtID = c.decodeString(token[jwt.JwtIDKey]) - - audience := token[jwt.AudienceKey] - switch v := audience.(type) { - case []string: - c.Audience = v - } - - if val, ok := token[claimRef]; ok { - c.Ref = c.decodeString(val) - } - - if val, ok := token[claimAPIKey]; ok { - c.APIKeyID = c.decodeString(val) - } - - if val, ok := token[claimNodeID]; ok { - c.NodeID = c.decodeString(val) - } - - if val, ok := token[claimRole]; ok { - c.Role = c.decodeString(val) - } - - permissions := token[claimPermissionsKey] - c.Permissions = c.decodeSliceString(permissions) - - if val, ok := token[claimMustChangePasswordKey]; ok { - c.MustChangePassword = c.decodeBoolean(val) - } - - if val, ok := token[claimMustSetSecondFactorKey]; ok { - c.MustSetTwoFactorAuth = c.decodeBoolean(val) - } - - if val, ok := token[claimRequiredTwoFactorProtocols]; ok { - c.RequiredTwoFactorProtocols = c.decodeSliceString(val) - } - - if val, ok := token[claimHideUserPageSection]; ok { - switch v := val.(type) { - case float64: - c.HideUserPageSections = int(v) - } - } -} - -func (c *jwtTokenClaims) hasPerm(perm string) bool { - if slices.Contains(c.Permissions, dataprovider.PermAdminAny) { - return true - } - - return slices.Contains(c.Permissions, perm) -} - -func (c *jwtTokenClaims) createToken(tokenAuth *jwtauth.JWTAuth, audience tokenAudience, ip string) (jwt.Token, string, error) { - claims := c.asMap() - now := time.Now().UTC() - - if _, ok := claims[jwt.JwtIDKey]; !ok { - claims[jwt.JwtIDKey] = xid.New().String() - } - if _, ok := claims[jwt.IssuedAtKey]; !ok { - claims[jwt.IssuedAtKey] = now - } - claims[jwt.NotBeforeKey] = now.Add(-30 * time.Second) - claims[jwt.ExpirationKey] = now.Add(getTokenDuration(audience)) - claims[jwt.AudienceKey] = []string{audience, ip} - - return tokenAuth.Encode(claims) -} - -func (c *jwtTokenClaims) createTokenResponse(tokenAuth *jwtauth.JWTAuth, audience tokenAudience, ip string) (map[string]any, error) { - token, tokenString, err := c.createToken(tokenAuth, audience, ip) - if err != nil { - return nil, err - } - - response := make(map[string]any) - response["access_token"] = tokenString - response["expires_at"] = token.Expiration().Format(time.RFC3339) - - return response, nil -} - -func (c *jwtTokenClaims) createAndSetCookie(w http.ResponseWriter, r *http.Request, tokenAuth *jwtauth.JWTAuth, +func createAndSetCookie(w http.ResponseWriter, r *http.Request, claims *jwt.Claims, tokenAuth *jwt.Signer, audience tokenAudience, ip string, ) error { - resp, err := c.createTokenResponse(tokenAuth, audience, ip) + duration := getTokenDuration(audience) + token, err := tokenAuth.SignWithParams(claims, audience, ip, duration) if err != nil { return err } + resp := claims.BuildTokenResponse(token) var basePath string if audience == tokenAudienceWebAdmin || audience == tokenAudienceWebAdminPartial { basePath = webBaseAdminPath } else { basePath = webBaseClientPath } - setCookie(w, r, basePath, resp["access_token"].(string), getTokenDuration(audience)) + setCookie(w, r, basePath, resp.Token, duration) return nil } @@ -386,8 +187,8 @@ func isTLS(r *http.Request) bool { func isTokenInvalidated(r *http.Request) bool { var findTokenFns []func(r *http.Request) string - findTokenFns = append(findTokenFns, jwtauth.TokenFromHeader) - findTokenFns = append(findTokenFns, jwtauth.TokenFromCookie) + findTokenFns = append(findTokenFns, jwt.TokenFromHeader) + findTokenFns = append(findTokenFns, jwt.TokenFromCookie) findTokenFns = append(findTokenFns, oidcTokenFromContext) isTokenFound := false @@ -405,89 +206,78 @@ func isTokenInvalidated(r *http.Request) bool { } func invalidateToken(r *http.Request) { - tokenString := jwtauth.TokenFromHeader(r) + tokenString := jwt.TokenFromHeader(r) if tokenString != "" { invalidateTokenString(r, tokenString, apiTokenDuration) } - tokenString = jwtauth.TokenFromCookie(r) + tokenString = jwt.TokenFromCookie(r) if tokenString != "" { invalidateTokenString(r, tokenString, getMaxCookieDuration()) } } func invalidateTokenString(r *http.Request, tokenString string, fallbackDuration time.Duration) { - token, _, err := jwtauth.FromContext(r.Context()) - if err != nil || token == nil { + token, err := jwt.FromContext(r.Context()) + if err != nil { invalidatedJWTTokens.Add(tokenString, time.Now().Add(fallbackDuration).UTC()) return } - invalidatedJWTTokens.Add(tokenString, token.Expiration().Add(1*time.Minute).UTC()) + invalidatedJWTTokens.Add(tokenString, token.Expiry.Time().Add(1*time.Minute).UTC()) } func getUserFromToken(r *http.Request) *dataprovider.User { user := &dataprovider.User{} - _, claims, err := jwtauth.FromContext(r.Context()) + claims, err := jwt.FromContext(r.Context()) if err != nil { return user } - tokenClaims := jwtTokenClaims{} - tokenClaims.Decode(claims) - user.Username = tokenClaims.Username - user.Filters.WebClient = tokenClaims.Permissions - user.Role = tokenClaims.Role + user.Username = claims.Username + user.Filters.WebClient = claims.Permissions + user.Role = claims.Role return user } func getAdminFromToken(r *http.Request) *dataprovider.Admin { admin := &dataprovider.Admin{} - _, claims, err := jwtauth.FromContext(r.Context()) + claims, err := jwt.FromContext(r.Context()) if err != nil { return admin } - tokenClaims := jwtTokenClaims{} - tokenClaims.Decode(claims) - admin.Username = tokenClaims.Username - admin.Permissions = tokenClaims.Permissions - admin.Filters.Preferences.HideUserPageSections = tokenClaims.HideUserPageSections - admin.Role = tokenClaims.Role + admin.Username = claims.Username + admin.Permissions = claims.Permissions + admin.Filters.Preferences.HideUserPageSections = claims.HideUserPageSections + admin.Role = claims.Role return admin } -func createLoginCookie(w http.ResponseWriter, r *http.Request, csrfTokenAuth *jwtauth.JWTAuth, tokenID, basePath, ip string, +func createLoginCookie(w http.ResponseWriter, r *http.Request, csrfTokenAuth *jwt.Signer, tokenID, basePath, ip string, ) { - c := jwtTokenClaims{ - JwtID: tokenID, - } - resp, err := c.createTokenResponse(csrfTokenAuth, tokenAudienceWebLogin, ip) + c := jwt.NewClaims(tokenAudienceWebLogin, ip, getTokenDuration(tokenAudienceWebLogin)) + c.ID = tokenID + resp, err := c.GenerateTokenResponse(csrfTokenAuth) if err != nil { return } - setCookie(w, r, basePath, resp["access_token"].(string), csrfTokenDuration) + setCookie(w, r, basePath, resp.Token, csrfTokenDuration) } -func createCSRFToken(w http.ResponseWriter, r *http.Request, csrfTokenAuth *jwtauth.JWTAuth, tokenID, +func createCSRFToken(w http.ResponseWriter, r *http.Request, csrfTokenAuth *jwt.Signer, tokenID, basePath string, ) string { ip := util.GetIPFromRemoteAddress(r.RemoteAddr) - claims := make(map[string]any) - now := time.Now().UTC() - - claims[jwt.JwtIDKey] = xid.New().String() - claims[jwt.IssuedAtKey] = now - claims[jwt.NotBeforeKey] = now.Add(-30 * time.Second) - claims[jwt.ExpirationKey] = now.Add(csrfTokenDuration) - claims[jwt.AudienceKey] = []string{tokenAudienceCSRF, ip} + claims := jwt.NewClaims(tokenAudienceCSRF, ip, csrfTokenDuration) + claims.ID = rand.Text() if tokenID != "" { createLoginCookie(w, r, csrfTokenAuth, tokenID, basePath, ip) - claims[claimRef] = tokenID + claims.Ref = tokenID } else { - if c, err := getTokenClaims(r); err == nil { - claims[claimRef] = c.JwtID + if c, err := jwt.FromContext(r.Context()); err == nil { + claims.Ref = c.ID } else { logger.Error(logSender, "", "unable to add reference to CSRF token: %v", err) } } - _, tokenString, err := csrfTokenAuth.Encode(claims) + tokenString, err := csrfTokenAuth.Sign(claims) if err != nil { logger.Debug(logSender, "", "unable to create CSRF token: %v", err) return "" @@ -495,15 +285,15 @@ func createCSRFToken(w http.ResponseWriter, r *http.Request, csrfTokenAuth *jwta return tokenString } -func verifyCSRFToken(r *http.Request, csrfTokenAuth *jwtauth.JWTAuth) error { +func verifyCSRFToken(r *http.Request, csrfTokenAuth *jwt.Signer) error { tokenString := r.Form.Get(csrfFormToken) - token, err := jwtauth.VerifyToken(csrfTokenAuth, tokenString) + token, err := jwt.VerifyToken(csrfTokenAuth, tokenString) if err != nil || token == nil { logger.Debug(logSender, "", "error validating CSRF token %q: %v", tokenString, err) return fmt.Errorf("unable to verify form token: %v", err) } - if !slices.Contains(token.Audience(), tokenAudienceCSRF) { + if !token.Audience.Contains(tokenAudienceCSRF) { logger.Debug(logSender, "", "error validating CSRF token audience") return errors.New("the form token is not valid") } @@ -515,19 +305,18 @@ func verifyCSRFToken(r *http.Request, csrfTokenAuth *jwtauth.JWTAuth) error { return checkCSRFTokenRef(r, token) } -func checkCSRFTokenRef(r *http.Request, token jwt.Token) error { - claims, err := getTokenClaims(r) +func checkCSRFTokenRef(r *http.Request, token *jwt.Claims) error { + claims, err := jwt.FromContext(r.Context()) if err != nil { logger.Debug(logSender, "", "error getting token claims for CSRF validation: %v", err) return err } - ref, ok := token.Get(claimRef) - if !ok { + if token.ID == "" { logger.Debug(logSender, "", "error validating CSRF token, missing reference") return errors.New("the form token is not valid") } - if claims.JwtID == "" || claims.JwtID != ref.(string) { - logger.Debug(logSender, "", "error validating CSRF reference, id %q, reference %q", claims.JwtID, ref) + if claims.ID != token.Ref { + logger.Debug(logSender, "", "error validating CSRF reference, id %q, reference %q", claims.ID, token.ID) return errors.New("unexpected form token") } @@ -535,8 +324,8 @@ func checkCSRFTokenRef(r *http.Request, token jwt.Token) error { } func verifyLoginCookie(r *http.Request) error { - token, _, err := jwtauth.FromContext(r.Context()) - if err != nil || token == nil { + token, err := jwt.FromContext(r.Context()) + if err != nil { logger.Debug(logSender, "", "error getting login token: %v", err) return errInvalidToken } @@ -544,8 +333,8 @@ func verifyLoginCookie(r *http.Request) error { logger.Debug(logSender, "", "the login token has been invalidated") return errInvalidToken } - if !slices.Contains(token.Audience(), tokenAudienceWebLogin) { - logger.Debug(logSender, "", "the token with id %q is not valid for audience %q", token.JwtID(), tokenAudienceWebLogin) + if !token.Audience.Contains(tokenAudienceWebLogin) { + logger.Debug(logSender, "", "the token with id %q is not valid for audience %q", token.ID, tokenAudienceWebLogin) return errInvalidToken } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) @@ -555,7 +344,7 @@ func verifyLoginCookie(r *http.Request) error { return nil } -func verifyLoginCookieAndCSRFToken(r *http.Request, csrfTokenAuth *jwtauth.JWTAuth) error { +func verifyLoginCookieAndCSRFToken(r *http.Request, csrfTokenAuth *jwt.Signer) error { if err := verifyLoginCookie(r); err != nil { return err } @@ -565,17 +354,11 @@ func verifyLoginCookieAndCSRFToken(r *http.Request, csrfTokenAuth *jwtauth.JWTAu return nil } -func createOAuth2Token(csrfTokenAuth *jwtauth.JWTAuth, state, ip string) string { - claims := make(map[string]any) - now := time.Now().UTC() +func createOAuth2Token(csrfTokenAuth *jwt.Signer, state, ip string) string { + claims := jwt.NewClaims(tokenAudienceOAuth2, ip, getTokenDuration(tokenAudienceOAuth2)) + claims.ID = state - claims[jwt.JwtIDKey] = state - claims[jwt.IssuedAtKey] = now - claims[jwt.NotBeforeKey] = now.Add(-30 * time.Second) - claims[jwt.ExpirationKey] = now.Add(getTokenDuration(tokenAudienceOAuth2)) - claims[jwt.AudienceKey] = []string{tokenAudienceOAuth2, ip} - - _, tokenString, err := csrfTokenAuth.Encode(claims) + tokenString, err := csrfTokenAuth.Sign(claims) if err != nil { logger.Debug(logSender, "", "unable to create OAuth2 token: %v", err) return "" @@ -583,8 +366,8 @@ func createOAuth2Token(csrfTokenAuth *jwtauth.JWTAuth, state, ip string) string return tokenString } -func verifyOAuth2Token(csrfTokenAuth *jwtauth.JWTAuth, tokenString, ip string) (string, error) { - token, err := jwtauth.VerifyToken(csrfTokenAuth, tokenString) +func verifyOAuth2Token(csrfTokenAuth *jwt.Signer, tokenString, ip string) (string, error) { + token, err := jwt.VerifyToken(csrfTokenAuth, tokenString) if err != nil || token == nil { logger.Debug(logSender, "", "error validating OAuth2 token %q: %v", tokenString, err) return "", util.NewI18nError( @@ -593,7 +376,7 @@ func verifyOAuth2Token(csrfTokenAuth *jwtauth.JWTAuth, tokenString, ip string) ( ) } - if !slices.Contains(token.Audience(), tokenAudienceOAuth2) { + if !token.Audience.Contains(tokenAudienceOAuth2) { logger.Debug(logSender, "", "error validating OAuth2 token audience") return "", util.NewI18nError(errors.New("invalid OAuth2 state"), util.I18nOAuth2InvalidState) } @@ -602,31 +385,29 @@ func verifyOAuth2Token(csrfTokenAuth *jwtauth.JWTAuth, tokenString, ip string) ( logger.Debug(logSender, "", "error validating OAuth2 token IP audience") return "", util.NewI18nError(errors.New("invalid OAuth2 state"), util.I18nOAuth2InvalidState) } - if val, ok := token.Get(jwt.JwtIDKey); ok { - if state, ok := val.(string); ok { - return state, nil - } + if token.ID != "" { + return token.ID, nil } logger.Debug(logSender, "", "jti not found in OAuth2 token") return "", util.NewI18nError(errors.New("invalid OAuth2 state"), util.I18nOAuth2InvalidState) } -func validateIPForToken(token jwt.Token, ip string) error { +func validateIPForToken(token *jwt.Claims, ip string) error { if tokenValidationMode&tokenValidationModeNoIPMatch == 0 { - if !slices.Contains(token.Audience(), ip) { + if !token.Audience.Contains(ip) { return errInvalidToken } } return nil } -func checkTokenSignature(r *http.Request, token jwt.Token) error { +func checkTokenSignature(r *http.Request, token *jwt.Claims) error { if _, ok := r.Context().Value(oidcTokenKey).(string); ok { return nil } var err error if tokenValidationMode&tokenValidationModeUserSignature != 0 { - for _, audience := range token.Audience() { + for _, audience := range token.Audience { switch audience { case tokenAudienceAPI, tokenAudienceWebAdmin: err = validateSignatureForToken(token, dataprovider.GetAdminSignature) @@ -641,22 +422,16 @@ func checkTokenSignature(r *http.Request, token jwt.Token) error { return err } -func validateSignatureForToken(token jwt.Token, getter func(string) (string, error)) error { - username := "" - if u, ok := token.Get(claimUsernameKey); ok { - c := jwtTokenClaims{} - username = c.decodeString(u) - } - - signature, err := getter(username) +func validateSignatureForToken(token *jwt.Claims, getter func(string) (string, error)) error { + signature, err := getter(token.Username) if err != nil { - logger.Debug(logSender, "", "unable to get signature for username %q: %v", username, err) + logger.Debug(logSender, "", "unable to get signature for username %q: %v", token.Username, err) return errInvalidToken } - if signature != "" && signature == token.Subject() { + if signature != "" && signature == token.Subject { return nil } logger.Debug(logSender, "", "signature mismatch for username %q, signature %q, token signature %q", - username, signature, token.Subject()) + token.Username, signature, token.Subject) return errInvalidToken } diff --git a/internal/httpd/httpd.go b/internal/httpd/httpd.go index fe6c088f..c96e791e 100644 --- a/internal/httpd/httpd.go +++ b/internal/httpd/httpd.go @@ -1334,10 +1334,12 @@ func updateWebAdminURLs(baseURL string) { } // GetHTTPRouter returns an HTTP handler suitable to use for test cases -func GetHTTPRouter(b Binding) http.Handler { +func GetHTTPRouter(b Binding) (http.Handler, error) { server := newHttpdServer(b, filepath.Join("..", "..", "static"), "", CorsConfig{}, filepath.Join("..", "..", "openapi")) - server.initializeRouter() - return server.router + if err := server.initializeRouter(); err != nil { + return nil, err + } + return server.router, nil } // the ticker cannot be started/stopped from multiple goroutines diff --git a/internal/httpd/httpd_test.go b/internal/httpd/httpd_test.go index 1c74775a..95f2c461 100644 --- a/internal/httpd/httpd_test.go +++ b/internal/httpd/httpd_test.go @@ -328,7 +328,7 @@ type recoveryCode struct { Used bool `json:"used"` } -func TestMain(m *testing.M) { +func TestMain(m *testing.M) { //nolint:gocyclo homeBasePath = os.TempDir() logfilePath := filepath.Join(configDir, "sftpgo_api_test.log") logger.InitLogger(logfilePath, 5, 1, 28, false, false, zerolog.DebugLevel) @@ -480,7 +480,12 @@ func TestMain(m *testing.M) { waitTCPListening(httpdConf.Bindings[0].GetAddress()) httpd.ReloadCertificateMgr() //nolint:errcheck - testServer = httptest.NewServer(httpd.GetHTTPRouter(httpdConf.Bindings[0])) + handler, err := httpd.GetHTTPRouter(httpdConf.Bindings[0]) + if err != nil { + logger.ErrorToConsole("unable to get http test handler: %v", err) + os.Exit(1) + } + testServer = httptest.NewServer(handler) defer testServer.Close() exitCode := m.Run() diff --git a/internal/httpd/internal_test.go b/internal/httpd/internal_test.go index 6a6788c3..ccc0213e 100644 --- a/internal/httpd/internal_test.go +++ b/internal/httpd/internal_test.go @@ -38,10 +38,9 @@ import ( "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" - "github.com/go-chi/jwtauth/v5" + "github.com/go-jose/go-jose/v4" + josejwt "github.com/go-jose/go-jose/v4/jwt" "github.com/klauspost/compress/zip" - "github.com/lestrrat-go/jwx/v2/jwa" - "github.com/lestrrat-go/jwx/v2/jwt" "github.com/rs/xid" "github.com/sftpgo/sdk" sdkkms "github.com/sftpgo/sdk/kms" @@ -53,6 +52,7 @@ import ( "github.com/drakkan/sftpgo/v2/internal/acme" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/jwt" "github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/internal/plugin" "github.com/drakkan/sftpgo/v2/internal/util" @@ -307,6 +307,16 @@ func (r *failingWriter) Header() http.Header { return make(http.Header) } +type failingJoseSigner struct{} + +func (s *failingJoseSigner) Sign(payload []byte) (*jose.JSONWebSignature, error) { + return nil, errors.New("sign test error") +} + +func (s *failingJoseSigner) Options() jose.SignerOptions { + return jose.SignerOptions{} +} + func TestShouldBind(t *testing.T) { c := Conf{ Bindings: []Binding{ @@ -445,19 +455,19 @@ func TestTokenDuration(t *testing.T) { func TestVerifyCSRFToken(t *testing.T) { server := httpdServer{} - server.initializeRouter() + err := server.initializeRouter() + require.NoError(t, err) req, err := http.NewRequest(http.MethodPost, webAdminEventActionPath, nil) require.NoError(t, err) - req = req.WithContext(context.WithValue(req.Context(), jwtauth.ErrorCtxKey, fs.ErrPermission)) + req = req.WithContext(jwt.NewContext(req.Context(), &jwt.Claims{}, fs.ErrPermission)) rr := httptest.NewRecorder() tokenString := createCSRFToken(rr, req, server.csrfTokenAuth, "", webBaseAdminPath) assert.NotEmpty(t, tokenString) - token, err := server.csrfTokenAuth.Decode(tokenString) + claims, err := jwt.VerifyToken(server.csrfTokenAuth, tokenString) require.NoError(t, err) - _, ok := token.Get(claimRef) - assert.False(t, ok) + assert.Empty(t, claims.Ref) req.Form = url.Values{} req.Form.Set(csrfFormToken, tokenString) @@ -466,6 +476,18 @@ func TestVerifyCSRFToken(t *testing.T) { req, err = http.NewRequest(http.MethodPost, webAdminEventActionPath, nil) require.NoError(t, err) + req = req.WithContext(jwt.NewContext(req.Context(), &jwt.Claims{Claims: josejwt.Claims{ID: xid.New().String()}}, nil)) + req.Form = url.Values{} + req.Form.Set(csrfFormToken, tokenString) + err = verifyCSRFToken(req, server.csrfTokenAuth) + assert.ErrorContains(t, err, "unexpected form token") + + claims = jwt.NewClaims(tokenAudienceCSRF, "", getTokenDuration(tokenAudienceCSRF)) + tokenString, err = josejwt.Signed(server.csrfTokenAuth.Signer()).Claims(claims).Serialize() + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webAdminEventActionPath, nil) + require.NoError(t, err) + req = req.WithContext(jwt.NewContext(req.Context(), &jwt.Claims{Claims: josejwt.Claims{ID: xid.New().String()}}, nil)) req.Form = url.Values{} req.Form.Set(csrfFormToken, tokenString) err = verifyCSRFToken(req, server.csrfTokenAuth) @@ -474,7 +496,8 @@ func TestVerifyCSRFToken(t *testing.T) { func TestInvalidToken(t *testing.T) { server := httpdServer{} - server.initializeRouter() + err := server.initializeRouter() + require.NoError(t, err) admin := dataprovider.Admin{ Username: "admin", } @@ -485,7 +508,7 @@ func TestInvalidToken(t *testing.T) { rctx := chi.NewRouteContext() rctx.URLParams.Add("username", admin.Username) req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) - req = req.WithContext(context.WithValue(req.Context(), jwtauth.ErrorCtxKey, errFake)) + req = req.WithContext(context.WithValue(req.Context(), jwt.ErrorCtxKey, errFake)) rr := httptest.NewRecorder() updateAdmin(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) @@ -501,7 +524,7 @@ func TestInvalidToken(t *testing.T) { assert.NoError(t, err) req, _ = http.NewRequest(http.MethodPut, "", bytes.NewBuffer(asJSON)) req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) - req = req.WithContext(context.WithValue(req.Context(), jwtauth.ErrorCtxKey, errFake)) + req = req.WithContext(context.WithValue(req.Context(), jwt.ErrorCtxKey, errFake)) rr = httptest.NewRecorder() changeAdminPassword(rr, req) assert.Equal(t, http.StatusInternalServerError, rr.Code) @@ -980,7 +1003,8 @@ func TestTokenSignatureValidation(t *testing.T) { enableWebClient: true, enableRESTAPI: true, } - server.initializeRouter() + err := server.initializeRouter() + require.NoError(t, err) testServer := httptest.NewServer(server.router) defer testServer.Close() @@ -1134,28 +1158,30 @@ func TestTokenSignatureValidation(t *testing.T) { func TestUpdateWebAdminInvalidClaims(t *testing.T) { server := httpdServer{} - server.initializeRouter() + err := server.initializeRouter() + require.NoError(t, err) rr := httptest.NewRecorder() admin := dataprovider.Admin{ Username: "", Password: "password", } - c := jwtTokenClaims{ + c := &jwt.Claims{ Username: admin.Username, Permissions: admin.Permissions, - Signature: admin.GetSignature(), } - token, err := c.createTokenResponse(server.tokenAuth, tokenAudienceWebAdmin, "") + c.Subject = admin.GetSignature() + token, err := server.tokenAuth.SignWithParams(c, tokenAudienceWebAdmin, "", 10*time.Minute) assert.NoError(t, err) + resp := c.BuildTokenResponse(token) req, err := http.NewRequest(http.MethodGet, webAdminPath, nil) assert.NoError(t, err) - req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", token["access_token"])) - parsedToken, err := jwtauth.VerifyRequest(server.tokenAuth, req, jwtauth.TokenFromCookie) + req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", resp.Token)) + parsedToken, err := jwt.VerifyRequest(server.tokenAuth, req, jwt.TokenFromCookie) assert.NoError(t, err) ctx := req.Context() - ctx = jwtauth.NewContext(ctx, parsedToken, err) + ctx = jwt.NewContext(ctx, parsedToken, err) req = req.WithContext(ctx) form := make(url.Values) @@ -1169,7 +1195,7 @@ func TestUpdateWebAdminInvalidClaims(t *testing.T) { req = req.WithContext(ctx) req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", token["access_token"])) + req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", resp.Token)) server.handleWebUpdateAdminPost(rr, req) assert.Equal(t, http.StatusOK, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) @@ -1215,7 +1241,8 @@ func TestUpdateSMTPSecrets(t *testing.T) { func TestOAuth2Redirect(t *testing.T) { server := httpdServer{} - server.initializeRouter() + err := server.initializeRouter() + require.NoError(t, err) rr := httptest.NewRecorder() req, err := http.NewRequest(http.MethodGet, webOAuth2RedirectPath+"?state=invalid", nil) @@ -1237,22 +1264,17 @@ func TestOAuth2Redirect(t *testing.T) { func TestOAuth2Token(t *testing.T) { server := httpdServer{} - server.initializeRouter() + err := server.initializeRouter() + require.NoError(t, err) // invalid token - _, err := verifyOAuth2Token(server.csrfTokenAuth, "token", "") + _, err = verifyOAuth2Token(server.csrfTokenAuth, "token", "") if assert.Error(t, err) { assert.Contains(t, err.Error(), "unable to verify OAuth2 state") } // bad audience - claims := make(map[string]any) - now := time.Now().UTC() + claims := jwt.NewClaims(tokenAudienceAPI, "", getTokenDuration(tokenAudienceAPI)) - claims[jwt.JwtIDKey] = xid.New().String() - claims[jwt.NotBeforeKey] = now.Add(-30 * time.Second) - claims[jwt.ExpirationKey] = now.Add(getTokenDuration(tokenAudienceAPI)) - claims[jwt.AudienceKey] = []string{tokenAudienceAPI} - - _, tokenString, err := server.csrfTokenAuth.Encode(claims) + tokenString, err := server.csrfTokenAuth.Sign(claims) assert.NoError(t, err) _, err = verifyOAuth2Token(server.csrfTokenAuth, tokenString, "") if assert.Error(t, err) { @@ -1271,19 +1293,15 @@ func TestOAuth2Token(t *testing.T) { assert.NoError(t, err) assert.Equal(t, state, s) // no jti - claims = make(map[string]any) - - claims[jwt.NotBeforeKey] = now.Add(-30 * time.Second) - claims[jwt.ExpirationKey] = now.Add(getTokenDuration(tokenAudienceOAuth2)) - claims[jwt.AudienceKey] = []string{tokenAudienceOAuth2, "127.1.1.4"} - _, tokenString, err = server.csrfTokenAuth.Encode(claims) + claims = jwt.NewClaims(tokenAudienceOAuth2, "127.1.1.4", getTokenDuration(tokenAudienceOAuth2)) + tokenString, err = josejwt.Signed(server.csrfTokenAuth.Signer()).Claims(claims).Serialize() assert.NoError(t, err) _, err = verifyOAuth2Token(server.csrfTokenAuth, tokenString, "127.1.1.4") if assert.Error(t, err) { assert.Contains(t, err.Error(), "invalid OAuth2 state") } // encode error - server.csrfTokenAuth = jwtauth.New("HT256", util.GenerateRandomBytes(32), nil) + server.csrfTokenAuth.SetSigner(&failingJoseSigner{}) tokenString = createOAuth2Token(server.csrfTokenAuth, xid.New().String(), "") assert.Empty(t, tokenString) @@ -1301,23 +1319,17 @@ func TestOAuth2Token(t *testing.T) { func TestCSRFToken(t *testing.T) { server := httpdServer{} - server.initializeRouter() + err := server.initializeRouter() + require.NoError(t, err) // invalid token req := &http.Request{} - err := verifyCSRFToken(req, server.csrfTokenAuth) + err = verifyCSRFToken(req, server.csrfTokenAuth) if assert.Error(t, err) { assert.Contains(t, err.Error(), "unable to verify form token") } // bad audience - claims := make(map[string]any) - now := time.Now().UTC() - - claims[jwt.JwtIDKey] = xid.New().String() - claims[jwt.NotBeforeKey] = now.Add(-30 * time.Second) - claims[jwt.ExpirationKey] = now.Add(getTokenDuration(tokenAudienceAPI)) - claims[jwt.AudienceKey] = []string{tokenAudienceAPI} - - _, tokenString, err := server.csrfTokenAuth.Encode(claims) + claims := jwt.NewClaims(tokenAudienceAPI, "", getTokenDuration(tokenAudienceAPI)) + tokenString, err := server.csrfTokenAuth.Sign(claims) assert.NoError(t, err) values := url.Values{} values.Set(csrfFormToken, tokenString) @@ -1338,15 +1350,12 @@ func TestCSRFToken(t *testing.T) { assert.Contains(t, err.Error(), "form token is not valid") } - claims[jwt.JwtIDKey] = xid.New().String() - claims[jwt.NotBeforeKey] = now.Add(-30 * time.Second) - claims[jwt.ExpirationKey] = now.Add(getTokenDuration(tokenAudienceAPI)) - claims[jwt.AudienceKey] = []string{tokenAudienceAPI} - _, tokenString, err = server.csrfTokenAuth.Encode(claims) + claims = jwt.NewClaims(tokenAudienceAPI, "", getTokenDuration(tokenAudienceAPI)) + tokenString, err = server.csrfTokenAuth.Sign(claims) assert.NoError(t, err) assert.NotEmpty(t, tokenString) - r := GetHTTPRouter(Binding{ + r, err := GetHTTPRouter(Binding{ Address: "", Port: 8080, EnableWebAdmin: true, @@ -1354,6 +1363,7 @@ func TestCSRFToken(t *testing.T) { EnableRESTAPI: true, RenderOpenAPI: true, }) + assert.NoError(t, err) fn := server.verifyCSRFHeader(r) rr := httptest.NewRecorder() req, _ = http.NewRequest(http.MethodDelete, path.Join(userPath, "username"), nil) @@ -1377,7 +1387,9 @@ func TestCSRFToken(t *testing.T) { assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), "the token is not valid") - csrfTokenAuth := jwtauth.New("PS256", util.GenerateRandomBytes(32), nil) + csrfTokenAuth, err := jwt.NewSigner(jose.HS256, util.GenerateRandomBytes(32)) + require.NoError(t, err) + csrfTokenAuth.SetSigner(&failingJoseSigner{}) tokenString = createCSRFToken(httptest.NewRecorder(), req, csrfTokenAuth, "", webBaseAdminPath) assert.Empty(t, tokenString) rr = httptest.NewRecorder() @@ -1412,24 +1424,29 @@ func TestCreateShareCookieError(t *testing.T) { err = dataprovider.AddShare(share, "", "", "") assert.NoError(t, err) + tokenAuth, err := jwt.NewSigner(jose.HS256, util.GenerateRandomBytes(32)) + require.NoError(t, err) + tokenAuth.SetSigner(&failingJoseSigner{}) + csrfTokenAuth, err := jwt.NewSigner(jose.HS256, util.GenerateRandomBytes(32)) + require.NoError(t, err) + server := httpdServer{ - tokenAuth: jwtauth.New("TS256", util.GenerateRandomBytes(32), nil), - csrfTokenAuth: jwtauth.New(jwa.HS256.String(), util.GenerateRandomBytes(32), nil), + tokenAuth: tokenAuth, + csrfTokenAuth: csrfTokenAuth, } - c := jwtTokenClaims{ - JwtID: xid.New().String(), - } - resp, err := c.createTokenResponse(server.csrfTokenAuth, tokenAudienceWebLogin, "127.0.0.1") + c := jwt.NewClaims(tokenAudienceWebLogin, "127.0.0.1", getTokenDuration(tokenAudienceWebLogin)) + token, err := server.csrfTokenAuth.Sign(c) assert.NoError(t, err) - parsedToken, err := jwtauth.VerifyToken(server.csrfTokenAuth, resp["access_token"].(string)) + resp := c.BuildTokenResponse(token) + parsedToken, err := jwt.VerifyToken(server.csrfTokenAuth, resp.Token) assert.NoError(t, err) req, err := http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, share.ShareID, "login"), nil) assert.NoError(t, err) req.RemoteAddr = "127.0.0.1:4567" ctx := req.Context() - ctx = jwtauth.NewContext(ctx, parsedToken, err) + ctx = jwt.NewContext(ctx, parsedToken, err) req = req.WithContext(ctx) form := make(url.Values) @@ -1442,7 +1459,7 @@ func TestCreateShareCookieError(t *testing.T) { bytes.NewBuffer([]byte(form.Encode()))) assert.NoError(t, err) req.RemoteAddr = "127.0.0.1:2345" - req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", resp["access_token"])) + req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", resp.Token)) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req = req.WithContext(ctx) req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) @@ -1455,9 +1472,15 @@ func TestCreateShareCookieError(t *testing.T) { } func TestCreateTokenError(t *testing.T) { + tokenAuth, err := jwt.NewSigner(jose.HS256, util.GenerateRandomBytes(32)) + require.NoError(t, err) + tokenAuth.SetSigner(&failingJoseSigner{}) + csrfTokenAuth, err := jwt.NewSigner(jose.HS256, util.GenerateRandomBytes(32)) + require.NoError(t, err) + server := httpdServer{ - tokenAuth: jwtauth.New("PS256", util.GenerateRandomBytes(32), nil), - csrfTokenAuth: jwtauth.New(jwa.HS256.String(), util.GenerateRandomBytes(32), nil), + tokenAuth: tokenAuth, + csrfTokenAuth: csrfTokenAuth, } rr := httptest.NewRecorder() admin := dataprovider.Admin{ @@ -1481,19 +1504,20 @@ func TestCreateTokenError(t *testing.T) { server.generateAndSendUserToken(rr, req, "", user) assert.Equal(t, http.StatusInternalServerError, rr.Code) - c := jwtTokenClaims{ - JwtID: xid.New().String(), - } - token, err := c.createTokenResponse(server.csrfTokenAuth, tokenAudienceWebLogin, "") + c := &jwt.Claims{} + c.ID = xid.New().String() + c.SetExpiry(time.Now().Add(1 * time.Minute)) + tokenString, err := server.csrfTokenAuth.SignWithParams(c, tokenAudienceAPI, "", getTokenDuration(tokenAudienceAPI)) assert.NoError(t, err) + token := c.BuildTokenResponse(tokenString) req, err = http.NewRequest(http.MethodGet, webAdminLoginPath, nil) assert.NoError(t, err) - req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", token["access_token"])) - parsedToken, err := jwtauth.VerifyRequest(server.csrfTokenAuth, req, jwtauth.TokenFromCookie) + req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token.Token)) + parsedToken, err := jwt.VerifyRequest(server.csrfTokenAuth, req, jwt.TokenFromCookie) assert.NoError(t, err) ctx := req.Context() - ctx = jwtauth.NewContext(ctx, parsedToken, err) + ctx = jwt.NewContext(ctx, parsedToken, err) req = req.WithContext(ctx) rr = httptest.NewRecorder() @@ -1506,10 +1530,10 @@ func TestCreateTokenError(t *testing.T) { req, _ = http.NewRequest(http.MethodPost, webAdminLoginPath, bytes.NewBuffer([]byte(form.Encode()))) req.Header.Set("Cookie", cookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - parsedToken, err = jwtauth.VerifyRequest(server.csrfTokenAuth, req, jwtauth.TokenFromCookie) + parsedToken, err = jwt.VerifyRequest(server.csrfTokenAuth, req, jwt.TokenFromCookie) assert.NoError(t, err) ctx = req.Context() - ctx = jwtauth.NewContext(ctx, parsedToken, err) + ctx = jwt.NewContext(ctx, parsedToken, err) req = req.WithContext(ctx) server.handleWebAdminLoginPost(rr, req) assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) @@ -1587,6 +1611,7 @@ func TestCreateTokenError(t *testing.T) { assert.Equal(t, http.StatusInternalServerError, rr.Code, rr.Body.String()) req, _ = http.NewRequest(http.MethodPost, webAdminTwoFactorPath+"?a=a%C3%AO%GC", bytes.NewBuffer([]byte(form.Encode()))) + req = req.WithContext(jwt.NewContext(req.Context(), &jwt.Claims{}, nil)) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = httptest.NewRecorder() server.handleWebAdminTwoFactorPost(rr, req) @@ -1594,6 +1619,7 @@ func TestCreateTokenError(t *testing.T) { assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidForm) req, _ = http.NewRequest(http.MethodPost, webAdminTwoFactorRecoveryPath+"?a=a%C3%AO%GD", bytes.NewBuffer([]byte(form.Encode()))) + req = req.WithContext(jwt.NewContext(req.Context(), &jwt.Claims{}, nil)) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = httptest.NewRecorder() server.handleWebAdminTwoFactorRecoveryPost(rr, req) @@ -1601,6 +1627,7 @@ func TestCreateTokenError(t *testing.T) { assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidForm) req, _ = http.NewRequest(http.MethodPost, webClientTwoFactorPath+"?a=a%C3%AO%GC", bytes.NewBuffer([]byte(form.Encode()))) + req = req.WithContext(jwt.NewContext(req.Context(), &jwt.Claims{}, nil)) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = httptest.NewRecorder() server.handleWebClientTwoFactorPost(rr, req) @@ -1608,6 +1635,7 @@ func TestCreateTokenError(t *testing.T) { assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidForm) req, _ = http.NewRequest(http.MethodPost, webClientTwoFactorRecoveryPath+"?a=a%C3%AO%GD", bytes.NewBuffer([]byte(form.Encode()))) + req = req.WithContext(jwt.NewContext(req.Context(), &jwt.Claims{}, nil)) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rr = httptest.NewRecorder() server.handleWebClientTwoFactorRecoveryPost(rr, req) @@ -1673,11 +1701,11 @@ func TestCreateTokenError(t *testing.T) { req, err = http.NewRequest(http.MethodGet, webClientLoginPath, nil) assert.NoError(t, err) - req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", token["access_token"])) - parsedToken, err = jwtauth.VerifyRequest(server.csrfTokenAuth, req, jwtauth.TokenFromCookie) + req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token.Token)) + parsedToken, err = jwt.VerifyRequest(server.csrfTokenAuth, req, jwt.TokenFromCookie) assert.NoError(t, err) ctx = req.Context() - ctx = jwtauth.NewContext(ctx, parsedToken, err) + ctx = jwt.NewContext(ctx, parsedToken, err) req = req.WithContext(ctx) rr = httptest.NewRecorder() @@ -1713,7 +1741,7 @@ func TestCreateTokenError(t *testing.T) { } func TestAPIKeyAuthForbidden(t *testing.T) { - r := GetHTTPRouter(Binding{ + r, err := GetHTTPRouter(Binding{ Address: "", Port: 8080, EnableWebAdmin: true, @@ -1721,6 +1749,7 @@ func TestAPIKeyAuthForbidden(t *testing.T) { EnableRESTAPI: true, RenderOpenAPI: true, }) + require.NoError(t, err) fn := forbidAPIKeyAuthentication(r) rr := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodGet, versionPath, nil) @@ -1730,12 +1759,14 @@ func TestAPIKeyAuthForbidden(t *testing.T) { } func TestJWTTokenValidation(t *testing.T) { - tokenAuth := jwtauth.New(jwa.HS256.String(), util.GenerateRandomBytes(32), nil) - claims := make(map[string]any) - claims["username"] = defaultAdminUsername - claims[jwt.ExpirationKey] = time.Now().UTC().Add(-1 * time.Hour) - token, _, err := tokenAuth.Encode(claims) - assert.NoError(t, err) + tokenAuth, err := jwt.NewSigner(jose.HS256, util.GenerateRandomBytes(32)) + require.NoError(t, err) + claims := &jwt.Claims{ + Username: defaultAdminUsername, + } + claims.SetExpiry(time.Now().UTC().Add(-1 * time.Hour)) + _, err = tokenAuth.SignWithParams(claims, tokenAudienceWebAdmin, "", getTokenDuration(tokenAudienceWebAdmin)) + require.NoError(t, err) server := httpdServer{ binding: Binding{ @@ -1747,19 +1778,20 @@ func TestJWTTokenValidation(t *testing.T) { RenderOpenAPI: true, }, } - server.initializeRouter() + err = server.initializeRouter() + require.NoError(t, err) r := server.router fn := jwtAuthenticatorAPI(r) rr := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodGet, userPath, nil) - ctx := jwtauth.NewContext(req.Context(), token, nil) + ctx := jwt.NewContext(req.Context(), claims, nil) fn.ServeHTTP(rr, req.WithContext(ctx)) assert.Equal(t, http.StatusUnauthorized, rr.Code) fn = jwtAuthenticatorWebAdmin(r) rr = httptest.NewRecorder() req, _ = http.NewRequest(http.MethodGet, webUserPath, nil) - ctx = jwtauth.NewContext(req.Context(), token, nil) + ctx = jwt.NewContext(req.Context(), claims, nil) fn.ServeHTTP(rr, req.WithContext(ctx)) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location")) @@ -1767,7 +1799,7 @@ func TestJWTTokenValidation(t *testing.T) { fn = jwtAuthenticatorWebClient(r) rr = httptest.NewRecorder() req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil) - ctx = jwtauth.NewContext(req.Context(), token, nil) + ctx = jwt.NewContext(req.Context(), claims, nil) fn.ServeHTTP(rr, req.WithContext(ctx)) assert.Equal(t, http.StatusFound, rr.Code) assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) @@ -1777,7 +1809,7 @@ func TestJWTTokenValidation(t *testing.T) { fn = permFn(r) rr = httptest.NewRecorder() req, _ = http.NewRequest(http.MethodGet, userPath, nil) - ctx = jwtauth.NewContext(req.Context(), token, errTest) + ctx = jwt.NewContext(req.Context(), claims, errTest) fn.ServeHTTP(rr, req.WithContext(ctx)) assert.Equal(t, http.StatusBadRequest, rr.Code) @@ -1786,7 +1818,7 @@ func TestJWTTokenValidation(t *testing.T) { rr = httptest.NewRecorder() req, _ = http.NewRequest(http.MethodGet, webUserPath, nil) req.RequestURI = webUserPath - ctx = jwtauth.NewContext(req.Context(), token, errTest) + ctx = jwt.NewContext(req.Context(), claims, errTest) fn.ServeHTTP(rr, req.WithContext(ctx)) assert.Equal(t, http.StatusBadRequest, rr.Code) @@ -1795,14 +1827,14 @@ func TestJWTTokenValidation(t *testing.T) { rr = httptest.NewRecorder() req, _ = http.NewRequest(http.MethodPost, webClientProfilePath, nil) req.RequestURI = webClientProfilePath - ctx = jwtauth.NewContext(req.Context(), token, errTest) + ctx = jwt.NewContext(req.Context(), claims, errTest) fn.ServeHTTP(rr, req.WithContext(ctx)) assert.Equal(t, http.StatusBadRequest, rr.Code) rr = httptest.NewRecorder() req, _ = http.NewRequest(http.MethodPost, userProfilePath, nil) req.RequestURI = userProfilePath - ctx = jwtauth.NewContext(req.Context(), token, errTest) + ctx = jwt.NewContext(req.Context(), claims, errTest) fn.ServeHTTP(rr, req.WithContext(ctx)) assert.Equal(t, http.StatusBadRequest, rr.Code) @@ -1810,7 +1842,7 @@ func TestJWTTokenValidation(t *testing.T) { rr = httptest.NewRecorder() req, _ = http.NewRequest(http.MethodPost, webClientProfilePath, nil) req.RequestURI = webClientProfilePath - ctx = jwtauth.NewContext(req.Context(), token, errTest) + ctx = jwt.NewContext(req.Context(), claims, errTest) fn.ServeHTTP(rr, req.WithContext(ctx)) assert.Equal(t, http.StatusBadRequest, rr.Code) @@ -1818,50 +1850,56 @@ func TestJWTTokenValidation(t *testing.T) { rr = httptest.NewRecorder() req, _ = http.NewRequest(http.MethodPost, webGroupsPath, nil) req.RequestURI = webGroupsPath - ctx = jwtauth.NewContext(req.Context(), token, errTest) + ctx = jwt.NewContext(req.Context(), claims, errTest) fn.ServeHTTP(rr, req.WithContext(ctx)) assert.Equal(t, http.StatusBadRequest, rr.Code) rr = httptest.NewRecorder() req, _ = http.NewRequest(http.MethodPost, userSharesPath, nil) req.RequestURI = userSharesPath - ctx = jwtauth.NewContext(req.Context(), token, errTest) + ctx = jwt.NewContext(req.Context(), claims, errTest) fn.ServeHTTP(rr, req.WithContext(ctx)) assert.Equal(t, http.StatusBadRequest, rr.Code) } func TestUpdateContextFromCookie(t *testing.T) { + tokenAuth, err := jwt.NewSigner(jose.HS256, util.GenerateRandomBytes(32)) + require.NoError(t, err) server := httpdServer{ - tokenAuth: jwtauth.New(jwa.HS256.String(), util.GenerateRandomBytes(32), nil), + tokenAuth: tokenAuth, } req, _ := http.NewRequest(http.MethodGet, tokenPath, nil) - claims := make(map[string]any) - claims["a"] = "b" - token, _, err := server.tokenAuth.Encode(claims) + claims := jwt.NewClaims(tokenAudienceWebClient, "", getTokenDuration(tokenAudienceWebClient)) + _, err = server.tokenAuth.Sign(claims) assert.NoError(t, err) - ctx := jwtauth.NewContext(req.Context(), token, nil) - server.updateContextFromCookie(req.WithContext(ctx)) + ctx := jwt.NewContext(req.Context(), claims, nil) + req = server.updateContextFromCookie(req.WithContext(ctx)) + token, err := jwt.FromContext(req.Context()) + require.NoError(t, err) + require.True(t, token.Audience.Contains(tokenAudienceWebClient)) + require.NotEmpty(t, token.ID) } func TestCookieExpiration(t *testing.T) { + tokenAuth, err := jwt.NewSigner(jose.HS256, util.GenerateRandomBytes(32)) + require.NoError(t, err) server := httpdServer{ - tokenAuth: jwtauth.New(jwa.HS256.String(), util.GenerateRandomBytes(32), nil), + tokenAuth: tokenAuth, } - err := errors.New("test error") + err = errors.New("test error") rr := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodGet, tokenPath, nil) - ctx := jwtauth.NewContext(req.Context(), nil, err) + ctx := jwt.NewContext(req.Context(), nil, err) server.checkCookieExpiration(rr, req.WithContext(ctx)) cookie := rr.Header().Get("Set-Cookie") assert.Empty(t, cookie) req, _ = http.NewRequest(http.MethodGet, tokenPath, nil) - claims := make(map[string]any) - claims["a"] = "b" - token, _, err := server.tokenAuth.Encode(claims) + claims := jwt.NewClaims(tokenAudienceWebClient, "", getTokenDuration(tokenAudienceWebClient)) + _, err = server.tokenAuth.Sign(claims) assert.NoError(t, err) - ctx = jwtauth.NewContext(req.Context(), token, nil) + ctx = jwt.NewContext(req.Context(), claims, nil) server.checkCookieExpiration(rr, req.WithContext(ctx)) cookie = rr.Header().Get("Set-Cookie") assert.Empty(t, cookie) @@ -1871,17 +1909,16 @@ func TestCookieExpiration(t *testing.T) { Password: "password", Permissions: []string{dataprovider.PermAdminAny}, } - claims = make(map[string]any) - claims[claimUsernameKey] = admin.Username - claims[claimPermissionsKey] = admin.Permissions - claims[jwt.JwtIDKey] = xid.New().String() - claims[jwt.SubjectKey] = admin.GetSignature() - claims[jwt.ExpirationKey] = time.Now().Add(1 * time.Minute) - claims[jwt.AudienceKey] = []string{tokenAudienceAPI} - token, _, err = server.tokenAuth.Encode(claims) + claims = jwt.NewClaims(tokenAudienceAPI, "", getTokenDuration(tokenAudienceAPI)) + claims.Username = admin.Username + claims.Permissions = admin.Permissions + claims.Subject = admin.GetSignature() + claims.SetExpiry(time.Now().Add(1 * time.Minute)) + _, err = server.tokenAuth.Sign(claims) assert.NoError(t, err) + req, _ = http.NewRequest(http.MethodGet, tokenPath, nil) - ctx = jwtauth.NewContext(req.Context(), token, nil) + ctx = jwt.NewContext(req.Context(), claims, nil) server.checkCookieExpiration(rr, req.WithContext(ctx)) cookie = rr.Header().Get("Set-Cookie") assert.Empty(t, cookie) @@ -1890,7 +1927,7 @@ func TestCookieExpiration(t *testing.T) { err = dataprovider.AddAdmin(&admin, "", "", "") assert.NoError(t, err) req, _ = http.NewRequest(http.MethodGet, tokenPath, nil) - ctx = jwtauth.NewContext(req.Context(), token, nil) + ctx = jwt.NewContext(req.Context(), claims, nil) server.checkCookieExpiration(rr, req.WithContext(ctx)) cookie = rr.Header().Get("Set-Cookie") assert.Empty(t, cookie) @@ -1900,7 +1937,7 @@ func TestCookieExpiration(t *testing.T) { err = dataprovider.UpdateAdmin(&admin, "", "", "") assert.NoError(t, err) req, _ = http.NewRequest(http.MethodGet, tokenPath, nil) - ctx = jwtauth.NewContext(req.Context(), token, nil) + ctx = jwt.NewContext(req.Context(), claims, nil) server.checkCookieExpiration(rr, req.WithContext(ctx)) cookie = rr.Header().Get("Set-Cookie") assert.Empty(t, cookie) @@ -1908,33 +1945,32 @@ func TestCookieExpiration(t *testing.T) { admin, err = dataprovider.AdminExists(admin.Username) assert.NoError(t, err) tokenID := xid.New().String() - claims = make(map[string]any) - claims[claimUsernameKey] = admin.Username - claims[claimPermissionsKey] = admin.Permissions - claims[jwt.JwtIDKey] = tokenID - claims[jwt.IssuedAtKey] = time.Now() - claims[jwt.SubjectKey] = admin.GetSignature() - claims[jwt.ExpirationKey] = time.Now().Add(1 * time.Minute) - claims[jwt.AudienceKey] = []string{tokenAudienceAPI} - token, _, err = server.tokenAuth.Encode(claims) + claims = jwt.NewClaims(tokenAudienceAPI, "", getTokenDuration(tokenAudienceAPI)) + claims.ID = tokenID + claims.Username = admin.Username + claims.Permissions = admin.Permissions + claims.Subject = admin.GetSignature() + claims.SetExpiry(time.Now().Add(1 * time.Minute)) + _, err = server.tokenAuth.Sign(claims) assert.NoError(t, err) + req, _ = http.NewRequest(http.MethodGet, tokenPath, nil) req.RemoteAddr = "192.168.8.1:1234" - ctx = jwtauth.NewContext(req.Context(), token, nil) + ctx = jwt.NewContext(req.Context(), claims, nil) server.checkCookieExpiration(rr, req.WithContext(ctx)) cookie = rr.Header().Get("Set-Cookie") assert.Empty(t, cookie) req, _ = http.NewRequest(http.MethodGet, tokenPath, nil) req.RemoteAddr = "172.16.1.12:4567" - ctx = jwtauth.NewContext(req.Context(), token, nil) + ctx = jwt.NewContext(req.Context(), claims, nil) server.checkCookieExpiration(rr, req.WithContext(ctx)) cookie = rr.Header().Get("Set-Cookie") assert.True(t, strings.HasPrefix(cookie, "jwt=")) req.Header.Set("Cookie", cookie) - token, err = jwtauth.VerifyRequest(server.tokenAuth, req, jwtauth.TokenFromCookie) + c, err := jwt.VerifyRequest(server.tokenAuth, req, jwt.TokenFromCookie) if assert.NoError(t, err) { - assert.Equal(t, tokenID, token.JwtID()) + assert.Equal(t, tokenID, c.ID) } err = dataprovider.DeleteAdmin(admin.Username, "", "", "") @@ -1953,20 +1989,18 @@ func TestCookieExpiration(t *testing.T) { user.Permissions = make(map[string][]string) user.Permissions["/"] = []string{"*"} - claims = make(map[string]any) - claims[claimUsernameKey] = user.Username - claims[claimPermissionsKey] = user.Filters.WebClient - claims[jwt.JwtIDKey] = tokenID - claims[jwt.IssuedAtKey] = time.Now() - claims[jwt.SubjectKey] = user.GetSignature() - claims[jwt.ExpirationKey] = time.Now().Add(1 * time.Minute) - claims[jwt.AudienceKey] = []string{tokenAudienceWebClient} - token, _, err = server.tokenAuth.Encode(claims) + claims = jwt.NewClaims(tokenAudienceWebClient, "", getTokenDuration(tokenAudienceWebClient)) + claims.ID = tokenID + claims.Username = user.Username + claims.Permissions = user.Filters.WebClient + claims.Subject = user.GetSignature() + claims.SetExpiry(time.Now().Add(1 * time.Minute)) + _, err = server.tokenAuth.Sign(claims) assert.NoError(t, err) rr = httptest.NewRecorder() req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil) - ctx = jwtauth.NewContext(req.Context(), token, nil) + ctx = jwt.NewContext(req.Context(), claims, nil) server.checkCookieExpiration(rr, req.WithContext(ctx)) cookie = rr.Header().Get("Set-Cookie") assert.Empty(t, cookie) @@ -1974,7 +2008,7 @@ func TestCookieExpiration(t *testing.T) { err = dataprovider.AddUser(&user, "", "", "") assert.NoError(t, err) req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil) - ctx = jwtauth.NewContext(req.Context(), token, nil) + ctx = jwt.NewContext(req.Context(), claims, nil) server.checkCookieExpiration(rr, req.WithContext(ctx)) cookie = rr.Header().Get("Set-Cookie") assert.Empty(t, cookie) @@ -1989,46 +2023,52 @@ func TestCookieExpiration(t *testing.T) { assert.NoError(t, err) issuedAt := time.Now().Add(-1 * time.Minute) expiresAt := time.Now().Add(1 * time.Minute) - claims = make(map[string]any) - claims[claimUsernameKey] = user.Username - claims[claimPermissionsKey] = user.Filters.WebClient - claims[jwt.JwtIDKey] = tokenID - claims[jwt.IssuedAtKey] = issuedAt - claims[jwt.SubjectKey] = user.GetSignature() - claims[jwt.ExpirationKey] = expiresAt - claims[jwt.AudienceKey] = []string{tokenAudienceWebClient} - token, _, err = server.tokenAuth.Encode(claims) + + claims = jwt.NewClaims(tokenAudienceWebClient, "", getTokenDuration(tokenAudienceWebClient)) + claims.ID = tokenID + claims.Username = user.Username + claims.Permissions = user.Filters.WebClient + claims.Subject = user.GetSignature() + claims.SetExpiry(expiresAt) + claims.SetIssuedAt(issuedAt) + _, err = server.tokenAuth.Sign(claims) assert.NoError(t, err) req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil) req.RemoteAddr = "172.16.3.12:4567" - ctx = jwtauth.NewContext(req.Context(), token, nil) + ctx = jwt.NewContext(req.Context(), claims, nil) server.checkCookieExpiration(rr, req.WithContext(ctx)) cookie = rr.Header().Get("Set-Cookie") assert.Empty(t, cookie) req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil) req.RemoteAddr = "172.16.4.16:4567" - ctx = jwtauth.NewContext(req.Context(), token, nil) + ctx = jwt.NewContext(req.Context(), claims, nil) server.checkCookieExpiration(rr, req.WithContext(ctx)) cookie = rr.Header().Get("Set-Cookie") assert.NotEmpty(t, cookie) req.Header.Set("Cookie", cookie) - token, err = jwtauth.VerifyRequest(server.tokenAuth, req, jwtauth.TokenFromCookie) + c, err = jwt.VerifyRequest(server.tokenAuth, req, jwt.TokenFromCookie) if assert.NoError(t, err) { - assert.Equal(t, tokenID, token.JwtID()) - assert.Equal(t, issuedAt.Unix(), token.IssuedAt().Unix()) - assert.NotEqual(t, expiresAt.Unix(), token.Expiration().Unix()) + assert.Equal(t, tokenID, c.ID) + assert.Equal(t, issuedAt.Unix(), c.IssuedAt.Time().Unix()) + assert.NotEqual(t, expiresAt.Unix(), c.Expiry.Time().Unix()) } // test a cookie issued more that 12 hours ago - claims[jwt.IssuedAtKey] = time.Now().Add(-24 * time.Hour) - token, _, err = server.tokenAuth.Encode(claims) + claims = jwt.NewClaims(tokenAudienceWebClient, "", getTokenDuration(tokenAudienceWebClient)) + claims.ID = tokenID + claims.Username = user.Username + claims.Permissions = user.Filters.WebClient + claims.Subject = user.GetSignature() + claims.SetExpiry(expiresAt) + claims.SetIssuedAt(time.Now().Add(-24 * time.Hour)) + _, err = server.tokenAuth.Sign(claims) assert.NoError(t, err) rr = httptest.NewRecorder() req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil) req.RemoteAddr = "172.16.4.16:6789" - ctx = jwtauth.NewContext(req.Context(), token, nil) + ctx = jwt.NewContext(req.Context(), claims, nil) server.checkCookieExpiration(rr, req.WithContext(ctx)) cookie = rr.Header().Get("Set-Cookie") assert.Empty(t, cookie) @@ -2040,20 +2080,19 @@ func TestCookieExpiration(t *testing.T) { user, err = dataprovider.UserExists(user.Username, "") assert.NoError(t, err) - claims = make(map[string]any) - claims[claimUsernameKey] = user.Username - claims[claimPermissionsKey] = user.Filters.WebClient - claims[jwt.JwtIDKey] = tokenID - claims[jwt.IssuedAtKey] = issuedAt - claims[jwt.SubjectKey] = user.GetSignature() - claims[jwt.ExpirationKey] = time.Now().Add(1 * time.Minute) - claims[jwt.AudienceKey] = []string{tokenAudienceWebClient} - token, _, err = server.tokenAuth.Encode(claims) + claims = jwt.NewClaims(tokenAudienceWebClient, "", getTokenDuration(tokenAudienceWebClient)) + claims.ID = tokenID + claims.Username = user.Username + claims.Permissions = user.Filters.WebClient + claims.Subject = user.GetSignature() + claims.SetExpiry(time.Now().Add(1 * time.Minute)) + claims.SetIssuedAt(issuedAt) + _, err = server.tokenAuth.Sign(claims) assert.NoError(t, err) rr = httptest.NewRecorder() req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil) - ctx = jwtauth.NewContext(req.Context(), token, nil) + ctx = jwt.NewContext(req.Context(), claims, nil) server.checkCookieExpiration(rr, req.WithContext(ctx)) cookie = rr.Header().Get("Set-Cookie") assert.Empty(t, cookie) @@ -2083,6 +2122,7 @@ func TestChangePwdValidationErrors(t *testing.T) { require.Error(t, err) req, _ := http.NewRequest(http.MethodPut, adminPwdPath, nil) + req = req.WithContext(jwt.NewContext(req.Context(), &jwt.Claims{Claims: josejwt.Claims{ID: xid.New().String()}}, nil)) err = doChangeAdminPassword(req, "currentpwd", "newpwd", "newpwd") assert.Error(t, err) } @@ -2090,23 +2130,24 @@ func TestChangePwdValidationErrors(t *testing.T) { func TestRenderUnexistingFolder(t *testing.T) { rr := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodPost, folderPath, nil) - renderFolder(rr, req, "path not mapped", &jwtTokenClaims{}, http.StatusOK) + renderFolder(rr, req, "path not mapped", &jwt.Claims{}, http.StatusOK) assert.Equal(t, http.StatusNotFound, rr.Code) } func TestCloseConnectionHandler(t *testing.T) { - tokenAuth := jwtauth.New(jwa.HS256.String(), util.GenerateRandomBytes(32), nil) - claims := make(map[string]any) - claims["username"] = defaultAdminUsername - claims[jwt.ExpirationKey] = time.Now().UTC().Add(1 * time.Hour) - token, _, err := tokenAuth.Encode(claims) + tokenAuth, err := jwt.NewSigner(jose.HS256, util.GenerateRandomBytes(32)) + require.NoError(t, err) + claims := jwt.NewClaims(tokenAudienceAPI, "", getTokenDuration(tokenAudienceAPI)) + claims.Username = defaultAdminUsername + claims.SetExpiry(time.Now().UTC().Add(1 * time.Hour)) + _, err = tokenAuth.Sign(claims) assert.NoError(t, err) req, err := http.NewRequest(http.MethodDelete, activeConnectionsPath+"/connectionID", nil) assert.NoError(t, err) rctx := chi.NewRouteContext() rctx.URLParams.Add("connectionID", "") req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) - req = req.WithContext(context.WithValue(req.Context(), jwtauth.TokenCtxKey, token)) + req = req.WithContext(context.WithValue(req.Context(), jwt.TokenCtxKey, claims)) rr := httptest.NewRecorder() handleCloseConnection(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) @@ -2301,20 +2342,22 @@ func TestGetUserFromTemplate(t *testing.T) { } func TestJWTTokenCleanup(t *testing.T) { + tokenAuth, err := jwt.NewSigner(jose.HS256, util.GenerateRandomBytes(32)) + require.NoError(t, err) server := httpdServer{ - tokenAuth: jwtauth.New(jwa.HS256.String(), util.GenerateRandomBytes(32), nil), + tokenAuth: tokenAuth, } admin := dataprovider.Admin{ Username: "newtestadmin", Password: "password", Permissions: []string{dataprovider.PermAdminAny}, } - claims := make(map[string]any) - claims[claimUsernameKey] = admin.Username - claims[claimPermissionsKey] = admin.Permissions - claims[jwt.SubjectKey] = admin.GetSignature() - claims[jwt.ExpirationKey] = time.Now().Add(1 * time.Minute) - _, token, err := server.tokenAuth.Encode(claims) + claims := jwt.NewClaims(tokenAudienceAPI, "", getTokenDuration(tokenAudienceAPI)) + claims.Username = admin.Username + claims.Permissions = admin.Permissions + claims.Subject = admin.GetSignature() + claims.SetExpiry(time.Now().Add(1 * time.Minute)) + token, err := server.tokenAuth.Sign(claims) assert.NoError(t, err) req, _ := http.NewRequest(http.MethodGet, versionPath, nil) @@ -2324,7 +2367,7 @@ func TestJWTTokenCleanup(t *testing.T) { invalidateTokenString(req, fakeToken, -100*time.Millisecond) assert.True(t, invalidatedJWTTokens.Get(fakeToken)) - req.Header.Set("Authorization", fmt.Sprintf("Bearer %v", token)) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) invalidatedJWTTokens.Add(token, time.Now().Add(-getTokenDuration(tokenAudienceWebAdmin)).UTC()) require.True(t, isTokenInvalidated(req)) @@ -2459,7 +2502,8 @@ func TestProxyHeaders(t *testing.T) { err = b.parseAllowedProxy() assert.NoError(t, err) server := newHttpdServer(b, "", "", CorsConfig{Enabled: true}, "") - server.initializeRouter() + err = server.initializeRouter() + require.NoError(t, err) testServer := httptest.NewServer(server.router) defer testServer.Close() @@ -2493,10 +2537,10 @@ func TestProxyHeaders(t *testing.T) { cookie := rr.Header().Get("Set-Cookie") assert.NotEmpty(t, cookie) req.Header.Set("Cookie", cookie) - parsedToken, err := jwtauth.VerifyRequest(server.csrfTokenAuth, req, jwtauth.TokenFromCookie) + parsedToken, err := jwt.VerifyRequest(server.csrfTokenAuth, req, jwt.TokenFromCookie) assert.NoError(t, err) ctx := req.Context() - ctx = jwtauth.NewContext(ctx, parsedToken, err) + ctx = jwt.NewContext(ctx, parsedToken, err) req = req.WithContext(ctx) form := make(url.Values) @@ -2522,10 +2566,10 @@ func TestProxyHeaders(t *testing.T) { loginCookie := rr.Header().Get("Set-Cookie") assert.NotEmpty(t, loginCookie) req.Header.Set("Cookie", loginCookie) - parsedToken, err = jwtauth.VerifyRequest(server.csrfTokenAuth, req, jwtauth.TokenFromCookie) + parsedToken, err = jwt.VerifyRequest(server.csrfTokenAuth, req, jwt.TokenFromCookie) assert.NoError(t, err) ctx = req.Context() - ctx = jwtauth.NewContext(ctx, parsedToken, err) + ctx = jwt.NewContext(ctx, parsedToken, err) req = req.WithContext(ctx) form.Set(csrfFormToken, createCSRFToken(httptest.NewRecorder(), req, server.csrfTokenAuth, "", webBaseAdminPath)) @@ -2562,10 +2606,10 @@ func TestProxyHeaders(t *testing.T) { loginCookie = rr.Header().Get("Set-Cookie") assert.NotEmpty(t, loginCookie) req.Header.Set("Cookie", loginCookie) - parsedToken, err = jwtauth.VerifyRequest(server.csrfTokenAuth, req, jwtauth.TokenFromCookie) + parsedToken, err = jwt.VerifyRequest(server.csrfTokenAuth, req, jwt.TokenFromCookie) assert.NoError(t, err) ctx = req.Context() - ctx = jwtauth.NewContext(ctx, parsedToken, err) + ctx = jwt.NewContext(ctx, parsedToken, err) req = req.WithContext(ctx) form.Set(csrfFormToken, createCSRFToken(httptest.NewRecorder(), req, server.csrfTokenAuth, "", webBaseAdminPath)) @@ -2591,10 +2635,10 @@ func TestProxyHeaders(t *testing.T) { loginCookie = rr.Header().Get("Set-Cookie") assert.NotEmpty(t, loginCookie) req.Header.Set("Cookie", loginCookie) - parsedToken, err = jwtauth.VerifyRequest(server.csrfTokenAuth, req, jwtauth.TokenFromCookie) + parsedToken, err = jwt.VerifyRequest(server.csrfTokenAuth, req, jwt.TokenFromCookie) assert.NoError(t, err) ctx = req.Context() - ctx = jwtauth.NewContext(ctx, parsedToken, err) + ctx = jwt.NewContext(ctx, parsedToken, err) req = req.WithContext(ctx) form.Set(csrfFormToken, createCSRFToken(httptest.NewRecorder(), req, server.csrfTokenAuth, "", webBaseAdminPath)) @@ -2625,7 +2669,8 @@ func TestRecoverer(t *testing.T) { EnableRESTAPI: true, } server := newHttpdServer(b, "../static", "", CorsConfig{}, "../openapi") - server.initializeRouter() + err := server.initializeRouter() + require.NoError(t, err) server.router.Get(recoveryPath, func(_ http.ResponseWriter, _ *http.Request) { panic("panic") }) @@ -2786,7 +2831,8 @@ func TestWebAdminRedirect(t *testing.T) { EnableRESTAPI: true, } server := newHttpdServer(b, "../static", "", CorsConfig{}, "../openapi") - server.initializeRouter() + err := server.initializeRouter() + require.NoError(t, err) testServer := httptest.NewServer(server.router) defer testServer.Close() @@ -3074,7 +3120,8 @@ func TestChangeUserPwd(t *testing.T) { func TestWebUserInvalidClaims(t *testing.T) { server := httpdServer{} - server.initializeRouter() + err := server.initializeRouter() + require.NoError(t, err) rr := httptest.NewRecorder() user := dataprovider.User{ @@ -3083,79 +3130,81 @@ func TestWebUserInvalidClaims(t *testing.T) { Password: "pwd", }, } - c := jwtTokenClaims{ + c := &jwt.Claims{ Username: user.Username, Permissions: nil, - Signature: user.GetSignature(), } - token, err := c.createTokenResponse(server.tokenAuth, tokenAudienceWebClient, "") + c.Subject = user.GetSignature() + c.SetExpiry(time.Now().Add(10 * time.Minute)) + c.Audience = []string{tokenAudienceAPI} + token, err := server.tokenAuth.Sign(c) assert.NoError(t, err) req, _ := http.NewRequest(http.MethodGet, webClientFilesPath, nil) - req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", token["access_token"])) + req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token)) server.handleClientGetFiles(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) rr = httptest.NewRecorder() req, _ = http.NewRequest(http.MethodGet, webClientDirsPath, nil) - req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", token["access_token"])) + req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token)) server.handleClientGetDirContents(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorDirList403) rr = httptest.NewRecorder() req, _ = http.NewRequest(http.MethodGet, webClientDownloadZipPath, nil) - req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", token["access_token"])) + req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token)) server.handleWebClientDownloadZip(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) rr = httptest.NewRecorder() req, _ = http.NewRequest(http.MethodGet, webClientEditFilePath, nil) - req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", token["access_token"])) + req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token)) server.handleClientEditFile(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) rr = httptest.NewRecorder() req, _ = http.NewRequest(http.MethodGet, webClientSharePath, nil) - req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", token["access_token"])) + req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token)) server.handleClientAddShareGet(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) rr = httptest.NewRecorder() req, _ = http.NewRequest(http.MethodGet, webClientSharePath, nil) - req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", token["access_token"])) + req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token)) server.handleClientUpdateShareGet(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) rr = httptest.NewRecorder() req, _ = http.NewRequest(http.MethodPost, webClientSharePath, nil) - req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", token["access_token"])) + req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token)) server.handleClientAddSharePost(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) rr = httptest.NewRecorder() req, _ = http.NewRequest(http.MethodPost, webClientSharePath+"/id", nil) - req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", token["access_token"])) + req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token)) server.handleClientUpdateSharePost(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) rr = httptest.NewRecorder() req, _ = http.NewRequest(http.MethodGet, webClientSharesPath+jsonAPISuffix, nil) - req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", token["access_token"])) + req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token)) getAllShares(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) rr = httptest.NewRecorder() req, _ = http.NewRequest(http.MethodGet, webClientViewPDFPath, nil) - req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", token["access_token"])) + req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token)) server.handleClientGetPDF(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) @@ -3163,7 +3212,8 @@ func TestWebUserInvalidClaims(t *testing.T) { func TestInvalidClaims(t *testing.T) { server := httpdServer{} - server.initializeRouter() + err := server.initializeRouter() + require.NoError(t, err) rr := httptest.NewRecorder() user := dataprovider.User{ @@ -3172,21 +3222,21 @@ func TestInvalidClaims(t *testing.T) { Password: "pwd", }, } - c := jwtTokenClaims{ + c := &jwt.Claims{ Username: user.Username, Permissions: nil, - Signature: user.GetSignature(), } - token, err := c.createTokenResponse(server.tokenAuth, tokenAudienceWebClient, "") + c.Subject = user.GetSignature() + token, err := server.tokenAuth.SignWithParams(c, tokenAudienceWebClient, "", getTokenDuration(tokenAudienceWebClient)) assert.NoError(t, err) req, err := http.NewRequest(http.MethodGet, webClientProfilePath, nil) assert.NoError(t, err) - req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", token["access_token"])) - parsedToken, err := jwtauth.VerifyRequest(server.tokenAuth, req, jwtauth.TokenFromCookie) + req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token)) + parsedToken, err := jwt.VerifyRequest(server.tokenAuth, req, jwt.TokenFromCookie) assert.NoError(t, err) ctx := req.Context() - ctx = jwtauth.NewContext(ctx, parsedToken, err) + ctx = jwt.NewContext(ctx, parsedToken, err) req = req.WithContext(ctx) form := make(url.Values) @@ -3196,7 +3246,7 @@ func TestInvalidClaims(t *testing.T) { assert.NoError(t, err) req = req.WithContext(ctx) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", token["access_token"])) + req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token)) server.handleWebClientProfilePost(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) @@ -3204,21 +3254,21 @@ func TestInvalidClaims(t *testing.T) { Username: "", Password: user.Password, } - c = jwtTokenClaims{ + c = &jwt.Claims{ Username: admin.Username, Permissions: nil, - Signature: admin.GetSignature(), } - token, err = c.createTokenResponse(server.tokenAuth, tokenAudienceWebAdmin, "") + c.Subject = admin.GetSignature() + token, err = server.tokenAuth.SignWithParams(c, tokenAudienceWebAdmin, "", getTokenDuration(tokenAudienceWebAdmin)) assert.NoError(t, err) req, err = http.NewRequest(http.MethodGet, webAdminProfilePath, nil) assert.NoError(t, err) - req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", token["access_token"])) - parsedToken, err = jwtauth.VerifyRequest(server.tokenAuth, req, jwtauth.TokenFromCookie) + req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token)) + parsedToken, err = jwt.VerifyRequest(server.tokenAuth, req, jwt.TokenFromCookie) assert.NoError(t, err) ctx = req.Context() - ctx = jwtauth.NewContext(ctx, parsedToken, err) + ctx = jwt.NewContext(ctx, parsedToken, err) req = req.WithContext(ctx) form = make(url.Values) @@ -3228,7 +3278,7 @@ func TestInvalidClaims(t *testing.T) { assert.NoError(t, err) req = req.WithContext(ctx) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", token["access_token"])) + req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token)) server.handleWebAdminProfilePost(rr, req) assert.Equal(t, http.StatusForbidden, rr.Code) assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) @@ -3252,12 +3302,14 @@ func TestSigningKey(t *testing.T) { server1 := httpdServer{ signingPassphrase: signingPassphrase, } - server1.initializeRouter() + err := server1.initializeRouter() + require.NoError(t, err) server2 := httpdServer{ signingPassphrase: signingPassphrase, } - server2.initializeRouter() + err = server2.initializeRouter() + require.NoError(t, err) user := dataprovider.User{ BaseUser: sdk.BaseUser{ @@ -3265,18 +3317,17 @@ func TestSigningKey(t *testing.T) { Password: "pwd", }, } - c := jwtTokenClaims{ + c := &jwt.Claims{ Username: user.Username, Permissions: nil, - Signature: user.GetSignature(), } - token, err := c.createTokenResponse(server1.tokenAuth, tokenAudienceWebClient, "") + c.Subject = user.GetSignature() + token, err := server1.tokenAuth.SignWithParams(c, tokenAudienceWebClient, "", getTokenDuration(tokenAudienceWebClient)) assert.NoError(t, err) - accessToken := token["access_token"].(string) - assert.NotEmpty(t, accessToken) - _, err = server1.tokenAuth.Decode(accessToken) + assert.NotEmpty(t, token) + _, err = jwt.VerifyToken(server1.tokenAuth, token) assert.NoError(t, err) - _, err = server2.tokenAuth.Decode(accessToken) + _, err = jwt.VerifyToken(server2.tokenAuth, token) assert.NoError(t, err) } @@ -3425,7 +3476,8 @@ func TestSecureMiddlewareIntegration(t *testing.T) { assert.NoError(t, err) assert.Equal(t, []string{forwardedHostHeader, xForwardedProto}, server.binding.Security.proxyHeaders) assert.Equal(t, map[string]string{xForwardedProto: "https"}, server.binding.Security.getHTTPSProxyHeaders()) - server.initializeRouter() + err = server.initializeRouter() + require.NoError(t, err) rr := httptest.NewRecorder() r, err := http.NewRequest(http.MethodGet, webClientLoginPath, nil) @@ -3501,7 +3553,8 @@ func TestRESTAPIDisabled(t *testing.T) { enableWebClient: true, enableRESTAPI: false, } - server.initializeRouter() + err := server.initializeRouter() + require.NoError(t, err) assert.False(t, server.enableRESTAPI) rr := httptest.NewRecorder() r, err := http.NewRequest(http.MethodGet, healthzPath, nil) @@ -3538,7 +3591,8 @@ func TestWebAdminSetupWithInstallCode(t *testing.T) { enableWebClient: true, enableRESTAPI: true, } - server.initializeRouter() + err = server.initializeRouter() + require.NoError(t, err) for _, webURL := range []string{"/", webBasePath, webBaseAdminPath, webAdminLoginPath, webClientLoginPath} { rr := httptest.NewRecorder() @@ -3556,10 +3610,10 @@ func TestWebAdminSetupWithInstallCode(t *testing.T) { assert.Equal(t, http.StatusOK, rr.Code) cookie := rr.Header().Get("Set-Cookie") r.Header.Set("Cookie", cookie) - parsedToken, err := jwtauth.VerifyRequest(server.csrfTokenAuth, r, jwtauth.TokenFromCookie) + parsedToken, err := jwt.VerifyRequest(server.csrfTokenAuth, r, jwt.TokenFromCookie) assert.NoError(t, err) ctx := r.Context() - ctx = jwtauth.NewContext(ctx, parsedToken, err) + ctx = jwt.NewContext(ctx, parsedToken, err) r = r.WithContext(ctx) form := make(url.Values) @@ -3624,10 +3678,10 @@ func TestWebAdminSetupWithInstallCode(t *testing.T) { assert.Equal(t, http.StatusOK, rr.Code) cookie = rr.Header().Get("Set-Cookie") r.Header.Set("Cookie", cookie) - parsedToken, err = jwtauth.VerifyRequest(server.csrfTokenAuth, r, jwtauth.TokenFromCookie) + parsedToken, err = jwt.VerifyRequest(server.csrfTokenAuth, r, jwt.TokenFromCookie) assert.NoError(t, err) ctx = r.Context() - ctx = jwtauth.NewContext(ctx, parsedToken, err) + ctx = jwt.NewContext(ctx, parsedToken, err) r = r.WithContext(ctx) form = make(url.Values) @@ -3711,43 +3765,6 @@ func TestDbResetCodeManager(t *testing.T) { } } -func TestDecodeToken(t *testing.T) { - nodeID := "nodeID" - token := map[string]any{ - claimUsernameKey: defaultAdminUsername, - claimPermissionsKey: []string{dataprovider.PermAdminAny}, - jwt.SubjectKey: "", - claimNodeID: nodeID, - claimMustChangePasswordKey: false, - claimMustSetSecondFactorKey: true, - claimRef: "ref", - } - c := jwtTokenClaims{} - c.Decode(token) - assert.Equal(t, defaultAdminUsername, c.Username) - assert.Equal(t, nodeID, c.NodeID) - assert.False(t, c.MustChangePassword) - assert.True(t, c.MustSetTwoFactorAuth) - assert.Equal(t, "ref", c.Ref) - - asMap := c.asMap() - asMap[claimMustChangePasswordKey] = false - assert.Equal(t, token, asMap) - - token[claimMustChangePasswordKey] = 10 - c = jwtTokenClaims{} - c.Decode(token) - assert.False(t, c.MustChangePassword) - - token[claimMustChangePasswordKey] = true - c = jwtTokenClaims{} - c.Decode(token) - assert.True(t, c.MustChangePassword) - - claims := c.asMap() - assert.Equal(t, token, claims) -} - func TestEventRoleFilter(t *testing.T) { defaultVal := "default" req, err := http.NewRequest(http.MethodGet, fsEventsPath+"?role=role1", nil) @@ -3878,7 +3895,8 @@ func TestHTTPSRedirect(t *testing.T) { }, }, } - server.initializeRouter() + err = server.initializeRouter() + require.NoError(t, err) rr := httptest.NewRecorder() r, err := http.NewRequest(http.MethodGet, path.Join(acmeChallengeURI, tokenName), nil) @@ -3931,7 +3949,8 @@ func TestDisabledAdminLoginMethods(t *testing.T) { enableWebClient: true, enableRESTAPI: true, } - server.initializeRouter() + err := server.initializeRouter() + require.NoError(t, err) testServer := httptest.NewServer(server.router) defer testServer.Close() @@ -3986,7 +4005,8 @@ func TestDisabledUserLoginMethods(t *testing.T) { enableWebClient: true, enableRESTAPI: true, } - server.initializeRouter() + err := server.initializeRouter() + require.NoError(t, err) testServer := httptest.NewServer(server.router) defer testServer.Close() diff --git a/internal/httpd/middleware.go b/internal/httpd/middleware.go index 9e5f2d70..199fa89f 100644 --- a/internal/httpd/middleware.go +++ b/internal/httpd/middleware.go @@ -24,12 +24,12 @@ import ( "strings" "time" - "github.com/go-chi/jwtauth/v5" "github.com/rs/xid" "github.com/sftpgo/sdk" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/jwt" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" ) @@ -48,7 +48,7 @@ func (k *contextKey) String() string { } func validateJWTToken(w http.ResponseWriter, r *http.Request, audience tokenAudience) error { - token, _, err := jwtauth.FromContext(r.Context()) + token, err := jwt.FromContext(r.Context()) var redirectPath string if audience == tokenAudienceWebAdmin { @@ -70,7 +70,7 @@ func validateJWTToken(w http.ResponseWriter, r *http.Request, audience tokenAudi } } - if err != nil || token == nil { + if err != nil { logger.Debug(logSender, "", "error getting jwt token: %v", err) doRedirect(http.StatusText(http.StatusUnauthorized), err) return errInvalidToken @@ -82,17 +82,17 @@ func validateJWTToken(w http.ResponseWriter, r *http.Request, audience tokenAudi return errInvalidToken } // a user with a partial token will be always redirected to the appropriate two factor auth page - if err := checkPartialAuth(w, r, audience, token.Audience()); err != nil { + if err := checkPartialAuth(w, r, audience, token.Audience); err != nil { return err } - if !slices.Contains(token.Audience(), audience) { + if !token.Audience.Contains(audience) { logger.Debug(logSender, "", "the token is not valid for audience %q", audience) doRedirect("Your token audience is not valid", nil) return errInvalidToken } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := validateIPForToken(token, ipAddr); err != nil { - logger.Debug(logSender, "", "the token with id %q is not valid for the ip address %q", token.JwtID(), ipAddr) + logger.Debug(logSender, "", "the token with id %q is not valid for the ip address %q", token.ID, ipAddr) doRedirect("Your token is not valid", nil) return err } @@ -104,14 +104,14 @@ func validateJWTToken(w http.ResponseWriter, r *http.Request, audience tokenAudi } func (s *httpdServer) validateJWTPartialToken(w http.ResponseWriter, r *http.Request, audience tokenAudience) error { - token, _, err := jwtauth.FromContext(r.Context()) + token, err := jwt.FromContext(r.Context()) var notFoundFunc func(w http.ResponseWriter, r *http.Request, err error) if audience == tokenAudienceWebAdminPartial { notFoundFunc = s.renderNotFoundPage } else { notFoundFunc = s.renderClientNotFoundPage } - if err != nil || token == nil { + if err != nil { notFoundFunc(w, r, nil) return errInvalidToken } @@ -119,14 +119,14 @@ func (s *httpdServer) validateJWTPartialToken(w http.ResponseWriter, r *http.Req notFoundFunc(w, r, nil) return errInvalidToken } - if !slices.Contains(token.Audience(), audience) { - logger.Debug(logSender, "", "the partial token with id %q is not valid for audience %q", token.JwtID(), audience) + if !token.Audience.Contains(audience) { + logger.Debug(logSender, "", "the partial token with id %q is not valid for audience %q", token.ID, audience) notFoundFunc(w, r, nil) return errInvalidToken } ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) if err := validateIPForToken(token, ipAddr); err != nil { - logger.Debug(logSender, "", "the partial token with id %q is not valid for the ip address %q", token.JwtID(), ipAddr) + logger.Debug(logSender, "", "the partial token with id %q is not valid for the ip address %q", token.ID, ipAddr) notFoundFunc(w, r, nil) return err } @@ -194,7 +194,7 @@ func jwtAuthenticatorWebClient(next http.Handler) http.Handler { func (s *httpdServer) checkHTTPUserPerm(perm string) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, claims, err := jwtauth.FromContext(r.Context()) + claims, err := jwt.FromContext(r.Context()) if err != nil { if isWebRequest(r) { s.renderClientBadRequestPage(w, r, err) @@ -203,10 +203,8 @@ func (s *httpdServer) checkHTTPUserPerm(perm string) func(next http.Handler) htt } return } - tokenClaims := jwtTokenClaims{} - tokenClaims.Decode(claims) // for web client perms are negated and not granted - if tokenClaims.hasPerm(perm) { + if claims.HasPerm(perm) { if isWebRequest(r) { s.renderClientForbiddenPage(w, r, errors.New("you don't have permission for this action")) } else { @@ -223,7 +221,7 @@ func (s *httpdServer) checkHTTPUserPerm(perm string) func(next http.Handler) htt // checkAuthRequirements checks if the user must set a second factor auth or change the password func (s *httpdServer) checkAuthRequirements(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, claims, err := jwtauth.FromContext(r.Context()) + claims, err := jwt.FromContext(r.Context()) if err != nil { if isWebRequest(r) { if isWebClientRequest(r) { @@ -236,13 +234,11 @@ func (s *httpdServer) checkAuthRequirements(next http.Handler) http.Handler { } return } - tokenClaims := jwtTokenClaims{} - tokenClaims.Decode(claims) - if tokenClaims.MustSetTwoFactorAuth || tokenClaims.MustChangePassword { + if claims.MustSetTwoFactorAuth || claims.MustChangePassword { var err error - if tokenClaims.MustSetTwoFactorAuth { - if len(tokenClaims.RequiredTwoFactorProtocols) > 0 { - protocols := strings.Join(tokenClaims.RequiredTwoFactorProtocols, ", ") + if claims.MustSetTwoFactorAuth { + if len(claims.RequiredTwoFactorProtocols) > 0 { + protocols := strings.Join(claims.RequiredTwoFactorProtocols, ", ") err = util.NewI18nError( util.NewGenericError( fmt.Sprintf("Two-factor authentication requirements not met, please configure two-factor authentication for the following protocols: %v", @@ -301,7 +297,7 @@ func (s *httpdServer) requireBuiltinLogin(next http.Handler) http.Handler { func (s *httpdServer) checkPerms(perms ...string) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, claims, err := jwtauth.FromContext(r.Context()) + claims, err := jwt.FromContext(r.Context()) if err != nil { if isWebRequest(r) { s.renderBadRequestPage(w, r, err) @@ -310,11 +306,9 @@ func (s *httpdServer) checkPerms(perms ...string) func(next http.Handler) http.H } return } - tokenClaims := jwtTokenClaims{} - tokenClaims.Decode(claims) for _, perm := range perms { - if !tokenClaims.hasPerm(perm) { + if !claims.HasPerm(perm) { if isWebRequest(r) { s.renderForbiddenPage(w, r, util.NewI18nError(fs.ErrPermission, util.I18nError403Message)) } else { @@ -332,14 +326,14 @@ func (s *httpdServer) checkPerms(perms ...string) func(next http.Handler) http.H func (s *httpdServer) verifyCSRFHeader(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { tokenString := r.Header.Get(csrfHeaderToken) - token, err := jwtauth.VerifyToken(s.csrfTokenAuth, tokenString) + token, err := jwt.VerifyToken(s.csrfTokenAuth, tokenString) if err != nil || token == nil { logger.Debug(logSender, "", "error validating CSRF header: %v", err) sendAPIResponse(w, r, err, "Invalid token", http.StatusForbidden) return } - if !slices.Contains(token.Audience(), tokenAudienceCSRF) { + if !token.Audience.Contains(tokenAudienceCSRF) { logger.Debug(logSender, "", "error validating CSRF header token audience") sendAPIResponse(w, r, errors.New("the token is not valid"), "", http.StatusForbidden) return @@ -359,49 +353,52 @@ func (s *httpdServer) verifyCSRFHeader(next http.Handler) http.Handler { }) } -func checkNodeToken(tokenAuth *jwtauth.JWTAuth) func(next http.Handler) http.Handler { +func checkNodeToken(tokenAuth *jwt.Signer) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - token := r.Header.Get(dataprovider.NodeTokenHeader) - if token == "" { + bearer := r.Header.Get(dataprovider.NodeTokenHeader) + if bearer == "" { next.ServeHTTP(w, r) return } - if len(token) > 7 && strings.ToUpper(token[0:6]) == "BEARER" { - token = token[7:] + const prefix = "Bearer " + if len(bearer) >= len(prefix) && strings.EqualFold(bearer[:len(prefix)], prefix) { + bearer = bearer[len(prefix):] } - if invalidatedJWTTokens.Get(token) { + if invalidatedJWTTokens.Get(bearer) { logger.Debug(logSender, "", "the node token has been invalidated") sendAPIResponse(w, r, fmt.Errorf("the provided token is not valid"), "", http.StatusUnauthorized) return } - admin, role, perms, err := dataprovider.AuthenticateNodeToken(token) + claims, err := dataprovider.AuthenticateNodeToken(bearer) if err != nil { - logger.Debug(logSender, "", "unable to authenticate node token %q: %v", token, err) + logger.Debug(logSender, "", "unable to authenticate node token %q: %v", bearer, err) sendAPIResponse(w, r, fmt.Errorf("the provided token cannot be authenticated"), "", http.StatusUnauthorized) return } - defer invalidatedJWTTokens.Add(token, time.Now().Add(2*time.Minute).UTC()) + defer invalidatedJWTTokens.Add(bearer, time.Now().Add(2*time.Minute).UTC()) - c := jwtTokenClaims{ - Username: admin, - Permissions: perms, + c := &jwt.Claims{ + Username: claims.Username, + Permissions: claims.Permissions, NodeID: dataprovider.GetNodeName(), - Role: role, + Role: claims.Role, } - resp, err := c.createTokenResponse(tokenAuth, tokenAudienceAPI, util.GetIPFromRemoteAddress(r.RemoteAddr)) + + token, err := tokenAuth.SignWithParams(c, tokenAudienceAPI, util.GetIPFromRemoteAddress(r.RemoteAddr), getTokenDuration(tokenAudienceAPI)) if err != nil { sendAPIResponse(w, r, err, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } - r.Header.Set("Authorization", fmt.Sprintf("Bearer %v", resp["access_token"])) + resp := c.BuildTokenResponse(token) + r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", resp.Token)) next.ServeHTTP(w, r) }) } } -func checkAPIKeyAuth(tokenAuth *jwtauth.JWTAuth, scope dataprovider.APIKeyScope) func(next http.Handler) http.Handler { +func checkAPIKeyAuth(tokenAuth *jwt.Signer, scope dataprovider.APIKeyScope) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { apiKey := r.Header.Get("X-SFTPGO-API-KEY") @@ -484,7 +481,7 @@ func checkAPIKeyAuth(tokenAuth *jwtauth.JWTAuth, scope dataprovider.APIKeyScope) func forbidAPIKeyAuthentication(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return @@ -498,7 +495,7 @@ func forbidAPIKeyAuthentication(next http.Handler) http.Handler { }) } -func authenticateAdminWithAPIKey(username, keyID string, tokenAuth *jwtauth.JWTAuth, r *http.Request) error { +func authenticateAdminWithAPIKey(username, keyID string, tokenAuth *jwt.Signer, r *http.Request) error { if username == "" { return errors.New("the provided key is not associated with any admin and no username was provided") } @@ -513,25 +510,26 @@ func authenticateAdminWithAPIKey(username, keyID string, tokenAuth *jwtauth.JWTA if err := admin.CanLogin(ipAddr); err != nil { return err } - c := jwtTokenClaims{ + c := &jwt.Claims{ Username: admin.Username, Permissions: admin.Permissions, - Signature: admin.GetSignature(), Role: admin.Role, APIKeyID: keyID, } + c.Subject = admin.GetSignature() - resp, err := c.createTokenResponse(tokenAuth, tokenAudienceAPI, ipAddr) + token, err := tokenAuth.SignWithParams(c, tokenAudienceAPI, ipAddr, getTokenDuration(tokenAudienceAPI)) if err != nil { return err } - r.Header.Set("Authorization", fmt.Sprintf("Bearer %v", resp["access_token"])) + resp := c.BuildTokenResponse(token) + r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", resp.Token)) dataprovider.UpdateAdminLastLogin(&admin) common.DelayLogin(nil) return nil } -func authenticateUserWithAPIKey(username, keyID string, tokenAuth *jwtauth.JWTAuth, r *http.Request) error { +func authenticateUserWithAPIKey(username, keyID string, tokenAuth *jwt.Signer, r *http.Request) error { ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) protocol := common.ProtocolHTTP if username == "" { @@ -569,20 +567,21 @@ func authenticateUserWithAPIKey(username, keyID string, tokenAuth *jwtauth.JWTAu updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure, r) return common.ErrInternalFailure } - c := jwtTokenClaims{ + c := &jwt.Claims{ Username: user.Username, Permissions: user.Filters.WebClient, - Signature: user.GetSignature(), Role: user.Role, APIKeyID: keyID, } + c.Subject = user.GetSignature() - resp, err := c.createTokenResponse(tokenAuth, tokenAudienceAPIUser, ipAddr) + token, err := tokenAuth.SignWithParams(c, tokenAudienceAPIUser, ipAddr, getTokenDuration(tokenAudienceAPIUser)) if err != nil { updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure, r) return err } - r.Header.Set("Authorization", fmt.Sprintf("Bearer %v", resp["access_token"])) + resp := c.BuildTokenResponse(token) + r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", resp.Token)) dataprovider.UpdateLastLogin(&user) updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, nil, r) diff --git a/internal/httpd/oidc.go b/internal/httpd/oidc.go index 8f612cde..c1a1d22b 100644 --- a/internal/httpd/oidc.go +++ b/internal/httpd/oidc.go @@ -31,6 +31,7 @@ import ( "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/httpclient" + "github.com/drakkan/sftpgo/v2/internal/jwt" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/util" ) @@ -551,19 +552,20 @@ func (s *httpdServer) oidcTokenAuthenticator(audience tokenAudience) func(next h if err != nil { return } - jwtTokenClaims := jwtTokenClaims{ - JwtID: token.Cookie, + claims := jwt.Claims{ Username: dataprovider.ConvertName(token.Username), Permissions: token.Permissions, Role: token.TokenRole, HideUserPageSections: token.HideUserPageSections, } + claims.ID = token.Cookie if audience == tokenAudienceWebClient { - jwtTokenClaims.MustSetTwoFactorAuth = token.MustSetTwoFactorAuth - jwtTokenClaims.MustChangePassword = token.MustChangePassword - jwtTokenClaims.RequiredTwoFactorProtocols = token.RequiredTwoFactorProtocols + claims.MustSetTwoFactorAuth = token.MustSetTwoFactorAuth + claims.MustChangePassword = token.MustChangePassword + claims.RequiredTwoFactorProtocols = token.RequiredTwoFactorProtocols } - _, tokenString, err := jwtTokenClaims.createToken(s.tokenAuth, audience, util.GetIPFromRemoteAddress(r.RemoteAddr)) + tokenString, err := s.tokenAuth.SignWithParams(&claims, audience, util.GetIPFromRemoteAddress(r.RemoteAddr), + getTokenDuration(audience)) if err != nil { setFlashMessage(w, r, newFlashMessage("Unable to create cookie", util.I18nError500Message)) if audience == tokenAudienceWebAdmin { diff --git a/internal/httpd/oidc_test.go b/internal/httpd/oidc_test.go index cb0a4fc3..bb32e5cf 100644 --- a/internal/httpd/oidc_test.go +++ b/internal/httpd/oidc_test.go @@ -32,7 +32,6 @@ import ( "unsafe" "github.com/coreos/go-oidc/v3/oidc" - "github.com/go-chi/jwtauth/v5" "github.com/rs/xid" "github.com/sftpgo/sdk" "github.com/stretchr/testify/assert" @@ -41,6 +40,7 @@ import ( "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/jwt" "github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/internal/vfs" @@ -142,7 +142,8 @@ func TestOIDCLoginLogout(t *testing.T) { server := getTestOIDCServer() err := server.binding.OIDC.initialize() assert.NoError(t, err) - server.initializeRouter() + err = server.initializeRouter() + require.NoError(t, err) rr := httptest.NewRecorder() r, err := http.NewRequest(http.MethodGet, webOIDCRedirectPath, nil) @@ -768,7 +769,8 @@ func TestValidateOIDCToken(t *testing.T) { server := getTestOIDCServer() err := server.binding.OIDC.initialize() assert.NoError(t, err) - server.initializeRouter() + err = server.initializeRouter() + require.NoError(t, err) rr := httptest.NewRecorder() r, err := http.NewRequest(http.MethodGet, webClientLogoutPath, nil) @@ -796,7 +798,7 @@ func TestValidateOIDCToken(t *testing.T) { oidcMgr.removeToken(token.Cookie) assert.Len(t, oidcMgr.tokens, 0) - server.tokenAuth = jwtauth.New("PS256", util.GenerateRandomBytes(32), nil) + server.tokenAuth.SetSigner(&failingJoseSigner{}) token = oidcToken{ Cookie: util.GenerateOpaqueString(), AccessToken: util.GenerateUniqueID(), @@ -833,11 +835,12 @@ func TestSkipOIDCAuth(t *testing.T) { server := getTestOIDCServer() err := server.binding.OIDC.initialize() assert.NoError(t, err) - server.initializeRouter() - jwtTokenClaims := jwtTokenClaims{ - Username: "user", - } - _, tokenString, err := jwtTokenClaims.createToken(server.tokenAuth, tokenAudienceWebClient, "") + err = server.initializeRouter() + require.NoError(t, err) + + claims := jwt.NewClaims(tokenAudienceWebClient, "", getTokenDuration(tokenAudienceWebClient)) + claims.Username = "user" + tokenString, err := server.tokenAuth.Sign(claims) assert.NoError(t, err) rr := httptest.NewRecorder() r, err := http.NewRequest(http.MethodGet, webClientLogoutPath, nil) @@ -968,7 +971,8 @@ func TestOIDCImplicitRoles(t *testing.T) { server.binding.OIDC.ImplicitRoles = true err := server.binding.OIDC.initialize() assert.NoError(t, err) - server.initializeRouter() + err = server.initializeRouter() + require.NoError(t, err) authReq := newOIDCPendingAuth(tokenAudienceWebAdmin) oidcMgr.addPendingAuth(authReq) @@ -1241,7 +1245,8 @@ func TestOIDCEvMgrIntegration(t *testing.T) { server.binding.OIDC.CustomFields = []string{"custom1.sub", "custom2"} err = server.binding.OIDC.initialize() assert.NoError(t, err) - server.initializeRouter() + err = server.initializeRouter() + require.NoError(t, err) // login a user with OIDC _, err = dataprovider.UserExists(username, "") assert.ErrorIs(t, err, util.ErrNotFound) @@ -1378,7 +1383,8 @@ func TestOIDCPreLoginHook(t *testing.T) { server.binding.OIDC.CustomFields = []string{"field1", "field2"} err = server.binding.OIDC.initialize() assert.NoError(t, err) - server.initializeRouter() + err = server.initializeRouter() + require.NoError(t, err) _, err = dataprovider.UserExists(username, "") assert.ErrorIs(t, err, util.ErrNotFound) @@ -1554,7 +1560,8 @@ func TestOIDCWithLoginFormsDisabled(t *testing.T) { server.binding.EnableWebClient = true err := server.binding.OIDC.initialize() assert.NoError(t, err) - server.initializeRouter() + err = server.initializeRouter() + require.NoError(t, err) // login with an admin user authReq := newOIDCPendingAuth(tokenAudienceWebAdmin) oidcMgr.addPendingAuth(authReq) diff --git a/internal/httpd/server.go b/internal/httpd/server.go index c572307c..14e1a1df 100644 --- a/internal/httpd/server.go +++ b/internal/httpd/server.go @@ -16,6 +16,7 @@ package httpd import ( "context" + "crypto/rand" "crypto/tls" "crypto/x509" "errors" @@ -32,9 +33,8 @@ import ( "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" - "github.com/go-chi/jwtauth/v5" "github.com/go-chi/render" - "github.com/lestrrat-go/jwx/v2/jwa" + "github.com/go-jose/go-jose/v4" "github.com/rs/cors" "github.com/rs/xid" "github.com/sftpgo/sdk" @@ -43,6 +43,7 @@ import ( "github.com/drakkan/sftpgo/v2/internal/acme" "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/jwt" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/mfa" "github.com/drakkan/sftpgo/v2/internal/smtp" @@ -69,8 +70,8 @@ type httpdServer struct { renderOpenAPI bool isShared int router *chi.Mux - tokenAuth *jwtauth.JWTAuth - csrfTokenAuth *jwtauth.JWTAuth + tokenAuth *jwt.Signer + csrfTokenAuth *jwt.Signer signingPassphrase string cors CorsConfig } @@ -99,7 +100,9 @@ func (s *httpdServer) setShared(value int) { } func (s *httpdServer) listenAndServe() error { - s.initializeRouter() + if err := s.initializeRouter(); err != nil { + return err + } httpServer := &http.Server{ Handler: s.router, ReadHeaderTimeout: 30 * time.Second, @@ -173,7 +176,7 @@ func (s *httpdServer) renderClientLoginPage(w http.ResponseWriter, r *http.Reque Title: util.I18nLoginTitle, CurrentURL: webClientLoginPath, Error: err, - CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, xid.New().String(), webBaseClientPath), + CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, rand.Text(), webBaseClientPath), Branding: s.binding.webClientBranding(), Languages: s.binding.languages(), FormDisabled: s.binding.isWebClientLoginFormDisabled(), @@ -327,7 +330,7 @@ func (s *httpdServer) handleWebClientPasswordResetPost(w http.ResponseWriter, r func (s *httpdServer) handleWebClientTwoFactorRecoveryPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil { s.renderNotFoundPage(w, r, nil) return @@ -393,7 +396,7 @@ func (s *httpdServer) handleWebClientTwoFactorRecoveryPost(w http.ResponseWriter func (s *httpdServer) handleWebClientTwoFactorPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil { s.renderNotFoundPage(w, r, nil) return @@ -451,7 +454,7 @@ func (s *httpdServer) handleWebClientTwoFactorPost(w http.ResponseWriter, r *htt func (s *httpdServer) handleWebAdminTwoFactorRecoveryPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil { s.renderNotFoundPage(w, r, nil) return @@ -511,7 +514,7 @@ func (s *httpdServer) handleWebAdminTwoFactorRecoveryPost(w http.ResponseWriter, func (s *httpdServer) handleWebAdminTwoFactorPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil { s.renderNotFoundPage(w, r, nil) return @@ -592,7 +595,7 @@ func (s *httpdServer) renderAdminLoginPage(w http.ResponseWriter, r *http.Reques Title: util.I18nLoginTitle, CurrentURL: webAdminLoginPath, Error: err, - CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, xid.New().String(), webBaseAdminPath), + CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, rand.Text(), webBaseAdminPath), Branding: s.binding.webAdminBranding(), Languages: s.binding.languages(), FormDisabled: s.binding.isWebAdminLoginFormDisabled(), @@ -735,15 +738,15 @@ func (s *httpdServer) loginUser( w http.ResponseWriter, r *http.Request, user *dataprovider.User, connectionID, ipAddr string, isSecondFactorAuth bool, errorFunc func(w http.ResponseWriter, r *http.Request, err *util.I18nError), ) { - c := jwtTokenClaims{ + c := &jwt.Claims{ Username: user.Username, Permissions: user.Filters.WebClient, - Signature: user.GetSignature(), Role: user.Role, MustSetTwoFactorAuth: user.MustSetSecondFactor(), MustChangePassword: user.MustChangePassword(), RequiredTwoFactorProtocols: user.Filters.TwoFactorAuthProtocols, } + c.Subject = user.GetSignature() audience := tokenAudienceWebClient if user.Filters.TOTPConfig.Enabled && slices.Contains(user.Filters.TOTPConfig.Protocols, common.ProtocolHTTP) && @@ -751,7 +754,7 @@ func (s *httpdServer) loginUser( audience = tokenAudienceWebClientPartial } - err := c.createAndSetCookie(w, r, s.tokenAuth, audience, ipAddr) + err := createAndSetCookie(w, r, c, s.tokenAuth, audience, ipAddr) if err != nil { logger.Warn(logSender, connectionID, "unable to set user login cookie %v", err) updateLoginMetrics(user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure, r) @@ -781,22 +784,22 @@ func (s *httpdServer) loginAdmin( isSecondFactorAuth bool, errorFunc func(w http.ResponseWriter, r *http.Request, err *util.I18nError), ipAddr string, ) { - c := jwtTokenClaims{ + c := &jwt.Claims{ Username: admin.Username, Permissions: admin.Permissions, Role: admin.Role, - Signature: admin.GetSignature(), HideUserPageSections: admin.Filters.Preferences.HideUserPageSections, MustSetTwoFactorAuth: admin.Filters.RequireTwoFactor && !admin.Filters.TOTPConfig.Enabled, MustChangePassword: admin.Filters.RequirePasswordChange, } + c.Subject = admin.GetSignature() audience := tokenAudienceWebAdmin if admin.Filters.TOTPConfig.Enabled && admin.CanManageMFA() && !isSecondFactorAuth { audience = tokenAudienceWebAdminPartial } - err := c.createAndSetCookie(w, r, s.tokenAuth, audience, ipAddr) + err := createAndSetCookie(w, r, c, s.tokenAuth, audience, ipAddr) if err != nil { logger.Warn(logSender, "", "unable to set admin login cookie %v", err) if errorFunc == nil { @@ -907,17 +910,17 @@ func (s *httpdServer) getUserToken(w http.ResponseWriter, r *http.Request) { } func (s *httpdServer) generateAndSendUserToken(w http.ResponseWriter, r *http.Request, ipAddr string, user dataprovider.User) { - c := jwtTokenClaims{ + c := &jwt.Claims{ Username: user.Username, Permissions: user.Filters.WebClient, - Signature: user.GetSignature(), Role: user.Role, MustSetTwoFactorAuth: user.MustSetSecondFactor(), MustChangePassword: user.MustChangePassword(), RequiredTwoFactorProtocols: user.Filters.TwoFactorAuthProtocols, } + c.Subject = user.GetSignature() - resp, err := c.createTokenResponse(s.tokenAuth, tokenAudienceAPIUser, ipAddr) + token, err := s.tokenAuth.SignWithParams(c, tokenAudienceAPIUser, ipAddr, getTokenDuration(tokenAudienceAPIUser)) if err != nil { updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure, r) sendAPIResponse(w, r, err, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) @@ -926,7 +929,7 @@ func (s *httpdServer) generateAndSendUserToken(w http.ResponseWriter, r *http.Re updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err, r) dataprovider.UpdateLastLogin(&user) - render.JSON(w, r, resp) + render.JSON(w, r, c.BuildTokenResponse(token)) } func (s *httpdServer) getToken(w http.ResponseWriter, r *http.Request) { @@ -976,17 +979,16 @@ func (s *httpdServer) getToken(w http.ResponseWriter, r *http.Request) { } func (s *httpdServer) generateAndSendToken(w http.ResponseWriter, r *http.Request, admin dataprovider.Admin, ip string) { - c := jwtTokenClaims{ + c := &jwt.Claims{ Username: admin.Username, Permissions: admin.Permissions, Role: admin.Role, - Signature: admin.GetSignature(), MustSetTwoFactorAuth: admin.Filters.RequireTwoFactor && !admin.Filters.TOTPConfig.Enabled, MustChangePassword: admin.Filters.RequirePasswordChange, } + c.Subject = admin.GetSignature() - resp, err := c.createTokenResponse(s.tokenAuth, tokenAudienceAPI, ip) - + token, err := s.tokenAuth.SignWithParams(c, tokenAudienceAPI, ip, getTokenDuration(tokenAudienceAPI)) if err != nil { sendAPIResponse(w, r, err, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return @@ -994,42 +996,39 @@ func (s *httpdServer) generateAndSendToken(w http.ResponseWriter, r *http.Reques dataprovider.UpdateAdminLastLogin(&admin) common.DelayLogin(nil) - render.JSON(w, r, resp) + render.JSON(w, r, c.BuildTokenResponse(token)) } func (s *httpdServer) checkCookieExpiration(w http.ResponseWriter, r *http.Request) { if _, ok := r.Context().Value(oidcTokenKey).(string); ok { return } - token, claims, err := jwtauth.FromContext(r.Context()) - if err != nil || token == nil { + claims, err := jwt.FromContext(r.Context()) + if err != nil { return } - tokenClaims := jwtTokenClaims{} - tokenClaims.Decode(claims) - if tokenClaims.Username == "" || tokenClaims.Signature == "" { + if claims.Username == "" || claims.Subject == "" { return } - if time.Until(token.Expiration()) > cookieRefreshThreshold { + if time.Until(claims.Expiry.Time()) > cookieRefreshThreshold { return } - if (time.Since(token.IssuedAt()) + cookieTokenDuration) > maxTokenDuration { + if (time.Since(claims.IssuedAt.Time()) + cookieTokenDuration) > maxTokenDuration { return } - tokenClaims.JwtIssuedAt = token.IssuedAt() - if slices.Contains(token.Audience(), tokenAudienceWebClient) { - s.refreshClientToken(w, r, &tokenClaims) + if claims.Audience.Contains(tokenAudienceWebClient) { + s.refreshClientToken(w, r, claims) } else { - s.refreshAdminToken(w, r, &tokenClaims) + s.refreshAdminToken(w, r, claims) } } -func (s *httpdServer) refreshClientToken(w http.ResponseWriter, r *http.Request, tokenClaims *jwtTokenClaims) { +func (s *httpdServer) refreshClientToken(w http.ResponseWriter, r *http.Request, tokenClaims *jwt.Claims) { user, err := dataprovider.GetUserWithGroupSettings(tokenClaims.Username, "") if err != nil { return } - if user.GetSignature() != tokenClaims.Signature { + if user.GetSignature() != tokenClaims.Subject { logger.Debug(logSender, "", "signature mismatch for user %q, unable to refresh cookie", user.Username) return } @@ -1045,15 +1044,15 @@ func (s *httpdServer) refreshClientToken(w http.ResponseWriter, r *http.Request, tokenClaims.Permissions = user.Filters.WebClient tokenClaims.Role = user.Role logger.Debug(logSender, "", "cookie refreshed for user %q", user.Username) - tokenClaims.createAndSetCookie(w, r, s.tokenAuth, tokenAudienceWebClient, util.GetIPFromRemoteAddress(r.RemoteAddr)) //nolint:errcheck + createAndSetCookie(w, r, tokenClaims, s.tokenAuth, tokenAudienceWebClient, util.GetIPFromRemoteAddress(r.RemoteAddr)) //nolint:errcheck } -func (s *httpdServer) refreshAdminToken(w http.ResponseWriter, r *http.Request, tokenClaims *jwtTokenClaims) { +func (s *httpdServer) refreshAdminToken(w http.ResponseWriter, r *http.Request, tokenClaims *jwt.Claims) { admin, err := dataprovider.AdminExists(tokenClaims.Username) if err != nil { return } - if admin.GetSignature() != tokenClaims.Signature { + if admin.GetSignature() != tokenClaims.Subject { logger.Debug(logSender, "", "signature mismatch for admin %q, unable to refresh cookie", admin.Username) return } @@ -1066,18 +1065,18 @@ func (s *httpdServer) refreshAdminToken(w http.ResponseWriter, r *http.Request, tokenClaims.Role = admin.Role tokenClaims.HideUserPageSections = admin.Filters.Preferences.HideUserPageSections logger.Debug(logSender, "", "cookie refreshed for admin %q", admin.Username) - tokenClaims.createAndSetCookie(w, r, s.tokenAuth, tokenAudienceWebAdmin, ipAddr) //nolint:errcheck + createAndSetCookie(w, r, tokenClaims, s.tokenAuth, tokenAudienceWebAdmin, ipAddr) //nolint:errcheck } func (s *httpdServer) updateContextFromCookie(r *http.Request) *http.Request { - token, _, err := jwtauth.FromContext(r.Context()) - if token == nil || err != nil { + _, err := jwt.FromContext(r.Context()) + if err != nil { _, err = r.Cookie(jwtCookieKey) if err != nil { return r } - token, err = jwtauth.VerifyRequest(s.tokenAuth, r, jwtauth.TokenFromCookie) - ctx := jwtauth.NewContext(r.Context(), token, err) + token, err := jwt.VerifyRequest(s.tokenAuth, r, jwt.TokenFromCookie) + ctx := jwt.NewContext(r.Context(), token, err) return r.WithContext(ctx) } return r @@ -1235,10 +1234,18 @@ func (s *httpdServer) mustCheckPath(r *http.Request) bool { return !strings.HasPrefix(urlPath, webStaticFilesPath) && !strings.HasPrefix(urlPath, acmeChallengeURI) } -func (s *httpdServer) initializeRouter() { +func (s *httpdServer) initializeRouter() error { + signer, err := jwt.NewSigner(jose.HS256, getSigningKey(s.signingPassphrase)) + if err != nil { + return err + } + csrfSigner, err := jwt.NewSigner(jose.HS256, getSigningKey(s.signingPassphrase)) + if err != nil { + return err + } var hasHTTPSRedirect bool - s.tokenAuth = jwtauth.New(jwa.HS256.String(), getSigningKey(s.signingPassphrase), nil) - s.csrfTokenAuth = jwtauth.New(jwa.HS256.String(), getSigningKey(s.signingPassphrase), nil) + s.tokenAuth = signer + s.csrfTokenAuth = csrfSigner s.router = chi.NewRouter() s.router.Use(middleware.RequestID) @@ -1336,6 +1343,7 @@ func (s *httpdServer) initializeRouter() { s.setupWebClientRoutes() s.setupWebAdminRoutes() + return nil } func (s *httpdServer) setupRESTAPIRoutes() { @@ -1351,7 +1359,7 @@ func (s *httpdServer) setupRESTAPIRoutes() { if !s.binding.isAdminAPIKeyAuthDisabled() { router.Use(checkAPIKeyAuth(s.tokenAuth, dataprovider.APIKeyScopeAdmin)) } - router.Use(jwtauth.Verify(s.tokenAuth, jwtauth.TokenFromHeader)) + router.Use(jwt.Verify(s.tokenAuth, jwt.TokenFromHeader)) router.Use(jwtAuthenticatorAPI) router.Get(versionPath, func(w http.ResponseWriter, r *http.Request) { @@ -1480,7 +1488,7 @@ func (s *httpdServer) setupRESTAPIRoutes() { if !s.binding.isUserAPIKeyAuthDisabled() { router.Use(checkAPIKeyAuth(s.tokenAuth, dataprovider.APIKeyScopeUser)) } - router.Use(jwtauth.Verify(s.tokenAuth, jwtauth.TokenFromHeader)) + router.Use(jwt.Verify(s.tokenAuth, jwt.TokenFromHeader)) router.Use(jwtAuthenticatorAPIUser) router.With(forbidAPIKeyAuthentication).Get(userLogoutPath, s.logout) @@ -1568,31 +1576,31 @@ func (s *httpdServer) setupWebClientRoutes() { s.router.Get(webClientOIDCLoginPath, s.handleWebClientOIDCLogin) } if !s.binding.isWebClientLoginFormDisabled() { - s.router.With(jwtauth.Verify(s.csrfTokenAuth, jwtauth.TokenFromCookie)). + s.router.With(jwt.Verify(s.csrfTokenAuth, jwt.TokenFromCookie)). Post(webClientLoginPath, s.handleWebClientLoginPost) s.router.Get(webClientForgotPwdPath, s.handleWebClientForgotPwd) - s.router.With(jwtauth.Verify(s.csrfTokenAuth, jwtauth.TokenFromCookie)). + s.router.With(jwt.Verify(s.csrfTokenAuth, jwt.TokenFromCookie)). Post(webClientForgotPwdPath, s.handleWebClientForgotPwdPost) - s.router.With(jwtauth.Verify(s.csrfTokenAuth, jwtauth.TokenFromCookie)). + s.router.With(jwt.Verify(s.csrfTokenAuth, jwt.TokenFromCookie)). Get(webClientResetPwdPath, s.handleWebClientPasswordReset) - s.router.With(jwtauth.Verify(s.csrfTokenAuth, jwtauth.TokenFromCookie)). + s.router.With(jwt.Verify(s.csrfTokenAuth, jwt.TokenFromCookie)). Post(webClientResetPwdPath, s.handleWebClientPasswordResetPost) - s.router.With(jwtauth.Verify(s.tokenAuth, jwtauth.TokenFromCookie), + s.router.With(jwt.Verify(s.tokenAuth, jwt.TokenFromCookie), s.jwtAuthenticatorPartial(tokenAudienceWebClientPartial)). Get(webClientTwoFactorPath, s.handleWebClientTwoFactor) - s.router.With(jwtauth.Verify(s.tokenAuth, jwtauth.TokenFromCookie), + s.router.With(jwt.Verify(s.tokenAuth, jwt.TokenFromCookie), s.jwtAuthenticatorPartial(tokenAudienceWebClientPartial)). Post(webClientTwoFactorPath, s.handleWebClientTwoFactorPost) - s.router.With(jwtauth.Verify(s.tokenAuth, jwtauth.TokenFromCookie), + s.router.With(jwt.Verify(s.tokenAuth, jwt.TokenFromCookie), s.jwtAuthenticatorPartial(tokenAudienceWebClientPartial)). Get(webClientTwoFactorRecoveryPath, s.handleWebClientTwoFactorRecovery) - s.router.With(jwtauth.Verify(s.tokenAuth, jwtauth.TokenFromCookie), + s.router.With(jwt.Verify(s.tokenAuth, jwt.TokenFromCookie), s.jwtAuthenticatorPartial(tokenAudienceWebClientPartial)). Post(webClientTwoFactorRecoveryPath, s.handleWebClientTwoFactorRecoveryPost) } // share routes available to external users s.router.Get(webClientPubSharesPath+"/{id}/login", s.handleClientShareLoginGet) - s.router.With(jwtauth.Verify(s.csrfTokenAuth, jwtauth.TokenFromCookie)). + s.router.With(jwt.Verify(s.csrfTokenAuth, jwt.TokenFromCookie)). Post(webClientPubSharesPath+"/{id}/login", s.handleClientShareLoginPost) s.router.Get(webClientPubSharesPath+"/{id}/logout", s.handleClientShareLogout) s.router.Get(webClientPubSharesPath+"/{id}", s.downloadFromShare) @@ -1611,7 +1619,7 @@ func (s *httpdServer) setupWebClientRoutes() { if s.binding.OIDC.isEnabled() { router.Use(s.oidcTokenAuthenticator(tokenAudienceWebClient)) } - router.Use(jwtauth.Verify(s.tokenAuth, oidcTokenFromContext, jwtauth.TokenFromCookie)) + router.Use(jwt.Verify(s.tokenAuth, oidcTokenFromContext, jwt.TokenFromCookie)) router.Use(jwtAuthenticatorWebClient) router.Get(webClientLogoutPath, s.handleWebClientLogout) @@ -1702,29 +1710,29 @@ func (s *httpdServer) setupWebAdminRoutes() { } s.router.Get(webOAuth2RedirectPath, s.handleOAuth2TokenRedirect) s.router.Get(webAdminSetupPath, s.handleWebAdminSetupGet) - s.router.With(jwtauth.Verify(s.csrfTokenAuth, jwtauth.TokenFromCookie)). + s.router.With(jwt.Verify(s.csrfTokenAuth, jwt.TokenFromCookie)). Post(webAdminSetupPath, s.handleWebAdminSetupPost) if !s.binding.isWebAdminLoginFormDisabled() { - s.router.With(jwtauth.Verify(s.csrfTokenAuth, jwtauth.TokenFromCookie)). + s.router.With(jwt.Verify(s.csrfTokenAuth, jwt.TokenFromCookie)). Post(webAdminLoginPath, s.handleWebAdminLoginPost) - s.router.With(jwtauth.Verify(s.tokenAuth, jwtauth.TokenFromCookie), + s.router.With(jwt.Verify(s.tokenAuth, jwt.TokenFromCookie), s.jwtAuthenticatorPartial(tokenAudienceWebAdminPartial)). Get(webAdminTwoFactorPath, s.handleWebAdminTwoFactor) - s.router.With(jwtauth.Verify(s.tokenAuth, jwtauth.TokenFromCookie), + s.router.With(jwt.Verify(s.tokenAuth, jwt.TokenFromCookie), s.jwtAuthenticatorPartial(tokenAudienceWebAdminPartial)). Post(webAdminTwoFactorPath, s.handleWebAdminTwoFactorPost) - s.router.With(jwtauth.Verify(s.tokenAuth, jwtauth.TokenFromCookie), + s.router.With(jwt.Verify(s.tokenAuth, jwt.TokenFromCookie), s.jwtAuthenticatorPartial(tokenAudienceWebAdminPartial)). Get(webAdminTwoFactorRecoveryPath, s.handleWebAdminTwoFactorRecovery) - s.router.With(jwtauth.Verify(s.tokenAuth, jwtauth.TokenFromCookie), + s.router.With(jwt.Verify(s.tokenAuth, jwt.TokenFromCookie), s.jwtAuthenticatorPartial(tokenAudienceWebAdminPartial)). Post(webAdminTwoFactorRecoveryPath, s.handleWebAdminTwoFactorRecoveryPost) s.router.Get(webAdminForgotPwdPath, s.handleWebAdminForgotPwd) - s.router.With(jwtauth.Verify(s.csrfTokenAuth, jwtauth.TokenFromCookie)). + s.router.With(jwt.Verify(s.csrfTokenAuth, jwt.TokenFromCookie)). Post(webAdminForgotPwdPath, s.handleWebAdminForgotPwdPost) - s.router.With(jwtauth.Verify(s.csrfTokenAuth, jwtauth.TokenFromCookie)). + s.router.With(jwt.Verify(s.csrfTokenAuth, jwt.TokenFromCookie)). Get(webAdminResetPwdPath, s.handleWebAdminPasswordReset) - s.router.With(jwtauth.Verify(s.csrfTokenAuth, jwtauth.TokenFromCookie)). + s.router.With(jwt.Verify(s.csrfTokenAuth, jwt.TokenFromCookie)). Post(webAdminResetPwdPath, s.handleWebAdminPasswordResetPost) } @@ -1732,7 +1740,7 @@ func (s *httpdServer) setupWebAdminRoutes() { if s.binding.OIDC.isEnabled() { router.Use(s.oidcTokenAuthenticator(tokenAudienceWebAdmin)) } - router.Use(jwtauth.Verify(s.tokenAuth, oidcTokenFromContext, jwtauth.TokenFromCookie)) + router.Use(jwt.Verify(s.tokenAuth, oidcTokenFromContext, jwt.TokenFromCookie)) router.Use(jwtAuthenticatorWebAdmin) router.Get(webLogoutPath, s.handleWebAdminLogout) diff --git a/internal/httpd/webadmin.go b/internal/httpd/webadmin.go index 38820cc8..b93cbb6b 100644 --- a/internal/httpd/webadmin.go +++ b/internal/httpd/webadmin.go @@ -16,6 +16,7 @@ package httpd import ( "context" + "crypto/rand" "encoding/json" "errors" "fmt" @@ -31,7 +32,6 @@ import ( "strings" "time" - "github.com/rs/xid" "github.com/sftpgo/sdk" sdkkms "github.com/sftpgo/sdk/kms" @@ -39,6 +39,7 @@ import ( "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/internal/ftpd" + "github.com/drakkan/sftpgo/v2/internal/jwt" "github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/mfa" @@ -726,7 +727,7 @@ func (s *httpdServer) renderForgotPwdPage(w http.ResponseWriter, r *http.Request commonBasePage: getCommonBasePage(r), CurrentURL: webAdminForgotPwdPath, Error: err, - CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, xid.New().String(), webBaseAdminPath), + CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, rand.Text(), webBaseAdminPath), LoginURL: webAdminLoginPath, Title: util.I18nForgotPwdTitle, Branding: s.binding.webAdminBranding(), @@ -863,7 +864,7 @@ func (s *httpdServer) renderAdminSetupPage(w http.ResponseWriter, r *http.Reques commonBasePage: getCommonBasePage(r), Title: util.I18nSetupTitle, CurrentURL: webAdminSetupPath, - CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, xid.New().String(), webBaseAdminPath), + CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, rand.Text(), webBaseAdminPath), Username: username, HasInstallationCode: installationCode != "", InstallationCodeHint: installationCodeHint, @@ -2964,7 +2965,7 @@ func (s *httpdServer) handleWebAdminProfilePost(w http.ResponseWriter, r *http.R s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderProfilePage(w, r, util.NewI18nError(err, util.I18nErrorInvalidToken)) return @@ -2992,7 +2993,7 @@ func (s *httpdServer) handleWebMaintenance(w http.ResponseWriter, r *http.Reques func (s *httpdServer) handleWebRestore(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, MaxRestoreSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return @@ -3045,7 +3046,7 @@ func (s *httpdServer) handleWebRestore(w http.ResponseWriter, r *http.Request) { func getAllAdmins(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, nil, util.I18nErrorInvalidToken, http.StatusForbidden) return @@ -3103,7 +3104,7 @@ func (s *httpdServer) handleWebUpdateAdminGet(w http.ResponseWriter, r *http.Req func (s *httpdServer) handleWebAddAdminPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return @@ -3163,7 +3164,7 @@ func (s *httpdServer) handleWebUpdateAdminPost(w http.ResponseWriter, r *http.Re } updatedAdmin.Filters.TOTPConfig = admin.Filters.TOTPConfig updatedAdmin.Filters.RecoveryCodes = admin.Filters.RecoveryCodes - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderAddUpdateAdminPage(w, r, &updatedAdmin, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken), false) return @@ -3214,7 +3215,7 @@ func (s *httpdServer) handleWebDefenderPage(w http.ResponseWriter, r *http.Reque func getAllUsers(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, nil, util.I18nErrorInvalidToken, http.StatusForbidden) return @@ -3234,7 +3235,7 @@ func getAllUsers(w http.ResponseWriter, r *http.Request) { func (s *httpdServer) handleGetWebUsers(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return @@ -3264,7 +3265,7 @@ func (s *httpdServer) handleWebTemplateFolderGet(w http.ResponseWriter, r *http. func (s *httpdServer) handleWebTemplateFolderPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return @@ -3361,7 +3362,7 @@ func (s *httpdServer) handleWebTemplateUserGet(w http.ResponseWriter, r *http.Re func (s *httpdServer) handleWebTemplateUserPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return @@ -3429,7 +3430,7 @@ func (s *httpdServer) handleWebAddUserGet(w http.ResponseWriter, r *http.Request func (s *httpdServer) handleWebUpdateUserGet(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return @@ -3447,7 +3448,7 @@ func (s *httpdServer) handleWebUpdateUserGet(w http.ResponseWriter, r *http.Requ func (s *httpdServer) handleWebAddUserPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return @@ -3485,7 +3486,7 @@ func (s *httpdServer) handleWebAddUserPost(w http.ResponseWriter, r *http.Reques func (s *httpdServer) handleWebUpdateUserPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return @@ -3552,7 +3553,7 @@ func (s *httpdServer) handleWebGetStatus(w http.ResponseWriter, r *http.Request) func (s *httpdServer) handleWebGetConnections(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return @@ -3569,7 +3570,7 @@ func (s *httpdServer) handleWebAddFolderGet(w http.ResponseWriter, r *http.Reque func (s *httpdServer) handleWebAddFolderPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return @@ -3621,7 +3622,7 @@ func (s *httpdServer) handleWebUpdateFolderGet(w http.ResponseWriter, r *http.Re func (s *httpdServer) handleWebUpdateFolderPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return @@ -3756,7 +3757,7 @@ func (s *httpdServer) handleWebAddGroupGet(w http.ResponseWriter, r *http.Reques func (s *httpdServer) handleWebAddGroupPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return @@ -3794,7 +3795,7 @@ func (s *httpdServer) handleWebUpdateGroupGet(w http.ResponseWriter, r *http.Req func (s *httpdServer) handleWebUpdateGroupPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return @@ -3881,7 +3882,7 @@ func (s *httpdServer) handleWebAddEventActionGet(w http.ResponseWriter, r *http. func (s *httpdServer) handleWebAddEventActionPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return @@ -3918,7 +3919,7 @@ func (s *httpdServer) handleWebUpdateEventActionGet(w http.ResponseWriter, r *ht func (s *httpdServer) handleWebUpdateEventActionPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return @@ -3992,7 +3993,7 @@ func (s *httpdServer) handleWebAddEventRuleGet(w http.ResponseWriter, r *http.Re func (s *httpdServer) handleWebAddEventRulePost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return @@ -4030,7 +4031,7 @@ func (s *httpdServer) handleWebUpdateEventRuleGet(w http.ResponseWriter, r *http func (s *httpdServer) handleWebUpdateEventRulePost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return @@ -4114,7 +4115,7 @@ func (s *httpdServer) handleWebAddRolePost(w http.ResponseWriter, r *http.Reques s.renderRolePage(w, r, role, genericPageModeAdd, err) return } - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return @@ -4146,7 +4147,7 @@ func (s *httpdServer) handleWebUpdateRoleGet(w http.ResponseWriter, r *http.Requ func (s *httpdServer) handleWebUpdateRolePost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return @@ -4228,7 +4229,7 @@ func (s *httpdServer) handleWebAddIPListEntryPost(w http.ResponseWriter, r *http return } entry.Type = listType - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return @@ -4265,7 +4266,7 @@ func (s *httpdServer) handleWebUpdateIPListEntryGet(w http.ResponseWriter, r *ht func (s *httpdServer) handleWebUpdateIPListEntryPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return @@ -4315,7 +4316,7 @@ func (s *httpdServer) handleWebConfigs(w http.ResponseWriter, r *http.Request) { func (s *httpdServer) handleWebConfigsPost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return diff --git a/internal/httpd/webclient.go b/internal/httpd/webclient.go index 23ab9cbf..fa44e0e5 100644 --- a/internal/httpd/webclient.go +++ b/internal/httpd/webclient.go @@ -16,6 +16,7 @@ package httpd import ( "bytes" + "crypto/rand" "encoding/json" "errors" "fmt" @@ -38,6 +39,7 @@ import ( "github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/jwt" "github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/internal/mfa" "github.com/drakkan/sftpgo/v2/internal/smtp" @@ -568,7 +570,7 @@ func (s *httpdServer) renderClientForgotPwdPage(w http.ResponseWriter, r *http.R commonBasePage: getCommonBasePage(r), CurrentURL: webClientForgotPwdPath, Error: err, - CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, xid.New().String(), webBaseClientPath), + CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, rand.Text(), webBaseClientPath), LoginURL: webClientLoginPath, Title: util.I18nForgotPwdTitle, Branding: s.binding.webClientBranding(), @@ -597,7 +599,7 @@ func (s *httpdServer) renderShareLoginPage(w http.ResponseWriter, r *http.Reques Title: util.I18nShareLoginTitle, CurrentURL: r.RequestURI, Error: err, - CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, xid.New().String(), webBaseClientPath), + CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, rand.Text(), webBaseClientPath), Branding: s.binding.webClientBranding(), Languages: s.binding.languages(), CheckRedirect: false, @@ -878,7 +880,7 @@ func (s *httpdServer) renderClientChangePasswordPage(w http.ResponseWriter, r *h func (s *httpdServer) handleWebClientDownloadZip(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxMultipartMem) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderClientForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return @@ -1175,7 +1177,7 @@ func (s *httpdServer) handleShareGetPDF(w http.ResponseWriter, r *http.Request) func (s *httpdServer) handleClientGetDirContents(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, nil, util.I18nErrorDirList403, http.StatusForbidden) return @@ -1261,7 +1263,7 @@ func (s *httpdServer) handleClientGetDirContents(w http.ResponseWriter, r *http. func (s *httpdServer) handleClientGetFiles(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderClientForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return @@ -1319,7 +1321,7 @@ func (s *httpdServer) handleClientGetFiles(w http.ResponseWriter, r *http.Reques func (s *httpdServer) handleClientEditFile(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderClientForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return @@ -1395,7 +1397,7 @@ func (s *httpdServer) handleClientEditFile(w http.ResponseWriter, r *http.Reques func (s *httpdServer) handleClientAddShareGet(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderClientForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return @@ -1437,7 +1439,7 @@ func (s *httpdServer) handleClientAddShareGet(w http.ResponseWriter, r *http.Req func (s *httpdServer) handleClientUpdateShareGet(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderClientForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return @@ -1455,7 +1457,7 @@ func (s *httpdServer) handleClientUpdateShareGet(w http.ResponseWriter, r *http. func (s *httpdServer) handleClientAddSharePost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderClientForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return @@ -1514,7 +1516,7 @@ func (s *httpdServer) handleClientAddSharePost(w http.ResponseWriter, r *http.Re func (s *httpdServer) handleClientUpdateSharePost(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderClientForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return @@ -1583,7 +1585,7 @@ func (s *httpdServer) handleClientUpdateSharePost(w http.ResponseWriter, r *http func getAllShares(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, nil, util.I18nErrorInvalidToken, http.StatusForbidden) return @@ -1633,7 +1635,7 @@ func (s *httpdServer) handleWebClientProfilePost(w http.ResponseWriter, r *http. s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) return } - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderClientForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return @@ -1807,7 +1809,7 @@ func (s *httpdServer) handleClientViewPDF(w http.ResponseWriter, r *http.Request func (s *httpdServer) handleClientGetPDF(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { s.renderClientForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) return @@ -1914,13 +1916,13 @@ func (s *httpdServer) handleClientShareLoginPost(w http.ResponseWriter, r *http. next := path.Clean(r.URL.Query().Get("next")) baseShareURL := path.Join(webClientPubSharesPath, share.ShareID) isRedirect, redirectTo := checkShareRedirectURL(next, baseShareURL) - c := jwtTokenClaims{ + c := &jwt.Claims{ Username: shareID, } if isRedirect { c.Ref = next } - err = c.createAndSetCookie(w, r, s.tokenAuth, tokenAudienceWebShare, ipAddr) + err = createAndSetCookie(w, r, c, s.tokenAuth, tokenAudienceWebShare, ipAddr) if err != nil { s.renderShareLoginPage(w, r, util.NewI18nError(err, util.I18nError500Message)) return @@ -2082,7 +2084,7 @@ func checkShareRedirectURL(next, base string) (bool, string) { func getWebTask(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) - claims, err := getTokenClaims(r) + claims, err := jwt.FromContext(r.Context()) if err != nil || claims.Username == "" { sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) return diff --git a/internal/jwt/jwt.go b/internal/jwt/jwt.go new file mode 100644 index 00000000..48079dc4 --- /dev/null +++ b/internal/jwt/jwt.go @@ -0,0 +1,264 @@ +// Copyright (C) 2025 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +// Package jwt provides functionality for creating, parsing, and validating +// JSON Web Tokens (JWT) used in authentication and authorization workflows. +package jwt + +import ( + "context" + "errors" + "fmt" + "net/http" + "slices" + "strings" + "time" + + "github.com/go-jose/go-jose/v4" + "github.com/go-jose/go-jose/v4/jwt" + "github.com/rs/xid" +) + +var ( + TokenCtxKey = &contextKey{"Token"} + ErrorCtxKey = &contextKey{"Error"} +) + +// contextKey is a value for use with context.WithValue. It's used as +// a pointer so it fits in an interface{} without allocation. This technique +// for defining context keys was copied from Go 1.7's new use of context in net/http. +type contextKey struct { + name string +} + +func (k *contextKey) String() string { + return "jwt context value " + k.name +} + +func NewClaims(audience, ip string, duration time.Duration) *Claims { + now := time.Now() + claims := &Claims{} + claims.IssuedAt = jwt.NewNumericDate(now) + claims.NotBefore = jwt.NewNumericDate(now.Add(-10 * time.Second)) + claims.Expiry = jwt.NewNumericDate(now.Add(duration)) + claims.Audience = []string{audience, ip} + return claims +} + +type Claims struct { + jwt.Claims + Username string `json:"username,omitempty"` + Permissions []string `json:"permissions,omitempty"` + Role string `json:"role,omitempty"` + APIKeyID string `json:"api_key,omitempty"` + NodeID string `json:"node_id,omitempty"` + MustSetTwoFactorAuth bool `json:"2fa_required,omitempty"` + MustChangePassword bool `json:"chpwd,omitempty"` + RequiredTwoFactorProtocols []string `json:"2fa_protos,omitempty"` + HideUserPageSections int `json:"hus,omitempty"` + Ref string `json:"ref,omitempty"` +} + +func (c *Claims) SetIssuedAt(t time.Time) { + c.IssuedAt = jwt.NewNumericDate(t) +} + +func (c *Claims) SetNotBefore(t time.Time) { + c.NotBefore = jwt.NewNumericDate(t) +} + +func (c *Claims) SetExpiry(t time.Time) { + c.Expiry = jwt.NewNumericDate(t) +} + +func (c *Claims) HasPerm(perm string) bool { + for _, p := range c.Permissions { + if p == "*" || p == perm { + return true + } + } + return false +} + +func (c *Claims) HasAnyAudience(audiences []string) bool { + for _, a := range c.Audience { + if slices.Contains(audiences, a) { + return true + } + } + return false +} + +func (c *Claims) GenerateTokenResponse(signer *Signer) (TokenResponse, error) { + token, err := signer.Sign(c) + if err != nil { + return TokenResponse{}, err + } + return c.BuildTokenResponse(token), nil +} + +func (c *Claims) BuildTokenResponse(token string) TokenResponse { + return TokenResponse{Token: token, Expiry: c.Expiry.Time().UTC().Format(time.RFC3339)} +} + +type TokenResponse struct { + Token string `json:"access_token"` + Expiry string `json:"expires_at"` +} + +func NewSigner(algo jose.SignatureAlgorithm, key any) (*Signer, error) { + opts := (&jose.SignerOptions{}).WithType("JWT") + signer, err := jose.NewSigner(jose.SigningKey{Algorithm: algo, Key: key}, opts) + if err != nil { + return nil, err + } + return &Signer{ + signer: signer, + algo: []jose.SignatureAlgorithm{algo}, + key: key, + }, nil +} + +type Signer struct { + algo []jose.SignatureAlgorithm + signer jose.Signer + key any +} + +func (s *Signer) Sign(claims *Claims) (string, error) { + if claims.ID == "" { + claims.ID = xid.New().String() + } + if claims.IssuedAt == nil { + claims.IssuedAt = jwt.NewNumericDate(time.Now()) + } + if claims.NotBefore == nil { + claims.NotBefore = jwt.NewNumericDate(time.Now().Add(-10 * time.Second)) + } + if claims.Expiry == nil { + return "", errors.New("expiration must be set") + } + if len(claims.Audience) == 0 { + return "", errors.New("audience must be set") + } + + return jwt.Signed(s.signer).Claims(claims).Serialize() +} + +func (s *Signer) Signer() jose.Signer { + return s.signer +} + +func (s *Signer) SetSigner(signer jose.Signer) { + s.signer = signer +} + +func (s *Signer) SignWithParams(claims *Claims, audience, ip string, duration time.Duration) (string, error) { + claims.Expiry = jwt.NewNumericDate(time.Now().Add(duration)) + claims.Audience = []string{audience, ip} + return s.Sign(claims) +} + +func NewContext(ctx context.Context, claims *Claims, err error) context.Context { + ctx = context.WithValue(ctx, TokenCtxKey, claims) + ctx = context.WithValue(ctx, ErrorCtxKey, err) + return ctx +} + +func FromContext(ctx context.Context) (*Claims, error) { + val := ctx.Value(TokenCtxKey) + token, ok := val.(*Claims) + if !ok && val != nil { + return nil, fmt.Errorf("invalid type for TokenCtxKey: %T", val) + } + + valErr := ctx.Value(ErrorCtxKey) + err, ok := valErr.(error) + if !ok && valErr != nil { + return nil, fmt.Errorf("invalid type for ErrorCtxKey: %T", valErr) + } + if token == nil { + return nil, errors.New("no token found") + } + + return token, err +} + +func Verify(s *Signer, findTokenFns ...func(r *http.Request) string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + hfn := func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + token, err := VerifyRequest(s, r, findTokenFns...) + ctx = NewContext(ctx, token, err) + next.ServeHTTP(w, r.WithContext(ctx)) + } + return http.HandlerFunc(hfn) + } +} + +func VerifyRequest(s *Signer, r *http.Request, findTokenFns ...func(r *http.Request) string) (*Claims, error) { + var tokenString string + for _, fn := range findTokenFns { + tokenString = fn(r) + if tokenString != "" { + break + } + } + if tokenString == "" { + return nil, errors.New("no token found") + } + return VerifyToken(s, tokenString) +} + +func VerifyToken(s *Signer, payload string) (*Claims, error) { + return VerifyTokenWithKey(payload, s.algo, s.key) +} + +func VerifyTokenWithKey(payload string, algo []jose.SignatureAlgorithm, key any) (*Claims, error) { + token, err := jwt.ParseSigned(payload, algo) + if err != nil { + return nil, err + } + var claims Claims + err = token.Claims(key, &claims) + if err != nil { + return nil, err + } + if err := claims.ValidateWithLeeway(jwt.Expected{Time: time.Now()}, 15*time.Second); err != nil { + return nil, err + } + return &claims, nil +} + +// TokenFromCookie tries to retrieve the token string from a cookie named +// "jwt". +func TokenFromCookie(r *http.Request) string { + cookie, err := r.Cookie("jwt") + if err != nil { + return "" + } + return cookie.Value +} + +// TokenFromHeader tries to retrieve the token string from the +// "Authorization" request header: "Authorization: BEARER T". +func TokenFromHeader(r *http.Request) string { + // Get token from authorization header. + bearer := r.Header.Get("Authorization") + const prefix = "Bearer " + if len(bearer) >= len(prefix) && strings.EqualFold(bearer[:len(prefix)], prefix) { + return bearer[len(prefix):] + } + return "" +} diff --git a/internal/jwt/jwt_test.go b/internal/jwt/jwt_test.go new file mode 100644 index 00000000..9d0a465d --- /dev/null +++ b/internal/jwt/jwt_test.go @@ -0,0 +1,225 @@ +package jwt + +import ( + "context" + "errors" + "fmt" + "io/fs" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/go-jose/go-jose/v4" + "github.com/go-jose/go-jose/v4/jwt" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/drakkan/sftpgo/v2/internal/util" +) + +type failingJoseSigner struct{} + +func (s *failingJoseSigner) Sign(payload []byte) (*jose.JSONWebSignature, error) { + return nil, errors.New("sign test error") +} + +func (s *failingJoseSigner) Options() jose.SignerOptions { + return jose.SignerOptions{} +} + +func TestJWTToken(t *testing.T) { + s, err := NewSigner(jose.HS256, util.GenerateRandomBytes(32)) + require.NoError(t, err) + username := util.GenerateUniqueID() + claims := Claims{ + Username: username, + Claims: jwt.Claims{ + Audience: jwt.Audience{"test"}, + Expiry: jwt.NewNumericDate(time.Now().Add(5 * time.Minute)), + NotBefore: jwt.NewNumericDate(time.Now()), + IssuedAt: jwt.NewNumericDate(time.Now()), + }, + } + token, err := s.Sign(&claims) + require.NoError(t, err) + require.NotEmpty(t, token) + + parsed, err := VerifyToken(s, token) + require.NoError(t, err) + require.Equal(t, username, parsed.Username) + + ja1, err := NewSigner(jose.HS256, util.GenerateRandomBytes(32)) + require.NoError(t, err) + + token, err = ja1.Sign(&claims) + require.NoError(t, err) + require.NotEmpty(t, token) + _, err = VerifyToken(s, token) + require.Error(t, err) + _, err = VerifyToken(ja1, token) + require.NoError(t, err) +} + +func TestClaims(t *testing.T) { + claims := NewClaims(util.GenerateUniqueID(), "", 10*time.Minute) + s, err := NewSigner(jose.HS256, util.GenerateRandomBytes(32)) + require.NoError(t, err) + token, err := s.Sign(claims) + require.NoError(t, err) + assert.NotEmpty(t, token) + assert.NotNil(t, claims.Expiry) + assert.NotNil(t, claims.IssuedAt) + assert.NotNil(t, claims.NotBefore) + + claims = &Claims{ + Permissions: []string{"myperm"}, + } + claims.SetExpiry(time.Now().Add(1 * time.Minute)) + claims.Audience = []string{"testaudience"} + _, err = s.Sign(claims) + assert.NoError(t, err) + assert.NotNil(t, claims.IssuedAt) + assert.NotNil(t, claims.NotBefore) + assert.True(t, claims.HasAnyAudience([]string{util.GenerateUniqueID(), util.GenerateUniqueID(), "testaudience"})) + assert.False(t, claims.HasAnyAudience([]string{util.GenerateUniqueID()})) + assert.True(t, claims.HasPerm("myperm")) + assert.False(t, claims.HasPerm(util.GenerateUniqueID())) + resp, err := claims.GenerateTokenResponse(s) + require.NoError(t, err) + assert.NotEmpty(t, resp.Token) + assert.Equal(t, claims.Expiry.Time().UTC().Format(time.RFC3339), resp.Expiry) + claims.SetIssuedAt(time.Now()) + claims.SetNotBefore(time.Now().Add(10 * time.Minute)) + token, err = s.SignWithParams(claims, util.GenerateUniqueID(), "127.0.0.1", time.Minute) + assert.NoError(t, err) + _, err = VerifyToken(s, token) + assert.ErrorContains(t, err, "nbf") + claims = &Claims{} + _, err = s.Sign(claims) + assert.ErrorContains(t, err, "expiration must be set") + claims.SetExpiry(time.Now()) + _, err = s.Sign(claims) + assert.ErrorContains(t, err, "audience must be set") + claims = &Claims{} + _, err = s.SignWithParams(claims, util.GenerateUniqueID(), "", time.Minute) + assert.NoError(t, err) +} + +func TestClaimsPermissions(t *testing.T) { + c := Claims{ + Permissions: []string{"*"}, + } + assert.True(t, c.HasPerm(util.GenerateUniqueID())) + c.Permissions = []string{"list"} + assert.False(t, c.HasPerm(util.GenerateUniqueID())) + assert.True(t, c.HasPerm("list")) +} + +func TestErrors(t *testing.T) { + s, err := NewSigner(jose.HS256, util.GenerateRandomBytes(32)) + require.NoError(t, err) + _, err = VerifyToken(s, util.GenerateUniqueID()) + assert.Error(t, err) + claims := &Claims{} + claims.SetExpiry(time.Now().Add(-1 * time.Minute)) + token, err := jwt.Signed(s.Signer()).Claims(claims).Serialize() + assert.NoError(t, err) + _, err = VerifyToken(s, token) + assert.ErrorContains(t, err, "exp") + claims.SetExpiry(time.Now().Add(2 * time.Minute)) + claims.SetIssuedAt(time.Now().Add(1 * time.Minute)) + token, err = jwt.Signed(s.Signer()).Claims(claims).Serialize() + assert.NoError(t, err) + _, err = VerifyToken(s, token) + assert.ErrorContains(t, err, "iat") + claims.SetIssuedAt(time.Now()) + claims.SetNotBefore(time.Now().Add(1 * time.Minute)) + token, err = jwt.Signed(s.Signer()).Claims(claims).Serialize() + assert.NoError(t, err) + _, err = VerifyToken(s, token) + assert.ErrorContains(t, err, "nbf") + + s.SetSigner(&failingJoseSigner{}) + claims = NewClaims(util.GenerateUniqueID(), "", time.Minute) + _, err = s.Sign(claims) + assert.Error(t, err) + _, err = claims.GenerateTokenResponse(s) + assert.Error(t, err) + // Wrong algorithm + _, err = NewSigner("PS256", util.GenerateRandomBytes(32)) + assert.Error(t, err) +} + +func TestTokenFromRequest(t *testing.T) { + claims := NewClaims(util.GenerateUniqueID(), "", 10*time.Minute) + s, err := NewSigner(jose.HS256, util.GenerateRandomBytes(32)) + require.NoError(t, err) + token, err := s.Sign(claims) + require.NoError(t, err) + assert.NotEmpty(t, token) + req, err := http.NewRequest(http.MethodGet, "/", nil) + require.NoError(t, err) + req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token)) + cookie := TokenFromCookie(req) + assert.Equal(t, token, cookie) + req, err = http.NewRequest(http.MethodGet, "/", nil) + require.NoError(t, err) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + _, err = VerifyRequest(s, req, TokenFromHeader) + assert.NoError(t, err) + req.Header.Set("Authorization", token) + assert.Empty(t, TokenFromHeader(req)) + assert.Empty(t, TokenFromCookie(req)) + _, err = VerifyRequest(s, req, TokenFromCookie) + assert.ErrorContains(t, err, "no token found") +} + +func TestContext(t *testing.T) { + claims := &Claims{ + Username: util.GenerateUniqueID(), + } + s, err := NewSigner(jose.HS256, util.GenerateRandomBytes(32)) + require.NoError(t, err) + token, err := s.SignWithParams(claims, util.GenerateUniqueID(), "", time.Minute) + require.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, "/", nil) + require.NoError(t, err) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + + h := Verify(s, TokenFromHeader) + wrapped := h(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + token, err := FromContext(r.Context()) + assert.Nil(t, err) + assert.Equal(t, claims.Username, token.Username) + w.WriteHeader(http.StatusOK) + })) + rr := httptest.NewRecorder() + wrapped.ServeHTTP(rr, req) + assert.Equal(t, http.StatusOK, rr.Code) + + _, err = FromContext(context.Background()) + assert.ErrorContains(t, err, "no token found") + + ctx := NewContext(context.Background(), &Claims{}, fs.ErrClosed) + _, err = FromContext(ctx) + assert.Equal(t, fs.ErrClosed, err) + + ctx = context.WithValue(context.Background(), TokenCtxKey, "1") + _, err = FromContext(ctx) + assert.ErrorContains(t, err, "invalid type for TokenCtxKey") + + ctx = context.WithValue(context.Background(), ErrorCtxKey, 2) + _, err = FromContext(ctx) + assert.ErrorContains(t, err, "invalid type for ErrorCtxKey") + claims = NewClaims(util.GenerateUniqueID(), "127.1.1.1", time.Minute) + _, err = s.Sign(claims) + require.NoError(t, err) + ctx = context.WithValue(context.Background(), TokenCtxKey, claims) + claimsFromContext, err := FromContext(ctx) + assert.NoError(t, err) + assert.Equal(t, claims, claimsFromContext) + + assert.Equal(t, "jwt context value Token", TokenCtxKey.String()) +}