package middleware import ( "context" "fmt" "net/http" "github.com/99designs/gqlgen/graphql" "github.com/golang-jwt/jwt/v5" "gitlab.com/unboundsoftware/schemas/domain" ) const ( UserKey = ContextKey("user") OrganizationKey = ContextKey("organization") ) type Cache interface { OrganizationByAPIKey(apiKey string) *domain.Organization } func NewAuth(cache Cache) *AuthMiddleware { return &AuthMiddleware{ cache: cache, } } type AuthMiddleware struct { cache Cache } func (m *AuthMiddleware) Handler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() token, err := TokenFromContext(r.Context()) if err != nil { w.WriteHeader(http.StatusInternalServerError) _, _ = w.Write([]byte("Invalid JWT token format")) return } if token != nil { ctx = context.WithValue(ctx, UserKey, token.Claims.(jwt.MapClaims)["sub"]) } apiKey, err := ApiKeyFromContext(r.Context()) if err != nil { w.WriteHeader(http.StatusInternalServerError) _, _ = w.Write([]byte("Invalid API Key format")) return } // Cache handles hash comparison internally organization := m.cache.OrganizationByAPIKey(apiKey) if organization != nil { ctx = context.WithValue(ctx, OrganizationKey, *organization) } next.ServeHTTP(w, r.WithContext(ctx)) }) } func UserFromContext(ctx context.Context) string { if value := ctx.Value(UserKey); value != nil { if u, ok := value.(string); ok { return u } } return "" } func OrganizationFromContext(ctx context.Context) string { if value := ctx.Value(OrganizationKey); value != nil { if u, ok := value.(domain.Organization); ok { return u.ID.String() } } return "" } func (m *AuthMiddleware) Directive(ctx context.Context, _ interface{}, next graphql.Resolver, user *bool, organization *bool) (res interface{}, err error) { if user != nil && *user { if u := UserFromContext(ctx); u == "" { return nil, fmt.Errorf("no user available in request") } } if organization != nil && *organization { if orgId := OrganizationFromContext(ctx); orgId == "" { return nil, fmt.Errorf("no organization available in request") } } return next(ctx) }