Merge branch 'enhance-api-key-handling-logging' into 'main'
fix: enhance API key handling and logging in middleware See merge request unboundsoftware/schemas!623
This commit was merged in pull request #627.
This commit is contained in:
+9
-1
@@ -24,9 +24,17 @@ RUN GOOS=linux GOARCH=amd64 go build \
|
||||
FROM scratch as export
|
||||
COPY --from=build /build/coverage.txt /
|
||||
|
||||
FROM scratch
|
||||
FROM node:22-alpine
|
||||
ENV TZ Europe/Stockholm
|
||||
|
||||
# Install wgc CLI globally for Cosmo Router composition
|
||||
RUN npm install -g wgc@latest
|
||||
|
||||
# Copy timezone data and certificates
|
||||
COPY --from=build /usr/share/zoneinfo /usr/share/zoneinfo
|
||||
COPY --from=build /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/
|
||||
|
||||
# Copy the service binary
|
||||
COPY --from=build /release/service /
|
||||
|
||||
CMD ["/service"]
|
||||
|
||||
@@ -30,6 +30,7 @@ import (
|
||||
"gitlab.com/unboundsoftware/schemas/domain"
|
||||
"gitlab.com/unboundsoftware/schemas/graph"
|
||||
"gitlab.com/unboundsoftware/schemas/graph/generated"
|
||||
"gitlab.com/unboundsoftware/schemas/hash"
|
||||
"gitlab.com/unboundsoftware/schemas/logging"
|
||||
"gitlab.com/unboundsoftware/schemas/middleware"
|
||||
"gitlab.com/unboundsoftware/schemas/monitoring"
|
||||
@@ -210,6 +211,24 @@ func start(closeEvents chan error, logger *slog.Logger, connectToAmqpFunc func(u
|
||||
|
||||
srv.AddTransport(transport.Websocket{
|
||||
KeepAlivePingInterval: 10 * time.Second,
|
||||
InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
||||
// Extract API key from WebSocket connection_init payload
|
||||
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
||||
logger.Info("WebSocket connection with API key", "has_key", true)
|
||||
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
||||
|
||||
// Look up organization by API key (same logic as auth middleware)
|
||||
if organization := serviceCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
||||
logger.Info("WebSocket: Organization found for API key", "org_id", organization.ID.String())
|
||||
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
||||
} else {
|
||||
logger.Warn("WebSocket: No organization found for API key")
|
||||
}
|
||||
} else {
|
||||
logger.Info("WebSocket connection without API key")
|
||||
}
|
||||
return ctx, &initPayload, nil
|
||||
},
|
||||
})
|
||||
srv.AddTransport(transport.Options{})
|
||||
srv.AddTransport(transport.GET{})
|
||||
|
||||
@@ -0,0 +1,336 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/99designs/gqlgen/graphql/handler/transport"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gitlab.com/unboundsoftware/eventsourced/eventsourced"
|
||||
|
||||
"gitlab.com/unboundsoftware/schemas/domain"
|
||||
"gitlab.com/unboundsoftware/schemas/hash"
|
||||
"gitlab.com/unboundsoftware/schemas/middleware"
|
||||
)
|
||||
|
||||
// MockCache is a mock implementation for testing
|
||||
type MockCache struct {
|
||||
organizations map[string]*domain.Organization
|
||||
}
|
||||
|
||||
func (m *MockCache) OrganizationByAPIKey(apiKey string) *domain.Organization {
|
||||
return m.organizations[apiKey]
|
||||
}
|
||||
|
||||
func TestWebSocketInitFunc_WithValidAPIKey(t *testing.T) {
|
||||
// Setup
|
||||
orgID := uuid.New()
|
||||
org := &domain.Organization{
|
||||
BaseAggregate: eventsourced.BaseAggregate{
|
||||
ID: eventsourced.IdFromString(orgID.String()),
|
||||
},
|
||||
Name: "Test Organization",
|
||||
}
|
||||
|
||||
apiKey := "test-api-key-123"
|
||||
hashedKey := hash.String(apiKey)
|
||||
|
||||
mockCache := &MockCache{
|
||||
organizations: map[string]*domain.Organization{
|
||||
hashedKey: org,
|
||||
},
|
||||
}
|
||||
|
||||
// Create InitFunc (simulating the WebSocket InitFunc logic)
|
||||
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
||||
// Extract API key from WebSocket connection_init payload
|
||||
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
||||
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
||||
|
||||
// Look up organization by API key
|
||||
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
||||
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
||||
}
|
||||
}
|
||||
return ctx, &initPayload, nil
|
||||
}
|
||||
|
||||
// Test
|
||||
ctx := context.Background()
|
||||
initPayload := transport.InitPayload{
|
||||
"X-Api-Key": apiKey,
|
||||
}
|
||||
|
||||
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
|
||||
|
||||
// Assert
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resultPayload)
|
||||
|
||||
// Check API key is in context
|
||||
if value := resultCtx.Value(middleware.ApiKey); value != nil {
|
||||
assert.Equal(t, apiKey, value.(string))
|
||||
} else {
|
||||
t.Fatal("API key not found in context")
|
||||
}
|
||||
|
||||
// Check organization is in context
|
||||
if value := resultCtx.Value(middleware.OrganizationKey); value != nil {
|
||||
capturedOrg, ok := value.(domain.Organization)
|
||||
require.True(t, ok, "Organization should be of correct type")
|
||||
assert.Equal(t, org.Name, capturedOrg.Name)
|
||||
assert.Equal(t, org.ID.String(), capturedOrg.ID.String())
|
||||
} else {
|
||||
t.Fatal("Organization not found in context")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebSocketInitFunc_WithInvalidAPIKey(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := &MockCache{
|
||||
organizations: map[string]*domain.Organization{},
|
||||
}
|
||||
|
||||
apiKey := "invalid-api-key"
|
||||
|
||||
// Create InitFunc
|
||||
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
||||
// Extract API key from WebSocket connection_init payload
|
||||
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
||||
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
||||
|
||||
// Look up organization by API key
|
||||
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
||||
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
||||
}
|
||||
}
|
||||
return ctx, &initPayload, nil
|
||||
}
|
||||
|
||||
// Test
|
||||
ctx := context.Background()
|
||||
initPayload := transport.InitPayload{
|
||||
"X-Api-Key": apiKey,
|
||||
}
|
||||
|
||||
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
|
||||
|
||||
// Assert
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resultPayload)
|
||||
|
||||
// Check API key is in context
|
||||
if value := resultCtx.Value(middleware.ApiKey); value != nil {
|
||||
assert.Equal(t, apiKey, value.(string))
|
||||
} else {
|
||||
t.Fatal("API key not found in context")
|
||||
}
|
||||
|
||||
// Check organization is NOT in context (since API key is invalid)
|
||||
value := resultCtx.Value(middleware.OrganizationKey)
|
||||
assert.Nil(t, value, "Organization should not be set for invalid API key")
|
||||
}
|
||||
|
||||
func TestWebSocketInitFunc_WithoutAPIKey(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := &MockCache{
|
||||
organizations: map[string]*domain.Organization{},
|
||||
}
|
||||
|
||||
// Create InitFunc
|
||||
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
||||
// Extract API key from WebSocket connection_init payload
|
||||
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
||||
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
||||
|
||||
// Look up organization by API key
|
||||
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
||||
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
||||
}
|
||||
}
|
||||
return ctx, &initPayload, nil
|
||||
}
|
||||
|
||||
// Test
|
||||
ctx := context.Background()
|
||||
initPayload := transport.InitPayload{}
|
||||
|
||||
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
|
||||
|
||||
// Assert
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resultPayload)
|
||||
|
||||
// Check API key is NOT in context
|
||||
value := resultCtx.Value(middleware.ApiKey)
|
||||
assert.Nil(t, value, "API key should not be set when not provided")
|
||||
|
||||
// Check organization is NOT in context
|
||||
value = resultCtx.Value(middleware.OrganizationKey)
|
||||
assert.Nil(t, value, "Organization should not be set when API key is not provided")
|
||||
}
|
||||
|
||||
func TestWebSocketInitFunc_WithEmptyAPIKey(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := &MockCache{
|
||||
organizations: map[string]*domain.Organization{},
|
||||
}
|
||||
|
||||
// Create InitFunc
|
||||
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
||||
// Extract API key from WebSocket connection_init payload
|
||||
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
||||
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
||||
|
||||
// Look up organization by API key
|
||||
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
||||
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
||||
}
|
||||
}
|
||||
return ctx, &initPayload, nil
|
||||
}
|
||||
|
||||
// Test
|
||||
ctx := context.Background()
|
||||
initPayload := transport.InitPayload{
|
||||
"X-Api-Key": "", // Empty string
|
||||
}
|
||||
|
||||
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
|
||||
|
||||
// Assert
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resultPayload)
|
||||
|
||||
// Check API key is NOT in context (because empty string fails the condition)
|
||||
value := resultCtx.Value(middleware.ApiKey)
|
||||
assert.Nil(t, value, "API key should not be set when empty")
|
||||
|
||||
// Check organization is NOT in context
|
||||
value = resultCtx.Value(middleware.OrganizationKey)
|
||||
assert.Nil(t, value, "Organization should not be set when API key is empty")
|
||||
}
|
||||
|
||||
func TestWebSocketInitFunc_WithWrongTypeAPIKey(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := &MockCache{
|
||||
organizations: map[string]*domain.Organization{},
|
||||
}
|
||||
|
||||
// Create InitFunc
|
||||
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
||||
// Extract API key from WebSocket connection_init payload
|
||||
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
||||
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
||||
|
||||
// Look up organization by API key
|
||||
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
||||
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
||||
}
|
||||
}
|
||||
return ctx, &initPayload, nil
|
||||
}
|
||||
|
||||
// Test
|
||||
ctx := context.Background()
|
||||
initPayload := transport.InitPayload{
|
||||
"X-Api-Key": 12345, // Wrong type (int instead of string)
|
||||
}
|
||||
|
||||
resultCtx, resultPayload, err := initFunc(ctx, initPayload)
|
||||
|
||||
// Assert
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resultPayload)
|
||||
|
||||
// Check API key is NOT in context (type assertion fails)
|
||||
value := resultCtx.Value(middleware.ApiKey)
|
||||
assert.Nil(t, value, "API key should not be set when wrong type")
|
||||
|
||||
// Check organization is NOT in context
|
||||
value = resultCtx.Value(middleware.OrganizationKey)
|
||||
assert.Nil(t, value, "Organization should not be set when API key has wrong type")
|
||||
}
|
||||
|
||||
func TestWebSocketInitFunc_WithMultipleOrganizations(t *testing.T) {
|
||||
// Setup - create multiple organizations
|
||||
org1ID := uuid.New()
|
||||
org1 := &domain.Organization{
|
||||
BaseAggregate: eventsourced.BaseAggregate{
|
||||
ID: eventsourced.IdFromString(org1ID.String()),
|
||||
},
|
||||
Name: "Organization 1",
|
||||
}
|
||||
|
||||
org2ID := uuid.New()
|
||||
org2 := &domain.Organization{
|
||||
BaseAggregate: eventsourced.BaseAggregate{
|
||||
ID: eventsourced.IdFromString(org2ID.String()),
|
||||
},
|
||||
Name: "Organization 2",
|
||||
}
|
||||
|
||||
apiKey1 := "api-key-org-1"
|
||||
apiKey2 := "api-key-org-2"
|
||||
hashedKey1 := hash.String(apiKey1)
|
||||
hashedKey2 := hash.String(apiKey2)
|
||||
|
||||
mockCache := &MockCache{
|
||||
organizations: map[string]*domain.Organization{
|
||||
hashedKey1: org1,
|
||||
hashedKey2: org2,
|
||||
},
|
||||
}
|
||||
|
||||
// Create InitFunc
|
||||
initFunc := func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
|
||||
// Extract API key from WebSocket connection_init payload
|
||||
if apiKey, ok := initPayload["X-Api-Key"].(string); ok && apiKey != "" {
|
||||
ctx = context.WithValue(ctx, middleware.ApiKey, apiKey)
|
||||
|
||||
// Look up organization by API key
|
||||
if organization := mockCache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
||||
ctx = context.WithValue(ctx, middleware.OrganizationKey, *organization)
|
||||
}
|
||||
}
|
||||
return ctx, &initPayload, nil
|
||||
}
|
||||
|
||||
// Test with first API key
|
||||
ctx1 := context.Background()
|
||||
initPayload1 := transport.InitPayload{
|
||||
"X-Api-Key": apiKey1,
|
||||
}
|
||||
|
||||
resultCtx1, _, err := initFunc(ctx1, initPayload1)
|
||||
require.NoError(t, err)
|
||||
|
||||
if value := resultCtx1.Value(middleware.OrganizationKey); value != nil {
|
||||
capturedOrg, ok := value.(domain.Organization)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, org1.Name, capturedOrg.Name)
|
||||
assert.Equal(t, org1.ID.String(), capturedOrg.ID.String())
|
||||
} else {
|
||||
t.Fatal("Organization 1 not found in context")
|
||||
}
|
||||
|
||||
// Test with second API key
|
||||
ctx2 := context.Background()
|
||||
initPayload2 := transport.InitPayload{
|
||||
"X-Api-Key": apiKey2,
|
||||
}
|
||||
|
||||
resultCtx2, _, err := initFunc(ctx2, initPayload2)
|
||||
require.NoError(t, err)
|
||||
|
||||
if value := resultCtx2.Value(middleware.OrganizationKey); value != nil {
|
||||
capturedOrg, ok := value.(domain.Organization)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, org2.Name, capturedOrg.Name)
|
||||
assert.Equal(t, org2.ID.String(), capturedOrg.ID.String())
|
||||
} else {
|
||||
t.Fatal("Organization 2 not found in context")
|
||||
}
|
||||
}
|
||||
@@ -9,6 +9,7 @@ require (
|
||||
github.com/apex/log v1.9.0
|
||||
github.com/auth0/go-jwt-middleware/v2 v2.3.0
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/jmoiron/sqlx v1.4.0
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/pressly/goose/v3 v3.26.0
|
||||
@@ -30,6 +31,7 @@ require (
|
||||
go.opentelemetry.io/otel/sdk/log v0.14.0
|
||||
go.opentelemetry.io/otel/sdk/metric v1.38.0
|
||||
go.opentelemetry.io/otel/trace v1.38.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -41,7 +43,6 @@ require (
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/gorilla/websocket v1.5.1 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect
|
||||
@@ -51,6 +52,7 @@ require (
|
||||
github.com/rabbitmq/amqp091-go v1.10.0 // indirect
|
||||
github.com/sethvargo/go-retry v0.3.0 // indirect
|
||||
github.com/sosodev/duration v1.3.1 // indirect
|
||||
github.com/stretchr/objx v0.5.2 // indirect
|
||||
github.com/tidwall/gjson v1.17.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
@@ -72,5 +74,4 @@ require (
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 // indirect
|
||||
google.golang.org/grpc v1.75.0 // indirect
|
||||
google.golang.org/protobuf v1.36.10 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
@@ -141,6 +141,8 @@ github.com/sosodev/duration v1.3.1/go.mod h1:RQIBBX0+fMLc/D9+Jb/fwvVmo0eZvDDEERA
|
||||
github.com/sparetimecoders/goamqp v0.3.3 h1:z/nfTPmrjeU/rIVuNOgsVLCimp3WFoNFvS3ZzXRJ6HE=
|
||||
github.com/sparetimecoders/goamqp v0.3.3/go.mod h1:W9NRCpWLE+Vruv2dcRSbszNil2O826d2Nv6kAkETW5o=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
|
||||
+98
-27
@@ -1,54 +1,125 @@
|
||||
package graph
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"gitlab.com/unboundsoftware/schemas/graph/model"
|
||||
)
|
||||
|
||||
// GenerateCosmoRouterConfig generates a Cosmo Router execution config from subgraphs
|
||||
func GenerateCosmoRouterConfig(subGraphs []*model.SubGraph) (string, error) {
|
||||
// Build the Cosmo router config structure
|
||||
// This is a simplified version - you may need to adjust based on actual Cosmo requirements
|
||||
config := map[string]interface{}{
|
||||
"version": "1",
|
||||
"subgraphs": convertSubGraphsToCosmo(subGraphs),
|
||||
// Add other Cosmo-specific configuration as needed
|
||||
}
|
||||
|
||||
// Marshal to JSON
|
||||
configJSON, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshal cosmo config: %w", err)
|
||||
}
|
||||
|
||||
return string(configJSON), nil
|
||||
// CommandExecutor is an interface for executing external commands
|
||||
// This allows for mocking in tests
|
||||
type CommandExecutor interface {
|
||||
Execute(name string, args ...string) ([]byte, error)
|
||||
}
|
||||
|
||||
func convertSubGraphsToCosmo(subGraphs []*model.SubGraph) []map[string]interface{} {
|
||||
cosmoSubgraphs := make([]map[string]interface{}, 0, len(subGraphs))
|
||||
// DefaultCommandExecutor implements CommandExecutor using os/exec
|
||||
type DefaultCommandExecutor struct{}
|
||||
|
||||
// Execute runs a command and returns its combined output
|
||||
func (e *DefaultCommandExecutor) Execute(name string, args ...string) ([]byte, error) {
|
||||
cmd := exec.Command(name, args...)
|
||||
return cmd.CombinedOutput()
|
||||
}
|
||||
|
||||
// GenerateCosmoRouterConfig generates a Cosmo Router execution config from subgraphs
|
||||
// using the official wgc CLI tool via npx
|
||||
func GenerateCosmoRouterConfig(subGraphs []*model.SubGraph) (string, error) {
|
||||
return GenerateCosmoRouterConfigWithExecutor(subGraphs, &DefaultCommandExecutor{})
|
||||
}
|
||||
|
||||
// GenerateCosmoRouterConfigWithExecutor generates a Cosmo Router execution config from subgraphs
|
||||
// using the provided command executor (useful for testing)
|
||||
func GenerateCosmoRouterConfigWithExecutor(subGraphs []*model.SubGraph, executor CommandExecutor) (string, error) {
|
||||
if len(subGraphs) == 0 {
|
||||
return "", fmt.Errorf("no subgraphs provided")
|
||||
}
|
||||
|
||||
// Create a temporary directory for composition
|
||||
tmpDir, err := os.MkdirTemp("", "cosmo-compose-*")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create temp dir: %w", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// Write each subgraph SDL to a file
|
||||
type SubgraphConfig struct {
|
||||
Name string `yaml:"name"`
|
||||
RoutingURL string `yaml:"routing_url,omitempty"`
|
||||
Schema map[string]string `yaml:"schema"`
|
||||
Subscription map[string]interface{} `yaml:"subscription,omitempty"`
|
||||
}
|
||||
|
||||
type InputConfig struct {
|
||||
Version int `yaml:"version"`
|
||||
Subgraphs []SubgraphConfig `yaml:"subgraphs"`
|
||||
}
|
||||
|
||||
inputConfig := InputConfig{
|
||||
Version: 1,
|
||||
Subgraphs: make([]SubgraphConfig, 0, len(subGraphs)),
|
||||
}
|
||||
|
||||
for _, sg := range subGraphs {
|
||||
cosmoSg := map[string]interface{}{
|
||||
"name": sg.Service,
|
||||
"sdl": sg.Sdl,
|
||||
// Write SDL to a temp file
|
||||
schemaFile := filepath.Join(tmpDir, fmt.Sprintf("%s.graphql", sg.Service))
|
||||
if err := os.WriteFile(schemaFile, []byte(sg.Sdl), 0o644); err != nil {
|
||||
return "", fmt.Errorf("write schema file for %s: %w", sg.Service, err)
|
||||
}
|
||||
|
||||
subgraphCfg := SubgraphConfig{
|
||||
Name: sg.Service,
|
||||
Schema: map[string]string{
|
||||
"file": schemaFile,
|
||||
},
|
||||
}
|
||||
|
||||
if sg.URL != nil {
|
||||
cosmoSg["routing_url"] = *sg.URL
|
||||
subgraphCfg.RoutingURL = *sg.URL
|
||||
}
|
||||
|
||||
if sg.WsURL != nil {
|
||||
cosmoSg["subscription"] = map[string]interface{}{
|
||||
subgraphCfg.Subscription = map[string]interface{}{
|
||||
"url": *sg.WsURL,
|
||||
"protocol": "ws",
|
||||
"websocket_subprotocol": "graphql-ws",
|
||||
}
|
||||
}
|
||||
|
||||
cosmoSubgraphs = append(cosmoSubgraphs, cosmoSg)
|
||||
inputConfig.Subgraphs = append(inputConfig.Subgraphs, subgraphCfg)
|
||||
}
|
||||
|
||||
return cosmoSubgraphs
|
||||
// Write input config YAML
|
||||
inputFile := filepath.Join(tmpDir, "input.yaml")
|
||||
inputYAML, err := yaml.Marshal(inputConfig)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshal input config: %w", err)
|
||||
}
|
||||
if err := os.WriteFile(inputFile, inputYAML, 0o644); err != nil {
|
||||
return "", fmt.Errorf("write input config: %w", err)
|
||||
}
|
||||
|
||||
// Execute wgc router compose
|
||||
// wgc is installed globally in the Docker image
|
||||
outputFile := filepath.Join(tmpDir, "config.json")
|
||||
output, err := executor.Execute("wgc", "router", "compose",
|
||||
"--input", inputFile,
|
||||
"--out", outputFile,
|
||||
"--suppress-warnings",
|
||||
)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("wgc router compose failed: %w\nOutput: %s", err, string(output))
|
||||
}
|
||||
|
||||
// Read the generated config
|
||||
configJSON, err := os.ReadFile(outputFile)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read output config: %w", err)
|
||||
}
|
||||
|
||||
return string(configJSON), nil
|
||||
}
|
||||
|
||||
+293
-86
@@ -2,14 +2,185 @@ package graph
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"gitlab.com/unboundsoftware/schemas/graph/model"
|
||||
)
|
||||
|
||||
// MockCommandExecutor implements CommandExecutor for testing
|
||||
type MockCommandExecutor struct {
|
||||
// CallCount tracks how many times Execute was called
|
||||
CallCount int
|
||||
// LastArgs stores the arguments from the last call
|
||||
LastArgs []string
|
||||
// Error can be set to simulate command failures
|
||||
Error error
|
||||
}
|
||||
|
||||
// Execute mocks the wgc command by generating a realistic config.json file
|
||||
func (m *MockCommandExecutor) Execute(name string, args ...string) ([]byte, error) {
|
||||
m.CallCount++
|
||||
m.LastArgs = append([]string{name}, args...)
|
||||
|
||||
if m.Error != nil {
|
||||
return nil, m.Error
|
||||
}
|
||||
|
||||
// Parse the input file to understand what subgraphs we're composing
|
||||
var inputFile, outputFile string
|
||||
for i, arg := range args {
|
||||
if arg == "--input" && i+1 < len(args) {
|
||||
inputFile = args[i+1]
|
||||
}
|
||||
if arg == "--out" && i+1 < len(args) {
|
||||
outputFile = args[i+1]
|
||||
}
|
||||
}
|
||||
|
||||
if inputFile == "" || outputFile == "" {
|
||||
return nil, fmt.Errorf("missing required arguments")
|
||||
}
|
||||
|
||||
// Read the input YAML to get subgraph information
|
||||
inputData, err := os.ReadFile(inputFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read input file: %w", err)
|
||||
}
|
||||
|
||||
var input struct {
|
||||
Version int `yaml:"version"`
|
||||
Subgraphs []struct {
|
||||
Name string `yaml:"name"`
|
||||
RoutingURL string `yaml:"routing_url,omitempty"`
|
||||
Schema map[string]string `yaml:"schema"`
|
||||
Subscription map[string]interface{} `yaml:"subscription,omitempty"`
|
||||
} `yaml:"subgraphs"`
|
||||
}
|
||||
|
||||
if err := yaml.Unmarshal(inputData, &input); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse input YAML: %w", err)
|
||||
}
|
||||
|
||||
// Generate a realistic Cosmo Router config based on the input
|
||||
config := map[string]interface{}{
|
||||
"version": "mock-version-uuid",
|
||||
"subgraphs": func() []map[string]interface{} {
|
||||
subgraphs := make([]map[string]interface{}, len(input.Subgraphs))
|
||||
for i, sg := range input.Subgraphs {
|
||||
subgraph := map[string]interface{}{
|
||||
"id": fmt.Sprintf("mock-id-%d", i),
|
||||
"name": sg.Name,
|
||||
}
|
||||
if sg.RoutingURL != "" {
|
||||
subgraph["routingUrl"] = sg.RoutingURL
|
||||
}
|
||||
subgraphs[i] = subgraph
|
||||
}
|
||||
return subgraphs
|
||||
}(),
|
||||
"engineConfig": map[string]interface{}{
|
||||
"graphqlSchema": generateMockSchema(input.Subgraphs),
|
||||
"datasourceConfigurations": func() []map[string]interface{} {
|
||||
dsConfigs := make([]map[string]interface{}, len(input.Subgraphs))
|
||||
for i, sg := range input.Subgraphs {
|
||||
// Read SDL from file
|
||||
sdl := ""
|
||||
if schemaFile, ok := sg.Schema["file"]; ok {
|
||||
if sdlData, err := os.ReadFile(schemaFile); err == nil {
|
||||
sdl = string(sdlData)
|
||||
}
|
||||
}
|
||||
|
||||
dsConfig := map[string]interface{}{
|
||||
"id": fmt.Sprintf("datasource-%d", i),
|
||||
"kind": "GRAPHQL",
|
||||
"customGraphql": map[string]interface{}{
|
||||
"federation": map[string]interface{}{
|
||||
"enabled": true,
|
||||
"serviceSdl": sdl,
|
||||
},
|
||||
"subscription": func() map[string]interface{} {
|
||||
if len(sg.Subscription) > 0 {
|
||||
return map[string]interface{}{
|
||||
"enabled": true,
|
||||
"url": map[string]interface{}{
|
||||
"staticVariableContent": sg.Subscription["url"],
|
||||
},
|
||||
"protocol": sg.Subscription["protocol"],
|
||||
"websocketSubprotocol": sg.Subscription["websocket_subprotocol"],
|
||||
}
|
||||
}
|
||||
return map[string]interface{}{
|
||||
"enabled": false,
|
||||
}
|
||||
}(),
|
||||
},
|
||||
}
|
||||
dsConfigs[i] = dsConfig
|
||||
}
|
||||
return dsConfigs
|
||||
}(),
|
||||
},
|
||||
}
|
||||
|
||||
// Write the config to the output file
|
||||
configJSON, err := json.Marshal(config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal config: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(outputFile, configJSON, 0o644); err != nil {
|
||||
return nil, fmt.Errorf("failed to write output file: %w", err)
|
||||
}
|
||||
|
||||
return []byte("Success"), nil
|
||||
}
|
||||
|
||||
// generateMockSchema creates a simple merged schema from subgraphs
|
||||
func generateMockSchema(subgraphs []struct {
|
||||
Name string `yaml:"name"`
|
||||
RoutingURL string `yaml:"routing_url,omitempty"`
|
||||
Schema map[string]string `yaml:"schema"`
|
||||
Subscription map[string]interface{} `yaml:"subscription,omitempty"`
|
||||
},
|
||||
) string {
|
||||
schema := strings.Builder{}
|
||||
schema.WriteString("schema {\n query: Query\n")
|
||||
|
||||
// Check if any subgraph has subscriptions
|
||||
hasSubscriptions := false
|
||||
for _, sg := range subgraphs {
|
||||
if len(sg.Subscription) > 0 {
|
||||
hasSubscriptions = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if hasSubscriptions {
|
||||
schema.WriteString(" subscription: Subscription\n")
|
||||
}
|
||||
schema.WriteString("}\n\n")
|
||||
|
||||
// Add types by reading SDL files
|
||||
for _, sg := range subgraphs {
|
||||
if schemaFile, ok := sg.Schema["file"]; ok {
|
||||
if sdlData, err := os.ReadFile(schemaFile); err == nil {
|
||||
schema.WriteString(string(sdlData))
|
||||
schema.WriteString("\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return schema.String()
|
||||
}
|
||||
|
||||
func TestGenerateCosmoRouterConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -33,7 +204,10 @@ func TestGenerateCosmoRouterConfig(t *testing.T) {
|
||||
err := json.Unmarshal([]byte(config), &result)
|
||||
require.NoError(t, err, "Config should be valid JSON")
|
||||
|
||||
assert.Equal(t, "1", result["version"], "Version should be 1")
|
||||
// Version is a UUID string from wgc
|
||||
version, ok := result["version"].(string)
|
||||
require.True(t, ok, "Version should be a string")
|
||||
assert.NotEmpty(t, version, "Version should not be empty")
|
||||
|
||||
subgraphs, ok := result["subgraphs"].([]interface{})
|
||||
require.True(t, ok, "subgraphs should be an array")
|
||||
@@ -41,14 +215,26 @@ func TestGenerateCosmoRouterConfig(t *testing.T) {
|
||||
|
||||
sg := subgraphs[0].(map[string]interface{})
|
||||
assert.Equal(t, "test-service", sg["name"])
|
||||
assert.Equal(t, "http://localhost:4001/query", sg["routing_url"])
|
||||
assert.Equal(t, "type Query { test: String }", sg["sdl"])
|
||||
assert.Equal(t, "http://localhost:4001/query", sg["routingUrl"])
|
||||
|
||||
subscription, ok := sg["subscription"].(map[string]interface{})
|
||||
// Check that datasource configurations include subscription settings
|
||||
engineConfig, ok := result["engineConfig"].(map[string]interface{})
|
||||
require.True(t, ok, "Should have engineConfig")
|
||||
|
||||
dsConfigs, ok := engineConfig["datasourceConfigurations"].([]interface{})
|
||||
require.True(t, ok && len(dsConfigs) > 0, "Should have datasource configurations")
|
||||
|
||||
ds := dsConfigs[0].(map[string]interface{})
|
||||
customGraphql, ok := ds["customGraphql"].(map[string]interface{})
|
||||
require.True(t, ok, "Should have customGraphql config")
|
||||
|
||||
subscription, ok := customGraphql["subscription"].(map[string]interface{})
|
||||
require.True(t, ok, "Should have subscription config")
|
||||
assert.Equal(t, "ws://localhost:4001/query", subscription["url"])
|
||||
assert.Equal(t, "ws", subscription["protocol"])
|
||||
assert.Equal(t, "graphql-ws", subscription["websocket_subprotocol"])
|
||||
assert.True(t, subscription["enabled"].(bool), "Subscription should be enabled")
|
||||
|
||||
subUrl, ok := subscription["url"].(map[string]interface{})
|
||||
require.True(t, ok, "Should have subscription URL")
|
||||
assert.Equal(t, "ws://localhost:4001/query", subUrl["staticVariableContent"])
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -80,18 +266,28 @@ func TestGenerateCosmoRouterConfig(t *testing.T) {
|
||||
subgraphs := result["subgraphs"].([]interface{})
|
||||
assert.Len(t, subgraphs, 3, "Should have 3 subgraphs")
|
||||
|
||||
// Check first service has no subscription
|
||||
// Check service names
|
||||
sg1 := subgraphs[0].(map[string]interface{})
|
||||
assert.Equal(t, "service-1", sg1["name"])
|
||||
_, hasSubscription := sg1["subscription"]
|
||||
assert.False(t, hasSubscription, "service-1 should not have subscription config")
|
||||
|
||||
// Check third service has subscription
|
||||
sg3 := subgraphs[2].(map[string]interface{})
|
||||
assert.Equal(t, "service-3", sg3["name"])
|
||||
subscription, hasSubscription := sg3["subscription"]
|
||||
assert.True(t, hasSubscription, "service-3 should have subscription config")
|
||||
assert.NotNil(t, subscription)
|
||||
|
||||
// Check that datasource configurations include subscription for service-3
|
||||
engineConfig, ok := result["engineConfig"].(map[string]interface{})
|
||||
require.True(t, ok, "Should have engineConfig")
|
||||
|
||||
dsConfigs, ok := engineConfig["datasourceConfigurations"].([]interface{})
|
||||
require.True(t, ok && len(dsConfigs) == 3, "Should have 3 datasource configurations")
|
||||
|
||||
// Find service-3's datasource config (should have subscription enabled)
|
||||
ds3 := dsConfigs[2].(map[string]interface{})
|
||||
customGraphql, ok := ds3["customGraphql"].(map[string]interface{})
|
||||
require.True(t, ok, "Service-3 should have customGraphql config")
|
||||
|
||||
subscription, ok := customGraphql["subscription"].(map[string]interface{})
|
||||
require.True(t, ok, "Service-3 should have subscription config")
|
||||
assert.True(t, subscription["enabled"].(bool), "Service-3 subscription should be enabled")
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -113,39 +309,43 @@ func TestGenerateCosmoRouterConfig(t *testing.T) {
|
||||
subgraphs := result["subgraphs"].([]interface{})
|
||||
sg := subgraphs[0].(map[string]interface{})
|
||||
|
||||
// Should not have routing_url or subscription fields if URLs are nil
|
||||
_, hasRoutingURL := sg["routing_url"]
|
||||
assert.False(t, hasRoutingURL, "Should not have routing_url when URL is nil")
|
||||
// Should not have routing URL when URL is nil
|
||||
_, hasRoutingURL := sg["routingUrl"]
|
||||
assert.False(t, hasRoutingURL, "Should not have routingUrl when URL is nil")
|
||||
|
||||
_, hasSubscription := sg["subscription"]
|
||||
assert.False(t, hasSubscription, "Should not have subscription when WsURL is nil")
|
||||
// Check datasource configurations don't have subscription enabled
|
||||
engineConfig, ok := result["engineConfig"].(map[string]interface{})
|
||||
require.True(t, ok, "Should have engineConfig")
|
||||
|
||||
dsConfigs, ok := engineConfig["datasourceConfigurations"].([]interface{})
|
||||
require.True(t, ok && len(dsConfigs) > 0, "Should have datasource configurations")
|
||||
|
||||
ds := dsConfigs[0].(map[string]interface{})
|
||||
customGraphql, ok := ds["customGraphql"].(map[string]interface{})
|
||||
require.True(t, ok, "Should have customGraphql config")
|
||||
|
||||
subscription, ok := customGraphql["subscription"].(map[string]interface{})
|
||||
if ok {
|
||||
// wgc always enables subscription but URL should be empty when WsURL is nil
|
||||
subUrl, hasUrl := subscription["url"].(map[string]interface{})
|
||||
if hasUrl {
|
||||
_, hasStaticContent := subUrl["staticVariableContent"]
|
||||
assert.False(t, hasStaticContent, "Subscription URL should be empty when WsURL is nil")
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty subgraphs",
|
||||
subGraphs: []*model.SubGraph{},
|
||||
wantErr: false,
|
||||
validate: func(t *testing.T, config string) {
|
||||
var result map[string]interface{}
|
||||
err := json.Unmarshal([]byte(config), &result)
|
||||
require.NoError(t, err)
|
||||
|
||||
subgraphs := result["subgraphs"].([]interface{})
|
||||
assert.Len(t, subgraphs, 0, "Should have empty subgraphs array")
|
||||
},
|
||||
wantErr: true,
|
||||
validate: nil,
|
||||
},
|
||||
{
|
||||
name: "nil subgraphs",
|
||||
subGraphs: nil,
|
||||
wantErr: false,
|
||||
validate: func(t *testing.T, config string) {
|
||||
var result map[string]interface{}
|
||||
err := json.Unmarshal([]byte(config), &result)
|
||||
require.NoError(t, err)
|
||||
|
||||
subgraphs := result["subgraphs"].([]interface{})
|
||||
assert.Len(t, subgraphs, 0, "Should handle nil subgraphs as empty array")
|
||||
},
|
||||
wantErr: true,
|
||||
validate: nil,
|
||||
},
|
||||
{
|
||||
name: "complex SDL with multiple types",
|
||||
@@ -173,29 +373,58 @@ func TestGenerateCosmoRouterConfig(t *testing.T) {
|
||||
err := json.Unmarshal([]byte(config), &result)
|
||||
require.NoError(t, err)
|
||||
|
||||
subgraphs := result["subgraphs"].([]interface{})
|
||||
sg := subgraphs[0].(map[string]interface{})
|
||||
sdl := sg["sdl"].(string)
|
||||
// Check the composed graphqlSchema contains the types
|
||||
engineConfig, ok := result["engineConfig"].(map[string]interface{})
|
||||
require.True(t, ok, "Should have engineConfig")
|
||||
|
||||
assert.Contains(t, sdl, "type Query")
|
||||
assert.Contains(t, sdl, "type User")
|
||||
assert.Contains(t, sdl, "email: String!")
|
||||
graphqlSchema, ok := engineConfig["graphqlSchema"].(string)
|
||||
require.True(t, ok, "Should have graphqlSchema")
|
||||
|
||||
assert.Contains(t, graphqlSchema, "Query", "Schema should contain Query type")
|
||||
assert.Contains(t, graphqlSchema, "User", "Schema should contain User type")
|
||||
|
||||
// Check datasource has the original SDL
|
||||
dsConfigs, ok := engineConfig["datasourceConfigurations"].([]interface{})
|
||||
require.True(t, ok && len(dsConfigs) > 0, "Should have datasource configurations")
|
||||
|
||||
ds := dsConfigs[0].(map[string]interface{})
|
||||
customGraphql, ok := ds["customGraphql"].(map[string]interface{})
|
||||
require.True(t, ok, "Should have customGraphql config")
|
||||
|
||||
federation, ok := customGraphql["federation"].(map[string]interface{})
|
||||
require.True(t, ok, "Should have federation config")
|
||||
|
||||
serviceSdl, ok := federation["serviceSdl"].(string)
|
||||
require.True(t, ok, "Should have serviceSdl")
|
||||
assert.Contains(t, serviceSdl, "email: String!", "SDL should contain email field")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config, err := GenerateCosmoRouterConfig(tt.subGraphs)
|
||||
// Use mock executor for all tests
|
||||
mockExecutor := &MockCommandExecutor{}
|
||||
config, err := GenerateCosmoRouterConfigWithExecutor(tt.subGraphs, mockExecutor)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
// Verify executor was not called for error cases
|
||||
if len(tt.subGraphs) == 0 {
|
||||
assert.Equal(t, 0, mockExecutor.CallCount, "Should not call executor for empty subgraphs")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, config, "Config should not be empty")
|
||||
|
||||
// Verify executor was called correctly
|
||||
assert.Equal(t, 1, mockExecutor.CallCount, "Should call executor once")
|
||||
assert.Equal(t, "wgc", mockExecutor.LastArgs[0], "Should call wgc command")
|
||||
assert.Contains(t, mockExecutor.LastArgs, "router", "Should include 'router' arg")
|
||||
assert.Contains(t, mockExecutor.LastArgs, "compose", "Should include 'compose' arg")
|
||||
|
||||
if tt.validate != nil {
|
||||
tt.validate(t, config)
|
||||
}
|
||||
@@ -203,53 +432,31 @@ func TestGenerateCosmoRouterConfig(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertSubGraphsToCosmo(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
subGraphs []*model.SubGraph
|
||||
wantLen int
|
||||
validate func(t *testing.T, result []map[string]interface{})
|
||||
}{
|
||||
// TestGenerateCosmoRouterConfig_MockError tests error handling with mock executor
|
||||
func TestGenerateCosmoRouterConfig_MockError(t *testing.T) {
|
||||
subGraphs := []*model.SubGraph{
|
||||
{
|
||||
name: "preserves subgraph order",
|
||||
subGraphs: []*model.SubGraph{
|
||||
{Service: "alpha", URL: stringPtr("http://a"), Sdl: "a"},
|
||||
{Service: "beta", URL: stringPtr("http://b"), Sdl: "b"},
|
||||
{Service: "gamma", URL: stringPtr("http://c"), Sdl: "c"},
|
||||
},
|
||||
wantLen: 3,
|
||||
validate: func(t *testing.T, result []map[string]interface{}) {
|
||||
assert.Equal(t, "alpha", result[0]["name"])
|
||||
assert.Equal(t, "beta", result[1]["name"])
|
||||
assert.Equal(t, "gamma", result[2]["name"])
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "includes SDL exactly as provided",
|
||||
subGraphs: []*model.SubGraph{
|
||||
{
|
||||
Service: "test",
|
||||
URL: stringPtr("http://test"),
|
||||
Sdl: "type Query { special: String! }",
|
||||
},
|
||||
},
|
||||
wantLen: 1,
|
||||
validate: func(t *testing.T, result []map[string]interface{}) {
|
||||
assert.Equal(t, "type Query { special: String! }", result[0]["sdl"])
|
||||
},
|
||||
Service: "test-service",
|
||||
URL: stringPtr("http://localhost:4001/query"),
|
||||
Sdl: "type Query { test: String }",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := convertSubGraphsToCosmo(tt.subGraphs)
|
||||
assert.Len(t, result, tt.wantLen)
|
||||
|
||||
if tt.validate != nil {
|
||||
tt.validate(t, result)
|
||||
}
|
||||
})
|
||||
// Create a mock executor that returns an error
|
||||
mockExecutor := &MockCommandExecutor{
|
||||
Error: fmt.Errorf("simulated wgc failure"),
|
||||
}
|
||||
|
||||
config, err := GenerateCosmoRouterConfigWithExecutor(subGraphs, mockExecutor)
|
||||
|
||||
// Verify error is propagated
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "wgc router compose failed")
|
||||
assert.Contains(t, err.Error(), "simulated wgc failure")
|
||||
assert.Empty(t, config)
|
||||
|
||||
// Verify executor was called
|
||||
assert.Equal(t, 1, mockExecutor.CallCount, "Should have attempted to call executor")
|
||||
}
|
||||
|
||||
// Helper function for tests
|
||||
|
||||
@@ -74,6 +74,7 @@ type ComplexityRoot struct {
|
||||
}
|
||||
|
||||
Query struct {
|
||||
LatestSchema func(childComplexity int, ref string) int
|
||||
Organizations func(childComplexity int) int
|
||||
Supergraph func(childComplexity int, ref string, isAfter *string) int
|
||||
}
|
||||
@@ -124,6 +125,7 @@ type MutationResolver interface {
|
||||
type QueryResolver interface {
|
||||
Organizations(ctx context.Context) ([]*model.Organization, error)
|
||||
Supergraph(ctx context.Context, ref string, isAfter *string) (model.Supergraph, error)
|
||||
LatestSchema(ctx context.Context, ref string) (*model.SchemaUpdate, error)
|
||||
}
|
||||
type SubscriptionResolver interface {
|
||||
SchemaUpdates(ctx context.Context, ref string) (<-chan *model.SchemaUpdate, error)
|
||||
@@ -250,6 +252,17 @@ func (e *executableSchema) Complexity(ctx context.Context, typeName, field strin
|
||||
|
||||
return e.complexity.Organization.Users(childComplexity), true
|
||||
|
||||
case "Query.latestSchema":
|
||||
if e.complexity.Query.LatestSchema == nil {
|
||||
break
|
||||
}
|
||||
|
||||
args, err := ec.field_Query_latestSchema_args(ctx, rawArgs)
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
return e.complexity.Query.LatestSchema(childComplexity, args["ref"].(string)), true
|
||||
case "Query.organizations":
|
||||
if e.complexity.Query.Organizations == nil {
|
||||
break
|
||||
@@ -520,6 +533,7 @@ var sources = []*ast.Source{
|
||||
{Name: "../schema.graphqls", Input: `type Query {
|
||||
organizations: [Organization!]! @auth(user: true)
|
||||
supergraph(ref: String!, isAfter: String): Supergraph! @auth(organization: true)
|
||||
latestSchema(ref: String!): SchemaUpdate! @auth(organization: true)
|
||||
}
|
||||
|
||||
type Mutation {
|
||||
@@ -671,6 +685,17 @@ func (ec *executionContext) field_Query___type_args(ctx context.Context, rawArgs
|
||||
return args, nil
|
||||
}
|
||||
|
||||
func (ec *executionContext) field_Query_latestSchema_args(ctx context.Context, rawArgs map[string]any) (map[string]any, error) {
|
||||
var err error
|
||||
args := map[string]any{}
|
||||
arg0, err := graphql.ProcessArgField(ctx, rawArgs, "ref", ec.unmarshalNString2string)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
args["ref"] = arg0
|
||||
return args, nil
|
||||
}
|
||||
|
||||
func (ec *executionContext) field_Query_supergraph_args(ctx context.Context, rawArgs map[string]any) (map[string]any, error) {
|
||||
var err error
|
||||
args := map[string]any{}
|
||||
@@ -1434,6 +1459,75 @@ func (ec *executionContext) fieldContext_Query_supergraph(ctx context.Context, f
|
||||
return fc, nil
|
||||
}
|
||||
|
||||
func (ec *executionContext) _Query_latestSchema(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) {
|
||||
return graphql.ResolveField(
|
||||
ctx,
|
||||
ec.OperationContext,
|
||||
field,
|
||||
ec.fieldContext_Query_latestSchema,
|
||||
func(ctx context.Context) (any, error) {
|
||||
fc := graphql.GetFieldContext(ctx)
|
||||
return ec.resolvers.Query().LatestSchema(ctx, fc.Args["ref"].(string))
|
||||
},
|
||||
func(ctx context.Context, next graphql.Resolver) graphql.Resolver {
|
||||
directive0 := next
|
||||
|
||||
directive1 := func(ctx context.Context) (any, error) {
|
||||
organization, err := ec.unmarshalOBoolean2ᚖbool(ctx, true)
|
||||
if err != nil {
|
||||
var zeroVal *model.SchemaUpdate
|
||||
return zeroVal, err
|
||||
}
|
||||
if ec.directives.Auth == nil {
|
||||
var zeroVal *model.SchemaUpdate
|
||||
return zeroVal, errors.New("directive auth is not implemented")
|
||||
}
|
||||
return ec.directives.Auth(ctx, nil, directive0, nil, organization)
|
||||
}
|
||||
|
||||
next = directive1
|
||||
return next
|
||||
},
|
||||
ec.marshalNSchemaUpdate2ᚖgitlabᚗcomᚋunboundsoftwareᚋschemasᚋgraphᚋmodelᚐSchemaUpdate,
|
||||
true,
|
||||
true,
|
||||
)
|
||||
}
|
||||
|
||||
func (ec *executionContext) fieldContext_Query_latestSchema(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) {
|
||||
fc = &graphql.FieldContext{
|
||||
Object: "Query",
|
||||
Field: field,
|
||||
IsMethod: true,
|
||||
IsResolver: true,
|
||||
Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) {
|
||||
switch field.Name {
|
||||
case "ref":
|
||||
return ec.fieldContext_SchemaUpdate_ref(ctx, field)
|
||||
case "id":
|
||||
return ec.fieldContext_SchemaUpdate_id(ctx, field)
|
||||
case "subGraphs":
|
||||
return ec.fieldContext_SchemaUpdate_subGraphs(ctx, field)
|
||||
case "cosmoRouterConfig":
|
||||
return ec.fieldContext_SchemaUpdate_cosmoRouterConfig(ctx, field)
|
||||
}
|
||||
return nil, fmt.Errorf("no field named %q was found under type SchemaUpdate", field.Name)
|
||||
},
|
||||
}
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = ec.Recover(ctx, r)
|
||||
ec.Error(ctx, err)
|
||||
}
|
||||
}()
|
||||
ctx = graphql.WithFieldContext(ctx, fc)
|
||||
if fc.Args, err = ec.field_Query_latestSchema_args(ctx, field.ArgumentMap(ec.Variables)); err != nil {
|
||||
ec.Error(ctx, err)
|
||||
return fc, err
|
||||
}
|
||||
return fc, nil
|
||||
}
|
||||
|
||||
func (ec *executionContext) _Query___type(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) {
|
||||
return graphql.ResolveField(
|
||||
ctx,
|
||||
@@ -3997,6 +4091,28 @@ func (ec *executionContext) _Query(ctx context.Context, sel ast.SelectionSet) gr
|
||||
func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) })
|
||||
}
|
||||
|
||||
out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return rrm(innerCtx) })
|
||||
case "latestSchema":
|
||||
field := field
|
||||
|
||||
innerFunc := func(ctx context.Context, fs *graphql.FieldSet) (res graphql.Marshaler) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
ec.Error(ctx, ec.Recover(ctx, r))
|
||||
}
|
||||
}()
|
||||
res = ec._Query_latestSchema(ctx, field)
|
||||
if res == graphql.Null {
|
||||
atomic.AddUint32(&fs.Invalids, 1)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
rrm := func(ctx context.Context) graphql.Marshaler {
|
||||
return ec.OperationContext.RootResolverMiddleware(ctx,
|
||||
func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) })
|
||||
}
|
||||
|
||||
out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return rrm(innerCtx) })
|
||||
case "__type":
|
||||
out.Values[i] = ec.OperationContext.RootResolverMiddleware(innerCtx, func(ctx context.Context) (res graphql.Marshaler) {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
type Query {
|
||||
organizations: [Organization!]! @auth(user: true)
|
||||
supergraph(ref: String!, isAfter: String): Supergraph! @auth(organization: true)
|
||||
latestSchema(ref: String!): SchemaUpdate! @auth(organization: true)
|
||||
}
|
||||
|
||||
type Mutation {
|
||||
|
||||
+113
-5
@@ -123,6 +123,13 @@ func (r *mutationResolver) UpdateSubGraph(ctx context.Context, input model.Input
|
||||
// Publish schema update to subscribers
|
||||
go func() {
|
||||
services, lastUpdate := r.Cache.Services(orgId, input.Ref, "")
|
||||
r.Logger.Info("Publishing schema update after subgraph change",
|
||||
"ref", input.Ref,
|
||||
"orgId", orgId,
|
||||
"lastUpdate", lastUpdate,
|
||||
"servicesCount", len(services),
|
||||
)
|
||||
|
||||
subGraphs := make([]*model.SubGraph, len(services))
|
||||
for i, id := range services {
|
||||
sg, err := r.fetchSubGraph(context.Background(), id)
|
||||
@@ -149,12 +156,21 @@ func (r *mutationResolver) UpdateSubGraph(ctx context.Context, input model.Input
|
||||
}
|
||||
|
||||
// Publish to all subscribers of this ref
|
||||
r.PubSub.Publish(input.Ref, &model.SchemaUpdate{
|
||||
update := &model.SchemaUpdate{
|
||||
Ref: input.Ref,
|
||||
ID: lastUpdate,
|
||||
SubGraphs: subGraphs,
|
||||
CosmoRouterConfig: &cosmoConfig,
|
||||
})
|
||||
}
|
||||
|
||||
r.Logger.Info("Publishing schema update to subscribers",
|
||||
"ref", update.Ref,
|
||||
"id", update.ID,
|
||||
"subGraphsCount", len(update.SubGraphs),
|
||||
"cosmoConfigLength", len(cosmoConfig),
|
||||
)
|
||||
|
||||
r.PubSub.Publish(input.Ref, update)
|
||||
}()
|
||||
|
||||
return r.toGqlSubGraph(subGraph), nil
|
||||
@@ -222,11 +238,84 @@ func (r *queryResolver) Supergraph(ctx context.Context, ref string, isAfter *str
|
||||
}, nil
|
||||
}
|
||||
|
||||
// LatestSchema is the resolver for the latestSchema field.
|
||||
func (r *queryResolver) LatestSchema(ctx context.Context, ref string) (*model.SchemaUpdate, error) {
|
||||
orgId := middleware.OrganizationFromContext(ctx)
|
||||
|
||||
r.Logger.Info("LatestSchema query",
|
||||
"ref", ref,
|
||||
"orgId", orgId,
|
||||
)
|
||||
|
||||
_, err := r.apiKeyCanAccessRef(ctx, ref, false)
|
||||
if err != nil {
|
||||
r.Logger.Error("API key cannot access ref", "error", err, "ref", ref)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get current services and schema
|
||||
services, lastUpdate := r.Cache.Services(orgId, ref, "")
|
||||
r.Logger.Info("Fetching latest schema",
|
||||
"ref", ref,
|
||||
"orgId", orgId,
|
||||
"lastUpdate", lastUpdate,
|
||||
"servicesCount", len(services),
|
||||
)
|
||||
|
||||
subGraphs := make([]*model.SubGraph, len(services))
|
||||
for i, id := range services {
|
||||
sg, err := r.fetchSubGraph(ctx, id)
|
||||
if err != nil {
|
||||
r.Logger.Error("fetch subgraph", "error", err, "id", id)
|
||||
return nil, err
|
||||
}
|
||||
subGraphs[i] = &model.SubGraph{
|
||||
ID: sg.ID.String(),
|
||||
Service: sg.Service,
|
||||
URL: sg.Url,
|
||||
WsURL: sg.WSUrl,
|
||||
Sdl: sg.Sdl,
|
||||
ChangedBy: sg.ChangedBy,
|
||||
ChangedAt: sg.ChangedAt,
|
||||
}
|
||||
}
|
||||
|
||||
// Generate Cosmo router config
|
||||
cosmoConfig, err := GenerateCosmoRouterConfig(subGraphs)
|
||||
if err != nil {
|
||||
r.Logger.Error("generate cosmo config", "error", err)
|
||||
cosmoConfig = "" // Return empty if generation fails
|
||||
}
|
||||
|
||||
update := &model.SchemaUpdate{
|
||||
Ref: ref,
|
||||
ID: lastUpdate,
|
||||
SubGraphs: subGraphs,
|
||||
CosmoRouterConfig: &cosmoConfig,
|
||||
}
|
||||
|
||||
r.Logger.Info("Latest schema fetched",
|
||||
"ref", update.Ref,
|
||||
"id", update.ID,
|
||||
"subGraphsCount", len(update.SubGraphs),
|
||||
"cosmoConfigLength", len(cosmoConfig),
|
||||
)
|
||||
|
||||
return update, nil
|
||||
}
|
||||
|
||||
// SchemaUpdates is the resolver for the schemaUpdates field.
|
||||
func (r *subscriptionResolver) SchemaUpdates(ctx context.Context, ref string) (<-chan *model.SchemaUpdate, error) {
|
||||
orgId := middleware.OrganizationFromContext(ctx)
|
||||
|
||||
r.Logger.Info("SchemaUpdates subscription started",
|
||||
"ref", ref,
|
||||
"orgId", orgId,
|
||||
)
|
||||
|
||||
_, err := r.apiKeyCanAccessRef(ctx, ref, false)
|
||||
if err != nil {
|
||||
r.Logger.Error("API key cannot access ref", "error", err, "ref", ref)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -235,12 +324,22 @@ func (r *subscriptionResolver) SchemaUpdates(ctx context.Context, ref string) (<
|
||||
|
||||
// Send initial state immediately
|
||||
go func() {
|
||||
// Use background context for async operation
|
||||
bgCtx := context.Background()
|
||||
|
||||
services, lastUpdate := r.Cache.Services(orgId, ref, "")
|
||||
r.Logger.Info("Preparing initial schema update",
|
||||
"ref", ref,
|
||||
"orgId", orgId,
|
||||
"lastUpdate", lastUpdate,
|
||||
"servicesCount", len(services),
|
||||
)
|
||||
|
||||
subGraphs := make([]*model.SubGraph, len(services))
|
||||
for i, id := range services {
|
||||
sg, err := r.fetchSubGraph(ctx, id)
|
||||
sg, err := r.fetchSubGraph(bgCtx, id)
|
||||
if err != nil {
|
||||
r.Logger.Error("fetch subgraph for initial update", "error", err)
|
||||
r.Logger.Error("fetch subgraph for initial update", "error", err, "id", id)
|
||||
continue
|
||||
}
|
||||
subGraphs[i] = &model.SubGraph{
|
||||
@@ -262,12 +361,21 @@ func (r *subscriptionResolver) SchemaUpdates(ctx context.Context, ref string) (<
|
||||
}
|
||||
|
||||
// Send initial update
|
||||
ch <- &model.SchemaUpdate{
|
||||
update := &model.SchemaUpdate{
|
||||
Ref: ref,
|
||||
ID: lastUpdate,
|
||||
SubGraphs: subGraphs,
|
||||
CosmoRouterConfig: &cosmoConfig,
|
||||
}
|
||||
|
||||
r.Logger.Info("Sending initial schema update",
|
||||
"ref", update.Ref,
|
||||
"id", update.ID,
|
||||
"subGraphsCount", len(update.SubGraphs),
|
||||
"cosmoConfigLength", len(cosmoConfig),
|
||||
)
|
||||
|
||||
ch <- update
|
||||
}()
|
||||
|
||||
// Clean up subscription when context is done
|
||||
|
||||
+3
-1
@@ -49,7 +49,9 @@ func (m *AuthMiddleware) Handler(next http.Handler) http.Handler {
|
||||
_, _ = w.Write([]byte("Invalid API Key format"))
|
||||
return
|
||||
}
|
||||
if organization := m.cache.OrganizationByAPIKey(hash.String(apiKey)); organization != nil {
|
||||
hashedKey := hash.String(apiKey)
|
||||
organization := m.cache.OrganizationByAPIKey(hashedKey)
|
||||
if organization != nil {
|
||||
ctx = context.WithValue(ctx, OrganizationKey, *organization)
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,467 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
mw "github.com/auth0/go-jwt-middleware/v2"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gitlab.com/unboundsoftware/eventsourced/eventsourced"
|
||||
|
||||
"gitlab.com/unboundsoftware/schemas/domain"
|
||||
"gitlab.com/unboundsoftware/schemas/hash"
|
||||
)
|
||||
|
||||
// MockCache is a mock implementation of the Cache interface
|
||||
type MockCache struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockCache) OrganizationByAPIKey(apiKey string) *domain.Organization {
|
||||
args := m.Called(apiKey)
|
||||
if args.Get(0) == nil {
|
||||
return nil
|
||||
}
|
||||
return args.Get(0).(*domain.Organization)
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler_WithValidAPIKey(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
orgID := uuid.New()
|
||||
expectedOrg := &domain.Organization{
|
||||
BaseAggregate: eventsourced.BaseAggregate{
|
||||
ID: eventsourced.IdFromString(orgID.String()),
|
||||
},
|
||||
Name: "Test Organization",
|
||||
}
|
||||
|
||||
apiKey := "test-api-key-123"
|
||||
hashedKey := hash.String(apiKey)
|
||||
|
||||
mockCache.On("OrganizationByAPIKey", hashedKey).Return(expectedOrg)
|
||||
|
||||
// Create a test handler that checks the context
|
||||
var capturedOrg *domain.Organization
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if org := r.Context().Value(OrganizationKey); org != nil {
|
||||
if o, ok := org.(domain.Organization); ok {
|
||||
capturedOrg = &o
|
||||
}
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Create request with API key in context
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
ctx := context.WithValue(req.Context(), ApiKey, apiKey)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||
|
||||
// Assert
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
require.NotNil(t, capturedOrg)
|
||||
assert.Equal(t, expectedOrg.Name, capturedOrg.Name)
|
||||
assert.Equal(t, expectedOrg.ID.String(), capturedOrg.ID.String())
|
||||
mockCache.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler_WithInvalidAPIKey(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
apiKey := "invalid-api-key"
|
||||
hashedKey := hash.String(apiKey)
|
||||
|
||||
mockCache.On("OrganizationByAPIKey", hashedKey).Return(nil)
|
||||
|
||||
// Create a test handler that checks the context
|
||||
var capturedOrg *domain.Organization
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if org := r.Context().Value(OrganizationKey); org != nil {
|
||||
if o, ok := org.(domain.Organization); ok {
|
||||
capturedOrg = &o
|
||||
}
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Create request with API key in context
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
ctx := context.WithValue(req.Context(), ApiKey, apiKey)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||
|
||||
// Assert
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Nil(t, capturedOrg, "Organization should not be set for invalid API key")
|
||||
mockCache.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler_WithoutAPIKey(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
// The middleware always hashes the API key (even if empty) and calls the cache
|
||||
emptyKeyHash := hash.String("")
|
||||
mockCache.On("OrganizationByAPIKey", emptyKeyHash).Return(nil)
|
||||
|
||||
// Create a test handler that checks the context
|
||||
var capturedOrg *domain.Organization
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if org := r.Context().Value(OrganizationKey); org != nil {
|
||||
if o, ok := org.(domain.Organization); ok {
|
||||
capturedOrg = &o
|
||||
}
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Create request without API key
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||
|
||||
// Assert
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Nil(t, capturedOrg, "Organization should not be set without API key")
|
||||
mockCache.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler_WithValidJWT(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
// The middleware always hashes the API key (even if empty) and calls the cache
|
||||
emptyKeyHash := hash.String("")
|
||||
mockCache.On("OrganizationByAPIKey", emptyKeyHash).Return(nil)
|
||||
|
||||
userID := "user-123"
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"sub": userID,
|
||||
})
|
||||
|
||||
// Create a test handler that checks the context
|
||||
var capturedUser string
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if user := r.Context().Value(UserKey); user != nil {
|
||||
if u, ok := user.(string); ok {
|
||||
capturedUser = u
|
||||
}
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Create request with JWT token in context
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
ctx := context.WithValue(req.Context(), mw.ContextKey{}, token)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||
|
||||
// Assert
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, userID, capturedUser)
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler_APIKeyErrorHandling(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Create request with invalid API key type in context
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
ctx := context.WithValue(req.Context(), ApiKey, 12345) // Invalid type
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||
|
||||
// Assert
|
||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), "Invalid API Key format")
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler_JWTErrorHandling(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Create request with invalid JWT token type in context
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
ctx := context.WithValue(req.Context(), mw.ContextKey{}, "not-a-token") // Invalid type
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||
|
||||
// Assert
|
||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), "Invalid JWT token format")
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler_BothJWTAndAPIKey(t *testing.T) {
|
||||
// Setup
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
orgID := uuid.New()
|
||||
expectedOrg := &domain.Organization{
|
||||
BaseAggregate: eventsourced.BaseAggregate{
|
||||
ID: eventsourced.IdFromString(orgID.String()),
|
||||
},
|
||||
Name: "Test Organization",
|
||||
}
|
||||
|
||||
userID := "user-123"
|
||||
apiKey := "test-api-key-123"
|
||||
hashedKey := hash.String(apiKey)
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"sub": userID,
|
||||
})
|
||||
|
||||
mockCache.On("OrganizationByAPIKey", hashedKey).Return(expectedOrg)
|
||||
|
||||
// Create a test handler that checks both user and organization in context
|
||||
var capturedUser string
|
||||
var capturedOrg *domain.Organization
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if user := r.Context().Value(UserKey); user != nil {
|
||||
if u, ok := user.(string); ok {
|
||||
capturedUser = u
|
||||
}
|
||||
}
|
||||
if org := r.Context().Value(OrganizationKey); org != nil {
|
||||
if o, ok := org.(domain.Organization); ok {
|
||||
capturedOrg = &o
|
||||
}
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Create request with both JWT and API key in context
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
ctx := context.WithValue(req.Context(), mw.ContextKey{}, token)
|
||||
ctx = context.WithValue(ctx, ApiKey, apiKey)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
authMiddleware.Handler(testHandler).ServeHTTP(rec, req)
|
||||
|
||||
// Assert
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, userID, capturedUser)
|
||||
require.NotNil(t, capturedOrg)
|
||||
assert.Equal(t, expectedOrg.Name, capturedOrg.Name)
|
||||
mockCache.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestUserFromContext(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ctx context.Context
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "with valid user",
|
||||
ctx: context.WithValue(context.Background(), UserKey, "user-123"),
|
||||
expected: "user-123",
|
||||
},
|
||||
{
|
||||
name: "without user",
|
||||
ctx: context.Background(),
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "with invalid type",
|
||||
ctx: context.WithValue(context.Background(), UserKey, 123),
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := UserFromContext(tt.ctx)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrganizationFromContext(t *testing.T) {
|
||||
orgID := uuid.New()
|
||||
org := domain.Organization{
|
||||
BaseAggregate: eventsourced.BaseAggregate{
|
||||
ID: eventsourced.IdFromString(orgID.String()),
|
||||
},
|
||||
Name: "Test Org",
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ctx context.Context
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "with valid organization",
|
||||
ctx: context.WithValue(context.Background(), OrganizationKey, org),
|
||||
expected: orgID.String(),
|
||||
},
|
||||
{
|
||||
name: "without organization",
|
||||
ctx: context.Background(),
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "with invalid type",
|
||||
ctx: context.WithValue(context.Background(), OrganizationKey, "not-an-org"),
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := OrganizationFromContext(tt.ctx)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Directive_RequiresUser(t *testing.T) {
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
requireUser := true
|
||||
|
||||
// Test with user present
|
||||
ctx := context.WithValue(context.Background(), UserKey, "user-123")
|
||||
_, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||
return "success", nil
|
||||
}, &requireUser, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Test without user
|
||||
ctx = context.Background()
|
||||
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||
return "success", nil
|
||||
}, &requireUser, nil)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no user available in request")
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Directive_RequiresOrganization(t *testing.T) {
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
requireOrg := true
|
||||
orgID := uuid.New()
|
||||
org := domain.Organization{
|
||||
BaseAggregate: eventsourced.BaseAggregate{
|
||||
ID: eventsourced.IdFromString(orgID.String()),
|
||||
},
|
||||
Name: "Test Org",
|
||||
}
|
||||
|
||||
// Test with organization present
|
||||
ctx := context.WithValue(context.Background(), OrganizationKey, org)
|
||||
_, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||
return "success", nil
|
||||
}, nil, &requireOrg)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Test without organization
|
||||
ctx = context.Background()
|
||||
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||
return "success", nil
|
||||
}, nil, &requireOrg)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no organization available in request")
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Directive_RequiresBoth(t *testing.T) {
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
requireUser := true
|
||||
requireOrg := true
|
||||
orgID := uuid.New()
|
||||
org := domain.Organization{
|
||||
BaseAggregate: eventsourced.BaseAggregate{
|
||||
ID: eventsourced.IdFromString(orgID.String()),
|
||||
},
|
||||
Name: "Test Org",
|
||||
}
|
||||
|
||||
// Test with both present
|
||||
ctx := context.WithValue(context.Background(), UserKey, "user-123")
|
||||
ctx = context.WithValue(ctx, OrganizationKey, org)
|
||||
_, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||
return "success", nil
|
||||
}, &requireUser, &requireOrg)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Test with only user
|
||||
ctx = context.WithValue(context.Background(), UserKey, "user-123")
|
||||
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||
return "success", nil
|
||||
}, &requireUser, &requireOrg)
|
||||
assert.Error(t, err)
|
||||
|
||||
// Test with only organization
|
||||
ctx = context.WithValue(context.Background(), OrganizationKey, org)
|
||||
_, err = authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||
return "success", nil
|
||||
}, &requireUser, &requireOrg)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Directive_NoRequirements(t *testing.T) {
|
||||
mockCache := new(MockCache)
|
||||
authMiddleware := NewAuth(mockCache)
|
||||
|
||||
// Test with no requirements
|
||||
ctx := context.Background()
|
||||
result, err := authMiddleware.Directive(ctx, nil, func(ctx context.Context) (interface{}, error) {
|
||||
return "success", nil
|
||||
}, nil, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "success", result)
|
||||
}
|
||||
Reference in New Issue
Block a user