Merge branch 'enhance-api-key-handling-logging' into 'main'

fix: enhance API key handling and logging in middleware

See merge request unboundsoftware/schemas!623
This commit was merged in pull request #627.
This commit is contained in:
2025-11-20 21:26:21 +01:00
12 changed files with 1460 additions and 122 deletions
+9 -1
View File
@@ -24,9 +24,17 @@ RUN GOOS=linux GOARCH=amd64 go build \
FROM scratch as export
COPY --from=build /build/coverage.txt /
FROM scratch
FROM node:22-alpine
ENV TZ Europe/Stockholm
# Install wgc CLI globally for Cosmo Router composition
RUN npm install -g wgc@latest
# Copy timezone data and certificates
COPY --from=build /usr/share/zoneinfo /usr/share/zoneinfo
COPY --from=build /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/
# Copy the service binary
COPY --from=build /release/service /
CMD ["/service"]
+19
View File
@@ -30,6 +30,7 @@ import (
"gitlab.com/unboundsoftware/schemas/domain"
"gitlab.com/unboundsoftware/schemas/graph"
"gitlab.com/unboundsoftware/schemas/graph/generated"
"gitlab.com/unboundsoftware/schemas/hash"
"gitlab.com/unboundsoftware/schemas/logging"
"gitlab.com/unboundsoftware/schemas/middleware"
"gitlab.com/unboundsoftware/schemas/monitoring"
@@ -210,6 +211,24 @@ func start(closeEvents chan error, logger *slog.Logger, connectToAmqpFunc func(u
srv.AddTransport(transport.Websocket{
KeepAlivePingInterval: 10 * time.Second,
InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
// Extract API key from WebSocket connection_init payload
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
logger.Info("WebSocket connection with API key", "has_key", true)
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
// Look up organization by API key (same logic as auth middleware)
if organization := serviceCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
logger.Info("WebSocket: Organization found for API key", "org_id", organization.ID.String())
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
} else {
logger.Warn("WebSocket: No organization found for API key")
}
} else {
logger.Info("WebSocket connection without API key")
}
return ctx, &initPayload, nil
},
})
srv.AddTransport(transport.Options{})
srv.AddTransport(transport.GET{})
+336
View File
@@ -0,0 +1,336 @@
package main
import (
"context"
"testing"
"github.com/99designs/gqlgen/graphql/handler/transport"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gitlab.com/unboundsoftware/eventsourced/eventsourced"
"gitlab.com/unboundsoftware/schemas/domain"
"gitlab.com/unboundsoftware/schemas/hash"
"gitlab.com/unboundsoftware/schemas/middleware"
)
// MockCache is a mock implementation for testing
type MockCache struct {
organizations map[string]*domain.Organization
}
func (m *MockCache) OrganizationByAPIKey(apiKey string) *domain.Organization {
return m.organizations[apiKey]
}
func TestWebSocketInitFunc_WithValidAPIKey(t *testing.T) {
// Setup
orgID := uuid.New()
org := &domain.Organization{
BaseAggregate: eventsourced.BaseAggregate{
ID: eventsourced.IdFromString(orgID.String()),
},
Name: "Test Organization",
}
apiKey := "test-api-key-123"
hashedKey := hash.String(apiKey)
mockCache := &MockCache{
organizations: map[string]*domain.Organization{
hashedKey: org,
},
}
// Create InitFunc (simulating the WebSocket InitFunc logic)
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
// Extract API key from WebSocket connection_init payload
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
// Look up organization by API key
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
}
}
return ctx, &initPayload, nil
}
// Test
ctx := context.Background()
initPayload := transport.InitPayload{
"X-Api-Key": apiKey,
}
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
// Assert
require.NoError(t, err)
require.NotNil(t, resultPayload)
// Check API key is in context
if value := resultCtx.Value(middleware.ApiKey); value != nil {
assert.Equal(t, apiKey, value.(string))
} else {
t.Fatal("API key not found in context")
}
// Check organization is in context
if value := resultCtx.Value(middleware.OrganizationKey); value != nil {
capturedOrg, ok := value.(domain.Organization)
require.True(t, ok, "Organization should be of correct type")
assert.Equal(t, org.Name, capturedOrg.Name)
assert.Equal(t, org.ID.String(), capturedOrg.ID.String())
} else {
t.Fatal("Organization not found in context")
}
}
func TestWebSocketInitFunc_WithInvalidAPIKey(t *testing.T) {
// Setup
mockCache := &MockCache{
organizations: map[string]*domain.Organization{},
}
apiKey := "invalid-api-key"
// Create InitFunc
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
// Extract API key from WebSocket connection_init payload
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
// Look up organization by API key
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
}
}
return ctx, &initPayload, nil
}
// Test
ctx := context.Background()
initPayload := transport.InitPayload{
"X-Api-Key": apiKey,
}
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
// Assert
require.NoError(t, err)
require.NotNil(t, resultPayload)
// Check API key is in context
if value := resultCtx.Value(middleware.ApiKey); value != nil {
assert.Equal(t, apiKey, value.(string))
} else {
t.Fatal("API key not found in context")
}
// Check organization is NOT in context (since API key is invalid)
value := resultCtx.Value(middleware.OrganizationKey)
assert.Nil(t, value, "Organization should not be set for invalid API key")
}
func TestWebSocketInitFunc_WithoutAPIKey(t *testing.T) {
// Setup
mockCache := &MockCache{
organizations: map[string]*domain.Organization{},
}
// Create InitFunc
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
// Extract API key from WebSocket connection_init payload
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
// Look up organization by API key
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
}
}
return ctx, &initPayload, nil
}
// Test
ctx := context.Background()
initPayload := transport.InitPayload{}
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
// Assert
require.NoError(t, err)
require.NotNil(t, resultPayload)
// Check API key is NOT in context
value := resultCtx.Value(middleware.ApiKey)
assert.Nil(t, value, "API key should not be set when not provided")
// Check organization is NOT in context
value = resultCtx.Value(middleware.OrganizationKey)
assert.Nil(t, value, "Organization should not be set when API key is not provided")
}
func TestWebSocketInitFunc_WithEmptyAPIKey(t *testing.T) {
// Setup
mockCache := &MockCache{
organizations: map[string]*domain.Organization{},
}
// Create InitFunc
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
// Extract API key from WebSocket connection_init payload
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
// Look up organization by API key
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
}
}
return ctx, &initPayload, nil
}
// Test
ctx := context.Background()
initPayload := transport.InitPayload{
"X-Api-Key": "", // Empty string
}
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
// Assert
require.NoError(t, err)
require.NotNil(t, resultPayload)
// Check API key is NOT in context (because empty string fails the condition)
value := resultCtx.Value(middleware.ApiKey)
assert.Nil(t, value, "API key should not be set when empty")
// Check organization is NOT in context
value = resultCtx.Value(middleware.OrganizationKey)
assert.Nil(t, value, "Organization should not be set when API key is empty")
}
func TestWebSocketInitFunc_WithWrongTypeAPIKey(t *testing.T) {
// Setup
mockCache := &MockCache{
organizations: map[string]*domain.Organization{},
}
// Create InitFunc
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
// Extract API key from WebSocket connection_init payload
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
// Look up organization by API key
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
}
}
return ctx, &initPayload, nil
}
// Test
ctx := context.Background()
initPayload := transport.InitPayload{
"X-Api-Key": 12345, // Wrong type (int instead of string)
}
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
// Assert
require.NoError(t, err)
require.NotNil(t, resultPayload)
// Check API key is NOT in context (type assertion fails)
value := resultCtx.Value(middleware.ApiKey)
assert.Nil(t, value, "API key should not be set when wrong type")
// Check organization is NOT in context
value = resultCtx.Value(middleware.OrganizationKey)
assert.Nil(t, value, "Organization should not be set when API key has wrong type")
}
func TestWebSocketInitFunc_WithMultipleOrganizations(t *testing.T) {
// Setup - create multiple organizations
org1ID := uuid.New()
org1 := &domain.Organization{
BaseAggregate: eventsourced.BaseAggregate{
ID: eventsourced.IdFromString(org1ID.String()),
},
Name: "Organization 1",
}
org2ID := uuid.New()
org2 := &domain.Organization{
BaseAggregate: eventsourced.BaseAggregate{
ID: eventsourced.IdFromString(org2ID.String()),
},
Name: "Organization 2",
}
apiKey1 := "api-key-org-1"
apiKey2 := "api-key-org-2"
hashedKey1 := hash.String(apiKey1)
hashedKey2 := hash.String(apiKey2)
mockCache := &MockCache{
organizations: map[string]*domain.Organization{
hashedKey1: org1,
hashedKey2: org2,
},
}
// Create InitFunc
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
// Extract API key from WebSocket connection_init payload
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
// Look up organization by API key
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
}
}
return ctx, &initPayload, nil
}
// Test with first API key
ctx1 := context.Background()
initPayload1 := transport.InitPayload{
"X-Api-Key": apiKey1,
}
resultCtx1, _, err := initFunc(ctx1, initPayload1)
require.NoError(t, err)
if value := resultCtx1.Value(middleware.OrganizationKey); value != nil {
capturedOrg, ok := value.(domain.Organization)
require.True(t, ok)
assert.Equal(t, org1.Name, capturedOrg.Name)
assert.Equal(t, org1.ID.String(), capturedOrg.ID.String())
} else {
t.Fatal("Organization 1 not found in context")
}
// Test with second API key
ctx2 := context.Background()
initPayload2 := transport.InitPayload{
"X-Api-Key": apiKey2,
}
resultCtx2, _, err := initFunc(ctx2, initPayload2)
require.NoError(t, err)
if value := resultCtx2.Value(middleware.OrganizationKey); value != nil {
capturedOrg, ok := value.(domain.Organization)
require.True(t, ok)
assert.Equal(t, org2.Name, capturedOrg.Name)
assert.Equal(t, org2.ID.String(), capturedOrg.ID.String())
} else {
t.Fatal("Organization 2 not found in context")
}
}
+3 -2
View File
@@ -9,6 +9,7 @@ require (
github.com/apex/log v1.9.0
github.com/auth0/go-jwt-middleware/v2 v2.3.0
github.com/golang-jwt/jwt/v5 v5.3.0
github.com/google/uuid v1.6.0
github.com/jmoiron/sqlx v1.4.0
github.com/pkg/errors v0.9.1
github.com/pressly/goose/v3 v3.26.0
@@ -30,6 +31,7 @@ require (
go.opentelemetry.io/otel/sdk/log v0.14.0
go.opentelemetry.io/otel/sdk/metric v1.38.0
go.opentelemetry.io/otel/trace v1.38.0
gopkg.in/yaml.v3 v3.0.1
)
require (
@@ -41,7 +43,6 @@ require (
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/gorilla/websocket v1.5.1 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect
github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect
@@ -51,6 +52,7 @@ require (
github.com/rabbitmq/amqp091-go v1.10.0 // indirect
github.com/sethvargo/go-retry v0.3.0 // indirect
github.com/sosodev/duration v1.3.1 // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/tidwall/gjson v1.17.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
@@ -72,5 +74,4 @@ require (
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 // indirect
google.golang.org/grpc v1.75.0 // indirect
google.golang.org/protobuf v1.36.10 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
+2
View File
@@ -141,6 +141,8 @@ github.com/sosodev/duration v1.3.1/go.mod h1:RQIBBX0+fMLc/D9+Jb/fwvVmo0eZvDDEERA
github.com/sparetimecoders/goamqp v0.3.3 h1:z/nfTPmrjeU/rIVuNOgsVLCimp3WFoNFvS3ZzXRJ6HE=
github.com/sparetimecoders/goamqp v0.3.3/go.mod h1:W9NRCpWLE+Vruv2dcRSbszNil2O826d2Nv6kAkETW5o=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
+98 -27
View File
@@ -1,54 +1,125 @@
package graph
import (
"encoding/json"
"fmt"
"os"
"os/exec"
"path/filepath"
"gopkg.in/yaml.v3"
"gitlab.com/unboundsoftware/schemas/graph/model"
)
// GenerateCosmoRouterConfig generates a Cosmo Router execution config from subgraphs
func GenerateCosmoRouterConfig(subGraphs []*model.SubGraph) (string, error) {
// Build the Cosmo router config structure
// This is a simplified version - you may need to adjust based on actual Cosmo requirements
config := map[string]interface{}{
"version": "1",
"subgraphs": convertSubGraphsToCosmo(subGraphs),
// Add other Cosmo-specific configuration as needed
}
// Marshal to JSON
configJSON, err := json.MarshalIndent(config, "", " ")
if err != nil {
return "", fmt.Errorf("marshal cosmo config: %w", err)
}
return string(configJSON), nil
// CommandExecutor is an interface for executing external commands
// This allows for mocking in tests
type CommandExecutor interface {
Execute(name string, args ...string) ([]byte, error)
}
func convertSubGraphsToCosmo(subGraphs []*model.SubGraph) []map[string]interface{} {
cosmoSubgraphs := make([]map[string]interface{}, 0, len(subGraphs))
// DefaultCommandExecutor implements CommandExecutor using os/exec
type DefaultCommandExecutor struct{}
// Execute runs a command and returns its combined output
func (e *DefaultCommandExecutor) Execute(name string, args ...string) ([]byte, error) {
cmd := exec.Command(name, args...)
return cmd.CombinedOutput()
}
// GenerateCosmoRouterConfig generates a Cosmo Router execution config from subgraphs
// using the official wgc CLI tool via npx
func GenerateCosmoRouterConfig(subGraphs []*model.SubGraph) (string, error) {
return GenerateCosmoRouterConfigWithExecutor(subGraphs, &DefaultCommandExecutor{})
}
// GenerateCosmoRouterConfigWithExecutor generates a Cosmo Router execution config from subgraphs
// using the provided command executor (useful for testing)
func GenerateCosmoRouterConfigWithExecutor(subGraphs []*model.SubGraph, executor CommandExecutor) (string, error) {
if len(subGraphs) == 0 {
return "", fmt.Errorf("no subgraphs provided")
}
// Create a temporary directory for composition
tmpDir, err := os.MkdirTemp("", "cosmo-compose-*")
if err != nil {
return "", fmt.Errorf("create temp dir: %w", err)
}
defer os.RemoveAll(tmpDir)
// Write each subgraph SDL to a file
type SubgraphConfig struct {
Name string `yaml:"name"`
RoutingURL string `yaml:"routing_url,omitempty"`
Schema map[string]string `yaml:"schema"`
Subscription map[string]interface{} `yaml:"subscription,omitempty"`
}
type InputConfig struct {
Version int `yaml:"version"`
Subgraphs []SubgraphConfig `yaml:"subgraphs"`
}
inputConfig := InputConfig{
Version: 1,
Subgraphs: make([]SubgraphConfig, 0, len(subGraphs)),
}
for _, sg := range subGraphs {
cosmoSg := map[string]interface{}{
"name": sg.Service,
"sdl": sg.Sdl,
// Write SDL to a temp file
schemaFile := filepath.Join(tmpDir, fmt.Sprintf("%s.graphql", sg.Service))
if err := os.WriteFile(schemaFile, []byte(sg.Sdl), 0o644); err != nil {
return "", fmt.Errorf("write schema file for %s: %w", sg.Service, err)
}
subgraphCfg := SubgraphConfig{
Name: sg.Service,
Schema: map[string]string{
"file": schemaFile,
},
}
if sg.URL != nil {
cosmoSg["routing_url"] = *sg.URL
subgraphCfg.RoutingURL = *sg.URL
}
if sg.WsURL != nil {
cosmoSg["subscription"] = map[string]interface{}{
subgraphCfg.Subscription = map[string]interface{}{
"url": *sg.WsURL,
"protocol": "ws",
"websocket_subprotocol": "graphql-ws",
}
}
cosmoSubgraphs = append(cosmoSubgraphs, cosmoSg)
inputConfig.Subgraphs = append(inputConfig.Subgraphs, subgraphCfg)
}
return cosmoSubgraphs
// Write input config YAML
inputFile := filepath.Join(tmpDir, "input.yaml")
inputYAML, err := yaml.Marshal(inputConfig)
if err != nil {
return "", fmt.Errorf("marshal input config: %w", err)
}
if err := os.WriteFile(inputFile, inputYAML, 0o644); err != nil {
return "", fmt.Errorf("write input config: %w", err)
}
// Execute wgc router compose
// wgc is installed globally in the Docker image
outputFile := filepath.Join(tmpDir, "config.json")
output, err := executor.Execute("wgc", "router", "compose",
"--input", inputFile,
"--out", outputFile,
"--suppress-warnings",
)
if err != nil {
return "", fmt.Errorf("wgc router compose failed: %w\nOutput: %s", err, string(output))
}
// Read the generated config
configJSON, err := os.ReadFile(outputFile)
if err != nil {
return "", fmt.Errorf("read output config: %w", err)
}
return string(configJSON), nil
}
+293 -86
View File
@@ -2,14 +2,185 @@ package graph
import (
"encoding/json"
"fmt"
"os"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/yaml.v3"
"gitlab.com/unboundsoftware/schemas/graph/model"
)
// MockCommandExecutor implements CommandExecutor for testing
type MockCommandExecutor struct {
// CallCount tracks how many times Execute was called
CallCount int
// LastArgs stores the arguments from the last call
LastArgs []string
// Error can be set to simulate command failures
Error error
}
// Execute mocks the wgc command by generating a realistic config.json file
func (m *MockCommandExecutor) Execute(name string, args ...string) ([]byte, error) {
m.CallCount++
m.LastArgs = append([]string{name}, args...)
if m.Error != nil {
return nil, m.Error
}
// Parse the input file to understand what subgraphs we're composing
var inputFile, outputFile string
for i, arg := range args {
if arg == "--input" && i+1 < len(args) {
inputFile = args[i+1]
}
if arg == "--out" && i+1 < len(args) {
outputFile = args[i+1]
}
}
if inputFile == "" || outputFile == "" {
return nil, fmt.Errorf("missing required arguments")
}
// Read the input YAML to get subgraph information
inputData, err := os.ReadFile(inputFile)
if err != nil {
return nil, fmt.Errorf("failed to read input file: %w", err)
}
var input struct {
Version int `yaml:"version"`
Subgraphs []struct {
Name string `yaml:"name"`
RoutingURL string `yaml:"routing_url,omitempty"`
Schema map[string]string `yaml:"schema"`
Subscription map[string]interface{} `yaml:"subscription,omitempty"`
} `yaml:"subgraphs"`
}
if err := yaml.Unmarshal(inputData, &input); err != nil {
return nil, fmt.Errorf("failed to parse input YAML: %w", err)
}
// Generate a realistic Cosmo Router config based on the input
config := map[string]interface{}{
"version": "mock-version-uuid",
"subgraphs": func() []map[string]interface{} {
subgraphs := make([]map[string]interface{}, len(input.Subgraphs))
for i, sg := range input.Subgraphs {
subgraph := map[string]interface{}{
"id": fmt.Sprintf("mock-id-%d", i),
"name": sg.Name,
}
if sg.RoutingURL != "" {
subgraph["routingUrl"] = sg.RoutingURL
}
subgraphs[i] = subgraph
}
return subgraphs
}(),
"engineConfig": map[string]interface{}{
"graphqlSchema": generateMockSchema(input.Subgraphs),
"datasourceConfigurations": func() []map[string]interface{} {
dsConfigs := make([]map[string]interface{}, len(input.Subgraphs))
for i, sg := range input.Subgraphs {
// Read SDL from file
sdl := ""
if schemaFile, ok := sg.Schema["file"]; ok {
if sdlData, err := os.ReadFile(schemaFile); err == nil {
sdl = string(sdlData)
}
}
dsConfig := map[string]interface{}{
"id": fmt.Sprintf("datasource-%d", i),
"kind": "GRAPHQL",
"customGraphql": map[string]interface{}{
"federation": map[string]interface{}{
"enabled": true,
"serviceSdl": sdl,
},
"subscription": func() map[string]interface{} {
if len(sg.Subscription) > 0 {
return map[string]interface{}{
"enabled": true,
"url": map[string]interface{}{
"staticVariableContent": sg.Subscription["url"],
},
"protocol": sg.Subscription["protocol"],
"websocketSubprotocol": sg.Subscription["websocket_subprotocol"],
}
}
return map[string]interface{}{
"enabled": false,
}
}(),
},
}
dsConfigs[i] = dsConfig
}
return dsConfigs
}(),
},
}
// Write the config to the output file
configJSON, err := json.Marshal(config)
if err != nil {
return nil, fmt.Errorf("failed to marshal config: %w", err)
}
if err := os.WriteFile(outputFile, configJSON, 0o644); err != nil {
return nil, fmt.Errorf("failed to write output file: %w", err)
}
return []byte("Success"), nil
}
// generateMockSchema creates a simple merged schema from subgraphs
func generateMockSchema(subgraphs []struct {
Name string `yaml:"name"`
RoutingURL string `yaml:"routing_url,omitempty"`
Schema map[string]string `yaml:"schema"`
Subscription map[string]interface{} `yaml:"subscription,omitempty"`
},
) string {
schema := strings.Builder{}
schema.WriteString("schema {\n query: Query\n")
// Check if any subgraph has subscriptions
hasSubscriptions := false
for _, sg := range subgraphs {
if len(sg.Subscription) > 0 {
hasSubscriptions = true
break
}
}
if hasSubscriptions {
schema.WriteString(" subscription: Subscription\n")
}
schema.WriteString("}\n\n")
// Add types by reading SDL files
for _, sg := range subgraphs {
if schemaFile, ok := sg.Schema["file"]; ok {
if sdlData, err := os.ReadFile(schemaFile); err == nil {
schema.WriteString(string(sdlData))
schema.WriteString("\n")
}
}
}
return schema.String()
}
func TestGenerateCosmoRouterConfig(t *testing.T) {
tests := []struct {
name string
@@ -33,7 +204,10 @@ func TestGenerateCosmoRouterConfig(t *testing.T) {
err := json.Unmarshal([]byte(config), &result)
require.NoError(t, err, "Config should be valid JSON")
assert.Equal(t, "1", result["version"], "Version should be 1")
// Version is a UUID string from wgc
version, ok := result["version"].(string)
require.True(t, ok, "Version should be a string")
assert.NotEmpty(t, version, "Version should not be empty")
subgraphs, ok := result["subgraphs"].([]interface{})
require.True(t, ok, "subgraphs should be an array")
@@ -41,14 +215,26 @@ func TestGenerateCosmoRouterConfig(t *testing.T) {
sg := subgraphs[0].(map[string]interface{})
assert.Equal(t, "test-service", sg["name"])
assert.Equal(t, "http://localhost:4001/query", sg["routing_url"])
assert.Equal(t, "type Query { test: String }", sg["sdl"])
assert.Equal(t, "http://localhost:4001/query", sg["routingUrl"])
subscription, ok := sg["subscription"].(map[string]interface{})
// Check that datasource configurations include subscription settings
engineConfig, ok := result["engineConfig"].(map[string]interface{})
require.True(t, ok, "Should have engineConfig")
dsConfigs, ok := engineConfig["datasourceConfigurations"].([]interface{})
require.True(t, ok && len(dsConfigs) > 0, "Should have datasource configurations")
ds := dsConfigs[0].(map[string]interface{})
customGraphql, ok := ds["customGraphql"].(map[string]interface{})
require.True(t, ok, "Should have customGraphql config")
subscription, ok := customGraphql["subscription"].(map[string]interface{})
require.True(t, ok, "Should have subscription config")
assert.Equal(t, "ws://localhost:4001/query", subscription["url"])
assert.Equal(t, "ws", subscription["protocol"])
assert.Equal(t, "graphql-ws", subscription["websocket_subprotocol"])
assert.True(t, subscription["enabled"].(bool), "Subscription should be enabled")
subUrl, ok := subscription["url"].(map[string]interface{})
require.True(t, ok, "Should have subscription URL")
assert.Equal(t, "ws://localhost:4001/query", subUrl["staticVariableContent"])
},
},
{
@@ -80,18 +266,28 @@ func TestGenerateCosmoRouterConfig(t *testing.T) {
subgraphs := result["subgraphs"].([]interface{})
assert.Len(t, subgraphs, 3, "Should have 3 subgraphs")
// Check first service has no subscription
// Check service names
sg1 := subgraphs[0].(map[string]interface{})
assert.Equal(t, "service-1", sg1["name"])
_, hasSubscription := sg1["subscription"]
assert.False(t, hasSubscription, "service-1 should not have subscription config")
// Check third service has subscription
sg3 := subgraphs[2].(map[string]interface{})
assert.Equal(t, "service-3", sg3["name"])
subscription, hasSubscription := sg3["subscription"]
assert.True(t, hasSubscription, "service-3 should have subscription config")
assert.NotNil(t, subscription)
// Check that datasource configurations include subscription for service-3
engineConfig, ok := result["engineConfig"].(map[string]interface{})
require.True(t, ok, "Should have engineConfig")
dsConfigs, ok := engineConfig["datasourceConfigurations"].([]interface{})
require.True(t, ok && len(dsConfigs) == 3, "Should have 3 datasource configurations")
// Find service-3's datasource config (should have subscription enabled)
ds3 := dsConfigs[2].(map[string]interface{})
customGraphql, ok := ds3["customGraphql"].(map[string]interface{})
require.True(t, ok, "Service-3 should have customGraphql config")
subscription, ok := customGraphql["subscription"].(map[string]interface{})
require.True(t, ok, "Service-3 should have subscription config")
assert.True(t, subscription["enabled"].(bool), "Service-3 subscription should be enabled")
},
},
{
@@ -113,39 +309,43 @@ func TestGenerateCosmoRouterConfig(t *testing.T) {
subgraphs := result["subgraphs"].([]interface{})
sg := subgraphs[0].(map[string]interface{})
// Should not have routing_url or subscription fields if URLs are nil
_, hasRoutingURL := sg["routing_url"]
assert.False(t, hasRoutingURL, "Should not have routing_url when URL is nil")
// Should not have routing URL when URL is nil
_, hasRoutingURL := sg["routingUrl"]
assert.False(t, hasRoutingURL, "Should not have routingUrl when URL is nil")
_, hasSubscription := sg["subscription"]
assert.False(t, hasSubscription, "Should not have subscription when WsURL is nil")
// Check datasource configurations don't have subscription enabled
engineConfig, ok := result["engineConfig"].(map[string]interface{})
require.True(t, ok, "Should have engineConfig")
dsConfigs, ok := engineConfig["datasourceConfigurations"].([]interface{})
require.True(t, ok && len(dsConfigs) > 0, "Should have datasource configurations")
ds := dsConfigs[0].(map[string]interface{})
customGraphql, ok := ds["customGraphql"].(map[string]interface{})
require.True(t, ok, "Should have customGraphql config")
subscription, ok := customGraphql["subscription"].(map[string]interface{})
if ok {
// wgc always enables subscription but URL should be empty when WsURL is nil
subUrl, hasUrl := subscription["url"].(map[string]interface{})
if hasUrl {
_, hasStaticContent := subUrl["staticVariableContent"]
assert.False(t, hasStaticContent, "Subscription URL should be empty when WsURL is nil")
}
}
},
},
{
name: "empty subgraphs",
subGraphs: []*model.SubGraph{},
wantErr: false,
validate: func(t *testing.T, config string) {
var result map[string]interface{}
err := json.Unmarshal([]byte(config), &result)
require.NoError(t, err)
subgraphs := result["subgraphs"].([]interface{})
assert.Len(t, subgraphs, 0, "Should have empty subgraphs array")
},
wantErr: true,
validate: nil,
},
{
name: "nil subgraphs",
subGraphs: nil,
wantErr: false,
validate: func(t *testing.T, config string) {
var result map[string]interface{}
err := json.Unmarshal([]byte(config), &result)
require.NoError(t, err)
subgraphs := result["subgraphs"].([]interface{})
assert.Len(t, subgraphs, 0, "Should handle nil subgraphs as empty array")
},
wantErr: true,
validate: nil,
},
{
name: "complex SDL with multiple types",
@@ -173,29 +373,58 @@ func TestGenerateCosmoRouterConfig(t *testing.T) {
err := json.Unmarshal([]byte(config), &result)
require.NoError(t, err)
subgraphs := result["subgraphs"].([]interface{})
sg := subgraphs[0].(map[string]interface{})
sdl := sg["sdl"].(string)
// Check the composed graphqlSchema contains the types
engineConfig, ok := result["engineConfig"].(map[string]interface{})
require.True(t, ok, "Should have engineConfig")
assert.Contains(t, sdl, "type Query")
assert.Contains(t, sdl, "type User")
assert.Contains(t, sdl, "email: String!")
graphqlSchema, ok := engineConfig["graphqlSchema"].(string)
require.True(t, ok, "Should have graphqlSchema")
assert.Contains(t, graphqlSchema, "Query", "Schema should contain Query type")
assert.Contains(t, graphqlSchema, "User", "Schema should contain User type")
// Check datasource has the original SDL
dsConfigs, ok := engineConfig["datasourceConfigurations"].([]interface{})
require.True(t, ok && len(dsConfigs) > 0, "Should have datasource configurations")
ds := dsConfigs[0].(map[string]interface{})
customGraphql, ok := ds["customGraphql"].(map[string]interface{})
require.True(t, ok, "Should have customGraphql config")
federation, ok := customGraphql["federation"].(map[string]interface{})
require.True(t, ok, "Should have federation config")
serviceSdl, ok := federation["serviceSdl"].(string)
require.True(t, ok, "Should have serviceSdl")
assert.Contains(t, serviceSdl, "email: String!", "SDL should contain email field")
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config, err := GenerateCosmoRouterConfig(tt.subGraphs)
// Use mock executor for all tests
mockExecutor := &MockCommandExecutor{}
config, err := GenerateCosmoRouterConfigWithExecutor(tt.subGraphs, mockExecutor)
if tt.wantErr {
assert.Error(t, err)
// Verify executor was not called for error cases
if len(tt.subGraphs) == 0 {
assert.Equal(t, 0, mockExecutor.CallCount, "Should not call executor for empty subgraphs")
}
return
}
require.NoError(t, err)
assert.NotEmpty(t, config, "Config should not be empty")
// Verify executor was called correctly
assert.Equal(t, 1, mockExecutor.CallCount, "Should call executor once")
assert.Equal(t, "wgc", mockExecutor.LastArgs[0], "Should call wgc command")
assert.Contains(t, mockExecutor.LastArgs, "router", "Should include 'router' arg")
assert.Contains(t, mockExecutor.LastArgs, "compose", "Should include 'compose' arg")
if tt.validate != nil {
tt.validate(t, config)
}
@@ -203,53 +432,31 @@ func TestGenerateCosmoRouterConfig(t *testing.T) {
}
}
func TestConvertSubGraphsToCosmo(t *testing.T) {
tests := []struct {
name string
subGraphs []*model.SubGraph
wantLen int
validate func(t *testing.T, result []map[string]interface{})
}{
// TestGenerateCosmoRouterConfig_MockError tests error handling with mock executor
func TestGenerateCosmoRouterConfig_MockError(t *testing.T) {
subGraphs := []*model.SubGraph{
{
name: "preserves subgraph order",
subGraphs: []*model.SubGraph{
{Service: "alpha", URL: stringPtr("http://a"), Sdl: "a"},
{Service: "beta", URL: stringPtr("http://b"), Sdl: "b"},
{Service: "gamma", URL: stringPtr("http://c"), Sdl: "c"},
},
wantLen: 3,
validate: func(t *testing.T, result []map[string]interface{}) {
assert.Equal(t, "alpha", result[0]["name"])
assert.Equal(t, "beta", result[1]["name"])
assert.Equal(t, "gamma", result[2]["name"])
},
},
{
name: "includes SDL exactly as provided",
subGraphs: []*model.SubGraph{
{
Service: "test",
URL: stringPtr("http://test"),
Sdl: "type Query { special: String! }",
},
},
wantLen: 1,
validate: func(t *testing.T, result []map[string]interface{}) {
assert.Equal(t, "type Query { special: String! }", result[0]["sdl"])
},
Service: "test-service",
URL: stringPtr("http://localhost:4001/query"),
Sdl: "type Query { test: String }",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := convertSubGraphsToCosmo(tt.subGraphs)
assert.Len(t, result, tt.wantLen)
// Create a mock executor that returns an error
mockExecutor := &MockCommandExecutor{
Error: fmt.Errorf("simulated wgc failure"),
}
if tt.validate != nil {
tt.validate(t, result)
}
})
}
config, err := GenerateCosmoRouterConfigWithExecutor(subGraphs, mockExecutor)
// Verify error is propagated
assert.Error(t, err)
assert.Contains(t, err.Error(), "wgc router compose failed")
assert.Contains(t, err.Error(), "simulated wgc failure")
assert.Empty(t, config)
// Verify executor was called
assert.Equal(t, 1, mockExecutor.CallCount, "Should have attempted to call executor")
}
// Helper function for tests
+116
View File
@@ -74,6 +74,7 @@ type ComplexityRoot struct {
}
Query struct {
LatestSchema func(childComplexity int, ref string) int
Organizations func(childComplexity int) int
Supergraph func(childComplexity int, ref string, isAfter *string) int
}
@@ -124,6 +125,7 @@ type MutationResolver interface {
type QueryResolver interface {
Organizations(ctx context.Context) ([]*model.Organization, error)
Supergraph(ctx context.Context, ref string, isAfter *string) (model.Supergraph, error)
LatestSchema(ctx context.Context, ref string) (*model.SchemaUpdate, error)
}
type SubscriptionResolver interface {
SchemaUpdates(ctx context.Context, ref string) (<-chan *model.SchemaUpdate, error)
@@ -250,6 +252,17 @@ func (e *executableSchema) Complexity(ctx context.Context, typeName, field strin
return e.complexity.Organization.Users(childComplexity), true
case "Query.latestSchema":
if e.complexity.Query.LatestSchema == nil {
break
}
args, err := ec.field_Query_latestSchema_args(ctx, rawArgs)
if err != nil {
return 0, false
}
return e.complexity.Query.LatestSchema(childComplexity, args["ref"].(string)), true
case "Query.organizations":
if e.complexity.Query.Organizations == nil {
break
@@ -520,6 +533,7 @@ var sources = []*ast.Source{
{Name: "../schema.graphqls", Input: `type Query {
organizations: [Organization!]! @auth(user: true)
supergraph(ref: String!, isAfter: String): Supergraph! @auth(organization: true)
latestSchema(ref: String!): SchemaUpdate! @auth(organization: true)
}
type Mutation {
@@ -671,6 +685,17 @@ func (ec *executionContext) field_Query___type_args(ctx context.Context, rawArgs
return args, nil
}
func (ec *executionContext) field_Query_latestSchema_args(ctx context.Context, rawArgs map[string]any) (map[string]any, error) {
var err error
args := map[string]any{}
arg0, err := graphql.ProcessArgField(ctx, rawArgs, "ref", ec.unmarshalNString2string)
if err != nil {
return nil, err
}
args["ref"] = arg0
return args, nil
}
func (ec *executionContext) field_Query_supergraph_args(ctx context.Context, rawArgs map[string]any) (map[string]any, error) {
var err error
args := map[string]any{}
@@ -1434,6 +1459,75 @@ func (ec *executionContext) fieldContext_Query_supergraph(ctx context.Context, f
return fc, nil
}
func (ec *executionContext) _Query_latestSchema(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) {
return graphql.ResolveField(
ctx,
ec.OperationContext,
field,
ec.fieldContext_Query_latestSchema,
func(ctx context.Context) (any, error) {
fc := graphql.GetFieldContext(ctx)
return ec.resolvers.Query().LatestSchema(ctx, fc.Args["ref"].(string))
},
func(ctx context.Context, next graphql.Resolver) graphql.Resolver {
directive0 := next
directive1 := func(ctx context.Context) (any, error) {
organization, err := ec.unmarshalOBoolean2ᚖbool(ctx, true)
if err != nil {
var zeroVal *model.SchemaUpdate
return zeroVal, err
}
if ec.directives.Auth == nil {
var zeroVal *model.SchemaUpdate
return zeroVal, errors.New("directive auth is not implemented")
}
return ec.directives.Auth(ctx, nil, directive0, nil, organization)
}
next = directive1
return next
},
ec.marshalNSchemaUpdate2ᚖgitlabᚗcomᚋunboundsoftwareᚋschemasᚋgraphᚋmodelᚐSchemaUpdate,
true,
true,
)
}
func (ec *executionContext) fieldContext_Query_latestSchema(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) {
fc = &graphql.FieldContext{
Object: "Query",
Field: field,
IsMethod: true,
IsResolver: true,
Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) {
switch field.Name {
case "ref":
return ec.fieldContext_SchemaUpdate_ref(ctx, field)
case "id":
return ec.fieldContext_SchemaUpdate_id(ctx, field)
case "subGraphs":
return ec.fieldContext_SchemaUpdate_subGraphs(ctx, field)
case "cosmoRouterConfig":
return ec.fieldContext_SchemaUpdate_cosmoRouterConfig(ctx, field)
}
return nil, fmt.Errorf("no field named %q was found under type SchemaUpdate", field.Name)
},
}
defer func() {
if r := recover(); r != nil {
err = ec.Recover(ctx, r)
ec.Error(ctx, err)
}
}()
ctx = graphql.WithFieldContext(ctx, fc)
if fc.Args, err = ec.field_Query_latestSchema_args(ctx, field.ArgumentMap(ec.Variables)); err != nil {
ec.Error(ctx, err)
return fc, err
}
return fc, nil
}
func (ec *executionContext) _Query___type(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) {
return graphql.ResolveField(
ctx,
@@ -3997,6 +4091,28 @@ func (ec *executionContext) _Query(ctx context.Context, sel ast.SelectionSet) gr
func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) })
}
out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return rrm(innerCtx) })
case "latestSchema":
field := field
innerFunc := func(ctx context.Context, fs *graphql.FieldSet) (res graphql.Marshaler) {
defer func() {
if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r))
}
}()
res = ec._Query_latestSchema(ctx, field)
if res == graphql.Null {
atomic.AddUint32(&fs.Invalids, 1)
}
return res
}
rrm := func(ctx context.Context) graphql.Marshaler {
return ec.OperationContext.RootResolverMiddleware(ctx,
func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) })
}
out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return rrm(innerCtx) })
case "__type":
out.Values[i] = ec.OperationContext.RootResolverMiddleware(innerCtx, func(ctx context.Context) (res graphql.Marshaler) {
+1
View File
@@ -1,6 +1,7 @@
type Query {
organizations: [Organization!]! @auth(user: true)
supergraph(ref: String!, isAfter: String): Supergraph! @auth(organization: true)
latestSchema(ref: String!): SchemaUpdate! @auth(organization: true)
}
type Mutation {
+113 -5
View File
@@ -123,6 +123,13 @@ func (r *mutationResolver) UpdateSubGraph(ctx context.Context, input model.Input
// Publish schema update to subscribers
go func() {
services, lastUpdate := r.Cache.Services(orgId, input.Ref, "")
r.Logger.Info("Publishing schema update after subgraph change",
"ref", input.Ref,
"orgId", orgId,
"lastUpdate", lastUpdate,
"servicesCount", len(services),
)
subGraphs := make([]*model.SubGraph, len(services))
for i, id := range services {
sg, err := r.fetchSubGraph(context.Background(), id)
@@ -149,12 +156,21 @@ func (r *mutationResolver) UpdateSubGraph(ctx context.Context, input model.Input
}
// Publish to all subscribers of this ref
r.PubSub.Publish(input.Ref, &model.SchemaUpdate{
update := &model.SchemaUpdate{
Ref: input.Ref,
ID: lastUpdate,
SubGraphs: subGraphs,
CosmoRouterConfig: &cosmoConfig,
})
}
r.Logger.Info("Publishing schema update to subscribers",
"ref", update.Ref,
"id", update.ID,
"subGraphsCount", len(update.SubGraphs),
"cosmoConfigLength", len(cosmoConfig),
)
r.PubSub.Publish(input.Ref, update)
}()
return r.toGqlSubGraph(subGraph), nil
@@ -222,11 +238,84 @@ func (r *queryResolver) Supergraph(ctx context.Context, ref string, isAfter *str
}, nil
}
// LatestSchema is the resolver for the latestSchema field.
func (r *queryResolver) LatestSchema(ctx context.Context, ref string) (*model.SchemaUpdate, error) {
orgId := middleware.OrganizationFromContext(ctx)
r.Logger.Info("LatestSchema query",
"ref", ref,
"orgId", orgId,
)
_, err := r.apiKeyCanAccessRef(ctx, ref, false)
if err != nil {
r.Logger.Error("API key cannot access ref", "error", err, "ref", ref)
return nil, err
}
// Get current services and schema
services, lastUpdate := r.Cache.Services(orgId, ref, "")
r.Logger.Info("Fetching latest schema",
"ref", ref,
"orgId", orgId,
"lastUpdate", lastUpdate,
"servicesCount", len(services),
)
subGraphs := make([]*model.SubGraph, len(services))
for i, id := range services {
sg, err := r.fetchSubGraph(ctx, id)
if err != nil {
r.Logger.Error("fetch subgraph", "error", err, "id", id)
return nil, err
}
subGraphs[i] = &model.SubGraph{
ID: sg.ID.String(),
Service: sg.Service,
URL: sg.Url,
WsURL: sg.WSUrl,
Sdl: sg.Sdl,
ChangedBy: sg.ChangedBy,
ChangedAt: sg.ChangedAt,
}
}
// Generate Cosmo router config
cosmoConfig, err := GenerateCosmoRouterConfig(subGraphs)
if err != nil {
r.Logger.Error("generate cosmo config", "error", err)
cosmoConfig = "" // Return empty if generation fails
}
update := &model.SchemaUpdate{
Ref: ref,
ID: lastUpdate,
SubGraphs: subGraphs,
CosmoRouterConfig: &cosmoConfig,
}
r.Logger.Info("Latest schema fetched",
"ref", update.Ref,
"id", update.ID,
"subGraphsCount", len(update.SubGraphs),
"cosmoConfigLength", len(cosmoConfig),
)
return update, nil
}
// SchemaUpdates is the resolver for the schemaUpdates field.
func (r *subscriptionResolver) SchemaUpdates(ctx context.Context, ref string) (<-chan *model.SchemaUpdate, error) {
orgId := middleware.OrganizationFromContext(ctx)
r.Logger.Info("SchemaUpdates subscription started",
"ref", ref,
"orgId", orgId,
)
_, err := r.apiKeyCanAccessRef(ctx, ref, false)
if err != nil {
r.Logger.Error("API key cannot access ref", "error", err, "ref", ref)
return nil, err
}
@@ -235,12 +324,22 @@ func (r *subscriptionResolver) SchemaUpdates(ctx context.Context, ref string) (<
// Send initial state immediately
go func() {
// Use background context for async operation
bgCtx := context.Background()
services, lastUpdate := r.Cache.Services(orgId, ref, "")
r.Logger.Info("Preparing initial schema update",
"ref", ref,
"orgId", orgId,
"lastUpdate", lastUpdate,
"servicesCount", len(services),
)
subGraphs := make([]*model.SubGraph, len(services))
for i, id := range services {
sg, err := r.fetchSubGraph(ctx, id)
sg, err := r.fetchSubGraph(bgCtx, id)
if err != nil {
r.Logger.Error("fetch subgraph for initial update", "error", err)
r.Logger.Error("fetch subgraph for initial update", "error", err, "id", id)
continue
}
subGraphs[i] = &model.SubGraph{
@@ -262,12 +361,21 @@ func (r *subscriptionResolver) SchemaUpdates(ctx context.Context, ref string) (<
}
// Send initial update
ch <- &model.SchemaUpdate{
update := &model.SchemaUpdate{
Ref: ref,
ID: lastUpdate,
SubGraphs: subGraphs,
CosmoRouterConfig: &cosmoConfig,
}
r.Logger.Info("Sending initial schema update",
"ref", update.Ref,
"id", update.ID,
"subGraphsCount", len(update.SubGraphs),
"cosmoConfigLength", len(cosmoConfig),
)
ch <- update
}()
// Clean up subscription when context is done
+3 -1
View File
@@ -49,7 +49,9 @@ func (m *AuthMiddleware) Handler(next http.Handler) http.Handler {
_, _ = w.Write([]byte("Invalid API Key format"))
return
}
if organization := m.cache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
hashedKey := hash.String(apiKey)
organization := m.cache.OrganizationByAPIKey(hashedKey)
if organization != nil {
ctx = context.WithValue(ctx, OrganizationKey, *organization)
}
+467
View File
@@ -0,0 +1,467 @@
package middleware
import (
"context"
"net/http"
"net/http/httptest"
"testing"
mw "github.com/auth0/go-jwt-middleware/v2"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"gitlab.com/unboundsoftware/eventsourced/eventsourced"
"gitlab.com/unboundsoftware/schemas/domain"
"gitlab.com/unboundsoftware/schemas/hash"
)
// MockCache is a mock implementation of the Cache interface
type MockCache struct {
mock.Mock
}
func (m *MockCache) OrganizationByAPIKey(apiKey string) *domain.Organization {
args := m.Called(apiKey)
if args.Get(0) == nil {
return nil
}
return args.Get(0).(*domain.Organization)
}
func TestAuthMiddleware_Handler_WithValidAPIKey(t *testing.T) {
// Setup
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
orgID := uuid.New()
expectedOrg := &domain.Organization{
BaseAggregate: eventsourced.BaseAggregate{
ID: eventsourced.IdFromString(orgID.String()),
},
Name: "Test Organization",
}
apiKey := "test-api-key-123"
hashedKey := hash.String(apiKey)
mockCache.On("OrganizationByAPIKey", hashedKey).Return(expectedOrg)
// Create a test handler that checks the context
var capturedOrg *domain.Organization
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if org := r.Context().Value(OrganizationKey); org != nil {
if o, ok := org.(domain.Organization); ok {
capturedOrg = &o
}
}
w.WriteHeader(http.StatusOK)
})
// Create request with API key in context
req := httptest.NewRequest(http.MethodGet, "/test", nil)
ctx := context.WithValue(req.Context(), ApiKey, apiKey)
req = req.WithContext(ctx)
rec := httptest.NewRecorder()
// Execute
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
// Assert
assert.Equal(t, http.StatusOK, rec.Code)
require.NotNil(t, capturedOrg)
assert.Equal(t, expectedOrg.Name, capturedOrg.Name)
assert.Equal(t, expectedOrg.ID.String(), capturedOrg.ID.String())
mockCache.AssertExpectations(t)
}
func TestAuthMiddleware_Handler_WithInvalidAPIKey(t *testing.T) {
// Setup
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
apiKey := "invalid-api-key"
hashedKey := hash.String(apiKey)
mockCache.On("OrganizationByAPIKey", hashedKey).Return(nil)
// Create a test handler that checks the context
var capturedOrg *domain.Organization
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if org := r.Context().Value(OrganizationKey); org != nil {
if o, ok := org.(domain.Organization); ok {
capturedOrg = &o
}
}
w.WriteHeader(http.StatusOK)
})
// Create request with API key in context
req := httptest.NewRequest(http.MethodGet, "/test", nil)
ctx := context.WithValue(req.Context(), ApiKey, apiKey)
req = req.WithContext(ctx)
rec := httptest.NewRecorder()
// Execute
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
// Assert
assert.Equal(t, http.StatusOK, rec.Code)
assert.Nil(t, capturedOrg, "Organization should not be set for invalid API key")
mockCache.AssertExpectations(t)
}
func TestAuthMiddleware_Handler_WithoutAPIKey(t *testing.T) {
// Setup
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
// The middleware always hashes the API key (even if empty) and calls the cache
emptyKeyHash := hash.String("")
mockCache.On("OrganizationByAPIKey", emptyKeyHash).Return(nil)
// Create a test handler that checks the context
var capturedOrg *domain.Organization
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if org := r.Context().Value(OrganizationKey); org != nil {
if o, ok := org.(domain.Organization); ok {
capturedOrg = &o
}
}
w.WriteHeader(http.StatusOK)
})
// Create request without API key
req := httptest.NewRequest(http.MethodGet, "/test", nil)
rec := httptest.NewRecorder()
// Execute
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
// Assert
assert.Equal(t, http.StatusOK, rec.Code)
assert.Nil(t, capturedOrg, "Organization should not be set without API key")
mockCache.AssertExpectations(t)
}
func TestAuthMiddleware_Handler_WithValidJWT(t *testing.T) {
// Setup
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
// The middleware always hashes the API key (even if empty) and calls the cache
emptyKeyHash := hash.String("")
mockCache.On("OrganizationByAPIKey", emptyKeyHash).Return(nil)
userID := "user-123"
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"sub": userID,
})
// Create a test handler that checks the context
var capturedUser string
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if user := r.Context().Value(UserKey); user != nil {
if u, ok := user.(string); ok {
capturedUser = u
}
}
w.WriteHeader(http.StatusOK)
})
// Create request with JWT token in context
req := httptest.NewRequest(http.MethodGet, "/test", nil)
ctx := context.WithValue(req.Context(), mw.ContextKey{}, token)
req = req.WithContext(ctx)
rec := httptest.NewRecorder()
// Execute
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
// Assert
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, userID, capturedUser)
}
func TestAuthMiddleware_Handler_APIKeyErrorHandling(t *testing.T) {
// Setup
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Create request with invalid API key type in context
req := httptest.NewRequest(http.MethodGet, "/test", nil)
ctx := context.WithValue(req.Context(), ApiKey, 12345) // Invalid type
req = req.WithContext(ctx)
rec := httptest.NewRecorder()
// Execute
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
// Assert
assert.Equal(t, http.StatusInternalServerError, rec.Code)
assert.Contains(t, rec.Body.String(), "Invalid API Key format")
}
func TestAuthMiddleware_Handler_JWTErrorHandling(t *testing.T) {
// Setup
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Create request with invalid JWT token type in context
req := httptest.NewRequest(http.MethodGet, "/test", nil)
ctx := context.WithValue(req.Context(), mw.ContextKey{}, "not-a-token") // Invalid type
req = req.WithContext(ctx)
rec := httptest.NewRecorder()
// Execute
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
// Assert
assert.Equal(t, http.StatusInternalServerError, rec.Code)
assert.Contains(t, rec.Body.String(), "Invalid JWT token format")
}
func TestAuthMiddleware_Handler_BothJWTAndAPIKey(t *testing.T) {
// Setup
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
orgID := uuid.New()
expectedOrg := &domain.Organization{
BaseAggregate: eventsourced.BaseAggregate{
ID: eventsourced.IdFromString(orgID.String()),
},
Name: "Test Organization",
}
userID := "user-123"
apiKey := "test-api-key-123"
hashedKey := hash.String(apiKey)
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"sub": userID,
})
mockCache.On("OrganizationByAPIKey", hashedKey).Return(expectedOrg)
// Create a test handler that checks both user and organization in context
var capturedUser string
var capturedOrg *domain.Organization
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if user := r.Context().Value(UserKey); user != nil {
if u, ok := user.(string); ok {
capturedUser = u
}
}
if org := r.Context().Value(OrganizationKey); org != nil {
if o, ok := org.(domain.Organization); ok {
capturedOrg = &o
}
}
w.WriteHeader(http.StatusOK)
})
// Create request with both JWT and API key in context
req := httptest.NewRequest(http.MethodGet, "/test", nil)
ctx := context.WithValue(req.Context(), mw.ContextKey{}, token)
ctx = context.WithValue(ctx, ApiKey, apiKey)
req = req.WithContext(ctx)
rec := httptest.NewRecorder()
// Execute
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
// Assert
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, userID, capturedUser)
require.NotNil(t, capturedOrg)
assert.Equal(t, expectedOrg.Name, capturedOrg.Name)
mockCache.AssertExpectations(t)
}
func TestUserFromContext(t *testing.T) {
tests := []struct {
name string
ctx context.Context
expected string
}{
{
name: "with valid user",
ctx: context.WithValue(context.Background(), UserKey, "user-123"),
expected: "user-123",
},
{
name: "without user",
ctx: context.Background(),
expected: "",
},
{
name: "with invalid type",
ctx: context.WithValue(context.Background(), UserKey, 123),
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := UserFromContext(tt.ctx)
assert.Equal(t, tt.expected, result)
})
}
}
func TestOrganizationFromContext(t *testing.T) {
orgID := uuid.New()
org := domain.Organization{
BaseAggregate: eventsourced.BaseAggregate{
ID: eventsourced.IdFromString(orgID.String()),
},
Name: "Test Org",
}
tests := []struct {
name string
ctx context.Context
expected string
}{
{
name: "with valid organization",
ctx: context.WithValue(context.Background(), OrganizationKey, org),
expected: orgID.String(),
},
{
name: "without organization",
ctx: context.Background(),
expected: "",
},
{
name: "with invalid type",
ctx: context.WithValue(context.Background(), OrganizationKey, "not-an-org"),
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := OrganizationFromContext(tt.ctx)
assert.Equal(t, tt.expected, result)
})
}
}
func TestAuthMiddleware_Directive_RequiresUser(t *testing.T) {
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
requireUser := true
// Test with user present
ctx := context.WithValue(context.Background(), UserKey, "user-123")
_, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
return "success", nil
}, &requireUser, nil)
assert.NoError(t, err)
// Test without user
ctx = context.Background()
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
return "success", nil
}, &requireUser, nil)
assert.Error(t, err)
assert.Contains(t, err.Error(), "no user available in request")
}
func TestAuthMiddleware_Directive_RequiresOrganization(t *testing.T) {
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
requireOrg := true
orgID := uuid.New()
org := domain.Organization{
BaseAggregate: eventsourced.BaseAggregate{
ID: eventsourced.IdFromString(orgID.String()),
},
Name: "Test Org",
}
// Test with organization present
ctx := context.WithValue(context.Background(), OrganizationKey, org)
_, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
return "success", nil
}, nil, &requireOrg)
assert.NoError(t, err)
// Test without organization
ctx = context.Background()
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
return "success", nil
}, nil, &requireOrg)
assert.Error(t, err)
assert.Contains(t, err.Error(), "no organization available in request")
}
func TestAuthMiddleware_Directive_RequiresBoth(t *testing.T) {
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
requireUser := true
requireOrg := true
orgID := uuid.New()
org := domain.Organization{
BaseAggregate: eventsourced.BaseAggregate{
ID: eventsourced.IdFromString(orgID.String()),
},
Name: "Test Org",
}
// Test with both present
ctx := context.WithValue(context.Background(), UserKey, "user-123")
ctx = context.WithValue(ctx, OrganizationKey, org)
_, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
return "success", nil
}, &requireUser, &requireOrg)
assert.NoError(t, err)
// Test with only user
ctx = context.WithValue(context.Background(), UserKey, "user-123")
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
return "success", nil
}, &requireUser, &requireOrg)
assert.Error(t, err)
// Test with only organization
ctx = context.WithValue(context.Background(), OrganizationKey, org)
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
return "success", nil
}, &requireUser, &requireOrg)
assert.Error(t, err)
}
func TestAuthMiddleware_Directive_NoRequirements(t *testing.T) {
mockCache := new(MockCache)
authMiddleware := NewAuth(mockCache)
// Test with no requirements
ctx := context.Background()
result, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
return "success", nil
}, nil, nil)
assert.NoError(t, err)
assert.Equal(t, "success", result)
}