190 lines
4.1 KiB
Go
190 lines
4.1 KiB
Go
|
|
package middleware
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"crypto/tls"
|
||
|
|
"encoding/json"
|
||
|
|
"fmt"
|
||
|
|
"net/http"
|
||
|
|
"strings"
|
||
|
|
"sync"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
mw "github.com/auth0/go-jwt-middleware/v2"
|
||
|
|
"github.com/golang-jwt/jwt/v4"
|
||
|
|
"github.com/pkg/errors"
|
||
|
|
)
|
||
|
|
|
||
|
|
type Auth0 struct {
|
||
|
|
domain string
|
||
|
|
audience string
|
||
|
|
client *http.Client
|
||
|
|
cache JwksCache
|
||
|
|
}
|
||
|
|
|
||
|
|
func NewAuth0(audience, domain string, strictSsl bool) *Auth0 {
|
||
|
|
customTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||
|
|
customTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: !strictSsl}
|
||
|
|
client := &http.Client{Transport: customTransport}
|
||
|
|
|
||
|
|
return &Auth0{
|
||
|
|
domain: domain,
|
||
|
|
audience: audience,
|
||
|
|
client: client,
|
||
|
|
cache: JwksCache{
|
||
|
|
RWMutex: &sync.RWMutex{},
|
||
|
|
cache: make(map[string]cacheItem),
|
||
|
|
},
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
type Response struct {
|
||
|
|
Message string `json:"message"`
|
||
|
|
}
|
||
|
|
|
||
|
|
type Jwks struct {
|
||
|
|
Keys []JSONWebKeys `json:"keys"`
|
||
|
|
}
|
||
|
|
|
||
|
|
type JSONWebKeys struct {
|
||
|
|
Kty string `json:"kty"`
|
||
|
|
Kid string `json:"kid"`
|
||
|
|
Use string `json:"use"`
|
||
|
|
N string `json:"n"`
|
||
|
|
E string `json:"e"`
|
||
|
|
X5c []string `json:"x5c"`
|
||
|
|
}
|
||
|
|
|
||
|
|
func (a *Auth0) ValidationKeyGetter() func(token *jwt.Token) (interface{}, error) {
|
||
|
|
issuer := fmt.Sprintf("https://%s/", a.domain)
|
||
|
|
return func(token *jwt.Token) (interface{}, error) {
|
||
|
|
// Verify 'aud' claim
|
||
|
|
aud := a.audience
|
||
|
|
checkAud := token.Claims.(jwt.MapClaims).VerifyAudience(aud, false)
|
||
|
|
if !checkAud {
|
||
|
|
return token, errors.New("Invalid audience.")
|
||
|
|
}
|
||
|
|
// Verify 'iss' claim
|
||
|
|
iss := issuer
|
||
|
|
checkIss := token.Claims.(jwt.MapClaims).VerifyIssuer(iss, false)
|
||
|
|
if !checkIss {
|
||
|
|
return token, errors.New("Invalid issuer.")
|
||
|
|
}
|
||
|
|
|
||
|
|
cert, err := a.getPemCert(token)
|
||
|
|
if err != nil {
|
||
|
|
panic(err.Error())
|
||
|
|
}
|
||
|
|
|
||
|
|
result, _ := jwt.ParseRSAPublicKeyFromPEM([]byte(cert))
|
||
|
|
return result, nil
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func (a *Auth0) Middleware() *mw.JWTMiddleware {
|
||
|
|
jwtMiddleware := mw.New(func(ctx context.Context, token string) (interface{}, error) {
|
||
|
|
jwtToken, err := jwt.Parse(token, a.ValidationKeyGetter())
|
||
|
|
if err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
if _, ok := jwtToken.Method.(*jwt.SigningMethodRSA); !ok {
|
||
|
|
return nil, fmt.Errorf("unexpected signing method: %v", jwtToken.Header["alg"])
|
||
|
|
}
|
||
|
|
err = jwtToken.Claims.Valid()
|
||
|
|
if err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
return jwtToken, nil
|
||
|
|
},
|
||
|
|
mw.WithTokenExtractor(func(r *http.Request) (string, error) {
|
||
|
|
token := r.Header.Get("Authorization")
|
||
|
|
if strings.HasPrefix(token, "Bearer ") {
|
||
|
|
return token[7:], nil
|
||
|
|
}
|
||
|
|
return "", nil
|
||
|
|
}),
|
||
|
|
mw.WithCredentialsOptional(true),
|
||
|
|
)
|
||
|
|
|
||
|
|
return jwtMiddleware
|
||
|
|
}
|
||
|
|
|
||
|
|
func TokenFromContext(ctx context.Context) (*jwt.Token, error) {
|
||
|
|
if value := ctx.Value(mw.ContextKey{}); value != nil {
|
||
|
|
if u, ok := value.(*jwt.Token); ok {
|
||
|
|
return u, nil
|
||
|
|
}
|
||
|
|
return nil, fmt.Errorf("token is in wrong format")
|
||
|
|
}
|
||
|
|
return nil, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (a *Auth0) cacheGetWellknown(url string) (*Jwks, error) {
|
||
|
|
if value := a.cache.get(url); value != nil {
|
||
|
|
return value, nil
|
||
|
|
}
|
||
|
|
jwks := &Jwks{}
|
||
|
|
resp, err := a.client.Get(url)
|
||
|
|
if err != nil {
|
||
|
|
return jwks, err
|
||
|
|
}
|
||
|
|
defer func() {
|
||
|
|
_ = resp.Body.Close()
|
||
|
|
}()
|
||
|
|
err = json.NewDecoder(resp.Body).Decode(jwks)
|
||
|
|
if err == nil && jwks != nil {
|
||
|
|
a.cache.put(url, jwks)
|
||
|
|
}
|
||
|
|
return jwks, err
|
||
|
|
}
|
||
|
|
|
||
|
|
func (a *Auth0) getPemCert(token *jwt.Token) (string, error) {
|
||
|
|
jwks, err := a.cacheGetWellknown(fmt.Sprintf("https://%s/.well-known/jwks.json", a.domain))
|
||
|
|
if err != nil {
|
||
|
|
return "", err
|
||
|
|
}
|
||
|
|
var cert string
|
||
|
|
for k := range jwks.Keys {
|
||
|
|
if token.Header["kid"] == jwks.Keys[k].Kid {
|
||
|
|
cert = "-----BEGIN CERTIFICATE-----\n" + jwks.Keys[k].X5c[0] + "\n-----END CERTIFICATE-----"
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
if cert == "" {
|
||
|
|
err := errors.New("Unable to find appropriate key.")
|
||
|
|
return cert, err
|
||
|
|
}
|
||
|
|
|
||
|
|
return cert, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
type JwksCache struct {
|
||
|
|
*sync.RWMutex
|
||
|
|
cache map[string]cacheItem
|
||
|
|
}
|
||
|
|
type cacheItem struct {
|
||
|
|
data *Jwks
|
||
|
|
expiration time.Time
|
||
|
|
}
|
||
|
|
|
||
|
|
func (c *JwksCache) get(url string) *Jwks {
|
||
|
|
c.RLock()
|
||
|
|
defer c.RUnlock()
|
||
|
|
if value, ok := c.cache[url]; ok {
|
||
|
|
if time.Now().After(value.expiration) {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
return value.data
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (c *JwksCache) put(url string, jwks *Jwks) {
|
||
|
|
c.Lock()
|
||
|
|
defer c.Unlock()
|
||
|
|
c.cache[url] = cacheItem{
|
||
|
|
data: jwks,
|
||
|
|
expiration: time.Now().Add(time.Minute * 60),
|
||
|
|
}
|
||
|
|
}
|