package middleware import ( "context" "fmt" "log" "net/url" jwtmiddleware "github.com/auth0/go-jwt-middleware/v3" "github.com/auth0/go-jwt-middleware/v3/jwks" "github.com/auth0/go-jwt-middleware/v3/validator" ) // CustomClaims contains custom claims from the JWT token. type CustomClaims struct { Roles []string `json:"https://unbound.se/roles"` } // Validate implements the validator.CustomClaims interface. func (c CustomClaims) Validate(_ context.Context) error { return nil } type Auth0 struct { domain string audience string } func NewAuth0(audience, domain string, _ bool) *Auth0 { return &Auth0{ domain: domain, audience: audience, } } type Response struct { Message string `json:"message"` } func (a *Auth0) Middleware() *jwtmiddleware.JWTMiddleware { issuer := fmt.Sprintf("https://%s/", a.domain) issuerURL, err := url.Parse(issuer) if err != nil { log.Fatalf("failed to parse issuer URL: %v", err) } provider, err := jwks.NewCachingProvider(jwks.WithIssuerURL(issuerURL)) if err != nil { log.Fatalf("failed to create JWKS provider: %v", err) } jwtValidator, err := validator.New( validator.WithKeyFunc(provider.KeyFunc), validator.WithAlgorithm(validator.RS256), validator.WithIssuer(issuer), validator.WithAudience(a.audience), validator.WithCustomClaims(func() validator.CustomClaims { return &CustomClaims{} }), ) if err != nil { log.Fatalf("failed to create JWT validator: %v", err) } jwtMiddleware, err := jwtmiddleware.New( jwtmiddleware.WithValidator(jwtValidator), jwtmiddleware.WithCredentialsOptional(true), ) if err != nil { log.Fatalf("failed to create JWT middleware: %v", err) } return jwtMiddleware } func ClaimsFromContext(ctx context.Context) *validator.ValidatedClaims { claims, err := jwtmiddleware.GetClaims[*validator.ValidatedClaims](ctx) if err != nil { return nil } return claims }