package middleware import ( "context" "fmt" "net/http" "github.com/99designs/gqlgen/graphql" "gitea.unbound.se/unboundsoftware/schemas/domain" ) const ( UserKey = ContextKey("user") OrganizationKey = ContextKey("organization") ) type Cache interface { OrganizationByAPIKey(apiKey string) *domain.Organization } func NewAuth(cache Cache) *AuthMiddleware { return &AuthMiddleware{ cache: cache, } } type AuthMiddleware struct { cache Cache } func (m *AuthMiddleware) Handler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() claims := ClaimsFromContext(r.Context()) if claims != nil { ctx = context.WithValue(ctx, UserKey, claims.RegisteredClaims.Subject) } apiKey, err := ApiKeyFromContext(r.Context()) if err != nil { w.WriteHeader(http.StatusInternalServerError) _, _ = w.Write([]byte("Invalid API Key format")) return } // Cache handles hash comparison internally organization := m.cache.OrganizationByAPIKey(apiKey) if organization != nil { ctx = context.WithValue(ctx, OrganizationKey, *organization) } next.ServeHTTP(w, r.WithContext(ctx)) }) } func UserFromContext(ctx context.Context) string { if value := ctx.Value(UserKey); value != nil { if u, ok := value.(string); ok { return u } } return "" } func UserHasRole(ctx context.Context, role string) bool { claims := ClaimsFromContext(ctx) if claims == nil { return false } customClaims, ok := claims.CustomClaims.(*CustomClaims) if !ok || customClaims == nil { return false } for _, r := range customClaims.Roles { if r == role { return true } } return false } func OrganizationFromContext(ctx context.Context) string { if value := ctx.Value(OrganizationKey); value != nil { if u, ok := value.(domain.Organization); ok { return u.ID.String() } } return "" } func (m *AuthMiddleware) Directive(ctx context.Context, _ interface{}, next graphql.Resolver, user *bool, organization *bool) (res interface{}, err error) { userRequired := user != nil && *user orgRequired := organization != nil && *organization u := UserFromContext(ctx) orgId := OrganizationFromContext(ctx) fmt.Printf("[Auth Directive] userRequired=%v, orgRequired=%v, hasUser=%v, hasOrg=%v\n", userRequired, orgRequired, u != "", orgId != "") // If both are required, it means EITHER is acceptable (OR logic) if userRequired && orgRequired { if u == "" && orgId == "" { fmt.Printf("[Auth Directive] REJECTED: Neither user nor organization available\n") return nil, fmt.Errorf("authentication required: provide either user token or organization API key") } fmt.Printf("[Auth Directive] ACCEPTED: Has user=%v OR organization=%v\n", u != "", orgId != "") return next(ctx) } // Only user required if userRequired { if u == "" { fmt.Printf("[Auth Directive] REJECTED: No user available\n") return nil, fmt.Errorf("no user available in request") } fmt.Printf("[Auth Directive] ACCEPTED: User authenticated\n") } // Only organization required if orgRequired { if orgId == "" { fmt.Printf("[Auth Directive] REJECTED: No organization available\n") return nil, fmt.Errorf("no organization available in request") } fmt.Printf("[Auth Directive] ACCEPTED: Organization authenticated\n") } return next(ctx) }