Files
schemas/middleware/auth.go
T

150 lines
3.6 KiB
Go
Raw Normal View History

2023-04-27 07:09:10 +02:00
package middleware
import (
"context"
"fmt"
"net/http"
"github.com/99designs/gqlgen/graphql"
"github.com/golang-jwt/jwt/v5"
2023-04-27 07:09:10 +02:00
"gitlab.com/unboundsoftware/schemas/domain"
)
const (
UserKey = ContextKey("user")
OrganizationKey = ContextKey("organization")
)
type Cache interface {
OrganizationByAPIKey(apiKey string) *domain.Organization
}
func NewAuth(cache Cache) *AuthMiddleware {
return &AuthMiddleware{
cache: cache,
}
}
type AuthMiddleware struct {
cache Cache
}
func (m *AuthMiddleware) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
token, err := TokenFromContext(r.Context())
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte("Invalid JWT token format"))
return
}
if token != nil {
ctx = context.WithValue(ctx, UserKey, token.Claims.(jwt.MapClaims)["sub"])
}
apiKey, err := ApiKeyFromContext(r.Context())
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte("Invalid API Key format"))
return
}
// Cache handles hash comparison internally
organization := m.cache.OrganizationByAPIKey(apiKey)
if organization != nil {
2023-04-27 07:09:10 +02:00
ctx = context.WithValue(ctx, OrganizationKey, *organization)
}
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func UserFromContext(ctx context.Context) string {
if value := ctx.Value(UserKey); value != nil {
if u, ok := value.(string); ok {
return u
}
}
return ""
}
func UserHasRole(ctx context.Context, role string) bool {
token, err := TokenFromContext(ctx)
if err != nil || token == nil {
return false
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return false
}
// Check the custom roles claim
rolesInterface, ok := claims["https://unbound.se/roles"]
if !ok {
return false
}
roles, ok := rolesInterface.([]interface{})
if !ok {
return false
}
for _, r := range roles {
if roleStr, ok := r.(string); ok && roleStr == role {
return true
}
}
return false
}
2023-04-27 07:09:10 +02:00
func OrganizationFromContext(ctx context.Context) string {
if value := ctx.Value(OrganizationKey); value != nil {
if u, ok := value.(domain.Organization); ok {
return u.ID.String()
}
}
return ""
}
func (m *AuthMiddleware) Directive(ctx context.Context, _ interface{}, next graphql.Resolver, user *bool, organization *bool) (res interface{}, err error) {
userRequired := user != nil && *user
orgRequired := organization != nil && *organization
u := UserFromContext(ctx)
orgId := OrganizationFromContext(ctx)
fmt.Printf("[Auth Directive] userRequired=%v, orgRequired=%v, hasUser=%v, hasOrg=%v\n",
userRequired, orgRequired, u != "", orgId != "")
// If both are required, it means EITHER is acceptable (OR logic)
if userRequired && orgRequired {
if u == "" && orgId == "" {
fmt.Printf("[Auth Directive] REJECTED: Neither user nor organization available\n")
return nil, fmt.Errorf("authentication required: provide either user token or organization API key")
}
fmt.Printf("[Auth Directive] ACCEPTED: Has user=%v OR organization=%v\n", u != "", orgId != "")
return next(ctx)
}
// Only user required
if userRequired {
if u == "" {
fmt.Printf("[Auth Directive] REJECTED: No user available\n")
2023-04-27 07:09:10 +02:00
return nil, fmt.Errorf("no user available in request")
}
fmt.Printf("[Auth Directive] ACCEPTED: User authenticated\n")
2023-04-27 07:09:10 +02:00
}
// Only organization required
if orgRequired {
if orgId == "" {
fmt.Printf("[Auth Directive] REJECTED: No organization available\n")
2023-04-27 07:09:10 +02:00
return nil, fmt.Errorf("no organization available in request")
}
fmt.Printf("[Auth Directive] ACCEPTED: Organization authenticated\n")
2023-04-27 07:09:10 +02:00
}
2023-04-27 07:09:10 +02:00
return next(ctx)
}