213 lines
5.0 KiB
Go
213 lines
5.0 KiB
Go
package auth
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"encoding/json"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/lestrrat-go/jwx/v3/jwa"
|
|
"github.com/lestrrat-go/jwx/v3/jwk"
|
|
"github.com/lestrrat-go/jwx/v3/jws"
|
|
"github.com/lestrrat-go/jwx/v3/jwt"
|
|
)
|
|
|
|
const (
|
|
// TokenExpiry is the default token expiration time
|
|
TokenExpiry = 2 * time.Hour
|
|
)
|
|
|
|
// JWTService handles JWT signing and JWKS generation
|
|
type JWTService struct {
|
|
privateKey *rsa.PrivateKey
|
|
jwkSet jwk.Set
|
|
issuer string
|
|
audience string
|
|
adminClaim string
|
|
emailClaim string
|
|
}
|
|
|
|
// NewJWTService creates a new JWT service with a generated RSA key pair
|
|
func NewJWTService(issuer, audience, adminClaim, emailClaim string) (*JWTService, error) {
|
|
// Generate RSA 2048-bit key pair
|
|
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("generate RSA key: %w", err)
|
|
}
|
|
|
|
// Create JWK from private key
|
|
key, err := jwk.Import(privateKey)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create JWK from private key: %w", err)
|
|
}
|
|
|
|
// Set key metadata
|
|
keyID := uuid.New().String()
|
|
if err := key.Set(jwk.KeyIDKey, keyID); err != nil {
|
|
return nil, fmt.Errorf("set key ID: %w", err)
|
|
}
|
|
if err := key.Set(jwk.AlgorithmKey, jwa.RS256()); err != nil {
|
|
return nil, fmt.Errorf("set algorithm: %w", err)
|
|
}
|
|
if err := key.Set(jwk.KeyUsageKey, "sig"); err != nil {
|
|
return nil, fmt.Errorf("set key usage: %w", err)
|
|
}
|
|
|
|
// Create public key for JWKS
|
|
publicKey, err := key.PublicKey()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get public key: %w", err)
|
|
}
|
|
|
|
// Create JWKS with public key
|
|
jwkSet := jwk.NewSet()
|
|
if err := jwkSet.AddKey(publicKey); err != nil {
|
|
return nil, fmt.Errorf("add key to set: %w", err)
|
|
}
|
|
|
|
return &JWTService{
|
|
privateKey: privateKey,
|
|
jwkSet: jwkSet,
|
|
issuer: issuer,
|
|
audience: audience,
|
|
adminClaim: adminClaim,
|
|
emailClaim: emailClaim,
|
|
}, nil
|
|
}
|
|
|
|
// SignToken creates a signed JWT with the given claims
|
|
func (s *JWTService) SignToken(claims map[string]interface{}) (string, error) {
|
|
// Build JWT token
|
|
builder := jwt.NewBuilder()
|
|
|
|
now := time.Now()
|
|
builder.Issuer(s.issuer)
|
|
builder.IssuedAt(now)
|
|
builder.Expiration(now.Add(TokenExpiry))
|
|
|
|
// Add all claims
|
|
for key, value := range claims {
|
|
builder.Claim(key, value)
|
|
}
|
|
|
|
token, err := builder.Build()
|
|
if err != nil {
|
|
return "", fmt.Errorf("build token: %w", err)
|
|
}
|
|
|
|
// Create JWK from private key for signing
|
|
key, err := jwk.Import(s.privateKey)
|
|
if err != nil {
|
|
return "", fmt.Errorf("create signing key: %w", err)
|
|
}
|
|
|
|
// Get key ID from JWKS
|
|
pubKey, _ := s.jwkSet.Key(0)
|
|
keyID, _ := pubKey.KeyID()
|
|
if err := key.Set(jwk.KeyIDKey, keyID); err != nil {
|
|
return "", fmt.Errorf("set key ID: %w", err)
|
|
}
|
|
|
|
// Sign the token
|
|
signed, err := jwt.Sign(token, jwt.WithKey(jwa.RS256(), key))
|
|
if err != nil {
|
|
return "", fmt.Errorf("sign token: %w", err)
|
|
}
|
|
|
|
return string(signed), nil
|
|
}
|
|
|
|
// SignAccessToken creates an access token for the given subject
|
|
func (s *JWTService) SignAccessToken(subject, clientID, email string, customClaims []map[string]interface{}) (string, error) {
|
|
claims := map[string]interface{}{
|
|
"sub": subject,
|
|
"aud": []string{s.audience},
|
|
"azp": clientID,
|
|
}
|
|
|
|
// Add custom claims
|
|
for _, cc := range customClaims {
|
|
for k, v := range cc {
|
|
claims[k] = v
|
|
}
|
|
}
|
|
|
|
// Add email claim
|
|
claims[s.emailClaim] = email
|
|
|
|
return s.SignToken(claims)
|
|
}
|
|
|
|
// SignIDToken creates an ID token for the given subject
|
|
func (s *JWTService) SignIDToken(subject, clientID, nonce, email, name, givenName, familyName, picture string, customClaims []map[string]interface{}) (string, error) {
|
|
claims := map[string]interface{}{
|
|
"sub": subject,
|
|
"aud": clientID,
|
|
"azp": clientID,
|
|
"name": name,
|
|
"given_name": givenName,
|
|
"family_name": familyName,
|
|
"email": email,
|
|
"picture": picture,
|
|
}
|
|
|
|
if nonce != "" {
|
|
claims["nonce"] = nonce
|
|
}
|
|
|
|
// Add custom claims
|
|
for _, cc := range customClaims {
|
|
for k, v := range cc {
|
|
claims[k] = v
|
|
}
|
|
}
|
|
|
|
// Add email claim
|
|
claims[s.emailClaim] = email
|
|
|
|
return s.SignToken(claims)
|
|
}
|
|
|
|
// GetJWKS returns the JSON Web Key Set as JSON bytes
|
|
func (s *JWTService) GetJWKS() ([]byte, error) {
|
|
return json.Marshal(s.jwkSet)
|
|
}
|
|
|
|
// DecodeToken decodes a JWT without verifying the signature
|
|
func (s *JWTService) DecodeToken(tokenString string) (map[string]interface{}, error) {
|
|
// Parse without verification
|
|
msg, err := jws.Parse([]byte(tokenString))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parse token: %w", err)
|
|
}
|
|
|
|
var claims map[string]interface{}
|
|
if err := json.Unmarshal(msg.Payload(), &claims); err != nil {
|
|
return nil, fmt.Errorf("unmarshal claims: %w", err)
|
|
}
|
|
|
|
return claims, nil
|
|
}
|
|
|
|
// Issuer returns the issuer URL
|
|
func (s *JWTService) Issuer() string {
|
|
return s.issuer
|
|
}
|
|
|
|
// Audience returns the audience
|
|
func (s *JWTService) Audience() string {
|
|
return s.audience
|
|
}
|
|
|
|
// AdminClaim returns the admin custom claim key
|
|
func (s *JWTService) AdminClaim() string {
|
|
return s.adminClaim
|
|
}
|
|
|
|
// EmailClaim returns the email custom claim key
|
|
func (s *JWTService) EmailClaim() string {
|
|
return s.emailClaim
|
|
}
|