feat: organizations and API keys
This commit is contained in:
+5
-25
@@ -4,26 +4,17 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/99designs/gqlgen/graphql"
|
||||
"github.com/apex/log"
|
||||
)
|
||||
|
||||
type ContextKey string
|
||||
|
||||
const (
|
||||
ApiKey = ContextKey("apikey")
|
||||
)
|
||||
|
||||
func NewApiKey(apiKey string, logger log.Interface) *ApiKeyMiddleware {
|
||||
return &ApiKeyMiddleware{
|
||||
apiKey: apiKey,
|
||||
}
|
||||
func NewApiKey() *ApiKeyMiddleware {
|
||||
return &ApiKeyMiddleware{}
|
||||
}
|
||||
|
||||
type ApiKeyMiddleware struct {
|
||||
apiKey string
|
||||
}
|
||||
type ApiKeyMiddleware struct{}
|
||||
|
||||
func (m *ApiKeyMiddleware) Handler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -37,23 +28,12 @@ func (m *ApiKeyMiddleware) Handler(next http.Handler) http.Handler {
|
||||
})
|
||||
}
|
||||
|
||||
func (m *ApiKeyMiddleware) Directive(ctx context.Context, obj interface{}, next graphql.Resolver) (res interface{}, err error) {
|
||||
key, err := m.fromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if key != m.apiKey {
|
||||
return nil, fmt.Errorf("invalid API-key")
|
||||
}
|
||||
return next(ctx)
|
||||
}
|
||||
|
||||
func (m *ApiKeyMiddleware) fromContext(ctx context.Context) (string, error) {
|
||||
func ApiKeyFromContext(ctx context.Context) (string, error) {
|
||||
if value := ctx.Value(ApiKey); value != nil {
|
||||
if u, ok := value.(string); ok {
|
||||
return u, nil
|
||||
}
|
||||
return "", fmt.Errorf("current API-key is in wrong format")
|
||||
}
|
||||
return "", fmt.Errorf("no API-key found")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/99designs/gqlgen/graphql"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
|
||||
"gitlab.com/unboundsoftware/schemas/domain"
|
||||
"gitlab.com/unboundsoftware/schemas/hash"
|
||||
)
|
||||
|
||||
const (
|
||||
UserKey = ContextKey("user")
|
||||
OrganizationKey = ContextKey("organization")
|
||||
)
|
||||
|
||||
type Cache interface {
|
||||
OrganizationByAPIKey(apiKey string) *domain.Organization
|
||||
}
|
||||
|
||||
func NewAuth(cache Cache) *AuthMiddleware {
|
||||
return &AuthMiddleware{
|
||||
cache: cache,
|
||||
}
|
||||
}
|
||||
|
||||
type AuthMiddleware struct {
|
||||
cache Cache
|
||||
}
|
||||
|
||||
func (m *AuthMiddleware) Handler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
token, err := TokenFromContext(r.Context())
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
_, _ = w.Write([]byte("Invalid JWT token format"))
|
||||
return
|
||||
}
|
||||
if token != nil {
|
||||
ctx = context.WithValue(ctx, UserKey, token.Claims.(jwt.MapClaims)["sub"])
|
||||
}
|
||||
apiKey, err := ApiKeyFromContext(r.Context())
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
_, _ = w.Write([]byte("Invalid API Key format"))
|
||||
return
|
||||
}
|
||||
if organization := m.cache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
||||
ctx = context.WithValue(ctx, OrganizationKey, *organization)
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
func UserFromContext(ctx context.Context) string {
|
||||
if value := ctx.Value(UserKey); value != nil {
|
||||
if u, ok := value.(string); ok {
|
||||
return u
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func OrganizationFromContext(ctx context.Context) string {
|
||||
if value := ctx.Value(OrganizationKey); value != nil {
|
||||
if u, ok := value.(domain.Organization); ok {
|
||||
return u.ID.String()
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *AuthMiddleware) Directive(ctx context.Context, _ interface{}, next graphql.Resolver, user *bool, organization *bool) (res interface{}, err error) {
|
||||
if user != nil && *user {
|
||||
if u := UserFromContext(ctx); u == "" {
|
||||
return nil, fmt.Errorf("no user available in request")
|
||||
}
|
||||
}
|
||||
if organization != nil && *organization {
|
||||
if orgId := OrganizationFromContext(ctx); orgId == "" {
|
||||
return nil, fmt.Errorf("no organization available in request")
|
||||
}
|
||||
}
|
||||
return next(ctx)
|
||||
}
|
||||
@@ -0,0 +1,189 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
mw "github.com/auth0/go-jwt-middleware/v2"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type Auth0 struct {
|
||||
domain string
|
||||
audience string
|
||||
client *http.Client
|
||||
cache JwksCache
|
||||
}
|
||||
|
||||
func NewAuth0(audience, domain string, strictSsl bool) *Auth0 {
|
||||
customTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
customTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: !strictSsl}
|
||||
client := &http.Client{Transport: customTransport}
|
||||
|
||||
return &Auth0{
|
||||
domain: domain,
|
||||
audience: audience,
|
||||
client: client,
|
||||
cache: JwksCache{
|
||||
RWMutex: &sync.RWMutex{},
|
||||
cache: make(map[string]cacheItem),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type Jwks struct {
|
||||
Keys []JSONWebKeys `json:"keys"`
|
||||
}
|
||||
|
||||
type JSONWebKeys struct {
|
||||
Kty string `json:"kty"`
|
||||
Kid string `json:"kid"`
|
||||
Use string `json:"use"`
|
||||
N string `json:"n"`
|
||||
E string `json:"e"`
|
||||
X5c []string `json:"x5c"`
|
||||
}
|
||||
|
||||
func (a *Auth0) ValidationKeyGetter() func(token *jwt.Token) (interface{}, error) {
|
||||
issuer := fmt.Sprintf("https://%s/", a.domain)
|
||||
return func(token *jwt.Token) (interface{}, error) {
|
||||
// Verify 'aud' claim
|
||||
aud := a.audience
|
||||
checkAud := token.Claims.(jwt.MapClaims).VerifyAudience(aud, false)
|
||||
if !checkAud {
|
||||
return token, errors.New("Invalid audience.")
|
||||
}
|
||||
// Verify 'iss' claim
|
||||
iss := issuer
|
||||
checkIss := token.Claims.(jwt.MapClaims).VerifyIssuer(iss, false)
|
||||
if !checkIss {
|
||||
return token, errors.New("Invalid issuer.")
|
||||
}
|
||||
|
||||
cert, err := a.getPemCert(token)
|
||||
if err != nil {
|
||||
panic(err.Error())
|
||||
}
|
||||
|
||||
result, _ := jwt.ParseRSAPublicKeyFromPEM([]byte(cert))
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Auth0) Middleware() *mw.JWTMiddleware {
|
||||
jwtMiddleware := mw.New(func(ctx context.Context, token string) (interface{}, error) {
|
||||
jwtToken, err := jwt.Parse(token, a.ValidationKeyGetter())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, ok := jwtToken.Method.(*jwt.SigningMethodRSA); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", jwtToken.Header["alg"])
|
||||
}
|
||||
err = jwtToken.Claims.Valid()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return jwtToken, nil
|
||||
},
|
||||
mw.WithTokenExtractor(func(r *http.Request) (string, error) {
|
||||
token := r.Header.Get("Authorization")
|
||||
if strings.HasPrefix(token, "Bearer ") {
|
||||
return token[7:], nil
|
||||
}
|
||||
return "", nil
|
||||
}),
|
||||
mw.WithCredentialsOptional(true),
|
||||
)
|
||||
|
||||
return jwtMiddleware
|
||||
}
|
||||
|
||||
func TokenFromContext(ctx context.Context) (*jwt.Token, error) {
|
||||
if value := ctx.Value(mw.ContextKey{}); value != nil {
|
||||
if u, ok := value.(*jwt.Token); ok {
|
||||
return u, nil
|
||||
}
|
||||
return nil, fmt.Errorf("token is in wrong format")
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Auth0) cacheGetWellknown(url string) (*Jwks, error) {
|
||||
if value := a.cache.get(url); value != nil {
|
||||
return value, nil
|
||||
}
|
||||
jwks := &Jwks{}
|
||||
resp, err := a.client.Get(url)
|
||||
if err != nil {
|
||||
return jwks, err
|
||||
}
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
err = json.NewDecoder(resp.Body).Decode(jwks)
|
||||
if err == nil && jwks != nil {
|
||||
a.cache.put(url, jwks)
|
||||
}
|
||||
return jwks, err
|
||||
}
|
||||
|
||||
func (a *Auth0) getPemCert(token *jwt.Token) (string, error) {
|
||||
jwks, err := a.cacheGetWellknown(fmt.Sprintf("https://%s/.well-known/jwks.json", a.domain))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
var cert string
|
||||
for k := range jwks.Keys {
|
||||
if token.Header["kid"] == jwks.Keys[k].Kid {
|
||||
cert = "-----BEGIN CERTIFICATE-----\n" + jwks.Keys[k].X5c[0] + "\n-----END CERTIFICATE-----"
|
||||
}
|
||||
}
|
||||
|
||||
if cert == "" {
|
||||
err := errors.New("Unable to find appropriate key.")
|
||||
return cert, err
|
||||
}
|
||||
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
type JwksCache struct {
|
||||
*sync.RWMutex
|
||||
cache map[string]cacheItem
|
||||
}
|
||||
type cacheItem struct {
|
||||
data *Jwks
|
||||
expiration time.Time
|
||||
}
|
||||
|
||||
func (c *JwksCache) get(url string) *Jwks {
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
if value, ok := c.cache[url]; ok {
|
||||
if time.Now().After(value.expiration) {
|
||||
return nil
|
||||
}
|
||||
return value.data
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *JwksCache) put(url string, jwks *Jwks) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
c.cache[url] = cacheItem{
|
||||
data: jwks,
|
||||
expiration: time.Now().Add(time.Minute * 60),
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
package middleware
|
||||
|
||||
type ContextKey string
|
||||
Reference in New Issue
Block a user