diff --git a/Dockerfile b/Dockerfile index 3512d54..67e36fd 100644 --- a/Dockerfile +++ b/Dockerfile @@ -24,9 +24,17 @@ RUN GOOS=linux GOARCH=amd64 go build \ FROM scratch as export COPY --from=build /build/coverage.txt / -FROM scratch +FROM node:22-alpine ENV TZ Europe/Stockholm + +# Install wgc CLI globally for Cosmo Router composition +RUN npm install -g wgc@latest + +# Copy timezone data and certificates COPY --from=build /usr/share/zoneinfo /usr/share/zoneinfo COPY --from=build /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ + +# Copy the service binary COPY --from=build /release/service / + CMD ["/service"] diff --git a/cmd/service/service.go b/cmd/service/service.go index dbefcf2..147d612 100644 --- a/cmd/service/service.go +++ b/cmd/service/service.go @@ -30,6 +30,7 @@ import ( "gitlab.com/unboundsoftware/schemas/domain" "gitlab.com/unboundsoftware/schemas/graph" "gitlab.com/unboundsoftware/schemas/graph/generated" + "gitlab.com/unboundsoftware/schemas/hash" "gitlab.com/unboundsoftware/schemas/logging" "gitlab.com/unboundsoftware/schemas/middleware" "gitlab.com/unboundsoftware/schemas/monitoring" @@ -210,6 +211,24 @@ func start(closeEvents chan error, logger *slog.Logger, connectToAmqpFunc func(u srv.AddTransport(transport.Websocket{ KeepAlivePingInterval: 10 * time.Second, + InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) { + // Extract API key from WebSocket connection_init payload + if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" { + logger.Info("WebSocket connection with API key", "has_key", true) + ctx = context.WithValue(ctx, middleware.ApiKey, apiKey) + + // Look up organization by API key (same logic as auth middleware) + if organization := serviceCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil { + logger.Info("WebSocket: Organization found for API key", "org_id", organization.ID.String()) + ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization) + } else { + logger.Warn("WebSocket: No organization found for API key") + } + } else { + logger.Info("WebSocket connection without API key") + } + return ctx, &initPayload, nil + }, }) srv.AddTransport(transport.Options{}) srv.AddTransport(transport.GET{}) diff --git a/cmd/service/service_test.go b/cmd/service/service_test.go new file mode 100644 index 0000000..c9680d9 --- /dev/null +++ b/cmd/service/service_test.go @@ -0,0 +1,336 @@ +package main + +import ( + "context" + "testing" + + "github.com/99designs/gqlgen/graphql/handler/transport" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gitlab.com/unboundsoftware/eventsourced/eventsourced" + + "gitlab.com/unboundsoftware/schemas/domain" + "gitlab.com/unboundsoftware/schemas/hash" + "gitlab.com/unboundsoftware/schemas/middleware" +) + +// MockCache is a mock implementation for testing +type MockCache struct { + organizations map[string]*domain.Organization +} + +func (m *MockCache) OrganizationByAPIKey(apiKey string) *domain.Organization { + return m.organizations[apiKey] +} + +func TestWebSocketInitFunc_WithValidAPIKey(t *testing.T) { + // Setup + orgID := uuid.New() + org := &domain.Organization{ + BaseAggregate: eventsourced.BaseAggregate{ + ID: eventsourced.IdFromString(orgID.String()), + }, + Name: "Test Organization", + } + + apiKey := "test-api-key-123" + hashedKey := hash.String(apiKey) + + mockCache := &MockCache{ + organizations: map[string]*domain.Organization{ + hashedKey: org, + }, + } + + // Create InitFunc (simulating the WebSocket InitFunc logic) + initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) { + // Extract API key from WebSocket connection_init payload + if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" { + ctx = context.WithValue(ctx, middleware.ApiKey, apiKey) + + // Look up organization by API key + if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil { + ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization) + } + } + return ctx, &initPayload, nil + } + + // Test + ctx := context.Background() + initPayload := transport.InitPayload{ + "X-Api-Key": apiKey, + } + + resultCtx, resultPayload, err := initFunc(ctx, initPayload) + + // Assert + require.NoError(t, err) + require.NotNil(t, resultPayload) + + // Check API key is in context + if value := resultCtx.Value(middleware.ApiKey); value != nil { + assert.Equal(t, apiKey, value.(string)) + } else { + t.Fatal("API key not found in context") + } + + // Check organization is in context + if value := resultCtx.Value(middleware.OrganizationKey); value != nil { + capturedOrg, ok := value.(domain.Organization) + require.True(t, ok, "Organization should be of correct type") + assert.Equal(t, org.Name, capturedOrg.Name) + assert.Equal(t, org.ID.String(), capturedOrg.ID.String()) + } else { + t.Fatal("Organization not found in context") + } +} + +func TestWebSocketInitFunc_WithInvalidAPIKey(t *testing.T) { + // Setup + mockCache := &MockCache{ + organizations: map[string]*domain.Organization{}, + } + + apiKey := "invalid-api-key" + + // Create InitFunc + initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) { + // Extract API key from WebSocket connection_init payload + if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" { + ctx = context.WithValue(ctx, middleware.ApiKey, apiKey) + + // Look up organization by API key + if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil { + ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization) + } + } + return ctx, &initPayload, nil + } + + // Test + ctx := context.Background() + initPayload := transport.InitPayload{ + "X-Api-Key": apiKey, + } + + resultCtx, resultPayload, err := initFunc(ctx, initPayload) + + // Assert + require.NoError(t, err) + require.NotNil(t, resultPayload) + + // Check API key is in context + if value := resultCtx.Value(middleware.ApiKey); value != nil { + assert.Equal(t, apiKey, value.(string)) + } else { + t.Fatal("API key not found in context") + } + + // Check organization is NOT in context (since API key is invalid) + value := resultCtx.Value(middleware.OrganizationKey) + assert.Nil(t, value, "Organization should not be set for invalid API key") +} + +func TestWebSocketInitFunc_WithoutAPIKey(t *testing.T) { + // Setup + mockCache := &MockCache{ + organizations: map[string]*domain.Organization{}, + } + + // Create InitFunc + initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) { + // Extract API key from WebSocket connection_init payload + if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" { + ctx = context.WithValue(ctx, middleware.ApiKey, apiKey) + + // Look up organization by API key + if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil { + ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization) + } + } + return ctx, &initPayload, nil + } + + // Test + ctx := context.Background() + initPayload := transport.InitPayload{} + + resultCtx, resultPayload, err := initFunc(ctx, initPayload) + + // Assert + require.NoError(t, err) + require.NotNil(t, resultPayload) + + // Check API key is NOT in context + value := resultCtx.Value(middleware.ApiKey) + assert.Nil(t, value, "API key should not be set when not provided") + + // Check organization is NOT in context + value = resultCtx.Value(middleware.OrganizationKey) + assert.Nil(t, value, "Organization should not be set when API key is not provided") +} + +func TestWebSocketInitFunc_WithEmptyAPIKey(t *testing.T) { + // Setup + mockCache := &MockCache{ + organizations: map[string]*domain.Organization{}, + } + + // Create InitFunc + initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) { + // Extract API key from WebSocket connection_init payload + if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" { + ctx = context.WithValue(ctx, middleware.ApiKey, apiKey) + + // Look up organization by API key + if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil { + ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization) + } + } + return ctx, &initPayload, nil + } + + // Test + ctx := context.Background() + initPayload := transport.InitPayload{ + "X-Api-Key": "", // Empty string + } + + resultCtx, resultPayload, err := initFunc(ctx, initPayload) + + // Assert + require.NoError(t, err) + require.NotNil(t, resultPayload) + + // Check API key is NOT in context (because empty string fails the condition) + value := resultCtx.Value(middleware.ApiKey) + assert.Nil(t, value, "API key should not be set when empty") + + // Check organization is NOT in context + value = resultCtx.Value(middleware.OrganizationKey) + assert.Nil(t, value, "Organization should not be set when API key is empty") +} + +func TestWebSocketInitFunc_WithWrongTypeAPIKey(t *testing.T) { + // Setup + mockCache := &MockCache{ + organizations: map[string]*domain.Organization{}, + } + + // Create InitFunc + initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) { + // Extract API key from WebSocket connection_init payload + if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" { + ctx = context.WithValue(ctx, middleware.ApiKey, apiKey) + + // Look up organization by API key + if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil { + ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization) + } + } + return ctx, &initPayload, nil + } + + // Test + ctx := context.Background() + initPayload := transport.InitPayload{ + "X-Api-Key": 12345, // Wrong type (int instead of string) + } + + resultCtx, resultPayload, err := initFunc(ctx, initPayload) + + // Assert + require.NoError(t, err) + require.NotNil(t, resultPayload) + + // Check API key is NOT in context (type assertion fails) + value := resultCtx.Value(middleware.ApiKey) + assert.Nil(t, value, "API key should not be set when wrong type") + + // Check organization is NOT in context + value = resultCtx.Value(middleware.OrganizationKey) + assert.Nil(t, value, "Organization should not be set when API key has wrong type") +} + +func TestWebSocketInitFunc_WithMultipleOrganizations(t *testing.T) { + // Setup - create multiple organizations + org1ID := uuid.New() + org1 := &domain.Organization{ + BaseAggregate: eventsourced.BaseAggregate{ + ID: eventsourced.IdFromString(org1ID.String()), + }, + Name: "Organization 1", + } + + org2ID := uuid.New() + org2 := &domain.Organization{ + BaseAggregate: eventsourced.BaseAggregate{ + ID: eventsourced.IdFromString(org2ID.String()), + }, + Name: "Organization 2", + } + + apiKey1 := "api-key-org-1" + apiKey2 := "api-key-org-2" + hashedKey1 := hash.String(apiKey1) + hashedKey2 := hash.String(apiKey2) + + mockCache := &MockCache{ + organizations: map[string]*domain.Organization{ + hashedKey1: org1, + hashedKey2: org2, + }, + } + + // Create InitFunc + initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) { + // Extract API key from WebSocket connection_init payload + if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" { + ctx = context.WithValue(ctx, middleware.ApiKey, apiKey) + + // Look up organization by API key + if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil { + ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization) + } + } + return ctx, &initPayload, nil + } + + // Test with first API key + ctx1 := context.Background() + initPayload1 := transport.InitPayload{ + "X-Api-Key": apiKey1, + } + + resultCtx1, _, err := initFunc(ctx1, initPayload1) + require.NoError(t, err) + + if value := resultCtx1.Value(middleware.OrganizationKey); value != nil { + capturedOrg, ok := value.(domain.Organization) + require.True(t, ok) + assert.Equal(t, org1.Name, capturedOrg.Name) + assert.Equal(t, org1.ID.String(), capturedOrg.ID.String()) + } else { + t.Fatal("Organization 1 not found in context") + } + + // Test with second API key + ctx2 := context.Background() + initPayload2 := transport.InitPayload{ + "X-Api-Key": apiKey2, + } + + resultCtx2, _, err := initFunc(ctx2, initPayload2) + require.NoError(t, err) + + if value := resultCtx2.Value(middleware.OrganizationKey); value != nil { + capturedOrg, ok := value.(domain.Organization) + require.True(t, ok) + assert.Equal(t, org2.Name, capturedOrg.Name) + assert.Equal(t, org2.ID.String(), capturedOrg.ID.String()) + } else { + t.Fatal("Organization 2 not found in context") + } +} diff --git a/go.mod b/go.mod index 127627f..6a4d2f0 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/apex/log v1.9.0 github.com/auth0/go-jwt-middleware/v2 v2.3.0 github.com/golang-jwt/jwt/v5 v5.3.0 + github.com/google/uuid v1.6.0 github.com/jmoiron/sqlx v1.4.0 github.com/pkg/errors v0.9.1 github.com/pressly/goose/v3 v3.26.0 @@ -30,6 +31,7 @@ require ( go.opentelemetry.io/otel/sdk/log v0.14.0 go.opentelemetry.io/otel/sdk/metric v1.38.0 go.opentelemetry.io/otel/trace v1.38.0 + gopkg.in/yaml.v3 v3.0.1 ) require ( @@ -41,7 +43,6 @@ require ( github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-viper/mapstructure/v2 v2.4.0 // indirect - github.com/google/uuid v1.6.0 // indirect github.com/gorilla/websocket v1.5.1 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect @@ -51,6 +52,7 @@ require ( github.com/rabbitmq/amqp091-go v1.10.0 // indirect github.com/sethvargo/go-retry v0.3.0 // indirect github.com/sosodev/duration v1.3.1 // indirect + github.com/stretchr/objx v0.5.2 // indirect github.com/tidwall/gjson v1.17.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect @@ -72,5 +74,4 @@ require ( google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 // indirect google.golang.org/grpc v1.75.0 // indirect google.golang.org/protobuf v1.36.10 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index fd5c8e7..3abb5fe 100644 --- a/go.sum +++ b/go.sum @@ -141,6 +141,8 @@ github.com/sosodev/duration v1.3.1/go.mod h1:RQIBBX0+fMLc/D9+Jb/fwvVmo0eZvDDEERA github.com/sparetimecoders/goamqp v0.3.3 h1:z/nfTPmrjeU/rIVuNOgsVLCimp3WFoNFvS3ZzXRJ6HE= github.com/sparetimecoders/goamqp v0.3.3/go.mod h1:W9NRCpWLE+Vruv2dcRSbszNil2O826d2Nv6kAkETW5o= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= diff --git a/graph/cosmo.go b/graph/cosmo.go index 150e162..ac62fe2 100644 --- a/graph/cosmo.go +++ b/graph/cosmo.go @@ -1,54 +1,125 @@ package graph import ( - "encoding/json" "fmt" + "os" + "os/exec" + "path/filepath" + + "gopkg.in/yaml.v3" "gitlab.com/unboundsoftware/schemas/graph/model" ) -// GenerateCosmoRouterConfig generates a Cosmo Router execution config from subgraphs -func GenerateCosmoRouterConfig(subGraphs []*model.SubGraph) (string, error) { - // Build the Cosmo router config structure - // This is a simplified version - you may need to adjust based on actual Cosmo requirements - config := map[string]interface{}{ - "version": "1", - "subgraphs": convertSubGraphsToCosmo(subGraphs), - // Add other Cosmo-specific configuration as needed - } - - // Marshal to JSON - configJSON, err := json.MarshalIndent(config, "", " ") - if err != nil { - return "", fmt.Errorf("marshal cosmo config: %w", err) - } - - return string(configJSON), nil +// CommandExecutor is an interface for executing external commands +// This allows for mocking in tests +type CommandExecutor interface { + Execute(name string, args ...string) ([]byte, error) } -func convertSubGraphsToCosmo(subGraphs []*model.SubGraph) []map[string]interface{} { - cosmoSubgraphs := make([]map[string]interface{}, 0, len(subGraphs)) +// DefaultCommandExecutor implements CommandExecutor using os/exec +type DefaultCommandExecutor struct{} + +// Execute runs a command and returns its combined output +func (e *DefaultCommandExecutor) Execute(name string, args ...string) ([]byte, error) { + cmd := exec.Command(name, args...) + return cmd.CombinedOutput() +} + +// GenerateCosmoRouterConfig generates a Cosmo Router execution config from subgraphs +// using the official wgc CLI tool via npx +func GenerateCosmoRouterConfig(subGraphs []*model.SubGraph) (string, error) { + return GenerateCosmoRouterConfigWithExecutor(subGraphs, &DefaultCommandExecutor{}) +} + +// GenerateCosmoRouterConfigWithExecutor generates a Cosmo Router execution config from subgraphs +// using the provided command executor (useful for testing) +func GenerateCosmoRouterConfigWithExecutor(subGraphs []*model.SubGraph, executor CommandExecutor) (string, error) { + if len(subGraphs) == 0 { + return "", fmt.Errorf("no subgraphs provided") + } + + // Create a temporary directory for composition + tmpDir, err := os.MkdirTemp("", "cosmo-compose-*") + if err != nil { + return "", fmt.Errorf("create temp dir: %w", err) + } + defer os.RemoveAll(tmpDir) + + // Write each subgraph SDL to a file + type SubgraphConfig struct { + Name string `yaml:"name"` + RoutingURL string `yaml:"routing_url,omitempty"` + Schema map[string]string `yaml:"schema"` + Subscription map[string]interface{} `yaml:"subscription,omitempty"` + } + + type InputConfig struct { + Version int `yaml:"version"` + Subgraphs []SubgraphConfig `yaml:"subgraphs"` + } + + inputConfig := InputConfig{ + Version: 1, + Subgraphs: make([]SubgraphConfig, 0, len(subGraphs)), + } for _, sg := range subGraphs { - cosmoSg := map[string]interface{}{ - "name": sg.Service, - "sdl": sg.Sdl, + // Write SDL to a temp file + schemaFile := filepath.Join(tmpDir, fmt.Sprintf("%s.graphql", sg.Service)) + if err := os.WriteFile(schemaFile, []byte(sg.Sdl), 0o644); err != nil { + return "", fmt.Errorf("write schema file for %s: %w", sg.Service, err) + } + + subgraphCfg := SubgraphConfig{ + Name: sg.Service, + Schema: map[string]string{ + "file": schemaFile, + }, } if sg.URL != nil { - cosmoSg["routing_url"] = *sg.URL + subgraphCfg.RoutingURL = *sg.URL } if sg.WsURL != nil { - cosmoSg["subscription"] = map[string]interface{}{ + subgraphCfg.Subscription = map[string]interface{}{ "url": *sg.WsURL, "protocol": "ws", "websocket_subprotocol": "graphql-ws", } } - cosmoSubgraphs = append(cosmoSubgraphs, cosmoSg) + inputConfig.Subgraphs = append(inputConfig.Subgraphs, subgraphCfg) } - return cosmoSubgraphs + // Write input config YAML + inputFile := filepath.Join(tmpDir, "input.yaml") + inputYAML, err := yaml.Marshal(inputConfig) + if err != nil { + return "", fmt.Errorf("marshal input config: %w", err) + } + if err := os.WriteFile(inputFile, inputYAML, 0o644); err != nil { + return "", fmt.Errorf("write input config: %w", err) + } + + // Execute wgc router compose + // wgc is installed globally in the Docker image + outputFile := filepath.Join(tmpDir, "config.json") + output, err := executor.Execute("wgc", "router", "compose", + "--input", inputFile, + "--out", outputFile, + "--suppress-warnings", + ) + if err != nil { + return "", fmt.Errorf("wgc router compose failed: %w\nOutput: %s", err, string(output)) + } + + // Read the generated config + configJSON, err := os.ReadFile(outputFile) + if err != nil { + return "", fmt.Errorf("read output config: %w", err) + } + + return string(configJSON), nil } diff --git a/graph/cosmo_test.go b/graph/cosmo_test.go index 0f6de36..f560fbf 100644 --- a/graph/cosmo_test.go +++ b/graph/cosmo_test.go @@ -2,14 +2,185 @@ package graph import ( "encoding/json" + "fmt" + "os" + "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" "gitlab.com/unboundsoftware/schemas/graph/model" ) +// MockCommandExecutor implements CommandExecutor for testing +type MockCommandExecutor struct { + // CallCount tracks how many times Execute was called + CallCount int + // LastArgs stores the arguments from the last call + LastArgs []string + // Error can be set to simulate command failures + Error error +} + +// Execute mocks the wgc command by generating a realistic config.json file +func (m *MockCommandExecutor) Execute(name string, args ...string) ([]byte, error) { + m.CallCount++ + m.LastArgs = append([]string{name}, args...) + + if m.Error != nil { + return nil, m.Error + } + + // Parse the input file to understand what subgraphs we're composing + var inputFile, outputFile string + for i, arg := range args { + if arg == "--input" && i+1 < len(args) { + inputFile = args[i+1] + } + if arg == "--out" && i+1 < len(args) { + outputFile = args[i+1] + } + } + + if inputFile == "" || outputFile == "" { + return nil, fmt.Errorf("missing required arguments") + } + + // Read the input YAML to get subgraph information + inputData, err := os.ReadFile(inputFile) + if err != nil { + return nil, fmt.Errorf("failed to read input file: %w", err) + } + + var input struct { + Version int `yaml:"version"` + Subgraphs []struct { + Name string `yaml:"name"` + RoutingURL string `yaml:"routing_url,omitempty"` + Schema map[string]string `yaml:"schema"` + Subscription map[string]interface{} `yaml:"subscription,omitempty"` + } `yaml:"subgraphs"` + } + + if err := yaml.Unmarshal(inputData, &input); err != nil { + return nil, fmt.Errorf("failed to parse input YAML: %w", err) + } + + // Generate a realistic Cosmo Router config based on the input + config := map[string]interface{}{ + "version": "mock-version-uuid", + "subgraphs": func() []map[string]interface{} { + subgraphs := make([]map[string]interface{}, len(input.Subgraphs)) + for i, sg := range input.Subgraphs { + subgraph := map[string]interface{}{ + "id": fmt.Sprintf("mock-id-%d", i), + "name": sg.Name, + } + if sg.RoutingURL != "" { + subgraph["routingUrl"] = sg.RoutingURL + } + subgraphs[i] = subgraph + } + return subgraphs + }(), + "engineConfig": map[string]interface{}{ + "graphqlSchema": generateMockSchema(input.Subgraphs), + "datasourceConfigurations": func() []map[string]interface{} { + dsConfigs := make([]map[string]interface{}, len(input.Subgraphs)) + for i, sg := range input.Subgraphs { + // Read SDL from file + sdl := "" + if schemaFile, ok := sg.Schema["file"]; ok { + if sdlData, err := os.ReadFile(schemaFile); err == nil { + sdl = string(sdlData) + } + } + + dsConfig := map[string]interface{}{ + "id": fmt.Sprintf("datasource-%d", i), + "kind": "GRAPHQL", + "customGraphql": map[string]interface{}{ + "federation": map[string]interface{}{ + "enabled": true, + "serviceSdl": sdl, + }, + "subscription": func() map[string]interface{} { + if len(sg.Subscription) > 0 { + return map[string]interface{}{ + "enabled": true, + "url": map[string]interface{}{ + "staticVariableContent": sg.Subscription["url"], + }, + "protocol": sg.Subscription["protocol"], + "websocketSubprotocol": sg.Subscription["websocket_subprotocol"], + } + } + return map[string]interface{}{ + "enabled": false, + } + }(), + }, + } + dsConfigs[i] = dsConfig + } + return dsConfigs + }(), + }, + } + + // Write the config to the output file + configJSON, err := json.Marshal(config) + if err != nil { + return nil, fmt.Errorf("failed to marshal config: %w", err) + } + + if err := os.WriteFile(outputFile, configJSON, 0o644); err != nil { + return nil, fmt.Errorf("failed to write output file: %w", err) + } + + return []byte("Success"), nil +} + +// generateMockSchema creates a simple merged schema from subgraphs +func generateMockSchema(subgraphs []struct { + Name string `yaml:"name"` + RoutingURL string `yaml:"routing_url,omitempty"` + Schema map[string]string `yaml:"schema"` + Subscription map[string]interface{} `yaml:"subscription,omitempty"` +}, +) string { + schema := strings.Builder{} + schema.WriteString("schema {\n query: Query\n") + + // Check if any subgraph has subscriptions + hasSubscriptions := false + for _, sg := range subgraphs { + if len(sg.Subscription) > 0 { + hasSubscriptions = true + break + } + } + + if hasSubscriptions { + schema.WriteString(" subscription: Subscription\n") + } + schema.WriteString("}\n\n") + + // Add types by reading SDL files + for _, sg := range subgraphs { + if schemaFile, ok := sg.Schema["file"]; ok { + if sdlData, err := os.ReadFile(schemaFile); err == nil { + schema.WriteString(string(sdlData)) + schema.WriteString("\n") + } + } + } + + return schema.String() +} + func TestGenerateCosmoRouterConfig(t *testing.T) { tests := []struct { name string @@ -33,7 +204,10 @@ func TestGenerateCosmoRouterConfig(t *testing.T) { err := json.Unmarshal([]byte(config), &result) require.NoError(t, err, "Config should be valid JSON") - assert.Equal(t, "1", result["version"], "Version should be 1") + // Version is a UUID string from wgc + version, ok := result["version"].(string) + require.True(t, ok, "Version should be a string") + assert.NotEmpty(t, version, "Version should not be empty") subgraphs, ok := result["subgraphs"].([]interface{}) require.True(t, ok, "subgraphs should be an array") @@ -41,14 +215,26 @@ func TestGenerateCosmoRouterConfig(t *testing.T) { sg := subgraphs[0].(map[string]interface{}) assert.Equal(t, "test-service", sg["name"]) - assert.Equal(t, "http://localhost:4001/query", sg["routing_url"]) - assert.Equal(t, "type Query { test: String }", sg["sdl"]) + assert.Equal(t, "http://localhost:4001/query", sg["routingUrl"]) - subscription, ok := sg["subscription"].(map[string]interface{}) + // Check that datasource configurations include subscription settings + engineConfig, ok := result["engineConfig"].(map[string]interface{}) + require.True(t, ok, "Should have engineConfig") + + dsConfigs, ok := engineConfig["datasourceConfigurations"].([]interface{}) + require.True(t, ok && len(dsConfigs) > 0, "Should have datasource configurations") + + ds := dsConfigs[0].(map[string]interface{}) + customGraphql, ok := ds["customGraphql"].(map[string]interface{}) + require.True(t, ok, "Should have customGraphql config") + + subscription, ok := customGraphql["subscription"].(map[string]interface{}) require.True(t, ok, "Should have subscription config") - assert.Equal(t, "ws://localhost:4001/query", subscription["url"]) - assert.Equal(t, "ws", subscription["protocol"]) - assert.Equal(t, "graphql-ws", subscription["websocket_subprotocol"]) + assert.True(t, subscription["enabled"].(bool), "Subscription should be enabled") + + subUrl, ok := subscription["url"].(map[string]interface{}) + require.True(t, ok, "Should have subscription URL") + assert.Equal(t, "ws://localhost:4001/query", subUrl["staticVariableContent"]) }, }, { @@ -80,18 +266,28 @@ func TestGenerateCosmoRouterConfig(t *testing.T) { subgraphs := result["subgraphs"].([]interface{}) assert.Len(t, subgraphs, 3, "Should have 3 subgraphs") - // Check first service has no subscription + // Check service names sg1 := subgraphs[0].(map[string]interface{}) assert.Equal(t, "service-1", sg1["name"]) - _, hasSubscription := sg1["subscription"] - assert.False(t, hasSubscription, "service-1 should not have subscription config") - // Check third service has subscription sg3 := subgraphs[2].(map[string]interface{}) assert.Equal(t, "service-3", sg3["name"]) - subscription, hasSubscription := sg3["subscription"] - assert.True(t, hasSubscription, "service-3 should have subscription config") - assert.NotNil(t, subscription) + + // Check that datasource configurations include subscription for service-3 + engineConfig, ok := result["engineConfig"].(map[string]interface{}) + require.True(t, ok, "Should have engineConfig") + + dsConfigs, ok := engineConfig["datasourceConfigurations"].([]interface{}) + require.True(t, ok && len(dsConfigs) == 3, "Should have 3 datasource configurations") + + // Find service-3's datasource config (should have subscription enabled) + ds3 := dsConfigs[2].(map[string]interface{}) + customGraphql, ok := ds3["customGraphql"].(map[string]interface{}) + require.True(t, ok, "Service-3 should have customGraphql config") + + subscription, ok := customGraphql["subscription"].(map[string]interface{}) + require.True(t, ok, "Service-3 should have subscription config") + assert.True(t, subscription["enabled"].(bool), "Service-3 subscription should be enabled") }, }, { @@ -113,39 +309,43 @@ func TestGenerateCosmoRouterConfig(t *testing.T) { subgraphs := result["subgraphs"].([]interface{}) sg := subgraphs[0].(map[string]interface{}) - // Should not have routing_url or subscription fields if URLs are nil - _, hasRoutingURL := sg["routing_url"] - assert.False(t, hasRoutingURL, "Should not have routing_url when URL is nil") + // Should not have routing URL when URL is nil + _, hasRoutingURL := sg["routingUrl"] + assert.False(t, hasRoutingURL, "Should not have routingUrl when URL is nil") - _, hasSubscription := sg["subscription"] - assert.False(t, hasSubscription, "Should not have subscription when WsURL is nil") + // Check datasource configurations don't have subscription enabled + engineConfig, ok := result["engineConfig"].(map[string]interface{}) + require.True(t, ok, "Should have engineConfig") + + dsConfigs, ok := engineConfig["datasourceConfigurations"].([]interface{}) + require.True(t, ok && len(dsConfigs) > 0, "Should have datasource configurations") + + ds := dsConfigs[0].(map[string]interface{}) + customGraphql, ok := ds["customGraphql"].(map[string]interface{}) + require.True(t, ok, "Should have customGraphql config") + + subscription, ok := customGraphql["subscription"].(map[string]interface{}) + if ok { + // wgc always enables subscription but URL should be empty when WsURL is nil + subUrl, hasUrl := subscription["url"].(map[string]interface{}) + if hasUrl { + _, hasStaticContent := subUrl["staticVariableContent"] + assert.False(t, hasStaticContent, "Subscription URL should be empty when WsURL is nil") + } + } }, }, { name: "empty subgraphs", subGraphs: []*model.SubGraph{}, - wantErr: false, - validate: func(t *testing.T, config string) { - var result map[string]interface{} - err := json.Unmarshal([]byte(config), &result) - require.NoError(t, err) - - subgraphs := result["subgraphs"].([]interface{}) - assert.Len(t, subgraphs, 0, "Should have empty subgraphs array") - }, + wantErr: true, + validate: nil, }, { name: "nil subgraphs", subGraphs: nil, - wantErr: false, - validate: func(t *testing.T, config string) { - var result map[string]interface{} - err := json.Unmarshal([]byte(config), &result) - require.NoError(t, err) - - subgraphs := result["subgraphs"].([]interface{}) - assert.Len(t, subgraphs, 0, "Should handle nil subgraphs as empty array") - }, + wantErr: true, + validate: nil, }, { name: "complex SDL with multiple types", @@ -173,29 +373,58 @@ func TestGenerateCosmoRouterConfig(t *testing.T) { err := json.Unmarshal([]byte(config), &result) require.NoError(t, err) - subgraphs := result["subgraphs"].([]interface{}) - sg := subgraphs[0].(map[string]interface{}) - sdl := sg["sdl"].(string) + // Check the composed graphqlSchema contains the types + engineConfig, ok := result["engineConfig"].(map[string]interface{}) + require.True(t, ok, "Should have engineConfig") - assert.Contains(t, sdl, "type Query") - assert.Contains(t, sdl, "type User") - assert.Contains(t, sdl, "email: String!") + graphqlSchema, ok := engineConfig["graphqlSchema"].(string) + require.True(t, ok, "Should have graphqlSchema") + + assert.Contains(t, graphqlSchema, "Query", "Schema should contain Query type") + assert.Contains(t, graphqlSchema, "User", "Schema should contain User type") + + // Check datasource has the original SDL + dsConfigs, ok := engineConfig["datasourceConfigurations"].([]interface{}) + require.True(t, ok && len(dsConfigs) > 0, "Should have datasource configurations") + + ds := dsConfigs[0].(map[string]interface{}) + customGraphql, ok := ds["customGraphql"].(map[string]interface{}) + require.True(t, ok, "Should have customGraphql config") + + federation, ok := customGraphql["federation"].(map[string]interface{}) + require.True(t, ok, "Should have federation config") + + serviceSdl, ok := federation["serviceSdl"].(string) + require.True(t, ok, "Should have serviceSdl") + assert.Contains(t, serviceSdl, "email: String!", "SDL should contain email field") }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - config, err := GenerateCosmoRouterConfig(tt.subGraphs) + // Use mock executor for all tests + mockExecutor := &MockCommandExecutor{} + config, err := GenerateCosmoRouterConfigWithExecutor(tt.subGraphs, mockExecutor) if tt.wantErr { assert.Error(t, err) + // Verify executor was not called for error cases + if len(tt.subGraphs) == 0 { + assert.Equal(t, 0, mockExecutor.CallCount, "Should not call executor for empty subgraphs") + } return } require.NoError(t, err) assert.NotEmpty(t, config, "Config should not be empty") + // Verify executor was called correctly + assert.Equal(t, 1, mockExecutor.CallCount, "Should call executor once") + assert.Equal(t, "wgc", mockExecutor.LastArgs[0], "Should call wgc command") + assert.Contains(t, mockExecutor.LastArgs, "router", "Should include 'router' arg") + assert.Contains(t, mockExecutor.LastArgs, "compose", "Should include 'compose' arg") + if tt.validate != nil { tt.validate(t, config) } @@ -203,53 +432,31 @@ func TestGenerateCosmoRouterConfig(t *testing.T) { } } -func TestConvertSubGraphsToCosmo(t *testing.T) { - tests := []struct { - name string - subGraphs []*model.SubGraph - wantLen int - validate func(t *testing.T, result []map[string]interface{}) - }{ +// TestGenerateCosmoRouterConfig_MockError tests error handling with mock executor +func TestGenerateCosmoRouterConfig_MockError(t *testing.T) { + subGraphs := []*model.SubGraph{ { - name: "preserves subgraph order", - subGraphs: []*model.SubGraph{ - {Service: "alpha", URL: stringPtr("http://a"), Sdl: "a"}, - {Service: "beta", URL: stringPtr("http://b"), Sdl: "b"}, - {Service: "gamma", URL: stringPtr("http://c"), Sdl: "c"}, - }, - wantLen: 3, - validate: func(t *testing.T, result []map[string]interface{}) { - assert.Equal(t, "alpha", result[0]["name"]) - assert.Equal(t, "beta", result[1]["name"]) - assert.Equal(t, "gamma", result[2]["name"]) - }, - }, - { - name: "includes SDL exactly as provided", - subGraphs: []*model.SubGraph{ - { - Service: "test", - URL: stringPtr("http://test"), - Sdl: "type Query { special: String! }", - }, - }, - wantLen: 1, - validate: func(t *testing.T, result []map[string]interface{}) { - assert.Equal(t, "type Query { special: String! }", result[0]["sdl"]) - }, + Service: "test-service", + URL: stringPtr("http://localhost:4001/query"), + Sdl: "type Query { test: String }", }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := convertSubGraphsToCosmo(tt.subGraphs) - assert.Len(t, result, tt.wantLen) - - if tt.validate != nil { - tt.validate(t, result) - } - }) + // Create a mock executor that returns an error + mockExecutor := &MockCommandExecutor{ + Error: fmt.Errorf("simulated wgc failure"), } + + config, err := GenerateCosmoRouterConfigWithExecutor(subGraphs, mockExecutor) + + // Verify error is propagated + assert.Error(t, err) + assert.Contains(t, err.Error(), "wgc router compose failed") + assert.Contains(t, err.Error(), "simulated wgc failure") + assert.Empty(t, config) + + // Verify executor was called + assert.Equal(t, 1, mockExecutor.CallCount, "Should have attempted to call executor") } // Helper function for tests diff --git a/graph/generated/generated.go b/graph/generated/generated.go index b4e5fec..7ba0598 100644 --- a/graph/generated/generated.go +++ b/graph/generated/generated.go @@ -74,6 +74,7 @@ type ComplexityRoot struct { } Query struct { + LatestSchema func(childComplexity int, ref string) int Organizations func(childComplexity int) int Supergraph func(childComplexity int, ref string, isAfter *string) int } @@ -124,6 +125,7 @@ type MutationResolver interface { type QueryResolver interface { Organizations(ctx context.Context) ([]*model.Organization, error) Supergraph(ctx context.Context, ref string, isAfter *string) (model.Supergraph, error) + LatestSchema(ctx context.Context, ref string) (*model.SchemaUpdate, error) } type SubscriptionResolver interface { SchemaUpdates(ctx context.Context, ref string) (<-chan *model.SchemaUpdate, error) @@ -250,6 +252,17 @@ func (e *executableSchema) Complexity(ctx context.Context, typeName, field strin return e.complexity.Organization.Users(childComplexity), true + case "Query.latestSchema": + if e.complexity.Query.LatestSchema == nil { + break + } + + args, err := ec.field_Query_latestSchema_args(ctx, rawArgs) + if err != nil { + return 0, false + } + + return e.complexity.Query.LatestSchema(childComplexity, args["ref"].(string)), true case "Query.organizations": if e.complexity.Query.Organizations == nil { break @@ -520,6 +533,7 @@ var sources = []*ast.Source{ {Name: "../schema.graphqls", Input: `type Query { organizations: [Organization!]! @auth(user: true) supergraph(ref: String!, isAfter: String): Supergraph! @auth(organization: true) + latestSchema(ref: String!): SchemaUpdate! @auth(organization: true) } type Mutation { @@ -671,6 +685,17 @@ func (ec *executionContext) field_Query___type_args(ctx context.Context, rawArgs return args, nil } +func (ec *executionContext) field_Query_latestSchema_args(ctx context.Context, rawArgs map[string]any) (map[string]any, error) { + var err error + args := map[string]any{} + arg0, err := graphql.ProcessArgField(ctx, rawArgs, "ref", ec.unmarshalNString2string) + if err != nil { + return nil, err + } + args["ref"] = arg0 + return args, nil +} + func (ec *executionContext) field_Query_supergraph_args(ctx context.Context, rawArgs map[string]any) (map[string]any, error) { var err error args := map[string]any{} @@ -1434,6 +1459,75 @@ func (ec *executionContext) fieldContext_Query_supergraph(ctx context.Context, f return fc, nil } +func (ec *executionContext) _Query_latestSchema(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { + return graphql.ResolveField( + ctx, + ec.OperationContext, + field, + ec.fieldContext_Query_latestSchema, + func(ctx context.Context) (any, error) { + fc := graphql.GetFieldContext(ctx) + return ec.resolvers.Query().LatestSchema(ctx, fc.Args["ref"].(string)) + }, + func(ctx context.Context, next graphql.Resolver) graphql.Resolver { + directive0 := next + + directive1 := func(ctx context.Context) (any, error) { + organization, err := ec.unmarshalOBoolean2ᚖbool(ctx, true) + if err != nil { + var zeroVal *model.SchemaUpdate + return zeroVal, err + } + if ec.directives.Auth == nil { + var zeroVal *model.SchemaUpdate + return zeroVal, errors.New("directive auth is not implemented") + } + return ec.directives.Auth(ctx, nil, directive0, nil, organization) + } + + next = directive1 + return next + }, + ec.marshalNSchemaUpdate2ᚖgitlabᚗcomᚋunboundsoftwareᚋschemasᚋgraphᚋmodelᚐSchemaUpdate, + true, + true, + ) +} + +func (ec *executionContext) fieldContext_Query_latestSchema(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "Query", + Field: field, + IsMethod: true, + IsResolver: true, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + switch field.Name { + case "ref": + return ec.fieldContext_SchemaUpdate_ref(ctx, field) + case "id": + return ec.fieldContext_SchemaUpdate_id(ctx, field) + case "subGraphs": + return ec.fieldContext_SchemaUpdate_subGraphs(ctx, field) + case "cosmoRouterConfig": + return ec.fieldContext_SchemaUpdate_cosmoRouterConfig(ctx, field) + } + return nil, fmt.Errorf("no field named %q was found under type SchemaUpdate", field.Name) + }, + } + defer func() { + if r := recover(); r != nil { + err = ec.Recover(ctx, r) + ec.Error(ctx, err) + } + }() + ctx = graphql.WithFieldContext(ctx, fc) + if fc.Args, err = ec.field_Query_latestSchema_args(ctx, field.ArgumentMap(ec.Variables)); err != nil { + ec.Error(ctx, err) + return fc, err + } + return fc, nil +} + func (ec *executionContext) _Query___type(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { return graphql.ResolveField( ctx, @@ -3997,6 +4091,28 @@ func (ec *executionContext) _Query(ctx context.Context, sel ast.SelectionSet) gr func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) }) } + out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return rrm(innerCtx) }) + case "latestSchema": + field := field + + innerFunc := func(ctx context.Context, fs *graphql.FieldSet) (res graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + } + }() + res = ec._Query_latestSchema(ctx, field) + if res == graphql.Null { + atomic.AddUint32(&fs.Invalids, 1) + } + return res + } + + rrm := func(ctx context.Context) graphql.Marshaler { + return ec.OperationContext.RootResolverMiddleware(ctx, + func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) }) + } + out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return rrm(innerCtx) }) case "__type": out.Values[i] = ec.OperationContext.RootResolverMiddleware(innerCtx, func(ctx context.Context) (res graphql.Marshaler) { diff --git a/graph/schema.graphqls b/graph/schema.graphqls index 97d82cb..ad1df55 100644 --- a/graph/schema.graphqls +++ b/graph/schema.graphqls @@ -1,6 +1,7 @@ type Query { organizations: [Organization!]! @auth(user: true) supergraph(ref: String!, isAfter: String): Supergraph! @auth(organization: true) + latestSchema(ref: String!): SchemaUpdate! @auth(organization: true) } type Mutation { diff --git a/graph/schema.resolvers.go b/graph/schema.resolvers.go index 6c68229..b426df5 100644 --- a/graph/schema.resolvers.go +++ b/graph/schema.resolvers.go @@ -123,6 +123,13 @@ func (r *mutationResolver) UpdateSubGraph(ctx context.Context, input model.Input // Publish schema update to subscribers go func() { services, lastUpdate := r.Cache.Services(orgId, input.Ref, "") + r.Logger.Info("Publishing schema update after subgraph change", + "ref", input.Ref, + "orgId", orgId, + "lastUpdate", lastUpdate, + "servicesCount", len(services), + ) + subGraphs := make([]*model.SubGraph, len(services)) for i, id := range services { sg, err := r.fetchSubGraph(context.Background(), id) @@ -149,12 +156,21 @@ func (r *mutationResolver) UpdateSubGraph(ctx context.Context, input model.Input } // Publish to all subscribers of this ref - r.PubSub.Publish(input.Ref, &model.SchemaUpdate{ + update := &model.SchemaUpdate{ Ref: input.Ref, ID: lastUpdate, SubGraphs: subGraphs, CosmoRouterConfig: &cosmoConfig, - }) + } + + r.Logger.Info("Publishing schema update to subscribers", + "ref", update.Ref, + "id", update.ID, + "subGraphsCount", len(update.SubGraphs), + "cosmoConfigLength", len(cosmoConfig), + ) + + r.PubSub.Publish(input.Ref, update) }() return r.toGqlSubGraph(subGraph), nil @@ -222,11 +238,84 @@ func (r *queryResolver) Supergraph(ctx context.Context, ref string, isAfter *str }, nil } +// LatestSchema is the resolver for the latestSchema field. +func (r *queryResolver) LatestSchema(ctx context.Context, ref string) (*model.SchemaUpdate, error) { + orgId := middleware.OrganizationFromContext(ctx) + + r.Logger.Info("LatestSchema query", + "ref", ref, + "orgId", orgId, + ) + + _, err := r.apiKeyCanAccessRef(ctx, ref, false) + if err != nil { + r.Logger.Error("API key cannot access ref", "error", err, "ref", ref) + return nil, err + } + + // Get current services and schema + services, lastUpdate := r.Cache.Services(orgId, ref, "") + r.Logger.Info("Fetching latest schema", + "ref", ref, + "orgId", orgId, + "lastUpdate", lastUpdate, + "servicesCount", len(services), + ) + + subGraphs := make([]*model.SubGraph, len(services)) + for i, id := range services { + sg, err := r.fetchSubGraph(ctx, id) + if err != nil { + r.Logger.Error("fetch subgraph", "error", err, "id", id) + return nil, err + } + subGraphs[i] = &model.SubGraph{ + ID: sg.ID.String(), + Service: sg.Service, + URL: sg.Url, + WsURL: sg.WSUrl, + Sdl: sg.Sdl, + ChangedBy: sg.ChangedBy, + ChangedAt: sg.ChangedAt, + } + } + + // Generate Cosmo router config + cosmoConfig, err := GenerateCosmoRouterConfig(subGraphs) + if err != nil { + r.Logger.Error("generate cosmo config", "error", err) + cosmoConfig = "" // Return empty if generation fails + } + + update := &model.SchemaUpdate{ + Ref: ref, + ID: lastUpdate, + SubGraphs: subGraphs, + CosmoRouterConfig: &cosmoConfig, + } + + r.Logger.Info("Latest schema fetched", + "ref", update.Ref, + "id", update.ID, + "subGraphsCount", len(update.SubGraphs), + "cosmoConfigLength", len(cosmoConfig), + ) + + return update, nil +} + // SchemaUpdates is the resolver for the schemaUpdates field. func (r *subscriptionResolver) SchemaUpdates(ctx context.Context, ref string) (<-chan *model.SchemaUpdate, error) { orgId := middleware.OrganizationFromContext(ctx) + + r.Logger.Info("SchemaUpdates subscription started", + "ref", ref, + "orgId", orgId, + ) + _, err := r.apiKeyCanAccessRef(ctx, ref, false) if err != nil { + r.Logger.Error("API key cannot access ref", "error", err, "ref", ref) return nil, err } @@ -235,12 +324,22 @@ func (r *subscriptionResolver) SchemaUpdates(ctx context.Context, ref string) (< // Send initial state immediately go func() { + // Use background context for async operation + bgCtx := context.Background() + services, lastUpdate := r.Cache.Services(orgId, ref, "") + r.Logger.Info("Preparing initial schema update", + "ref", ref, + "orgId", orgId, + "lastUpdate", lastUpdate, + "servicesCount", len(services), + ) + subGraphs := make([]*model.SubGraph, len(services)) for i, id := range services { - sg, err := r.fetchSubGraph(ctx, id) + sg, err := r.fetchSubGraph(bgCtx, id) if err != nil { - r.Logger.Error("fetch subgraph for initial update", "error", err) + r.Logger.Error("fetch subgraph for initial update", "error", err, "id", id) continue } subGraphs[i] = &model.SubGraph{ @@ -262,12 +361,21 @@ func (r *subscriptionResolver) SchemaUpdates(ctx context.Context, ref string) (< } // Send initial update - ch <- &model.SchemaUpdate{ + update := &model.SchemaUpdate{ Ref: ref, ID: lastUpdate, SubGraphs: subGraphs, CosmoRouterConfig: &cosmoConfig, } + + r.Logger.Info("Sending initial schema update", + "ref", update.Ref, + "id", update.ID, + "subGraphsCount", len(update.SubGraphs), + "cosmoConfigLength", len(cosmoConfig), + ) + + ch <- update }() // Clean up subscription when context is done diff --git a/middleware/auth.go b/middleware/auth.go index 6704946..bbe3a9e 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -49,7 +49,9 @@ func (m *AuthMiddleware) Handler(next http.Handler) http.Handler { _, _ = w.Write([]byte("Invalid API Key format")) return } - if organization := m.cache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil { + hashedKey := hash.String(apiKey) + organization := m.cache.OrganizationByAPIKey(hashedKey) + if organization != nil { ctx = context.WithValue(ctx, OrganizationKey, *organization) } diff --git a/middleware/auth_test.go b/middleware/auth_test.go new file mode 100644 index 0000000..3e456c8 --- /dev/null +++ b/middleware/auth_test.go @@ -0,0 +1,467 @@ +package middleware + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + mw "github.com/auth0/go-jwt-middleware/v2" + "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "gitlab.com/unboundsoftware/eventsourced/eventsourced" + + "gitlab.com/unboundsoftware/schemas/domain" + "gitlab.com/unboundsoftware/schemas/hash" +) + +// MockCache is a mock implementation of the Cache interface +type MockCache struct { + mock.Mock +} + +func (m *MockCache) OrganizationByAPIKey(apiKey string) *domain.Organization { + args := m.Called(apiKey) + if args.Get(0) == nil { + return nil + } + return args.Get(0).(*domain.Organization) +} + +func TestAuthMiddleware_Handler_WithValidAPIKey(t *testing.T) { + // Setup + mockCache := new(MockCache) + authMiddleware := NewAuth(mockCache) + + orgID := uuid.New() + expectedOrg := &domain.Organization{ + BaseAggregate: eventsourced.BaseAggregate{ + ID: eventsourced.IdFromString(orgID.String()), + }, + Name: "Test Organization", + } + + apiKey := "test-api-key-123" + hashedKey := hash.String(apiKey) + + mockCache.On("OrganizationByAPIKey", hashedKey).Return(expectedOrg) + + // Create a test handler that checks the context + var capturedOrg *domain.Organization + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if org := r.Context().Value(OrganizationKey); org != nil { + if o, ok := org.(domain.Organization); ok { + capturedOrg = &o + } + } + w.WriteHeader(http.StatusOK) + }) + + // Create request with API key in context + req := httptest.NewRequest(http.MethodGet, "/test", nil) + ctx := context.WithValue(req.Context(), ApiKey, apiKey) + req = req.WithContext(ctx) + + rec := httptest.NewRecorder() + + // Execute + authMiddleware.Handler(testHandler).ServeHTTP(rec, req) + + // Assert + assert.Equal(t, http.StatusOK, rec.Code) + require.NotNil(t, capturedOrg) + assert.Equal(t, expectedOrg.Name, capturedOrg.Name) + assert.Equal(t, expectedOrg.ID.String(), capturedOrg.ID.String()) + mockCache.AssertExpectations(t) +} + +func TestAuthMiddleware_Handler_WithInvalidAPIKey(t *testing.T) { + // Setup + mockCache := new(MockCache) + authMiddleware := NewAuth(mockCache) + + apiKey := "invalid-api-key" + hashedKey := hash.String(apiKey) + + mockCache.On("OrganizationByAPIKey", hashedKey).Return(nil) + + // Create a test handler that checks the context + var capturedOrg *domain.Organization + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if org := r.Context().Value(OrganizationKey); org != nil { + if o, ok := org.(domain.Organization); ok { + capturedOrg = &o + } + } + w.WriteHeader(http.StatusOK) + }) + + // Create request with API key in context + req := httptest.NewRequest(http.MethodGet, "/test", nil) + ctx := context.WithValue(req.Context(), ApiKey, apiKey) + req = req.WithContext(ctx) + + rec := httptest.NewRecorder() + + // Execute + authMiddleware.Handler(testHandler).ServeHTTP(rec, req) + + // Assert + assert.Equal(t, http.StatusOK, rec.Code) + assert.Nil(t, capturedOrg, "Organization should not be set for invalid API key") + mockCache.AssertExpectations(t) +} + +func TestAuthMiddleware_Handler_WithoutAPIKey(t *testing.T) { + // Setup + mockCache := new(MockCache) + authMiddleware := NewAuth(mockCache) + + // The middleware always hashes the API key (even if empty) and calls the cache + emptyKeyHash := hash.String("") + mockCache.On("OrganizationByAPIKey", emptyKeyHash).Return(nil) + + // Create a test handler that checks the context + var capturedOrg *domain.Organization + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if org := r.Context().Value(OrganizationKey); org != nil { + if o, ok := org.(domain.Organization); ok { + capturedOrg = &o + } + } + w.WriteHeader(http.StatusOK) + }) + + // Create request without API key + req := httptest.NewRequest(http.MethodGet, "/test", nil) + rec := httptest.NewRecorder() + + // Execute + authMiddleware.Handler(testHandler).ServeHTTP(rec, req) + + // Assert + assert.Equal(t, http.StatusOK, rec.Code) + assert.Nil(t, capturedOrg, "Organization should not be set without API key") + mockCache.AssertExpectations(t) +} + +func TestAuthMiddleware_Handler_WithValidJWT(t *testing.T) { + // Setup + mockCache := new(MockCache) + authMiddleware := NewAuth(mockCache) + + // The middleware always hashes the API key (even if empty) and calls the cache + emptyKeyHash := hash.String("") + mockCache.On("OrganizationByAPIKey", emptyKeyHash).Return(nil) + + userID := "user-123" + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "sub": userID, + }) + + // Create a test handler that checks the context + var capturedUser string + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if user := r.Context().Value(UserKey); user != nil { + if u, ok := user.(string); ok { + capturedUser = u + } + } + w.WriteHeader(http.StatusOK) + }) + + // Create request with JWT token in context + req := httptest.NewRequest(http.MethodGet, "/test", nil) + ctx := context.WithValue(req.Context(), mw.ContextKey{}, token) + req = req.WithContext(ctx) + + rec := httptest.NewRecorder() + + // Execute + authMiddleware.Handler(testHandler).ServeHTTP(rec, req) + + // Assert + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, userID, capturedUser) +} + +func TestAuthMiddleware_Handler_APIKeyErrorHandling(t *testing.T) { + // Setup + mockCache := new(MockCache) + authMiddleware := NewAuth(mockCache) + + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + // Create request with invalid API key type in context + req := httptest.NewRequest(http.MethodGet, "/test", nil) + ctx := context.WithValue(req.Context(), ApiKey, 12345) // Invalid type + req = req.WithContext(ctx) + + rec := httptest.NewRecorder() + + // Execute + authMiddleware.Handler(testHandler).ServeHTTP(rec, req) + + // Assert + assert.Equal(t, http.StatusInternalServerError, rec.Code) + assert.Contains(t, rec.Body.String(), "Invalid API Key format") +} + +func TestAuthMiddleware_Handler_JWTErrorHandling(t *testing.T) { + // Setup + mockCache := new(MockCache) + authMiddleware := NewAuth(mockCache) + + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + // Create request with invalid JWT token type in context + req := httptest.NewRequest(http.MethodGet, "/test", nil) + ctx := context.WithValue(req.Context(), mw.ContextKey{}, "not-a-token") // Invalid type + req = req.WithContext(ctx) + + rec := httptest.NewRecorder() + + // Execute + authMiddleware.Handler(testHandler).ServeHTTP(rec, req) + + // Assert + assert.Equal(t, http.StatusInternalServerError, rec.Code) + assert.Contains(t, rec.Body.String(), "Invalid JWT token format") +} + +func TestAuthMiddleware_Handler_BothJWTAndAPIKey(t *testing.T) { + // Setup + mockCache := new(MockCache) + authMiddleware := NewAuth(mockCache) + + orgID := uuid.New() + expectedOrg := &domain.Organization{ + BaseAggregate: eventsourced.BaseAggregate{ + ID: eventsourced.IdFromString(orgID.String()), + }, + Name: "Test Organization", + } + + userID := "user-123" + apiKey := "test-api-key-123" + hashedKey := hash.String(apiKey) + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "sub": userID, + }) + + mockCache.On("OrganizationByAPIKey", hashedKey).Return(expectedOrg) + + // Create a test handler that checks both user and organization in context + var capturedUser string + var capturedOrg *domain.Organization + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if user := r.Context().Value(UserKey); user != nil { + if u, ok := user.(string); ok { + capturedUser = u + } + } + if org := r.Context().Value(OrganizationKey); org != nil { + if o, ok := org.(domain.Organization); ok { + capturedOrg = &o + } + } + w.WriteHeader(http.StatusOK) + }) + + // Create request with both JWT and API key in context + req := httptest.NewRequest(http.MethodGet, "/test", nil) + ctx := context.WithValue(req.Context(), mw.ContextKey{}, token) + ctx = context.WithValue(ctx, ApiKey, apiKey) + req = req.WithContext(ctx) + + rec := httptest.NewRecorder() + + // Execute + authMiddleware.Handler(testHandler).ServeHTTP(rec, req) + + // Assert + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, userID, capturedUser) + require.NotNil(t, capturedOrg) + assert.Equal(t, expectedOrg.Name, capturedOrg.Name) + mockCache.AssertExpectations(t) +} + +func TestUserFromContext(t *testing.T) { + tests := []struct { + name string + ctx context.Context + expected string + }{ + { + name: "with valid user", + ctx: context.WithValue(context.Background(), UserKey, "user-123"), + expected: "user-123", + }, + { + name: "without user", + ctx: context.Background(), + expected: "", + }, + { + name: "with invalid type", + ctx: context.WithValue(context.Background(), UserKey, 123), + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := UserFromContext(tt.ctx) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestOrganizationFromContext(t *testing.T) { + orgID := uuid.New() + org := domain.Organization{ + BaseAggregate: eventsourced.BaseAggregate{ + ID: eventsourced.IdFromString(orgID.String()), + }, + Name: "Test Org", + } + + tests := []struct { + name string + ctx context.Context + expected string + }{ + { + name: "with valid organization", + ctx: context.WithValue(context.Background(), OrganizationKey, org), + expected: orgID.String(), + }, + { + name: "without organization", + ctx: context.Background(), + expected: "", + }, + { + name: "with invalid type", + ctx: context.WithValue(context.Background(), OrganizationKey, "not-an-org"), + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := OrganizationFromContext(tt.ctx) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestAuthMiddleware_Directive_RequiresUser(t *testing.T) { + mockCache := new(MockCache) + authMiddleware := NewAuth(mockCache) + + requireUser := true + + // Test with user present + ctx := context.WithValue(context.Background(), UserKey, "user-123") + _, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) { + return "success", nil + }, &requireUser, nil) + assert.NoError(t, err) + + // Test without user + ctx = context.Background() + _, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) { + return "success", nil + }, &requireUser, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "no user available in request") +} + +func TestAuthMiddleware_Directive_RequiresOrganization(t *testing.T) { + mockCache := new(MockCache) + authMiddleware := NewAuth(mockCache) + + requireOrg := true + orgID := uuid.New() + org := domain.Organization{ + BaseAggregate: eventsourced.BaseAggregate{ + ID: eventsourced.IdFromString(orgID.String()), + }, + Name: "Test Org", + } + + // Test with organization present + ctx := context.WithValue(context.Background(), OrganizationKey, org) + _, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) { + return "success", nil + }, nil, &requireOrg) + assert.NoError(t, err) + + // Test without organization + ctx = context.Background() + _, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) { + return "success", nil + }, nil, &requireOrg) + assert.Error(t, err) + assert.Contains(t, err.Error(), "no organization available in request") +} + +func TestAuthMiddleware_Directive_RequiresBoth(t *testing.T) { + mockCache := new(MockCache) + authMiddleware := NewAuth(mockCache) + + requireUser := true + requireOrg := true + orgID := uuid.New() + org := domain.Organization{ + BaseAggregate: eventsourced.BaseAggregate{ + ID: eventsourced.IdFromString(orgID.String()), + }, + Name: "Test Org", + } + + // Test with both present + ctx := context.WithValue(context.Background(), UserKey, "user-123") + ctx = context.WithValue(ctx, OrganizationKey, org) + _, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) { + return "success", nil + }, &requireUser, &requireOrg) + assert.NoError(t, err) + + // Test with only user + ctx = context.WithValue(context.Background(), UserKey, "user-123") + _, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) { + return "success", nil + }, &requireUser, &requireOrg) + assert.Error(t, err) + + // Test with only organization + ctx = context.WithValue(context.Background(), OrganizationKey, org) + _, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) { + return "success", nil + }, &requireUser, &requireOrg) + assert.Error(t, err) +} + +func TestAuthMiddleware_Directive_NoRequirements(t *testing.T) { + mockCache := new(MockCache) + authMiddleware := NewAuth(mockCache) + + // Test with no requirements + ctx := context.Background() + result, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) { + return "success", nil + }, nil, nil) + assert.NoError(t, err) + assert.Equal(t, "success", result) +}