9992fb4ef1
Refactor the application to a Go-based architecture for improved performance and maintainability. Replace the Dockerfile to utilize a multi-stage build process, enhancing image efficiency. Implement comprehensive session store tests to ensure reliability and create new OAuth handlers for managing authentication efficiently. Update documentation to reflect these structural changes.
213 lines
5.1 KiB
Go
213 lines
5.1 KiB
Go
package auth
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"encoding/json"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/lestrrat-go/jwx/v2/jwa"
|
|
"github.com/lestrrat-go/jwx/v2/jwk"
|
|
"github.com/lestrrat-go/jwx/v2/jws"
|
|
"github.com/lestrrat-go/jwx/v2/jwt"
|
|
)
|
|
|
|
const (
|
|
// TokenExpiry is the default token expiration time
|
|
TokenExpiry = 2 * time.Hour
|
|
)
|
|
|
|
// JWTService handles JWT signing and JWKS generation
|
|
type JWTService struct {
|
|
privateKey *rsa.PrivateKey
|
|
jwkSet jwk.Set
|
|
issuer string
|
|
audience string
|
|
adminClaim string
|
|
emailClaim string
|
|
}
|
|
|
|
// NewJWTService creates a new JWT service with a generated RSA key pair
|
|
func NewJWTService(issuer, audience, adminClaim, emailClaim string) (*JWTService, error) {
|
|
// Generate RSA 2048-bit key pair
|
|
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("generate RSA key: %w", err)
|
|
}
|
|
|
|
// Create JWK from private key
|
|
key, err := jwk.FromRaw(privateKey)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create JWK from private key: %w", err)
|
|
}
|
|
|
|
// Set key metadata
|
|
keyID := uuid.New().String()
|
|
if err := key.Set(jwk.KeyIDKey, keyID); err != nil {
|
|
return nil, fmt.Errorf("set key ID: %w", err)
|
|
}
|
|
if err := key.Set(jwk.AlgorithmKey, jwa.RS256); err != nil {
|
|
return nil, fmt.Errorf("set algorithm: %w", err)
|
|
}
|
|
if err := key.Set(jwk.KeyUsageKey, "sig"); err != nil {
|
|
return nil, fmt.Errorf("set key usage: %w", err)
|
|
}
|
|
|
|
// Create public key for JWKS
|
|
publicKey, err := key.PublicKey()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get public key: %w", err)
|
|
}
|
|
|
|
// Create JWKS with public key
|
|
jwkSet := jwk.NewSet()
|
|
if err := jwkSet.AddKey(publicKey); err != nil {
|
|
return nil, fmt.Errorf("add key to set: %w", err)
|
|
}
|
|
|
|
return &JWTService{
|
|
privateKey: privateKey,
|
|
jwkSet: jwkSet,
|
|
issuer: issuer,
|
|
audience: audience,
|
|
adminClaim: adminClaim,
|
|
emailClaim: emailClaim,
|
|
}, nil
|
|
}
|
|
|
|
// SignToken creates a signed JWT with the given claims
|
|
func (s *JWTService) SignToken(claims map[string]interface{}) (string, error) {
|
|
// Build JWT token
|
|
builder := jwt.NewBuilder()
|
|
|
|
now := time.Now()
|
|
builder.Issuer(s.issuer)
|
|
builder.IssuedAt(now)
|
|
builder.Expiration(now.Add(TokenExpiry))
|
|
|
|
// Add all claims
|
|
for key, value := range claims {
|
|
builder.Claim(key, value)
|
|
}
|
|
|
|
token, err := builder.Build()
|
|
if err != nil {
|
|
return "", fmt.Errorf("build token: %w", err)
|
|
}
|
|
|
|
// Create JWK from private key for signing
|
|
key, err := jwk.FromRaw(s.privateKey)
|
|
if err != nil {
|
|
return "", fmt.Errorf("create signing key: %w", err)
|
|
}
|
|
|
|
// Get key ID from JWKS
|
|
pubKey, _ := s.jwkSet.Key(0)
|
|
keyID := pubKey.KeyID()
|
|
if err := key.Set(jwk.KeyIDKey, keyID); err != nil {
|
|
return "", fmt.Errorf("set key ID: %w", err)
|
|
}
|
|
|
|
// Sign the token
|
|
signed, err := jwt.Sign(token, jwt.WithKey(jwa.RS256, key))
|
|
if err != nil {
|
|
return "", fmt.Errorf("sign token: %w", err)
|
|
}
|
|
|
|
return string(signed), nil
|
|
}
|
|
|
|
// SignAccessToken creates an access token for the given subject
|
|
func (s *JWTService) SignAccessToken(subject, clientID, email string, customClaims []map[string]interface{}) (string, error) {
|
|
claims := map[string]interface{}{
|
|
"sub": subject,
|
|
"aud": []string{s.audience},
|
|
"azp": clientID,
|
|
}
|
|
|
|
// Add custom claims
|
|
for _, cc := range customClaims {
|
|
for k, v := range cc {
|
|
claims[k] = v
|
|
}
|
|
}
|
|
|
|
// Add email claim
|
|
claims[s.emailClaim] = email
|
|
|
|
return s.SignToken(claims)
|
|
}
|
|
|
|
// SignIDToken creates an ID token for the given subject
|
|
func (s *JWTService) SignIDToken(subject, clientID, nonce, email, name, givenName, familyName, picture string, customClaims []map[string]interface{}) (string, error) {
|
|
claims := map[string]interface{}{
|
|
"sub": subject,
|
|
"aud": clientID,
|
|
"azp": clientID,
|
|
"name": name,
|
|
"given_name": givenName,
|
|
"family_name": familyName,
|
|
"email": email,
|
|
"picture": picture,
|
|
}
|
|
|
|
if nonce != "" {
|
|
claims["nonce"] = nonce
|
|
}
|
|
|
|
// Add custom claims
|
|
for _, cc := range customClaims {
|
|
for k, v := range cc {
|
|
claims[k] = v
|
|
}
|
|
}
|
|
|
|
// Add email claim
|
|
claims[s.emailClaim] = email
|
|
|
|
return s.SignToken(claims)
|
|
}
|
|
|
|
// GetJWKS returns the JSON Web Key Set as JSON bytes
|
|
func (s *JWTService) GetJWKS() ([]byte, error) {
|
|
return json.Marshal(s.jwkSet)
|
|
}
|
|
|
|
// DecodeToken decodes a JWT without verifying the signature
|
|
func (s *JWTService) DecodeToken(tokenString string) (map[string]interface{}, error) {
|
|
// Parse without verification
|
|
msg, err := jws.Parse([]byte(tokenString))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parse token: %w", err)
|
|
}
|
|
|
|
var claims map[string]interface{}
|
|
if err := json.Unmarshal(msg.Payload(), &claims); err != nil {
|
|
return nil, fmt.Errorf("unmarshal claims: %w", err)
|
|
}
|
|
|
|
return claims, nil
|
|
}
|
|
|
|
// Issuer returns the issuer URL
|
|
func (s *JWTService) Issuer() string {
|
|
return s.issuer
|
|
}
|
|
|
|
// Audience returns the audience
|
|
func (s *JWTService) Audience() string {
|
|
return s.audience
|
|
}
|
|
|
|
// AdminClaim returns the admin custom claim key
|
|
func (s *JWTService) AdminClaim() string {
|
|
return s.adminClaim
|
|
}
|
|
|
|
// EmailClaim returns the email custom claim key
|
|
func (s *JWTService) EmailClaim() string {
|
|
return s.emailClaim
|
|
}
|