feat: migrate auth0mock from Node.js to Go

Refactor the application to a Go-based architecture for improved
performance and maintainability. Replace the Dockerfile to utilize a
multi-stage build process, enhancing image efficiency. Implement
comprehensive session store tests to ensure reliability and create
new OAuth handlers for managing authentication efficiently. Update 
documentation to reflect these structural changes.
This commit is contained in:
2025-12-29 16:30:37 +01:00
parent 96453e1d15
commit 9992fb4ef1
25 changed files with 1976 additions and 1991 deletions
+212
View File
@@ -0,0 +1,212 @@
package auth
import (
"crypto/rand"
"crypto/rsa"
"encoding/json"
"fmt"
"time"
"github.com/google/uuid"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jws"
"github.com/lestrrat-go/jwx/v2/jwt"
)
const (
// TokenExpiry is the default token expiration time
TokenExpiry = 2 * time.Hour
)
// JWTService handles JWT signing and JWKS generation
type JWTService struct {
privateKey *rsa.PrivateKey
jwkSet jwk.Set
issuer string
audience string
adminClaim string
emailClaim string
}
// NewJWTService creates a new JWT service with a generated RSA key pair
func NewJWTService(issuer, audience, adminClaim, emailClaim string) (*JWTService, error) {
// Generate RSA 2048-bit key pair
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, fmt.Errorf("generate RSA key: %w", err)
}
// Create JWK from private key
key, err := jwk.FromRaw(privateKey)
if err != nil {
return nil, fmt.Errorf("create JWK from private key: %w", err)
}
// Set key metadata
keyID := uuid.New().String()
if err := key.Set(jwk.KeyIDKey, keyID); err != nil {
return nil, fmt.Errorf("set key ID: %w", err)
}
if err := key.Set(jwk.AlgorithmKey, jwa.RS256); err != nil {
return nil, fmt.Errorf("set algorithm: %w", err)
}
if err := key.Set(jwk.KeyUsageKey, "sig"); err != nil {
return nil, fmt.Errorf("set key usage: %w", err)
}
// Create public key for JWKS
publicKey, err := key.PublicKey()
if err != nil {
return nil, fmt.Errorf("get public key: %w", err)
}
// Create JWKS with public key
jwkSet := jwk.NewSet()
if err := jwkSet.AddKey(publicKey); err != nil {
return nil, fmt.Errorf("add key to set: %w", err)
}
return &JWTService{
privateKey: privateKey,
jwkSet: jwkSet,
issuer: issuer,
audience: audience,
adminClaim: adminClaim,
emailClaim: emailClaim,
}, nil
}
// SignToken creates a signed JWT with the given claims
func (s *JWTService) SignToken(claims map[string]interface{}) (string, error) {
// Build JWT token
builder := jwt.NewBuilder()
now := time.Now()
builder.Issuer(s.issuer)
builder.IssuedAt(now)
builder.Expiration(now.Add(TokenExpiry))
// Add all claims
for key, value := range claims {
builder.Claim(key, value)
}
token, err := builder.Build()
if err != nil {
return "", fmt.Errorf("build token: %w", err)
}
// Create JWK from private key for signing
key, err := jwk.FromRaw(s.privateKey)
if err != nil {
return "", fmt.Errorf("create signing key: %w", err)
}
// Get key ID from JWKS
pubKey, _ := s.jwkSet.Key(0)
keyID := pubKey.KeyID()
if err := key.Set(jwk.KeyIDKey, keyID); err != nil {
return "", fmt.Errorf("set key ID: %w", err)
}
// Sign the token
signed, err := jwt.Sign(token, jwt.WithKey(jwa.RS256, key))
if err != nil {
return "", fmt.Errorf("sign token: %w", err)
}
return string(signed), nil
}
// SignAccessToken creates an access token for the given subject
func (s *JWTService) SignAccessToken(subject, clientID, email string, customClaims []map[string]interface{}) (string, error) {
claims := map[string]interface{}{
"sub": subject,
"aud": []string{s.audience},
"azp": clientID,
}
// Add custom claims
for _, cc := range customClaims {
for k, v := range cc {
claims[k] = v
}
}
// Add email claim
claims[s.emailClaim] = email
return s.SignToken(claims)
}
// SignIDToken creates an ID token for the given subject
func (s *JWTService) SignIDToken(subject, clientID, nonce, email, name, givenName, familyName, picture string, customClaims []map[string]interface{}) (string, error) {
claims := map[string]interface{}{
"sub": subject,
"aud": clientID,
"azp": clientID,
"name": name,
"given_name": givenName,
"family_name": familyName,
"email": email,
"picture": picture,
}
if nonce != "" {
claims["nonce"] = nonce
}
// Add custom claims
for _, cc := range customClaims {
for k, v := range cc {
claims[k] = v
}
}
// Add email claim
claims[s.emailClaim] = email
return s.SignToken(claims)
}
// GetJWKS returns the JSON Web Key Set as JSON bytes
func (s *JWTService) GetJWKS() ([]byte, error) {
return json.Marshal(s.jwkSet)
}
// DecodeToken decodes a JWT without verifying the signature
func (s *JWTService) DecodeToken(tokenString string) (map[string]interface{}, error) {
// Parse without verification
msg, err := jws.Parse([]byte(tokenString))
if err != nil {
return nil, fmt.Errorf("parse token: %w", err)
}
var claims map[string]interface{}
if err := json.Unmarshal(msg.Payload(), &claims); err != nil {
return nil, fmt.Errorf("unmarshal claims: %w", err)
}
return claims, nil
}
// Issuer returns the issuer URL
func (s *JWTService) Issuer() string {
return s.issuer
}
// Audience returns the audience
func (s *JWTService) Audience() string {
return s.audience
}
// AdminClaim returns the admin custom claim key
func (s *JWTService) AdminClaim() string {
return s.adminClaim
}
// EmailClaim returns the email custom claim key
func (s *JWTService) EmailClaim() string {
return s.emailClaim
}
+151
View File
@@ -0,0 +1,151 @@
package auth
import (
"encoding/json"
"testing"
)
func TestNewJWTService(t *testing.T) {
service, err := NewJWTService("https://test.example.com/", "https://audience", "https://admin", "https://email")
if err != nil {
t.Fatalf("failed to create JWT service: %v", err)
}
if service.Issuer() != "https://test.example.com/" {
t.Errorf("expected issuer https://test.example.com/, got %s", service.Issuer())
}
if service.Audience() != "https://audience" {
t.Errorf("expected audience https://audience, got %s", service.Audience())
}
}
func TestSignToken(t *testing.T) {
service, err := NewJWTService("https://test.example.com/", "https://audience", "https://admin", "https://email")
if err != nil {
t.Fatalf("failed to create JWT service: %v", err)
}
claims := map[string]interface{}{
"sub": "test-subject",
"aud": "test-audience",
}
token, err := service.SignToken(claims)
if err != nil {
t.Fatalf("failed to sign token: %v", err)
}
if token == "" {
t.Error("expected non-empty token")
}
// Verify token can be decoded
decoded, err := service.DecodeToken(token)
if err != nil {
t.Fatalf("failed to decode token: %v", err)
}
if decoded["sub"] != "test-subject" {
t.Errorf("expected sub=test-subject, got %v", decoded["sub"])
}
}
func TestSignAccessToken(t *testing.T) {
service, err := NewJWTService("https://test.example.com/", "https://audience", "https://admin", "https://email")
if err != nil {
t.Fatalf("failed to create JWT service: %v", err)
}
customClaims := []map[string]interface{}{
{"https://admin": true},
}
token, err := service.SignAccessToken("auth0|user@example.com", "client-id", "user@example.com", customClaims)
if err != nil {
t.Fatalf("failed to sign access token: %v", err)
}
decoded, err := service.DecodeToken(token)
if err != nil {
t.Fatalf("failed to decode token: %v", err)
}
if decoded["sub"] != "auth0|user@example.com" {
t.Errorf("expected sub=auth0|user@example.com, got %v", decoded["sub"])
}
if decoded["https://email"] != "user@example.com" {
t.Errorf("expected email claim, got %v", decoded["https://email"])
}
}
func TestSignIDToken(t *testing.T) {
service, err := NewJWTService("https://test.example.com/", "https://audience", "https://admin", "https://email")
if err != nil {
t.Fatalf("failed to create JWT service: %v", err)
}
token, err := service.SignIDToken(
"auth0|user@example.com",
"client-id",
"test-nonce",
"user@example.com",
"Test User",
"Test",
"User",
"https://example.com/picture.jpg",
nil,
)
if err != nil {
t.Fatalf("failed to sign ID token: %v", err)
}
decoded, err := service.DecodeToken(token)
if err != nil {
t.Fatalf("failed to decode token: %v", err)
}
if decoded["name"] != "Test User" {
t.Errorf("expected name=Test User, got %v", decoded["name"])
}
if decoded["nonce"] != "test-nonce" {
t.Errorf("expected nonce=test-nonce, got %v", decoded["nonce"])
}
}
func TestGetJWKS(t *testing.T) {
service, err := NewJWTService("https://test.example.com/", "https://audience", "https://admin", "https://email")
if err != nil {
t.Fatalf("failed to create JWT service: %v", err)
}
jwks, err := service.GetJWKS()
if err != nil {
t.Fatalf("failed to get JWKS: %v", err)
}
var result map[string]interface{}
if err := json.Unmarshal(jwks, &result); err != nil {
t.Fatalf("failed to parse JWKS: %v", err)
}
keys, ok := result["keys"].([]interface{})
if !ok {
t.Fatal("expected keys array in JWKS")
}
if len(keys) != 1 {
t.Errorf("expected 1 key, got %d", len(keys))
}
key := keys[0].(map[string]interface{})
if key["kty"] != "RSA" {
t.Errorf("expected kty=RSA, got %v", key["kty"])
}
if key["use"] != "sig" {
t.Errorf("expected use=sig, got %v", key["use"])
}
}
+49
View File
@@ -0,0 +1,49 @@
package auth
import (
"crypto/sha256"
"encoding/base64"
"strings"
)
// PKCEMethod represents the code challenge method
type PKCEMethod string
const (
// PKCEMethodPlain uses the verifier directly as the challenge
PKCEMethodPlain PKCEMethod = "plain"
// PKCEMethodS256 uses SHA256 hash of the verifier
PKCEMethodS256 PKCEMethod = "S256"
)
// VerifyPKCE verifies that the code verifier matches the code challenge
func VerifyPKCE(verifier, challenge string, method PKCEMethod) bool {
if verifier == "" || challenge == "" {
return false
}
switch method {
case PKCEMethodPlain, "":
// Plain method or no method specified - direct comparison
return verifier == challenge
case PKCEMethodS256:
// S256 method - SHA256 hash, base64url encoded
computed := ComputeS256Challenge(verifier)
return computed == challenge
default:
return false
}
}
// ComputeS256Challenge computes the S256 code challenge from a verifier
func ComputeS256Challenge(verifier string) string {
hash := sha256.Sum256([]byte(verifier))
return base64URLEncode(hash[:])
}
// base64URLEncode encodes bytes to base64url without padding
func base64URLEncode(data []byte) string {
encoded := base64.URLEncoding.EncodeToString(data)
// Remove padding
return strings.TrimRight(encoded, "=")
}
+74
View File
@@ -0,0 +1,74 @@
package auth
import (
"testing"
)
func TestVerifyPKCE_Plain(t *testing.T) {
verifier := "test-verifier-12345"
challenge := "test-verifier-12345"
if !VerifyPKCE(verifier, challenge, PKCEMethodPlain) {
t.Error("expected plain PKCE verification to succeed")
}
if VerifyPKCE("wrong-verifier", challenge, PKCEMethodPlain) {
t.Error("expected plain PKCE verification to fail with wrong verifier")
}
}
func TestVerifyPKCE_S256(t *testing.T) {
// Test vector from RFC 7636
verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
challenge := ComputeS256Challenge(verifier)
if !VerifyPKCE(verifier, challenge, PKCEMethodS256) {
t.Error("expected S256 PKCE verification to succeed")
}
if VerifyPKCE("wrong-verifier", challenge, PKCEMethodS256) {
t.Error("expected S256 PKCE verification to fail with wrong verifier")
}
}
func TestVerifyPKCE_EmptyValues(t *testing.T) {
if VerifyPKCE("", "challenge", PKCEMethodS256) {
t.Error("expected PKCE verification to fail with empty verifier")
}
if VerifyPKCE("verifier", "", PKCEMethodS256) {
t.Error("expected PKCE verification to fail with empty challenge")
}
}
func TestVerifyPKCE_DefaultMethod(t *testing.T) {
verifier := "test-verifier"
challenge := "test-verifier"
// Empty method should default to plain
if !VerifyPKCE(verifier, challenge, "") {
t.Error("expected PKCE verification with empty method to use plain")
}
}
func TestComputeS256Challenge(t *testing.T) {
// Known test case
verifier := "abc123"
challenge := ComputeS256Challenge(verifier)
// Challenge should be base64url encoded without padding
if challenge == "" {
t.Error("expected non-empty challenge")
}
// Should not contain padding
if len(challenge) > 0 && challenge[len(challenge)-1] == '=' {
t.Error("challenge should not have padding")
}
// Same verifier should produce same challenge
challenge2 := ComputeS256Challenge(verifier)
if challenge != challenge2 {
t.Error("same verifier should produce same challenge")
}
}