fix: migrate to go-jwt-middleware v3 API
schemas / check-release (pull_request) Successful in 1m57s
schemas / vulnerabilities (pull_request) Successful in 2m48s
schemas / check (pull_request) Successful in 8m17s
pre-commit / pre-commit (pull_request) Successful in 11m38s
schemas / build (pull_request) Successful in 5m31s
schemas / deploy-prod (pull_request) Has been skipped
schemas / check-release (pull_request) Successful in 1m57s
schemas / vulnerabilities (pull_request) Successful in 2m48s
schemas / check (pull_request) Successful in 8m17s
pre-commit / pre-commit (pull_request) Successful in 11m38s
schemas / build (pull_request) Successful in 5m31s
schemas / deploy-prod (pull_request) Has been skipped
- Use validator and jwks packages for JWT validation - Replace manual JWKS caching with jwks.NewCachingProvider - Add CustomClaims struct for https://unbound.se/roles claim - Rename TokenFromContext to ClaimsFromContext - Update middleware/auth.go to use new claims structure - Update tests to use core.SetClaims and validator.ValidatedClaims Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -9,10 +9,8 @@ require (
|
|||||||
github.com/alecthomas/kong v1.13.0
|
github.com/alecthomas/kong v1.13.0
|
||||||
github.com/apex/log v1.9.0
|
github.com/apex/log v1.9.0
|
||||||
github.com/auth0/go-jwt-middleware/v3 v3.0.0
|
github.com/auth0/go-jwt-middleware/v3 v3.0.0
|
||||||
github.com/golang-jwt/jwt/v5 v5.3.0
|
|
||||||
github.com/google/uuid v1.6.0
|
github.com/google/uuid v1.6.0
|
||||||
github.com/jmoiron/sqlx v1.4.0
|
github.com/jmoiron/sqlx v1.4.0
|
||||||
github.com/pkg/errors v0.9.1
|
|
||||||
github.com/pressly/goose/v3 v3.26.0
|
github.com/pressly/goose/v3 v3.26.0
|
||||||
github.com/rs/cors v1.11.1
|
github.com/rs/cors v1.11.1
|
||||||
github.com/sparetimecoders/goamqp v0.3.3
|
github.com/sparetimecoders/goamqp v0.3.3
|
||||||
@@ -60,6 +58,7 @@ require (
|
|||||||
github.com/lestrrat-go/option/v2 v2.0.0 // indirect
|
github.com/lestrrat-go/option/v2 v2.0.0 // indirect
|
||||||
github.com/lib/pq v1.10.9 // indirect
|
github.com/lib/pq v1.10.9 // indirect
|
||||||
github.com/mfridman/interpolate v0.0.2 // indirect
|
github.com/mfridman/interpolate v0.0.2 // indirect
|
||||||
|
github.com/pkg/errors v0.9.1 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
github.com/rabbitmq/amqp091-go v1.10.0 // indirect
|
github.com/rabbitmq/amqp091-go v1.10.0 // indirect
|
||||||
github.com/segmentio/asm v1.2.1 // indirect
|
github.com/segmentio/asm v1.2.1 // indirect
|
||||||
|
|||||||
@@ -63,8 +63,6 @@ github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
|
|||||||
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||||
github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM=
|
github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM=
|
||||||
github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
|
github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
|
||||||
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
|
|
||||||
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
|
||||||
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||||
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||||
|
|||||||
+9
-26
@@ -6,7 +6,6 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/99designs/gqlgen/graphql"
|
"github.com/99designs/gqlgen/graphql"
|
||||||
"github.com/golang-jwt/jwt/v5"
|
|
||||||
|
|
||||||
"gitea.unbound.se/unboundsoftware/schemas/domain"
|
"gitea.unbound.se/unboundsoftware/schemas/domain"
|
||||||
)
|
)
|
||||||
@@ -33,14 +32,9 @@ type AuthMiddleware struct {
|
|||||||
func (m *AuthMiddleware) Handler(next http.Handler) http.Handler {
|
func (m *AuthMiddleware) Handler(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
token, err := TokenFromContext(r.Context())
|
claims := ClaimsFromContext(r.Context())
|
||||||
if err != nil {
|
if claims != nil {
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
ctx = context.WithValue(ctx, UserKey, claims.RegisteredClaims.Subject)
|
||||||
_, _ = w.Write([]byte("Invalid JWT token format"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if token != nil {
|
|
||||||
ctx = context.WithValue(ctx, UserKey, token.Claims.(jwt.MapClaims)["sub"])
|
|
||||||
}
|
}
|
||||||
apiKey, err := ApiKeyFromContext(r.Context())
|
apiKey, err := ApiKeyFromContext(r.Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -68,29 +62,18 @@ func UserFromContext(ctx context.Context) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func UserHasRole(ctx context.Context, role string) bool {
|
func UserHasRole(ctx context.Context, role string) bool {
|
||||||
token, err := TokenFromContext(ctx)
|
claims := ClaimsFromContext(ctx)
|
||||||
if err != nil || token == nil {
|
if claims == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
claims, ok := token.Claims.(jwt.MapClaims)
|
customClaims, ok := claims.CustomClaims.(*CustomClaims)
|
||||||
if !ok {
|
if !ok || customClaims == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check the custom roles claim
|
for _, r := range customClaims.Roles {
|
||||||
rolesInterface, ok := claims["https://unbound.se/roles"]
|
if r == role {
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
roles, ok := rolesInterface.([]interface{})
|
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, r := range roles {
|
|
||||||
if roleStr, ok := r.(string); ok && roleStr == role {
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+50
-141
@@ -2,39 +2,34 @@ package middleware
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"log"
|
||||||
"strings"
|
"net/url"
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
mw "github.com/auth0/go-jwt-middleware/v3"
|
jwtmiddleware "github.com/auth0/go-jwt-middleware/v3"
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/auth0/go-jwt-middleware/v3/jwks"
|
||||||
"github.com/pkg/errors"
|
"github.com/auth0/go-jwt-middleware/v3/validator"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// CustomClaims contains custom claims from the JWT token.
|
||||||
|
type CustomClaims struct {
|
||||||
|
Roles []string `json:"https://unbound.se/roles"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate implements the validator.CustomClaims interface.
|
||||||
|
func (c CustomClaims) Validate(_ context.Context) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
type Auth0 struct {
|
type Auth0 struct {
|
||||||
domain string
|
domain string
|
||||||
audience string
|
audience string
|
||||||
client *http.Client
|
|
||||||
cache JwksCache
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAuth0(audience, domain string, strictSsl bool) *Auth0 {
|
func NewAuth0(audience, domain string, _ bool) *Auth0 {
|
||||||
customTransport := http.DefaultTransport.(*http.Transport).Clone()
|
|
||||||
customTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: !strictSsl}
|
|
||||||
client := &http.Client{Transport: customTransport}
|
|
||||||
|
|
||||||
return &Auth0{
|
return &Auth0{
|
||||||
domain: domain,
|
domain: domain,
|
||||||
audience: audience,
|
audience: audience,
|
||||||
client: client,
|
|
||||||
cache: JwksCache{
|
|
||||||
RWMutex: &sync.RWMutex{},
|
|
||||||
cache: make(map[string]cacheItem),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -42,133 +37,47 @@ type Response struct {
|
|||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Jwks struct {
|
func (a *Auth0) Middleware() *jwtmiddleware.JWTMiddleware {
|
||||||
Keys []JSONWebKeys `json:"keys"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type JSONWebKeys struct {
|
|
||||||
Kty string `json:"kty"`
|
|
||||||
Kid string `json:"kid"`
|
|
||||||
Use string `json:"use"`
|
|
||||||
N string `json:"n"`
|
|
||||||
E string `json:"e"`
|
|
||||||
X5c []string `json:"x5c"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Auth0) ValidationKeyGetter() func(token *jwt.Token) (interface{}, error) {
|
|
||||||
return func(token *jwt.Token) (interface{}, error) {
|
|
||||||
// Verify 'aud' claim
|
|
||||||
|
|
||||||
cert, err := a.getPemCert(token)
|
|
||||||
if err != nil {
|
|
||||||
panic(err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
result, _ := jwt.ParseRSAPublicKeyFromPEM([]byte(cert))
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Auth0) Middleware() *mw.JWTMiddleware {
|
|
||||||
issuer := fmt.Sprintf("https://%s/", a.domain)
|
issuer := fmt.Sprintf("https://%s/", a.domain)
|
||||||
jwtMiddleware := mw.New(func(ctx context.Context, token string) (interface{}, error) {
|
|
||||||
jwtToken, err := jwt.Parse(token, a.ValidationKeyGetter(), jwt.WithAudience(a.audience), jwt.WithIssuer(issuer))
|
issuerURL, err := url.Parse(issuer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
log.Fatalf("failed to parse issuer URL: %v", err)
|
||||||
}
|
}
|
||||||
if _, ok := jwtToken.Method.(*jwt.SigningMethodRSA); !ok {
|
|
||||||
return nil, fmt.Errorf("unexpected signing method: %v", jwtToken.Header["alg"])
|
provider, err := jwks.NewCachingProvider(jwks.WithIssuerURL(issuerURL))
|
||||||
}
|
if err != nil {
|
||||||
return jwtToken, nil
|
log.Fatalf("failed to create JWKS provider: %v", err)
|
||||||
},
|
}
|
||||||
mw.WithTokenExtractor(func(r *http.Request) (string, error) {
|
|
||||||
token := r.Header.Get("Authorization")
|
jwtValidator, err := validator.New(
|
||||||
if strings.HasPrefix(token, "Bearer ") {
|
validator.WithKeyFunc(provider.KeyFunc),
|
||||||
return token[7:], nil
|
validator.WithAlgorithm(validator.RS256),
|
||||||
}
|
validator.WithIssuer(issuer),
|
||||||
return "", nil
|
validator.WithAudience(a.audience),
|
||||||
|
validator.WithCustomClaims(func() validator.CustomClaims {
|
||||||
|
return &CustomClaims{}
|
||||||
}),
|
}),
|
||||||
mw.WithCredentialsOptional(true),
|
|
||||||
)
|
)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("failed to create JWT validator: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
jwtMiddleware, err := jwtmiddleware.New(
|
||||||
|
jwtmiddleware.WithValidator(jwtValidator),
|
||||||
|
jwtmiddleware.WithCredentialsOptional(true),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("failed to create JWT middleware: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
return jwtMiddleware
|
return jwtMiddleware
|
||||||
}
|
}
|
||||||
|
|
||||||
func TokenFromContext(ctx context.Context) (*jwt.Token, error) {
|
func ClaimsFromContext(ctx context.Context) *validator.ValidatedClaims {
|
||||||
if value := ctx.Value(mw.ContextKey{}); value != nil {
|
claims, err := jwtmiddleware.GetClaims[*validator.ValidatedClaims](ctx)
|
||||||
if u, ok := value.(*jwt.Token); ok {
|
|
||||||
return u, nil
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("token is in wrong format")
|
|
||||||
}
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Auth0) cacheGetWellknown(url string) (*Jwks, error) {
|
|
||||||
if value := a.cache.get(url); value != nil {
|
|
||||||
return value, nil
|
|
||||||
}
|
|
||||||
jwks := &Jwks{}
|
|
||||||
resp, err := a.client.Get(url)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return jwks, err
|
return nil
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
_ = resp.Body.Close()
|
|
||||||
}()
|
|
||||||
err = json.NewDecoder(resp.Body).Decode(jwks)
|
|
||||||
if err == nil && jwks != nil {
|
|
||||||
a.cache.put(url, jwks)
|
|
||||||
}
|
|
||||||
return jwks, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Auth0) getPemCert(token *jwt.Token) (string, error) {
|
|
||||||
jwks, err := a.cacheGetWellknown(fmt.Sprintf("https://%s/.well-known/jwks.json", a.domain))
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
var cert string
|
|
||||||
for k := range jwks.Keys {
|
|
||||||
if token.Header["kid"] == jwks.Keys[k].Kid {
|
|
||||||
cert = "-----BEGIN CERTIFICATE-----\n" + jwks.Keys[k].X5c[0] + "\n-----END CERTIFICATE-----"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if cert == "" {
|
|
||||||
err := errors.New("Unable to find appropriate key.")
|
|
||||||
return cert, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return cert, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type JwksCache struct {
|
|
||||||
*sync.RWMutex
|
|
||||||
cache map[string]cacheItem
|
|
||||||
}
|
|
||||||
type cacheItem struct {
|
|
||||||
data *Jwks
|
|
||||||
expiration time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *JwksCache) get(url string) *Jwks {
|
|
||||||
c.RLock()
|
|
||||||
defer c.RUnlock()
|
|
||||||
if value, ok := c.cache[url]; ok {
|
|
||||||
if time.Now().After(value.expiration) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return value.data
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *JwksCache) put(url string, jwks *Jwks) {
|
|
||||||
c.Lock()
|
|
||||||
defer c.Unlock()
|
|
||||||
c.cache[url] = cacheItem{
|
|
||||||
data: jwks,
|
|
||||||
expiration: time.Now().Add(time.Minute * 60),
|
|
||||||
}
|
}
|
||||||
|
return claims
|
||||||
}
|
}
|
||||||
|
|||||||
+73
-71
@@ -6,8 +6,8 @@ import (
|
|||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
mw "github.com/auth0/go-jwt-middleware/v3"
|
"github.com/auth0/go-jwt-middleware/v3/core"
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/auth0/go-jwt-middleware/v3/validator"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/mock"
|
"github.com/stretchr/testify/mock"
|
||||||
@@ -155,9 +155,11 @@ func TestAuthMiddleware_Handler_WithValidJWT(t *testing.T) {
|
|||||||
mockCache.On("OrganizationByAPIKey", "").Return(nil)
|
mockCache.On("OrganizationByAPIKey", "").Return(nil)
|
||||||
|
|
||||||
userID := "user-123"
|
userID := "user-123"
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
claims := &validator.ValidatedClaims{
|
||||||
"sub": userID,
|
RegisteredClaims: validator.RegisteredClaims{
|
||||||
})
|
Subject: userID,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
// Create a test handler that checks the context
|
// Create a test handler that checks the context
|
||||||
var capturedUser string
|
var capturedUser string
|
||||||
@@ -170,9 +172,9 @@ func TestAuthMiddleware_Handler_WithValidJWT(t *testing.T) {
|
|||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
})
|
})
|
||||||
|
|
||||||
// Create request with JWT token in context
|
// Create request with JWT claims in context
|
||||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
ctx := context.WithValue(req.Context(), mw.ContextKey{}, token)
|
ctx := core.SetClaims(req.Context(), claims)
|
||||||
req = req.WithContext(ctx)
|
req = req.WithContext(ctx)
|
||||||
|
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
@@ -209,28 +211,35 @@ func TestAuthMiddleware_Handler_APIKeyErrorHandling(t *testing.T) {
|
|||||||
assert.Contains(t, rec.Body.String(), "Invalid API Key format")
|
assert.Contains(t, rec.Body.String(), "Invalid API Key format")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAuthMiddleware_Handler_JWTErrorHandling(t *testing.T) {
|
func TestAuthMiddleware_Handler_JWTMissingClaims(t *testing.T) {
|
||||||
// Setup
|
// Setup
|
||||||
mockCache := new(MockCache)
|
mockCache := new(MockCache)
|
||||||
authMiddleware := NewAuth(mockCache)
|
authMiddleware := NewAuth(mockCache)
|
||||||
|
|
||||||
|
// The middleware passes the plaintext API key (cache handles hashing)
|
||||||
|
mockCache.On("OrganizationByAPIKey", "").Return(nil)
|
||||||
|
|
||||||
|
// Create a test handler that checks the context
|
||||||
|
var capturedUser string
|
||||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if user := r.Context().Value(UserKey); user != nil {
|
||||||
|
if u, ok := user.(string); ok {
|
||||||
|
capturedUser = u
|
||||||
|
}
|
||||||
|
}
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
})
|
})
|
||||||
|
|
||||||
// Create request with invalid JWT token type in context
|
// Create request without JWT claims - user should not be set
|
||||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
ctx := context.WithValue(req.Context(), mw.ContextKey{}, "not-a-token") // Invalid type
|
|
||||||
req = req.WithContext(ctx)
|
|
||||||
|
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
// Execute
|
// Execute
|
||||||
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
assert.Contains(t, rec.Body.String(), "Invalid JWT token format")
|
assert.Empty(t, capturedUser, "User should not be set when no claims in context")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAuthMiddleware_Handler_BothJWTAndAPIKey(t *testing.T) {
|
func TestAuthMiddleware_Handler_BothJWTAndAPIKey(t *testing.T) {
|
||||||
@@ -249,9 +258,11 @@ func TestAuthMiddleware_Handler_BothJWTAndAPIKey(t *testing.T) {
|
|||||||
userID := "user-123"
|
userID := "user-123"
|
||||||
apiKey := "test-api-key-123"
|
apiKey := "test-api-key-123"
|
||||||
|
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
claims := &validator.ValidatedClaims{
|
||||||
"sub": userID,
|
RegisteredClaims: validator.RegisteredClaims{
|
||||||
})
|
Subject: userID,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
// Mock expects plaintext key (cache handles hashing internally)
|
// Mock expects plaintext key (cache handles hashing internally)
|
||||||
mockCache.On("OrganizationByAPIKey", apiKey).Return(expectedOrg)
|
mockCache.On("OrganizationByAPIKey", apiKey).Return(expectedOrg)
|
||||||
@@ -273,9 +284,9 @@ func TestAuthMiddleware_Handler_BothJWTAndAPIKey(t *testing.T) {
|
|||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
})
|
})
|
||||||
|
|
||||||
// Create request with both JWT and API key in context
|
// Create request with both JWT claims and API key in context
|
||||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
ctx := context.WithValue(req.Context(), mw.ContextKey{}, token)
|
ctx := core.SetClaims(req.Context(), claims)
|
||||||
ctx = context.WithValue(ctx, ApiKey, apiKey)
|
ctx = context.WithValue(ctx, ApiKey, apiKey)
|
||||||
req = req.WithContext(ctx)
|
req = req.WithContext(ctx)
|
||||||
|
|
||||||
@@ -475,13 +486,17 @@ func TestAuthMiddleware_Directive_NoRequirements(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUserHasRole_WithValidRole(t *testing.T) {
|
func TestUserHasRole_WithValidRole(t *testing.T) {
|
||||||
// Create token with roles claim
|
// Create claims with roles
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
claims := &validator.ValidatedClaims{
|
||||||
"sub": "user-123",
|
RegisteredClaims: validator.RegisteredClaims{
|
||||||
"https://unbound.se/roles": []interface{}{"admin", "user"},
|
Subject: "user-123",
|
||||||
})
|
},
|
||||||
|
CustomClaims: &CustomClaims{
|
||||||
|
Roles: []string{"admin", "user"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
ctx := context.WithValue(context.Background(), mw.ContextKey{}, token)
|
ctx := core.SetClaims(context.Background(), claims)
|
||||||
|
|
||||||
// Test for existing role
|
// Test for existing role
|
||||||
hasRole := UserHasRole(ctx, "admin")
|
hasRole := UserHasRole(ctx, "admin")
|
||||||
@@ -492,13 +507,17 @@ func TestUserHasRole_WithValidRole(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUserHasRole_WithoutRole(t *testing.T) {
|
func TestUserHasRole_WithoutRole(t *testing.T) {
|
||||||
// Create token with roles claim
|
// Create claims with roles
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
claims := &validator.ValidatedClaims{
|
||||||
"sub": "user-123",
|
RegisteredClaims: validator.RegisteredClaims{
|
||||||
"https://unbound.se/roles": []interface{}{"user"},
|
Subject: "user-123",
|
||||||
})
|
},
|
||||||
|
CustomClaims: &CustomClaims{
|
||||||
|
Roles: []string{"user"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
ctx := context.WithValue(context.Background(), mw.ContextKey{}, token)
|
ctx := core.SetClaims(context.Background(), claims)
|
||||||
|
|
||||||
// Test for non-existing role
|
// Test for non-existing role
|
||||||
hasRole := UserHasRole(ctx, "admin")
|
hasRole := UserHasRole(ctx, "admin")
|
||||||
@@ -506,59 +525,42 @@ func TestUserHasRole_WithoutRole(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUserHasRole_WithoutRolesClaim(t *testing.T) {
|
func TestUserHasRole_WithoutRolesClaim(t *testing.T) {
|
||||||
// Create token without roles claim
|
// Create claims without custom claims
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
claims := &validator.ValidatedClaims{
|
||||||
"sub": "user-123",
|
RegisteredClaims: validator.RegisteredClaims{
|
||||||
})
|
Subject: "user-123",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
ctx := context.WithValue(context.Background(), mw.ContextKey{}, token)
|
ctx := core.SetClaims(context.Background(), claims)
|
||||||
|
|
||||||
// Test should return false when roles claim is missing
|
// Test should return false when custom claims is missing
|
||||||
hasRole := UserHasRole(ctx, "admin")
|
hasRole := UserHasRole(ctx, "admin")
|
||||||
assert.False(t, hasRole)
|
assert.False(t, hasRole)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUserHasRole_WithoutToken(t *testing.T) {
|
func TestUserHasRole_WithoutClaims(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
// Test should return false when no token in context
|
// Test should return false when no claims in context
|
||||||
hasRole := UserHasRole(ctx, "admin")
|
hasRole := UserHasRole(ctx, "admin")
|
||||||
assert.False(t, hasRole)
|
assert.False(t, hasRole)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUserHasRole_WithInvalidTokenType(t *testing.T) {
|
func TestUserHasRole_WithEmptyRoles(t *testing.T) {
|
||||||
// Put invalid token type in context
|
// Create claims with empty roles
|
||||||
ctx := context.WithValue(context.Background(), mw.ContextKey{}, "not-a-token")
|
claims := &validator.ValidatedClaims{
|
||||||
|
RegisteredClaims: validator.RegisteredClaims{
|
||||||
|
Subject: "user-123",
|
||||||
|
},
|
||||||
|
CustomClaims: &CustomClaims{
|
||||||
|
Roles: []string{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
// Test should return false when token type is invalid
|
ctx := core.SetClaims(context.Background(), claims)
|
||||||
hasRole := UserHasRole(ctx, "admin")
|
|
||||||
assert.False(t, hasRole)
|
// Test should return false when roles array is empty
|
||||||
}
|
|
||||||
|
|
||||||
func TestUserHasRole_WithInvalidRolesType(t *testing.T) {
|
|
||||||
// Create token with invalid roles type
|
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
|
||||||
"sub": "user-123",
|
|
||||||
"https://unbound.se/roles": "not-an-array",
|
|
||||||
})
|
|
||||||
|
|
||||||
ctx := context.WithValue(context.Background(), mw.ContextKey{}, token)
|
|
||||||
|
|
||||||
// Test should return false when roles type is invalid
|
|
||||||
hasRole := UserHasRole(ctx, "admin")
|
|
||||||
assert.False(t, hasRole)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUserHasRole_WithInvalidRoleElementType(t *testing.T) {
|
|
||||||
// Create token with invalid role element types
|
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
|
||||||
"sub": "user-123",
|
|
||||||
"https://unbound.se/roles": []interface{}{123, 456}, // Numbers instead of strings
|
|
||||||
})
|
|
||||||
|
|
||||||
ctx := context.WithValue(context.Background(), mw.ContextKey{}, token)
|
|
||||||
|
|
||||||
// Test should return false when role elements are not strings
|
|
||||||
hasRole := UserHasRole(ctx, "admin")
|
hasRole := UserHasRole(ctx, "admin")
|
||||||
assert.False(t, hasRole)
|
assert.False(t, hasRole)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user