feat: migrate auth0mock from Node.js to Go
Refactor the application to a Go-based architecture for improved performance and maintainability. Replace the Dockerfile to utilize a multi-stage build process, enhancing image efficiency. Implement comprehensive session store tests to ensure reliability and create new OAuth handlers for managing authentication efficiently. Update documentation to reflect these structural changes.
This commit is contained in:
+212
@@ -0,0 +1,212 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/lestrrat-go/jwx/v2/jwa"
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
"github.com/lestrrat-go/jwx/v2/jws"
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
)
|
||||
|
||||
const (
|
||||
// TokenExpiry is the default token expiration time
|
||||
TokenExpiry = 2 * time.Hour
|
||||
)
|
||||
|
||||
// JWTService handles JWT signing and JWKS generation
|
||||
type JWTService struct {
|
||||
privateKey *rsa.PrivateKey
|
||||
jwkSet jwk.Set
|
||||
issuer string
|
||||
audience string
|
||||
adminClaim string
|
||||
emailClaim string
|
||||
}
|
||||
|
||||
// NewJWTService creates a new JWT service with a generated RSA key pair
|
||||
func NewJWTService(issuer, audience, adminClaim, emailClaim string) (*JWTService, error) {
|
||||
// Generate RSA 2048-bit key pair
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate RSA key: %w", err)
|
||||
}
|
||||
|
||||
// Create JWK from private key
|
||||
key, err := jwk.FromRaw(privateKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create JWK from private key: %w", err)
|
||||
}
|
||||
|
||||
// Set key metadata
|
||||
keyID := uuid.New().String()
|
||||
if err := key.Set(jwk.KeyIDKey, keyID); err != nil {
|
||||
return nil, fmt.Errorf("set key ID: %w", err)
|
||||
}
|
||||
if err := key.Set(jwk.AlgorithmKey, jwa.RS256); err != nil {
|
||||
return nil, fmt.Errorf("set algorithm: %w", err)
|
||||
}
|
||||
if err := key.Set(jwk.KeyUsageKey, "sig"); err != nil {
|
||||
return nil, fmt.Errorf("set key usage: %w", err)
|
||||
}
|
||||
|
||||
// Create public key for JWKS
|
||||
publicKey, err := key.PublicKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get public key: %w", err)
|
||||
}
|
||||
|
||||
// Create JWKS with public key
|
||||
jwkSet := jwk.NewSet()
|
||||
if err := jwkSet.AddKey(publicKey); err != nil {
|
||||
return nil, fmt.Errorf("add key to set: %w", err)
|
||||
}
|
||||
|
||||
return &JWTService{
|
||||
privateKey: privateKey,
|
||||
jwkSet: jwkSet,
|
||||
issuer: issuer,
|
||||
audience: audience,
|
||||
adminClaim: adminClaim,
|
||||
emailClaim: emailClaim,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SignToken creates a signed JWT with the given claims
|
||||
func (s *JWTService) SignToken(claims map[string]interface{}) (string, error) {
|
||||
// Build JWT token
|
||||
builder := jwt.NewBuilder()
|
||||
|
||||
now := time.Now()
|
||||
builder.Issuer(s.issuer)
|
||||
builder.IssuedAt(now)
|
||||
builder.Expiration(now.Add(TokenExpiry))
|
||||
|
||||
// Add all claims
|
||||
for key, value := range claims {
|
||||
builder.Claim(key, value)
|
||||
}
|
||||
|
||||
token, err := builder.Build()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("build token: %w", err)
|
||||
}
|
||||
|
||||
// Create JWK from private key for signing
|
||||
key, err := jwk.FromRaw(s.privateKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create signing key: %w", err)
|
||||
}
|
||||
|
||||
// Get key ID from JWKS
|
||||
pubKey, _ := s.jwkSet.Key(0)
|
||||
keyID := pubKey.KeyID()
|
||||
if err := key.Set(jwk.KeyIDKey, keyID); err != nil {
|
||||
return "", fmt.Errorf("set key ID: %w", err)
|
||||
}
|
||||
|
||||
// Sign the token
|
||||
signed, err := jwt.Sign(token, jwt.WithKey(jwa.RS256, key))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("sign token: %w", err)
|
||||
}
|
||||
|
||||
return string(signed), nil
|
||||
}
|
||||
|
||||
// SignAccessToken creates an access token for the given subject
|
||||
func (s *JWTService) SignAccessToken(subject, clientID, email string, customClaims []map[string]interface{}) (string, error) {
|
||||
claims := map[string]interface{}{
|
||||
"sub": subject,
|
||||
"aud": []string{s.audience},
|
||||
"azp": clientID,
|
||||
}
|
||||
|
||||
// Add custom claims
|
||||
for _, cc := range customClaims {
|
||||
for k, v := range cc {
|
||||
claims[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
// Add email claim
|
||||
claims[s.emailClaim] = email
|
||||
|
||||
return s.SignToken(claims)
|
||||
}
|
||||
|
||||
// SignIDToken creates an ID token for the given subject
|
||||
func (s *JWTService) SignIDToken(subject, clientID, nonce, email, name, givenName, familyName, picture string, customClaims []map[string]interface{}) (string, error) {
|
||||
claims := map[string]interface{}{
|
||||
"sub": subject,
|
||||
"aud": clientID,
|
||||
"azp": clientID,
|
||||
"name": name,
|
||||
"given_name": givenName,
|
||||
"family_name": familyName,
|
||||
"email": email,
|
||||
"picture": picture,
|
||||
}
|
||||
|
||||
if nonce != "" {
|
||||
claims["nonce"] = nonce
|
||||
}
|
||||
|
||||
// Add custom claims
|
||||
for _, cc := range customClaims {
|
||||
for k, v := range cc {
|
||||
claims[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
// Add email claim
|
||||
claims[s.emailClaim] = email
|
||||
|
||||
return s.SignToken(claims)
|
||||
}
|
||||
|
||||
// GetJWKS returns the JSON Web Key Set as JSON bytes
|
||||
func (s *JWTService) GetJWKS() ([]byte, error) {
|
||||
return json.Marshal(s.jwkSet)
|
||||
}
|
||||
|
||||
// DecodeToken decodes a JWT without verifying the signature
|
||||
func (s *JWTService) DecodeToken(tokenString string) (map[string]interface{}, error) {
|
||||
// Parse without verification
|
||||
msg, err := jws.Parse([]byte(tokenString))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse token: %w", err)
|
||||
}
|
||||
|
||||
var claims map[string]interface{}
|
||||
if err := json.Unmarshal(msg.Payload(), &claims); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal claims: %w", err)
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// Issuer returns the issuer URL
|
||||
func (s *JWTService) Issuer() string {
|
||||
return s.issuer
|
||||
}
|
||||
|
||||
// Audience returns the audience
|
||||
func (s *JWTService) Audience() string {
|
||||
return s.audience
|
||||
}
|
||||
|
||||
// AdminClaim returns the admin custom claim key
|
||||
func (s *JWTService) AdminClaim() string {
|
||||
return s.adminClaim
|
||||
}
|
||||
|
||||
// EmailClaim returns the email custom claim key
|
||||
func (s *JWTService) EmailClaim() string {
|
||||
return s.emailClaim
|
||||
}
|
||||
@@ -0,0 +1,151 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewJWTService(t *testing.T) {
|
||||
service, err := NewJWTService("https://test.example.com/", "https://audience", "https://admin", "https://email")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create JWT service: %v", err)
|
||||
}
|
||||
|
||||
if service.Issuer() != "https://test.example.com/" {
|
||||
t.Errorf("expected issuer https://test.example.com/, got %s", service.Issuer())
|
||||
}
|
||||
|
||||
if service.Audience() != "https://audience" {
|
||||
t.Errorf("expected audience https://audience, got %s", service.Audience())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignToken(t *testing.T) {
|
||||
service, err := NewJWTService("https://test.example.com/", "https://audience", "https://admin", "https://email")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create JWT service: %v", err)
|
||||
}
|
||||
|
||||
claims := map[string]interface{}{
|
||||
"sub": "test-subject",
|
||||
"aud": "test-audience",
|
||||
}
|
||||
|
||||
token, err := service.SignToken(claims)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to sign token: %v", err)
|
||||
}
|
||||
|
||||
if token == "" {
|
||||
t.Error("expected non-empty token")
|
||||
}
|
||||
|
||||
// Verify token can be decoded
|
||||
decoded, err := service.DecodeToken(token)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to decode token: %v", err)
|
||||
}
|
||||
|
||||
if decoded["sub"] != "test-subject" {
|
||||
t.Errorf("expected sub=test-subject, got %v", decoded["sub"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignAccessToken(t *testing.T) {
|
||||
service, err := NewJWTService("https://test.example.com/", "https://audience", "https://admin", "https://email")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create JWT service: %v", err)
|
||||
}
|
||||
|
||||
customClaims := []map[string]interface{}{
|
||||
{"https://admin": true},
|
||||
}
|
||||
|
||||
token, err := service.SignAccessToken("auth0|user@example.com", "client-id", "user@example.com", customClaims)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to sign access token: %v", err)
|
||||
}
|
||||
|
||||
decoded, err := service.DecodeToken(token)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to decode token: %v", err)
|
||||
}
|
||||
|
||||
if decoded["sub"] != "auth0|user@example.com" {
|
||||
t.Errorf("expected sub=auth0|user@example.com, got %v", decoded["sub"])
|
||||
}
|
||||
|
||||
if decoded["https://email"] != "user@example.com" {
|
||||
t.Errorf("expected email claim, got %v", decoded["https://email"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignIDToken(t *testing.T) {
|
||||
service, err := NewJWTService("https://test.example.com/", "https://audience", "https://admin", "https://email")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create JWT service: %v", err)
|
||||
}
|
||||
|
||||
token, err := service.SignIDToken(
|
||||
"auth0|user@example.com",
|
||||
"client-id",
|
||||
"test-nonce",
|
||||
"user@example.com",
|
||||
"Test User",
|
||||
"Test",
|
||||
"User",
|
||||
"https://example.com/picture.jpg",
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to sign ID token: %v", err)
|
||||
}
|
||||
|
||||
decoded, err := service.DecodeToken(token)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to decode token: %v", err)
|
||||
}
|
||||
|
||||
if decoded["name"] != "Test User" {
|
||||
t.Errorf("expected name=Test User, got %v", decoded["name"])
|
||||
}
|
||||
|
||||
if decoded["nonce"] != "test-nonce" {
|
||||
t.Errorf("expected nonce=test-nonce, got %v", decoded["nonce"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetJWKS(t *testing.T) {
|
||||
service, err := NewJWTService("https://test.example.com/", "https://audience", "https://admin", "https://email")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create JWT service: %v", err)
|
||||
}
|
||||
|
||||
jwks, err := service.GetJWKS()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get JWKS: %v", err)
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(jwks, &result); err != nil {
|
||||
t.Fatalf("failed to parse JWKS: %v", err)
|
||||
}
|
||||
|
||||
keys, ok := result["keys"].([]interface{})
|
||||
if !ok {
|
||||
t.Fatal("expected keys array in JWKS")
|
||||
}
|
||||
|
||||
if len(keys) != 1 {
|
||||
t.Errorf("expected 1 key, got %d", len(keys))
|
||||
}
|
||||
|
||||
key := keys[0].(map[string]interface{})
|
||||
if key["kty"] != "RSA" {
|
||||
t.Errorf("expected kty=RSA, got %v", key["kty"])
|
||||
}
|
||||
|
||||
if key["use"] != "sig" {
|
||||
t.Errorf("expected use=sig, got %v", key["use"])
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// PKCEMethod represents the code challenge method
|
||||
type PKCEMethod string
|
||||
|
||||
const (
|
||||
// PKCEMethodPlain uses the verifier directly as the challenge
|
||||
PKCEMethodPlain PKCEMethod = "plain"
|
||||
// PKCEMethodS256 uses SHA256 hash of the verifier
|
||||
PKCEMethodS256 PKCEMethod = "S256"
|
||||
)
|
||||
|
||||
// VerifyPKCE verifies that the code verifier matches the code challenge
|
||||
func VerifyPKCE(verifier, challenge string, method PKCEMethod) bool {
|
||||
if verifier == "" || challenge == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
switch method {
|
||||
case PKCEMethodPlain, "":
|
||||
// Plain method or no method specified - direct comparison
|
||||
return verifier == challenge
|
||||
case PKCEMethodS256:
|
||||
// S256 method - SHA256 hash, base64url encoded
|
||||
computed := ComputeS256Challenge(verifier)
|
||||
return computed == challenge
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// ComputeS256Challenge computes the S256 code challenge from a verifier
|
||||
func ComputeS256Challenge(verifier string) string {
|
||||
hash := sha256.Sum256([]byte(verifier))
|
||||
return base64URLEncode(hash[:])
|
||||
}
|
||||
|
||||
// base64URLEncode encodes bytes to base64url without padding
|
||||
func base64URLEncode(data []byte) string {
|
||||
encoded := base64.URLEncoding.EncodeToString(data)
|
||||
// Remove padding
|
||||
return strings.TrimRight(encoded, "=")
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestVerifyPKCE_Plain(t *testing.T) {
|
||||
verifier := "test-verifier-12345"
|
||||
challenge := "test-verifier-12345"
|
||||
|
||||
if !VerifyPKCE(verifier, challenge, PKCEMethodPlain) {
|
||||
t.Error("expected plain PKCE verification to succeed")
|
||||
}
|
||||
|
||||
if VerifyPKCE("wrong-verifier", challenge, PKCEMethodPlain) {
|
||||
t.Error("expected plain PKCE verification to fail with wrong verifier")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyPKCE_S256(t *testing.T) {
|
||||
// Test vector from RFC 7636
|
||||
verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
|
||||
challenge := ComputeS256Challenge(verifier)
|
||||
|
||||
if !VerifyPKCE(verifier, challenge, PKCEMethodS256) {
|
||||
t.Error("expected S256 PKCE verification to succeed")
|
||||
}
|
||||
|
||||
if VerifyPKCE("wrong-verifier", challenge, PKCEMethodS256) {
|
||||
t.Error("expected S256 PKCE verification to fail with wrong verifier")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyPKCE_EmptyValues(t *testing.T) {
|
||||
if VerifyPKCE("", "challenge", PKCEMethodS256) {
|
||||
t.Error("expected PKCE verification to fail with empty verifier")
|
||||
}
|
||||
|
||||
if VerifyPKCE("verifier", "", PKCEMethodS256) {
|
||||
t.Error("expected PKCE verification to fail with empty challenge")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyPKCE_DefaultMethod(t *testing.T) {
|
||||
verifier := "test-verifier"
|
||||
challenge := "test-verifier"
|
||||
|
||||
// Empty method should default to plain
|
||||
if !VerifyPKCE(verifier, challenge, "") {
|
||||
t.Error("expected PKCE verification with empty method to use plain")
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeS256Challenge(t *testing.T) {
|
||||
// Known test case
|
||||
verifier := "abc123"
|
||||
challenge := ComputeS256Challenge(verifier)
|
||||
|
||||
// Challenge should be base64url encoded without padding
|
||||
if challenge == "" {
|
||||
t.Error("expected non-empty challenge")
|
||||
}
|
||||
|
||||
// Should not contain padding
|
||||
if len(challenge) > 0 && challenge[len(challenge)-1] == '=' {
|
||||
t.Error("challenge should not have padding")
|
||||
}
|
||||
|
||||
// Same verifier should produce same challenge
|
||||
challenge2 := ComputeS256Challenge(verifier)
|
||||
if challenge != challenge2 {
|
||||
t.Error("same verifier should produce same challenge")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user