Files
auth0mock/auth/jwt.go
T

213 lines
5.0 KiB
Go

package auth
import (
"crypto/rand"
"crypto/rsa"
"encoding/json"
"fmt"
"time"
"github.com/google/uuid"
"github.com/lestrrat-go/jwx/v3/jwa"
"github.com/lestrrat-go/jwx/v3/jwk"
"github.com/lestrrat-go/jwx/v3/jws"
"github.com/lestrrat-go/jwx/v3/jwt"
)
const (
// TokenExpiry is the default token expiration time
TokenExpiry = 2 * time.Hour
)
// JWTService handles JWT signing and JWKS generation
type JWTService struct {
privateKey *rsa.PrivateKey
jwkSet jwk.Set
issuer string
audience string
adminClaim string
emailClaim string
}
// NewJWTService creates a new JWT service with a generated RSA key pair
func NewJWTService(issuer, audience, adminClaim, emailClaim string) (*JWTService, error) {
// Generate RSA 2048-bit key pair
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, fmt.Errorf("generate RSA key: %w", err)
}
// Create JWK from private key
key, err := jwk.Import(privateKey)
if err != nil {
return nil, fmt.Errorf("create JWK from private key: %w", err)
}
// Set key metadata
keyID := uuid.New().String()
if err := key.Set(jwk.KeyIDKey, keyID); err != nil {
return nil, fmt.Errorf("set key ID: %w", err)
}
if err := key.Set(jwk.AlgorithmKey, jwa.RS256()); err != nil {
return nil, fmt.Errorf("set algorithm: %w", err)
}
if err := key.Set(jwk.KeyUsageKey, "sig"); err != nil {
return nil, fmt.Errorf("set key usage: %w", err)
}
// Create public key for JWKS
publicKey, err := key.PublicKey()
if err != nil {
return nil, fmt.Errorf("get public key: %w", err)
}
// Create JWKS with public key
jwkSet := jwk.NewSet()
if err := jwkSet.AddKey(publicKey); err != nil {
return nil, fmt.Errorf("add key to set: %w", err)
}
return &JWTService{
privateKey: privateKey,
jwkSet: jwkSet,
issuer: issuer,
audience: audience,
adminClaim: adminClaim,
emailClaim: emailClaim,
}, nil
}
// SignToken creates a signed JWT with the given claims
func (s *JWTService) SignToken(claims map[string]interface{}) (string, error) {
// Build JWT token
builder := jwt.NewBuilder()
now := time.Now()
builder.Issuer(s.issuer)
builder.IssuedAt(now)
builder.Expiration(now.Add(TokenExpiry))
// Add all claims
for key, value := range claims {
builder.Claim(key, value)
}
token, err := builder.Build()
if err != nil {
return "", fmt.Errorf("build token: %w", err)
}
// Create JWK from private key for signing
key, err := jwk.Import(s.privateKey)
if err != nil {
return "", fmt.Errorf("create signing key: %w", err)
}
// Get key ID from JWKS
pubKey, _ := s.jwkSet.Key(0)
keyID, _ := pubKey.KeyID()
if err := key.Set(jwk.KeyIDKey, keyID); err != nil {
return "", fmt.Errorf("set key ID: %w", err)
}
// Sign the token
signed, err := jwt.Sign(token, jwt.WithKey(jwa.RS256(), key))
if err != nil {
return "", fmt.Errorf("sign token: %w", err)
}
return string(signed), nil
}
// SignAccessToken creates an access token for the given subject
func (s *JWTService) SignAccessToken(subject, clientID, email string, customClaims []map[string]interface{}) (string, error) {
claims := map[string]interface{}{
"sub": subject,
"aud": []string{s.audience},
"azp": clientID,
}
// Add custom claims
for _, cc := range customClaims {
for k, v := range cc {
claims[k] = v
}
}
// Add email claim
claims[s.emailClaim] = email
return s.SignToken(claims)
}
// SignIDToken creates an ID token for the given subject
func (s *JWTService) SignIDToken(subject, clientID, nonce, email, name, givenName, familyName, picture string, customClaims []map[string]interface{}) (string, error) {
claims := map[string]interface{}{
"sub": subject,
"aud": clientID,
"azp": clientID,
"name": name,
"given_name": givenName,
"family_name": familyName,
"email": email,
"picture": picture,
}
if nonce != "" {
claims["nonce"] = nonce
}
// Add custom claims
for _, cc := range customClaims {
for k, v := range cc {
claims[k] = v
}
}
// Add email claim
claims[s.emailClaim] = email
return s.SignToken(claims)
}
// GetJWKS returns the JSON Web Key Set as JSON bytes
func (s *JWTService) GetJWKS() ([]byte, error) {
return json.Marshal(s.jwkSet)
}
// DecodeToken decodes a JWT without verifying the signature
func (s *JWTService) DecodeToken(tokenString string) (map[string]interface{}, error) {
// Parse without verification
msg, err := jws.Parse([]byte(tokenString))
if err != nil {
return nil, fmt.Errorf("parse token: %w", err)
}
var claims map[string]interface{}
if err := json.Unmarshal(msg.Payload(), &claims); err != nil {
return nil, fmt.Errorf("unmarshal claims: %w", err)
}
return claims, nil
}
// Issuer returns the issuer URL
func (s *JWTService) Issuer() string {
return s.issuer
}
// Audience returns the audience
func (s *JWTService) Audience() string {
return s.audience
}
// AdminClaim returns the admin custom claim key
func (s *JWTService) AdminClaim() string {
return s.adminClaim
}
// EmailClaim returns the email custom claim key
func (s *JWTService) EmailClaim() string {
return s.emailClaim
}