From bb0c08be06478e486d25a2849db71563cc044fdf Mon Sep 17 00:00:00 2001 From: Joakim Olsson Date: Thu, 20 Nov 2025 08:09:00 +0100 Subject: [PATCH 1/5] fix: enhance API key handling and logging in middleware Refactor API key processing to improve clarity and reduce code duplication. Introduce detailed logging for schema updates and initializations, capturing relevant context information. Use background context for async operations to avoid blocking. Implement organization lookup logic in the WebSocket init function for consistent API key handling across connections. --- cmd/service/service.go | 19 ++++++++++++++ graph/schema.resolvers.go | 52 +++++++++++++++++++++++++++++++++++---- middleware/auth.go | 4 ++- 3 files changed, 69 insertions(+), 6 deletions(-) diff --git a/cmd/service/service.go b/cmd/service/service.go index dbefcf2..147d612 100644 --- a/cmd/service/service.go +++ b/cmd/service/service.go @@ -30,6 +30,7 @@ import ( "gitlab.com/unboundsoftware/schemas/domain" "gitlab.com/unboundsoftware/schemas/graph" "gitlab.com/unboundsoftware/schemas/graph/generated" + "gitlab.com/unboundsoftware/schemas/hash" "gitlab.com/unboundsoftware/schemas/logging" "gitlab.com/unboundsoftware/schemas/middleware" "gitlab.com/unboundsoftware/schemas/monitoring" @@ -210,6 +211,24 @@ func start(closeEvents chan error, logger *slog.Logger, connectToAmqpFunc func(u srv.AddTransport(transport.Websocket{ KeepAlivePingInterval: 10 * time.Second, + InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) { + // Extract API key from WebSocket connection_init payload + if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" { + logger.Info("WebSocket connection with API key", "has_key", true) + ctx = context.WithValue(ctx, middleware.ApiKey, apiKey) + + // Look up organization by API key (same logic as auth middleware) + if organization := serviceCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil { + logger.Info("WebSocket: Organization found for API key", "org_id", organization.ID.String()) + ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization) + } else { + logger.Warn("WebSocket: No organization found for API key") + } + } else { + logger.Info("WebSocket connection without API key") + } + return ctx, &initPayload, nil + }, }) srv.AddTransport(transport.Options{}) srv.AddTransport(transport.GET{}) diff --git a/graph/schema.resolvers.go b/graph/schema.resolvers.go index 6c68229..00f6202 100644 --- a/graph/schema.resolvers.go +++ b/graph/schema.resolvers.go @@ -123,6 +123,13 @@ func (r *mutationResolver) UpdateSubGraph(ctx context.Context, input model.Input // Publish schema update to subscribers go func() { services, lastUpdate := r.Cache.Services(orgId, input.Ref, "") + r.Logger.Info("Publishing schema update after subgraph change", + "ref", input.Ref, + "orgId", orgId, + "lastUpdate", lastUpdate, + "servicesCount", len(services), + ) + subGraphs := make([]*model.SubGraph, len(services)) for i, id := range services { sg, err := r.fetchSubGraph(context.Background(), id) @@ -149,12 +156,21 @@ func (r *mutationResolver) UpdateSubGraph(ctx context.Context, input model.Input } // Publish to all subscribers of this ref - r.PubSub.Publish(input.Ref, &model.SchemaUpdate{ + update := &model.SchemaUpdate{ Ref: input.Ref, ID: lastUpdate, SubGraphs: subGraphs, CosmoRouterConfig: &cosmoConfig, - }) + } + + r.Logger.Info("Publishing schema update to subscribers", + "ref", update.Ref, + "id", update.ID, + "subGraphsCount", len(update.SubGraphs), + "cosmoConfigLength", len(cosmoConfig), + ) + + r.PubSub.Publish(input.Ref, update) }() return r.toGqlSubGraph(subGraph), nil @@ -225,8 +241,15 @@ func (r *queryResolver) Supergraph(ctx context.Context, ref string, isAfter *str // SchemaUpdates is the resolver for the schemaUpdates field. func (r *subscriptionResolver) SchemaUpdates(ctx context.Context, ref string) (<-chan *model.SchemaUpdate, error) { orgId := middleware.OrganizationFromContext(ctx) + + r.Logger.Info("SchemaUpdates subscription started", + "ref", ref, + "orgId", orgId, + ) + _, err := r.apiKeyCanAccessRef(ctx, ref, false) if err != nil { + r.Logger.Error("API key cannot access ref", "error", err, "ref", ref) return nil, err } @@ -235,12 +258,22 @@ func (r *subscriptionResolver) SchemaUpdates(ctx context.Context, ref string) (< // Send initial state immediately go func() { + // Use background context for async operation + bgCtx := context.Background() + services, lastUpdate := r.Cache.Services(orgId, ref, "") + r.Logger.Info("Preparing initial schema update", + "ref", ref, + "orgId", orgId, + "lastUpdate", lastUpdate, + "servicesCount", len(services), + ) + subGraphs := make([]*model.SubGraph, len(services)) for i, id := range services { - sg, err := r.fetchSubGraph(ctx, id) + sg, err := r.fetchSubGraph(bgCtx, id) if err != nil { - r.Logger.Error("fetch subgraph for initial update", "error", err) + r.Logger.Error("fetch subgraph for initial update", "error", err, "id", id) continue } subGraphs[i] = &model.SubGraph{ @@ -262,12 +295,21 @@ func (r *subscriptionResolver) SchemaUpdates(ctx context.Context, ref string) (< } // Send initial update - ch <- &model.SchemaUpdate{ + update := &model.SchemaUpdate{ Ref: ref, ID: lastUpdate, SubGraphs: subGraphs, CosmoRouterConfig: &cosmoConfig, } + + r.Logger.Info("Sending initial schema update", + "ref", update.Ref, + "id", update.ID, + "subGraphsCount", len(update.SubGraphs), + "cosmoConfigLength", len(cosmoConfig), + ) + + ch <- update }() // Clean up subscription when context is done diff --git a/middleware/auth.go b/middleware/auth.go index 6704946..bbe3a9e 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -49,7 +49,9 @@ func (m *AuthMiddleware) Handler(next http.Handler) http.Handler { _, _ = w.Write([]byte("Invalid API Key format")) return } - if organization := m.cache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil { + hashedKey := hash.String(apiKey) + organization := m.cache.OrganizationByAPIKey(hashedKey) + if organization != nil { ctx = context.WithValue(ctx, OrganizationKey, *organization) } From 4d18cf4175b60b83a9a873afab3a507c7ab9b016 Mon Sep 17 00:00:00 2001 From: Joakim Olsson Date: Thu, 20 Nov 2025 14:24:39 +0100 Subject: [PATCH 2/5] feat(tests): add unit tests for WebSocket initialization logic Adds unit tests for the WebSocket initialization function to validate behavior with valid, invalid, and absent API keys. Introduces a mock cache implementation to simulate organization retrieval based on hashed API keys. Ensures proper context value setting upon initialization, enhancing test coverage and reliability for API key handling in WebSocket connections. --- cmd/service/service_test.go | 336 ++++++++++++++++++++++++++ middleware/auth_test.go | 467 ++++++++++++++++++++++++++++++++++++ 2 files changed, 803 insertions(+) create mode 100644 cmd/service/service_test.go create mode 100644 middleware/auth_test.go diff --git a/cmd/service/service_test.go b/cmd/service/service_test.go new file mode 100644 index 0000000..c9680d9 --- /dev/null +++ b/cmd/service/service_test.go @@ -0,0 +1,336 @@ +package main + +import ( + "context" + "testing" + + "github.com/99designs/gqlgen/graphql/handler/transport" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gitlab.com/unboundsoftware/eventsourced/eventsourced" + + "gitlab.com/unboundsoftware/schemas/domain" + "gitlab.com/unboundsoftware/schemas/hash" + "gitlab.com/unboundsoftware/schemas/middleware" +) + +// MockCache is a mock implementation for testing +type MockCache struct { + organizations map[string]*domain.Organization +} + +func (m *MockCache) OrganizationByAPIKey(apiKey string) *domain.Organization { + return m.organizations[apiKey] +} + +func TestWebSocketInitFunc_WithValidAPIKey(t *testing.T) { + // Setup + orgID := uuid.New() + org := &domain.Organization{ + BaseAggregate: eventsourced.BaseAggregate{ + ID: eventsourced.IdFromString(orgID.String()), + }, + Name: "Test Organization", + } + + apiKey := "test-api-key-123" + hashedKey := hash.String(apiKey) + + mockCache := &MockCache{ + organizations: map[string]*domain.Organization{ + hashedKey: org, + }, + } + + // Create InitFunc (simulating the WebSocket InitFunc logic) + initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) { + // Extract API key from WebSocket connection_init payload + if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" { + ctx = context.WithValue(ctx, middleware.ApiKey, apiKey) + + // Look up organization by API key + if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil { + ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization) + } + } + return ctx, &initPayload, nil + } + + // Test + ctx := context.Background() + initPayload := transport.InitPayload{ + "X-Api-Key": apiKey, + } + + resultCtx, resultPayload, err := initFunc(ctx, initPayload) + + // Assert + require.NoError(t, err) + require.NotNil(t, resultPayload) + + // Check API key is in context + if value := resultCtx.Value(middleware.ApiKey); value != nil { + assert.Equal(t, apiKey, value.(string)) + } else { + t.Fatal("API key not found in context") + } + + // Check organization is in context + if value := resultCtx.Value(middleware.OrganizationKey); value != nil { + capturedOrg, ok := value.(domain.Organization) + require.True(t, ok, "Organization should be of correct type") + assert.Equal(t, org.Name, capturedOrg.Name) + assert.Equal(t, org.ID.String(), capturedOrg.ID.String()) + } else { + t.Fatal("Organization not found in context") + } +} + +func TestWebSocketInitFunc_WithInvalidAPIKey(t *testing.T) { + // Setup + mockCache := &MockCache{ + organizations: map[string]*domain.Organization{}, + } + + apiKey := "invalid-api-key" + + // Create InitFunc + initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) { + // Extract API key from WebSocket connection_init payload + if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" { + ctx = context.WithValue(ctx, middleware.ApiKey, apiKey) + + // Look up organization by API key + if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil { + ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization) + } + } + return ctx, &initPayload, nil + } + + // Test + ctx := context.Background() + initPayload := transport.InitPayload{ + "X-Api-Key": apiKey, + } + + resultCtx, resultPayload, err := initFunc(ctx, initPayload) + + // Assert + require.NoError(t, err) + require.NotNil(t, resultPayload) + + // Check API key is in context + if value := resultCtx.Value(middleware.ApiKey); value != nil { + assert.Equal(t, apiKey, value.(string)) + } else { + t.Fatal("API key not found in context") + } + + // Check organization is NOT in context (since API key is invalid) + value := resultCtx.Value(middleware.OrganizationKey) + assert.Nil(t, value, "Organization should not be set for invalid API key") +} + +func TestWebSocketInitFunc_WithoutAPIKey(t *testing.T) { + // Setup + mockCache := &MockCache{ + organizations: map[string]*domain.Organization{}, + } + + // Create InitFunc + initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) { + // Extract API key from WebSocket connection_init payload + if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" { + ctx = context.WithValue(ctx, middleware.ApiKey, apiKey) + + // Look up organization by API key + if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil { + ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization) + } + } + return ctx, &initPayload, nil + } + + // Test + ctx := context.Background() + initPayload := transport.InitPayload{} + + resultCtx, resultPayload, err := initFunc(ctx, initPayload) + + // Assert + require.NoError(t, err) + require.NotNil(t, resultPayload) + + // Check API key is NOT in context + value := resultCtx.Value(middleware.ApiKey) + assert.Nil(t, value, "API key should not be set when not provided") + + // Check organization is NOT in context + value = resultCtx.Value(middleware.OrganizationKey) + assert.Nil(t, value, "Organization should not be set when API key is not provided") +} + +func TestWebSocketInitFunc_WithEmptyAPIKey(t *testing.T) { + // Setup + mockCache := &MockCache{ + organizations: map[string]*domain.Organization{}, + } + + // Create InitFunc + initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) { + // Extract API key from WebSocket connection_init payload + if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" { + ctx = context.WithValue(ctx, middleware.ApiKey, apiKey) + + // Look up organization by API key + if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil { + ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization) + } + } + return ctx, &initPayload, nil + } + + // Test + ctx := context.Background() + initPayload := transport.InitPayload{ + "X-Api-Key": "", // Empty string + } + + resultCtx, resultPayload, err := initFunc(ctx, initPayload) + + // Assert + require.NoError(t, err) + require.NotNil(t, resultPayload) + + // Check API key is NOT in context (because empty string fails the condition) + value := resultCtx.Value(middleware.ApiKey) + assert.Nil(t, value, "API key should not be set when empty") + + // Check organization is NOT in context + value = resultCtx.Value(middleware.OrganizationKey) + assert.Nil(t, value, "Organization should not be set when API key is empty") +} + +func TestWebSocketInitFunc_WithWrongTypeAPIKey(t *testing.T) { + // Setup + mockCache := &MockCache{ + organizations: map[string]*domain.Organization{}, + } + + // Create InitFunc + initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) { + // Extract API key from WebSocket connection_init payload + if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" { + ctx = context.WithValue(ctx, middleware.ApiKey, apiKey) + + // Look up organization by API key + if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil { + ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization) + } + } + return ctx, &initPayload, nil + } + + // Test + ctx := context.Background() + initPayload := transport.InitPayload{ + "X-Api-Key": 12345, // Wrong type (int instead of string) + } + + resultCtx, resultPayload, err := initFunc(ctx, initPayload) + + // Assert + require.NoError(t, err) + require.NotNil(t, resultPayload) + + // Check API key is NOT in context (type assertion fails) + value := resultCtx.Value(middleware.ApiKey) + assert.Nil(t, value, "API key should not be set when wrong type") + + // Check organization is NOT in context + value = resultCtx.Value(middleware.OrganizationKey) + assert.Nil(t, value, "Organization should not be set when API key has wrong type") +} + +func TestWebSocketInitFunc_WithMultipleOrganizations(t *testing.T) { + // Setup - create multiple organizations + org1ID := uuid.New() + org1 := &domain.Organization{ + BaseAggregate: eventsourced.BaseAggregate{ + ID: eventsourced.IdFromString(org1ID.String()), + }, + Name: "Organization 1", + } + + org2ID := uuid.New() + org2 := &domain.Organization{ + BaseAggregate: eventsourced.BaseAggregate{ + ID: eventsourced.IdFromString(org2ID.String()), + }, + Name: "Organization 2", + } + + apiKey1 := "api-key-org-1" + apiKey2 := "api-key-org-2" + hashedKey1 := hash.String(apiKey1) + hashedKey2 := hash.String(apiKey2) + + mockCache := &MockCache{ + organizations: map[string]*domain.Organization{ + hashedKey1: org1, + hashedKey2: org2, + }, + } + + // Create InitFunc + initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) { + // Extract API key from WebSocket connection_init payload + if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" { + ctx = context.WithValue(ctx, middleware.ApiKey, apiKey) + + // Look up organization by API key + if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil { + ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization) + } + } + return ctx, &initPayload, nil + } + + // Test with first API key + ctx1 := context.Background() + initPayload1 := transport.InitPayload{ + "X-Api-Key": apiKey1, + } + + resultCtx1, _, err := initFunc(ctx1, initPayload1) + require.NoError(t, err) + + if value := resultCtx1.Value(middleware.OrganizationKey); value != nil { + capturedOrg, ok := value.(domain.Organization) + require.True(t, ok) + assert.Equal(t, org1.Name, capturedOrg.Name) + assert.Equal(t, org1.ID.String(), capturedOrg.ID.String()) + } else { + t.Fatal("Organization 1 not found in context") + } + + // Test with second API key + ctx2 := context.Background() + initPayload2 := transport.InitPayload{ + "X-Api-Key": apiKey2, + } + + resultCtx2, _, err := initFunc(ctx2, initPayload2) + require.NoError(t, err) + + if value := resultCtx2.Value(middleware.OrganizationKey); value != nil { + capturedOrg, ok := value.(domain.Organization) + require.True(t, ok) + assert.Equal(t, org2.Name, capturedOrg.Name) + assert.Equal(t, org2.ID.String(), capturedOrg.ID.String()) + } else { + t.Fatal("Organization 2 not found in context") + } +} diff --git a/middleware/auth_test.go b/middleware/auth_test.go new file mode 100644 index 0000000..3e456c8 --- /dev/null +++ b/middleware/auth_test.go @@ -0,0 +1,467 @@ +package middleware + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + mw "github.com/auth0/go-jwt-middleware/v2" + "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "gitlab.com/unboundsoftware/eventsourced/eventsourced" + + "gitlab.com/unboundsoftware/schemas/domain" + "gitlab.com/unboundsoftware/schemas/hash" +) + +// MockCache is a mock implementation of the Cache interface +type MockCache struct { + mock.Mock +} + +func (m *MockCache) OrganizationByAPIKey(apiKey string) *domain.Organization { + args := m.Called(apiKey) + if args.Get(0) == nil { + return nil + } + return args.Get(0).(*domain.Organization) +} + +func TestAuthMiddleware_Handler_WithValidAPIKey(t *testing.T) { + // Setup + mockCache := new(MockCache) + authMiddleware := NewAuth(mockCache) + + orgID := uuid.New() + expectedOrg := &domain.Organization{ + BaseAggregate: eventsourced.BaseAggregate{ + ID: eventsourced.IdFromString(orgID.String()), + }, + Name: "Test Organization", + } + + apiKey := "test-api-key-123" + hashedKey := hash.String(apiKey) + + mockCache.On("OrganizationByAPIKey", hashedKey).Return(expectedOrg) + + // Create a test handler that checks the context + var capturedOrg *domain.Organization + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if org := r.Context().Value(OrganizationKey); org != nil { + if o, ok := org.(domain.Organization); ok { + capturedOrg = &o + } + } + w.WriteHeader(http.StatusOK) + }) + + // Create request with API key in context + req := httptest.NewRequest(http.MethodGet, "/test", nil) + ctx := context.WithValue(req.Context(), ApiKey, apiKey) + req = req.WithContext(ctx) + + rec := httptest.NewRecorder() + + // Execute + authMiddleware.Handler(testHandler).ServeHTTP(rec, req) + + // Assert + assert.Equal(t, http.StatusOK, rec.Code) + require.NotNil(t, capturedOrg) + assert.Equal(t, expectedOrg.Name, capturedOrg.Name) + assert.Equal(t, expectedOrg.ID.String(), capturedOrg.ID.String()) + mockCache.AssertExpectations(t) +} + +func TestAuthMiddleware_Handler_WithInvalidAPIKey(t *testing.T) { + // Setup + mockCache := new(MockCache) + authMiddleware := NewAuth(mockCache) + + apiKey := "invalid-api-key" + hashedKey := hash.String(apiKey) + + mockCache.On("OrganizationByAPIKey", hashedKey).Return(nil) + + // Create a test handler that checks the context + var capturedOrg *domain.Organization + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if org := r.Context().Value(OrganizationKey); org != nil { + if o, ok := org.(domain.Organization); ok { + capturedOrg = &o + } + } + w.WriteHeader(http.StatusOK) + }) + + // Create request with API key in context + req := httptest.NewRequest(http.MethodGet, "/test", nil) + ctx := context.WithValue(req.Context(), ApiKey, apiKey) + req = req.WithContext(ctx) + + rec := httptest.NewRecorder() + + // Execute + authMiddleware.Handler(testHandler).ServeHTTP(rec, req) + + // Assert + assert.Equal(t, http.StatusOK, rec.Code) + assert.Nil(t, capturedOrg, "Organization should not be set for invalid API key") + mockCache.AssertExpectations(t) +} + +func TestAuthMiddleware_Handler_WithoutAPIKey(t *testing.T) { + // Setup + mockCache := new(MockCache) + authMiddleware := NewAuth(mockCache) + + // The middleware always hashes the API key (even if empty) and calls the cache + emptyKeyHash := hash.String("") + mockCache.On("OrganizationByAPIKey", emptyKeyHash).Return(nil) + + // Create a test handler that checks the context + var capturedOrg *domain.Organization + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if org := r.Context().Value(OrganizationKey); org != nil { + if o, ok := org.(domain.Organization); ok { + capturedOrg = &o + } + } + w.WriteHeader(http.StatusOK) + }) + + // Create request without API key + req := httptest.NewRequest(http.MethodGet, "/test", nil) + rec := httptest.NewRecorder() + + // Execute + authMiddleware.Handler(testHandler).ServeHTTP(rec, req) + + // Assert + assert.Equal(t, http.StatusOK, rec.Code) + assert.Nil(t, capturedOrg, "Organization should not be set without API key") + mockCache.AssertExpectations(t) +} + +func TestAuthMiddleware_Handler_WithValidJWT(t *testing.T) { + // Setup + mockCache := new(MockCache) + authMiddleware := NewAuth(mockCache) + + // The middleware always hashes the API key (even if empty) and calls the cache + emptyKeyHash := hash.String("") + mockCache.On("OrganizationByAPIKey", emptyKeyHash).Return(nil) + + userID := "user-123" + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "sub": userID, + }) + + // Create a test handler that checks the context + var capturedUser string + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if user := r.Context().Value(UserKey); user != nil { + if u, ok := user.(string); ok { + capturedUser = u + } + } + w.WriteHeader(http.StatusOK) + }) + + // Create request with JWT token in context + req := httptest.NewRequest(http.MethodGet, "/test", nil) + ctx := context.WithValue(req.Context(), mw.ContextKey{}, token) + req = req.WithContext(ctx) + + rec := httptest.NewRecorder() + + // Execute + authMiddleware.Handler(testHandler).ServeHTTP(rec, req) + + // Assert + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, userID, capturedUser) +} + +func TestAuthMiddleware_Handler_APIKeyErrorHandling(t *testing.T) { + // Setup + mockCache := new(MockCache) + authMiddleware := NewAuth(mockCache) + + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + // Create request with invalid API key type in context + req := httptest.NewRequest(http.MethodGet, "/test", nil) + ctx := context.WithValue(req.Context(), ApiKey, 12345) // Invalid type + req = req.WithContext(ctx) + + rec := httptest.NewRecorder() + + // Execute + authMiddleware.Handler(testHandler).ServeHTTP(rec, req) + + // Assert + assert.Equal(t, http.StatusInternalServerError, rec.Code) + assert.Contains(t, rec.Body.String(), "Invalid API Key format") +} + +func TestAuthMiddleware_Handler_JWTErrorHandling(t *testing.T) { + // Setup + mockCache := new(MockCache) + authMiddleware := NewAuth(mockCache) + + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + // Create request with invalid JWT token type in context + req := httptest.NewRequest(http.MethodGet, "/test", nil) + ctx := context.WithValue(req.Context(), mw.ContextKey{}, "not-a-token") // Invalid type + req = req.WithContext(ctx) + + rec := httptest.NewRecorder() + + // Execute + authMiddleware.Handler(testHandler).ServeHTTP(rec, req) + + // Assert + assert.Equal(t, http.StatusInternalServerError, rec.Code) + assert.Contains(t, rec.Body.String(), "Invalid JWT token format") +} + +func TestAuthMiddleware_Handler_BothJWTAndAPIKey(t *testing.T) { + // Setup + mockCache := new(MockCache) + authMiddleware := NewAuth(mockCache) + + orgID := uuid.New() + expectedOrg := &domain.Organization{ + BaseAggregate: eventsourced.BaseAggregate{ + ID: eventsourced.IdFromString(orgID.String()), + }, + Name: "Test Organization", + } + + userID := "user-123" + apiKey := "test-api-key-123" + hashedKey := hash.String(apiKey) + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "sub": userID, + }) + + mockCache.On("OrganizationByAPIKey", hashedKey).Return(expectedOrg) + + // Create a test handler that checks both user and organization in context + var capturedUser string + var capturedOrg *domain.Organization + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if user := r.Context().Value(UserKey); user != nil { + if u, ok := user.(string); ok { + capturedUser = u + } + } + if org := r.Context().Value(OrganizationKey); org != nil { + if o, ok := org.(domain.Organization); ok { + capturedOrg = &o + } + } + w.WriteHeader(http.StatusOK) + }) + + // Create request with both JWT and API key in context + req := httptest.NewRequest(http.MethodGet, "/test", nil) + ctx := context.WithValue(req.Context(), mw.ContextKey{}, token) + ctx = context.WithValue(ctx, ApiKey, apiKey) + req = req.WithContext(ctx) + + rec := httptest.NewRecorder() + + // Execute + authMiddleware.Handler(testHandler).ServeHTTP(rec, req) + + // Assert + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, userID, capturedUser) + require.NotNil(t, capturedOrg) + assert.Equal(t, expectedOrg.Name, capturedOrg.Name) + mockCache.AssertExpectations(t) +} + +func TestUserFromContext(t *testing.T) { + tests := []struct { + name string + ctx context.Context + expected string + }{ + { + name: "with valid user", + ctx: context.WithValue(context.Background(), UserKey, "user-123"), + expected: "user-123", + }, + { + name: "without user", + ctx: context.Background(), + expected: "", + }, + { + name: "with invalid type", + ctx: context.WithValue(context.Background(), UserKey, 123), + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := UserFromContext(tt.ctx) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestOrganizationFromContext(t *testing.T) { + orgID := uuid.New() + org := domain.Organization{ + BaseAggregate: eventsourced.BaseAggregate{ + ID: eventsourced.IdFromString(orgID.String()), + }, + Name: "Test Org", + } + + tests := []struct { + name string + ctx context.Context + expected string + }{ + { + name: "with valid organization", + ctx: context.WithValue(context.Background(), OrganizationKey, org), + expected: orgID.String(), + }, + { + name: "without organization", + ctx: context.Background(), + expected: "", + }, + { + name: "with invalid type", + ctx: context.WithValue(context.Background(), OrganizationKey, "not-an-org"), + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := OrganizationFromContext(tt.ctx) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestAuthMiddleware_Directive_RequiresUser(t *testing.T) { + mockCache := new(MockCache) + authMiddleware := NewAuth(mockCache) + + requireUser := true + + // Test with user present + ctx := context.WithValue(context.Background(), UserKey, "user-123") + _, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) { + return "success", nil + }, &requireUser, nil) + assert.NoError(t, err) + + // Test without user + ctx = context.Background() + _, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) { + return "success", nil + }, &requireUser, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "no user available in request") +} + +func TestAuthMiddleware_Directive_RequiresOrganization(t *testing.T) { + mockCache := new(MockCache) + authMiddleware := NewAuth(mockCache) + + requireOrg := true + orgID := uuid.New() + org := domain.Organization{ + BaseAggregate: eventsourced.BaseAggregate{ + ID: eventsourced.IdFromString(orgID.String()), + }, + Name: "Test Org", + } + + // Test with organization present + ctx := context.WithValue(context.Background(), OrganizationKey, org) + _, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) { + return "success", nil + }, nil, &requireOrg) + assert.NoError(t, err) + + // Test without organization + ctx = context.Background() + _, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) { + return "success", nil + }, nil, &requireOrg) + assert.Error(t, err) + assert.Contains(t, err.Error(), "no organization available in request") +} + +func TestAuthMiddleware_Directive_RequiresBoth(t *testing.T) { + mockCache := new(MockCache) + authMiddleware := NewAuth(mockCache) + + requireUser := true + requireOrg := true + orgID := uuid.New() + org := domain.Organization{ + BaseAggregate: eventsourced.BaseAggregate{ + ID: eventsourced.IdFromString(orgID.String()), + }, + Name: "Test Org", + } + + // Test with both present + ctx := context.WithValue(context.Background(), UserKey, "user-123") + ctx = context.WithValue(ctx, OrganizationKey, org) + _, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) { + return "success", nil + }, &requireUser, &requireOrg) + assert.NoError(t, err) + + // Test with only user + ctx = context.WithValue(context.Background(), UserKey, "user-123") + _, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) { + return "success", nil + }, &requireUser, &requireOrg) + assert.Error(t, err) + + // Test with only organization + ctx = context.WithValue(context.Background(), OrganizationKey, org) + _, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) { + return "success", nil + }, &requireUser, &requireOrg) + assert.Error(t, err) +} + +func TestAuthMiddleware_Directive_NoRequirements(t *testing.T) { + mockCache := new(MockCache) + authMiddleware := NewAuth(mockCache) + + // Test with no requirements + ctx := context.Background() + result, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) { + return "success", nil + }, nil, nil) + assert.NoError(t, err) + assert.Equal(t, "success", result) +} From 9368d77bc81bfcb9b8fe33965b7652e3ae970131 Mon Sep 17 00:00:00 2001 From: Joakim Olsson Date: Thu, 20 Nov 2025 17:02:19 +0100 Subject: [PATCH 3/5] feat: add latestSchema query for retrieving schema updates Implements the `latestSchema` query to fetch the latest schema updates for an organization. This change enhances the GraphQL API by allowing clients to retrieve the most recent schema version and its associated subgraphs. The resolver performs necessary access checks, logs relevant information, and generates the Cosmo router configuration from fetched subgraph SDLs, returning structured schema update details. --- Dockerfile | 10 ++- go.mod | 5 +- go.sum | 2 + graph/cosmo.go | 94 +++++++++++++++++++++------- graph/cosmo_test.go | 49 --------------- graph/generated/generated.go | 116 +++++++++++++++++++++++++++++++++++ graph/schema.graphqls | 1 + graph/schema.resolvers.go | 66 ++++++++++++++++++++ 8 files changed, 270 insertions(+), 73 deletions(-) diff --git a/Dockerfile b/Dockerfile index 3512d54..67e36fd 100644 --- a/Dockerfile +++ b/Dockerfile @@ -24,9 +24,17 @@ RUN GOOS=linux GOARCH=amd64 go build \ FROM scratch as export COPY --from=build /build/coverage.txt / -FROM scratch +FROM node:22-alpine ENV TZ Europe/Stockholm + +# Install wgc CLI globally for Cosmo Router composition +RUN npm install -g wgc@latest + +# Copy timezone data and certificates COPY --from=build /usr/share/zoneinfo /usr/share/zoneinfo COPY --from=build /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ + +# Copy the service binary COPY --from=build /release/service / + CMD ["/service"] diff --git a/go.mod b/go.mod index 127627f..6a4d2f0 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/apex/log v1.9.0 github.com/auth0/go-jwt-middleware/v2 v2.3.0 github.com/golang-jwt/jwt/v5 v5.3.0 + github.com/google/uuid v1.6.0 github.com/jmoiron/sqlx v1.4.0 github.com/pkg/errors v0.9.1 github.com/pressly/goose/v3 v3.26.0 @@ -30,6 +31,7 @@ require ( go.opentelemetry.io/otel/sdk/log v0.14.0 go.opentelemetry.io/otel/sdk/metric v1.38.0 go.opentelemetry.io/otel/trace v1.38.0 + gopkg.in/yaml.v3 v3.0.1 ) require ( @@ -41,7 +43,6 @@ require ( github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-viper/mapstructure/v2 v2.4.0 // indirect - github.com/google/uuid v1.6.0 // indirect github.com/gorilla/websocket v1.5.1 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect @@ -51,6 +52,7 @@ require ( github.com/rabbitmq/amqp091-go v1.10.0 // indirect github.com/sethvargo/go-retry v0.3.0 // indirect github.com/sosodev/duration v1.3.1 // indirect + github.com/stretchr/objx v0.5.2 // indirect github.com/tidwall/gjson v1.17.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect @@ -72,5 +74,4 @@ require ( google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 // indirect google.golang.org/grpc v1.75.0 // indirect google.golang.org/protobuf v1.36.10 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index fd5c8e7..3abb5fe 100644 --- a/go.sum +++ b/go.sum @@ -141,6 +141,8 @@ github.com/sosodev/duration v1.3.1/go.mod h1:RQIBBX0+fMLc/D9+Jb/fwvVmo0eZvDDEERA github.com/sparetimecoders/goamqp v0.3.3 h1:z/nfTPmrjeU/rIVuNOgsVLCimp3WFoNFvS3ZzXRJ6HE= github.com/sparetimecoders/goamqp v0.3.3/go.mod h1:W9NRCpWLE+Vruv2dcRSbszNil2O826d2Nv6kAkETW5o= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= diff --git a/graph/cosmo.go b/graph/cosmo.go index 150e162..69deb88 100644 --- a/graph/cosmo.go +++ b/graph/cosmo.go @@ -1,54 +1,106 @@ package graph import ( - "encoding/json" "fmt" + "os" + "os/exec" + "path/filepath" + + "gopkg.in/yaml.v3" "gitlab.com/unboundsoftware/schemas/graph/model" ) // GenerateCosmoRouterConfig generates a Cosmo Router execution config from subgraphs +// using the official wgc CLI tool via npx func GenerateCosmoRouterConfig(subGraphs []*model.SubGraph) (string, error) { - // Build the Cosmo router config structure - // This is a simplified version - you may need to adjust based on actual Cosmo requirements - config := map[string]interface{}{ - "version": "1", - "subgraphs": convertSubGraphsToCosmo(subGraphs), - // Add other Cosmo-specific configuration as needed + if len(subGraphs) == 0 { + return "", fmt.Errorf("no subgraphs provided") } - // Marshal to JSON - configJSON, err := json.MarshalIndent(config, "", " ") + // Create a temporary directory for composition + tmpDir, err := os.MkdirTemp("", "cosmo-compose-*") if err != nil { - return "", fmt.Errorf("marshal cosmo config: %w", err) + return "", fmt.Errorf("create temp dir: %w", err) + } + defer os.RemoveAll(tmpDir) + + // Write each subgraph SDL to a file + type SubgraphConfig struct { + Name string `yaml:"name"` + RoutingURL string `yaml:"routing_url,omitempty"` + Schema map[string]string `yaml:"schema"` + Subscription map[string]interface{} `yaml:"subscription,omitempty"` } - return string(configJSON), nil -} + type InputConfig struct { + Version int `yaml:"version"` + Subgraphs []SubgraphConfig `yaml:"subgraphs"` + } -func convertSubGraphsToCosmo(subGraphs []*model.SubGraph) []map[string]interface{} { - cosmoSubgraphs := make([]map[string]interface{}, 0, len(subGraphs)) + inputConfig := InputConfig{ + Version: 1, + Subgraphs: make([]SubgraphConfig, 0, len(subGraphs)), + } for _, sg := range subGraphs { - cosmoSg := map[string]interface{}{ - "name": sg.Service, - "sdl": sg.Sdl, + // Write SDL to a temp file + schemaFile := filepath.Join(tmpDir, fmt.Sprintf("%s.graphql", sg.Service)) + if err := os.WriteFile(schemaFile, []byte(sg.Sdl), 0o644); err != nil { + return "", fmt.Errorf("write schema file for %s: %w", sg.Service, err) + } + + subgraphCfg := SubgraphConfig{ + Name: sg.Service, + Schema: map[string]string{ + "file": schemaFile, + }, } if sg.URL != nil { - cosmoSg["routing_url"] = *sg.URL + subgraphCfg.RoutingURL = *sg.URL } if sg.WsURL != nil { - cosmoSg["subscription"] = map[string]interface{}{ + subgraphCfg.Subscription = map[string]interface{}{ "url": *sg.WsURL, "protocol": "ws", "websocket_subprotocol": "graphql-ws", } } - cosmoSubgraphs = append(cosmoSubgraphs, cosmoSg) + inputConfig.Subgraphs = append(inputConfig.Subgraphs, subgraphCfg) } - return cosmoSubgraphs + // Write input config YAML + inputFile := filepath.Join(tmpDir, "input.yaml") + inputYAML, err := yaml.Marshal(inputConfig) + if err != nil { + return "", fmt.Errorf("marshal input config: %w", err) + } + if err := os.WriteFile(inputFile, inputYAML, 0o644); err != nil { + return "", fmt.Errorf("write input config: %w", err) + } + + // Execute wgc router compose + // wgc is installed globally in the Docker image + outputFile := filepath.Join(tmpDir, "config.json") + cmd := exec.Command("wgc", "router", "compose", + "--input", inputFile, + "--out", outputFile, + "--suppress-warnings", + ) + + output, err := cmd.CombinedOutput() + if err != nil { + return "", fmt.Errorf("wgc router compose failed: %w\nOutput: %s", err, string(output)) + } + + // Read the generated config + configJSON, err := os.ReadFile(outputFile) + if err != nil { + return "", fmt.Errorf("read output config: %w", err) + } + + return string(configJSON), nil } diff --git a/graph/cosmo_test.go b/graph/cosmo_test.go index 0f6de36..60cd3a0 100644 --- a/graph/cosmo_test.go +++ b/graph/cosmo_test.go @@ -203,55 +203,6 @@ func TestGenerateCosmoRouterConfig(t *testing.T) { } } -func TestConvertSubGraphsToCosmo(t *testing.T) { - tests := []struct { - name string - subGraphs []*model.SubGraph - wantLen int - validate func(t *testing.T, result []map[string]interface{}) - }{ - { - name: "preserves subgraph order", - subGraphs: []*model.SubGraph{ - {Service: "alpha", URL: stringPtr("http://a"), Sdl: "a"}, - {Service: "beta", URL: stringPtr("http://b"), Sdl: "b"}, - {Service: "gamma", URL: stringPtr("http://c"), Sdl: "c"}, - }, - wantLen: 3, - validate: func(t *testing.T, result []map[string]interface{}) { - assert.Equal(t, "alpha", result[0]["name"]) - assert.Equal(t, "beta", result[1]["name"]) - assert.Equal(t, "gamma", result[2]["name"]) - }, - }, - { - name: "includes SDL exactly as provided", - subGraphs: []*model.SubGraph{ - { - Service: "test", - URL: stringPtr("http://test"), - Sdl: "type Query { special: String! }", - }, - }, - wantLen: 1, - validate: func(t *testing.T, result []map[string]interface{}) { - assert.Equal(t, "type Query { special: String! }", result[0]["sdl"]) - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := convertSubGraphsToCosmo(tt.subGraphs) - assert.Len(t, result, tt.wantLen) - - if tt.validate != nil { - tt.validate(t, result) - } - }) - } -} - // Helper function for tests func stringPtr(s string) *string { return &s diff --git a/graph/generated/generated.go b/graph/generated/generated.go index b4e5fec..7ba0598 100644 --- a/graph/generated/generated.go +++ b/graph/generated/generated.go @@ -74,6 +74,7 @@ type ComplexityRoot struct { } Query struct { + LatestSchema func(childComplexity int, ref string) int Organizations func(childComplexity int) int Supergraph func(childComplexity int, ref string, isAfter *string) int } @@ -124,6 +125,7 @@ type MutationResolver interface { type QueryResolver interface { Organizations(ctx context.Context) ([]*model.Organization, error) Supergraph(ctx context.Context, ref string, isAfter *string) (model.Supergraph, error) + LatestSchema(ctx context.Context, ref string) (*model.SchemaUpdate, error) } type SubscriptionResolver interface { SchemaUpdates(ctx context.Context, ref string) (<-chan *model.SchemaUpdate, error) @@ -250,6 +252,17 @@ func (e *executableSchema) Complexity(ctx context.Context, typeName, field strin return e.complexity.Organization.Users(childComplexity), true + case "Query.latestSchema": + if e.complexity.Query.LatestSchema == nil { + break + } + + args, err := ec.field_Query_latestSchema_args(ctx, rawArgs) + if err != nil { + return 0, false + } + + return e.complexity.Query.LatestSchema(childComplexity, args["ref"].(string)), true case "Query.organizations": if e.complexity.Query.Organizations == nil { break @@ -520,6 +533,7 @@ var sources = []*ast.Source{ {Name: "../schema.graphqls", Input: `type Query { organizations: [Organization!]! @auth(user: true) supergraph(ref: String!, isAfter: String): Supergraph! @auth(organization: true) + latestSchema(ref: String!): SchemaUpdate! @auth(organization: true) } type Mutation { @@ -671,6 +685,17 @@ func (ec *executionContext) field_Query___type_args(ctx context.Context, rawArgs return args, nil } +func (ec *executionContext) field_Query_latestSchema_args(ctx context.Context, rawArgs map[string]any) (map[string]any, error) { + var err error + args := map[string]any{} + arg0, err := graphql.ProcessArgField(ctx, rawArgs, "ref", ec.unmarshalNString2string) + if err != nil { + return nil, err + } + args["ref"] = arg0 + return args, nil +} + func (ec *executionContext) field_Query_supergraph_args(ctx context.Context, rawArgs map[string]any) (map[string]any, error) { var err error args := map[string]any{} @@ -1434,6 +1459,75 @@ func (ec *executionContext) fieldContext_Query_supergraph(ctx context.Context, f return fc, nil } +func (ec *executionContext) _Query_latestSchema(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { + return graphql.ResolveField( + ctx, + ec.OperationContext, + field, + ec.fieldContext_Query_latestSchema, + func(ctx context.Context) (any, error) { + fc := graphql.GetFieldContext(ctx) + return ec.resolvers.Query().LatestSchema(ctx, fc.Args["ref"].(string)) + }, + func(ctx context.Context, next graphql.Resolver) graphql.Resolver { + directive0 := next + + directive1 := func(ctx context.Context) (any, error) { + organization, err := ec.unmarshalOBoolean2ᚖbool(ctx, true) + if err != nil { + var zeroVal *model.SchemaUpdate + return zeroVal, err + } + if ec.directives.Auth == nil { + var zeroVal *model.SchemaUpdate + return zeroVal, errors.New("directive auth is not implemented") + } + return ec.directives.Auth(ctx, nil, directive0, nil, organization) + } + + next = directive1 + return next + }, + ec.marshalNSchemaUpdate2ᚖgitlabᚗcomᚋunboundsoftwareᚋschemasᚋgraphᚋmodelᚐSchemaUpdate, + true, + true, + ) +} + +func (ec *executionContext) fieldContext_Query_latestSchema(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "Query", + Field: field, + IsMethod: true, + IsResolver: true, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + switch field.Name { + case "ref": + return ec.fieldContext_SchemaUpdate_ref(ctx, field) + case "id": + return ec.fieldContext_SchemaUpdate_id(ctx, field) + case "subGraphs": + return ec.fieldContext_SchemaUpdate_subGraphs(ctx, field) + case "cosmoRouterConfig": + return ec.fieldContext_SchemaUpdate_cosmoRouterConfig(ctx, field) + } + return nil, fmt.Errorf("no field named %q was found under type SchemaUpdate", field.Name) + }, + } + defer func() { + if r := recover(); r != nil { + err = ec.Recover(ctx, r) + ec.Error(ctx, err) + } + }() + ctx = graphql.WithFieldContext(ctx, fc) + if fc.Args, err = ec.field_Query_latestSchema_args(ctx, field.ArgumentMap(ec.Variables)); err != nil { + ec.Error(ctx, err) + return fc, err + } + return fc, nil +} + func (ec *executionContext) _Query___type(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { return graphql.ResolveField( ctx, @@ -3997,6 +4091,28 @@ func (ec *executionContext) _Query(ctx context.Context, sel ast.SelectionSet) gr func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) }) } + out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return rrm(innerCtx) }) + case "latestSchema": + field := field + + innerFunc := func(ctx context.Context, fs *graphql.FieldSet) (res graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + } + }() + res = ec._Query_latestSchema(ctx, field) + if res == graphql.Null { + atomic.AddUint32(&fs.Invalids, 1) + } + return res + } + + rrm := func(ctx context.Context) graphql.Marshaler { + return ec.OperationContext.RootResolverMiddleware(ctx, + func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) }) + } + out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return rrm(innerCtx) }) case "__type": out.Values[i] = ec.OperationContext.RootResolverMiddleware(innerCtx, func(ctx context.Context) (res graphql.Marshaler) { diff --git a/graph/schema.graphqls b/graph/schema.graphqls index 97d82cb..ad1df55 100644 --- a/graph/schema.graphqls +++ b/graph/schema.graphqls @@ -1,6 +1,7 @@ type Query { organizations: [Organization!]! @auth(user: true) supergraph(ref: String!, isAfter: String): Supergraph! @auth(organization: true) + latestSchema(ref: String!): SchemaUpdate! @auth(organization: true) } type Mutation { diff --git a/graph/schema.resolvers.go b/graph/schema.resolvers.go index 00f6202..b426df5 100644 --- a/graph/schema.resolvers.go +++ b/graph/schema.resolvers.go @@ -238,6 +238,72 @@ func (r *queryResolver) Supergraph(ctx context.Context, ref string, isAfter *str }, nil } +// LatestSchema is the resolver for the latestSchema field. +func (r *queryResolver) LatestSchema(ctx context.Context, ref string) (*model.SchemaUpdate, error) { + orgId := middleware.OrganizationFromContext(ctx) + + r.Logger.Info("LatestSchema query", + "ref", ref, + "orgId", orgId, + ) + + _, err := r.apiKeyCanAccessRef(ctx, ref, false) + if err != nil { + r.Logger.Error("API key cannot access ref", "error", err, "ref", ref) + return nil, err + } + + // Get current services and schema + services, lastUpdate := r.Cache.Services(orgId, ref, "") + r.Logger.Info("Fetching latest schema", + "ref", ref, + "orgId", orgId, + "lastUpdate", lastUpdate, + "servicesCount", len(services), + ) + + subGraphs := make([]*model.SubGraph, len(services)) + for i, id := range services { + sg, err := r.fetchSubGraph(ctx, id) + if err != nil { + r.Logger.Error("fetch subgraph", "error", err, "id", id) + return nil, err + } + subGraphs[i] = &model.SubGraph{ + ID: sg.ID.String(), + Service: sg.Service, + URL: sg.Url, + WsURL: sg.WSUrl, + Sdl: sg.Sdl, + ChangedBy: sg.ChangedBy, + ChangedAt: sg.ChangedAt, + } + } + + // Generate Cosmo router config + cosmoConfig, err := GenerateCosmoRouterConfig(subGraphs) + if err != nil { + r.Logger.Error("generate cosmo config", "error", err) + cosmoConfig = "" // Return empty if generation fails + } + + update := &model.SchemaUpdate{ + Ref: ref, + ID: lastUpdate, + SubGraphs: subGraphs, + CosmoRouterConfig: &cosmoConfig, + } + + r.Logger.Info("Latest schema fetched", + "ref", update.Ref, + "id", update.ID, + "subGraphsCount", len(update.SubGraphs), + "cosmoConfigLength", len(cosmoConfig), + ) + + return update, nil +} + // SchemaUpdates is the resolver for the schemaUpdates field. func (r *subscriptionResolver) SchemaUpdates(ctx context.Context, ref string) (<-chan *model.SchemaUpdate, error) { orgId := middleware.OrganizationFromContext(ctx) From df44ddbb8ed406913a282c821295baf19d518d43 Mon Sep 17 00:00:00 2001 From: Joakim Olsson Date: Thu, 20 Nov 2025 17:06:45 +0100 Subject: [PATCH 4/5] test: enhance assertions for version and subscription config Update version check to validate it is a non-empty string. Improve assertions for the subscription configuration by ensuring the presence of required fields and correct types. Adapt checks for routing URLs and decentralize subscription validation for more robust testing. These changes ensure better verification of configuration integrity and correctness in test scenarios. --- graph/cosmo_test.go | 132 +++++++++++++++++++++++++++++--------------- 1 file changed, 89 insertions(+), 43 deletions(-) diff --git a/graph/cosmo_test.go b/graph/cosmo_test.go index 60cd3a0..202bdc8 100644 --- a/graph/cosmo_test.go +++ b/graph/cosmo_test.go @@ -33,7 +33,10 @@ func TestGenerateCosmoRouterConfig(t *testing.T) { err := json.Unmarshal([]byte(config), &result) require.NoError(t, err, "Config should be valid JSON") - assert.Equal(t, "1", result["version"], "Version should be 1") + // Version is a UUID string from wgc + version, ok := result["version"].(string) + require.True(t, ok, "Version should be a string") + assert.NotEmpty(t, version, "Version should not be empty") subgraphs, ok := result["subgraphs"].([]interface{}) require.True(t, ok, "subgraphs should be an array") @@ -41,14 +44,26 @@ func TestGenerateCosmoRouterConfig(t *testing.T) { sg := subgraphs[0].(map[string]interface{}) assert.Equal(t, "test-service", sg["name"]) - assert.Equal(t, "http://localhost:4001/query", sg["routing_url"]) - assert.Equal(t, "type Query { test: String }", sg["sdl"]) + assert.Equal(t, "http://localhost:4001/query", sg["routingUrl"]) - subscription, ok := sg["subscription"].(map[string]interface{}) + // Check that datasource configurations include subscription settings + engineConfig, ok := result["engineConfig"].(map[string]interface{}) + require.True(t, ok, "Should have engineConfig") + + dsConfigs, ok := engineConfig["datasourceConfigurations"].([]interface{}) + require.True(t, ok && len(dsConfigs) > 0, "Should have datasource configurations") + + ds := dsConfigs[0].(map[string]interface{}) + customGraphql, ok := ds["customGraphql"].(map[string]interface{}) + require.True(t, ok, "Should have customGraphql config") + + subscription, ok := customGraphql["subscription"].(map[string]interface{}) require.True(t, ok, "Should have subscription config") - assert.Equal(t, "ws://localhost:4001/query", subscription["url"]) - assert.Equal(t, "ws", subscription["protocol"]) - assert.Equal(t, "graphql-ws", subscription["websocket_subprotocol"]) + assert.True(t, subscription["enabled"].(bool), "Subscription should be enabled") + + subUrl, ok := subscription["url"].(map[string]interface{}) + require.True(t, ok, "Should have subscription URL") + assert.Equal(t, "ws://localhost:4001/query", subUrl["staticVariableContent"]) }, }, { @@ -80,18 +95,28 @@ func TestGenerateCosmoRouterConfig(t *testing.T) { subgraphs := result["subgraphs"].([]interface{}) assert.Len(t, subgraphs, 3, "Should have 3 subgraphs") - // Check first service has no subscription + // Check service names sg1 := subgraphs[0].(map[string]interface{}) assert.Equal(t, "service-1", sg1["name"]) - _, hasSubscription := sg1["subscription"] - assert.False(t, hasSubscription, "service-1 should not have subscription config") - // Check third service has subscription sg3 := subgraphs[2].(map[string]interface{}) assert.Equal(t, "service-3", sg3["name"]) - subscription, hasSubscription := sg3["subscription"] - assert.True(t, hasSubscription, "service-3 should have subscription config") - assert.NotNil(t, subscription) + + // Check that datasource configurations include subscription for service-3 + engineConfig, ok := result["engineConfig"].(map[string]interface{}) + require.True(t, ok, "Should have engineConfig") + + dsConfigs, ok := engineConfig["datasourceConfigurations"].([]interface{}) + require.True(t, ok && len(dsConfigs) == 3, "Should have 3 datasource configurations") + + // Find service-3's datasource config (should have subscription enabled) + ds3 := dsConfigs[2].(map[string]interface{}) + customGraphql, ok := ds3["customGraphql"].(map[string]interface{}) + require.True(t, ok, "Service-3 should have customGraphql config") + + subscription, ok := customGraphql["subscription"].(map[string]interface{}) + require.True(t, ok, "Service-3 should have subscription config") + assert.True(t, subscription["enabled"].(bool), "Service-3 subscription should be enabled") }, }, { @@ -113,39 +138,43 @@ func TestGenerateCosmoRouterConfig(t *testing.T) { subgraphs := result["subgraphs"].([]interface{}) sg := subgraphs[0].(map[string]interface{}) - // Should not have routing_url or subscription fields if URLs are nil - _, hasRoutingURL := sg["routing_url"] - assert.False(t, hasRoutingURL, "Should not have routing_url when URL is nil") + // Should not have routing URL when URL is nil + _, hasRoutingURL := sg["routingUrl"] + assert.False(t, hasRoutingURL, "Should not have routingUrl when URL is nil") - _, hasSubscription := sg["subscription"] - assert.False(t, hasSubscription, "Should not have subscription when WsURL is nil") + // Check datasource configurations don't have subscription enabled + engineConfig, ok := result["engineConfig"].(map[string]interface{}) + require.True(t, ok, "Should have engineConfig") + + dsConfigs, ok := engineConfig["datasourceConfigurations"].([]interface{}) + require.True(t, ok && len(dsConfigs) > 0, "Should have datasource configurations") + + ds := dsConfigs[0].(map[string]interface{}) + customGraphql, ok := ds["customGraphql"].(map[string]interface{}) + require.True(t, ok, "Should have customGraphql config") + + subscription, ok := customGraphql["subscription"].(map[string]interface{}) + if ok { + // wgc always enables subscription but URL should be empty when WsURL is nil + subUrl, hasUrl := subscription["url"].(map[string]interface{}) + if hasUrl { + _, hasStaticContent := subUrl["staticVariableContent"] + assert.False(t, hasStaticContent, "Subscription URL should be empty when WsURL is nil") + } + } }, }, { name: "empty subgraphs", subGraphs: []*model.SubGraph{}, - wantErr: false, - validate: func(t *testing.T, config string) { - var result map[string]interface{} - err := json.Unmarshal([]byte(config), &result) - require.NoError(t, err) - - subgraphs := result["subgraphs"].([]interface{}) - assert.Len(t, subgraphs, 0, "Should have empty subgraphs array") - }, + wantErr: true, + validate: nil, }, { name: "nil subgraphs", subGraphs: nil, - wantErr: false, - validate: func(t *testing.T, config string) { - var result map[string]interface{} - err := json.Unmarshal([]byte(config), &result) - require.NoError(t, err) - - subgraphs := result["subgraphs"].([]interface{}) - assert.Len(t, subgraphs, 0, "Should handle nil subgraphs as empty array") - }, + wantErr: true, + validate: nil, }, { name: "complex SDL with multiple types", @@ -173,13 +202,30 @@ func TestGenerateCosmoRouterConfig(t *testing.T) { err := json.Unmarshal([]byte(config), &result) require.NoError(t, err) - subgraphs := result["subgraphs"].([]interface{}) - sg := subgraphs[0].(map[string]interface{}) - sdl := sg["sdl"].(string) + // Check the composed graphqlSchema contains the types + engineConfig, ok := result["engineConfig"].(map[string]interface{}) + require.True(t, ok, "Should have engineConfig") - assert.Contains(t, sdl, "type Query") - assert.Contains(t, sdl, "type User") - assert.Contains(t, sdl, "email: String!") + graphqlSchema, ok := engineConfig["graphqlSchema"].(string) + require.True(t, ok, "Should have graphqlSchema") + + assert.Contains(t, graphqlSchema, "Query", "Schema should contain Query type") + assert.Contains(t, graphqlSchema, "User", "Schema should contain User type") + + // Check datasource has the original SDL + dsConfigs, ok := engineConfig["datasourceConfigurations"].([]interface{}) + require.True(t, ok && len(dsConfigs) > 0, "Should have datasource configurations") + + ds := dsConfigs[0].(map[string]interface{}) + customGraphql, ok := ds["customGraphql"].(map[string]interface{}) + require.True(t, ok, "Should have customGraphql config") + + federation, ok := customGraphql["federation"].(map[string]interface{}) + require.True(t, ok, "Should have federation config") + + serviceSdl, ok := federation["serviceSdl"].(string) + require.True(t, ok, "Should have serviceSdl") + assert.Contains(t, serviceSdl, "email: String!", "SDL should contain email field") }, }, } From 47dbf827f2cf994074c5e792767c55f60fd6da0e Mon Sep 17 00:00:00 2001 From: Joakim Olsson Date: Thu, 20 Nov 2025 21:09:00 +0100 Subject: [PATCH 5/5] fix: add command executor interface for better testing Introduce the CommandExecutor interface to abstract command execution, allowing for easier mocking in tests. Implement DefaultCommandExecutor to use the os/exec package for executing commands. Update the GenerateCosmoRouterConfig function to utilize the new GenerateCosmoRouterConfigWithExecutor function that accepts a command executor parameter. Add a MockCommandExecutor for simulating command execution in unit tests with realistic behavior based on input YAML files. This enhances test coverage and simplifies error handling. --- graph/cosmo.go | 25 +++++- graph/cosmo_test.go | 212 +++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 233 insertions(+), 4 deletions(-) diff --git a/graph/cosmo.go b/graph/cosmo.go index 69deb88..ac62fe2 100644 --- a/graph/cosmo.go +++ b/graph/cosmo.go @@ -11,9 +11,30 @@ import ( "gitlab.com/unboundsoftware/schemas/graph/model" ) +// CommandExecutor is an interface for executing external commands +// This allows for mocking in tests +type CommandExecutor interface { + Execute(name string, args ...string) ([]byte, error) +} + +// DefaultCommandExecutor implements CommandExecutor using os/exec +type DefaultCommandExecutor struct{} + +// Execute runs a command and returns its combined output +func (e *DefaultCommandExecutor) Execute(name string, args ...string) ([]byte, error) { + cmd := exec.Command(name, args...) + return cmd.CombinedOutput() +} + // GenerateCosmoRouterConfig generates a Cosmo Router execution config from subgraphs // using the official wgc CLI tool via npx func GenerateCosmoRouterConfig(subGraphs []*model.SubGraph) (string, error) { + return GenerateCosmoRouterConfigWithExecutor(subGraphs, &DefaultCommandExecutor{}) +} + +// GenerateCosmoRouterConfigWithExecutor generates a Cosmo Router execution config from subgraphs +// using the provided command executor (useful for testing) +func GenerateCosmoRouterConfigWithExecutor(subGraphs []*model.SubGraph, executor CommandExecutor) (string, error) { if len(subGraphs) == 0 { return "", fmt.Errorf("no subgraphs provided") } @@ -85,13 +106,11 @@ func GenerateCosmoRouterConfig(subGraphs []*model.SubGraph) (string, error) { // Execute wgc router compose // wgc is installed globally in the Docker image outputFile := filepath.Join(tmpDir, "config.json") - cmd := exec.Command("wgc", "router", "compose", + output, err := executor.Execute("wgc", "router", "compose", "--input", inputFile, "--out", outputFile, "--suppress-warnings", ) - - output, err := cmd.CombinedOutput() if err != nil { return "", fmt.Errorf("wgc router compose failed: %w\nOutput: %s", err, string(output)) } diff --git a/graph/cosmo_test.go b/graph/cosmo_test.go index 202bdc8..f560fbf 100644 --- a/graph/cosmo_test.go +++ b/graph/cosmo_test.go @@ -2,14 +2,185 @@ package graph import ( "encoding/json" + "fmt" + "os" + "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" "gitlab.com/unboundsoftware/schemas/graph/model" ) +// MockCommandExecutor implements CommandExecutor for testing +type MockCommandExecutor struct { + // CallCount tracks how many times Execute was called + CallCount int + // LastArgs stores the arguments from the last call + LastArgs []string + // Error can be set to simulate command failures + Error error +} + +// Execute mocks the wgc command by generating a realistic config.json file +func (m *MockCommandExecutor) Execute(name string, args ...string) ([]byte, error) { + m.CallCount++ + m.LastArgs = append([]string{name}, args...) + + if m.Error != nil { + return nil, m.Error + } + + // Parse the input file to understand what subgraphs we're composing + var inputFile, outputFile string + for i, arg := range args { + if arg == "--input" && i+1 < len(args) { + inputFile = args[i+1] + } + if arg == "--out" && i+1 < len(args) { + outputFile = args[i+1] + } + } + + if inputFile == "" || outputFile == "" { + return nil, fmt.Errorf("missing required arguments") + } + + // Read the input YAML to get subgraph information + inputData, err := os.ReadFile(inputFile) + if err != nil { + return nil, fmt.Errorf("failed to read input file: %w", err) + } + + var input struct { + Version int `yaml:"version"` + Subgraphs []struct { + Name string `yaml:"name"` + RoutingURL string `yaml:"routing_url,omitempty"` + Schema map[string]string `yaml:"schema"` + Subscription map[string]interface{} `yaml:"subscription,omitempty"` + } `yaml:"subgraphs"` + } + + if err := yaml.Unmarshal(inputData, &input); err != nil { + return nil, fmt.Errorf("failed to parse input YAML: %w", err) + } + + // Generate a realistic Cosmo Router config based on the input + config := map[string]interface{}{ + "version": "mock-version-uuid", + "subgraphs": func() []map[string]interface{} { + subgraphs := make([]map[string]interface{}, len(input.Subgraphs)) + for i, sg := range input.Subgraphs { + subgraph := map[string]interface{}{ + "id": fmt.Sprintf("mock-id-%d", i), + "name": sg.Name, + } + if sg.RoutingURL != "" { + subgraph["routingUrl"] = sg.RoutingURL + } + subgraphs[i] = subgraph + } + return subgraphs + }(), + "engineConfig": map[string]interface{}{ + "graphqlSchema": generateMockSchema(input.Subgraphs), + "datasourceConfigurations": func() []map[string]interface{} { + dsConfigs := make([]map[string]interface{}, len(input.Subgraphs)) + for i, sg := range input.Subgraphs { + // Read SDL from file + sdl := "" + if schemaFile, ok := sg.Schema["file"]; ok { + if sdlData, err := os.ReadFile(schemaFile); err == nil { + sdl = string(sdlData) + } + } + + dsConfig := map[string]interface{}{ + "id": fmt.Sprintf("datasource-%d", i), + "kind": "GRAPHQL", + "customGraphql": map[string]interface{}{ + "federation": map[string]interface{}{ + "enabled": true, + "serviceSdl": sdl, + }, + "subscription": func() map[string]interface{} { + if len(sg.Subscription) > 0 { + return map[string]interface{}{ + "enabled": true, + "url": map[string]interface{}{ + "staticVariableContent": sg.Subscription["url"], + }, + "protocol": sg.Subscription["protocol"], + "websocketSubprotocol": sg.Subscription["websocket_subprotocol"], + } + } + return map[string]interface{}{ + "enabled": false, + } + }(), + }, + } + dsConfigs[i] = dsConfig + } + return dsConfigs + }(), + }, + } + + // Write the config to the output file + configJSON, err := json.Marshal(config) + if err != nil { + return nil, fmt.Errorf("failed to marshal config: %w", err) + } + + if err := os.WriteFile(outputFile, configJSON, 0o644); err != nil { + return nil, fmt.Errorf("failed to write output file: %w", err) + } + + return []byte("Success"), nil +} + +// generateMockSchema creates a simple merged schema from subgraphs +func generateMockSchema(subgraphs []struct { + Name string `yaml:"name"` + RoutingURL string `yaml:"routing_url,omitempty"` + Schema map[string]string `yaml:"schema"` + Subscription map[string]interface{} `yaml:"subscription,omitempty"` +}, +) string { + schema := strings.Builder{} + schema.WriteString("schema {\n query: Query\n") + + // Check if any subgraph has subscriptions + hasSubscriptions := false + for _, sg := range subgraphs { + if len(sg.Subscription) > 0 { + hasSubscriptions = true + break + } + } + + if hasSubscriptions { + schema.WriteString(" subscription: Subscription\n") + } + schema.WriteString("}\n\n") + + // Add types by reading SDL files + for _, sg := range subgraphs { + if schemaFile, ok := sg.Schema["file"]; ok { + if sdlData, err := os.ReadFile(schemaFile); err == nil { + schema.WriteString(string(sdlData)) + schema.WriteString("\n") + } + } + } + + return schema.String() +} + func TestGenerateCosmoRouterConfig(t *testing.T) { tests := []struct { name string @@ -232,16 +403,28 @@ func TestGenerateCosmoRouterConfig(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - config, err := GenerateCosmoRouterConfig(tt.subGraphs) + // Use mock executor for all tests + mockExecutor := &MockCommandExecutor{} + config, err := GenerateCosmoRouterConfigWithExecutor(tt.subGraphs, mockExecutor) if tt.wantErr { assert.Error(t, err) + // Verify executor was not called for error cases + if len(tt.subGraphs) == 0 { + assert.Equal(t, 0, mockExecutor.CallCount, "Should not call executor for empty subgraphs") + } return } require.NoError(t, err) assert.NotEmpty(t, config, "Config should not be empty") + // Verify executor was called correctly + assert.Equal(t, 1, mockExecutor.CallCount, "Should call executor once") + assert.Equal(t, "wgc", mockExecutor.LastArgs[0], "Should call wgc command") + assert.Contains(t, mockExecutor.LastArgs, "router", "Should include 'router' arg") + assert.Contains(t, mockExecutor.LastArgs, "compose", "Should include 'compose' arg") + if tt.validate != nil { tt.validate(t, config) } @@ -249,6 +432,33 @@ func TestGenerateCosmoRouterConfig(t *testing.T) { } } +// TestGenerateCosmoRouterConfig_MockError tests error handling with mock executor +func TestGenerateCosmoRouterConfig_MockError(t *testing.T) { + subGraphs := []*model.SubGraph{ + { + Service: "test-service", + URL: stringPtr("http://localhost:4001/query"), + Sdl: "type Query { test: String }", + }, + } + + // Create a mock executor that returns an error + mockExecutor := &MockCommandExecutor{ + Error: fmt.Errorf("simulated wgc failure"), + } + + config, err := GenerateCosmoRouterConfigWithExecutor(subGraphs, mockExecutor) + + // Verify error is propagated + assert.Error(t, err) + assert.Contains(t, err.Error(), "wgc router compose failed") + assert.Contains(t, err.Error(), "simulated wgc failure") + assert.Empty(t, config) + + // Verify executor was called + assert.Equal(t, 1, mockExecutor.CallCount, "Should have attempted to call executor") +} + // Helper function for tests func stringPtr(s string) *string { return &s