Files
schemas/middleware/auth0.go
T

84 lines
1.9 KiB
Go
Raw Normal View History

2023-04-27 07:09:10 +02:00
package middleware
import (
"context"
"fmt"
"log"
"net/url"
2023-04-27 07:09:10 +02:00
jwtmiddleware "github.com/auth0/go-jwt-middleware/v3"
"github.com/auth0/go-jwt-middleware/v3/jwks"
"github.com/auth0/go-jwt-middleware/v3/validator"
2023-04-27 07:09:10 +02:00
)
// CustomClaims contains custom claims from the JWT token.
type CustomClaims struct {
Roles []string `json:"https://unbound.se/roles"`
}
// Validate implements the validator.CustomClaims interface.
func (c CustomClaims) Validate(_ context.Context) error {
return nil
}
2023-04-27 07:09:10 +02:00
type Auth0 struct {
domain string
audience string
}
func NewAuth0(audience, domain string, _ bool) *Auth0 {
2023-04-27 07:09:10 +02:00
return &Auth0{
domain: domain,
audience: audience,
}
}
type Response struct {
Message string `json:"message"`
}
func (a *Auth0) Middleware() *jwtmiddleware.JWTMiddleware {
issuer := fmt.Sprintf("https://%s/", a.domain)
2023-04-27 07:09:10 +02:00
issuerURL, err := url.Parse(issuer)
if err != nil {
log.Fatalf("failed to parse issuer URL: %v", err)
2023-04-27 07:09:10 +02:00
}
provider, err := jwks.NewCachingProvider(jwks.WithIssuerURL(issuerURL))
2023-04-27 07:09:10 +02:00
if err != nil {
log.Fatalf("failed to create JWKS provider: %v", err)
2023-04-27 07:09:10 +02:00
}
jwtValidator, err := validator.New(
validator.WithKeyFunc(provider.KeyFunc),
validator.WithAlgorithm(validator.RS256),
validator.WithIssuer(issuer),
validator.WithAudience(a.audience),
validator.WithCustomClaims(func() validator.CustomClaims {
return &CustomClaims{}
}),
)
2023-04-27 07:09:10 +02:00
if err != nil {
log.Fatalf("failed to create JWT validator: %v", err)
2023-04-27 07:09:10 +02:00
}
jwtMiddleware, err := jwtmiddleware.New(
jwtmiddleware.WithValidator(jwtValidator),
jwtmiddleware.WithCredentialsOptional(true),
)
if err != nil {
log.Fatalf("failed to create JWT middleware: %v", err)
2023-04-27 07:09:10 +02:00
}
return jwtMiddleware
2023-04-27 07:09:10 +02:00
}
func ClaimsFromContext(ctx context.Context) *validator.ValidatedClaims {
claims, err := jwtmiddleware.GetClaims[*validator.ValidatedClaims](ctx)
if err != nil {
return nil
2023-04-27 07:09:10 +02:00
}
return claims
2023-04-27 07:09:10 +02:00
}