Files
schemas/middleware/auth0.go
T

190 lines
4.1 KiB
Go
Raw Normal View History

2023-04-27 07:09:10 +02:00
package middleware
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"net/http"
"strings"
"sync"
"time"
mw "github.com/auth0/go-jwt-middleware/v2"
"github.com/golang-jwt/jwt/v4"
"github.com/pkg/errors"
)
type Auth0 struct {
domain string
audience string
client *http.Client
cache JwksCache
}
func NewAuth0(audience, domain string, strictSsl bool) *Auth0 {
customTransport := http.DefaultTransport.(*http.Transport).Clone()
customTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: !strictSsl}
client := &http.Client{Transport: customTransport}
return &Auth0{
domain: domain,
audience: audience,
client: client,
cache: JwksCache{
RWMutex: &sync.RWMutex{},
cache: make(map[string]cacheItem),
},
}
}
type Response struct {
Message string `json:"message"`
}
type Jwks struct {
Keys []JSONWebKeys `json:"keys"`
}
type JSONWebKeys struct {
Kty string `json:"kty"`
Kid string `json:"kid"`
Use string `json:"use"`
N string `json:"n"`
E string `json:"e"`
X5c []string `json:"x5c"`
}
func (a *Auth0) ValidationKeyGetter() func(token *jwt.Token) (interface{}, error) {
issuer := fmt.Sprintf("https://%s/", a.domain)
return func(token *jwt.Token) (interface{}, error) {
// Verify 'aud' claim
aud := a.audience
checkAud := token.Claims.(jwt.MapClaims).VerifyAudience(aud, false)
if !checkAud {
return token, errors.New("Invalid audience.")
}
// Verify 'iss' claim
iss := issuer
checkIss := token.Claims.(jwt.MapClaims).VerifyIssuer(iss, false)
if !checkIss {
return token, errors.New("Invalid issuer.")
}
cert, err := a.getPemCert(token)
if err != nil {
panic(err.Error())
}
result, _ := jwt.ParseRSAPublicKeyFromPEM([]byte(cert))
return result, nil
}
}
func (a *Auth0) Middleware() *mw.JWTMiddleware {
jwtMiddleware := mw.New(func(ctx context.Context, token string) (interface{}, error) {
jwtToken, err := jwt.Parse(token, a.ValidationKeyGetter())
if err != nil {
return nil, err
}
if _, ok := jwtToken.Method.(*jwt.SigningMethodRSA); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", jwtToken.Header["alg"])
}
err = jwtToken.Claims.Valid()
if err != nil {
return nil, err
}
return jwtToken, nil
},
mw.WithTokenExtractor(func(r *http.Request) (string, error) {
token := r.Header.Get("Authorization")
if strings.HasPrefix(token, "Bearer ") {
return token[7:], nil
}
return "", nil
}),
mw.WithCredentialsOptional(true),
)
return jwtMiddleware
}
func TokenFromContext(ctx context.Context) (*jwt.Token, error) {
if value := ctx.Value(mw.ContextKey{}); value != nil {
if u, ok := value.(*jwt.Token); ok {
return u, nil
}
return nil, fmt.Errorf("token is in wrong format")
}
return nil, nil
}
func (a *Auth0) cacheGetWellknown(url string) (*Jwks, error) {
if value := a.cache.get(url); value != nil {
return value, nil
}
jwks := &Jwks{}
resp, err := a.client.Get(url)
if err != nil {
return jwks, err
}
defer func() {
_ = resp.Body.Close()
}()
err = json.NewDecoder(resp.Body).Decode(jwks)
if err == nil && jwks != nil {
a.cache.put(url, jwks)
}
return jwks, err
}
func (a *Auth0) getPemCert(token *jwt.Token) (string, error) {
jwks, err := a.cacheGetWellknown(fmt.Sprintf("https://%s/.well-known/jwks.json", a.domain))
if err != nil {
return "", err
}
var cert string
for k := range jwks.Keys {
if token.Header["kid"] == jwks.Keys[k].Kid {
cert = "-----BEGIN CERTIFICATE-----\n" + jwks.Keys[k].X5c[0] + "\n-----END CERTIFICATE-----"
}
}
if cert == "" {
err := errors.New("Unable to find appropriate key.")
return cert, err
}
return cert, nil
}
type JwksCache struct {
*sync.RWMutex
cache map[string]cacheItem
}
type cacheItem struct {
data *Jwks
expiration time.Time
}
func (c *JwksCache) get(url string) *Jwks {
c.RLock()
defer c.RUnlock()
if value, ok := c.cache[url]; ok {
if time.Now().After(value.expiration) {
return nil
}
return value.data
}
return nil
}
func (c *JwksCache) put(url string, jwks *Jwks) {
c.Lock()
defer c.Unlock()
c.cache[url] = cacheItem{
data: jwks,
expiration: time.Now().Add(time.Minute * 60),
}
}