package middleware import ( "context" "crypto/tls" "encoding/json" "fmt" "net/http" "strings" "sync" "time" mw "github.com/auth0/go-jwt-middleware/v2" "github.com/golang-jwt/jwt/v5" "github.com/pkg/errors" ) type Auth0 struct { domain string audience string client *http.Client cache JwksCache } func NewAuth0(audience, domain string, strictSsl bool) *Auth0 { customTransport := http.DefaultTransport.(*http.Transport).Clone() customTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: !strictSsl} client := &http.Client{Transport: customTransport} return &Auth0{ domain: domain, audience: audience, client: client, cache: JwksCache{ RWMutex: &sync.RWMutex{}, cache: make(map[string]cacheItem), }, } } type Response struct { Message string `json:"message"` } type Jwks struct { Keys []JSONWebKeys `json:"keys"` } type JSONWebKeys struct { Kty string `json:"kty"` Kid string `json:"kid"` Use string `json:"use"` N string `json:"n"` E string `json:"e"` X5c []string `json:"x5c"` } func (a *Auth0) ValidationKeyGetter() func(token *jwt.Token) (interface{}, error) { return func(token *jwt.Token) (interface{}, error) { // Verify 'aud' claim cert, err := a.getPemCert(token) if err != nil { panic(err.Error()) } result, _ := jwt.ParseRSAPublicKeyFromPEM([]byte(cert)) return result, nil } } func (a *Auth0) Middleware() *mw.JWTMiddleware { issuer := fmt.Sprintf("https://%s/", a.domain) jwtMiddleware := mw.New(func(ctx context.Context, token string) (interface{}, error) { jwtToken, err := jwt.Parse(token, a.ValidationKeyGetter(), jwt.WithAudience(a.audience), jwt.WithIssuer(issuer)) if err != nil { return nil, err } if _, ok := jwtToken.Method.(*jwt.SigningMethodRSA); !ok { return nil, fmt.Errorf("unexpected signing method: %v", jwtToken.Header["alg"]) } return jwtToken, nil }, mw.WithTokenExtractor(func(r *http.Request) (string, error) { token := r.Header.Get("Authorization") if strings.HasPrefix(token, "Bearer ") { return token[7:], nil } return "", nil }), mw.WithCredentialsOptional(true), ) return jwtMiddleware } func TokenFromContext(ctx context.Context) (*jwt.Token, error) { if value := ctx.Value(mw.ContextKey{}); value != nil { if u, ok := value.(*jwt.Token); ok { return u, nil } return nil, fmt.Errorf("token is in wrong format") } return nil, nil } func (a *Auth0) cacheGetWellknown(url string) (*Jwks, error) { if value := a.cache.get(url); value != nil { return value, nil } jwks := &Jwks{} resp, err := a.client.Get(url) if err != nil { return jwks, err } defer func() { _ = resp.Body.Close() }() err = json.NewDecoder(resp.Body).Decode(jwks) if err == nil && jwks != nil { a.cache.put(url, jwks) } return jwks, err } func (a *Auth0) getPemCert(token *jwt.Token) (string, error) { jwks, err := a.cacheGetWellknown(fmt.Sprintf("https://%s/.well-known/jwks.json", a.domain)) if err != nil { return "", err } var cert string for k := range jwks.Keys { if token.Header["kid"] == jwks.Keys[k].Kid { cert = "-----BEGIN CERTIFICATE-----\n" + jwks.Keys[k].X5c[0] + "\n-----END CERTIFICATE-----" } } if cert == "" { err := errors.New("Unable to find appropriate key.") return cert, err } return cert, nil } type JwksCache struct { *sync.RWMutex cache map[string]cacheItem } type cacheItem struct { data *Jwks expiration time.Time } func (c *JwksCache) get(url string) *Jwks { c.RLock() defer c.RUnlock() if value, ok := c.cache[url]; ok { if time.Now().After(value.expiration) { return nil } return value.data } return nil } func (c *JwksCache) put(url string, jwks *Jwks) { c.Lock() defer c.Unlock() c.cache[url] = cacheItem{ data: jwks, expiration: time.Now().Add(time.Minute * 60), } }